markytools commited on
Commit
d61b9c7
·
1 Parent(s): 5f5c8d7

added strexp

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +224 -0
  2. attribution_ops.py +87 -0
  3. augmentation/blur.py +189 -0
  4. augmentation/camera.py +120 -0
  5. augmentation/frost/frost4.jpg +0 -0
  6. augmentation/frost/frost5.jpg +0 -0
  7. augmentation/frost/frost6.jpg +0 -0
  8. augmentation/geometry.py +233 -0
  9. augmentation/noise.py +94 -0
  10. augmentation/ops.py +87 -0
  11. augmentation/pattern.py +115 -0
  12. augmentation/process.py +123 -0
  13. augmentation/test.py +43 -0
  14. augmentation/warp.py +241 -0
  15. augmentation/weather.py +231 -0
  16. callbacks.py +360 -0
  17. captum/__init__.py +3 -0
  18. captum/_utils/__init__.py +0 -0
  19. captum/_utils/av.py +499 -0
  20. captum/_utils/common.py +679 -0
  21. captum/_utils/gradient.py +865 -0
  22. captum/_utils/models/__init__.py +25 -0
  23. captum/_utils/models/linear_model/__init__.py +23 -0
  24. captum/_utils/models/linear_model/model.py +341 -0
  25. captum/_utils/models/linear_model/train.py +364 -0
  26. captum/_utils/models/model.py +66 -0
  27. captum/_utils/progress.py +138 -0
  28. captum/_utils/sample_gradient.py +184 -0
  29. captum/_utils/typing.py +37 -0
  30. captum/attr/__init__.py +143 -0
  31. captum/attr/_core/__init__.py +0 -0
  32. captum/attr/_core/deep_lift.py +1151 -0
  33. captum/attr/_core/feature_ablation.py +591 -0
  34. captum/attr/_core/feature_permutation.py +305 -0
  35. captum/attr/_core/gradient_shap.py +414 -0
  36. captum/attr/_core/guided_backprop_deconvnet.py +322 -0
  37. captum/attr/_core/guided_grad_cam.py +226 -0
  38. captum/attr/_core/input_x_gradient.py +130 -0
  39. captum/attr/_core/integrated_gradients.py +390 -0
  40. captum/attr/_core/kernel_shap.py +348 -0
  41. captum/attr/_core/layer/__init__.py +0 -0
  42. captum/attr/_core/layer/grad_cam.py +217 -0
  43. captum/attr/_core/layer/internal_influence.py +309 -0
  44. captum/attr/_core/layer/layer_activation.py +136 -0
  45. captum/attr/_core/layer/layer_conductance.py +395 -0
  46. captum/attr/_core/layer/layer_deep_lift.py +682 -0
  47. captum/attr/_core/layer/layer_feature_ablation.py +302 -0
  48. captum/attr/_core/layer/layer_gradient_shap.py +474 -0
  49. captum/attr/_core/layer/layer_gradient_x_activation.py +201 -0
  50. captum/attr/_core/layer/layer_integrated_gradients.py +528 -0
.gitignore ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/saved_models/*
2
+ **/data_lmdb_release/*
3
+ **/image_release/*
4
+ **/vitstr_base_patch*
5
+ **/result/*
6
+ **/results/*
7
+ **/oldData/
8
+ *.mdb
9
+ *.xlsx
10
+ *.pth
11
+ *.json
12
+ *.pkl
13
+ *.tar
14
+ *.ipynb
15
+ *.zip
16
+ *.eps
17
+ *.pdf
18
+ **/grcnn_straug/*
19
+ **/augmentation/results/*
20
+ **/tmp/*
21
+ *.sh
22
+ **/__pycache__
23
+ workdir/
24
+ .remote-sync.json
25
+ *.png
26
+ pretrained/
27
+ attributionImgs/
28
+ attributionImgsOld/
29
+ attrSelectivityOld/
30
+
31
+ ### Linux ###
32
+ *~
33
+
34
+ # temporary files which can be created if a process still has a handle open of a deleted file
35
+ .fuse_hidden*
36
+
37
+ # KDE directory preferences
38
+ .directory
39
+
40
+ # Linux trash folder which might appear on any partition or disk
41
+ .Trash-*
42
+
43
+ # .nfs files are created when an open file is removed but is still being accessed
44
+ .nfs*
45
+
46
+ ### OSX ###
47
+ # General
48
+ .DS_Store
49
+ .AppleDouble
50
+ .LSOverride
51
+
52
+ # Icon must end with two \r
53
+ Icon
54
+
55
+ # Thumbnails
56
+ ._*
57
+
58
+ # Files that might appear in the root of a volume
59
+ .DocumentRevisions-V100
60
+ .fseventsd
61
+ .Spotlight-V100
62
+ .TemporaryItems
63
+ .Trashes
64
+ .VolumeIcon.icns
65
+ .com.apple.timemachine.donotpresent
66
+
67
+ # Directories potentially created on remote AFP share
68
+ .AppleDB
69
+ .AppleDesktop
70
+ Network Trash Folder
71
+ Temporary Items
72
+ .apdisk
73
+
74
+ ### Python ###
75
+ # Byte-compiled / optimized / DLL files
76
+ __pycache__/
77
+ *.py[cod]
78
+ *$py.class
79
+
80
+ # C extensions
81
+ *.so
82
+
83
+ # Distribution / packaging
84
+ .Python
85
+ build/
86
+ develop-eggs/
87
+ dist/
88
+ downloads/
89
+ eggs/
90
+ .eggs/
91
+ lib/
92
+ lib64/
93
+ parts/
94
+ sdist/
95
+ var/
96
+ wheels/
97
+ *.egg-info/
98
+ .installed.cfg
99
+ *.egg
100
+ MANIFEST
101
+
102
+ # PyInstaller
103
+ # Usually these files are written by a python script from a template
104
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
105
+ *.manifest
106
+ *.spec
107
+
108
+ # Installer logs
109
+ pip-log.txt
110
+ pip-delete-this-directory.txt
111
+
112
+ # Unit test / coverage reports
113
+ htmlcov/
114
+ .tox/
115
+ .coverage
116
+ .coverage.*
117
+ .cache
118
+ nosetests.xml
119
+ coverage.xml
120
+ *.cover
121
+ .hypothesis/
122
+ .pytest_cache/
123
+
124
+ # Translations
125
+ *.mo
126
+ *.pot
127
+
128
+ # Django stuff:
129
+ *.log
130
+ local_settings.py
131
+ db.sqlite3
132
+
133
+ # Flask stuff:
134
+ instance/
135
+ .webassets-cache
136
+
137
+ # Scrapy stuff:
138
+ .scrapy
139
+
140
+ # Sphinx documentation
141
+ docs/_build/
142
+
143
+ # PyBuilder
144
+ target/
145
+
146
+ # Jupyter Notebook
147
+ .ipynb_checkpoints
148
+
149
+ # IPython
150
+ profile_default/
151
+ ipython_config.py
152
+
153
+ # pyenv
154
+ .python-version
155
+
156
+ # celery beat schedule file
157
+ celerybeat-schedule
158
+
159
+ # SageMath parsed files
160
+ *.sage.py
161
+
162
+ # Environments
163
+ .env
164
+ .venv
165
+ env/
166
+ venv/
167
+ ENV/
168
+ env.bak/
169
+ venv.bak/
170
+
171
+ # Spyder project settings
172
+ .spyderproject
173
+ .spyproject
174
+
175
+ # Rope project settings
176
+ .ropeproject
177
+
178
+ # mkdocs documentation
179
+ /site
180
+
181
+ # mypy
182
+ .mypy_cache/
183
+ .dmypy.json
184
+ dmypy.json
185
+
186
+ ### Python Patch ###
187
+ .venv/
188
+
189
+ ### Python.VirtualEnv Stack ###
190
+ # Virtualenv
191
+ # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
192
+ [Bb]in
193
+ [Ii]nclude
194
+ [Ll]ib
195
+ [Ll]ib64
196
+ [Ll]ocal
197
+ [Ss]cripts
198
+ pyvenv.cfg
199
+ pip-selfcheck.json
200
+
201
+ ### Windows ###
202
+ # Windows thumbnail cache files
203
+ Thumbs.db
204
+ ehthumbs.db
205
+ ehthumbs_vista.db
206
+
207
+ # Dump file
208
+ *.stackdump
209
+
210
+ # Folder config file
211
+ [Dd]esktop.ini
212
+
213
+ # Recycle Bin used on file shares
214
+ $RECYCLE.BIN/
215
+
216
+ # Windows Installer files
217
+ *.cab
218
+ *.msi
219
+ *.msix
220
+ *.msm
221
+ *.msp
222
+
223
+ # Windows shortcuts
224
+ *.lnk
attribution_ops.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from captum_improve_vitstr import rankedAttributionsBySegm
4
+ import matplotlib.pyplot as plt
5
+ from skimage.color import gray2rgb
6
+ from captum.attr._utils.visualization import visualize_image_attr
7
+ import torch
8
+ import numpy as np
9
+
10
+ def attr_one_dataset():
11
+ modelName = "vitstr"
12
+ datasetName = "IIIT5k_3000"
13
+
14
+ rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/"
15
+ attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/"
16
+ if not os.path.exists(attrOutputImgs):
17
+ os.makedirs(attrOutputImgs)
18
+
19
+ minNumber = 1000000
20
+ maxNumber = 0
21
+ # From a folder containing saved attribution pickle files, convert them into attribution images
22
+ for path, subdirs, files in os.walk(rootDir):
23
+ for name in files:
24
+ fullfilename = os.path.join(rootDir, name) # Value
25
+ # fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl
26
+ partfilename = fullfilename[fullfilename.rfind('/')+1:]
27
+ print("fullfilename: ", fullfilename)
28
+ imgNum = int(partfilename.split('_')[0])
29
+ attrImgName = partfilename.replace('.pkl', '.png')
30
+ minNumber = min(minNumber, imgNum)
31
+ maxNumber = max(maxNumber, imgNum)
32
+ with open(fullfilename, 'rb') as f:
33
+ pklData = pickle.load(f)
34
+ attributions = pklData['attribution']
35
+ segmDataNP = pklData['segmData']
36
+ origImgNP = pklData['origImg']
37
+ if np.isnan(attributions).any():
38
+ continue
39
+ attributions = torch.from_numpy(attributions)
40
+ rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP)
41
+ rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
42
+ rankedAttr = gray2rgb(rankedAttr)
43
+ mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
44
+ mplotfig.savefig(attrOutputImgs + attrImgName)
45
+ mplotfig.clear()
46
+ plt.close(mplotfig)
47
+
48
+ def attr_all_dataset():
49
+ modelName = "vitstr"
50
+
51
+ datasetNameList = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
52
+
53
+ for datasetName in datasetNameList:
54
+ rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/"
55
+ attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/"
56
+ if not os.path.exists(attrOutputImgs):
57
+ os.makedirs(attrOutputImgs)
58
+
59
+ minNumber = 1000000
60
+ maxNumber = 0
61
+ # From a folder containing saved attribution pickle files, convert them into attribution images
62
+ for path, subdirs, files in os.walk(rootDir):
63
+ for name in files:
64
+ fullfilename = os.path.join(rootDir, name) # Value
65
+ # fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl
66
+ partfilename = fullfilename[fullfilename.rfind('/')+1:]
67
+ imgNum = int(partfilename.split('_')[0])
68
+ attrImgName = partfilename.replace('.pkl', '.png')
69
+ minNumber = min(minNumber, imgNum)
70
+ maxNumber = max(maxNumber, imgNum)
71
+ with open(fullfilename, 'rb') as f:
72
+ pklData = pickle.load(f)
73
+ attributions = pklData['attribution']
74
+ segmDataNP = pklData['segmData']
75
+ origImgNP = pklData['origImg']
76
+ attributions = torch.from_numpy(attributions)
77
+ rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP)
78
+ rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
79
+ rankedAttr = gray2rgb(rankedAttr)
80
+ mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
81
+ mplotfig.savefig(attrOutputImgs + attrImgName)
82
+ mplotfig.clear()
83
+ plt.close(mplotfig)
84
+
85
+ if __name__ == '__main__':
86
+ attr_one_dataset()
87
+ # attr_all_dataset()
augmentation/blur.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image, ImageOps
5
+ import torchvision.transforms as transforms
6
+ from wand.image import Image as WandImage
7
+ from scipy.ndimage import zoom as scizoom
8
+ from skimage.filters import gaussian
9
+ from wand.api import library as wandlibrary
10
+ from io import BytesIO
11
+
12
+ #from skimage import color
13
+ from .ops import MotionImage, clipped_zoom, disk, plasma_fractal
14
+ '''
15
+ PIL resize (W,H)
16
+ '''
17
+ class GaussianBlur:
18
+ def __init__(self):
19
+ pass
20
+
21
+ def __call__(self, img, mag=-1, prob=1.):
22
+ if np.random.uniform(0,1) > prob:
23
+ return img
24
+
25
+ W, H = img.size
26
+ #kernel = [(31,31)] prev 1 level only
27
+ kernel = (31, 31)
28
+ sigmas = [.5, 1, 2]
29
+ if mag<0 or mag>=len(kernel):
30
+ index = np.random.randint(0, len(sigmas))
31
+ else:
32
+ index = mag
33
+
34
+ sigma = sigmas[index]
35
+ return transforms.GaussianBlur(kernel_size=kernel, sigma=sigma)(img)
36
+
37
+
38
+ class DefocusBlur:
39
+ def __init__(self):
40
+ pass
41
+
42
+ def __call__(self, img, mag=-1, prob=1.):
43
+ if np.random.uniform(0,1) > prob:
44
+ return img
45
+
46
+ n_channels = len(img.getbands())
47
+ isgray = n_channels == 1
48
+ #c = [(3, 0.1), (4, 0.5), (6, 0.5), (8, 0.5), (10, 0.5)]
49
+ c = [(2, 0.1), (3, 0.1), (4, 0.1)] #, (6, 0.5)] #prev 2 levels only
50
+ if mag<0 or mag>=len(c):
51
+ index = np.random.randint(0, len(c))
52
+ else:
53
+ index = mag
54
+ c = c[index]
55
+
56
+ img = np.array(img) / 255.
57
+ if isgray:
58
+ img = np.expand_dims(img, axis=2)
59
+ img = np.repeat(img, 3, axis=2)
60
+ n_channels = 3
61
+ kernel = disk(radius=c[0], alias_blur=c[1])
62
+
63
+ channels = []
64
+ for d in range(n_channels):
65
+ channels.append(cv2.filter2D(img[:, :, d], -1, kernel))
66
+ channels = np.array(channels).transpose((1, 2, 0)) # 3x224x224 -> 224x224x3
67
+
68
+ #if isgray:
69
+ # img = img[:,:,0]
70
+ # img = np.squeeze(img)
71
+
72
+ img = np.clip(channels, 0, 1) * 255
73
+ img = Image.fromarray(img.astype(np.uint8))
74
+ if isgray:
75
+ img = ImageOps.grayscale(img)
76
+
77
+ return img
78
+
79
+
80
+ class MotionBlur:
81
+ def __init__(self):
82
+ pass
83
+
84
+ def __call__(self, img, mag=-1, prob=1.):
85
+ if np.random.uniform(0,1) > prob:
86
+ return img
87
+
88
+ n_channels = len(img.getbands())
89
+ isgray = n_channels == 1
90
+ #c = [(10, 3), (15, 5), (15, 8), (15, 12), (20, 15)]
91
+ c = [(10, 3), (12, 4), (14, 5)]
92
+ if mag<0 or mag>=len(c):
93
+ index = np.random.randint(0, len(c))
94
+ else:
95
+ index = mag
96
+ c = c[index]
97
+
98
+ output = BytesIO()
99
+ img.save(output, format='PNG')
100
+ img = MotionImage(blob=output.getvalue())
101
+
102
+ img.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45))
103
+ img = cv2.imdecode(np.fromstring(img.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED)
104
+ if len(img.shape) > 2:
105
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
106
+
107
+ img = Image.fromarray(img.astype(np.uint8))
108
+
109
+ if isgray:
110
+ img = ImageOps.grayscale(img)
111
+
112
+ return img
113
+
114
+ class GlassBlur:
115
+ def __init__(self):
116
+ pass
117
+
118
+ def __call__(self, img, mag=-1, prob=1.):
119
+ if np.random.uniform(0,1) > prob:
120
+ return img
121
+
122
+ W, H = img.size
123
+ #c = [(0.7, 1, 2), (0.9, 2, 1), (1, 2, 3), (1.1, 3, 2), (1.5, 4, 2)][severity - 1]
124
+ c = [(0.7, 1, 2), (0.75, 1, 2), (0.8, 1, 2)] #, (1, 2, 3)] #prev 2 levels only
125
+ if mag<0 or mag>=len(c):
126
+ index = np.random.randint(0, len(c))
127
+ else:
128
+ index = mag
129
+
130
+ c = c[index]
131
+
132
+ img = np.uint8(gaussian(np.array(img) / 255., sigma=c[0], multichannel=True) * 255)
133
+
134
+ # locally shuffle pixels
135
+ for i in range(c[2]):
136
+ for h in range(H - c[1], c[1], -1):
137
+ for w in range(W - c[1], c[1], -1):
138
+ dx, dy = np.random.randint(-c[1], c[1], size=(2,))
139
+ h_prime, w_prime = h + dy, w + dx
140
+ # swap
141
+ img[h, w], img[h_prime, w_prime] = img[h_prime, w_prime], img[h, w]
142
+
143
+ img = np.clip(gaussian(img / 255., sigma=c[0], multichannel=True), 0, 1) * 255
144
+ return Image.fromarray(img.astype(np.uint8))
145
+
146
+
147
+ class ZoomBlur:
148
+ def __init__(self):
149
+ pass
150
+
151
+ def __call__(self, img, mag=-1, prob=1.):
152
+ if np.random.uniform(0,1) > prob:
153
+ return img
154
+
155
+ W, H = img.size
156
+ c = [np.arange(1, 1.11, .01),
157
+ np.arange(1, 1.16, .01),
158
+ np.arange(1, 1.21, .02)]
159
+ if mag<0 or mag>=len(c):
160
+ index = np.random.randint(0, len(c))
161
+ else:
162
+ index = mag
163
+
164
+ c = c[index]
165
+
166
+ n_channels = len(img.getbands())
167
+ isgray = n_channels == 1
168
+
169
+ uint8_img = img
170
+ img = (np.array(img) / 255.).astype(np.float32)
171
+
172
+ out = np.zeros_like(img)
173
+ for zoom_factor in c:
174
+ ZW = int(W*zoom_factor)
175
+ ZH = int(H*zoom_factor)
176
+ zoom_img = uint8_img.resize((ZW, ZH), Image.BICUBIC)
177
+ x1 = (ZW - W) // 2
178
+ y1 = (ZH - H) // 2
179
+ x2 = x1 + W
180
+ y2 = y1 + H
181
+ zoom_img = zoom_img.crop((x1,y1,x2,y2))
182
+ out += (np.array(zoom_img) / 255.).astype(np.float32)
183
+
184
+ img = (img + out) / (len(c) + 1)
185
+
186
+ img = np.clip(img, 0, 1) * 255
187
+ img = Image.fromarray(img.astype(np.uint8))
188
+
189
+ return img
augmentation/camera.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import numpy as np
4
+ import skimage as sk
5
+ from PIL import Image, ImageOps
6
+ from io import BytesIO
7
+
8
+ from skimage import color
9
+ '''
10
+ PIL resize (W,H)
11
+ cv2 image is BGR
12
+ PIL image is RGB
13
+ '''
14
+ class Contrast:
15
+ def __init__(self):
16
+ pass
17
+
18
+ def __call__(self, img, mag=-1, prob=1.):
19
+ if np.random.uniform(0,1) > prob:
20
+ return img
21
+
22
+ #c = [0.4, .3, .2, .1, .05]
23
+ c = [0.4, .3, .2]
24
+ if mag<0 or mag>=len(c):
25
+ index = np.random.randint(0, len(c))
26
+ else:
27
+ index = mag
28
+ c = c[index]
29
+ img = np.array(img) / 255.
30
+ means = np.mean(img, axis=(0, 1), keepdims=True)
31
+ img = np.clip((img - means) * c + means, 0, 1) * 255
32
+
33
+ return Image.fromarray(img.astype(np.uint8))
34
+
35
+
36
+ class Brightness:
37
+ def __init__(self):
38
+ pass
39
+
40
+ def __call__(self, img, mag=-1, prob=1.):
41
+ if np.random.uniform(0,1) > prob:
42
+ return img
43
+
44
+ #W, H = img.size
45
+ #c = [.1, .2, .3, .4, .5]
46
+ c = [.1, .2, .3]
47
+ if mag<0 or mag>=len(c):
48
+ index = np.random.randint(0, len(c))
49
+ else:
50
+ index = mag
51
+ c = c[index]
52
+
53
+ n_channels = len(img.getbands())
54
+ isgray = n_channels == 1
55
+
56
+ img = np.array(img) / 255.
57
+ if isgray:
58
+ img = np.expand_dims(img, axis=2)
59
+ img = np.repeat(img, 3, axis=2)
60
+
61
+ img = sk.color.rgb2hsv(img)
62
+ img[:, :, 2] = np.clip(img[:, :, 2] + c, 0, 1)
63
+ img = sk.color.hsv2rgb(img)
64
+
65
+ #if isgray:
66
+ # img = img[:,:,0]
67
+ # img = np.squeeze(img)
68
+
69
+ img = np.clip(img, 0, 1) * 255
70
+ img = Image.fromarray(img.astype(np.uint8))
71
+ if isgray:
72
+ img = ImageOps.grayscale(img)
73
+
74
+ return img
75
+ #if isgray:
76
+ #if isgray:
77
+ # img = color.rgb2gray(img)
78
+
79
+ #return Image.fromarray(img.astype(np.uint8))
80
+
81
+
82
+ class JpegCompression:
83
+ def __init__(self):
84
+ pass
85
+
86
+ def __call__(self, img, mag=-1, prob=1.):
87
+ if np.random.uniform(0,1) > prob:
88
+ return img
89
+
90
+ #c = [25, 18, 15, 10, 7]
91
+ c = [25, 18, 15]
92
+ if mag<0 or mag>=len(c):
93
+ index = np.random.randint(0, len(c))
94
+ else:
95
+ index = mag
96
+ c = c[index]
97
+ output = BytesIO()
98
+ img.save(output, 'JPEG', quality=c)
99
+ return Image.open(output)
100
+
101
+
102
+ class Pixelate:
103
+ def __init__(self):
104
+ pass
105
+
106
+ def __call__(self, img, mag=-1, prob=1.):
107
+ if np.random.uniform(0,1) > prob:
108
+ return img
109
+
110
+ W, H = img.size
111
+ #c = [0.6, 0.5, 0.4, 0.3, 0.25]
112
+ c = [0.6, 0.5, 0.4]
113
+ if mag<0 or mag>=len(c):
114
+ index = np.random.randint(0, len(c))
115
+ else:
116
+ index = mag
117
+ c = c[index]
118
+ img = img.resize((int(W* c), int(H * c)), Image.BOX)
119
+ return img.resize((W, H), Image.BOX)
120
+
augmentation/frost/frost4.jpg ADDED
augmentation/frost/frost5.jpg ADDED
augmentation/frost/frost6.jpg ADDED
augmentation/geometry.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image, ImageOps
5
+
6
+ '''
7
+ PIL resize (W,H)
8
+ Torch resize is (H,W)
9
+ '''
10
+ class Shrink:
11
+ def __init__(self):
12
+ self.tps = cv2.createThinPlateSplineShapeTransformer()
13
+ self.translateXAbs = TranslateXAbs()
14
+ self.translateYAbs = TranslateYAbs()
15
+
16
+ def __call__(self, img, mag=-1, prob=1.):
17
+ if np.random.uniform(0,1) > prob:
18
+ return img
19
+
20
+ W, H = img.size
21
+ img = np.array(img)
22
+ srcpt = list()
23
+ dstpt = list()
24
+
25
+ W_33 = 0.33 * W
26
+ W_50 = 0.50 * W
27
+ W_66 = 0.66 * W
28
+
29
+ H_50 = 0.50 * H
30
+
31
+ P = 0
32
+
33
+ #frac = 0.4
34
+
35
+ b = [.2, .3, .4]
36
+ if mag<0 or mag>=len(b):
37
+ index = 0
38
+ else:
39
+ index = mag
40
+ frac = b[index]
41
+
42
+ # left-most
43
+ srcpt.append([P, P])
44
+ srcpt.append([P, H-P])
45
+ x = np.random.uniform(frac-.1, frac)*W_33
46
+ y = np.random.uniform(frac-.1, frac)*H_50
47
+ dstpt.append([P+x, P+y])
48
+ dstpt.append([P+x, H-P-y])
49
+
50
+ # 2nd left-most
51
+ srcpt.append([P+W_33, P])
52
+ srcpt.append([P+W_33, H-P])
53
+ dstpt.append([P+W_33, P+y])
54
+ dstpt.append([P+W_33, H-P-y])
55
+
56
+ # 3rd left-most
57
+ srcpt.append([P+W_66, P])
58
+ srcpt.append([P+W_66, H-P])
59
+ dstpt.append([P+W_66, P+y])
60
+ dstpt.append([P+W_66, H-P-y])
61
+
62
+ # right-most
63
+ srcpt.append([W-P, P])
64
+ srcpt.append([W-P, H-P])
65
+ dstpt.append([W-P-x, P+y])
66
+ dstpt.append([W-P-x, H-P-y])
67
+
68
+ N = len(dstpt)
69
+ matches = [cv2.DMatch(i, i, 0) for i in range(N)]
70
+ dst_shape = np.array(dstpt).reshape((-1, N, 2))
71
+ src_shape = np.array(srcpt).reshape((-1, N, 2))
72
+ self.tps.estimateTransformation(dst_shape, src_shape, matches)
73
+ img = self.tps.warpImage(img)
74
+ img = Image.fromarray(img)
75
+
76
+ if np.random.uniform(0, 1) < 0.5:
77
+ img = self.translateXAbs(img, val=x)
78
+ else:
79
+ img = self.translateYAbs(img, val=y)
80
+
81
+ return img
82
+
83
+
84
+ class Rotate:
85
+ def __init__(self, square_side=224):
86
+ self.side = square_side
87
+
88
+ def __call__(self, img, iscurve=False, mag=-1, prob=1.):
89
+ if np.random.uniform(0,1) > prob:
90
+ return img
91
+
92
+ W, H = img.size
93
+
94
+ if H!=self.side or W!=self.side:
95
+ img = img.resize((self.side, self.side), Image.BICUBIC)
96
+
97
+ b = [20., 40, 60]
98
+ if mag<0 or mag>=len(b):
99
+ index = 1
100
+ else:
101
+ index = mag
102
+ rotate_angle = b[index]
103
+
104
+ angle = np.random.uniform(rotate_angle-20, rotate_angle)
105
+ if np.random.uniform(0, 1) < 0.5:
106
+ angle = -angle
107
+
108
+ #angle = np.random.normal(loc=0., scale=rotate_angle)
109
+ #angle = min(angle, 2*rotate_angle)
110
+ #angle = max(angle, -2*rotate_angle)
111
+
112
+ expand = False if iscurve else True
113
+ img = img.rotate(angle=angle, resample=Image.BICUBIC, expand=expand)
114
+ img = img.resize((W, H), Image.BICUBIC)
115
+
116
+ return img
117
+
118
+ class Perspective:
119
+ def __init__(self):
120
+ pass
121
+
122
+ def __call__(self, img, mag=-1, prob=1.):
123
+ if np.random.uniform(0,1) > prob:
124
+ return img
125
+
126
+ W, H = img.size
127
+
128
+ # upper-left, upper-right, lower-left, lower-right
129
+ src = np.float32([[0, 0], [W, 0], [0, H], [W, H]])
130
+ #low = 0.3
131
+
132
+ b = [.1, .2, .3]
133
+ if mag<0 or mag>=len(b):
134
+ index = 2
135
+ else:
136
+ index = mag
137
+ low = b[index]
138
+
139
+ high = 1 - low
140
+ if np.random.uniform(0, 1) > 0.5:
141
+ toprightY = np.random.uniform(low, low+.1)*H
142
+ bottomrightY = np.random.uniform(high-.1, high)*H
143
+ dest = np.float32([[0, 0], [W, toprightY], [0, H], [W, bottomrightY]])
144
+ else:
145
+ topleftY = np.random.uniform(low, low+.1)*H
146
+ bottomleftY = np.random.uniform(high-.1, high)*H
147
+ dest = np.float32([[0, topleftY], [W, 0], [0, bottomleftY], [W, H]])
148
+ M = cv2.getPerspectiveTransform(src, dest)
149
+ img = np.array(img)
150
+ img = cv2.warpPerspective(img, M, (W, H) )
151
+ img = Image.fromarray(img)
152
+
153
+ return img
154
+
155
+
156
+ class TranslateX:
157
+ def __init__(self):
158
+ pass
159
+
160
+ def __call__(self, img, mag=-1, prob=1.):
161
+ if np.random.uniform(0,1) > prob:
162
+ return img
163
+
164
+ b = [.03, .06, .09]
165
+ if mag<0 or mag>=len(b):
166
+ index = 2
167
+ else:
168
+ index = mag
169
+ v = b[index]
170
+ v = np.random.uniform(v-0.03, v)
171
+
172
+ v = v * img.size[0]
173
+ if np.random.uniform(0,1) > 0.5:
174
+ v = -v
175
+ return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0))
176
+
177
+
178
+ class TranslateY:
179
+ def __init__(self):
180
+ pass
181
+
182
+ def __call__(self, img, mag=-1, prob=1.):
183
+ if np.random.uniform(0,1) > prob:
184
+ return img
185
+
186
+ b = [.07, .14, .21]
187
+ if mag<0 or mag>=len(b):
188
+ index = 2
189
+ else:
190
+ index = mag
191
+ v = b[index]
192
+ v = np.random.uniform(v-0.07, v)
193
+
194
+ v = v * img.size[1]
195
+ if np.random.uniform(0,1) > 0.5:
196
+ v = -v
197
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v))
198
+
199
+
200
+ class TranslateXAbs:
201
+ def __init__(self):
202
+ pass
203
+
204
+ def __call__(self, img, val=0, prob=1.):
205
+ if np.random.uniform(0,1) > prob:
206
+ return img
207
+
208
+ v = np.random.uniform(0, val)
209
+
210
+ if np.random.uniform(0,1) > 0.5:
211
+ v = -v
212
+ return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0))
213
+
214
+
215
+ class TranslateYAbs:
216
+ def __init__(self):
217
+ pass
218
+
219
+ def __call__(self, img, val=0, prob=1.):
220
+ if np.random.uniform(0,1) > prob:
221
+ return img
222
+
223
+ v = np.random.uniform(0, val)
224
+
225
+ if np.random.uniform(0,1) > 0.5:
226
+ v = -v
227
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v))
228
+
229
+
230
+
231
+
232
+
233
+
augmentation/noise.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import skimage as sk
4
+ from PIL import Image
5
+
6
+ '''
7
+ PIL resize (W,H)
8
+ '''
9
+ class GaussianNoise:
10
+ def __init__(self):
11
+ pass
12
+
13
+ def __call__(self, img, mag=-1, prob=1.):
14
+ if np.random.uniform(0,1) > prob:
15
+ return img
16
+
17
+ W, H = img.size
18
+ #c = np.random.uniform(.08, .38)
19
+ b = [.08, 0.1, 0.12]
20
+ if mag<0 or mag>=len(b):
21
+ index = 0
22
+ else:
23
+ index = mag
24
+ a = b[index]
25
+ c = np.random.uniform(a, a+0.03)
26
+ img = np.array(img) / 255.
27
+ img = np.clip(img + np.random.normal(size=img.shape, scale=c), 0, 1) * 255
28
+ return Image.fromarray(img.astype(np.uint8))
29
+
30
+
31
+ class ShotNoise:
32
+ def __init__(self):
33
+ pass
34
+
35
+ def __call__(self, img, mag=-1, prob=1.):
36
+ if np.random.uniform(0,1) > prob:
37
+ return img
38
+
39
+ W, H = img.size
40
+ #c = np.random.uniform(3, 60)
41
+ b = [13, 8, 3]
42
+ if mag<0 or mag>=len(b):
43
+ index = 2
44
+ else:
45
+ index = mag
46
+ a = b[index]
47
+ c = np.random.uniform(a, a+7)
48
+ img = np.array(img) / 255.
49
+ img = np.clip(np.random.poisson(img * c) / float(c), 0, 1) * 255
50
+ return Image.fromarray(img.astype(np.uint8))
51
+
52
+
53
+ class ImpulseNoise:
54
+ def __init__(self):
55
+ pass
56
+
57
+ def __call__(self, img, mag=-1, prob=1.):
58
+ if np.random.uniform(0,1) > prob:
59
+ return img
60
+
61
+ W, H = img.size
62
+ #c = np.random.uniform(.03, .27)
63
+ b = [.03, .07, .11]
64
+ if mag<0 or mag>=len(b):
65
+ index = 0
66
+ else:
67
+ index = mag
68
+ a = b[index]
69
+ c = np.random.uniform(a, a+.04)
70
+ img = sk.util.random_noise(np.array(img) / 255., mode='s&p', amount=c) * 255
71
+ return Image.fromarray(img.astype(np.uint8))
72
+
73
+
74
+ class SpeckleNoise:
75
+ def __init__(self):
76
+ pass
77
+
78
+ def __call__(self, img, mag=-1, prob=1.):
79
+ if np.random.uniform(0,1) > prob:
80
+ return img
81
+
82
+ W, H = img.size
83
+ # c = np.random.uniform(.15, .6)
84
+ b = [.15, .2, .25]
85
+ if mag<0 or mag>=len(b):
86
+ index = 0
87
+ else:
88
+ index = mag
89
+ a = b[index]
90
+ c = np.random.uniform(a, a+.05)
91
+ img = np.array(img) / 255.
92
+ img = np.clip(img + img * np.random.normal(size=img.shape, scale=c), 0, 1) * 255
93
+ return Image.fromarray(img.astype(np.uint8))
94
+
augmentation/ops.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import numpy as np
4
+ from wand.image import Image as WandImage
5
+ from scipy.ndimage import zoom as scizoom
6
+ from wand.api import library as wandlibrary
7
+
8
+ class MotionImage(WandImage):
9
+ def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
10
+ wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)
11
+
12
+ def clipped_zoom(img, zoom_factor):
13
+ h = img.shape[1]
14
+ # ceil crop height(= crop width)
15
+ ch = int(np.ceil(h / float(zoom_factor)))
16
+
17
+ top = (h - ch) // 2
18
+ img = scizoom(img[top:top + ch, top:top + ch], (zoom_factor, zoom_factor, 1), order=1)
19
+ # trim off any extra pixels
20
+ trim_top = (img.shape[0] - h) // 2
21
+
22
+ return img[trim_top:trim_top + h, trim_top:trim_top + h]
23
+
24
+ def disk(radius, alias_blur=0.1, dtype=np.float32):
25
+ if radius <= 8:
26
+ L = np.arange(-8, 8 + 1)
27
+ ksize = (3, 3)
28
+ else:
29
+ L = np.arange(-radius, radius + 1)
30
+ ksize = (5, 5)
31
+ X, Y = np.meshgrid(L, L)
32
+ aliased_disk = np.array((X ** 2 + Y ** 2) <= radius ** 2, dtype=dtype)
33
+ aliased_disk /= np.sum(aliased_disk)
34
+
35
+ # supersample disk to antialias
36
+ return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)
37
+
38
+ # modification of https://github.com/FLHerne/mapgen/blob/master/diamondsquare.py
39
+ def plasma_fractal(mapsize=256, wibbledecay=3):
40
+ """
41
+ Generate a heightmap using diamond-square algorithm.
42
+ Return square 2d array, side length 'mapsize', of floats in range 0-255.
43
+ 'mapsize' must be a power of two.
44
+ """
45
+ assert (mapsize & (mapsize - 1) == 0)
46
+ maparray = np.empty((mapsize, mapsize), dtype=np.float_)
47
+ maparray[0, 0] = 0
48
+ stepsize = mapsize
49
+ wibble = 100
50
+
51
+ def wibbledmean(array):
52
+ return array / 4 + wibble * np.random.uniform(-wibble, wibble, array.shape)
53
+
54
+ def fillsquares():
55
+ """For each square of points stepsize apart,
56
+ calculate middle value as mean of points + wibble"""
57
+ cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
58
+ squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
59
+ squareaccum += np.roll(squareaccum, shift=-1, axis=1)
60
+ maparray[stepsize // 2:mapsize:stepsize,
61
+ stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum)
62
+
63
+ def filldiamonds():
64
+ """For each diamond of points stepsize apart,
65
+ calculate middle value as mean of points + wibble"""
66
+ mapsize = maparray.shape[0]
67
+ drgrid = maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize]
68
+ ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
69
+ ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
70
+ lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
71
+ ltsum = ldrsum + lulsum
72
+ maparray[0:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum)
73
+ tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
74
+ tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
75
+ ttsum = tdrsum + tulsum
76
+ maparray[stepsize // 2:mapsize:stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum)
77
+
78
+ while stepsize >= 2:
79
+ fillsquares()
80
+ filldiamonds()
81
+ stepsize //= 2
82
+ wibble /= wibbledecay
83
+
84
+ maparray -= maparray.min()
85
+ return maparray / maparray.max()
86
+
87
+
augmentation/pattern.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image, ImageOps, ImageDraw
5
+
6
+ '''
7
+ PIL resize (W,H)
8
+ Torch resize is (H,W)
9
+ '''
10
+ class VGrid:
11
+ def __init__(self):
12
+ pass
13
+
14
+ def __call__(self, img, copy=True, max_width=4, mag=-1, prob=1.):
15
+ if np.random.uniform(0,1) > prob:
16
+ return img
17
+
18
+ if copy:
19
+ img = img.copy()
20
+ W, H = img.size
21
+
22
+ if mag<0 or mag>max_width:
23
+ line_width = np.random.randint(1, max_width)
24
+ image_stripe = np.random.randint(1, max_width)
25
+ else:
26
+ line_width = 1
27
+ image_stripe = 3 - mag
28
+
29
+ n_lines = W // (line_width + image_stripe) + 1
30
+ draw = ImageDraw.Draw(img)
31
+ for i in range(1, n_lines):
32
+ x = image_stripe*i + line_width*(i-1)
33
+ draw.line([(x,0), (x,H)], width=line_width, fill='black')
34
+
35
+ return img
36
+
37
+ class HGrid:
38
+ def __init__(self):
39
+ pass
40
+
41
+ def __call__(self, img, copy=True, max_width=4, mag=-1, prob=1.):
42
+ if np.random.uniform(0,1) > prob:
43
+ return img
44
+
45
+ if copy:
46
+ img = img.copy()
47
+ W, H = img.size
48
+ if mag<0 or mag>max_width:
49
+ line_width = np.random.randint(1, max_width)
50
+ image_stripe = np.random.randint(1, max_width)
51
+ else:
52
+ line_width = 1
53
+ image_stripe = 3 - mag
54
+
55
+ n_lines = H // (line_width + image_stripe) + 1
56
+ draw = ImageDraw.Draw(img)
57
+ for i in range(1, n_lines):
58
+ y = image_stripe*i + line_width*(i-1)
59
+ draw.line([(0,y), (W, y)], width=line_width, fill='black')
60
+
61
+ return img
62
+
63
+ class Grid:
64
+ def __init__(self):
65
+ pass
66
+
67
+ def __call__(self, img, mag=-1, prob=1.):
68
+ if np.random.uniform(0,1) > prob:
69
+ return img
70
+
71
+ img = VGrid()(img, copy=True, mag=mag)
72
+ img = HGrid()(img, copy=False, mag=mag)
73
+ return img
74
+
75
+ class RectGrid:
76
+ def __init__(self):
77
+ pass
78
+
79
+ def __call__(self, img, isellipse=False, mag=-1, prob=1.):
80
+ if np.random.uniform(0,1) > prob:
81
+ return img
82
+
83
+ img = img.copy()
84
+ W, H = img.size
85
+ line_width = 1
86
+ image_stripe = 3 - mag #np.random.randint(2, 6)
87
+ offset = 4 if isellipse else 1
88
+ n_lines = ((H//2) // (line_width + image_stripe)) + offset
89
+ draw = ImageDraw.Draw(img)
90
+ x_center = W // 2
91
+ y_center = H // 2
92
+ for i in range(1, n_lines):
93
+ dx = image_stripe*i + line_width*(i-1)
94
+ dy = image_stripe*i + line_width*(i-1)
95
+ x1 = x_center - (dx * W//H)
96
+ y1 = y_center - dy
97
+ x2 = x_center + (dx * W/H)
98
+ y2 = y_center + dy
99
+ if isellipse:
100
+ draw.ellipse([(x1,y1), (x2, y2)], width=line_width, outline='black')
101
+ else:
102
+ draw.rectangle([(x1,y1), (x2, y2)], width=line_width, outline='black')
103
+
104
+ return img
105
+
106
+ class EllipseGrid:
107
+ def __init__(self):
108
+ pass
109
+
110
+ def __call__(self, img, mag=-1, prob=1.):
111
+ if np.random.uniform(0,1) > prob:
112
+ return img
113
+
114
+ img = RectGrid()(img, isellipse=True, mag=mag, prob=prob)
115
+ return img
augmentation/process.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from PIL import Image
3
+ import PIL.ImageOps, PIL.ImageEnhance
4
+ import numpy as np
5
+
6
+ class Posterize:
7
+ def __init__(self):
8
+ pass
9
+
10
+ def __call__(self, img, mag=-1, prob=1.):
11
+ if np.random.uniform(0,1) > prob:
12
+ return img
13
+
14
+ c = [1, 3, 6]
15
+ if mag<0 or mag>=len(c):
16
+ index = np.random.randint(0, len(c))
17
+ else:
18
+ index = mag
19
+ c = c[index]
20
+ bit = np.random.randint(c, c+2)
21
+ img = PIL.ImageOps.posterize(img, bit)
22
+
23
+ return img
24
+
25
+
26
+ class Solarize:
27
+ def __init__(self):
28
+ pass
29
+
30
+ def __call__(self, img, mag=-1, prob=1.):
31
+ if np.random.uniform(0,1) > prob:
32
+ return img
33
+
34
+ c = [64, 128, 192]
35
+ if mag<0 or mag>=len(c):
36
+ index = np.random.randint(0, len(c))
37
+ else:
38
+ index = mag
39
+ c = c[index]
40
+ thresh = np.random.randint(c, c+64)
41
+ img = PIL.ImageOps.solarize(img, thresh)
42
+
43
+ return img
44
+
45
+ class Invert:
46
+ def __init__(self):
47
+ pass
48
+
49
+ def __call__(self, img, mag=-1, prob=1.):
50
+ if np.random.uniform(0,1) > prob:
51
+ return img
52
+
53
+ img = PIL.ImageOps.invert(img)
54
+
55
+ return img
56
+
57
+
58
+ class Equalize:
59
+ def __init__(self):
60
+ pass
61
+
62
+ def __call__(self, img, mag=-1, prob=1.):
63
+ if np.random.uniform(0,1) > prob:
64
+ return img
65
+
66
+ mg = PIL.ImageOps.equalize(img)
67
+
68
+ return img
69
+
70
+
71
+ class AutoContrast:
72
+ def __init__(self):
73
+ pass
74
+
75
+ def __call__(self, img, mag=-1, prob=1.):
76
+ if np.random.uniform(0,1) > prob:
77
+ return img
78
+
79
+ mg = PIL.ImageOps.autocontrast(img)
80
+
81
+ return img
82
+
83
+
84
+ class Sharpness:
85
+ def __init__(self):
86
+ pass
87
+
88
+ def __call__(self, img, mag=-1, prob=1.):
89
+ if np.random.uniform(0,1) > prob:
90
+ return img
91
+
92
+ c = [.1, .7, 1.3]
93
+ if mag<0 or mag>=len(c):
94
+ index = np.random.randint(0, len(c))
95
+ else:
96
+ index = mag
97
+ c = c[index]
98
+ magnitude = np.random.uniform(c, c+.6)
99
+ img = PIL.ImageEnhance.Sharpness(img).enhance(magnitude)
100
+
101
+ return img
102
+
103
+
104
+ class Color:
105
+ def __init__(self):
106
+ pass
107
+
108
+ def __call__(self, img, mag=-1, prob=1.):
109
+ if np.random.uniform(0,1) > prob:
110
+ return img
111
+
112
+ c = [.1, .7, 1.3]
113
+ if mag<0 or mag>=len(c):
114
+ index = np.random.randint(0, len(c))
115
+ else:
116
+ index = mag
117
+ c = c[index]
118
+ magnitude = np.random.uniform(c, c+.6)
119
+ img = PIL.ImageEnhance.Color(img).enhance(magnitude)
120
+
121
+ return img
122
+
123
+
augmentation/test.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import cv2
4
+ from warp import Curve, Distort, Stretch
5
+ from geometry import Rotate, Perspective, Shrink, TranslateX, TranslateY
6
+ from pattern import VGrid, HGrid, Grid, RectGrid, EllipseGrid
7
+ from noise import GaussianNoise, ShotNoise, ImpulseNoise, SpeckleNoise
8
+ from blur import GaussianBlur, DefocusBlur, MotionBlur, GlassBlur, ZoomBlur
9
+ from camera import Contrast, Brightness, JpegCompression, Pixelate
10
+ from weather import Fog, Snow, Frost, Rain, Shadow
11
+ from process import Posterize, Solarize, Invert, Equalize, AutoContrast, Sharpness, Color
12
+
13
+ from PIL import Image
14
+ import PIL.ImageOps
15
+ import numpy as np
16
+ import argparse
17
+
18
+
19
+ if __name__ == '__main__':
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--image', default="images/delivery.png", help='Load image file')
22
+ parser.add_argument('--results', default="results", help='Load image file')
23
+ parser.add_argument('--gray', action='store_true', help='Convert to grayscale 1st')
24
+ opt = parser.parse_args()
25
+ os.makedirs(opt.results, exist_ok=True)
26
+
27
+ img = Image.open(opt.image)
28
+ img = img.resize( (100,32) )
29
+ ops = [Curve(), Rotate(), Perspective(), Distort(), Stretch(), Shrink(), TranslateX(), TranslateY(), VGrid(), HGrid(), Grid(), RectGrid(), EllipseGrid()]
30
+ ops.extend([GaussianNoise(), ShotNoise(), ImpulseNoise(), SpeckleNoise()])
31
+ ops.extend([GaussianBlur(), DefocusBlur(), MotionBlur(), GlassBlur(), ZoomBlur()])
32
+ ops.extend([Contrast(), Brightness(), JpegCompression(), Pixelate()])
33
+ ops.extend([Fog(), Snow(), Frost(), Rain(), Shadow()])
34
+ ops.extend([Posterize(), Solarize(), Invert(), Equalize(), AutoContrast(), Sharpness(), Color()])
35
+ for op in ops:
36
+ for mag in range(3):
37
+ filename = type(op).__name__ + "-" + str(mag) + ".png"
38
+ out_img = op(img, mag=mag)
39
+ if opt.gray:
40
+ out_img = PIL.ImageOps.grayscale(out_img)
41
+ out_img.save(os.path.join(opt.results, filename))
42
+
43
+
augmentation/warp.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image, ImageOps
5
+
6
+ '''
7
+ PIL resize (W,H)
8
+ Torch resize is (H,W)
9
+ '''
10
+ class Stretch:
11
+ def __init__(self):
12
+ self.tps = cv2.createThinPlateSplineShapeTransformer()
13
+
14
+ def __call__(self, img, mag=-1, prob=1.):
15
+ if np.random.uniform(0,1) > prob:
16
+ return img
17
+
18
+ W, H = img.size
19
+ img = np.array(img)
20
+ srcpt = list()
21
+ dstpt = list()
22
+
23
+ W_33 = 0.33 * W
24
+ W_50 = 0.50 * W
25
+ W_66 = 0.66 * W
26
+
27
+ H_50 = 0.50 * H
28
+
29
+ P = 0
30
+ #frac = 0.4
31
+
32
+ b = [.2, .3, .4]
33
+ if mag<0 or mag>=len(b):
34
+ index = len(b)-1
35
+ else:
36
+ index = mag
37
+ frac = b[index]
38
+
39
+ # left-most
40
+ srcpt.append([P, P])
41
+ srcpt.append([P, H-P])
42
+ srcpt.append([P, H_50])
43
+ x = np.random.uniform(0, frac)*W_33 #if np.random.uniform(0,1) > 0.5 else 0
44
+ dstpt.append([P+x, P])
45
+ dstpt.append([P+x, H-P])
46
+ dstpt.append([P+x, H_50])
47
+
48
+ # 2nd left-most
49
+ srcpt.append([P+W_33, P])
50
+ srcpt.append([P+W_33, H-P])
51
+ x = np.random.uniform(-frac, frac)*W_33
52
+ dstpt.append([P+W_33+x, P])
53
+ dstpt.append([P+W_33+x, H-P])
54
+
55
+ # 3rd left-most
56
+ srcpt.append([P+W_66, P])
57
+ srcpt.append([P+W_66, H-P])
58
+ x = np.random.uniform(-frac, frac)*W_33
59
+ dstpt.append([P+W_66+x, P])
60
+ dstpt.append([P+W_66+x, H-P])
61
+
62
+ # right-most
63
+ srcpt.append([W-P, P])
64
+ srcpt.append([W-P, H-P])
65
+ srcpt.append([W-P, H_50])
66
+ x = np.random.uniform(-frac, 0)*W_33 #if np.random.uniform(0,1) > 0.5 else 0
67
+ dstpt.append([W-P+x, P])
68
+ dstpt.append([W-P+x, H-P])
69
+ dstpt.append([W-P+x, H_50])
70
+
71
+ N = len(dstpt)
72
+ matches = [cv2.DMatch(i, i, 0) for i in range(N)]
73
+ dst_shape = np.array(dstpt).reshape((-1, N, 2))
74
+ src_shape = np.array(srcpt).reshape((-1, N, 2))
75
+ self.tps.estimateTransformation(dst_shape, src_shape, matches)
76
+ img = self.tps.warpImage(img)
77
+ img = Image.fromarray(img)
78
+
79
+ return img
80
+
81
+
82
+ class Distort:
83
+ def __init__(self):
84
+ self.tps = cv2.createThinPlateSplineShapeTransformer()
85
+
86
+ def __call__(self, img, mag=-1, prob=1.):
87
+ if np.random.uniform(0,1) > prob:
88
+ return img
89
+
90
+ W, H = img.size
91
+ img = np.array(img)
92
+ srcpt = list()
93
+ dstpt = list()
94
+
95
+ W_33 = 0.33 * W
96
+ W_50 = 0.50 * W
97
+ W_66 = 0.66 * W
98
+
99
+ H_50 = 0.50 * H
100
+
101
+ P = 0
102
+ #frac = 0.4
103
+
104
+ b = [.2, .3, .4]
105
+ if mag<0 or mag>=len(b):
106
+ index = len(b)-1
107
+ else:
108
+ index = mag
109
+ frac = b[index]
110
+
111
+ # top pts
112
+ srcpt.append([P, P])
113
+ x = np.random.uniform(0, frac)*W_33
114
+ y = np.random.uniform(0, frac)*H_50
115
+ dstpt.append([P+x, P+y])
116
+
117
+ srcpt.append([P+W_33, P])
118
+ x = np.random.uniform(-frac, frac)*W_33
119
+ y = np.random.uniform(0, frac)*H_50
120
+ dstpt.append([P+W_33+x, P+y])
121
+
122
+ srcpt.append([P+W_66, P])
123
+ x = np.random.uniform(-frac, frac)*W_33
124
+ y = np.random.uniform(0, frac)*H_50
125
+ dstpt.append([P+W_66+x, P+y])
126
+
127
+ srcpt.append([W-P, P])
128
+ x = np.random.uniform(-frac, 0)*W_33
129
+ y = np.random.uniform(0, frac)*H_50
130
+ dstpt.append([W-P+x, P+y])
131
+
132
+ # bottom pts
133
+ srcpt.append([P, H-P])
134
+ x = np.random.uniform(0, frac)*W_33
135
+ y = np.random.uniform(-frac, 0)*H_50
136
+ dstpt.append([P+x, H-P+y])
137
+
138
+ srcpt.append([P+W_33, H-P])
139
+ x = np.random.uniform(-frac, frac)*W_33
140
+ y = np.random.uniform(-frac, 0)*H_50
141
+ dstpt.append([P+W_33+x, H-P+y])
142
+
143
+ srcpt.append([P+W_66, H-P])
144
+ x = np.random.uniform(-frac, frac)*W_33
145
+ y = np.random.uniform(-frac, 0)*H_50
146
+ dstpt.append([P+W_66+x, H-P+y])
147
+
148
+ srcpt.append([W-P, H-P])
149
+ x = np.random.uniform(-frac, 0)*W_33
150
+ y = np.random.uniform(-frac, 0)*H_50
151
+ dstpt.append([W-P+x, H-P+y])
152
+
153
+ N = len(dstpt)
154
+ matches = [cv2.DMatch(i, i, 0) for i in range(N)]
155
+ dst_shape = np.array(dstpt).reshape((-1, N, 2))
156
+ src_shape = np.array(srcpt).reshape((-1, N, 2))
157
+ self.tps.estimateTransformation(dst_shape, src_shape, matches)
158
+ img = self.tps.warpImage(img)
159
+ img = Image.fromarray(img)
160
+
161
+ return img
162
+
163
+
164
+ class Curve:
165
+ def __init__(self, square_side=224):
166
+ self.tps = cv2.createThinPlateSplineShapeTransformer()
167
+ self.side = square_side
168
+
169
+ def __call__(self, img, mag=-1, prob=1.):
170
+ if np.random.uniform(0,1) > prob:
171
+ return img
172
+
173
+ W, H = img.size
174
+
175
+ if H!=self.side or W!=self.side:
176
+ img = img.resize((self.side, self.side), Image.BICUBIC)
177
+
178
+ isflip = np.random.uniform(0,1) > 0.5
179
+ if isflip:
180
+ img = ImageOps.flip(img)
181
+ #img = TF.vflip(img)
182
+
183
+ img = np.array(img)
184
+ w = self.side
185
+ h = self.side
186
+ w_25 = 0.25 * w
187
+ w_50 = 0.50 * w
188
+ w_75 = 0.75 * w
189
+
190
+ b = [1.1, .95, .8]
191
+ if mag<0 or mag>=len(b):
192
+ index = 0
193
+ else:
194
+ index = mag
195
+ rmin = b[index]
196
+
197
+ r = np.random.uniform(rmin, rmin+.1)*h
198
+ x1 = (r**2 - w_50**2)**0.5
199
+ h1 = r - x1
200
+
201
+ t = np.random.uniform(0.4, 0.5)*h
202
+
203
+ w2 = w_50*t/r
204
+ hi = x1*t/r
205
+ h2 = h1 + hi
206
+
207
+ sinb_2 = ((1 - x1/r)/2)**0.5
208
+ cosb_2 = ((1 + x1/r)/2)**0.5
209
+ w3 = w_50 - r*sinb_2
210
+ h3 = r - r*cosb_2
211
+
212
+ w4 = w_50 - (r-t)*sinb_2
213
+ h4 = r - (r-t)*cosb_2
214
+
215
+ w5 = 0.5*w2
216
+ h5 = h1 + 0.5*hi
217
+ h_50 = 0.50*h
218
+
219
+ srcpt = [(0,0 ), (w,0 ), (w_50,0), (0,h ), (w,h ), (w_25,0), (w_75,0 ), (w_50,h), (w_25,h), (w_75,h ), (0,h_50), (w,h_50 )]
220
+ dstpt = [(0,h1), (w,h1), (w_50,0), (w2,h2), (w-w2,h2), (w3, h3), (w-w3,h3), (w_50,t), (w4,h4 ), (w-w4,h4), (w5,h5 ), (w-w5,h5)]
221
+
222
+ N = len(dstpt)
223
+ matches = [cv2.DMatch(i, i, 0) for i in range(N)]
224
+ dst_shape = np.array(dstpt).reshape((-1, N, 2))
225
+ src_shape = np.array(srcpt).reshape((-1, N, 2))
226
+ self.tps.estimateTransformation(dst_shape, src_shape, matches)
227
+ img = self.tps.warpImage(img)
228
+ img = Image.fromarray(img)
229
+
230
+ if isflip:
231
+ #img = TF.vflip(img)
232
+ img = ImageOps.flip(img)
233
+ rect = (0, self.side//2, self.side, self.side)
234
+ else:
235
+ rect = (0, 0, self.side, self.side//2)
236
+
237
+ img = img.crop(rect)
238
+ img = img.resize((W, H), Image.BICUBIC)
239
+ return img
240
+
241
+
augmentation/weather.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import numpy as np
4
+ import math
5
+ from PIL import Image, ImageOps, ImageDraw
6
+ from skimage import color
7
+ from pkg_resources import resource_filename
8
+ from io import BytesIO
9
+ from .ops import plasma_fractal, clipped_zoom, MotionImage
10
+
11
+ '''
12
+ PIL resize (W,H)
13
+ '''
14
+ class Fog:
15
+ def __init__(self):
16
+ pass
17
+
18
+ def __call__(self, img, mag=-1, prob=1.):
19
+ if np.random.uniform(0,1) > prob:
20
+ return img
21
+
22
+ W, H = img.size
23
+ c = [(1.5, 2), (2., 2), (2.5, 1.7)]
24
+ if mag<0 or mag>=len(c):
25
+ index = np.random.randint(0, len(c))
26
+ else:
27
+ index = mag
28
+ c = c[index]
29
+
30
+ n_channels = len(img.getbands())
31
+ isgray = n_channels == 1
32
+
33
+ img = np.array(img) / 255.
34
+ max_val = img.max()
35
+ fog = c[0] * plasma_fractal(wibbledecay=c[1])[:H, :W][..., np.newaxis]
36
+ #x += c[0] * plasma_fractal(wibbledecay=c[1])[:224, :224][..., np.newaxis]
37
+ #return np.clip(x * max_val / (max_val + c[0]), 0, 1) * 255
38
+ if isgray:
39
+ fog = np.squeeze(fog)
40
+ else:
41
+ fog = np.repeat(fog, 3, axis=2)
42
+
43
+ img += fog
44
+ img = np.clip(img * max_val / (max_val + c[0]), 0, 1) * 255
45
+ return Image.fromarray(img.astype(np.uint8))
46
+
47
+
48
+ class Frost:
49
+ def __init__(self):
50
+ pass
51
+
52
+ def __call__(self, img, mag=-1, prob=1.):
53
+ if np.random.uniform(0,1) > prob:
54
+ return img
55
+
56
+ W, H = img.size
57
+ c = [(1, 0.4), (0.8, 0.6), (0.7, 0.7)]
58
+ if mag<0 or mag>=len(c):
59
+ index = np.random.randint(0, len(c))
60
+ else:
61
+ index = mag
62
+ c = c[index]
63
+
64
+ filename = [resource_filename(__name__, 'frost/frost1.png'),
65
+ resource_filename(__name__, 'frost/frost2.png'),
66
+ resource_filename(__name__, 'frost/frost3.png'),
67
+ resource_filename(__name__, 'frost/frost4.jpg'),
68
+ resource_filename(__name__, 'frost/frost5.jpg'),
69
+ resource_filename(__name__, 'frost/frost6.jpg')]
70
+ index = np.random.randint(0, len(filename))
71
+ filename = filename[index]
72
+ frost = cv2.imread(filename)
73
+ #randomly crop and convert to rgb
74
+ x_start, y_start = np.random.randint(0, frost.shape[0] - H), np.random.randint(0, frost.shape[1] - W)
75
+ frost = frost[x_start:x_start + H, y_start:y_start + W][..., [2, 1, 0]]
76
+
77
+ n_channels = len(img.getbands())
78
+ isgray = n_channels == 1
79
+
80
+ img = np.array(img)
81
+
82
+ if isgray:
83
+ img = np.expand_dims(img, axis=2)
84
+ img = np.repeat(img, 3, axis=2)
85
+
86
+ img = img * c[0]
87
+ frost = frost * c[1]
88
+ img = np.clip(c[0] * img + c[1] * frost, 0, 255)
89
+ img = Image.fromarray(img.astype(np.uint8))
90
+ if isgray:
91
+ img = ImageOps.grayscale(img)
92
+
93
+ return img
94
+
95
+ class Snow:
96
+ def __init__(self):
97
+ pass
98
+
99
+ def __call__(self, img, mag=-1, prob=1.):
100
+ if np.random.uniform(0,1) > prob:
101
+ return img
102
+
103
+ W, H = img.size
104
+ c = [(0.1, 0.3, 3, 0.5, 10, 4, 0.8),
105
+ (0.2, 0.3, 2, 0.5, 12, 4, 0.7),
106
+ (0.55, 0.3, 4, 0.9, 12, 8, 0.7)]
107
+ if mag<0 or mag>=len(c):
108
+ index = np.random.randint(0, len(c))
109
+ else:
110
+ index = mag
111
+ c = c[index]
112
+
113
+ n_channels = len(img.getbands())
114
+ isgray = n_channels == 1
115
+
116
+ img = np.array(img, dtype=np.float32) / 255.
117
+ if isgray:
118
+ img = np.expand_dims(img, axis=2)
119
+ img = np.repeat(img, 3, axis=2)
120
+
121
+ snow_layer = np.random.normal(size=img.shape[:2], loc=c[0], scale=c[1]) # [:2] for monochrome
122
+
123
+ #snow_layer = clipped_zoom(snow_layer[..., np.newaxis], c[2])
124
+ snow_layer[snow_layer < c[3]] = 0
125
+
126
+ snow_layer = Image.fromarray((np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode='L')
127
+ output = BytesIO()
128
+ snow_layer.save(output, format='PNG')
129
+ snow_layer = MotionImage(blob=output.getvalue())
130
+
131
+ snow_layer.motion_blur(radius=c[4], sigma=c[5], angle=np.random.uniform(-135, -45))
132
+
133
+ snow_layer = cv2.imdecode(np.fromstring(snow_layer.make_blob(), np.uint8),
134
+ cv2.IMREAD_UNCHANGED) / 255.
135
+
136
+ #snow_layer = cv2.cvtColor(snow_layer, cv2.COLOR_BGR2RGB)
137
+
138
+ snow_layer = snow_layer[..., np.newaxis]
139
+
140
+ img = c[6] * img
141
+ gray_img = (1 - c[6]) * np.maximum(img, cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).reshape(H, W, 1) * 1.5 + 0.5)
142
+ img += gray_img
143
+ img = np.clip(img + snow_layer + np.rot90(snow_layer, k=2), 0, 1) * 255
144
+ img = Image.fromarray(img.astype(np.uint8))
145
+ if isgray:
146
+ img = ImageOps.grayscale(img)
147
+
148
+ return img
149
+
150
+ class Rain:
151
+ def __init__(self):
152
+ pass
153
+
154
+ def __call__(self, img, mag=-1, prob=1.):
155
+ if np.random.uniform(0,1) > prob:
156
+ return img
157
+
158
+ img = img.copy()
159
+ W, H = img.size
160
+ n_channels = len(img.getbands())
161
+ isgray = n_channels == 1
162
+ line_width = np.random.randint(1, 2)
163
+
164
+ c =[50, 70, 90]
165
+ if mag<0 or mag>=len(c):
166
+ index = 0
167
+ else:
168
+ index = mag
169
+ c = c[index]
170
+
171
+ n_rains = np.random.randint(c, c+20)
172
+ slant = np.random.randint(-60, 60)
173
+ fillcolor = 200 if isgray else (200,200,200)
174
+
175
+ draw = ImageDraw.Draw(img)
176
+ for i in range(1, n_rains):
177
+ length = np.random.randint(5, 10)
178
+ x1 = np.random.randint(0, W-length)
179
+ y1 = np.random.randint(0, H-length)
180
+ x2 = x1 + length*math.sin(slant*math.pi/180.)
181
+ y2 = y1 + length*math.cos(slant*math.pi/180.)
182
+ x2 = int(x2)
183
+ y2 = int(y2)
184
+ draw.line([(x1,y1), (x2,y2)], width=line_width, fill=fillcolor)
185
+
186
+ return img
187
+
188
+ class Shadow:
189
+ def __init__(self):
190
+ pass
191
+
192
+ def __call__(self, img, mag=-1, prob=1.):
193
+ if np.random.uniform(0,1) > prob:
194
+ return img
195
+
196
+ #img = img.copy()
197
+ W, H = img.size
198
+ n_channels = len(img.getbands())
199
+ isgray = n_channels == 1
200
+
201
+ c =[64, 96, 128]
202
+ if mag<0 or mag>=len(c):
203
+ index = 0
204
+ else:
205
+ index = mag
206
+ c = c[index]
207
+
208
+ img = img.convert('RGBA')
209
+ overlay = Image.new('RGBA', img.size, (255,255,255,0))
210
+ draw = ImageDraw.Draw(overlay)
211
+ transparency = np.random.randint(c, c+32)
212
+ x1 = np.random.randint(0, W//2)
213
+ y1 = 0
214
+
215
+ x2 = np.random.randint(W//2, W)
216
+ y2 = 0
217
+
218
+ x3 = np.random.randint(W//2, W)
219
+ y3 = H - 1
220
+
221
+ x4 = np.random.randint(0, W//2)
222
+ y4 = H - 1
223
+
224
+ draw.polygon([(x1,y1), (x2,y2), (x3,y3), (x4,y4)], fill=(0,0,0,transparency))
225
+
226
+ img = Image.alpha_composite(img, overlay)
227
+ img = img.convert("RGB")
228
+ if isgray:
229
+ img = ImageOps.grayscale(img)
230
+
231
+ return img
callbacks.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import shutil
3
+ import time
4
+
5
+ import editdistance as ed
6
+ import torchvision.utils as vutils
7
+ from fastai.callbacks.tensorboard import (LearnerTensorboardWriter,
8
+ SummaryWriter, TBWriteRequest,
9
+ asyncTBWriter)
10
+ from fastai.vision import *
11
+ from torch.nn.parallel import DistributedDataParallel
12
+ from torchvision import transforms
13
+
14
+ import dataset_abinet
15
+ from utils_abinet import CharsetMapper, Timer, blend_mask
16
+
17
+
18
+ class IterationCallback(LearnerTensorboardWriter):
19
+ "A `TrackerCallback` that monitor in each iteration."
20
+ def __init__(self, learn:Learner, name:str='model', checpoint_keep_num=5,
21
+ show_iters:int=50, eval_iters:int=1000, save_iters:int=20000,
22
+ start_iters:int=0, stats_iters=20000):
23
+ #if self.learn.rank is not None: time.sleep(self.learn.rank) # keep all event files
24
+ super().__init__(learn, base_dir='.', name=learn.path, loss_iters=show_iters,
25
+ stats_iters=stats_iters, hist_iters=stats_iters)
26
+ self.name, self.bestname = Path(name).name, f'best-{Path(name).name}'
27
+ self.show_iters = show_iters
28
+ self.eval_iters = eval_iters
29
+ self.save_iters = save_iters
30
+ self.start_iters = start_iters
31
+ self.checpoint_keep_num = checpoint_keep_num
32
+ self.metrics_root = 'metrics/' # rewrite
33
+ self.timer = Timer()
34
+ self.host = self.learn.rank is None or self.learn.rank == 0
35
+
36
+ def _write_metrics(self, iteration:int, names:List[str], last_metrics:MetricsList)->None:
37
+ "Writes training metrics to Tensorboard."
38
+ for i, name in enumerate(names):
39
+ if last_metrics is None or len(last_metrics) < i+1: return
40
+ scalar_value = last_metrics[i]
41
+ self._write_scalar(name=name, scalar_value=scalar_value, iteration=iteration)
42
+
43
+ def _write_sub_loss(self, iteration:int, last_losses:dict)->None:
44
+ "Writes sub loss to Tensorboard."
45
+ for name, loss in last_losses.items():
46
+ scalar_value = to_np(loss)
47
+ tag = self.metrics_root + name
48
+ self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
49
+
50
+ def _save(self, name):
51
+ if isinstance(self.learn.model, DistributedDataParallel):
52
+ tmp = self.learn.model
53
+ self.learn.model = self.learn.model.module
54
+ self.learn.save(name)
55
+ self.learn.model = tmp
56
+ else: self.learn.save(name)
57
+
58
+ def _validate(self, dl=None, callbacks=None, metrics=None, keeped_items=False):
59
+ "Validate on `dl` with potential `callbacks` and `metrics`."
60
+ dl = ifnone(dl, self.learn.data.valid_dl)
61
+ metrics = ifnone(metrics, self.learn.metrics)
62
+ cb_handler = CallbackHandler(ifnone(callbacks, []), metrics)
63
+ cb_handler.on_train_begin(1, None, metrics); cb_handler.on_epoch_begin()
64
+ if keeped_items: cb_handler.state_dict.update(dict(keeped_items=[]))
65
+ val_metrics = validate(self.learn.model, dl, self.loss_func, cb_handler)
66
+ cb_handler.on_epoch_end(val_metrics)
67
+ if keeped_items: return cb_handler.state_dict['keeped_items']
68
+ else: return cb_handler.state_dict['last_metrics']
69
+
70
+ def jump_to_epoch_iter(self, epoch:int, iteration:int)->None:
71
+ try:
72
+ self.learn.load(f'{self.name}_{epoch}_{iteration}', purge=False)
73
+ logging.info(f'Loaded {self.name}_{epoch}_{iteration}')
74
+ except: logging.info(f'Model {self.name}_{epoch}_{iteration} not found.')
75
+
76
+ def on_train_begin(self, n_epochs, **kwargs):
77
+ # TODO: can not write graph here
78
+ # super().on_train_begin(**kwargs)
79
+ self.best = -float('inf')
80
+ self.timer.tic()
81
+ if self.host:
82
+ checkpoint_path = self.learn.path/'checkpoint.yaml'
83
+ if checkpoint_path.exists():
84
+ os.remove(checkpoint_path)
85
+ open(checkpoint_path, 'w').close()
86
+ return {'skip_validate': True, 'iteration':self.start_iters} # disable default validate
87
+
88
+ def on_batch_begin(self, **kwargs:Any)->None:
89
+ self.timer.toc_data()
90
+ super().on_batch_begin(**kwargs)
91
+
92
+ def on_batch_end(self, iteration, epoch, last_loss, smooth_loss, train, **kwargs):
93
+ super().on_batch_end(last_loss, iteration, train, **kwargs)
94
+ if iteration == 0: return
95
+
96
+ if iteration % self.loss_iters == 0:
97
+ last_losses = self.learn.loss_func.last_losses
98
+ self._write_sub_loss(iteration=iteration, last_losses=last_losses)
99
+ self.tbwriter.add_scalar(tag=self.metrics_root + 'lr',
100
+ scalar_value=self.opt.lr, global_step=iteration)
101
+
102
+ if iteration % self.show_iters == 0:
103
+ log_str = f'epoch {epoch} iter {iteration}: loss = {last_loss:6.4f}, ' \
104
+ f'smooth loss = {smooth_loss:6.4f}'
105
+ logging.info(log_str)
106
+ # log_str = f'data time = {self.timer.data_diff:.4f}s, runing time = {self.timer.running_diff:.4f}s'
107
+ # logging.info(log_str)
108
+
109
+ if iteration % self.eval_iters == 0:
110
+ # TODO: or remove time to on_epoch_end
111
+ # 1. Record time
112
+ log_str = f'average data time = {self.timer.average_data_time():.4f}s, ' \
113
+ f'average running time = {self.timer.average_running_time():.4f}s'
114
+ logging.info(log_str)
115
+
116
+ # 2. Call validate
117
+ last_metrics = self._validate()
118
+ self.learn.model.train()
119
+ log_str = f'epoch {epoch} iter {iteration}: eval loss = {last_metrics[0]:6.4f}, ' \
120
+ f'ccr = {last_metrics[1]:6.4f}, cwr = {last_metrics[2]:6.4f}, ' \
121
+ f'ted = {last_metrics[3]:6.4f}, ned = {last_metrics[4]:6.4f}, ' \
122
+ f'ted/w = {last_metrics[5]:6.4f}, '
123
+ logging.info(log_str)
124
+ names = ['eval_loss', 'ccr', 'cwr', 'ted', 'ned', 'ted/w']
125
+ self._write_metrics(iteration, names, last_metrics)
126
+
127
+ # 3. Save best model
128
+ current = last_metrics[2]
129
+ if current is not None and current > self.best:
130
+ logging.info(f'Better model found at epoch {epoch}, '\
131
+ f'iter {iteration} with accuracy value: {current:6.4f}.')
132
+ self.best = current
133
+ self._save(f'{self.bestname}')
134
+
135
+ if iteration % self.save_iters == 0 and self.host:
136
+ logging.info(f'Save model {self.name}_{epoch}_{iteration}')
137
+ filename = f'{self.name}_{epoch}_{iteration}'
138
+ self._save(filename)
139
+
140
+ checkpoint_path = self.learn.path/'checkpoint.yaml'
141
+ if not checkpoint_path.exists():
142
+ open(checkpoint_path, 'w').close()
143
+ with open(checkpoint_path, 'r') as file:
144
+ checkpoints = yaml.load(file, Loader=yaml.FullLoader) or dict()
145
+ checkpoints['all_checkpoints'] = (
146
+ checkpoints.get('all_checkpoints') or list())
147
+ checkpoints['all_checkpoints'].insert(0, filename)
148
+ if len(checkpoints['all_checkpoints']) > self.checpoint_keep_num:
149
+ removed_checkpoint = checkpoints['all_checkpoints'].pop()
150
+ removed_checkpoint = self.learn.path/self.learn.model_dir/f'{removed_checkpoint}.pth'
151
+ os.remove(removed_checkpoint)
152
+ checkpoints['current_checkpoint'] = filename
153
+ with open(checkpoint_path, 'w') as file:
154
+ yaml.dump(checkpoints, file)
155
+
156
+
157
+ self.timer.toc_running()
158
+
159
+ def on_train_end(self, **kwargs):
160
+ #self.learn.load(f'{self.bestname}', purge=False)
161
+ pass
162
+
163
+ def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs)->None:
164
+ self._write_embedding(iteration=iteration)
165
+
166
+
167
+ class TextAccuracy(Callback):
168
+ _names = ['ccr', 'cwr', 'ted', 'ned', 'ted/w']
169
+ def __init__(self, charset_path, max_length, case_sensitive, model_eval):
170
+ self.charset_path = charset_path
171
+ self.max_length = max_length
172
+ self.case_sensitive = case_sensitive
173
+ self.charset = CharsetMapper(charset_path, self.max_length)
174
+ self.names = self._names
175
+
176
+ self.model_eval = model_eval or 'alignment'
177
+ assert self.model_eval in ['vision', 'language', 'alignment']
178
+
179
+ def on_epoch_begin(self, **kwargs):
180
+ self.total_num_char = 0.
181
+ self.total_num_word = 0.
182
+ self.correct_num_char = 0.
183
+ self.correct_num_word = 0.
184
+ self.total_ed = 0.
185
+ self.total_ned = 0.
186
+
187
+ def _get_output(self, last_output):
188
+ if isinstance(last_output, (tuple, list)):
189
+ for res in last_output:
190
+ if res['name'] == self.model_eval: output = res
191
+ else: output = last_output
192
+ return output
193
+
194
+ def _update_output(self, last_output, items):
195
+ if isinstance(last_output, (tuple, list)):
196
+ for res in last_output:
197
+ if res['name'] == self.model_eval: res.update(items)
198
+ else: last_output.update(items)
199
+ return last_output
200
+
201
+ def on_batch_end(self, last_output, last_target, **kwargs):
202
+ output = self._get_output(last_output)
203
+ logits, pt_lengths = output['logits'], output['pt_lengths']
204
+ pt_text, pt_scores, pt_lengths_ = self.decode(logits)
205
+ assert (pt_lengths == pt_lengths_).all(), f'{pt_lengths} != {pt_lengths_} for {pt_text}'
206
+ last_output = self._update_output(last_output, {'pt_text':pt_text, 'pt_scores':pt_scores})
207
+
208
+ pt_text = [self.charset.trim(t) for t in pt_text]
209
+ label = last_target[0]
210
+ if label.dim() == 3: label = label.argmax(dim=-1) # one-hot label
211
+ gt_text = [self.charset.get_text(l, trim=True) for l in label]
212
+
213
+ for i in range(len(gt_text)):
214
+ if not self.case_sensitive:
215
+ gt_text[i], pt_text[i] = gt_text[i].lower(), pt_text[i].lower()
216
+ distance = ed.eval(gt_text[i], pt_text[i])
217
+ self.total_ed += distance
218
+ self.total_ned += float(distance) / max(len(gt_text[i]), 1)
219
+
220
+ if gt_text[i] == pt_text[i]:
221
+ self.correct_num_word += 1
222
+ self.total_num_word += 1
223
+
224
+ for j in range(min(len(gt_text[i]), len(pt_text[i]))):
225
+ if gt_text[i][j] == pt_text[i][j]:
226
+ self.correct_num_char += 1
227
+ self.total_num_char += len(gt_text[i])
228
+
229
+ return {'last_output': last_output}
230
+
231
+ def on_epoch_end(self, last_metrics, **kwargs):
232
+ mets = [self.correct_num_char / self.total_num_char,
233
+ self.correct_num_word / self.total_num_word,
234
+ self.total_ed,
235
+ self.total_ned,
236
+ self.total_ed / self.total_num_word]
237
+ return add_metrics(last_metrics, mets)
238
+
239
+ def decode(self, logit):
240
+ """ Greed decode """
241
+ # TODO: test running time and decode on GPU
242
+ out = F.softmax(logit, dim=2)
243
+ pt_text, pt_scores, pt_lengths = [], [], []
244
+ for o in out:
245
+ text = self.charset.get_text(o.argmax(dim=1), padding=False, trim=False)
246
+ text = text.split(self.charset.null_char)[0] # end at end-token
247
+ pt_text.append(text)
248
+ pt_scores.append(o.max(dim=1)[0])
249
+ pt_lengths.append(min(len(text) + 1, self.max_length)) # one for end-token
250
+ pt_scores = torch.stack(pt_scores)
251
+ pt_lengths = pt_scores.new_tensor(pt_lengths, dtype=torch.long)
252
+ return pt_text, pt_scores, pt_lengths
253
+
254
+
255
+ class TopKTextAccuracy(TextAccuracy):
256
+ _names = ['ccr', 'cwr']
257
+ def __init__(self, k, charset_path, max_length, case_sensitive, model_eval):
258
+ self.k = k
259
+ self.charset_path = charset_path
260
+ self.max_length = max_length
261
+ self.case_sensitive = case_sensitive
262
+ self.charset = CharsetMapper(charset_path, self.max_length)
263
+ self.names = self._names
264
+
265
+ def on_epoch_begin(self, **kwargs):
266
+ self.total_num_char = 0.
267
+ self.total_num_word = 0.
268
+ self.correct_num_char = 0.
269
+ self.correct_num_word = 0.
270
+
271
+ def on_batch_end(self, last_output, last_target, **kwargs):
272
+ logits, pt_lengths = last_output['logits'], last_output['pt_lengths']
273
+ gt_labels, gt_lengths = last_target[:]
274
+
275
+ for logit, pt_length, label, length in zip(logits, pt_lengths, gt_labels, gt_lengths):
276
+ word_flag = True
277
+ for i in range(length):
278
+ char_logit = logit[i].topk(self.k)[1]
279
+ char_label = label[i].argmax(-1)
280
+ if char_label in char_logit: self.correct_num_char += 1
281
+ else: word_flag = False
282
+ self.total_num_char += 1
283
+ if pt_length == length and word_flag:
284
+ self.correct_num_word += 1
285
+ self.total_num_word += 1
286
+
287
+ def on_epoch_end(self, last_metrics, **kwargs):
288
+ mets = [self.correct_num_char / self.total_num_char,
289
+ self.correct_num_word / self.total_num_word,
290
+ 0., 0., 0.]
291
+ return add_metrics(last_metrics, mets)
292
+
293
+
294
+ class DumpPrediction(LearnerCallback):
295
+
296
+ def __init__(self, learn, dataset, charset_path, model_eval, image_only=False, debug=False):
297
+ super().__init__(learn=learn)
298
+ self.debug = debug
299
+ self.model_eval = model_eval or 'alignment'
300
+ self.image_only = image_only
301
+ assert self.model_eval in ['vision', 'language', 'alignment']
302
+
303
+ self.dataset, self.root = dataset, Path(self.learn.path)/f'{dataset}-{self.model_eval}'
304
+ self.attn_root = self.root/'attn'
305
+ self.charset = CharsetMapper(charset_path)
306
+ if self.root.exists(): shutil.rmtree(self.root)
307
+ self.root.mkdir(), self.attn_root.mkdir()
308
+
309
+ self.pil = transforms.ToPILImage()
310
+ self.tensor = transforms.ToTensor()
311
+ size = self.learn.data.img_h, self.learn.data.img_w
312
+ self.resize = transforms.Resize(size=size, interpolation=0)
313
+ self.c = 0
314
+
315
+ def on_batch_end(self, last_input, last_output, last_target, **kwargs):
316
+ if isinstance(last_output, (tuple, list)):
317
+ for res in last_output:
318
+ if res['name'] == self.model_eval: pt_text = res['pt_text']
319
+ if res['name'] == 'vision': attn_scores = res['attn_scores'].detach().cpu()
320
+ if res['name'] == self.model_eval: logits = res['logits']
321
+ else:
322
+ pt_text = last_output['pt_text']
323
+ attn_scores = last_output['attn_scores'].detach().cpu()
324
+ logits = last_output['logits']
325
+
326
+ images = last_input[0] if isinstance(last_input, (tuple, list)) else last_input
327
+ images = images.detach().cpu()
328
+ pt_text = [self.charset.trim(t) for t in pt_text]
329
+ gt_label = last_target[0]
330
+ if gt_label.dim() == 3: gt_label = gt_label.argmax(dim=-1) # one-hot label
331
+ gt_text = [self.charset.get_text(l, trim=True) for l in gt_label]
332
+
333
+ prediction, false_prediction = [], []
334
+ for gt, pt, image, attn, logit in zip(gt_text, pt_text, images, attn_scores, logits):
335
+ prediction.append(f'{gt}\t{pt}\n')
336
+ if gt != pt:
337
+ if self.debug:
338
+ scores = torch.softmax(logit, dim=-1)[:max(len(pt), len(gt)) + 1]
339
+ logging.info(f'{self.c} gt {gt}, pt {pt}, logit {logit.shape}, scores {scores.topk(5, dim=-1)}')
340
+ false_prediction.append(f'{gt}\t{pt}\n')
341
+
342
+ image = self.learn.data.denorm(image)
343
+ if not self.image_only:
344
+ image_np = np.array(self.pil(image))
345
+ attn_pil = [self.pil(a) for a in attn[:, None, :, :]]
346
+ attn = [self.tensor(self.resize(a)).repeat(3, 1, 1) for a in attn_pil]
347
+ attn_sum = np.array([np.array(a) for a in attn_pil[:len(pt)]]).sum(axis=0)
348
+ blended_sum = self.tensor(blend_mask(image_np, attn_sum))
349
+ blended = [self.tensor(blend_mask(image_np, np.array(a))) for a in attn_pil]
350
+ save_image = torch.stack([image] + attn + [blended_sum] + blended)
351
+ save_image = save_image.view(2, -1, *save_image.shape[1:])
352
+ save_image = save_image.permute(1, 0, 2, 3, 4).flatten(0, 1)
353
+ vutils.save_image(save_image, self.attn_root/f'{self.c}_{gt}_{pt}.jpg',
354
+ nrow=2, normalize=True, scale_each=True)
355
+ else:
356
+ self.pil(image).save(self.attn_root/f'{self.c}_{gt}_{pt}.jpg')
357
+ self.c += 1
358
+
359
+ with open(self.root/f'{self.model_eval}.txt', 'a') as f: f.writelines(prediction)
360
+ with open(self.root/f'{self.model_eval}-false.txt', 'a') as f: f.writelines(false_prediction)
captum/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ __version__ = "0.5.0"
captum/_utils/__init__.py ADDED
File without changes
captum/_utils/av.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import glob
4
+ import os
5
+ import re
6
+ import warnings
7
+ from typing import Any, List, Optional, Tuple, Union
8
+
9
+ import captum._utils.common as common
10
+ import torch
11
+ from captum.attr import LayerActivation
12
+ from torch import Tensor
13
+ from torch.nn import Module
14
+ from torch.utils.data import DataLoader, Dataset
15
+
16
+
17
+ class AV:
18
+ r"""
19
+ This class provides functionality to store and load activation vectors
20
+ generated for pre-defined neural network layers.
21
+ It also provides functionality to check if activation vectors already
22
+ exist in the manifold and other auxiliary functions.
23
+
24
+ This class also defines a torch `Dataset`, representing Activation Vectors,
25
+ which enables lazy access to activation vectors and layer stored in the manifold.
26
+
27
+ """
28
+
29
+ r"""
30
+ The name of the subfolder in the manifold where the activation vectors
31
+ are stored.
32
+ """
33
+
34
+ class AVDataset(Dataset):
35
+ r"""
36
+ This dataset enables access to activation vectors for a given `model` stored
37
+ under a pre-defined path.
38
+ The iterator of this dataset returns a batch of data tensors.
39
+ Additionally, subsets of the model activations can be loaded based on layer
40
+ or identifier or num_id (representing batch number in source dataset).
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ path: str,
46
+ model_id: str,
47
+ identifier: Optional[str] = None,
48
+ layer: Optional[str] = None,
49
+ num_id: Optional[str] = None,
50
+ ):
51
+ r"""
52
+ Loads into memory the list of all activation file paths associated
53
+ with the input `model_id`.
54
+
55
+ Args:
56
+ path (str): The path where the activation vectors
57
+ for the `layer` are stored.
58
+ model_id (str): The name/version of the model for which layer
59
+ activations are being computed and stored.
60
+ identifier (str or None): An optional identifier for the layer
61
+ activations. Can be used to distinguish between activations for
62
+ different training batches.
63
+ layer (str or None): The layer for which the activation vectors
64
+ are computed.
65
+ num_id (str): An optional string representing the batch number for
66
+ which the activation vectors are computed
67
+ """
68
+
69
+ self.av_filesearch = AV._construct_file_search(
70
+ path, model_id, identifier, layer, num_id
71
+ )
72
+
73
+ files = glob.glob(self.av_filesearch)
74
+
75
+ self.files = AV.sort_files(files)
76
+
77
+ def __getitem__(self, idx: int) -> Union[Tensor, Tuple[Tensor, ...]]:
78
+ assert idx < len(self.files), "Layer index is out of bounds!"
79
+ fl = self.files[idx]
80
+ av = torch.load(fl)
81
+ return av
82
+
83
+ def __len__(self):
84
+ return len(self.files)
85
+
86
+ AV_DIR_NAME: str = "av"
87
+
88
+ def __init__(self) -> None:
89
+ pass
90
+
91
+ @staticmethod
92
+ def _assemble_model_dir(path: str, model_id: str) -> str:
93
+ r"""
94
+ Returns a directory path for the given source path `path` and `model_id.`
95
+ This path is suffixed with the '/' delimiter.
96
+ """
97
+ return "/".join([path, AV.AV_DIR_NAME, model_id, ""])
98
+
99
+ @staticmethod
100
+ def _assemble_file_path(source_dir: str, identifier: str, layer: str) -> str:
101
+ r"""
102
+ Returns a full filepath given a source directory, layer, and required
103
+ identifier. The source dir is not required to end with a "/" delimiter.
104
+ """
105
+ if not source_dir.endswith("/"):
106
+ source_dir += "/"
107
+
108
+ filepath = os.path.join(source_dir, identifier)
109
+
110
+ filepath = os.path.join(filepath, layer)
111
+
112
+ return filepath
113
+
114
+ @staticmethod
115
+ def _construct_file_search(
116
+ source_dir: str,
117
+ model_id: str,
118
+ identifier: Optional[str] = None,
119
+ layer: Optional[str] = None,
120
+ num_id: Optional[str] = None,
121
+ ) -> str:
122
+ r"""
123
+ Returns a search string that can be used by glob to search `source_dir/model_id`
124
+ for the desired layer/identifier pair. Leaving `layer` as None will search ids
125
+ over all layers, and leaving `identifier` as none will search layers over all
126
+ ids. Leaving both as none will return a path to glob for every activation.
127
+ Assumes identifier is always specified when saving activations, so that
128
+ activations live at source_dir/model_id/identifier/layer
129
+ (and never source_dir/model_id/layer)
130
+ """
131
+
132
+ av_filesearch = AV._assemble_model_dir(source_dir, model_id)
133
+
134
+ av_filesearch = os.path.join(
135
+ av_filesearch, "*" if identifier is None else identifier
136
+ )
137
+
138
+ av_filesearch = os.path.join(av_filesearch, "*" if layer is None else layer)
139
+
140
+ av_filesearch = os.path.join(
141
+ av_filesearch, "*.pt" if num_id is None else "%s.pt" % num_id
142
+ )
143
+
144
+ return av_filesearch
145
+
146
+ @staticmethod
147
+ def exists(
148
+ path: str,
149
+ model_id: str,
150
+ identifier: Optional[str] = None,
151
+ layer: Optional[str] = None,
152
+ num_id: Optional[str] = None,
153
+ ) -> bool:
154
+ r"""
155
+ Verifies whether the model + layer activations exist
156
+ under the path.
157
+
158
+ Args:
159
+ path (str): The path where the activation vectors
160
+ for the `model_id` are stored.
161
+ model_id (str): The name/version of the model for which layer activations
162
+ are being computed and stored.
163
+ identifier (str or None): An optional identifier for the layer activations.
164
+ Can be used to distinguish between activations for different
165
+ training batches. For example, the id could be a suffix composed of
166
+ a train/test label and numerical value, such as "-train-xxxxx".
167
+ The numerical id is often a monotonic sequence taken from datetime.
168
+ layer (str or None): The layer for which the activation vectors are
169
+ computed.
170
+ num_id (str): An optional string representing the batch number for which
171
+ the activation vectors are computed
172
+
173
+ Returns:
174
+ exists (bool): Indicating whether the activation vectors for the `layer`
175
+ and `identifier` (if provided) and num_id (if provided) were stored
176
+ in the manifold. If no `identifier` is provided, will return `True`
177
+ if any layer activation exists, whether it has an identifier or
178
+ not, and vice-versa.
179
+ """
180
+ av_dir = AV._assemble_model_dir(path, model_id)
181
+ av_filesearch = AV._construct_file_search(
182
+ path, model_id, identifier, layer, num_id
183
+ )
184
+ return os.path.exists(av_dir) and len(glob.glob(av_filesearch)) > 0
185
+
186
+ @staticmethod
187
+ def save(
188
+ path: str,
189
+ model_id: str,
190
+ identifier: str,
191
+ layers: Union[str, List[str]],
192
+ act_tensors: Union[Tensor, List[Tensor]],
193
+ num_id: str,
194
+ ) -> None:
195
+ r"""
196
+ Saves the activation vectors `act_tensor` for the
197
+ `layer` under the manifold `path`.
198
+
199
+ Args:
200
+ path (str): The path where the activation vectors
201
+ for the `layer` are stored.
202
+ model_id (str): The name/version of the model for which layer activations
203
+ are being computed and stored.
204
+ identifier (str or None): An optional identifier for the layer
205
+ activations. Can be used to distinguish between activations for
206
+ different training batches. For example, the identifier could be
207
+ a suffix composed of a train/test label and numerical value, such
208
+ as "-src-abc".
209
+ Additionally, (abc) could be a unique identifying number. For
210
+ example, it is automatically created in
211
+ AV.generate_dataset_activations from batch index.
212
+ It assumes identifier is same for all layers if a list of
213
+ `layers` is provided.
214
+ layers (str or List of str): The layer(s) for which the activation vectors
215
+ are computed.
216
+ act_tensors (Tensor or List of Tensor): A batch of activation vectors.
217
+ This must match the dimension of `layers`.
218
+ num_id (str): string representing the batch number for which the activation
219
+ vectors are computed
220
+ """
221
+ if isinstance(layers, str):
222
+ layers = [layers]
223
+ if isinstance(act_tensors, Tensor):
224
+ act_tensors = [act_tensors]
225
+
226
+ if len(layers) != len(act_tensors):
227
+ raise ValueError("The dimension of `layers` and `act_tensors` must match!")
228
+
229
+ av_dir = AV._assemble_model_dir(path, model_id)
230
+
231
+ for i, layer in enumerate(layers):
232
+ av_save_fl_path = os.path.join(
233
+ AV._assemble_file_path(av_dir, identifier, layer), "%s.pt" % num_id
234
+ )
235
+
236
+ layer_dir = os.path.dirname(av_save_fl_path)
237
+ if not os.path.exists(layer_dir):
238
+ os.makedirs(layer_dir)
239
+ torch.save(act_tensors[i], av_save_fl_path)
240
+
241
+ @staticmethod
242
+ def load(
243
+ path: str,
244
+ model_id: str,
245
+ identifier: Optional[str] = None,
246
+ layer: Optional[str] = None,
247
+ num_id: Optional[str] = None,
248
+ ) -> AVDataset:
249
+ r"""
250
+ Loads lazily the activation vectors for given `model_id` and
251
+ `layer` saved under the `path`.
252
+
253
+ Args:
254
+ path (str): The path where the activation vectors
255
+ for the `layer` are stored.
256
+ model_id (str): The name/version of the model for which layer activations
257
+ are being computed and stored.
258
+ identifier (str or None): An optional identifier for the layer
259
+ activations. Can be used to distinguish between activations for
260
+ different training batches.
261
+ layer (str or None): The layer for which the activation vectors
262
+ are computed.
263
+ num_id (str): An optional string representing the batch number for which
264
+ the activation vectors are computed
265
+
266
+ Returns:
267
+ dataset (AV.AVDataset): AV.AVDataset that allows to iterate
268
+ over the activation vectors for given layer, identifier (if
269
+ provided), num_id (if provided). Returning an AV.AVDataset as
270
+ opposed to a DataLoader constructed from it offers more
271
+ flexibility. Raises RuntimeError if activation vectors are not
272
+ found.
273
+ """
274
+
275
+ av_save_dir = AV._assemble_model_dir(path, model_id)
276
+
277
+ if os.path.exists(av_save_dir):
278
+ avdataset = AV.AVDataset(path, model_id, identifier, layer, num_id)
279
+ return avdataset
280
+ else:
281
+ raise RuntimeError(
282
+ f"Activation vectors for model {model_id} was not found at path {path}"
283
+ )
284
+
285
+ @staticmethod
286
+ def _manage_loading_layers(
287
+ path: str,
288
+ model_id: str,
289
+ layers: Union[str, List[str]],
290
+ load_from_disk: bool = True,
291
+ identifier: Optional[str] = None,
292
+ num_id: Optional[str] = None,
293
+ ) -> List[str]:
294
+ r"""
295
+ Returns unsaved layers, and deletes saved layers if load_from_disk is False.
296
+
297
+ Args:
298
+ path (str): The path where the activation vectors
299
+ for the `layer` are stored.
300
+ model_id (str): The name/version of the model for which layer activations
301
+ are being computed and stored.
302
+ layers (str or List of str): The layer(s) for which the activation vectors
303
+ are computed.
304
+ identifier (str or None): An optional identifier for the layer
305
+ activations. Can be used to distinguish between activations for
306
+ different training batches.
307
+ num_id (str): An optional string representing the batch number for which the
308
+ activation vectors are computed
309
+
310
+ Returns:
311
+ List of layer names for which activations should be generated
312
+ """
313
+
314
+ layers = [layers] if isinstance(layers, str) else layers
315
+ unsaved_layers = []
316
+
317
+ if load_from_disk:
318
+ for layer in layers:
319
+ if not AV.exists(path, model_id, identifier, layer, num_id):
320
+ unsaved_layers.append(layer)
321
+ else:
322
+ unsaved_layers = layers
323
+ warnings.warn(
324
+ "Overwriting activations: load_from_disk is set to False. Removing all "
325
+ f"activations matching specified parameters {{path: {path}, "
326
+ f"model_id: {model_id}, layers: {layers}, identifier: {identifier}}} "
327
+ "before generating new activations."
328
+ )
329
+ for layer in layers:
330
+ files = glob.glob(
331
+ AV._construct_file_search(path, model_id, identifier, layer)
332
+ )
333
+ for filename in files:
334
+ os.remove(filename)
335
+
336
+ return unsaved_layers
337
+
338
+ @staticmethod
339
+ def _compute_and_save_activations(
340
+ path: str,
341
+ model: Module,
342
+ model_id: str,
343
+ layers: Union[str, List[str]],
344
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
345
+ identifier: str,
346
+ num_id: str,
347
+ additional_forward_args: Any = None,
348
+ load_from_disk: bool = True,
349
+ ) -> None:
350
+ r"""
351
+ Computes layer activations for the given inputs and specified `layers`
352
+
353
+ Args:
354
+ path (str): The path where the activation vectors
355
+ for the `layer` are stored.
356
+ model (torch.nn.Module): An instance of pytorch model. This model should
357
+ define all of its layers as attributes of the model.
358
+ model_id (str): The name/version of the model for which layer activations
359
+ are being computed and stored.
360
+ layers (str or List of str): The layer(s) for which the activation vectors
361
+ are computed.
362
+ inputs (tensor or tuple of tensors): Batch of examples for
363
+ which influential instances are computed. They are passed to the
364
+ input `model`. The first dimension in `inputs` tensor or tuple of
365
+ tensors corresponds to the batch size.
366
+ identifier (str or None): An optional identifier for the layer
367
+ activations. Can be used to distinguish between activations for
368
+ different training batches.
369
+ num_id (str): An required string representing the batch number for which the
370
+ activation vectors are computed
371
+ additional_forward_args (optional): Additional arguments that will be
372
+ passed to `model` after inputs.
373
+ Default: None
374
+ load_from_disk (bool): Forces function to regenerate activations if False.
375
+ Default: True
376
+ """
377
+ unsaved_layers = AV._manage_loading_layers(
378
+ path,
379
+ model_id,
380
+ layers,
381
+ load_from_disk,
382
+ identifier,
383
+ num_id,
384
+ )
385
+ layer_modules = [
386
+ common._get_module_from_name(model, layer) for layer in unsaved_layers
387
+ ]
388
+ if len(unsaved_layers) > 0:
389
+ layer_act = LayerActivation(model, layer_modules)
390
+ new_activations = layer_act.attribute.__wrapped__( # type: ignore
391
+ layer_act, inputs, additional_forward_args
392
+ )
393
+ AV.save(path, model_id, identifier, unsaved_layers, new_activations, num_id)
394
+
395
+ @staticmethod
396
+ def _unpack_data(data: Union[Any, Tuple[Any, Any]]) -> Any:
397
+ r"""
398
+ Helper to extract input from labels when getting items from a Dataset. Assumes
399
+ that data is either a single value, or a tuple containing two elements.
400
+ The input could itself be a Tuple containing multiple values. If your
401
+ dataset returns a Tuple with more than 2 elements, please reformat it such that
402
+ all inputs are formatted into a tuple stored at the first position.
403
+ """
404
+ if isinstance(data, tuple) or isinstance(data, list):
405
+ data = data[0]
406
+ return data
407
+
408
+ r"""TODO:
409
+ 1. Can propagate saving labels along with activations.
410
+ 2. Use of additional_forward_args when sourcing from dataset?
411
+ """
412
+
413
+ @staticmethod
414
+ def generate_dataset_activations(
415
+ path: str,
416
+ model: Module,
417
+ model_id: str,
418
+ layers: Union[str, List[str]],
419
+ dataloader: DataLoader,
420
+ identifier: str = "default",
421
+ load_from_disk: bool = True,
422
+ return_activations: bool = False,
423
+ ) -> Optional[Union[AVDataset, List[AVDataset]]]:
424
+ r"""
425
+ Computes layer activations for a source dataset and specified `layers`. Assumes
426
+ that the dataset returns a single value, or a tuple containing two elements
427
+ (see AV._unpack_data).
428
+
429
+ Args:
430
+ path (str): The path where the activation vectors
431
+ for the `layer` are stored.
432
+ module (torch.nn.Module): An instance of pytorch model. This model should
433
+ define all of its layers as attributes of the model.
434
+ model_id (str): The name/version of the model for which layer activations
435
+ are being computed and stored.
436
+ layers (str or List of str): The layer(s) for which the activation vectors
437
+ are computed.
438
+ dataloader (torch.utils.data.DataLoader): DataLoader that yields Dataset
439
+ for which influential instances are computed. They are passed to
440
+ input `model`.
441
+ identifier (str or None): An identifier for the layer
442
+ activations. Can be used to distinguish between activations for
443
+ different training batches.
444
+ Default: "default"
445
+ load_from_disk (bool): Forces function to regenerate activations if False.
446
+ Default: True
447
+ return_activations (bool, optional): Whether to return the activations.
448
+ Default: False
449
+ Returns: If `return_activations == True`, returns a single `AVDataset` if
450
+ `layers` is a str, otherwise, a list of `AVDataset`s of the length
451
+ of `layers`, where each element corresponds to a layer. In either
452
+ case, `AVDataset`'s represent the activations for a single layer,
453
+ over the entire `dataloader`. If `return_activations == False`,
454
+ does not return anything.
455
+
456
+ """
457
+
458
+ unsaved_layers = AV._manage_loading_layers(
459
+ path,
460
+ model_id,
461
+ layers,
462
+ load_from_disk,
463
+ identifier,
464
+ )
465
+ if len(unsaved_layers) > 0:
466
+ for i, data in enumerate(dataloader):
467
+ AV._compute_and_save_activations(
468
+ path,
469
+ model,
470
+ model_id,
471
+ layers,
472
+ AV._unpack_data(data),
473
+ identifier,
474
+ str(i),
475
+ )
476
+
477
+ if not return_activations:
478
+ return None
479
+ if isinstance(layers, str):
480
+ return AV.load(path, model_id, identifier, layers)
481
+ else:
482
+ return [AV.load(path, model_id, identifier, layer) for layer in layers]
483
+
484
+ @staticmethod
485
+ def sort_files(files: List[str]) -> List[str]:
486
+ r"""
487
+ Utility for sorting files based on natural sorting instead of the default
488
+ lexigraphical sort.
489
+ """
490
+
491
+ def split_alphanum(s):
492
+ r"""
493
+ Splits string into a list of strings and numbers
494
+ "z23a" -> ["z", 23, "a"]
495
+ """
496
+
497
+ return [int(x) if x.isdigit() else x for x in re.split("([0-9]+)", s)]
498
+
499
+ return sorted(files, key=split_alphanum)
captum/_utils/common.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import typing
3
+ from enum import Enum
4
+ from functools import reduce
5
+ from inspect import signature
6
+ from typing import Any, Callable, cast, Dict, List, overload, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from captum._utils.typing import (
11
+ BaselineType,
12
+ Literal,
13
+ TargetType,
14
+ TensorOrTupleOfTensorsGeneric,
15
+ TupleOrTensorOrBoolGeneric,
16
+ )
17
+ from torch import device, Tensor
18
+ from torch.nn import Module
19
+
20
+
21
+ class ExpansionTypes(Enum):
22
+ repeat = 1
23
+ repeat_interleave = 2
24
+
25
+
26
+ def safe_div(
27
+ numerator: Tensor,
28
+ denom: Union[Tensor, int, float],
29
+ default_denom: Union[Tensor, int, float] = 1.0,
30
+ ) -> Tensor:
31
+ r"""
32
+ A simple utility function to perform `numerator / denom`
33
+ if the statement is undefined => result will be `numerator / default_denorm`
34
+ """
35
+ if isinstance(denom, (int, float)):
36
+ return numerator / (denom if denom != 0 else default_denom)
37
+
38
+ # convert default_denom to tensor if it is float
39
+ if not torch.is_tensor(default_denom):
40
+ default_denom = torch.tensor(
41
+ default_denom, dtype=denom.dtype, device=denom.device
42
+ )
43
+
44
+ return numerator / torch.where(denom != 0, denom, default_denom)
45
+
46
+
47
+ @typing.overload
48
+ def _is_tuple(inputs: Tensor) -> Literal[False]:
49
+ ...
50
+
51
+
52
+ @typing.overload
53
+ def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]:
54
+ ...
55
+
56
+
57
+ def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
58
+ return isinstance(inputs, tuple)
59
+
60
+
61
+ def _validate_target(num_samples: int, target: TargetType) -> None:
62
+ if isinstance(target, list) or (
63
+ isinstance(target, torch.Tensor) and torch.numel(target) > 1
64
+ ):
65
+ assert num_samples == len(target), (
66
+ "The number of samples provied in the"
67
+ "input {} does not match with the number of targets. {}".format(
68
+ num_samples, len(target)
69
+ )
70
+ )
71
+
72
+
73
+ def _validate_input(
74
+ inputs: Tuple[Tensor, ...],
75
+ baselines: Tuple[Union[Tensor, int, float], ...],
76
+ draw_baseline_from_distrib: bool = False,
77
+ ) -> None:
78
+ assert len(inputs) == len(baselines), (
79
+ "Input and baseline must have the same "
80
+ "dimensions, baseline has {} features whereas input has {}.".format(
81
+ len(baselines), len(inputs)
82
+ )
83
+ )
84
+
85
+ for input, baseline in zip(inputs, baselines):
86
+ if draw_baseline_from_distrib:
87
+ assert (
88
+ isinstance(baseline, (int, float))
89
+ or input.shape[1:] == baseline.shape[1:]
90
+ ), (
91
+ "The samples in input and baseline batches must have"
92
+ " the same shape or the baseline corresponding to the"
93
+ " input tensor must be a scalar."
94
+ " Found baseline: {} and input: {} ".format(baseline, input)
95
+ )
96
+ else:
97
+ assert (
98
+ isinstance(baseline, (int, float))
99
+ or input.shape == baseline.shape
100
+ or baseline.shape[0] == 1
101
+ ), (
102
+ "Baseline can be provided as a tensor for just one input and"
103
+ " broadcasted to the batch or input and baseline must have the"
104
+ " same shape or the baseline corresponding to each input tensor"
105
+ " must be a scalar. Found baseline: {} and input: {}".format(
106
+ baseline, input
107
+ )
108
+ )
109
+
110
+
111
+ def _zeros(inputs: Tuple[Tensor, ...]) -> Tuple[int, ...]:
112
+ r"""
113
+ Takes a tuple of tensors as input and returns a tuple that has the same
114
+ length as `inputs` with each element as the integer 0.
115
+ """
116
+ return tuple(0 if input.dtype is not torch.bool else False for input in inputs)
117
+
118
+
119
+ def _format_baseline(
120
+ baselines: BaselineType, inputs: Tuple[Tensor, ...]
121
+ ) -> Tuple[Union[Tensor, int, float], ...]:
122
+ if baselines is None:
123
+ return _zeros(inputs)
124
+
125
+ if not isinstance(baselines, tuple):
126
+ baselines = (baselines,)
127
+
128
+ for baseline in baselines:
129
+ assert isinstance(
130
+ baseline, (torch.Tensor, int, float)
131
+ ), "baseline input argument must be either a torch.Tensor or a number \
132
+ however {} detected".format(
133
+ type(baseline)
134
+ )
135
+
136
+ return baselines
137
+
138
+
139
+ @overload
140
+ def _format_tensor_into_tuples(inputs: None) -> None:
141
+ ...
142
+
143
+
144
+ @overload
145
+ def _format_tensor_into_tuples(
146
+ inputs: Union[Tensor, Tuple[Tensor, ...]]
147
+ ) -> Tuple[Tensor, ...]:
148
+ ...
149
+
150
+
151
+ def _format_tensor_into_tuples(
152
+ inputs: Union[None, Tensor, Tuple[Tensor, ...]]
153
+ ) -> Union[None, Tuple[Tensor, ...]]:
154
+ if inputs is None:
155
+ return None
156
+ if not isinstance(inputs, tuple):
157
+ assert isinstance(
158
+ inputs, torch.Tensor
159
+ ), "`inputs` must have type " "torch.Tensor but {} found: ".format(type(inputs))
160
+ inputs = (inputs,)
161
+ return inputs
162
+
163
+
164
+ def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any:
165
+ return (
166
+ inputs
167
+ if (isinstance(inputs, tuple) or isinstance(inputs, list)) and unpack_inputs
168
+ else (inputs,)
169
+ )
170
+
171
+
172
+ def _format_float_or_tensor_into_tuples(
173
+ inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]
174
+ ) -> Tuple[Union[float, Tensor], ...]:
175
+ if not isinstance(inputs, tuple):
176
+ assert isinstance(
177
+ inputs, (torch.Tensor, float)
178
+ ), "`inputs` must have type float or torch.Tensor but {} found: ".format(
179
+ type(inputs)
180
+ )
181
+ inputs = (inputs,)
182
+ return inputs
183
+
184
+
185
+ @overload
186
+ def _format_additional_forward_args(additional_forward_args: None) -> None:
187
+ ...
188
+
189
+
190
+ @overload
191
+ def _format_additional_forward_args(
192
+ additional_forward_args: Union[Tensor, Tuple]
193
+ ) -> Tuple:
194
+ ...
195
+
196
+
197
+ @overload
198
+ def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, Tuple]:
199
+ ...
200
+
201
+
202
+ def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, Tuple]:
203
+ if additional_forward_args is not None and not isinstance(
204
+ additional_forward_args, tuple
205
+ ):
206
+ additional_forward_args = (additional_forward_args,)
207
+ return additional_forward_args
208
+
209
+
210
+ def _expand_additional_forward_args(
211
+ additional_forward_args: Any,
212
+ n_steps: int,
213
+ expansion_type: ExpansionTypes = ExpansionTypes.repeat,
214
+ ) -> Union[None, Tuple]:
215
+ def _expand_tensor_forward_arg(
216
+ additional_forward_arg: Tensor,
217
+ n_steps: int,
218
+ expansion_type: ExpansionTypes = ExpansionTypes.repeat,
219
+ ) -> Tensor:
220
+ if len(additional_forward_arg.size()) == 0:
221
+ return additional_forward_arg
222
+ if expansion_type == ExpansionTypes.repeat:
223
+ return torch.cat([additional_forward_arg] * n_steps, dim=0)
224
+ elif expansion_type == ExpansionTypes.repeat_interleave:
225
+ return additional_forward_arg.repeat_interleave(n_steps, dim=0)
226
+ else:
227
+ raise NotImplementedError(
228
+ "Currently only `repeat` and `repeat_interleave`"
229
+ " expansion_types are supported"
230
+ )
231
+
232
+ if additional_forward_args is None:
233
+ return None
234
+
235
+ return tuple(
236
+ _expand_tensor_forward_arg(additional_forward_arg, n_steps, expansion_type)
237
+ if isinstance(additional_forward_arg, torch.Tensor)
238
+ else additional_forward_arg
239
+ for additional_forward_arg in additional_forward_args
240
+ )
241
+
242
+
243
+ def _expand_target(
244
+ target: TargetType,
245
+ n_steps: int,
246
+ expansion_type: ExpansionTypes = ExpansionTypes.repeat,
247
+ ) -> TargetType:
248
+ if isinstance(target, list):
249
+ if expansion_type == ExpansionTypes.repeat:
250
+ return target * n_steps
251
+ elif expansion_type == ExpansionTypes.repeat_interleave:
252
+ expanded_target = []
253
+ for i in target:
254
+ expanded_target.extend([i] * n_steps)
255
+ return cast(Union[List[Tuple[int, ...]], List[int]], expanded_target)
256
+ else:
257
+ raise NotImplementedError(
258
+ "Currently only `repeat` and `repeat_interleave`"
259
+ " expansion_types are supported"
260
+ )
261
+
262
+ elif isinstance(target, torch.Tensor) and torch.numel(target) > 1:
263
+ if expansion_type == ExpansionTypes.repeat:
264
+ return torch.cat([target] * n_steps, dim=0)
265
+ elif expansion_type == ExpansionTypes.repeat_interleave:
266
+ return target.repeat_interleave(n_steps, dim=0)
267
+ else:
268
+ raise NotImplementedError(
269
+ "Currently only `repeat` and `repeat_interleave`"
270
+ " expansion_types are supported"
271
+ )
272
+
273
+ return target
274
+
275
+
276
+ def _expand_feature_mask(
277
+ feature_mask: Union[Tensor, Tuple[Tensor, ...]], n_samples: int
278
+ ):
279
+ is_feature_mask_tuple = _is_tuple(feature_mask)
280
+ feature_mask = _format_tensor_into_tuples(feature_mask)
281
+ feature_mask_new = tuple(
282
+ feature_mask_elem.repeat_interleave(n_samples, dim=0)
283
+ if feature_mask_elem.size(0) > 1
284
+ else feature_mask_elem
285
+ for feature_mask_elem in feature_mask
286
+ )
287
+ return _format_output(is_feature_mask_tuple, feature_mask_new)
288
+
289
+
290
+ def _expand_and_update_baselines(
291
+ inputs: Tuple[Tensor, ...],
292
+ n_samples: int,
293
+ kwargs: dict,
294
+ draw_baseline_from_distrib: bool = False,
295
+ ):
296
+ def get_random_baseline_indices(bsz, baseline):
297
+ num_ref_samples = baseline.shape[0]
298
+ return np.random.choice(num_ref_samples, n_samples * bsz).tolist()
299
+
300
+ # expand baselines to match the sizes of input
301
+ if "baselines" not in kwargs:
302
+ return
303
+
304
+ baselines = kwargs["baselines"]
305
+ baselines = _format_baseline(baselines, inputs)
306
+ _validate_input(
307
+ inputs, baselines, draw_baseline_from_distrib=draw_baseline_from_distrib
308
+ )
309
+
310
+ if draw_baseline_from_distrib:
311
+ bsz = inputs[0].shape[0]
312
+ baselines = tuple(
313
+ baseline[get_random_baseline_indices(bsz, baseline)]
314
+ if isinstance(baseline, torch.Tensor)
315
+ else baseline
316
+ for baseline in baselines
317
+ )
318
+ else:
319
+ baselines = tuple(
320
+ baseline.repeat_interleave(n_samples, dim=0)
321
+ if isinstance(baseline, torch.Tensor)
322
+ and baseline.shape[0] == input.shape[0]
323
+ and baseline.shape[0] > 1
324
+ else baseline
325
+ for input, baseline in zip(inputs, baselines)
326
+ )
327
+ # update kwargs with expanded baseline
328
+ kwargs["baselines"] = baselines
329
+
330
+
331
+ def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict):
332
+ if "additional_forward_args" not in kwargs:
333
+ return
334
+ additional_forward_args = kwargs["additional_forward_args"]
335
+ additional_forward_args = _format_additional_forward_args(additional_forward_args)
336
+ if additional_forward_args is None:
337
+ return
338
+ additional_forward_args = _expand_additional_forward_args(
339
+ additional_forward_args,
340
+ n_samples,
341
+ expansion_type=ExpansionTypes.repeat_interleave,
342
+ )
343
+ # update kwargs with expanded baseline
344
+ kwargs["additional_forward_args"] = additional_forward_args
345
+
346
+
347
+ def _expand_and_update_target(n_samples: int, kwargs: dict):
348
+ if "target" not in kwargs:
349
+ return
350
+ target = kwargs["target"]
351
+ target = _expand_target(
352
+ target, n_samples, expansion_type=ExpansionTypes.repeat_interleave
353
+ )
354
+ # update kwargs with expanded baseline
355
+ kwargs["target"] = target
356
+
357
+
358
+ def _expand_and_update_feature_mask(n_samples: int, kwargs: dict):
359
+ if "feature_mask" not in kwargs:
360
+ return
361
+
362
+ feature_mask = kwargs["feature_mask"]
363
+ if feature_mask is None:
364
+ return
365
+
366
+ feature_mask = _expand_feature_mask(feature_mask, n_samples)
367
+ kwargs["feature_mask"] = feature_mask
368
+
369
+
370
+ @typing.overload
371
+ def _format_output(
372
+ is_inputs_tuple: Literal[True], output: Tuple[Tensor, ...]
373
+ ) -> Tuple[Tensor, ...]:
374
+ ...
375
+
376
+
377
+ @typing.overload
378
+ def _format_output(
379
+ is_inputs_tuple: Literal[False], output: Tuple[Tensor, ...]
380
+ ) -> Tensor:
381
+ ...
382
+
383
+
384
+ @typing.overload
385
+ def _format_output(
386
+ is_inputs_tuple: bool, output: Tuple[Tensor, ...]
387
+ ) -> Union[Tensor, Tuple[Tensor, ...]]:
388
+ ...
389
+
390
+
391
+ def _format_output(
392
+ is_inputs_tuple: bool, output: Tuple[Tensor, ...]
393
+ ) -> Union[Tensor, Tuple[Tensor, ...]]:
394
+ r"""
395
+ In case input is a tensor and the output is returned in form of a
396
+ tuple we take the first element of the output's tuple to match the
397
+ same shape signatues of the inputs
398
+ """
399
+ assert isinstance(output, tuple), "Output must be in shape of a tuple"
400
+ assert is_inputs_tuple or len(output) == 1, (
401
+ "The input is a single tensor however the output isn't."
402
+ "The number of output tensors is: {}".format(len(output))
403
+ )
404
+ return output if is_inputs_tuple else output[0]
405
+
406
+
407
+ @typing.overload
408
+ def _format_outputs(
409
+ is_multiple_inputs: Literal[False], outputs: List[Tuple[Tensor, ...]]
410
+ ) -> Union[Tensor, Tuple[Tensor, ...]]:
411
+ ...
412
+
413
+
414
+ @typing.overload
415
+ def _format_outputs(
416
+ is_multiple_inputs: Literal[True], outputs: List[Tuple[Tensor, ...]]
417
+ ) -> List[Union[Tensor, Tuple[Tensor, ...]]]:
418
+ ...
419
+
420
+
421
+ @typing.overload
422
+ def _format_outputs(
423
+ is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
424
+ ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
425
+ ...
426
+
427
+
428
+ def _format_outputs(
429
+ is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
430
+ ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
431
+ assert isinstance(outputs, list), "Outputs must be a list"
432
+ assert is_multiple_inputs or len(outputs) == 1, (
433
+ "outputs should contain multiple inputs or have a single output"
434
+ f"however the number of outputs is: {len(outputs)}"
435
+ )
436
+
437
+ return (
438
+ [_format_output(len(output) > 1, output) for output in outputs]
439
+ if is_multiple_inputs
440
+ else _format_output(len(outputs[0]) > 1, outputs[0])
441
+ )
442
+
443
+
444
+ def _run_forward(
445
+ forward_func: Callable,
446
+ inputs: Any,
447
+ target: TargetType = None,
448
+ additional_forward_args: Any = None,
449
+ ) -> Tensor:
450
+ forward_func_args = signature(forward_func).parameters
451
+ if len(forward_func_args) == 0:
452
+ output = forward_func()
453
+ return output if target is None else _select_targets(output, target)
454
+
455
+ # make everything a tuple so that it is easy to unpack without
456
+ # using if-statements
457
+ inputs = _format_inputs(inputs)
458
+ additional_forward_args = _format_additional_forward_args(additional_forward_args)
459
+
460
+ output = forward_func(
461
+ *(*inputs, *additional_forward_args)
462
+ if additional_forward_args is not None
463
+ else inputs
464
+ )
465
+ return _select_targets(output, target)
466
+
467
+
468
+ def _select_targets(output: Tensor, target: TargetType) -> Tensor:
469
+ if target is None:
470
+ return output
471
+
472
+ num_examples = output.shape[0]
473
+ dims = len(output.shape)
474
+ device = output.device
475
+ if isinstance(target, (int, tuple)):
476
+ return _verify_select_column(output, target)
477
+ elif isinstance(target, torch.Tensor):
478
+ if torch.numel(target) == 1 and isinstance(target.item(), int):
479
+ return _verify_select_column(output, cast(int, target.item()))
480
+ elif len(target.shape) == 1 and torch.numel(target) == num_examples:
481
+ assert dims == 2, "Output must be 2D to select tensor of targets."
482
+ return torch.gather(output, 1, target.reshape(len(output), 1))
483
+ else:
484
+ raise AssertionError(
485
+ "Tensor target dimension %r is not valid. %r"
486
+ % (target.shape, output.shape)
487
+ )
488
+ elif isinstance(target, list):
489
+ assert len(target) == num_examples, "Target list length does not match output!"
490
+ if isinstance(target[0], int):
491
+ assert dims == 2, "Output must be 2D to select tensor of targets."
492
+ return torch.gather(
493
+ output, 1, torch.tensor(target, device=device).reshape(len(output), 1)
494
+ )
495
+ elif isinstance(target[0], tuple):
496
+ return torch.stack(
497
+ [
498
+ output[(i,) + cast(Tuple, targ_elem)]
499
+ for i, targ_elem in enumerate(target)
500
+ ]
501
+ )
502
+ else:
503
+ raise AssertionError("Target element type in list is not valid.")
504
+ else:
505
+ raise AssertionError("Target type %r is not valid." % target)
506
+
507
+
508
+ def _contains_slice(target: Union[int, Tuple[Union[int, slice], ...]]) -> bool:
509
+ if isinstance(target, tuple):
510
+ for index in target:
511
+ if isinstance(index, slice):
512
+ return True
513
+ return False
514
+ return isinstance(target, slice)
515
+
516
+
517
+ def _verify_select_column(
518
+ output: Tensor, target: Union[int, Tuple[Union[int, slice], ...]]
519
+ ) -> Tensor:
520
+ target = (target,) if isinstance(target, int) else target
521
+ assert (
522
+ len(target) <= len(output.shape) - 1
523
+ ), "Cannot choose target column with output shape %r." % (output.shape,)
524
+ return output[(slice(None), *target)]
525
+
526
+
527
+ def _verify_select_neuron(
528
+ layer_output: Tuple[Tensor, ...],
529
+ selector: Union[int, Tuple[Union[int, slice], ...], Callable],
530
+ ) -> Tensor:
531
+ if callable(selector):
532
+ return selector(layer_output if len(layer_output) > 1 else layer_output[0])
533
+
534
+ assert len(layer_output) == 1, (
535
+ "Cannot select neuron index from layer with multiple tensors,"
536
+ "consider providing a neuron selector function instead."
537
+ )
538
+
539
+ selected_neurons = _verify_select_column(layer_output[0], selector)
540
+ if _contains_slice(selector):
541
+ return selected_neurons.reshape(selected_neurons.shape[0], -1).sum(1)
542
+ return selected_neurons
543
+
544
+
545
+ def _extract_device(
546
+ module: Module,
547
+ hook_inputs: Union[None, Tensor, Tuple[Tensor, ...]],
548
+ hook_outputs: Union[None, Tensor, Tuple[Tensor, ...]],
549
+ ) -> device:
550
+ params = list(module.parameters())
551
+ if (
552
+ (hook_inputs is None or len(hook_inputs) == 0)
553
+ and (hook_outputs is None or len(hook_outputs) == 0)
554
+ and len(params) == 0
555
+ ):
556
+ raise RuntimeError(
557
+ """Unable to extract device information for the module
558
+ {}. Both inputs and outputs to the forward hook and
559
+ `module.parameters()` are empty.
560
+ The reason that the inputs to the forward hook are empty
561
+ could be due to the fact that the arguments to that
562
+ module {} are all named and are passed as named
563
+ variables to its forward function.
564
+ """.format(
565
+ module, module
566
+ )
567
+ )
568
+ if hook_inputs is not None and len(hook_inputs) > 0:
569
+ return hook_inputs[0].device
570
+ if hook_outputs is not None and len(hook_outputs) > 0:
571
+ return hook_outputs[0].device
572
+
573
+ return params[0].device
574
+
575
+
576
+ def _reduce_list(
577
+ val_list: List[TupleOrTensorOrBoolGeneric],
578
+ red_func: Callable[[List], Any] = torch.cat,
579
+ ) -> TupleOrTensorOrBoolGeneric:
580
+ """
581
+ Applies reduction function to given list. If each element in the list is
582
+ a Tensor, applies reduction function to all elements of the list, and returns
583
+ the output Tensor / value. If each element is a boolean, apply any method (or).
584
+ If each element is a tuple, applies reduction
585
+ function to corresponding elements of each tuple in the list, and returns
586
+ tuple of reduction function outputs with length matching the length of tuple
587
+ val_list[0]. It is assumed that all tuples in the list have the same length
588
+ and red_func can be applied to all elements in each corresponding position.
589
+ """
590
+ assert len(val_list) > 0, "Cannot reduce empty list!"
591
+ if isinstance(val_list[0], torch.Tensor):
592
+ first_device = val_list[0].device
593
+ return red_func([elem.to(first_device) for elem in val_list])
594
+ elif isinstance(val_list[0], bool):
595
+ return any(val_list)
596
+ elif isinstance(val_list[0], tuple):
597
+ final_out = []
598
+ for i in range(len(val_list[0])):
599
+ final_out.append(
600
+ _reduce_list([val_elem[i] for val_elem in val_list], red_func)
601
+ )
602
+ else:
603
+ raise AssertionError(
604
+ "Elements to be reduced can only be"
605
+ "either Tensors or tuples containing Tensors."
606
+ )
607
+ return tuple(final_out)
608
+
609
+
610
+ def _sort_key_list(
611
+ keys: List[device], device_ids: Union[None, List[int]] = None
612
+ ) -> List[device]:
613
+ """
614
+ Sorts list of torch devices (keys) by given index list, device_ids. If keys
615
+ contains only one device, then the list is returned unchanged. If keys
616
+ contains a device for which the id is not contained in device_ids, then
617
+ an error is returned. This method is used to identify the order of DataParallel
618
+ batched devices, given the device ID ordering.
619
+ """
620
+ if len(keys) == 1:
621
+ return keys
622
+ id_dict: Dict[int, device] = {}
623
+ assert device_ids is not None, "Device IDs must be provided with multiple devices."
624
+ for key in keys:
625
+ if key.index in id_dict:
626
+ raise AssertionError("Duplicate CUDA Device ID identified in device list.")
627
+ id_dict[key.index] = key
628
+
629
+ out_list = [
630
+ id_dict[device_id]
631
+ for device_id in filter(lambda device_id: device_id in id_dict, device_ids)
632
+ ]
633
+
634
+ assert len(out_list) == len(keys), "Given Device ID List does not match"
635
+ "devices with computed tensors."
636
+
637
+ return out_list
638
+
639
+
640
+ def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor:
641
+ if isinstance(inp, Tensor):
642
+ return inp.flatten()
643
+ return torch.cat([single_inp.flatten() for single_inp in inp])
644
+
645
+
646
+ def _get_module_from_name(model: Module, layer_name: str) -> Any:
647
+ r"""
648
+ Returns the module (layer) object, given its (string) name
649
+ in the model.
650
+
651
+ Args:
652
+ name (str): Module or nested modules name string in self.model
653
+
654
+ Returns:
655
+ The module (layer) in self.model.
656
+ """
657
+
658
+ return reduce(getattr, layer_name.split("."), model)
659
+
660
+
661
+ def _register_backward_hook(
662
+ module: Module, hook: Callable, attr_obj: Any
663
+ ) -> torch.utils.hooks.RemovableHandle:
664
+ # Special case for supporting output attributions for neuron methods
665
+ # This can be removed after deprecation of neuron output attributions
666
+ # for NeuronDeepLift, NeuronDeconvolution, and NeuronGuidedBackprop
667
+ # in v0.6.0
668
+ if (
669
+ hasattr(attr_obj, "skip_new_hook_layer")
670
+ and attr_obj.skip_new_hook_layer == module
671
+ ):
672
+ return module.register_backward_hook(hook)
673
+
674
+ if torch.__version__ >= "1.9":
675
+ # Only supported for torch >= 1.9
676
+ return module.register_full_backward_hook(hook)
677
+ else:
678
+ # Fallback for previous versions of PyTorch
679
+ return module.register_backward_hook(hook)
captum/_utils/gradient.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import threading
3
+ import typing
4
+ import warnings
5
+ from collections import defaultdict
6
+ from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ from captum._utils.common import (
10
+ _reduce_list,
11
+ _run_forward,
12
+ _sort_key_list,
13
+ _verify_select_neuron,
14
+ )
15
+ from captum._utils.sample_gradient import SampleGradientWrapper
16
+ from captum._utils.typing import (
17
+ Literal,
18
+ ModuleOrModuleList,
19
+ TargetType,
20
+ TensorOrTupleOfTensorsGeneric,
21
+ )
22
+ from torch import device, Tensor
23
+ from torch.nn import Module
24
+
25
+
26
+ def apply_gradient_requirements(
27
+ inputs: Tuple[Tensor, ...], warn: bool = True
28
+ ) -> List[bool]:
29
+ """
30
+ Iterates through tuple on input tensors and sets requires_grad to be true on
31
+ each Tensor, and ensures all grads are set to zero. To ensure that the input
32
+ is returned to its initial state, a list of flags representing whether or not
33
+ a tensor originally required grad is returned.
34
+ """
35
+ assert isinstance(
36
+ inputs, tuple
37
+ ), "Inputs should be wrapped in a tuple prior to preparing for gradients"
38
+ grad_required = []
39
+ for index, input in enumerate(inputs):
40
+ assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor"
41
+ grad_required.append(input.requires_grad)
42
+ inputs_dtype = input.dtype
43
+ # Note: torch 1.2 doesn't support is_complex for dtype that's why we check
44
+ # on the existance of is_complex method.
45
+ if not inputs_dtype.is_floating_point and not (
46
+ hasattr(inputs_dtype, "is_complex") and inputs_dtype.is_complex
47
+ ):
48
+ if warn:
49
+ warnings.warn(
50
+ """Input Tensor %d has a dtype of %s.
51
+ Gradients cannot be activated
52
+ for these data types."""
53
+ % (index, str(inputs_dtype))
54
+ )
55
+ elif not input.requires_grad:
56
+ if warn:
57
+ warnings.warn(
58
+ "Input Tensor %d did not already require gradients, "
59
+ "required_grads has been set automatically." % index
60
+ )
61
+ input.requires_grad_()
62
+ return grad_required
63
+
64
+
65
+ def undo_gradient_requirements(
66
+ inputs: Tuple[Tensor, ...], grad_required: List[bool]
67
+ ) -> None:
68
+ """
69
+ Iterates through list of tensors, zeros each gradient, and sets required
70
+ grad to false if the corresponding index in grad_required is False.
71
+ This method is used to undo the effects of prepare_gradient_inputs, making
72
+ grads not required for any input tensor that did not initially require
73
+ gradients.
74
+ """
75
+
76
+ assert isinstance(
77
+ inputs, tuple
78
+ ), "Inputs should be wrapped in a tuple prior to preparing for gradients."
79
+ assert len(inputs) == len(
80
+ grad_required
81
+ ), "Input tuple length should match gradient mask."
82
+ for index, input in enumerate(inputs):
83
+ assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor"
84
+ if not grad_required[index]:
85
+ input.requires_grad_(False)
86
+
87
+
88
+ def compute_gradients(
89
+ forward_fn: Callable,
90
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
91
+ target_ind: TargetType = None,
92
+ additional_forward_args: Any = None,
93
+ ) -> Tuple[Tensor, ...]:
94
+ r"""
95
+ Computes gradients of the output with respect to inputs for an
96
+ arbitrary forward function.
97
+
98
+ Args:
99
+
100
+ forward_fn: forward function. This can be for example model's
101
+ forward function.
102
+ input: Input at which gradients are evaluated,
103
+ will be passed to forward_fn.
104
+ target_ind: Index of the target class for which gradients
105
+ must be computed (classification only).
106
+ additional_forward_args: Additional input arguments that forward
107
+ function requires. It takes an empty tuple (no additional
108
+ arguments) if no additional arguments are required
109
+ """
110
+ with torch.autograd.set_grad_enabled(True):
111
+ # runs forward pass
112
+ outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
113
+ assert outputs[0].numel() == 1, (
114
+ "Target not provided when necessary, cannot"
115
+ " take gradient with respect to multiple outputs."
116
+ )
117
+ # torch.unbind(forward_out) is a list of scalar tensor tuples and
118
+ # contains batch_size * #steps elements
119
+ grads = torch.autograd.grad(torch.unbind(outputs), inputs)
120
+ return grads
121
+
122
+
123
+ def _neuron_gradients(
124
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
125
+ saved_layer: Dict[device, Tuple[Tensor, ...]],
126
+ key_list: List[device],
127
+ gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
128
+ ) -> Tuple[Tensor, ...]:
129
+ with torch.autograd.set_grad_enabled(True):
130
+ gradient_tensors = []
131
+ for key in key_list:
132
+ current_out_tensor = _verify_select_neuron(
133
+ saved_layer[key], gradient_neuron_selector
134
+ )
135
+ gradient_tensors.append(
136
+ torch.autograd.grad(
137
+ torch.unbind(current_out_tensor)
138
+ if current_out_tensor.numel() > 1
139
+ else current_out_tensor,
140
+ inputs,
141
+ )
142
+ )
143
+ _total_gradients = _reduce_list(gradient_tensors, sum)
144
+ return _total_gradients
145
+
146
+
147
+ @typing.overload
148
+ def _forward_layer_eval(
149
+ forward_fn: Callable,
150
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
151
+ layer: Module,
152
+ additional_forward_args: Any = None,
153
+ device_ids: Union[None, List[int]] = None,
154
+ attribute_to_layer_input: bool = False,
155
+ grad_enabled: bool = False,
156
+ ) -> Tuple[Tensor, ...]:
157
+ ...
158
+
159
+
160
+ @typing.overload
161
+ def _forward_layer_eval(
162
+ forward_fn: Callable,
163
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
164
+ layer: List[Module],
165
+ additional_forward_args: Any = None,
166
+ device_ids: Union[None, List[int]] = None,
167
+ attribute_to_layer_input: bool = False,
168
+ grad_enabled: bool = False,
169
+ ) -> List[Tuple[Tensor, ...]]:
170
+ ...
171
+
172
+
173
+ def _forward_layer_eval(
174
+ forward_fn: Callable,
175
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
176
+ layer: ModuleOrModuleList,
177
+ additional_forward_args: Any = None,
178
+ device_ids: Union[None, List[int]] = None,
179
+ attribute_to_layer_input: bool = False,
180
+ grad_enabled: bool = False,
181
+ ) -> Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]]:
182
+ return _forward_layer_eval_with_neuron_grads(
183
+ forward_fn,
184
+ inputs,
185
+ layer,
186
+ additional_forward_args=additional_forward_args,
187
+ gradient_neuron_selector=None,
188
+ grad_enabled=grad_enabled,
189
+ device_ids=device_ids,
190
+ attribute_to_layer_input=attribute_to_layer_input,
191
+ )
192
+
193
+
194
+ @typing.overload
195
+ def _forward_layer_distributed_eval(
196
+ forward_fn: Callable,
197
+ inputs: Any,
198
+ layer: ModuleOrModuleList,
199
+ target_ind: TargetType = None,
200
+ additional_forward_args: Any = None,
201
+ attribute_to_layer_input: bool = False,
202
+ forward_hook_with_return: Literal[False] = False,
203
+ require_layer_grads: bool = False,
204
+ ) -> Dict[Module, Dict[device, Tuple[Tensor, ...]]]:
205
+ ...
206
+
207
+
208
+ @typing.overload
209
+ def _forward_layer_distributed_eval(
210
+ forward_fn: Callable,
211
+ inputs: Any,
212
+ layer: ModuleOrModuleList,
213
+ target_ind: TargetType = None,
214
+ additional_forward_args: Any = None,
215
+ attribute_to_layer_input: bool = False,
216
+ *,
217
+ forward_hook_with_return: Literal[True],
218
+ require_layer_grads: bool = False,
219
+ ) -> Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor]:
220
+ ...
221
+
222
+
223
+ def _forward_layer_distributed_eval(
224
+ forward_fn: Callable,
225
+ inputs: Any,
226
+ layer: ModuleOrModuleList,
227
+ target_ind: TargetType = None,
228
+ additional_forward_args: Any = None,
229
+ attribute_to_layer_input: bool = False,
230
+ forward_hook_with_return: bool = False,
231
+ require_layer_grads: bool = False,
232
+ ) -> Union[
233
+ Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor],
234
+ Dict[Module, Dict[device, Tuple[Tensor, ...]]],
235
+ ]:
236
+ r"""
237
+ A helper function that allows to set a hook on model's `layer`, run the forward
238
+ pass and returns intermediate layer results, stored in a dictionary,
239
+ and optionally also the output of the forward function. The keys in the
240
+ dictionary are the device ids and the values are corresponding intermediate layer
241
+ results, either the inputs or the outputs of the layer depending on whether we set
242
+ `attribute_to_layer_input` to True or False.
243
+ This is especially useful when we execute forward pass in a distributed setting,
244
+ using `DataParallel`s for example.
245
+ """
246
+ saved_layer: Dict[Module, Dict[device, Tuple[Tensor, ...]]] = defaultdict(dict)
247
+ lock = threading.Lock()
248
+ all_layers: List[Module] = [layer] if isinstance(layer, Module) else layer
249
+
250
+ # Set a forward hook on specified module and run forward pass to
251
+ # get layer output tensor(s).
252
+ # For DataParallel models, each partition adds entry to dictionary
253
+ # with key as device and value as corresponding Tensor.
254
+ def hook_wrapper(original_module):
255
+ def forward_hook(module, inp, out=None):
256
+ eval_tsrs = inp if attribute_to_layer_input else out
257
+ is_eval_tuple = isinstance(eval_tsrs, tuple)
258
+
259
+ if not is_eval_tuple:
260
+ eval_tsrs = (eval_tsrs,)
261
+ if require_layer_grads:
262
+ apply_gradient_requirements(eval_tsrs, warn=False)
263
+ with lock:
264
+ nonlocal saved_layer
265
+ # Note that cloning behaviour of `eval_tsr` is different
266
+ # when `forward_hook_with_return` is set to True. This is because
267
+ # otherwise `backward()` on the last output layer won't execute.
268
+ if forward_hook_with_return:
269
+ saved_layer[original_module][eval_tsrs[0].device] = eval_tsrs
270
+ eval_tsrs_to_return = tuple(
271
+ eval_tsr.clone() for eval_tsr in eval_tsrs
272
+ )
273
+ if not is_eval_tuple:
274
+ eval_tsrs_to_return = eval_tsrs_to_return[0]
275
+ return eval_tsrs_to_return
276
+ else:
277
+ saved_layer[original_module][eval_tsrs[0].device] = tuple(
278
+ eval_tsr.clone() for eval_tsr in eval_tsrs
279
+ )
280
+
281
+ return forward_hook
282
+
283
+ all_hooks = []
284
+ try:
285
+ for single_layer in all_layers:
286
+ if attribute_to_layer_input:
287
+ all_hooks.append(
288
+ single_layer.register_forward_pre_hook(hook_wrapper(single_layer))
289
+ )
290
+ else:
291
+ all_hooks.append(
292
+ single_layer.register_forward_hook(hook_wrapper(single_layer))
293
+ )
294
+ output = _run_forward(
295
+ forward_fn,
296
+ inputs,
297
+ target=target_ind,
298
+ additional_forward_args=additional_forward_args,
299
+ )
300
+ finally:
301
+ for hook in all_hooks:
302
+ hook.remove()
303
+
304
+ if len(saved_layer) == 0:
305
+ raise AssertionError("Forward hook did not obtain any outputs for given layer")
306
+
307
+ if forward_hook_with_return:
308
+ return saved_layer, output
309
+ return saved_layer
310
+
311
+
312
+ def _gather_distributed_tensors(
313
+ saved_layer: Dict[device, Tuple[Tensor, ...]],
314
+ device_ids: Union[None, List[int]] = None,
315
+ key_list: Union[None, List[device]] = None,
316
+ ) -> Tuple[Tensor, ...]:
317
+ r"""
318
+ A helper function to concatenate intermediate layer results stored on
319
+ different devices in `saved_layer`. `saved_layer` is a dictionary that
320
+ contains `device_id` as a key and intermediate layer results (either
321
+ the input or the output of the layer) stored on the device corresponding to
322
+ the key.
323
+ `key_list` is a list of devices in appropriate ordering for concatenation
324
+ and if not provided, keys are sorted based on device ids.
325
+
326
+ If only one key exists (standard model), key list simply has one element.
327
+ """
328
+ if key_list is None:
329
+ key_list = _sort_key_list(list(saved_layer.keys()), device_ids)
330
+ return _reduce_list([saved_layer[device_id] for device_id in key_list])
331
+
332
+
333
+ def _extract_device_ids(
334
+ forward_fn: Callable,
335
+ saved_layer: Dict[Module, Dict[device, Tuple[Tensor, ...]]],
336
+ device_ids: Union[None, List[int]],
337
+ ) -> Union[None, List[int]]:
338
+ r"""
339
+ A helper function to extract device_ids from `forward_function` in case it is
340
+ provided as part of a `DataParallel` model or if is accessible from
341
+ `forward_fn`.
342
+ In case input device_ids is not None, this function returns that value.
343
+ """
344
+ # Multiple devices / keys implies a DataParallel model, so we look for
345
+ # device IDs if given or available from forward function
346
+ # (DataParallel model object).
347
+ if (
348
+ max(len(saved_layer[single_layer]) for single_layer in saved_layer) > 1
349
+ and device_ids is None
350
+ ):
351
+ if (
352
+ hasattr(forward_fn, "device_ids")
353
+ and cast(Any, forward_fn).device_ids is not None
354
+ ):
355
+ device_ids = cast(Any, forward_fn).device_ids
356
+ else:
357
+ raise AssertionError(
358
+ "Layer tensors are saved on multiple devices, however unable to access"
359
+ " device ID list from the `forward_fn`. Device ID list must be"
360
+ " accessible from `forward_fn`. For example, they can be retrieved"
361
+ " if `forward_fn` is a model of type `DataParallel`. It is used"
362
+ " for identifying device batch ordering."
363
+ )
364
+ return device_ids
365
+
366
+
367
+ @typing.overload
368
+ def _forward_layer_eval_with_neuron_grads(
369
+ forward_fn: Callable,
370
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
371
+ layer: Module,
372
+ additional_forward_args: Any = None,
373
+ *,
374
+ gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
375
+ grad_enabled: bool = False,
376
+ device_ids: Union[None, List[int]] = None,
377
+ attribute_to_layer_input: bool = False,
378
+ ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
379
+ ...
380
+
381
+
382
+ @typing.overload
383
+ def _forward_layer_eval_with_neuron_grads(
384
+ forward_fn: Callable,
385
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
386
+ layer: Module,
387
+ additional_forward_args: Any = None,
388
+ gradient_neuron_selector: None = None,
389
+ grad_enabled: bool = False,
390
+ device_ids: Union[None, List[int]] = None,
391
+ attribute_to_layer_input: bool = False,
392
+ ) -> Tuple[Tensor, ...]:
393
+ ...
394
+
395
+
396
+ @typing.overload
397
+ def _forward_layer_eval_with_neuron_grads(
398
+ forward_fn: Callable,
399
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
400
+ layer: List[Module],
401
+ additional_forward_args: Any = None,
402
+ gradient_neuron_selector: None = None,
403
+ grad_enabled: bool = False,
404
+ device_ids: Union[None, List[int]] = None,
405
+ attribute_to_layer_input: bool = False,
406
+ ) -> List[Tuple[Tensor, ...]]:
407
+ ...
408
+
409
+
410
+ def _forward_layer_eval_with_neuron_grads(
411
+ forward_fn: Callable,
412
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
413
+ layer: ModuleOrModuleList,
414
+ additional_forward_args: Any = None,
415
+ gradient_neuron_selector: Union[
416
+ None, int, Tuple[Union[int, slice], ...], Callable
417
+ ] = None,
418
+ grad_enabled: bool = False,
419
+ device_ids: Union[None, List[int]] = None,
420
+ attribute_to_layer_input: bool = False,
421
+ ) -> Union[
422
+ Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
423
+ Tuple[Tensor, ...],
424
+ List[Tuple[Tensor, ...]],
425
+ ]:
426
+ """
427
+ This method computes forward evaluation for a particular layer using a
428
+ forward hook. If a gradient_neuron_selector is provided, then gradients with
429
+ respect to that neuron in the layer output are also returned.
430
+
431
+ These functionalities are combined due to the behavior of DataParallel models
432
+ with hooks, in which hooks are executed once per device. We need to internally
433
+ combine the separated tensors from devices by concatenating based on device_ids.
434
+ Any necessary gradients must be taken with respect to each independent batched
435
+ tensor, so the gradients are computed and combined appropriately.
436
+
437
+ More information regarding the behavior of forward hooks with DataParallel models
438
+ can be found in the PyTorch data parallel documentation. We maintain the separate
439
+ evals in a dictionary protected by a lock, analogous to the gather implementation
440
+ for the core PyTorch DataParallel implementation.
441
+ """
442
+ grad_enabled = True if gradient_neuron_selector is not None else grad_enabled
443
+
444
+ with torch.autograd.set_grad_enabled(grad_enabled):
445
+ saved_layer = _forward_layer_distributed_eval(
446
+ forward_fn,
447
+ inputs,
448
+ layer,
449
+ additional_forward_args=additional_forward_args,
450
+ attribute_to_layer_input=attribute_to_layer_input,
451
+ )
452
+ device_ids = _extract_device_ids(forward_fn, saved_layer, device_ids)
453
+ # Identifies correct device ordering based on device ids.
454
+ # key_list is a list of devices in appropriate ordering for concatenation.
455
+ # If only one key exists (standard model), key list simply has one element.
456
+ key_list = _sort_key_list(list(next(iter(saved_layer.values())).keys()), device_ids)
457
+ if gradient_neuron_selector is not None:
458
+ assert isinstance(
459
+ layer, Module
460
+ ), "Cannot compute neuron gradients for multiple layers simultaneously!"
461
+ inp_grads = _neuron_gradients(
462
+ inputs, saved_layer[layer], key_list, gradient_neuron_selector
463
+ )
464
+ return (
465
+ _gather_distributed_tensors(saved_layer[layer], key_list=key_list),
466
+ inp_grads,
467
+ )
468
+ else:
469
+ if isinstance(layer, Module):
470
+ return _gather_distributed_tensors(saved_layer[layer], key_list=key_list)
471
+ else:
472
+ return [
473
+ _gather_distributed_tensors(saved_layer[curr_layer], key_list=key_list)
474
+ for curr_layer in layer
475
+ ]
476
+
477
+
478
+ @typing.overload
479
+ def compute_layer_gradients_and_eval(
480
+ forward_fn: Callable,
481
+ layer: Module,
482
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
483
+ target_ind: TargetType = None,
484
+ additional_forward_args: Any = None,
485
+ *,
486
+ gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
487
+ device_ids: Union[None, List[int]] = None,
488
+ attribute_to_layer_input: bool = False,
489
+ output_fn: Union[None, Callable] = None,
490
+ ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]]:
491
+ ...
492
+
493
+
494
+ @typing.overload
495
+ def compute_layer_gradients_and_eval(
496
+ forward_fn: Callable,
497
+ layer: List[Module],
498
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
499
+ target_ind: TargetType = None,
500
+ additional_forward_args: Any = None,
501
+ gradient_neuron_selector: None = None,
502
+ device_ids: Union[None, List[int]] = None,
503
+ attribute_to_layer_input: bool = False,
504
+ output_fn: Union[None, Callable] = None,
505
+ ) -> Tuple[List[Tuple[Tensor, ...]], List[Tuple[Tensor, ...]]]:
506
+ ...
507
+
508
+
509
+ @typing.overload
510
+ def compute_layer_gradients_and_eval(
511
+ forward_fn: Callable,
512
+ layer: Module,
513
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
514
+ target_ind: TargetType = None,
515
+ additional_forward_args: Any = None,
516
+ gradient_neuron_selector: None = None,
517
+ device_ids: Union[None, List[int]] = None,
518
+ attribute_to_layer_input: bool = False,
519
+ output_fn: Union[None, Callable] = None,
520
+ ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
521
+ ...
522
+
523
+
524
+ def compute_layer_gradients_and_eval(
525
+ forward_fn: Callable,
526
+ layer: ModuleOrModuleList,
527
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
528
+ target_ind: TargetType = None,
529
+ additional_forward_args: Any = None,
530
+ gradient_neuron_selector: Union[
531
+ None, int, Tuple[Union[int, slice], ...], Callable
532
+ ] = None,
533
+ device_ids: Union[None, List[int]] = None,
534
+ attribute_to_layer_input: bool = False,
535
+ output_fn: Union[None, Callable] = None,
536
+ ) -> Union[
537
+ Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
538
+ Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]],
539
+ Tuple[List[Tuple[Tensor, ...]], List[Tuple[Tensor, ...]]],
540
+ ]:
541
+ r"""
542
+ Computes gradients of the output with respect to a given layer as well
543
+ as the output evaluation of the layer for an arbitrary forward function
544
+ and given input.
545
+
546
+ For data parallel models, hooks are executed once per device ,so we
547
+ need to internally combine the separated tensors from devices by
548
+ concatenating based on device_ids. Any necessary gradients must be taken
549
+ with respect to each independent batched tensor, so the gradients are
550
+ computed and combined appropriately.
551
+
552
+ More information regarding the behavior of forward hooks with DataParallel
553
+ models can be found in the PyTorch data parallel documentation. We maintain
554
+ the separate inputs in a dictionary protected by a lock, analogous to the
555
+ gather implementation for the core PyTorch DataParallel implementation.
556
+
557
+ NOTE: To properly handle inplace operations, a clone of the layer output
558
+ is stored. This structure inhibits execution of a backward hook on the last
559
+ module for the layer output when computing the gradient with respect to
560
+ the input, since we store an intermediate clone, as
561
+ opposed to the true module output. If backward module hooks are necessary
562
+ for the final module when computing input gradients, utilize
563
+ _forward_layer_eval_with_neuron_grads instead.
564
+
565
+ Args:
566
+
567
+ forward_fn: forward function. This can be for example model's
568
+ forward function.
569
+ layer: Layer for which gradients / output will be evaluated.
570
+ inputs: Input at which gradients are evaluated,
571
+ will be passed to forward_fn.
572
+ target_ind: Index of the target class for which gradients
573
+ must be computed (classification only).
574
+ output_fn: An optional function that is applied to the layer inputs or
575
+ outputs depending whether the `attribute_to_layer_input` is
576
+ set to `True` or `False`
577
+ args: Additional input arguments that forward function requires.
578
+ It takes an empty tuple (no additional arguments) if no
579
+ additional arguments are required
580
+
581
+
582
+ Returns:
583
+ 2-element tuple of **gradients**, **evals**:
584
+ - **gradients**:
585
+ Gradients of output with respect to target layer output.
586
+ - **evals**:
587
+ Target layer output for given input.
588
+ """
589
+ with torch.autograd.set_grad_enabled(True):
590
+ # saved_layer is a dictionary mapping device to a tuple of
591
+ # layer evaluations on that device.
592
+ saved_layer, output = _forward_layer_distributed_eval(
593
+ forward_fn,
594
+ inputs,
595
+ layer,
596
+ target_ind=target_ind,
597
+ additional_forward_args=additional_forward_args,
598
+ attribute_to_layer_input=attribute_to_layer_input,
599
+ forward_hook_with_return=True,
600
+ require_layer_grads=True,
601
+ )
602
+ assert output[0].numel() == 1, (
603
+ "Target not provided when necessary, cannot"
604
+ " take gradient with respect to multiple outputs."
605
+ )
606
+
607
+ device_ids = _extract_device_ids(forward_fn, saved_layer, device_ids)
608
+
609
+ # Identifies correct device ordering based on device ids.
610
+ # key_list is a list of devices in appropriate ordering for concatenation.
611
+ # If only one key exists (standard model), key list simply has one element.
612
+ key_list = _sort_key_list(
613
+ list(next(iter(saved_layer.values())).keys()), device_ids
614
+ )
615
+ all_outputs: Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]]
616
+ if isinstance(layer, Module):
617
+ all_outputs = _reduce_list(
618
+ [
619
+ saved_layer[layer][device_id]
620
+ if output_fn is None
621
+ else output_fn(saved_layer[layer][device_id])
622
+ for device_id in key_list
623
+ ]
624
+ )
625
+ else:
626
+ all_outputs = [
627
+ _reduce_list(
628
+ [
629
+ saved_layer[single_layer][device_id]
630
+ if output_fn is None
631
+ else output_fn(saved_layer[single_layer][device_id])
632
+ for device_id in key_list
633
+ ]
634
+ )
635
+ for single_layer in layer
636
+ ]
637
+ all_layers: List[Module] = [layer] if isinstance(layer, Module) else layer
638
+ grad_inputs = tuple(
639
+ layer_tensor
640
+ for single_layer in all_layers
641
+ for device_id in key_list
642
+ for layer_tensor in saved_layer[single_layer][device_id]
643
+ )
644
+ saved_grads = torch.autograd.grad(torch.unbind(output), grad_inputs)
645
+
646
+ offset = 0
647
+ all_grads: List[Tuple[Tensor, ...]] = []
648
+ for single_layer in all_layers:
649
+ num_tensors = len(next(iter(saved_layer[single_layer].values())))
650
+ curr_saved_grads = [
651
+ saved_grads[i : i + num_tensors]
652
+ for i in range(
653
+ offset, offset + len(key_list) * num_tensors, num_tensors
654
+ )
655
+ ]
656
+ offset += len(key_list) * num_tensors
657
+ if output_fn is not None:
658
+ curr_saved_grads = [
659
+ output_fn(curr_saved_grad) for curr_saved_grad in curr_saved_grads
660
+ ]
661
+
662
+ all_grads.append(_reduce_list(curr_saved_grads))
663
+
664
+ layer_grads: Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]]
665
+ layer_grads = all_grads
666
+ if isinstance(layer, Module):
667
+ layer_grads = all_grads[0]
668
+
669
+ if gradient_neuron_selector is not None:
670
+ assert isinstance(
671
+ layer, Module
672
+ ), "Cannot compute neuron gradients for multiple layers simultaneously!"
673
+ inp_grads = _neuron_gradients(
674
+ inputs, saved_layer[layer], key_list, gradient_neuron_selector
675
+ )
676
+ return (
677
+ cast(Tuple[Tensor, ...], layer_grads),
678
+ cast(Tuple[Tensor, ...], all_outputs),
679
+ inp_grads,
680
+ )
681
+ return layer_grads, all_outputs # type: ignore
682
+
683
+
684
+ def construct_neuron_grad_fn(
685
+ layer: Module,
686
+ neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
687
+ device_ids: Union[None, List[int]] = None,
688
+ attribute_to_neuron_input: bool = False,
689
+ ) -> Callable:
690
+ def grad_fn(
691
+ forward_fn: Callable,
692
+ inputs: TensorOrTupleOfTensorsGeneric,
693
+ target_ind: TargetType = None,
694
+ additional_forward_args: Any = None,
695
+ ) -> Tuple[Tensor, ...]:
696
+ _, grads = _forward_layer_eval_with_neuron_grads(
697
+ forward_fn,
698
+ inputs,
699
+ layer,
700
+ additional_forward_args,
701
+ gradient_neuron_selector=neuron_selector,
702
+ device_ids=device_ids,
703
+ attribute_to_layer_input=attribute_to_neuron_input,
704
+ )
705
+ return grads
706
+
707
+ return grad_fn
708
+
709
+
710
+ def _compute_jacobian_wrt_params(
711
+ model: Module,
712
+ inputs: Tuple[Any, ...],
713
+ labels: Optional[Tensor] = None,
714
+ loss_fn: Optional[Union[Module, Callable]] = None,
715
+ ) -> Tuple[Tensor, ...]:
716
+ r"""
717
+ Computes the Jacobian of a batch of test examples given a model, and optional
718
+ loss function and target labels. This method is equivalent to calculating the
719
+ gradient for every individual example in the minibatch.
720
+
721
+ Args:
722
+ model (torch.nn.Module): The trainable model providing the forward pass
723
+ inputs (tuple of Any): The minibatch for which the forward pass is computed.
724
+ It is unpacked before passing to `model`, so it must be a tuple. The
725
+ individual elements of `inputs` can be anything.
726
+ labels (Tensor or None): Labels for input if computing a loss function.
727
+ loss_fn (torch.nn.Module or Callable or None): The loss function. If a library
728
+ defined loss function is provided, it would be expected to be a
729
+ torch.nn.Module. If a custom loss is provided, it can be either type,
730
+ but must behave as a library loss function would if `reduction='none'`.
731
+
732
+ Returns:
733
+ grads (Tuple of Tensor): Returns the Jacobian for the minibatch as a
734
+ tuple of gradients corresponding to the tuple of trainable parameters
735
+ returned by `model.parameters()`. Each object grads[i] references to the
736
+ gradients for the parameters in the i-th trainable layer of the model.
737
+ Each grads[i] object is a tensor with the gradients for the `inputs`
738
+ batch. For example, grads[i][j] would reference the gradients for the
739
+ parameters of the i-th layer, for the j-th member of the minibatch.
740
+ """
741
+ with torch.autograd.set_grad_enabled(True):
742
+ out = model(*inputs)
743
+ assert out.dim() != 0, "Please ensure model output has at least one dimension."
744
+
745
+ if labels is not None and loss_fn is not None:
746
+ loss = loss_fn(out, labels)
747
+ if hasattr(loss_fn, "reduction"):
748
+ msg0 = "Please ensure loss_fn.reduction is set to `none`"
749
+ assert loss_fn.reduction == "none", msg0 # type: ignore
750
+ else:
751
+ msg1 = (
752
+ "Loss function is applying a reduction. Please ensure "
753
+ f"Output shape: {out.shape} and Loss shape: {loss.shape} "
754
+ "are matching."
755
+ )
756
+ assert loss.dim() != 0, msg1
757
+ assert out.shape[0] == loss.shape[0], msg1
758
+ out = loss
759
+
760
+ grads_list = [
761
+ torch.autograd.grad(
762
+ outputs=out[i],
763
+ inputs=model.parameters(), # type: ignore
764
+ grad_outputs=torch.ones_like(out[i]),
765
+ retain_graph=True,
766
+ )
767
+ for i in range(out.shape[0])
768
+ ]
769
+
770
+ grads = tuple([torch.stack(x) for x in zip(*grads_list)])
771
+
772
+ return tuple(grads)
773
+
774
+
775
+ def _compute_jacobian_wrt_params_with_sample_wise_trick(
776
+ model: Module,
777
+ inputs: Tuple[Any, ...],
778
+ labels: Optional[Tensor] = None,
779
+ loss_fn: Optional[Union[Module, Callable]] = None,
780
+ reduction_type: Optional[str] = "sum",
781
+ ) -> Tuple[Any, ...]:
782
+ r"""
783
+ Computes the Jacobian of a batch of test examples given a model, and optional
784
+ loss function and target labels. This method uses sample-wise gradients per
785
+ batch trick to fully vectorize the Jacobian calculation. Currently, only
786
+ linear and conv2d layers are supported.
787
+
788
+ User must `add_hooks(model)` before calling this function.
789
+
790
+ Args:
791
+ model (torch.nn.Module): The trainable model providing the forward pass
792
+ inputs (tuple of Any): The minibatch for which the forward pass is computed.
793
+ It is unpacked before passing to `model`, so it must be a tuple. The
794
+ individual elements of `inputs` can be anything.
795
+ labels (Tensor or None): Labels for input if computing a loss function.
796
+ loss_fn (torch.nn.Module or Callable or None): The loss function. If a library
797
+ defined loss function is provided, it would be expected to be a
798
+ torch.nn.Module. If a custom loss is provided, it can be either type,
799
+ but must behave as a library loss function would if `reduction='sum'` or
800
+ `reduction='mean'`.
801
+ reduction_type (str): The type of reduction applied. If a loss_fn is passed,
802
+ this should match `loss_fn.reduction`. Else if gradients are being
803
+ computed on direct model outputs (scores), then 'sum' should be used.
804
+ Defaults to 'sum'.
805
+
806
+ Returns:
807
+ grads (Tuple of Tensor): Returns the Jacobian for the minibatch as a
808
+ tuple of gradients corresponding to the tuple of trainable parameters
809
+ returned by `model.parameters()`. Each object grads[i] references to the
810
+ gradients for the parameters in the i-th trainable layer of the model.
811
+ Each grads[i] object is a tensor with the gradients for the `inputs`
812
+ batch. For example, grads[i][j] would reference the gradients for the
813
+ parameters of the i-th layer, for the j-th member of the minibatch.
814
+ """
815
+ with torch.autograd.set_grad_enabled(True):
816
+ sample_grad_wrapper = SampleGradientWrapper(model)
817
+ try:
818
+ sample_grad_wrapper.add_hooks()
819
+
820
+ out = model(*inputs)
821
+ assert (
822
+ out.dim() != 0
823
+ ), "Please ensure model output has at least one dimension."
824
+
825
+ if labels is not None and loss_fn is not None:
826
+ loss = loss_fn(out, labels)
827
+ # TODO: allow loss_fn to be Callable
828
+ if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
829
+ msg0 = (
830
+ "Please ensure that loss_fn.reduction is set to `sum` or `mean`"
831
+ )
832
+
833
+ assert loss_fn.reduction != "none", msg0
834
+ msg1 = (
835
+ f"loss_fn.reduction ({loss_fn.reduction}) does not match"
836
+ f"reduction type ({reduction_type}). Please ensure they are"
837
+ " matching."
838
+ )
839
+ assert loss_fn.reduction == reduction_type, msg1
840
+ msg2 = (
841
+ "Please ensure custom loss function is applying either a "
842
+ "sum or mean reduction."
843
+ )
844
+ assert out.shape != loss.shape, msg2
845
+
846
+ if reduction_type != "sum" and reduction_type != "mean":
847
+ raise ValueError(
848
+ f"{reduction_type} is not a valid value for reduction_type. "
849
+ "Must be either 'sum' or 'mean'."
850
+ )
851
+ out = loss
852
+
853
+ sample_grad_wrapper.compute_param_sample_gradients(
854
+ out, loss_mode=reduction_type
855
+ )
856
+
857
+ grads = tuple(
858
+ param.sample_grad # type: ignore
859
+ for param in model.parameters()
860
+ if hasattr(param, "sample_grad")
861
+ )
862
+ finally:
863
+ sample_grad_wrapper.remove_hooks()
864
+
865
+ return grads
captum/_utils/models/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from captum._utils.models.linear_model import (
2
+ LinearModel,
3
+ SGDLasso,
4
+ SGDLinearModel,
5
+ SGDLinearRegression,
6
+ SGDRidge,
7
+ SkLearnLasso,
8
+ SkLearnLinearModel,
9
+ SkLearnLinearRegression,
10
+ SkLearnRidge,
11
+ )
12
+ from captum._utils.models.model import Model
13
+
14
+ __all__ = [
15
+ "Model",
16
+ "LinearModel",
17
+ "SGDLinearModel",
18
+ "SGDLasso",
19
+ "SGDRidge",
20
+ "SGDLinearRegression",
21
+ "SkLearnLinearModel",
22
+ "SkLearnLasso",
23
+ "SkLearnRidge",
24
+ "SkLearnLinearRegression",
25
+ ]
captum/_utils/models/linear_model/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from captum._utils.models.linear_model.model import (
2
+ LinearModel,
3
+ SGDLasso,
4
+ SGDLinearModel,
5
+ SGDLinearRegression,
6
+ SGDRidge,
7
+ SkLearnLasso,
8
+ SkLearnLinearModel,
9
+ SkLearnLinearRegression,
10
+ SkLearnRidge,
11
+ )
12
+
13
+ __all__ = [
14
+ "LinearModel",
15
+ "SGDLinearModel",
16
+ "SGDLasso",
17
+ "SGDRidge",
18
+ "SGDLinearRegression",
19
+ "SkLearnLinearModel",
20
+ "SkLearnLasso",
21
+ "SkLearnRidge",
22
+ "SkLearnLinearRegression",
23
+ ]
captum/_utils/models/linear_model/model.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, cast, List, Optional
2
+
3
+ import torch.nn as nn
4
+ from captum._utils.models.model import Model
5
+ from torch import Tensor
6
+ from torch.utils.data import DataLoader
7
+
8
+
9
+ class LinearModel(nn.Module, Model):
10
+ SUPPORTED_NORMS: List[Optional[str]] = [None, "batch_norm", "layer_norm"]
11
+
12
+ def __init__(self, train_fn: Callable, **kwargs) -> None:
13
+ r"""
14
+ Constructs a linear model with a training function and additional
15
+ construction arguments that will be sent to
16
+ `self._construct_model_params` after a `self.fit` is called. Please note
17
+ that this assumes the `self.train_fn` will call
18
+ `self._construct_model_params`.
19
+
20
+ Please note that this is an experimental feature.
21
+
22
+ Args:
23
+ train_fn (callable)
24
+ The function to train with. See
25
+ `captum._utils.models.linear_model.train.sgd_train_linear_model`
26
+ and
27
+ `captum._utils.models.linear_model.train.sklearn_train_linear_model`
28
+ for examples
29
+ kwargs
30
+ Any additional keyword arguments to send to
31
+ `self._construct_model_params` once a `self.fit` is called.
32
+ """
33
+ super().__init__()
34
+
35
+ self.norm: Optional[nn.Module] = None
36
+ self.linear: Optional[nn.Linear] = None
37
+ self.train_fn = train_fn
38
+ self.construct_kwargs = kwargs
39
+
40
+ def _construct_model_params(
41
+ self,
42
+ in_features: Optional[int] = None,
43
+ out_features: Optional[int] = None,
44
+ norm_type: Optional[str] = None,
45
+ affine_norm: bool = False,
46
+ bias: bool = True,
47
+ weight_values: Optional[Tensor] = None,
48
+ bias_value: Optional[Tensor] = None,
49
+ classes: Optional[Tensor] = None,
50
+ ):
51
+ r"""
52
+ Lazily initializes a linear model. This will be called for you in a
53
+ train method.
54
+
55
+ Args:
56
+ in_features (int):
57
+ The number of input features
58
+ output_features (int):
59
+ The number of output features.
60
+ norm_type (str, optional):
61
+ The type of normalization that can occur. Please assign this
62
+ to one of `PyTorchLinearModel.SUPPORTED_NORMS`.
63
+ affine_norm (bool):
64
+ Whether or not to learn an affine transformation of the
65
+ normalization parameters used.
66
+ bias (bool):
67
+ Whether to add a bias term. Not needed if normalized input.
68
+ weight_values (tensor, optional):
69
+ The values to initialize the linear model with. This must be a
70
+ 1D or 2D tensor, and of the form `(num_outputs, num_features)` or
71
+ `(num_features,)`. Additionally, if this is provided you need not
72
+ to provide `in_features` or `out_features`.
73
+ bias_value (tensor, optional):
74
+ The bias value to initialize the model with.
75
+ classes (tensor, optional):
76
+ The list of prediction classes supported by the model in case it
77
+ performs classificaton. In case of regression it is set to None.
78
+ Default: None
79
+ """
80
+ if norm_type not in LinearModel.SUPPORTED_NORMS:
81
+ raise ValueError(
82
+ f"{norm_type} not supported. Please use {LinearModel.SUPPORTED_NORMS}"
83
+ )
84
+
85
+ if weight_values is not None:
86
+ in_features = weight_values.shape[-1]
87
+ out_features = (
88
+ 1 if len(weight_values.shape) == 1 else weight_values.shape[0]
89
+ )
90
+
91
+ if in_features is None or out_features is None:
92
+ raise ValueError(
93
+ "Please provide `in_features` and `out_features` or `weight_values`"
94
+ )
95
+
96
+ if norm_type == "batch_norm":
97
+ self.norm = nn.BatchNorm1d(in_features, eps=1e-8, affine=affine_norm)
98
+ elif norm_type == "layer_norm":
99
+ self.norm = nn.LayerNorm(
100
+ in_features, eps=1e-8, elementwise_affine=affine_norm
101
+ )
102
+ else:
103
+ self.norm = None
104
+
105
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
106
+
107
+ if weight_values is not None:
108
+ self.linear.weight.data = weight_values
109
+
110
+ if bias_value is not None:
111
+ if not bias:
112
+ raise ValueError("`bias_value` is not None and bias is False")
113
+
114
+ self.linear.bias.data = bias_value
115
+
116
+ if classes is not None:
117
+ self.linear.classes = classes
118
+
119
+ def fit(self, train_data: DataLoader, **kwargs):
120
+ r"""
121
+ Calls `self.train_fn`
122
+ """
123
+ return self.train_fn(
124
+ self,
125
+ dataloader=train_data,
126
+ construct_kwargs=self.construct_kwargs,
127
+ **kwargs,
128
+ )
129
+
130
+ def forward(self, x: Tensor) -> Tensor:
131
+ assert self.linear is not None
132
+ if self.norm is not None:
133
+ x = self.norm(x)
134
+ return self.linear(x)
135
+
136
+ def representation(self) -> Tensor:
137
+ r"""
138
+ Returns a tensor which describes the hyper-plane input space. This does
139
+ not include the bias. For bias/intercept, please use `self.bias`
140
+ """
141
+ assert self.linear is not None
142
+ return self.linear.weight.detach()
143
+
144
+ def bias(self) -> Optional[Tensor]:
145
+ r"""
146
+ Returns the bias of the linear model
147
+ """
148
+ if self.linear is None or self.linear.bias is None:
149
+ return None
150
+ return self.linear.bias.detach()
151
+
152
+ def classes(self) -> Optional[Tensor]:
153
+ if self.linear is None or self.linear.classes is None:
154
+ return None
155
+ return cast(Tensor, self.linear.classes).detach()
156
+
157
+
158
+ class SGDLinearModel(LinearModel):
159
+ def __init__(self, **kwargs) -> None:
160
+ r"""
161
+ Factory class. Construct a a `LinearModel` with the
162
+ `sgd_train_linear_model` as the train method
163
+
164
+ Args:
165
+ kwargs
166
+ Arguments send to `self._construct_model_params` after
167
+ `self.fit` is called. Please refer to that method for parameter
168
+ documentation.
169
+ """
170
+ # avoid cycles
171
+ from captum._utils.models.linear_model.train import sgd_train_linear_model
172
+
173
+ super().__init__(train_fn=sgd_train_linear_model, **kwargs)
174
+
175
+
176
+ class SGDLasso(SGDLinearModel):
177
+ def __init__(self, **kwargs) -> None:
178
+ r"""
179
+ Factory class to train a `LinearModel` with SGD
180
+ (`sgd_train_linear_model`) whilst setting appropriate parameters to
181
+ optimize for ridge regression loss. This optimizes L2 loss + alpha * L1
182
+ regularization.
183
+
184
+ Please note that with SGD it is not guaranteed that weights will
185
+ converge to 0.
186
+ """
187
+ super().__init__(**kwargs)
188
+
189
+ def fit(self, train_data: DataLoader, **kwargs):
190
+ # avoid cycles
191
+ from captum._utils.models.linear_model.train import l2_loss
192
+
193
+ return super().fit(train_data=train_data, loss_fn=l2_loss, reg_term=1, **kwargs)
194
+
195
+
196
+ class SGDRidge(SGDLinearModel):
197
+ def __init__(self, **kwargs) -> None:
198
+ r"""
199
+ Factory class to train a `LinearModel` with SGD
200
+ (`sgd_train_linear_model`) whilst setting appropriate parameters to
201
+ optimize for ridge regression loss. This optimizes L2 loss + alpha *
202
+ L2 regularization.
203
+ """
204
+ super().__init__(**kwargs)
205
+
206
+ def fit(self, train_data: DataLoader, **kwargs):
207
+ # avoid cycles
208
+ from captum._utils.models.linear_model.train import l2_loss
209
+
210
+ return super().fit(train_data=train_data, loss_fn=l2_loss, reg_term=2, **kwargs)
211
+
212
+
213
+ class SGDLinearRegression(SGDLinearModel):
214
+ def __init__(self, **kwargs) -> None:
215
+ r"""
216
+ Factory class to train a `LinearModel` with SGD
217
+ (`sgd_train_linear_model`). For linear regression this assigns the loss
218
+ to L2 and no regularization.
219
+ """
220
+ super().__init__(**kwargs)
221
+
222
+ def fit(self, train_data: DataLoader, **kwargs):
223
+ # avoid cycles
224
+ from captum._utils.models.linear_model.train import l2_loss
225
+
226
+ return super().fit(
227
+ train_data=train_data, loss_fn=l2_loss, reg_term=None, **kwargs
228
+ )
229
+
230
+
231
+ class SkLearnLinearModel(LinearModel):
232
+ def __init__(self, sklearn_module: str, **kwargs) -> None:
233
+ r"""
234
+ Factory class to construct a `LinearModel` with sklearn training method.
235
+
236
+ Please note that this assumes:
237
+
238
+ 0. You have sklearn and numpy installed
239
+ 1. The dataset can fit into memory
240
+
241
+ SkLearn support does introduce some slight overhead as we convert the
242
+ tensors to numpy and then convert the resulting trained model to a
243
+ `LinearModel` object. However, this conversion should be negligible.
244
+
245
+ Args:
246
+ sklearn_module
247
+ The module under sklearn to construct and use for training, e.g.
248
+ use "svm.LinearSVC" for an SVM or "linear_model.Lasso" for Lasso.
249
+
250
+ There are factory classes defined for you for common use cases,
251
+ such as `SkLearnLasso`.
252
+ kwargs
253
+ The kwargs to pass to the construction of the sklearn model
254
+ """
255
+ # avoid cycles
256
+ from captum._utils.models.linear_model.train import sklearn_train_linear_model
257
+
258
+ super().__init__(train_fn=sklearn_train_linear_model, **kwargs)
259
+
260
+ self.sklearn_module = sklearn_module
261
+
262
+ def fit(self, train_data: DataLoader, **kwargs):
263
+ r"""
264
+ Args:
265
+ train_data
266
+ Train data to use
267
+ kwargs
268
+ Arguments to feed to `.fit` method for sklearn
269
+ """
270
+ return super().fit(
271
+ train_data=train_data, sklearn_trainer=self.sklearn_module, **kwargs
272
+ )
273
+
274
+
275
+ class SkLearnLasso(SkLearnLinearModel):
276
+ def __init__(self, **kwargs) -> None:
277
+ r"""
278
+ Factory class. Trains a `LinearModel` model with
279
+ `sklearn.linear_model.Lasso`. You will need sklearn version >= 0.23 to
280
+ support sample weights.
281
+ """
282
+ super().__init__(sklearn_module="linear_model.Lasso", **kwargs)
283
+
284
+ def fit(self, train_data: DataLoader, **kwargs):
285
+ return super().fit(train_data=train_data, **kwargs)
286
+
287
+
288
+ class SkLearnRidge(SkLearnLinearModel):
289
+ def __init__(self, **kwargs) -> None:
290
+ r"""
291
+ Factory class. Trains a model with `sklearn.linear_model.Ridge`.
292
+
293
+ Any arguments provided to the sklearn constructor can be provided
294
+ as kwargs here.
295
+ """
296
+ super().__init__(sklearn_module="linear_model.Ridge", **kwargs)
297
+
298
+ def fit(self, train_data: DataLoader, **kwargs):
299
+ return super().fit(train_data=train_data, **kwargs)
300
+
301
+
302
+ class SkLearnLinearRegression(SkLearnLinearModel):
303
+ def __init__(self, **kwargs) -> None:
304
+ r"""
305
+ Factory class. Trains a model with `sklearn.linear_model.LinearRegression`.
306
+
307
+ Any arguments provided to the sklearn constructor can be provided
308
+ as kwargs here.
309
+ """
310
+ super().__init__(sklearn_module="linear_model.LinearRegression", **kwargs)
311
+
312
+ def fit(self, train_data: DataLoader, **kwargs):
313
+ return super().fit(train_data=train_data, **kwargs)
314
+
315
+
316
+ class SkLearnLogisticRegression(SkLearnLinearModel):
317
+ def __init__(self, **kwargs) -> None:
318
+ r"""
319
+ Factory class. Trains a model with `sklearn.linear_model.LogisticRegression`.
320
+
321
+ Any arguments provided to the sklearn constructor can be provided
322
+ as kwargs here.
323
+ """
324
+ super().__init__(sklearn_module="linear_model.LogisticRegression", **kwargs)
325
+
326
+ def fit(self, train_data: DataLoader, **kwargs):
327
+ return super().fit(train_data=train_data, **kwargs)
328
+
329
+
330
+ class SkLearnSGDClassifier(SkLearnLinearModel):
331
+ def __init__(self, **kwargs) -> None:
332
+ r"""
333
+ Factory class. Trains a model with `sklearn.linear_model.SGDClassifier(`.
334
+
335
+ Any arguments provided to the sklearn constructor can be provided
336
+ as kwargs here.
337
+ """
338
+ super().__init__(sklearn_module="linear_model.SGDClassifier", **kwargs)
339
+
340
+ def fit(self, train_data: DataLoader, **kwargs):
341
+ return super().fit(train_data=train_data, **kwargs)
captum/_utils/models/linear_model/train.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import warnings
3
+ from typing import Any, Callable, Dict, List, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from captum._utils.models.linear_model.model import LinearModel
8
+ from torch.utils.data import DataLoader
9
+
10
+
11
+ def l2_loss(x1, x2, weights=None):
12
+ if weights is None:
13
+ return torch.mean((x1 - x2) ** 2) / 2.0
14
+ else:
15
+ return torch.sum((weights / weights.norm(p=1)) * ((x1 - x2) ** 2)) / 2.0
16
+
17
+
18
+ def sgd_train_linear_model(
19
+ model: LinearModel,
20
+ dataloader: DataLoader,
21
+ construct_kwargs: Dict[str, Any],
22
+ max_epoch: int = 100,
23
+ reduce_lr: bool = True,
24
+ initial_lr: float = 0.01,
25
+ alpha: float = 1.0,
26
+ loss_fn: Callable = l2_loss,
27
+ reg_term: Optional[int] = 1,
28
+ patience: int = 10,
29
+ threshold: float = 1e-4,
30
+ running_loss_window: Optional[int] = None,
31
+ device: Optional[str] = None,
32
+ init_scheme: str = "zeros",
33
+ debug: bool = False,
34
+ ) -> Dict[str, float]:
35
+ r"""
36
+ Trains a linear model with SGD. This will continue to iterate your
37
+ dataloader until we converged to a solution or alternatively until we have
38
+ exhausted `max_epoch`.
39
+
40
+ Convergence is defined by the loss not changing by `threshold` amount for
41
+ `patience` number of iterations.
42
+
43
+ Args:
44
+ model
45
+ The model to train
46
+ dataloader
47
+ The data to train it with. We will assume the dataloader produces
48
+ either pairs or triples of the form (x, y) or (x, y, w). Where x and
49
+ y are typical pairs for supervised learning and w is a weight
50
+ vector.
51
+
52
+ We will call `model._construct_model_params` with construct_kwargs
53
+ and the input features set to `x.shape[1]` (`x.shape[0]` corresponds
54
+ to the batch size). We assume that `len(x.shape) == 2`, i.e. the
55
+ tensor is flat. The number of output features will be set to
56
+ y.shape[1] or 1 (if `len(y.shape) == 1`); we require `len(y.shape)
57
+ <= 2`.
58
+ max_epoch
59
+ The maximum number of epochs to exhaust
60
+ reduce_lr
61
+ Whether or not to reduce the learning rate as iterations progress.
62
+ Halves the learning rate when the training loss does not move. This
63
+ uses torch.optim.lr_scheduler.ReduceLROnPlateau and uses the
64
+ parameters `patience` and `threshold`
65
+ initial_lr
66
+ The initial learning rate to use.
67
+ alpha
68
+ A constant for the regularization term.
69
+ loss_fn
70
+ The loss to optimise for. This must accept three parameters:
71
+ x1 (predicted), x2 (labels) and a weight vector
72
+ reg_term
73
+ Regularization is defined by the `reg_term` norm of the weights.
74
+ Please use `None` if you do not wish to use regularization.
75
+ patience
76
+ Defines the number of iterations in a row the loss must remain
77
+ within `threshold` in order to be classified as converged.
78
+ threshold
79
+ Threshold for convergence detection.
80
+ running_loss_window
81
+ Used to report the training loss once we have finished training and
82
+ to determine when we have converged (along with reducing the
83
+ learning rate).
84
+
85
+ The reported training loss will take the last `running_loss_window`
86
+ iterations and average them.
87
+
88
+ If `None` we will approximate this to be the number of examples in
89
+ an epoch.
90
+ init_scheme
91
+ Initialization to use prior to training the linear model.
92
+ device
93
+ The device to send the model and data to. If None then no `.to` call
94
+ will be used.
95
+ debug
96
+ Whether to print the loss, learning rate per iteration
97
+
98
+ Returns
99
+ This will return the final training loss (averaged with
100
+ `running_loss_window`)
101
+ """
102
+
103
+ loss_window: List[torch.Tensor] = []
104
+ min_avg_loss = None
105
+ convergence_counter = 0
106
+ converged = False
107
+
108
+ def get_point(datapoint):
109
+ if len(datapoint) == 2:
110
+ x, y = datapoint
111
+ w = None
112
+ else:
113
+ x, y, w = datapoint
114
+
115
+ if device is not None:
116
+ x = x.to(device)
117
+ y = y.to(device)
118
+ if w is not None:
119
+ w = w.to(device)
120
+
121
+ return x, y, w
122
+
123
+ # get a point and construct the model
124
+ data_iter = iter(dataloader)
125
+ x, y, w = get_point(next(data_iter))
126
+
127
+ model._construct_model_params(
128
+ in_features=x.shape[1],
129
+ out_features=y.shape[1] if len(y.shape) == 2 else 1,
130
+ **construct_kwargs,
131
+ )
132
+ model.train()
133
+
134
+ assert model.linear is not None
135
+
136
+ if init_scheme is not None:
137
+ assert init_scheme in ["xavier", "zeros"]
138
+
139
+ with torch.no_grad():
140
+ if init_scheme == "xavier":
141
+ torch.nn.init.xavier_uniform_(model.linear.weight)
142
+ else:
143
+ model.linear.weight.zero_()
144
+
145
+ if model.linear.bias is not None:
146
+ model.linear.bias.zero_()
147
+
148
+ optim = torch.optim.SGD(model.parameters(), lr=initial_lr)
149
+ if reduce_lr:
150
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
151
+ optim, factor=0.5, patience=patience, threshold=threshold
152
+ )
153
+
154
+ t1 = time.time()
155
+ epoch = 0
156
+ i = 0
157
+ while epoch < max_epoch:
158
+ while True: # for x, y, w in dataloader
159
+ if running_loss_window is None:
160
+ running_loss_window = x.shape[0] * len(dataloader)
161
+
162
+ y = y.view(x.shape[0], -1)
163
+ if w is not None:
164
+ w = w.view(x.shape[0], -1)
165
+
166
+ i += 1
167
+
168
+ out = model(x)
169
+
170
+ loss = loss_fn(y, out, w)
171
+ if reg_term is not None:
172
+ reg = torch.norm(model.linear.weight, p=reg_term)
173
+ loss += reg.sum() * alpha
174
+
175
+ if len(loss_window) >= running_loss_window:
176
+ loss_window = loss_window[1:]
177
+ loss_window.append(loss.clone().detach())
178
+ assert len(loss_window) <= running_loss_window
179
+
180
+ average_loss = torch.mean(torch.stack(loss_window))
181
+ if min_avg_loss is not None:
182
+ # if we haven't improved by at least `threshold`
183
+ if average_loss > min_avg_loss or torch.isclose(
184
+ min_avg_loss, average_loss, atol=threshold
185
+ ):
186
+ convergence_counter += 1
187
+ if convergence_counter >= patience:
188
+ converged = True
189
+ break
190
+ else:
191
+ convergence_counter = 0
192
+ if min_avg_loss is None or min_avg_loss >= average_loss:
193
+ min_avg_loss = average_loss.clone()
194
+
195
+ if debug:
196
+ print(
197
+ f"lr={optim.param_groups[0]['lr']}, Loss={loss},"
198
+ + "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
199
+ )
200
+
201
+ loss.backward()
202
+
203
+ optim.step()
204
+ model.zero_grad()
205
+ if scheduler:
206
+ scheduler.step(average_loss)
207
+
208
+ temp = next(data_iter, None)
209
+ if temp is None:
210
+ break
211
+ x, y, w = get_point(temp)
212
+
213
+ if converged:
214
+ break
215
+
216
+ epoch += 1
217
+ data_iter = iter(dataloader)
218
+ x, y, w = get_point(next(data_iter))
219
+
220
+ t2 = time.time()
221
+ return {
222
+ "train_time": t2 - t1,
223
+ "train_loss": torch.mean(torch.stack(loss_window)).item(),
224
+ "train_iter": i,
225
+ "train_epoch": epoch,
226
+ }
227
+
228
+
229
+ class NormLayer(nn.Module):
230
+ def __init__(self, mean, std, n=None, eps=1e-8) -> None:
231
+ super().__init__()
232
+ self.mean = mean
233
+ self.std = std
234
+ self.eps = eps
235
+
236
+ def forward(self, x):
237
+ return (x - self.mean) / (self.std + self.eps)
238
+
239
+
240
+ def sklearn_train_linear_model(
241
+ model: LinearModel,
242
+ dataloader: DataLoader,
243
+ construct_kwargs: Dict[str, Any],
244
+ sklearn_trainer: str = "Lasso",
245
+ norm_input: bool = False,
246
+ **fit_kwargs,
247
+ ):
248
+ r"""
249
+ Alternative method to train with sklearn. This does introduce some slight
250
+ overhead as we convert the tensors to numpy and then convert the resulting
251
+ trained model to a `LinearModel` object. However, this conversion
252
+ should be negligible.
253
+
254
+ Please note that this assumes:
255
+
256
+ 0. You have sklearn and numpy installed
257
+ 1. The dataset can fit into memory
258
+
259
+ Args
260
+ model
261
+ The model to train.
262
+ dataloader
263
+ The data to use. This will be exhausted and converted to numpy
264
+ arrays. Therefore please do not feed an infinite dataloader.
265
+ norm_input
266
+ Whether or not to normalize the input
267
+ sklearn_trainer
268
+ The sklearn model to use to train the model. Please refer to
269
+ sklearn.linear_model for a list of modules to use.
270
+ construct_kwargs
271
+ Additional arguments provided to the `sklearn_trainer` constructor
272
+ fit_kwargs
273
+ Other arguments to send to `sklearn_trainer`'s `.fit` method
274
+ """
275
+ from functools import reduce
276
+
277
+ try:
278
+ import numpy as np
279
+ except ImportError:
280
+ raise ValueError("numpy is not available. Please install numpy.")
281
+
282
+ try:
283
+ import sklearn
284
+ import sklearn.linear_model
285
+ import sklearn.svm
286
+ except ImportError:
287
+ raise ValueError("sklearn is not available. Please install sklearn >= 0.23")
288
+
289
+ if not sklearn.__version__ >= "0.23.0":
290
+ warnings.warn(
291
+ "Must have sklearn version 0.23.0 or higher to use "
292
+ "sample_weight in Lasso regression."
293
+ )
294
+
295
+ num_batches = 0
296
+ xs, ys, ws = [], [], []
297
+ for data in dataloader:
298
+ if len(data) == 3:
299
+ x, y, w = data
300
+ else:
301
+ assert len(data) == 2
302
+ x, y = data
303
+ w = None
304
+
305
+ xs.append(x.cpu().numpy())
306
+ ys.append(y.cpu().numpy())
307
+ if w is not None:
308
+ ws.append(w.cpu().numpy())
309
+ num_batches += 1
310
+
311
+ x = np.concatenate(xs, axis=0)
312
+ y = np.concatenate(ys, axis=0)
313
+ if len(ws) > 0:
314
+ w = np.concatenate(ws, axis=0)
315
+ else:
316
+ w = None
317
+
318
+ if norm_input:
319
+ mean, std = x.mean(0), x.std(0)
320
+ x -= mean
321
+ x /= std
322
+
323
+ t1 = time.time()
324
+ sklearn_model = reduce(
325
+ lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".")
326
+ )(**construct_kwargs)
327
+ try:
328
+ sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
329
+ except TypeError:
330
+ sklearn_model.fit(x, y, **fit_kwargs)
331
+ warnings.warn(
332
+ "Sample weight is not supported for the provided linear model!"
333
+ " Trained model without weighting inputs. For Lasso, please"
334
+ " upgrade sklearn to a version >= 0.23.0."
335
+ )
336
+
337
+ t2 = time.time()
338
+
339
+ # Convert weights to pytorch
340
+ classes = (
341
+ torch.IntTensor(sklearn_model.classes_)
342
+ if hasattr(sklearn_model, "classes_")
343
+ else None
344
+ )
345
+
346
+ # extract model device
347
+ device = model.device if hasattr(model, "device") else "cpu"
348
+
349
+ num_outputs = sklearn_model.coef_.shape[0] if sklearn_model.coef_.ndim > 1 else 1
350
+ weight_values = torch.FloatTensor(sklearn_model.coef_).to(device) # type: ignore
351
+ bias_values = torch.FloatTensor([sklearn_model.intercept_]).to( # type: ignore
352
+ device # type: ignore
353
+ ) # type: ignore
354
+ model._construct_model_params(
355
+ norm_type=None,
356
+ weight_values=weight_values.view(num_outputs, -1),
357
+ bias_value=bias_values.squeeze().unsqueeze(0),
358
+ classes=classes,
359
+ )
360
+
361
+ if norm_input:
362
+ model.norm = NormLayer(mean, std)
363
+
364
+ return {"train_time": t2 - t1}
captum/_utils/models/model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, Optional, Union
5
+
6
+ from captum._utils.typing import TensorOrTupleOfTensorsGeneric
7
+ from torch import Tensor
8
+ from torch.utils.data import DataLoader
9
+
10
+
11
+ class Model(ABC):
12
+ r"""
13
+ Abstract Class to describe the interface of a trainable model to be used
14
+ within the algorithms of captum.
15
+
16
+ Please note that this is an experimental feature.
17
+ """
18
+
19
+ @abstractmethod
20
+ def fit(
21
+ self, train_data: DataLoader, **kwargs
22
+ ) -> Optional[Dict[str, Union[int, float, Tensor]]]:
23
+ r"""
24
+ Override this method to actually train your model.
25
+
26
+ The specification of the dataloader will be supplied by the algorithm
27
+ you are using within captum. This will likely be a supervised learning
28
+ task, thus you should expect batched (x, y) pairs or (x, y, w) triples.
29
+
30
+ Args:
31
+ train_data (DataLoader):
32
+ The data to train on
33
+
34
+ Returns:
35
+ Optional statistics about training, e.g. iterations it took to
36
+ train, training loss, etc.
37
+ """
38
+ pass
39
+
40
+ @abstractmethod
41
+ def representation(self) -> Tensor:
42
+ r"""
43
+ Returns the underlying representation of the interpretable model. For a
44
+ linear model this is simply a tensor (the concatenation of weights
45
+ and bias). For something slightly more complicated, such as a decision
46
+ tree, this could be the nodes of a decision tree.
47
+
48
+ Returns:
49
+ A Tensor describing the representation of the model.
50
+ """
51
+ pass
52
+
53
+ @abstractmethod
54
+ def __call__(
55
+ self, x: TensorOrTupleOfTensorsGeneric
56
+ ) -> TensorOrTupleOfTensorsGeneric:
57
+ r"""
58
+ Predicts with the interpretable model.
59
+
60
+ Args:
61
+ x (TensorOrTupleOfTensorsGeneric)
62
+ A batched input of tensor(s) to the model to predict
63
+ Returns:
64
+ The prediction of the input as a TensorOrTupleOfTensorsGeneric.
65
+ """
66
+ pass
captum/_utils/progress.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import sys
4
+ import warnings
5
+ from time import time
6
+ from typing import cast, Iterable, Sized, TextIO
7
+
8
+ try:
9
+ from tqdm import tqdm
10
+ except ImportError:
11
+ tqdm = None
12
+
13
+
14
+ class DisableErrorIOWrapper(object):
15
+ def __init__(self, wrapped: TextIO):
16
+ """
17
+ The wrapper around a TextIO object to ignore write errors like tqdm
18
+ https://github.com/tqdm/tqdm/blob/bcce20f771a16cb8e4ac5cc5b2307374a2c0e535/tqdm/utils.py#L131
19
+ """
20
+ self._wrapped = wrapped
21
+
22
+ def __getattr__(self, name):
23
+ return getattr(self._wrapped, name)
24
+
25
+ @staticmethod
26
+ def _wrapped_run(func, *args, **kwargs):
27
+ try:
28
+ return func(*args, **kwargs)
29
+ except OSError as e:
30
+ if e.errno != 5:
31
+ raise
32
+ except ValueError as e:
33
+ if "closed" not in str(e):
34
+ raise
35
+
36
+ def write(self, *args, **kwargs):
37
+ return self._wrapped_run(self._wrapped.write, *args, **kwargs)
38
+
39
+ def flush(self, *args, **kwargs):
40
+ return self._wrapped_run(self._wrapped.flush, *args, **kwargs)
41
+
42
+
43
+ class SimpleProgress:
44
+ def __init__(
45
+ self,
46
+ iterable: Iterable = None,
47
+ desc: str = None,
48
+ total: int = None,
49
+ file: TextIO = None,
50
+ mininterval: float = 0.5,
51
+ ):
52
+ """
53
+ Simple progress output used when tqdm is unavailable.
54
+ Same as tqdm, output to stderr channel
55
+ """
56
+ self.cur = 0
57
+
58
+ self.iterable = iterable
59
+ self.total = total
60
+ if total is None and hasattr(iterable, "__len__"):
61
+ self.total = len(cast(Sized, iterable))
62
+
63
+ self.desc = desc
64
+
65
+ file = DisableErrorIOWrapper(file if file else sys.stderr)
66
+ cast(TextIO, file)
67
+ self.file = file
68
+
69
+ self.mininterval = mininterval
70
+ self.last_print_t = 0.0
71
+ self.closed = False
72
+
73
+ def __iter__(self):
74
+ if self.closed or not self.iterable:
75
+ return
76
+ self._refresh()
77
+ for it in self.iterable:
78
+ yield it
79
+ self.update()
80
+ self.close()
81
+
82
+ def _refresh(self):
83
+ progress_str = self.desc + ": " if self.desc else ""
84
+ if self.total:
85
+ # e.g., progress: 60% 3/5
86
+ progress_str += f"{100 * self.cur // self.total}% {self.cur}/{self.total}"
87
+ else:
88
+ # e.g., progress: .....
89
+ progress_str += "." * self.cur
90
+
91
+ print("\r" + progress_str, end="", file=self.file)
92
+
93
+ def update(self, amount: int = 1):
94
+ if self.closed:
95
+ return
96
+ self.cur += amount
97
+
98
+ cur_t = time()
99
+ if cur_t - self.last_print_t >= self.mininterval:
100
+ self._refresh()
101
+ self.last_print_t = cur_t
102
+
103
+ def close(self):
104
+ if not self.closed:
105
+ self._refresh()
106
+ print(file=self.file) # end with new line
107
+ self.closed = True
108
+
109
+
110
+ def progress(
111
+ iterable: Iterable = None,
112
+ desc: str = None,
113
+ total: int = None,
114
+ use_tqdm=True,
115
+ file: TextIO = None,
116
+ mininterval: float = 0.5,
117
+ **kwargs,
118
+ ):
119
+ # Try to use tqdm is possible. Fall back to simple progress print
120
+ if tqdm and use_tqdm:
121
+ return tqdm(
122
+ iterable,
123
+ desc=desc,
124
+ total=total,
125
+ file=file,
126
+ mininterval=mininterval,
127
+ **kwargs,
128
+ )
129
+ else:
130
+ if not tqdm and use_tqdm:
131
+ warnings.warn(
132
+ "Tried to show progress with tqdm "
133
+ "but tqdm is not installed. "
134
+ "Fall back to simply print out the progress."
135
+ )
136
+ return SimpleProgress(
137
+ iterable, desc=desc, total=total, file=file, mininterval=mininterval
138
+ )
captum/_utils/sample_gradient.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from enum import Enum
3
+ from typing import cast, Iterable, Tuple, Union
4
+
5
+ import torch
6
+ from captum._utils.common import _format_tensor_into_tuples, _register_backward_hook
7
+ from torch import Tensor
8
+ from torch.nn import Module
9
+
10
+
11
+ def _reset_sample_grads(module: Module):
12
+ module.weight.sample_grad = 0 # type: ignore
13
+ if module.bias is not None:
14
+ module.bias.sample_grad = 0 # type: ignore
15
+
16
+
17
+ def linear_param_grads(
18
+ module: Module, activation: Tensor, gradient_out: Tensor, reset: bool = False
19
+ ) -> None:
20
+ r"""
21
+ Computes parameter gradients per sample for nn.Linear module, given module
22
+ input activations and output gradients.
23
+
24
+ Gradients are accumulated in the sample_grad attribute of each parameter
25
+ (weight and bias). If reset = True, any current sample_grad values are reset,
26
+ otherwise computed gradients are accumulated and added to the existing
27
+ stored gradients.
28
+
29
+ Inputs with more than 2 dimensions are only supported with torch 1.8 or later
30
+ """
31
+ if reset:
32
+ _reset_sample_grads(module)
33
+
34
+ module.weight.sample_grad += torch.einsum( # type: ignore
35
+ "n...i,n...j->nij", gradient_out, activation
36
+ )
37
+ if module.bias is not None:
38
+ module.bias.sample_grad += torch.einsum( # type: ignore
39
+ "n...i->ni", gradient_out
40
+ )
41
+
42
+
43
+ def conv2d_param_grads(
44
+ module: Module, activation: Tensor, gradient_out: Tensor, reset: bool = False
45
+ ) -> None:
46
+ r"""
47
+ Computes parameter gradients per sample for nn.Conv2d module, given module
48
+ input activations and output gradients.
49
+
50
+ nn.Conv2d modules with padding set to a string option ('same' or 'valid') are
51
+ currently unsupported.
52
+
53
+ Gradients are accumulated in the sample_grad attribute of each parameter
54
+ (weight and bias). If reset = True, any current sample_grad values are reset,
55
+ otherwise computed gradients are accumulated and added to the existing
56
+ stored gradients.
57
+ """
58
+ if reset:
59
+ _reset_sample_grads(module)
60
+
61
+ batch_size = cast(int, activation.shape[0])
62
+ unfolded_act = torch.nn.functional.unfold(
63
+ activation,
64
+ cast(Union[int, Tuple[int, ...]], module.kernel_size),
65
+ dilation=cast(Union[int, Tuple[int, ...]], module.dilation),
66
+ padding=cast(Union[int, Tuple[int, ...]], module.padding),
67
+ stride=cast(Union[int, Tuple[int, ...]], module.stride),
68
+ )
69
+ reshaped_grad = gradient_out.reshape(batch_size, -1, unfolded_act.shape[-1])
70
+ grad1 = torch.einsum("ijk,ilk->ijl", reshaped_grad, unfolded_act)
71
+ shape = [batch_size] + list(cast(Iterable[int], module.weight.shape))
72
+ module.weight.sample_grad += grad1.reshape(shape) # type: ignore
73
+ if module.bias is not None:
74
+ module.bias.sample_grad += torch.sum(reshaped_grad, dim=2) # type: ignore
75
+
76
+
77
+ SUPPORTED_MODULES = {
78
+ torch.nn.Conv2d: conv2d_param_grads,
79
+ torch.nn.Linear: linear_param_grads,
80
+ }
81
+
82
+
83
+ class LossMode(Enum):
84
+ SUM = 0
85
+ MEAN = 1
86
+
87
+
88
+ class SampleGradientWrapper:
89
+ r"""
90
+ Wrapper which allows computing sample-wise gradients in a single backward pass.
91
+
92
+ This is accomplished by adding hooks to capture activations and output
93
+ gradients for supported modules, and using these activations and gradients
94
+ to compute the parameter gradients per-sample.
95
+
96
+ Currently, only nn.Linear and nn.Conv2d modules are supported.
97
+
98
+ Similar reference implementations of sample-based gradients include:
99
+ - https://github.com/cybertronai/autograd-hacks
100
+ - https://github.com/pytorch/opacus/tree/main/opacus/grad_sample
101
+ """
102
+
103
+ def __init__(self, model):
104
+ self.model = model
105
+ self.hooks_added = False
106
+ self.activation_dict = defaultdict(list)
107
+ self.gradient_dict = defaultdict(list)
108
+ self.forward_hooks = []
109
+ self.backward_hooks = []
110
+
111
+ def add_hooks(self):
112
+ self.hooks_added = True
113
+ self.model.apply(self._register_module_hooks)
114
+
115
+ def _register_module_hooks(self, module: torch.nn.Module):
116
+ if isinstance(module, tuple(SUPPORTED_MODULES.keys())):
117
+ self.forward_hooks.append(
118
+ module.register_forward_hook(self._forward_hook_fn)
119
+ )
120
+ self.backward_hooks.append(
121
+ _register_backward_hook(module, self._backward_hook_fn, None)
122
+ )
123
+
124
+ def _forward_hook_fn(
125
+ self,
126
+ module: Module,
127
+ module_input: Union[Tensor, Tuple[Tensor, ...]],
128
+ module_output: Union[Tensor, Tuple[Tensor, ...]],
129
+ ):
130
+ inp_tuple = _format_tensor_into_tuples(module_input)
131
+ self.activation_dict[module].append(inp_tuple[0].clone().detach())
132
+
133
+ def _backward_hook_fn(
134
+ self,
135
+ module: Module,
136
+ grad_input: Union[Tensor, Tuple[Tensor, ...]],
137
+ grad_output: Union[Tensor, Tuple[Tensor, ...]],
138
+ ):
139
+ grad_output_tuple = _format_tensor_into_tuples(grad_output)
140
+ self.gradient_dict[module].append(grad_output_tuple[0].clone().detach())
141
+
142
+ def remove_hooks(self):
143
+ self.hooks_added = False
144
+
145
+ for hook in self.forward_hooks:
146
+ hook.remove()
147
+
148
+ for hook in self.backward_hooks:
149
+ hook.remove()
150
+
151
+ self.forward_hooks = []
152
+ self.backward_hooks = []
153
+
154
+ def _reset(self):
155
+ self.activation_dict = defaultdict(list)
156
+ self.gradient_dict = defaultdict(list)
157
+
158
+ def compute_param_sample_gradients(self, loss_blob, loss_mode="mean"):
159
+ assert (
160
+ loss_mode.upper() in LossMode.__members__
161
+ ), f"Provided loss mode {loss_mode} is not valid"
162
+ mode = LossMode[loss_mode.upper()]
163
+
164
+ self.model.zero_grad()
165
+ loss_blob.backward(gradient=torch.ones_like(loss_blob))
166
+
167
+ for module in self.gradient_dict:
168
+ sample_grad_fn = SUPPORTED_MODULES[type(module)]
169
+ activations = self.activation_dict[module]
170
+ gradients = self.gradient_dict[module]
171
+ assert len(activations) == len(gradients), (
172
+ "Number of saved activations do not match number of saved gradients."
173
+ " This may occur if multiple forward passes are run without calling"
174
+ " reset or computing param gradients."
175
+ )
176
+ # Reversing grads since when a module is used multiple times,
177
+ # the activations will be aligned with the reverse order of the gradients,
178
+ # since the order is reversed in backprop.
179
+ for i, (act, grad) in enumerate(
180
+ zip(activations, list(reversed(gradients)))
181
+ ):
182
+ mult = 1 if mode is LossMode.SUM else act.shape[0]
183
+ sample_grad_fn(module, act, grad * mult, reset=(i == 0))
184
+ self._reset()
captum/_utils/typing.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import List, Tuple, TYPE_CHECKING, TypeVar, Union
4
+
5
+ from torch import Tensor
6
+ from torch.nn import Module
7
+
8
+ if TYPE_CHECKING:
9
+ import sys
10
+
11
+ if sys.version_info >= (3, 8):
12
+ from typing import Literal # noqa: F401
13
+ else:
14
+ from typing_extensions import Literal # noqa: F401
15
+ else:
16
+ Literal = {True: bool, False: bool, (True, False): bool}
17
+
18
+ TensorOrTupleOfTensorsGeneric = TypeVar(
19
+ "TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...]
20
+ )
21
+ TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool)
22
+ ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module])
23
+ TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]]
24
+ BaselineType = Union[None, Tensor, int, float, Tuple[Union[Tensor, int, float], ...]]
25
+
26
+ TensorLikeList1D = List[float]
27
+ TensorLikeList2D = List[TensorLikeList1D]
28
+ TensorLikeList3D = List[TensorLikeList2D]
29
+ TensorLikeList4D = List[TensorLikeList3D]
30
+ TensorLikeList5D = List[TensorLikeList4D]
31
+ TensorLikeList = Union[
32
+ TensorLikeList1D,
33
+ TensorLikeList2D,
34
+ TensorLikeList3D,
35
+ TensorLikeList4D,
36
+ TensorLikeList5D,
37
+ ]
captum/attr/__init__.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from captum.attr._core.deep_lift import DeepLift, DeepLiftShap # noqa
3
+ from captum.attr._core.feature_ablation import FeatureAblation # noqa
4
+ from captum.attr._core.feature_permutation import FeaturePermutation # noqa
5
+ from captum.attr._core.gradient_shap import GradientShap # noqa
6
+ from captum.attr._core.guided_backprop_deconvnet import ( # noqa
7
+ Deconvolution,
8
+ GuidedBackprop,
9
+ )
10
+ from captum.attr._core.guided_grad_cam import GuidedGradCam # noqa
11
+ from captum.attr._core.input_x_gradient import InputXGradient # noqa
12
+ from captum.attr._core.integrated_gradients import IntegratedGradients # noqa
13
+ from captum.attr._core.kernel_shap import KernelShap # noqa
14
+ from captum.attr._core.layer.grad_cam import LayerGradCam # noqa
15
+ from captum.attr._core.layer.internal_influence import InternalInfluence # noqa
16
+ from captum.attr._core.layer.layer_activation import LayerActivation # noqa
17
+ from captum.attr._core.layer.layer_conductance import LayerConductance # noqa
18
+ from captum.attr._core.layer.layer_deep_lift import ( # noqa
19
+ LayerDeepLift,
20
+ LayerDeepLiftShap,
21
+ )
22
+ from captum.attr._core.layer.layer_feature_ablation import LayerFeatureAblation # noqa
23
+ from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap # noqa
24
+ from captum.attr._core.layer.layer_gradient_x_activation import ( # noqa
25
+ LayerGradientXActivation,
26
+ )
27
+ from captum.attr._core.layer.layer_integrated_gradients import ( # noqa
28
+ LayerIntegratedGradients,
29
+ )
30
+ from captum.attr._core.layer.layer_lrp import LayerLRP # noqa
31
+ from captum.attr._core.lime import Lime, LimeBase # noqa
32
+ from captum.attr._core.lrp import LRP # noqa
33
+ from captum.attr._core.neuron.neuron_conductance import NeuronConductance # noqa
34
+ from captum.attr._core.neuron.neuron_deep_lift import ( # noqa
35
+ NeuronDeepLift,
36
+ NeuronDeepLiftShap,
37
+ )
38
+ from captum.attr._core.neuron.neuron_feature_ablation import ( # noqa
39
+ NeuronFeatureAblation,
40
+ )
41
+ from captum.attr._core.neuron.neuron_gradient import NeuronGradient # noqa
42
+ from captum.attr._core.neuron.neuron_gradient_shap import NeuronGradientShap # noqa
43
+ from captum.attr._core.neuron.neuron_guided_backprop_deconvnet import ( # noqa
44
+ NeuronDeconvolution,
45
+ NeuronGuidedBackprop,
46
+ )
47
+ from captum.attr._core.neuron.neuron_integrated_gradients import ( # noqa
48
+ NeuronIntegratedGradients,
49
+ )
50
+ from captum.attr._core.noise_tunnel import NoiseTunnel # noqa
51
+ from captum.attr._core.occlusion import Occlusion # noqa
52
+ from captum.attr._core.saliency import Saliency # noqa
53
+ from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling # noqa
54
+ from captum.attr._models.base import ( # noqa
55
+ configure_interpretable_embedding_layer,
56
+ InterpretableEmbeddingBase,
57
+ remove_interpretable_embedding_layer,
58
+ TokenReferenceBase,
59
+ )
60
+ from captum.attr._utils import visualization # noqa
61
+ from captum.attr._utils.attribution import ( # noqa # noqa # noqa # noqa # noqa
62
+ Attribution,
63
+ GradientAttribution,
64
+ LayerAttribution,
65
+ NeuronAttribution,
66
+ PerturbationAttribution,
67
+ )
68
+ from captum.attr._utils.class_summarizer import ClassSummarizer
69
+ from captum.attr._utils.stat import (
70
+ CommonStats,
71
+ Count,
72
+ Max,
73
+ Mean,
74
+ Min,
75
+ MSE,
76
+ StdDev,
77
+ Sum,
78
+ Var,
79
+ )
80
+ from captum.attr._utils.summarizer import Summarizer
81
+
82
+ __all__ = [
83
+ "Attribution",
84
+ "GradientAttribution",
85
+ "PerturbationAttribution",
86
+ "NeuronAttribution",
87
+ "LayerAttribution",
88
+ "IntegratedGradients",
89
+ "DeepLift",
90
+ "DeepLiftShap",
91
+ "InputXGradient",
92
+ "Saliency",
93
+ "GuidedBackprop",
94
+ "Deconvolution",
95
+ "GuidedGradCam",
96
+ "FeatureAblation",
97
+ "FeaturePermutation",
98
+ "Occlusion",
99
+ "ShapleyValueSampling",
100
+ "ShapleyValues",
101
+ "LimeBase",
102
+ "Lime",
103
+ "LRP",
104
+ "KernelShap",
105
+ "LayerConductance",
106
+ "LayerGradientXActivation",
107
+ "LayerActivation",
108
+ "LayerFeatureAblation",
109
+ "InternalInfluence",
110
+ "LayerGradCam",
111
+ "LayerDeepLift",
112
+ "LayerDeepLiftShap",
113
+ "LayerGradientShap",
114
+ "LayerIntegratedGradients",
115
+ "LayerLRP",
116
+ "NeuronConductance",
117
+ "NeuronFeatureAblation",
118
+ "NeuronGradient",
119
+ "NeuronIntegratedGradients",
120
+ "NeuronDeepLift",
121
+ "NeuronDeepLiftShap",
122
+ "NeuronGradientShap",
123
+ "NeuronDeconvolution",
124
+ "NeuronGuidedBackprop",
125
+ "NoiseTunnel",
126
+ "GradientShap",
127
+ "InterpretableEmbeddingBase",
128
+ "TokenReferenceBase",
129
+ "visualization",
130
+ "configure_interpretable_embedding_layer",
131
+ "remove_interpretable_embedding_layer",
132
+ "Summarizer",
133
+ "CommonStats",
134
+ "ClassSummarizer",
135
+ "Mean",
136
+ "StdDev",
137
+ "MSE",
138
+ "Var",
139
+ "Min",
140
+ "Max",
141
+ "Sum",
142
+ "Count",
143
+ ]
captum/attr/_core/__init__.py ADDED
File without changes
captum/attr/_core/deep_lift.py ADDED
@@ -0,0 +1,1151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import typing
3
+ import warnings
4
+ from typing import Any, Callable, cast, List, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from captum._utils.common import (
10
+ _expand_additional_forward_args,
11
+ _expand_target,
12
+ _format_additional_forward_args,
13
+ _format_baseline,
14
+ _format_output,
15
+ _format_tensor_into_tuples,
16
+ _is_tuple,
17
+ _register_backward_hook,
18
+ _run_forward,
19
+ _select_targets,
20
+ ExpansionTypes,
21
+ )
22
+ from captum._utils.gradient import (
23
+ apply_gradient_requirements,
24
+ undo_gradient_requirements,
25
+ )
26
+ from captum._utils.typing import (
27
+ BaselineType,
28
+ Literal,
29
+ TargetType,
30
+ TensorOrTupleOfTensorsGeneric,
31
+ )
32
+ from captum.attr._utils.attribution import GradientAttribution
33
+ from captum.attr._utils.common import (
34
+ _call_custom_attribution_func,
35
+ _compute_conv_delta_and_format_attrs,
36
+ _format_callable_baseline,
37
+ _tensorize_baseline,
38
+ _validate_input,
39
+ )
40
+ from captum.log import log_usage
41
+ from torch import Tensor
42
+ from torch.nn import Module
43
+ from torch.utils.hooks import RemovableHandle
44
+
45
+
46
+ # Check if module backward hook can safely be used for the module that produced
47
+ # this inputs / outputs mapping
48
+ def _check_valid_module(inputs_grad_fn, outputs) -> bool:
49
+ def is_output_cloned(output_fn, input_grad_fn) -> bool:
50
+ """
51
+ Checks if the output has been cloned. This happens especially in case of
52
+ layer deeplift.
53
+ """
54
+ return (
55
+ output_fn[0].next_functions is not None
56
+ and output_fn[0].next_functions[0][0] == input_grad_fn
57
+ )
58
+
59
+ curr_fn = outputs.grad_fn
60
+ first_next = curr_fn.next_functions[0]
61
+ try:
62
+ # if `inputs` in the input to the network then the grad_fn is None and
63
+ # for that input backward_hook isn't computed. That's the reason why we
64
+ # need to check on `inputs_grad_fns[first_next[1]]` being None.
65
+ return (
66
+ inputs_grad_fn is None
67
+ or first_next[0] == inputs_grad_fn
68
+ or is_output_cloned(first_next, inputs_grad_fn)
69
+ )
70
+ except IndexError:
71
+ return False
72
+
73
+
74
+ class DeepLift(GradientAttribution):
75
+ r"""
76
+ Implements DeepLIFT algorithm based on the following paper:
77
+ Learning Important Features Through Propagating Activation Differences,
78
+ Avanti Shrikumar, et. al.
79
+ https://arxiv.org/abs/1704.02685
80
+
81
+ and the gradient formulation proposed in:
82
+ Towards better understanding of gradient-based attribution methods for
83
+ deep neural networks, Marco Ancona, et.al.
84
+ https://openreview.net/pdf?id=Sy21R9JAW
85
+
86
+ This implementation supports only Rescale rule. RevealCancel rule will
87
+ be supported in later releases.
88
+ In addition to that, in order to keep the implementation cleaner, DeepLIFT
89
+ for internal neurons and layers extends current implementation and is
90
+ implemented separately in LayerDeepLift and NeuronDeepLift.
91
+ Although DeepLIFT's(Rescale Rule) attribution quality is comparable with
92
+ Integrated Gradients, it runs significantly faster than Integrated
93
+ Gradients and is preferred for large datasets.
94
+
95
+ Currently we only support a limited number of non-linear activations
96
+ but the plan is to expand the list in the future.
97
+
98
+ Note: As we know, currently we cannot access the building blocks,
99
+ of PyTorch's built-in LSTM, RNNs and GRUs such as Tanh and Sigmoid.
100
+ Nonetheless, it is possible to build custom LSTMs, RNNS and GRUs
101
+ with performance similar to built-in ones using TorchScript.
102
+ More details on how to build custom RNNs can be found here:
103
+ https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ model: Module,
109
+ multiply_by_inputs: bool = True,
110
+ eps: float = 1e-10,
111
+ ) -> None:
112
+ r"""
113
+ Args:
114
+
115
+ model (nn.Module): The reference to PyTorch model instance. Model cannot
116
+ contain any in-place nonlinear submodules; these are not
117
+ supported by the register_full_backward_hook PyTorch API
118
+ starting from PyTorch v1.9.
119
+ multiply_by_inputs (bool, optional): Indicates whether to factor
120
+ model inputs' multiplier in the final attribution scores.
121
+ In the literature this is also known as local vs global
122
+ attribution. If inputs' multiplier isn't factored in
123
+ then that type of attribution method is also called local
124
+ attribution. If it is, then that type of attribution
125
+ method is called global.
126
+ More detailed can be found here:
127
+ https://arxiv.org/abs/1711.06104
128
+
129
+ In case of DeepLift, if `multiply_by_inputs`
130
+ is set to True, final sensitivity scores
131
+ are being multiplied by (inputs - baselines).
132
+ This flag applies only if `custom_attribution_func` is
133
+ set to None.
134
+
135
+ eps (float, optional): A value at which to consider output/input change
136
+ significant when computing the gradients for non-linear layers.
137
+ This is useful to adjust, depending on your model's bit depth,
138
+ to avoid numerical issues during the gradient computation.
139
+ Default: 1e-10
140
+ """
141
+ GradientAttribution.__init__(self, model)
142
+ self.model = model
143
+ self.eps = eps
144
+ self.forward_handles: List[RemovableHandle] = []
145
+ self.backward_handles: List[RemovableHandle] = []
146
+ self._multiply_by_inputs = multiply_by_inputs
147
+
148
+ @typing.overload
149
+ def attribute(
150
+ self,
151
+ inputs: TensorOrTupleOfTensorsGeneric,
152
+ baselines: BaselineType = None,
153
+ target: TargetType = None,
154
+ additional_forward_args: Any = None,
155
+ return_convergence_delta: Literal[False] = False,
156
+ custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
157
+ ) -> TensorOrTupleOfTensorsGeneric:
158
+ ...
159
+
160
+ @typing.overload
161
+ def attribute(
162
+ self,
163
+ inputs: TensorOrTupleOfTensorsGeneric,
164
+ baselines: BaselineType = None,
165
+ target: TargetType = None,
166
+ additional_forward_args: Any = None,
167
+ *,
168
+ return_convergence_delta: Literal[True],
169
+ custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
170
+ ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
171
+ ...
172
+
173
+ @log_usage()
174
+ def attribute( # type: ignore
175
+ self,
176
+ inputs: TensorOrTupleOfTensorsGeneric,
177
+ baselines: BaselineType = None,
178
+ target: TargetType = None,
179
+ additional_forward_args: Any = None,
180
+ return_convergence_delta: bool = False,
181
+ custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
182
+ ) -> Union[
183
+ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
184
+ ]:
185
+ r"""
186
+ Args:
187
+
188
+ inputs (tensor or tuple of tensors): Input for which
189
+ attributions are computed. If forward_func takes a single
190
+ tensor as input, a single input tensor should be provided.
191
+ If forward_func takes multiple tensors as input, a tuple
192
+ of the input tensors should be provided. It is assumed
193
+ that for all given input tensors, dimension 0 corresponds
194
+ to the number of examples (aka batch size), and if
195
+ multiple input tensors are provided, the examples must
196
+ be aligned appropriately.
197
+ baselines (scalar, tensor, tuple of scalars or tensors, optional):
198
+ Baselines define reference samples that are compared with
199
+ the inputs. In order to assign attribution scores DeepLift
200
+ computes the differences between the inputs/outputs and
201
+ corresponding references.
202
+ Baselines can be provided as:
203
+
204
+ - a single tensor, if inputs is a single tensor, with
205
+ exactly the same dimensions as inputs or the first
206
+ dimension is one and the remaining dimensions match
207
+ with inputs.
208
+
209
+ - a single scalar, if inputs is a single tensor, which will
210
+ be broadcasted for each input value in input tensor.
211
+
212
+ - a tuple of tensors or scalars, the baseline corresponding
213
+ to each tensor in the inputs' tuple can be:
214
+
215
+ - either a tensor with matching dimensions to
216
+ corresponding tensor in the inputs' tuple
217
+ or the first dimension is one and the remaining
218
+ dimensions match with the corresponding
219
+ input tensor.
220
+
221
+ - or a scalar, corresponding to a tensor in the
222
+ inputs' tuple. This scalar value is broadcasted
223
+ for corresponding input tensor.
224
+
225
+ In the cases when `baselines` is not provided, we internally
226
+ use zero scalar corresponding to each input tensor.
227
+
228
+ Default: None
229
+ target (int, tuple, tensor or list, optional): Output indices for
230
+ which gradients are computed (for classification cases,
231
+ this is usually the target class).
232
+ If the network returns a scalar value per example,
233
+ no target index is necessary.
234
+ For general 2D outputs, targets can be either:
235
+
236
+ - a single integer or a tensor containing a single
237
+ integer, which is applied to all input examples
238
+
239
+ - a list of integers or a 1D tensor, with length matching
240
+ the number of examples in inputs (dim 0). Each integer
241
+ is applied as the target for the corresponding example.
242
+
243
+ For outputs with > 2 dimensions, targets can be either:
244
+
245
+ - A single tuple, which contains #output_dims - 1
246
+ elements. This target index is applied to all examples.
247
+
248
+ - A list of tuples with length equal to the number of
249
+ examples in inputs (dim 0), and each tuple containing
250
+ #output_dims - 1 elements. Each tuple is applied as the
251
+ target for the corresponding example.
252
+
253
+ Default: None
254
+ additional_forward_args (any, optional): If the forward function
255
+ requires additional arguments other than the inputs for
256
+ which attributions should not be computed, this argument
257
+ can be provided. It must be either a single additional
258
+ argument of a Tensor or arbitrary (non-tuple) type or a tuple
259
+ containing multiple additional arguments including tensors
260
+ or any arbitrary python types. These arguments are provided to
261
+ forward_func in order, following the arguments in inputs.
262
+ Note that attributions are not computed with respect
263
+ to these arguments.
264
+ Default: None
265
+ return_convergence_delta (bool, optional): Indicates whether to return
266
+ convergence delta or not. If `return_convergence_delta`
267
+ is set to True convergence delta will be returned in
268
+ a tuple following attributions.
269
+ Default: False
270
+ custom_attribution_func (callable, optional): A custom function for
271
+ computing final attribution scores. This function can take
272
+ at least one and at most three arguments with the
273
+ following signature:
274
+
275
+ - custom_attribution_func(multipliers)
276
+ - custom_attribution_func(multipliers, inputs)
277
+ - custom_attribution_func(multipliers, inputs, baselines)
278
+
279
+ In case this function is not provided, we use the default
280
+ logic defined as: multipliers * (inputs - baselines)
281
+ It is assumed that all input arguments, `multipliers`,
282
+ `inputs` and `baselines` are provided in tuples of same
283
+ length. `custom_attribution_func` returns a tuple of
284
+ attribution tensors that have the same length as the
285
+ `inputs`.
286
+
287
+ Default: None
288
+
289
+ Returns:
290
+ **attributions** or 2-element tuple of **attributions**, **delta**:
291
+ - **attributions** (*tensor* or tuple of *tensors*):
292
+ Attribution score computed based on DeepLift rescale rule with respect
293
+ to each input feature. Attributions will always be
294
+ the same size as the provided inputs, with each value
295
+ providing the attribution of the corresponding input index.
296
+ If a single tensor is provided as inputs, a single tensor is
297
+ returned. If a tuple is provided for inputs, a tuple of
298
+ corresponding sized tensors is returned.
299
+ - **delta** (*tensor*, returned if return_convergence_delta=True):
300
+ This is computed using the property that
301
+ the total sum of forward_func(inputs) - forward_func(baselines)
302
+ must equal the total sum of the attributions computed
303
+ based on DeepLift's rescale rule.
304
+ Delta is calculated per example, meaning that the number of
305
+ elements in returned delta tensor is equal to the number of
306
+ of examples in input.
307
+ Note that the logic described for deltas is guaranteed when the
308
+ default logic for attribution computations is used, meaning that the
309
+ `custom_attribution_func=None`, otherwise it is not guaranteed and
310
+ depends on the specifics of the `custom_attribution_func`.
311
+
312
+ Examples::
313
+
314
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
315
+ >>> # and returns an Nx10 tensor of class probabilities.
316
+ >>> net = ImageClassifier()
317
+ >>> dl = DeepLift(net)
318
+ >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
319
+ >>> # Computes deeplift attribution scores for class 3.
320
+ >>> attribution = dl.attribute(input, target=3)
321
+ """
322
+
323
+ # Keeps track whether original input is a tuple or not before
324
+ # converting it into a tuple.
325
+ is_inputs_tuple = _is_tuple(inputs)
326
+
327
+ inputs = _format_tensor_into_tuples(inputs)
328
+ baselines = _format_baseline(baselines, inputs)
329
+
330
+ gradient_mask = apply_gradient_requirements(inputs)
331
+
332
+ _validate_input(inputs, baselines)
333
+
334
+ # set hooks for baselines
335
+ warnings.warn(
336
+ """Setting forward, backward hooks and attributes on non-linear
337
+ activations. The hooks and attributes will be removed
338
+ after the attribution is finished"""
339
+ )
340
+ baselines = _tensorize_baseline(inputs, baselines)
341
+ main_model_hooks = []
342
+ try:
343
+ main_model_hooks = self._hook_main_model()
344
+
345
+ self.model.apply(self._register_hooks)
346
+
347
+ additional_forward_args = _format_additional_forward_args(
348
+ additional_forward_args
349
+ )
350
+
351
+ expanded_target = _expand_target(
352
+ target, 2, expansion_type=ExpansionTypes.repeat
353
+ )
354
+
355
+ wrapped_forward_func = self._construct_forward_func(
356
+ self.model,
357
+ (inputs, baselines),
358
+ expanded_target,
359
+ additional_forward_args,
360
+ )
361
+ gradients = self.gradient_func(wrapped_forward_func, inputs)
362
+ if custom_attribution_func is None:
363
+ if self.multiplies_by_inputs:
364
+ attributions = tuple(
365
+ (input - baseline) * gradient
366
+ for input, baseline, gradient in zip(
367
+ inputs, baselines, gradients
368
+ )
369
+ )
370
+ else:
371
+ attributions = gradients
372
+ else:
373
+ attributions = _call_custom_attribution_func(
374
+ custom_attribution_func, gradients, inputs, baselines
375
+ )
376
+ finally:
377
+ # Even if any error is raised, remove all hooks before raising
378
+ self._remove_hooks(main_model_hooks)
379
+
380
+ undo_gradient_requirements(inputs, gradient_mask)
381
+ return _compute_conv_delta_and_format_attrs(
382
+ self,
383
+ return_convergence_delta,
384
+ attributions,
385
+ baselines,
386
+ inputs,
387
+ additional_forward_args,
388
+ target,
389
+ is_inputs_tuple,
390
+ )
391
+
392
+ def _construct_forward_func(
393
+ self,
394
+ forward_func: Callable,
395
+ inputs: Tuple,
396
+ target: TargetType = None,
397
+ additional_forward_args: Any = None,
398
+ ) -> Callable:
399
+ def forward_fn():
400
+ model_out = _run_forward(
401
+ forward_func, inputs, None, additional_forward_args
402
+ )
403
+ return _select_targets(
404
+ torch.cat((model_out[:, 0], model_out[:, 1])), target
405
+ )
406
+
407
+ if hasattr(forward_func, "device_ids"):
408
+ forward_fn.device_ids = forward_func.device_ids # type: ignore
409
+ return forward_fn
410
+
411
+ def _is_non_linear(self, module: Module) -> bool:
412
+ return type(module) in SUPPORTED_NON_LINEAR.keys()
413
+
414
+ def _forward_pre_hook_ref(
415
+ self, module: Module, inputs: Union[Tensor, Tuple[Tensor, ...]]
416
+ ) -> None:
417
+ inputs = _format_tensor_into_tuples(inputs)
418
+ module.input_ref = tuple( # type: ignore
419
+ input.clone().detach() for input in inputs
420
+ )
421
+
422
+ def _forward_pre_hook(
423
+ self, module: Module, inputs: Union[Tensor, Tuple[Tensor, ...]]
424
+ ) -> None:
425
+ """
426
+ For the modules that perform in-place operations such as ReLUs, we cannot
427
+ use inputs from forward hooks. This is because in that case inputs
428
+ and outputs are the same. We need access the inputs in pre-hooks and
429
+ set necessary hooks on inputs there.
430
+ """
431
+ inputs = _format_tensor_into_tuples(inputs)
432
+ module.input = inputs[0].clone().detach()
433
+ module.input_grad_fns = inputs[0].grad_fn # type: ignore
434
+
435
+ def tensor_backward_hook(grad):
436
+ if module.saved_grad is None:
437
+ raise RuntimeError(
438
+ """Module {} was detected as not supporting correctly module
439
+ backward hook. You should modify your hook to ignore the given
440
+ grad_inputs (recompute them by hand if needed) and save the
441
+ newly computed grad_inputs in module.saved_grad. See MaxPool1d
442
+ as an example.""".format(
443
+ module
444
+ )
445
+ )
446
+ return module.saved_grad
447
+
448
+ # the hook is set by default but it will be used only for
449
+ # failure cases and will be removed otherwise
450
+ handle = inputs[0].register_hook(tensor_backward_hook)
451
+ module.input_hook = handle
452
+
453
+ def _forward_hook(
454
+ self,
455
+ module: Module,
456
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
457
+ outputs: Union[Tensor, Tuple[Tensor, ...]],
458
+ ) -> None:
459
+ r"""
460
+ we need forward hook to access and detach the inputs and
461
+ outputs of a neuron
462
+ """
463
+ outputs = _format_tensor_into_tuples(outputs)
464
+ module.output = outputs[0].clone().detach()
465
+ if not _check_valid_module(module.input_grad_fns, outputs[0]):
466
+ warnings.warn(
467
+ """An invalid module {} is detected. Saved gradients will
468
+ be used as the gradients of the module's input tensor.
469
+ See MaxPool1d as an example.""".format(
470
+ module
471
+ )
472
+ )
473
+ module.is_invalid = True # type: ignore
474
+ module.saved_grad = None # type: ignore
475
+ self.forward_handles.append(cast(RemovableHandle, module.input_hook))
476
+ else:
477
+ module.is_invalid = False # type: ignore
478
+ # removing the hook if there is no failure case
479
+ cast(RemovableHandle, module.input_hook).remove()
480
+ del module.input_hook
481
+ del module.input_grad_fns
482
+
483
+ def _backward_hook(
484
+ self,
485
+ module: Module,
486
+ grad_input: Union[Tensor, Tuple[Tensor, ...]],
487
+ grad_output: Union[Tensor, Tuple[Tensor, ...]],
488
+ ):
489
+ r"""
490
+ `grad_input` is the gradient of the neuron with respect to its input
491
+ `grad_output` is the gradient of the neuron with respect to its output
492
+ we can override `grad_input` according to chain rule with.
493
+ `grad_output` * delta_out / delta_in.
494
+
495
+ """
496
+ # before accessing the attributes from the module we want
497
+ # to ensure that the properties exist, if not, then it is
498
+ # likely that the module is being reused.
499
+ attr_criteria = self.satisfies_attribute_criteria(module)
500
+ if not attr_criteria:
501
+ raise RuntimeError(
502
+ "A Module {} was detected that does not contain some of "
503
+ "the input/output attributes that are required for DeepLift "
504
+ "computations. This can occur, for example, if "
505
+ "your module is being used more than once in the network."
506
+ "Please, ensure that module is being used only once in the "
507
+ "network.".format(module)
508
+ )
509
+ multipliers = tuple(
510
+ SUPPORTED_NON_LINEAR[type(module)](
511
+ module,
512
+ module.input,
513
+ module.output,
514
+ grad_input,
515
+ grad_output,
516
+ eps=self.eps,
517
+ )
518
+ )
519
+ # remove all the properies that we set for the inputs and output
520
+ del module.input
521
+ del module.output
522
+
523
+ return multipliers
524
+
525
+ def satisfies_attribute_criteria(self, module: Module) -> bool:
526
+ return hasattr(module, "input") and hasattr(module, "output")
527
+
528
+ def _can_register_hook(self, module: Module) -> bool:
529
+ # TODO find a better way of checking if a module is a container or not
530
+ module_fullname = str(type(module))
531
+ has_already_hooks = len(module._backward_hooks) > 0 # type: ignore
532
+ return not (
533
+ "nn.modules.container" in module_fullname
534
+ or has_already_hooks
535
+ or not self._is_non_linear(module)
536
+ )
537
+
538
+ def _register_hooks(
539
+ self, module: Module, attribute_to_layer_input: bool = True
540
+ ) -> None:
541
+ if not self._can_register_hook(module) or (
542
+ not attribute_to_layer_input and module is self.layer # type: ignore
543
+ ):
544
+ return
545
+ # adds forward hook to leaf nodes that are non-linear
546
+ forward_handle = module.register_forward_hook(self._forward_hook)
547
+ pre_forward_handle = module.register_forward_pre_hook(self._forward_pre_hook)
548
+ backward_handle = _register_backward_hook(module, self._backward_hook, self)
549
+ self.forward_handles.append(forward_handle)
550
+ self.forward_handles.append(pre_forward_handle)
551
+ self.backward_handles.append(backward_handle)
552
+
553
+ def _remove_hooks(self, extra_hooks_to_remove: List[RemovableHandle]) -> None:
554
+ for handle in extra_hooks_to_remove:
555
+ handle.remove()
556
+ for forward_handle in self.forward_handles:
557
+ forward_handle.remove()
558
+ for backward_handle in self.backward_handles:
559
+ backward_handle.remove()
560
+
561
+ def _hook_main_model(self) -> List[RemovableHandle]:
562
+ def pre_hook(module: Module, baseline_inputs_add_args: Tuple) -> Tuple:
563
+ inputs = baseline_inputs_add_args[0]
564
+ baselines = baseline_inputs_add_args[1]
565
+ additional_args = None
566
+ if len(baseline_inputs_add_args) > 2:
567
+ additional_args = baseline_inputs_add_args[2:]
568
+
569
+ baseline_input_tsr = tuple(
570
+ torch.cat([input, baseline])
571
+ for input, baseline in zip(inputs, baselines)
572
+ )
573
+ if additional_args is not None:
574
+ expanded_additional_args = cast(
575
+ Tuple,
576
+ _expand_additional_forward_args(
577
+ additional_args, 2, ExpansionTypes.repeat
578
+ ),
579
+ )
580
+ return (*baseline_input_tsr, *expanded_additional_args)
581
+ return baseline_input_tsr
582
+
583
+ def forward_hook(module: Module, inputs: Tuple, outputs: Tensor):
584
+ return torch.stack(torch.chunk(outputs, 2), dim=1)
585
+
586
+ if isinstance(
587
+ self.model, (nn.DataParallel, nn.parallel.DistributedDataParallel)
588
+ ):
589
+ return [
590
+ self.model.module.register_forward_pre_hook(pre_hook), # type: ignore
591
+ self.model.module.register_forward_hook(forward_hook),
592
+ ] # type: ignore
593
+ else:
594
+ return [
595
+ self.model.register_forward_pre_hook(pre_hook), # type: ignore
596
+ self.model.register_forward_hook(forward_hook),
597
+ ] # type: ignore
598
+
599
+ def has_convergence_delta(self) -> bool:
600
+ return True
601
+
602
+ @property
603
+ def multiplies_by_inputs(self):
604
+ return self._multiply_by_inputs
605
+
606
+
607
+ class DeepLiftShap(DeepLift):
608
+ r"""
609
+ Extends DeepLift algorithm and approximates SHAP values using Deeplift.
610
+ For each input sample it computes DeepLift attribution with respect to
611
+ each baseline and averages resulting attributions.
612
+ More details about the algorithm can be found here:
613
+
614
+ http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf
615
+
616
+ Note that the explanation model:
617
+ 1. Assumes that input features are independent of one another
618
+ 2. Is linear, meaning that the explanations are modeled through
619
+ the additive composition of feature effects.
620
+ Although, it assumes a linear model for each explanation, the overall
621
+ model across multiple explanations can be complex and non-linear.
622
+ """
623
+
624
+ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None:
625
+ r"""
626
+ Args:
627
+
628
+ model (nn.Module): The reference to PyTorch model instance. Model cannot
629
+ contain any in-place nonlinear submodules; these are not
630
+ supported by the register_full_backward_hook PyTorch API.
631
+ multiply_by_inputs (bool, optional): Indicates whether to factor
632
+ model inputs' multiplier in the final attribution scores.
633
+ In the literature this is also known as local vs global
634
+ attribution. If inputs' multiplier isn't factored in
635
+ then that type of attribution method is also called local
636
+ attribution. If it is, then that type of attribution
637
+ method is called global.
638
+ More detailed can be found here:
639
+ https://arxiv.org/abs/1711.06104
640
+
641
+ In case of DeepLiftShap, if `multiply_by_inputs`
642
+ is set to True, final sensitivity scores
643
+ are being multiplied by (inputs - baselines).
644
+ This flag applies only if `custom_attribution_func` is
645
+ set to None.
646
+ """
647
+ DeepLift.__init__(self, model, multiply_by_inputs=multiply_by_inputs)
648
+
649
+ # There's a mismatch between the signatures of DeepLift.attribute and
650
+ # DeepLiftShap.attribute, so we ignore typing here
651
+ @typing.overload # type: ignore
652
+ def attribute(
653
+ self,
654
+ inputs: TensorOrTupleOfTensorsGeneric,
655
+ baselines: Union[
656
+ TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
657
+ ],
658
+ target: TargetType = None,
659
+ additional_forward_args: Any = None,
660
+ return_convergence_delta: Literal[False] = False,
661
+ custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
662
+ ) -> TensorOrTupleOfTensorsGeneric:
663
+ ...
664
+
665
+ @typing.overload
666
+ def attribute(
667
+ self,
668
+ inputs: TensorOrTupleOfTensorsGeneric,
669
+ baselines: Union[
670
+ TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
671
+ ],
672
+ target: TargetType = None,
673
+ additional_forward_args: Any = None,
674
+ *,
675
+ return_convergence_delta: Literal[True],
676
+ custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
677
+ ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
678
+ ...
679
+
680
+ @log_usage()
681
+ def attribute( # type: ignore
682
+ self,
683
+ inputs: TensorOrTupleOfTensorsGeneric,
684
+ baselines: Union[
685
+ TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
686
+ ],
687
+ target: TargetType = None,
688
+ additional_forward_args: Any = None,
689
+ return_convergence_delta: bool = False,
690
+ custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
691
+ ) -> Union[
692
+ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
693
+ ]:
694
+ r"""
695
+ Args:
696
+
697
+ inputs (tensor or tuple of tensors): Input for which
698
+ attributions are computed. If forward_func takes a single
699
+ tensor as input, a single input tensor should be provided.
700
+ If forward_func takes multiple tensors as input, a tuple
701
+ of the input tensors should be provided. It is assumed
702
+ that for all given input tensors, dimension 0 corresponds
703
+ to the number of examples (aka batch size), and if
704
+ multiple input tensors are provided, the examples must
705
+ be aligned appropriately.
706
+ baselines (tensor, tuple of tensors, callable):
707
+ Baselines define reference samples that are compared with
708
+ the inputs. In order to assign attribution scores DeepLift
709
+ computes the differences between the inputs/outputs and
710
+ corresponding references. Baselines can be provided as:
711
+
712
+ - a single tensor, if inputs is a single tensor, with
713
+ the first dimension equal to the number of examples
714
+ in the baselines' distribution. The remaining dimensions
715
+ must match with input tensor's dimension starting from
716
+ the second dimension.
717
+
718
+ - a tuple of tensors, if inputs is a tuple of tensors,
719
+ with the first dimension of any tensor inside the tuple
720
+ equal to the number of examples in the baseline's
721
+ distribution. The remaining dimensions must match
722
+ the dimensions of the corresponding input tensor
723
+ starting from the second dimension.
724
+
725
+ - callable function, optionally takes `inputs` as an
726
+ argument and either returns a single tensor
727
+ or a tuple of those.
728
+
729
+ It is recommended that the number of samples in the baselines'
730
+ tensors is larger than one.
731
+ target (int, tuple, tensor or list, optional): Output indices for
732
+ which gradients are computed (for classification cases,
733
+ this is usually the target class).
734
+ If the network returns a scalar value per example,
735
+ no target index is necessary.
736
+ For general 2D outputs, targets can be either:
737
+
738
+ - a single integer or a tensor containing a single
739
+ integer, which is applied to all input examples
740
+
741
+ - a list of integers or a 1D tensor, with length matching
742
+ the number of examples in inputs (dim 0). Each integer
743
+ is applied as the target for the corresponding example.
744
+
745
+ For outputs with > 2 dimensions, targets can be either:
746
+
747
+ - A single tuple, which contains #output_dims - 1
748
+ elements. This target index is applied to all examples.
749
+
750
+ - A list of tuples with length equal to the number of
751
+ examples in inputs (dim 0), and each tuple containing
752
+ #output_dims - 1 elements. Each tuple is applied as the
753
+ target for the corresponding example.
754
+
755
+ Default: None
756
+ additional_forward_args (any, optional): If the forward function
757
+ requires additional arguments other than the inputs for
758
+ which attributions should not be computed, this argument
759
+ can be provided. It must be either a single additional
760
+ argument of a Tensor or arbitrary (non-tuple) type or a tuple
761
+ containing multiple additional arguments including tensors
762
+ or any arbitrary python types. These arguments are provided to
763
+ forward_func in order, following the arguments in inputs.
764
+ Note that attributions are not computed with respect
765
+ to these arguments.
766
+ Default: None
767
+ return_convergence_delta (bool, optional): Indicates whether to return
768
+ convergence delta or not. If `return_convergence_delta`
769
+ is set to True convergence delta will be returned in
770
+ a tuple following attributions.
771
+ Default: False
772
+ custom_attribution_func (callable, optional): A custom function for
773
+ computing final attribution scores. This function can take
774
+ at least one and at most three arguments with the
775
+ following signature:
776
+
777
+ - custom_attribution_func(multipliers)
778
+ - custom_attribution_func(multipliers, inputs)
779
+ - custom_attribution_func(multipliers, inputs, baselines)
780
+
781
+ In case this function is not provided we use the default
782
+ logic defined as: multipliers * (inputs - baselines)
783
+ It is assumed that all input arguments, `multipliers`,
784
+ `inputs` and `baselines` are provided in tuples of same
785
+ length. `custom_attribution_func` returns a tuple of
786
+ attribution tensors that have the same length as the
787
+ `inputs`.
788
+ Default: None
789
+
790
+ Returns:
791
+ **attributions** or 2-element tuple of **attributions**, **delta**:
792
+ - **attributions** (*tensor* or tuple of *tensors*):
793
+ Attribution score computed based on DeepLift rescale rule with
794
+ respect to each input feature. Attributions will always be
795
+ the same size as the provided inputs, with each value
796
+ providing the attribution of the corresponding input index.
797
+ If a single tensor is provided as inputs, a single tensor is
798
+ returned. If a tuple is provided for inputs, a tuple of
799
+ corresponding sized tensors is returned.
800
+ - **delta** (*tensor*, returned if return_convergence_delta=True):
801
+ This is computed using the property that the
802
+ total sum of forward_func(inputs) - forward_func(baselines)
803
+ must be very close to the total sum of attributions
804
+ computed based on approximated SHAP values using
805
+ Deeplift's rescale rule.
806
+ Delta is calculated for each example input and baseline pair,
807
+ meaning that the number of elements in returned delta tensor
808
+ is equal to the
809
+ `number of examples in input` * `number of examples
810
+ in baseline`. The deltas are ordered in the first place by
811
+ input example, followed by the baseline.
812
+ Note that the logic described for deltas is guaranteed
813
+ when the default logic for attribution computations is used,
814
+ meaning that the `custom_attribution_func=None`, otherwise
815
+ it is not guaranteed and depends on the specifics of the
816
+ `custom_attribution_func`.
817
+
818
+ Examples::
819
+
820
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
821
+ >>> # and returns an Nx10 tensor of class probabilities.
822
+ >>> net = ImageClassifier()
823
+ >>> dl = DeepLiftShap(net)
824
+ >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
825
+ >>> # Computes shap values using deeplift for class 3.
826
+ >>> attribution = dl.attribute(input, target=3)
827
+ """
828
+ baselines = _format_callable_baseline(baselines, inputs)
829
+
830
+ assert isinstance(baselines[0], torch.Tensor) and baselines[0].shape[0] > 1, (
831
+ "Baselines distribution has to be provided in form of a torch.Tensor"
832
+ " with more than one example but found: {}."
833
+ " If baselines are provided in shape of scalars or with a single"
834
+ " baseline example, `DeepLift`"
835
+ " approach can be used instead.".format(baselines[0])
836
+ )
837
+
838
+ # Keeps track whether original input is a tuple or not before
839
+ # converting it into a tuple.
840
+ is_inputs_tuple = _is_tuple(inputs)
841
+
842
+ inputs = _format_tensor_into_tuples(inputs)
843
+
844
+ # batch sizes
845
+ inp_bsz = inputs[0].shape[0]
846
+ base_bsz = baselines[0].shape[0]
847
+
848
+ (
849
+ exp_inp,
850
+ exp_base,
851
+ exp_tgt,
852
+ exp_addit_args,
853
+ ) = self._expand_inputs_baselines_targets(
854
+ baselines, inputs, target, additional_forward_args
855
+ )
856
+ attributions = super().attribute.__wrapped__( # type: ignore
857
+ self,
858
+ exp_inp,
859
+ exp_base,
860
+ target=exp_tgt,
861
+ additional_forward_args=exp_addit_args,
862
+ return_convergence_delta=cast(
863
+ Literal[True, False], return_convergence_delta
864
+ ),
865
+ custom_attribution_func=custom_attribution_func,
866
+ )
867
+ if return_convergence_delta:
868
+ attributions, delta = cast(Tuple[Tuple[Tensor, ...], Tensor], attributions)
869
+
870
+ attributions = tuple(
871
+ self._compute_mean_across_baselines(
872
+ inp_bsz, base_bsz, cast(Tensor, attribution)
873
+ )
874
+ for attribution in attributions
875
+ )
876
+
877
+ if return_convergence_delta:
878
+ return _format_output(is_inputs_tuple, attributions), delta
879
+ else:
880
+ return _format_output(is_inputs_tuple, attributions)
881
+
882
+ def _expand_inputs_baselines_targets(
883
+ self,
884
+ baselines: Tuple[Tensor, ...],
885
+ inputs: Tuple[Tensor, ...],
886
+ target: TargetType,
887
+ additional_forward_args: Any,
888
+ ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], TargetType, Any]:
889
+ inp_bsz = inputs[0].shape[0]
890
+ base_bsz = baselines[0].shape[0]
891
+
892
+ expanded_inputs = tuple(
893
+ [
894
+ input.repeat_interleave(base_bsz, dim=0).requires_grad_()
895
+ for input in inputs
896
+ ]
897
+ )
898
+ expanded_baselines = tuple(
899
+ [
900
+ baseline.repeat(
901
+ (inp_bsz,) + tuple([1] * (len(baseline.shape) - 1))
902
+ ).requires_grad_()
903
+ for baseline in baselines
904
+ ]
905
+ )
906
+ expanded_target = _expand_target(
907
+ target, base_bsz, expansion_type=ExpansionTypes.repeat_interleave
908
+ )
909
+ input_additional_args = (
910
+ _expand_additional_forward_args(
911
+ additional_forward_args,
912
+ base_bsz,
913
+ expansion_type=ExpansionTypes.repeat_interleave,
914
+ )
915
+ if additional_forward_args is not None
916
+ else None
917
+ )
918
+ return (
919
+ expanded_inputs,
920
+ expanded_baselines,
921
+ expanded_target,
922
+ input_additional_args,
923
+ )
924
+
925
+ def _compute_mean_across_baselines(
926
+ self, inp_bsz: int, base_bsz: int, attribution: Tensor
927
+ ) -> Tensor:
928
+ # Average for multiple references
929
+ attr_shape: Tuple = (inp_bsz, base_bsz)
930
+ if len(attribution.shape) > 1:
931
+ attr_shape += attribution.shape[1:]
932
+ return torch.mean(attribution.view(attr_shape), dim=1, keepdim=False)
933
+
934
+
935
+ def nonlinear(
936
+ module: Module,
937
+ inputs: Tensor,
938
+ outputs: Tensor,
939
+ grad_input: Tensor,
940
+ grad_output: Tensor,
941
+ eps: float = 1e-10,
942
+ ):
943
+ r"""
944
+ grad_input: (dLoss / dprev_layer_out, dLoss / wij, dLoss / bij)
945
+ grad_output: (dLoss / dlayer_out)
946
+ https://github.com/pytorch/pytorch/issues/12331
947
+ """
948
+ delta_in, delta_out = _compute_diffs(inputs, outputs)
949
+
950
+ new_grad_inp = list(grad_input)
951
+
952
+ # supported non-linear modules take only single tensor as input hence accessing
953
+ # only the first element in `grad_input` and `grad_output`
954
+ new_grad_inp[0] = torch.where(
955
+ abs(delta_in) < eps, new_grad_inp[0], grad_output[0] * delta_out / delta_in
956
+ )
957
+
958
+ # If the module is invalid, save the newly computed gradients
959
+ # The original_grad_input will be overridden later in the Tensor hook
960
+ if module.is_invalid:
961
+ module.saved_grad = new_grad_inp[0]
962
+ return new_grad_inp
963
+
964
+
965
+ def softmax(
966
+ module: Module,
967
+ inputs: Tensor,
968
+ outputs: Tensor,
969
+ grad_input: Tensor,
970
+ grad_output: Tensor,
971
+ eps: float = 1e-10,
972
+ ):
973
+ delta_in, delta_out = _compute_diffs(inputs, outputs)
974
+
975
+ new_grad_inp = list(grad_input)
976
+ grad_input_unnorm = torch.where(
977
+ abs(delta_in) < eps, new_grad_inp[0], grad_output[0] * delta_out / delta_in
978
+ )
979
+ # normalizing
980
+ n = grad_input[0].numel()
981
+
982
+ # updating only the first half
983
+ new_grad_inp[0] = grad_input_unnorm - grad_input_unnorm.sum() * 1 / n
984
+ return new_grad_inp
985
+
986
+
987
+ def maxpool1d(
988
+ module: Module,
989
+ inputs: Tensor,
990
+ outputs: Tensor,
991
+ grad_input: Tensor,
992
+ grad_output: Tensor,
993
+ eps: float = 1e-10,
994
+ ):
995
+ return maxpool(
996
+ module,
997
+ F.max_pool1d,
998
+ F.max_unpool1d,
999
+ inputs,
1000
+ outputs,
1001
+ grad_input,
1002
+ grad_output,
1003
+ eps=eps,
1004
+ )
1005
+
1006
+
1007
+ def maxpool2d(
1008
+ module: Module,
1009
+ inputs: Tensor,
1010
+ outputs: Tensor,
1011
+ grad_input: Tensor,
1012
+ grad_output: Tensor,
1013
+ eps: float = 1e-10,
1014
+ ):
1015
+ return maxpool(
1016
+ module,
1017
+ F.max_pool2d,
1018
+ F.max_unpool2d,
1019
+ inputs,
1020
+ outputs,
1021
+ grad_input,
1022
+ grad_output,
1023
+ eps=eps,
1024
+ )
1025
+
1026
+
1027
+ def maxpool3d(
1028
+ module: Module, inputs, outputs, grad_input, grad_output, eps: float = 1e-10
1029
+ ):
1030
+ return maxpool(
1031
+ module,
1032
+ F.max_pool3d,
1033
+ F.max_unpool3d,
1034
+ inputs,
1035
+ outputs,
1036
+ grad_input,
1037
+ grad_output,
1038
+ eps=eps,
1039
+ )
1040
+
1041
+
1042
+ def maxpool(
1043
+ module: Module,
1044
+ pool_func: Callable,
1045
+ unpool_func: Callable,
1046
+ inputs,
1047
+ outputs,
1048
+ grad_input,
1049
+ grad_output,
1050
+ eps: float = 1e-10,
1051
+ ):
1052
+ with torch.no_grad():
1053
+ input, input_ref = inputs.chunk(2)
1054
+ output, output_ref = outputs.chunk(2)
1055
+
1056
+ delta_in = input - input_ref
1057
+ delta_in = torch.cat(2 * [delta_in])
1058
+ # Extracts cross maximum between the outputs of maxpool for the
1059
+ # actual inputs and its corresponding references. In case the delta outputs
1060
+ # for the references are larger the method relies on the references and
1061
+ # corresponding gradients to compute the multiplies and contributions.
1062
+ delta_out_xmax = torch.max(output, output_ref)
1063
+ delta_out = torch.cat([delta_out_xmax - output_ref, output - delta_out_xmax])
1064
+
1065
+ _, indices = pool_func(
1066
+ module.input,
1067
+ module.kernel_size,
1068
+ module.stride,
1069
+ module.padding,
1070
+ module.dilation,
1071
+ module.ceil_mode,
1072
+ True,
1073
+ )
1074
+ grad_output_updated = grad_output[0]
1075
+ unpool_grad_out_delta, unpool_grad_out_ref_delta = torch.chunk(
1076
+ unpool_func(
1077
+ grad_output_updated * delta_out,
1078
+ indices,
1079
+ module.kernel_size,
1080
+ module.stride,
1081
+ module.padding,
1082
+ list(cast(torch.Size, module.input.shape)),
1083
+ ),
1084
+ 2,
1085
+ )
1086
+
1087
+ unpool_grad_out_delta = unpool_grad_out_delta + unpool_grad_out_ref_delta
1088
+ unpool_grad_out_delta = torch.cat(2 * [unpool_grad_out_delta])
1089
+
1090
+ # If the module is invalid, we need to recompute the grad_input
1091
+ if module.is_invalid:
1092
+ original_grad_input = grad_input
1093
+ grad_input = (
1094
+ unpool_func(
1095
+ grad_output_updated,
1096
+ indices,
1097
+ module.kernel_size,
1098
+ module.stride,
1099
+ module.padding,
1100
+ list(cast(torch.Size, module.input.shape)),
1101
+ ),
1102
+ )
1103
+ if grad_input[0].shape != inputs.shape:
1104
+ raise AssertionError(
1105
+ "A problem occurred during maxpool modul's backward pass. "
1106
+ "The gradients with respect to inputs include only a "
1107
+ "subset of inputs. More details about this issue can "
1108
+ "be found here: "
1109
+ "https://pytorch.org/docs/stable/"
1110
+ "nn.html#torch.nn.Module.register_backward_hook "
1111
+ "This can happen for example if you attribute to the outputs of a "
1112
+ "MaxPool. As a workaround, please, attribute to the inputs of "
1113
+ "the following layer."
1114
+ )
1115
+
1116
+ new_grad_inp = torch.where(
1117
+ abs(delta_in) < eps, grad_input[0], unpool_grad_out_delta / delta_in
1118
+ )
1119
+ # If the module is invalid, save the newly computed gradients
1120
+ # The original_grad_input will be overridden later in the Tensor hook
1121
+ if module.is_invalid:
1122
+ module.saved_grad = new_grad_inp
1123
+ return original_grad_input
1124
+ else:
1125
+ return (new_grad_inp,)
1126
+
1127
+
1128
+ def _compute_diffs(inputs: Tensor, outputs: Tensor) -> Tuple[Tensor, Tensor]:
1129
+ input, input_ref = inputs.chunk(2)
1130
+ # if the model is a single non-linear module and we apply Rescale rule on it
1131
+ # we might not be able to perform chunk-ing because the output of the module is
1132
+ # usually being replaced by model output.
1133
+ output, output_ref = outputs.chunk(2)
1134
+ delta_in = input - input_ref
1135
+ delta_out = output - output_ref
1136
+
1137
+ return torch.cat(2 * [delta_in]), torch.cat(2 * [delta_out])
1138
+
1139
+
1140
+ SUPPORTED_NON_LINEAR = {
1141
+ nn.ReLU: nonlinear,
1142
+ nn.ELU: nonlinear,
1143
+ nn.LeakyReLU: nonlinear,
1144
+ nn.Sigmoid: nonlinear,
1145
+ nn.Tanh: nonlinear,
1146
+ nn.Softplus: nonlinear,
1147
+ nn.MaxPool1d: maxpool1d,
1148
+ nn.MaxPool2d: maxpool2d,
1149
+ nn.MaxPool3d: maxpool3d,
1150
+ nn.Softmax: softmax,
1151
+ }
captum/attr/_core/feature_ablation.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import math
4
+ from typing import Any, Callable, cast, Tuple, Union
5
+
6
+ import torch
7
+ from captum._utils.common import (
8
+ _expand_additional_forward_args,
9
+ _expand_target,
10
+ _format_additional_forward_args,
11
+ _format_output,
12
+ _format_tensor_into_tuples,
13
+ _is_tuple,
14
+ _run_forward,
15
+ )
16
+ from captum._utils.progress import progress
17
+ from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
18
+ from captum.attr._utils.attribution import PerturbationAttribution
19
+ from captum.attr._utils.common import _format_input_baseline
20
+ from captum.log import log_usage
21
+ from torch import dtype, Tensor
22
+
23
+
24
+ class FeatureAblation(PerturbationAttribution):
25
+ r"""
26
+ A perturbation based approach to computing attribution, involving
27
+ replacing each input feature with a given baseline / reference, and
28
+ computing the difference in output. By default, each scalar value within
29
+ each input tensor is taken as a feature and replaced independently. Passing
30
+ a feature mask, allows grouping features to be ablated together. This can
31
+ be used in cases such as images, where an entire segment or region
32
+ can be ablated, measuring the importance of the segment (feature group).
33
+ Each input scalar in the group will be given the same attribution value
34
+ equal to the change in target as a result of ablating the entire feature
35
+ group.
36
+
37
+ The forward function can either return a scalar per example or a tensor
38
+ of a fixed sized tensor (or scalar value) for the full batch, i.e. the
39
+ output does not grow as the batch size increase. If the output is fixed
40
+ we consider this model to be an "aggregation" of the inputs. In the fixed
41
+ sized output mode we require `perturbations_per_eval == 1` and the
42
+ `feature_mask` to be either `None` or for all of them to have 1 as their
43
+ first dimension (i.e. a feature mask requires to be applied to all inputs).
44
+ """
45
+
46
+ def __init__(self, forward_func: Callable) -> None:
47
+ r"""
48
+ Args:
49
+
50
+ forward_func (callable): The forward function of the model or
51
+ any modification of it
52
+ """
53
+ PerturbationAttribution.__init__(self, forward_func)
54
+ self.use_weights = False
55
+
56
+ @log_usage()
57
+ def attribute(
58
+ self,
59
+ inputs: TensorOrTupleOfTensorsGeneric,
60
+ baselines: BaselineType = None,
61
+ target: TargetType = None,
62
+ additional_forward_args: Any = None,
63
+ feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
64
+ perturbations_per_eval: int = 1,
65
+ show_progress: bool = False,
66
+ **kwargs: Any,
67
+ ) -> TensorOrTupleOfTensorsGeneric:
68
+ r"""
69
+ Args:
70
+
71
+ inputs (tensor or tuple of tensors): Input for which ablation
72
+ attributions are computed. If forward_func takes a single
73
+ tensor as input, a single input tensor should be provided.
74
+ If forward_func takes multiple tensors as input, a tuple
75
+ of the input tensors should be provided. It is assumed
76
+ that for all given input tensors, dimension 0 corresponds
77
+ to the number of examples (aka batch size), and if
78
+ multiple input tensors are provided, the examples must
79
+ be aligned appropriately.
80
+ baselines (scalar, tensor, tuple of scalars or tensors, optional):
81
+ Baselines define reference value which replaces each
82
+ feature when ablated.
83
+ Baselines can be provided as:
84
+
85
+ - a single tensor, if inputs is a single tensor, with
86
+ exactly the same dimensions as inputs or
87
+ broadcastable to match the dimensions of inputs
88
+
89
+ - a single scalar, if inputs is a single tensor, which will
90
+ be broadcasted for each input value in input tensor.
91
+
92
+ - a tuple of tensors or scalars, the baseline corresponding
93
+ to each tensor in the inputs' tuple can be:
94
+
95
+ - either a tensor with matching dimensions to
96
+ corresponding tensor in the inputs' tuple
97
+ or the first dimension is one and the remaining
98
+ dimensions match with the corresponding
99
+ input tensor.
100
+
101
+ - or a scalar, corresponding to a tensor in the
102
+ inputs' tuple. This scalar value is broadcasted
103
+ for corresponding input tensor.
104
+ In the cases when `baselines` is not provided, we internally
105
+ use zero scalar corresponding to each input tensor.
106
+ Default: None
107
+ target (int, tuple, tensor or list, optional): Output indices for
108
+ which gradients are computed (for classification cases,
109
+ this is usually the target class).
110
+ If the network returns a scalar value per example,
111
+ no target index is necessary.
112
+ For general 2D outputs, targets can be either:
113
+
114
+ - a single integer or a tensor containing a single
115
+ integer, which is applied to all input examples
116
+
117
+ - a list of integers or a 1D tensor, with length matching
118
+ the number of examples in inputs (dim 0). Each integer
119
+ is applied as the target for the corresponding example.
120
+
121
+ For outputs with > 2 dimensions, targets can be either:
122
+
123
+ - A single tuple, which contains #output_dims - 1
124
+ elements. This target index is applied to all examples.
125
+
126
+ - A list of tuples with length equal to the number of
127
+ examples in inputs (dim 0), and each tuple containing
128
+ #output_dims - 1 elements. Each tuple is applied as the
129
+ target for the corresponding example.
130
+
131
+ Default: None
132
+ additional_forward_args (any, optional): If the forward function
133
+ requires additional arguments other than the inputs for
134
+ which attributions should not be computed, this argument
135
+ can be provided. It must be either a single additional
136
+ argument of a Tensor or arbitrary (non-tuple) type or a
137
+ tuple containing multiple additional arguments including
138
+ tensors or any arbitrary python types. These arguments
139
+ are provided to forward_func in order following the
140
+ arguments in inputs.
141
+ For a tensor, the first dimension of the tensor must
142
+ correspond to the number of examples. For all other types,
143
+ the given argument is used for all forward evaluations.
144
+ Note that attributions are not computed with respect
145
+ to these arguments.
146
+ Default: None
147
+ feature_mask (tensor or tuple of tensors, optional):
148
+ feature_mask defines a mask for the input, grouping
149
+ features which should be ablated together. feature_mask
150
+ should contain the same number of tensors as inputs.
151
+ Each tensor should
152
+ be the same size as the corresponding input or
153
+ broadcastable to match the input tensor. Each tensor
154
+ should contain integers in the range 0 to num_features
155
+ - 1, and indices corresponding to the same feature should
156
+ have the same value.
157
+ Note that features within each input tensor are ablated
158
+ independently (not across tensors).
159
+ If the forward function returns a single scalar per batch,
160
+ we enforce that the first dimension of each mask must be 1,
161
+ since attributions are returned batch-wise rather than per
162
+ example, so the attributions must correspond to the
163
+ same features (indices) in each input example.
164
+ If None, then a feature mask is constructed which assigns
165
+ each scalar within a tensor as a separate feature, which
166
+ is ablated independently.
167
+ Default: None
168
+ perturbations_per_eval (int, optional): Allows ablation of multiple
169
+ features to be processed simultaneously in one call to
170
+ forward_fn.
171
+ Each forward pass will contain a maximum of
172
+ perturbations_per_eval * #examples samples.
173
+ For DataParallel models, each batch is split among the
174
+ available devices, so evaluations on each available
175
+ device contain at most
176
+ (perturbations_per_eval * #examples) / num_devices
177
+ samples.
178
+ If the forward function's number of outputs does not
179
+ change as the batch size grows (e.g. if it outputs a
180
+ scalar value), you must set perturbations_per_eval to 1
181
+ and use a single feature mask to describe the features
182
+ for all examples in the batch.
183
+ Default: 1
184
+ show_progress (bool, optional): Displays the progress of computation.
185
+ It will try to use tqdm if available for advanced features
186
+ (e.g. time estimation). Otherwise, it will fallback to
187
+ a simple output of progress.
188
+ Default: False
189
+ **kwargs (Any, optional): Any additional arguments used by child
190
+ classes of FeatureAblation (such as Occlusion) to construct
191
+ ablations. These arguments are ignored when using
192
+ FeatureAblation directly.
193
+ Default: None
194
+
195
+ Returns:
196
+ *tensor* or tuple of *tensors* of **attributions**:
197
+ - **attributions** (*tensor* or tuple of *tensors*):
198
+ The attributions with respect to each input feature.
199
+ If the forward function returns
200
+ a scalar value per example, attributions will be
201
+ the same size as the provided inputs, with each value
202
+ providing the attribution of the corresponding input index.
203
+ If the forward function returns a scalar per batch, then
204
+ attribution tensor(s) will have first dimension 1 and
205
+ the remaining dimensions will match the input.
206
+ If a single tensor is provided as inputs, a single tensor is
207
+ returned. If a tuple of tensors is provided for inputs, a
208
+ tuple of corresponding sized tensors is returned.
209
+
210
+
211
+ Examples::
212
+
213
+ >>> # SimpleClassifier takes a single input tensor of size Nx4x4,
214
+ >>> # and returns an Nx3 tensor of class probabilities.
215
+ >>> net = SimpleClassifier()
216
+ >>> # Generating random input with size 2 x 4 x 4
217
+ >>> input = torch.randn(2, 4, 4)
218
+ >>> # Defining FeatureAblation interpreter
219
+ >>> ablator = FeatureAblation(net)
220
+ >>> # Computes ablation attribution, ablating each of the 16
221
+ >>> # scalar input independently.
222
+ >>> attr = ablator.attribute(input, target=1)
223
+
224
+ >>> # Alternatively, we may want to ablate features in groups, e.g.
225
+ >>> # grouping each 2x2 square of the inputs and ablating them together.
226
+ >>> # This can be done by creating a feature mask as follows, which
227
+ >>> # defines the feature groups, e.g.:
228
+ >>> # +---+---+---+---+
229
+ >>> # | 0 | 0 | 1 | 1 |
230
+ >>> # +---+---+---+---+
231
+ >>> # | 0 | 0 | 1 | 1 |
232
+ >>> # +---+---+---+---+
233
+ >>> # | 2 | 2 | 3 | 3 |
234
+ >>> # +---+---+---+---+
235
+ >>> # | 2 | 2 | 3 | 3 |
236
+ >>> # +---+---+---+---+
237
+ >>> # With this mask, all inputs with the same value are ablated
238
+ >>> # simultaneously, and the attribution for each input in the same
239
+ >>> # group (0, 1, 2, and 3) per example are the same.
240
+ >>> # The attributions can be calculated as follows:
241
+ >>> # feature mask has dimensions 1 x 4 x 4
242
+ >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1],
243
+ >>> [2,2,3,3],[2,2,3,3]]])
244
+ >>> attr = ablator.attribute(input, target=1, feature_mask=feature_mask)
245
+ """
246
+ # Keeps track whether original input is a tuple or not before
247
+ # converting it into a tuple.
248
+ is_inputs_tuple = _is_tuple(inputs)
249
+ inputs, baselines = _format_input_baseline(inputs, baselines)
250
+ additional_forward_args = _format_additional_forward_args(
251
+ additional_forward_args
252
+ )
253
+ num_examples = inputs[0].shape[0]
254
+ feature_mask = (
255
+ _format_tensor_into_tuples(feature_mask)
256
+ if feature_mask is not None
257
+ else None
258
+ )
259
+ assert (
260
+ isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1
261
+ ), "Perturbations per evaluation must be an integer and at least 1."
262
+ with torch.no_grad():
263
+ if show_progress:
264
+ feature_counts = self._get_feature_counts(
265
+ inputs, feature_mask, **kwargs
266
+ )
267
+ total_forwards = (
268
+ sum(
269
+ math.ceil(count / perturbations_per_eval)
270
+ for count in feature_counts
271
+ )
272
+ + 1
273
+ ) # add 1 for the initial eval
274
+ attr_progress = progress(
275
+ desc=f"{self.get_name()} attribution", total=total_forwards
276
+ )
277
+ attr_progress.update(0)
278
+
279
+ # Computes initial evaluation with all features, which is compared
280
+ # to each ablated result.
281
+ initial_eval = _run_forward(
282
+ self.forward_func, inputs, target, additional_forward_args
283
+ )
284
+
285
+ if show_progress:
286
+ attr_progress.update()
287
+
288
+ agg_output_mode = FeatureAblation._find_output_mode(
289
+ perturbations_per_eval, feature_mask
290
+ )
291
+
292
+ # get as a 2D tensor (if it is not a scalar)
293
+ if isinstance(initial_eval, torch.Tensor):
294
+ initial_eval = initial_eval.reshape(1, -1)
295
+ num_outputs = initial_eval.shape[1]
296
+ else:
297
+ num_outputs = 1
298
+
299
+ if not agg_output_mode:
300
+ assert (
301
+ isinstance(initial_eval, torch.Tensor)
302
+ and num_outputs == num_examples
303
+ ), (
304
+ "expected output of `forward_func` to have "
305
+ + "`batch_size` elements for perturbations_per_eval > 1 "
306
+ + "and all feature_mask.shape[0] > 1"
307
+ )
308
+
309
+ # Initialize attribution totals and counts
310
+ attrib_type = cast(
311
+ dtype,
312
+ initial_eval.dtype
313
+ if isinstance(initial_eval, Tensor)
314
+ else type(initial_eval),
315
+ )
316
+
317
+ total_attrib = [
318
+ torch.zeros(
319
+ (num_outputs,) + input.shape[1:],
320
+ dtype=attrib_type,
321
+ device=input.device,
322
+ )
323
+ for input in inputs
324
+ ]
325
+
326
+ # Weights are used in cases where ablations may be overlapping.
327
+ if self.use_weights:
328
+ weights = [
329
+ torch.zeros(
330
+ (num_outputs,) + input.shape[1:], device=input.device
331
+ ).float()
332
+ for input in inputs
333
+ ]
334
+
335
+ # Iterate through each feature tensor for ablation
336
+ for i in range(len(inputs)):
337
+ # Skip any empty input tensors
338
+ if torch.numel(inputs[i]) == 0:
339
+ continue
340
+
341
+ for (
342
+ current_inputs,
343
+ current_add_args,
344
+ current_target,
345
+ current_mask,
346
+ ) in self._ith_input_ablation_generator(
347
+ i,
348
+ inputs,
349
+ additional_forward_args,
350
+ target,
351
+ baselines,
352
+ feature_mask,
353
+ perturbations_per_eval,
354
+ **kwargs,
355
+ ):
356
+ # modified_eval dimensions: 1D tensor with length
357
+ # equal to #num_examples * #features in batch
358
+ modified_eval = _run_forward(
359
+ self.forward_func,
360
+ current_inputs,
361
+ current_target,
362
+ current_add_args,
363
+ )
364
+
365
+ if show_progress:
366
+ attr_progress.update()
367
+
368
+ # (contains 1 more dimension than inputs). This adds extra
369
+ # dimensions of 1 to make the tensor broadcastable with the inputs
370
+ # tensor.
371
+ if not isinstance(modified_eval, torch.Tensor):
372
+ eval_diff = initial_eval - modified_eval
373
+ else:
374
+ if not agg_output_mode:
375
+ assert (
376
+ modified_eval.numel() == current_inputs[0].shape[0]
377
+ ), """expected output of forward_func to grow with
378
+ batch_size. If this is not the case for your model
379
+ please set perturbations_per_eval = 1"""
380
+
381
+ eval_diff = (
382
+ initial_eval - modified_eval.reshape((-1, num_outputs))
383
+ ).reshape((-1, num_outputs) + (len(inputs[i].shape) - 1) * (1,))
384
+ eval_diff = eval_diff.to(total_attrib[i].device)
385
+ if self.use_weights:
386
+ weights[i] += current_mask.float().sum(dim=0)
387
+ total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(
388
+ dim=0
389
+ )
390
+
391
+ if show_progress:
392
+ attr_progress.close()
393
+
394
+ # Divide total attributions by counts and return formatted attributions
395
+ if self.use_weights:
396
+ attrib = tuple(
397
+ single_attrib.float() / weight
398
+ for single_attrib, weight in zip(total_attrib, weights)
399
+ )
400
+ else:
401
+ attrib = tuple(total_attrib)
402
+ _result = _format_output(is_inputs_tuple, attrib)
403
+ return _result
404
+
405
+ def _ith_input_ablation_generator(
406
+ self,
407
+ i,
408
+ inputs,
409
+ additional_args,
410
+ target,
411
+ baselines,
412
+ input_mask,
413
+ perturbations_per_eval,
414
+ **kwargs,
415
+ ):
416
+ """
417
+ This method return an generator of ablation perturbations of the i-th input
418
+
419
+ Returns:
420
+ ablation_iter (generator): yields each perturbation to be evaluated
421
+ as a tuple (inputs, additional_forward_args, targets, mask).
422
+ """
423
+ extra_args = {}
424
+ for key, value in kwargs.items():
425
+ # For any tuple argument in kwargs, we choose index i of the tuple.
426
+ if isinstance(value, tuple):
427
+ extra_args[key] = value[i]
428
+ else:
429
+ extra_args[key] = value
430
+
431
+ input_mask = input_mask[i] if input_mask is not None else None
432
+ min_feature, num_features, input_mask = self._get_feature_range_and_mask(
433
+ inputs[i], input_mask, **extra_args
434
+ )
435
+ num_examples = inputs[0].shape[0]
436
+ perturbations_per_eval = min(perturbations_per_eval, num_features)
437
+ baseline = baselines[i] if isinstance(baselines, tuple) else baselines
438
+ if isinstance(baseline, torch.Tensor):
439
+ baseline = baseline.reshape((1,) + baseline.shape)
440
+
441
+ if perturbations_per_eval > 1:
442
+ # Repeat features and additional args for batch size.
443
+ all_features_repeated = [
444
+ torch.cat([inputs[j]] * perturbations_per_eval, dim=0)
445
+ for j in range(len(inputs))
446
+ ]
447
+ additional_args_repeated = (
448
+ _expand_additional_forward_args(additional_args, perturbations_per_eval)
449
+ if additional_args is not None
450
+ else None
451
+ )
452
+ target_repeated = _expand_target(target, perturbations_per_eval)
453
+ else:
454
+ all_features_repeated = list(inputs)
455
+ additional_args_repeated = additional_args
456
+ target_repeated = target
457
+
458
+ num_features_processed = min_feature
459
+ while num_features_processed < num_features:
460
+ current_num_ablated_features = min(
461
+ perturbations_per_eval, num_features - num_features_processed
462
+ )
463
+
464
+ # Store appropriate inputs and additional args based on batch size.
465
+ if current_num_ablated_features != perturbations_per_eval:
466
+ current_features = [
467
+ feature_repeated[0 : current_num_ablated_features * num_examples]
468
+ for feature_repeated in all_features_repeated
469
+ ]
470
+ current_additional_args = (
471
+ _expand_additional_forward_args(
472
+ additional_args, current_num_ablated_features
473
+ )
474
+ if additional_args is not None
475
+ else None
476
+ )
477
+ current_target = _expand_target(target, current_num_ablated_features)
478
+ else:
479
+ current_features = all_features_repeated
480
+ current_additional_args = additional_args_repeated
481
+ current_target = target_repeated
482
+
483
+ # Store existing tensor before modifying
484
+ original_tensor = current_features[i]
485
+ # Construct ablated batch for features in range num_features_processed
486
+ # to num_features_processed + current_num_ablated_features and return
487
+ # mask with same size as ablated batch. ablated_features has dimension
488
+ # (current_num_ablated_features, num_examples, inputs[i].shape[1:])
489
+ # Note that in the case of sparse tensors, the second dimension
490
+ # may not necessarilly be num_examples and will match the first
491
+ # dimension of this tensor.
492
+ current_reshaped = current_features[i].reshape(
493
+ (current_num_ablated_features, -1) + current_features[i].shape[1:]
494
+ )
495
+
496
+ ablated_features, current_mask = self._construct_ablated_input(
497
+ current_reshaped,
498
+ input_mask,
499
+ baseline,
500
+ num_features_processed,
501
+ num_features_processed + current_num_ablated_features,
502
+ **extra_args,
503
+ )
504
+
505
+ # current_features[i] has dimension
506
+ # (current_num_ablated_features * num_examples, inputs[i].shape[1:]),
507
+ # which can be provided to the model as input.
508
+ current_features[i] = ablated_features.reshape(
509
+ (-1,) + ablated_features.shape[2:]
510
+ )
511
+ yield tuple(
512
+ current_features
513
+ ), current_additional_args, current_target, current_mask
514
+ # Replace existing tensor at index i.
515
+ current_features[i] = original_tensor
516
+ num_features_processed += current_num_ablated_features
517
+
518
+ def _construct_ablated_input(
519
+ self, expanded_input, input_mask, baseline, start_feature, end_feature, **kwargs
520
+ ):
521
+ r"""
522
+ Ablates given expanded_input tensor with given feature mask, feature range,
523
+ and baselines. expanded_input shape is (`num_features`, `num_examples`, ...)
524
+ with remaining dimensions corresponding to remaining original tensor
525
+ dimensions and `num_features` = `end_feature` - `start_feature`.
526
+ input_mask has same number of dimensions as original input tensor (one less
527
+ than `expanded_input`), and can have first dimension either 1, applying same
528
+ feature mask to all examples, or `num_examples`. baseline is expected to
529
+ be broadcastable to match `expanded_input`.
530
+
531
+ This method returns the ablated input tensor, which has the same
532
+ dimensionality as `expanded_input` as well as the corresponding mask with
533
+ either the same dimensionality as `expanded_input` or second dimension
534
+ being 1. This mask contains 1s in locations which have been ablated (and
535
+ thus counted towards ablations for that feature) and 0s otherwise.
536
+ """
537
+ current_mask = torch.stack(
538
+ [input_mask == j for j in range(start_feature, end_feature)], dim=0
539
+ ).long()
540
+ ablated_tensor = (
541
+ expanded_input * (1 - current_mask).to(expanded_input.dtype)
542
+ ) + (baseline * current_mask.to(expanded_input.dtype))
543
+ return ablated_tensor, current_mask
544
+
545
+ def _get_feature_range_and_mask(self, input, input_mask, **kwargs):
546
+ if input_mask is None:
547
+ # Obtain feature mask for selected input tensor, matches size of
548
+ # 1 input example, (1 x inputs[i].shape[1:])
549
+ input_mask = torch.reshape(
550
+ torch.arange(torch.numel(input[0]), device=input.device),
551
+ input[0:1].shape,
552
+ ).long()
553
+ return (
554
+ torch.min(input_mask).item(),
555
+ torch.max(input_mask).item() + 1,
556
+ input_mask,
557
+ )
558
+
559
+ def _get_feature_counts(self, inputs, feature_mask, **kwargs):
560
+ """return the numbers of input features"""
561
+ if not feature_mask:
562
+ return tuple(inp[0].numel() if inp.numel() else 0 for inp in inputs)
563
+
564
+ return tuple(
565
+ (mask.max() - mask.min()).item() + 1
566
+ if mask is not None
567
+ else (inp[0].numel() if inp.numel() else 0)
568
+ for inp, mask in zip(inputs, feature_mask)
569
+ )
570
+
571
+ @staticmethod
572
+ def _find_output_mode(
573
+ perturbations_per_eval: int,
574
+ feature_mask: Union[None, TensorOrTupleOfTensorsGeneric],
575
+ ) -> bool:
576
+ """
577
+ Returns True if the output mode is "aggregation output mode"
578
+
579
+ Aggregation output mode is defined as: when there is no 1:1 correspondence
580
+ with the `num_examples` (`batch_size`) and the amount of outputs your model
581
+ produces, i.e. the model output does not grow in size as the input becomes
582
+ larger.
583
+
584
+ We assume this is the case if `perturbations_per_eval == 1`
585
+ and your feature mask is None or is associated to all
586
+ examples in a batch (fm.shape[0] == 1 for all fm in feature_mask).
587
+ """
588
+ return perturbations_per_eval == 1 and (
589
+ feature_mask is None
590
+ or all(len(sm.shape) == 0 or sm.shape[0] == 1 for sm in feature_mask)
591
+ )
captum/attr/_core/feature_permutation.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from typing import Any, Callable, Tuple, Union
3
+
4
+ import torch
5
+ from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
6
+ from captum.attr._core.feature_ablation import FeatureAblation
7
+ from captum.log import log_usage
8
+ from torch import Tensor
9
+
10
+
11
+ def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor:
12
+ n = x.size(0)
13
+ assert n > 1, "cannot permute features with batch_size = 1"
14
+
15
+ perm = torch.randperm(n)
16
+ no_perm = torch.arange(n)
17
+ while (perm == no_perm).all():
18
+ perm = torch.randperm(n)
19
+
20
+ return (x[perm] * feature_mask.to(dtype=x.dtype)) + (
21
+ x * feature_mask.bitwise_not().to(dtype=x.dtype)
22
+ )
23
+
24
+
25
+ class FeaturePermutation(FeatureAblation):
26
+ r"""
27
+ A perturbation based approach to compute attribution, which
28
+ takes each input feature, permutes the feature values within a batch,
29
+ and computes the difference between original and shuffled outputs for
30
+ the given batch. This difference signifies the feature importance
31
+ for the permuted feature.
32
+
33
+ Example pseudocode for the algorithm is as follows::
34
+
35
+ perm_feature_importance(batch):
36
+ importance = dict()
37
+ baseline_error = error_metric(model(batch), batch_labels)
38
+ for each feature:
39
+ permute this feature across the batch
40
+ error = error_metric(model(permuted_batch), batch_labels)
41
+ importance[feature] = baseline_error - error
42
+ "un-permute" the feature across the batch
43
+
44
+ return importance
45
+
46
+ It should be noted that the `error_metric` must be called in the
47
+ `forward_func`. You do not need to have an error metric, e.g. you
48
+ could simply return the logits (the model output), but this may or may
49
+ not provide a meaningful attribution.
50
+
51
+ This method, unlike other attribution methods, requires a batch
52
+ of examples to compute attributions and cannot be performed on a single example.
53
+
54
+ By default, each scalar value within
55
+ each input tensor is taken as a feature and shuffled independently. Passing
56
+ a feature mask, allows grouping features to be shuffled together.
57
+ Each input scalar in the group will be given the same attribution value
58
+ equal to the change in target as a result of shuffling the entire feature
59
+ group.
60
+
61
+ The forward function can either return a scalar per example, or a single
62
+ scalar for the full batch. If a single scalar is returned for the batch,
63
+ `perturbations_per_eval` must be 1, and the returned attributions will have
64
+ first dimension 1, corresponding to feature importance across all
65
+ examples in the batch.
66
+
67
+ More information can be found in the permutation feature
68
+ importance algorithm description here:
69
+ https://christophm.github.io/interpretable-ml-book/feature-importance.html
70
+ """
71
+
72
+ def __init__(
73
+ self, forward_func: Callable, perm_func: Callable = _permute_feature
74
+ ) -> None:
75
+ r"""
76
+ Args:
77
+
78
+ forward_func (callable): The forward function of the model or
79
+ any modification of it
80
+ perm_func (callable, optional): A function that accepts a batch of
81
+ inputs and a feature mask, and "permutes" the feature using
82
+ feature mask across the batch. This defaults to a function
83
+ which applies a random permutation, this argument only needs
84
+ to be provided if a custom permutation behavior is desired.
85
+ Default: `_permute_feature`
86
+ """
87
+ FeatureAblation.__init__(self, forward_func=forward_func)
88
+ self.perm_func = perm_func
89
+
90
+ # suppressing error caused by the child class not having a matching
91
+ # signature to the parent
92
+ @log_usage()
93
+ def attribute( # type: ignore
94
+ self,
95
+ inputs: TensorOrTupleOfTensorsGeneric,
96
+ target: TargetType = None,
97
+ additional_forward_args: Any = None,
98
+ feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
99
+ perturbations_per_eval: int = 1,
100
+ show_progress: bool = False,
101
+ **kwargs: Any,
102
+ ) -> TensorOrTupleOfTensorsGeneric:
103
+ r"""
104
+ This function is almost equivalent to `FeatureAblation.attribute`. The
105
+ main difference is the way ablated examples are generated. Specifically
106
+ they are generated through the `perm_func`, as we set the baselines for
107
+ `FeatureAblation.attribute` to None.
108
+
109
+
110
+ Args:
111
+ inputs (tensor or tuple of tensors): Input for which
112
+ permutation attributions are computed. If
113
+ forward_func takes a single tensor as input, a
114
+ single input tensor should be provided. If
115
+ forward_func takes multiple tensors as input, a
116
+ tuple of the input tensors should be provided. It is
117
+ assumed that for all given input tensors, dimension
118
+ 0 corresponds to the number of examples (aka batch
119
+ size), and if multiple input tensors are provided,
120
+ the examples must be aligned appropriately.
121
+ target (int, tuple, tensor or list, optional): Output indices for
122
+ which difference is computed (for classification cases,
123
+ this is usually the target class).
124
+ If the network returns a scalar value per example,
125
+ no target index is necessary.
126
+ For general 2D outputs, targets can be either:
127
+
128
+ - a single integer or a tensor containing a single
129
+ integer, which is applied to all input examples
130
+
131
+ - a list of integers or a 1D tensor, with length matching
132
+ the number of examples in inputs (dim 0). Each integer
133
+ is applied as the target for the corresponding example.
134
+
135
+ For outputs with > 2 dimensions, targets can be either:
136
+
137
+ - A single tuple, which contains #output_dims - 1
138
+ elements. This target index is applied to all examples.
139
+
140
+ - A list of tuples with length equal to the number of
141
+ examples in inputs (dim 0), and each tuple containing
142
+ #output_dims - 1 elements. Each tuple is applied as the
143
+ target for the corresponding example.
144
+
145
+ Default: None
146
+ additional_forward_args (any, optional): If the forward function
147
+ requires additional arguments other than the inputs for
148
+ which attributions should not be computed, this argument
149
+ can be provided. It must be either a single additional
150
+ argument of a Tensor or arbitrary (non-tuple) type or a
151
+ tuple containing multiple additional arguments including
152
+ tensors or any arbitrary python types. These arguments
153
+ are provided to forward_func in order following the
154
+ arguments in inputs.
155
+ For a tensor, the first dimension of the tensor must
156
+ correspond to the number of examples. For all other types,
157
+ the given argument is used for all forward evaluations.
158
+ Note that attributions are not computed with respect
159
+ to these arguments.
160
+ Default: None
161
+ feature_mask (tensor or tuple of tensors, optional):
162
+ feature_mask defines a mask for the input, grouping
163
+ features which should be ablated together. feature_mask
164
+ should contain the same number of tensors as inputs.
165
+ Each tensor should be the same size as the
166
+ corresponding input or broadcastable to match the
167
+ input tensor. Each tensor should contain integers in
168
+ the range 0 to num_features - 1, and indices
169
+ corresponding to the same feature should have the
170
+ same value. Note that features within each input
171
+ tensor are ablated independently (not across
172
+ tensors).
173
+
174
+ The first dimension of each mask must be 1, as we require
175
+ to have the same group of features for each input sample.
176
+
177
+ If None, then a feature mask is constructed which assigns
178
+ each scalar within a tensor as a separate feature, which
179
+ is permuted independently.
180
+ Default: None
181
+ perturbations_per_eval (int, optional): Allows permutations
182
+ of multiple features to be processed simultaneously
183
+ in one call to forward_fn. Each forward pass will
184
+ contain a maximum of perturbations_per_eval * #examples
185
+ samples. For DataParallel models, each batch is
186
+ split among the available devices, so evaluations on
187
+ each available device contain at most
188
+ (perturbations_per_eval * #examples) / num_devices
189
+ samples.
190
+ If the forward function returns a single scalar per batch,
191
+ perturbations_per_eval must be set to 1.
192
+ Default: 1
193
+ show_progress (bool, optional): Displays the progress of computation.
194
+ It will try to use tqdm if available for advanced features
195
+ (e.g. time estimation). Otherwise, it will fallback to
196
+ a simple output of progress.
197
+ Default: False
198
+ **kwargs (Any, optional): Any additional arguments used by child
199
+ classes of FeatureAblation (such as Occlusion) to construct
200
+ ablations. These arguments are ignored when using
201
+ FeatureAblation directly.
202
+ Default: None
203
+
204
+ Returns:
205
+ *tensor* or tuple of *tensors* of **attributions**:
206
+ - **attributions** (*tensor* or tuple of *tensors*):
207
+ The attributions with respect to each input feature.
208
+ If the forward function returns
209
+ a scalar value per example, attributions will be
210
+ the same size as the provided inputs, with each value
211
+ providing the attribution of the corresponding input index.
212
+ If the forward function returns a scalar per batch, then
213
+ attribution tensor(s) will have first dimension 1 and
214
+ the remaining dimensions will match the input.
215
+ If a single tensor is provided as inputs, a single tensor is
216
+ returned. If a tuple of tensors is provided for inputs,
217
+ a tuple of corresponding sized tensors is returned.
218
+
219
+
220
+ Examples::
221
+
222
+ >>> # SimpleClassifier takes a single input tensor of size Nx4x4,
223
+ >>> # and returns an Nx3 tensor of class probabilities.
224
+ >>> net = SimpleClassifier()
225
+ >>> # Generating random input with size 10 x 4 x 4
226
+ >>> input = torch.randn(10, 4, 4)
227
+ >>> # Defining FeaturePermutation interpreter
228
+ >>> feature_perm = FeaturePermutation(net)
229
+ >>> # Computes permutation attribution, shuffling each of the 16
230
+ >>> # scalar input independently.
231
+ >>> attr = feature_perm.attribute(input, target=1)
232
+
233
+ >>> # Alternatively, we may want to permute features in groups, e.g.
234
+ >>> # grouping each 2x2 square of the inputs and shuffling them together.
235
+ >>> # This can be done by creating a feature mask as follows, which
236
+ >>> # defines the feature groups, e.g.:
237
+ >>> # +---+---+---+---+
238
+ >>> # | 0 | 0 | 1 | 1 |
239
+ >>> # +---+---+---+---+
240
+ >>> # | 0 | 0 | 1 | 1 |
241
+ >>> # +---+---+---+---+
242
+ >>> # | 2 | 2 | 3 | 3 |
243
+ >>> # +---+---+---+---+
244
+ >>> # | 2 | 2 | 3 | 3 |
245
+ >>> # +---+---+---+---+
246
+ >>> # With this mask, all inputs with the same value are shuffled
247
+ >>> # simultaneously, and the attribution for each input in the same
248
+ >>> # group (0, 1, 2, and 3) per example are the same.
249
+ >>> # The attributions can be calculated as follows:
250
+ >>> # feature mask has dimensions 1 x 4 x 4
251
+ >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1],
252
+ >>> [2,2,3,3],[2,2,3,3]]])
253
+ >>> attr = feature_perm.attribute(input, target=1,
254
+ >>> feature_mask=feature_mask)
255
+ """
256
+ return FeatureAblation.attribute.__wrapped__(
257
+ self,
258
+ inputs,
259
+ baselines=None,
260
+ target=target,
261
+ additional_forward_args=additional_forward_args,
262
+ feature_mask=feature_mask,
263
+ perturbations_per_eval=perturbations_per_eval,
264
+ show_progress=show_progress,
265
+ **kwargs,
266
+ )
267
+
268
+ def _construct_ablated_input(
269
+ self,
270
+ expanded_input: Tensor,
271
+ input_mask: Tensor,
272
+ baseline: Union[int, float, Tensor],
273
+ start_feature: int,
274
+ end_feature: int,
275
+ **kwargs: Any,
276
+ ) -> Tuple[Tensor, Tensor]:
277
+ r"""
278
+ This function permutes the features of `expanded_input` with a given
279
+ feature mask and feature range. Permutation occurs via calling
280
+ `self.perm_func` across each batch within `expanded_input`. As with
281
+ `FeatureAblation._construct_ablated_input`:
282
+ - `expanded_input.shape = (num_features, num_examples, ...)`
283
+ - `num_features = end_feature - start_feature` (i.e. start and end is a
284
+ half-closed interval)
285
+ - `input_mask` is a tensor of the same shape as one input, which
286
+ describes the locations of each feature via their "index"
287
+
288
+ Since `baselines` is set to None for `FeatureAblation.attribute, this
289
+ will be the zero tensor, however, it is not used.
290
+ """
291
+ assert input_mask.shape[0] == 1, (
292
+ "input_mask.shape[0] != 1: pass in one mask in order to permute"
293
+ "the same features for each input"
294
+ )
295
+ current_mask = torch.stack(
296
+ [input_mask == j for j in range(start_feature, end_feature)], dim=0
297
+ ).bool()
298
+
299
+ output = torch.stack(
300
+ [
301
+ self.perm_func(x, mask.squeeze(0))
302
+ for x, mask in zip(expanded_input, current_mask)
303
+ ]
304
+ )
305
+ return output, current_mask
captum/attr/_core/gradient_shap.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import typing
3
+ from typing import Any, Callable, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from captum._utils.common import _is_tuple
8
+ from captum._utils.typing import (
9
+ BaselineType,
10
+ Literal,
11
+ TargetType,
12
+ Tensor,
13
+ TensorOrTupleOfTensorsGeneric,
14
+ )
15
+ from captum.attr._core.noise_tunnel import NoiseTunnel
16
+ from captum.attr._utils.attribution import GradientAttribution
17
+ from captum.attr._utils.common import (
18
+ _compute_conv_delta_and_format_attrs,
19
+ _format_callable_baseline,
20
+ _format_input_baseline,
21
+ )
22
+ from captum.log import log_usage
23
+
24
+
25
+ class GradientShap(GradientAttribution):
26
+ r"""
27
+ Implements gradient SHAP based on the implementation from SHAP's primary
28
+ author. For reference, please view the original
29
+ `implementation
30
+ <https://github.com/slundberg/shap#deep-learning-example-with-gradientexplainer-tensorflowkeraspytorch-models>`_
31
+ and the paper: `A Unified Approach to Interpreting Model Predictions
32
+ <https://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions>`_
33
+
34
+ GradientShap approximates SHAP values by computing the expectations of
35
+ gradients by randomly sampling from the distribution of baselines/references.
36
+ It adds white noise to each input sample `n_samples` times, selects a
37
+ random baseline from baselines' distribution and a random point along the
38
+ path between the baseline and the input, and computes the gradient of outputs
39
+ with respect to those selected random points. The final SHAP values represent
40
+ the expected values of gradients * (inputs - baselines).
41
+
42
+ GradientShap makes an assumption that the input features are independent
43
+ and that the explanation model is linear, meaning that the explanations
44
+ are modeled through the additive composition of feature effects.
45
+ Under those assumptions, SHAP value can be approximated as the expectation
46
+ of gradients that are computed for randomly generated `n_samples` input
47
+ samples after adding gaussian noise `n_samples` times to each input for
48
+ different baselines/references.
49
+
50
+ In some sense it can be viewed as an approximation of integrated gradients
51
+ by computing the expectations of gradients for different baselines.
52
+
53
+ Current implementation uses Smoothgrad from `NoiseTunnel` in order to
54
+ randomly draw samples from the distribution of baselines, add noise to input
55
+ samples and compute the expectation (smoothgrad).
56
+ """
57
+
58
+ def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> None:
59
+ r"""
60
+ Args:
61
+
62
+ forward_func (function): The forward function of the model or
63
+ any modification of it.
64
+ multiply_by_inputs (bool, optional): Indicates whether to factor
65
+ model inputs' multiplier in the final attribution scores.
66
+ In the literature this is also known as local vs global
67
+ attribution. If inputs' multiplier isn't factored in
68
+ then this type of attribution method is also called local
69
+ attribution. If it is, then that type of attribution
70
+ method is called global.
71
+ More detailed can be found here:
72
+ https://arxiv.org/abs/1711.06104
73
+
74
+ In case of gradient shap, if `multiply_by_inputs`
75
+ is set to True, the sensitivity scores of scaled inputs
76
+ are being multiplied by (inputs - baselines).
77
+ """
78
+ GradientAttribution.__init__(self, forward_func)
79
+ self._multiply_by_inputs = multiply_by_inputs
80
+
81
+ @typing.overload
82
+ def attribute(
83
+ self,
84
+ inputs: TensorOrTupleOfTensorsGeneric,
85
+ baselines: Union[
86
+ TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
87
+ ],
88
+ n_samples: int = 5,
89
+ stdevs: Union[float, Tuple[float, ...]] = 0.0,
90
+ target: TargetType = None,
91
+ additional_forward_args: Any = None,
92
+ *,
93
+ return_convergence_delta: Literal[True],
94
+ ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
95
+ ...
96
+
97
+ @typing.overload
98
+ def attribute(
99
+ self,
100
+ inputs: TensorOrTupleOfTensorsGeneric,
101
+ baselines: Union[
102
+ TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
103
+ ],
104
+ n_samples: int = 5,
105
+ stdevs: Union[float, Tuple[float, ...]] = 0.0,
106
+ target: TargetType = None,
107
+ additional_forward_args: Any = None,
108
+ return_convergence_delta: Literal[False] = False,
109
+ ) -> TensorOrTupleOfTensorsGeneric:
110
+ ...
111
+
112
+ @log_usage()
113
+ def attribute(
114
+ self,
115
+ inputs: TensorOrTupleOfTensorsGeneric,
116
+ baselines: Union[
117
+ TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
118
+ ],
119
+ n_samples: int = 5,
120
+ stdevs: Union[float, Tuple[float, ...]] = 0.0,
121
+ target: TargetType = None,
122
+ additional_forward_args: Any = None,
123
+ return_convergence_delta: bool = False,
124
+ ) -> Union[
125
+ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
126
+ ]:
127
+ r"""
128
+ Args:
129
+
130
+ inputs (tensor or tuple of tensors): Input for which SHAP attribution
131
+ values are computed. If `forward_func` takes a single
132
+ tensor as input, a single input tensor should be provided.
133
+ If `forward_func` takes multiple tensors as input, a tuple
134
+ of the input tensors should be provided. It is assumed
135
+ that for all given input tensors, dimension 0 corresponds
136
+ to the number of examples, and if multiple input tensors
137
+ are provided, the examples must be aligned appropriately.
138
+ baselines (tensor, tuple of tensors, callable):
139
+ Baselines define the starting point from which expectation
140
+ is computed and can be provided as:
141
+
142
+ - a single tensor, if inputs is a single tensor, with
143
+ the first dimension equal to the number of examples
144
+ in the baselines' distribution. The remaining dimensions
145
+ must match with input tensor's dimension starting from
146
+ the second dimension.
147
+
148
+ - a tuple of tensors, if inputs is a tuple of tensors,
149
+ with the first dimension of any tensor inside the tuple
150
+ equal to the number of examples in the baseline's
151
+ distribution. The remaining dimensions must match
152
+ the dimensions of the corresponding input tensor
153
+ starting from the second dimension.
154
+
155
+ - callable function, optionally takes `inputs` as an
156
+ argument and either returns a single tensor
157
+ or a tuple of those.
158
+
159
+ It is recommended that the number of samples in the baselines'
160
+ tensors is larger than one.
161
+ n_samples (int, optional): The number of randomly generated examples
162
+ per sample in the input batch. Random examples are
163
+ generated by adding gaussian random noise to each sample.
164
+ Default: `5` if `n_samples` is not provided.
165
+ stdevs (float, or a tuple of floats optional): The standard deviation
166
+ of gaussian noise with zero mean that is added to each
167
+ input in the batch. If `stdevs` is a single float value
168
+ then that same value is used for all inputs. If it is
169
+ a tuple, then it must have the same length as the inputs
170
+ tuple. In this case, each stdev value in the stdevs tuple
171
+ corresponds to the input with the same index in the inputs
172
+ tuple.
173
+ Default: 0.0
174
+ target (int, tuple, tensor or list, optional): Output indices for
175
+ which gradients are computed (for classification cases,
176
+ this is usually the target class).
177
+ If the network returns a scalar value per example,
178
+ no target index is necessary.
179
+ For general 2D outputs, targets can be either:
180
+
181
+ - a single integer or a tensor containing a single
182
+ integer, which is applied to all input examples
183
+
184
+ - a list of integers or a 1D tensor, with length matching
185
+ the number of examples in inputs (dim 0). Each integer
186
+ is applied as the target for the corresponding example.
187
+
188
+ For outputs with > 2 dimensions, targets can be either:
189
+
190
+ - A single tuple, which contains #output_dims - 1
191
+ elements. This target index is applied to all examples.
192
+
193
+ - A list of tuples with length equal to the number of
194
+ examples in inputs (dim 0), and each tuple containing
195
+ #output_dims - 1 elements. Each tuple is applied as the
196
+ target for the corresponding example.
197
+
198
+ Default: None
199
+ additional_forward_args (any, optional): If the forward function
200
+ requires additional arguments other than the inputs for
201
+ which attributions should not be computed, this argument
202
+ can be provided. It can contain a tuple of ND tensors or
203
+ any arbitrary python type of any shape.
204
+ In case of the ND tensor the first dimension of the
205
+ tensor must correspond to the batch size. It will be
206
+ repeated for each `n_steps` for each randomly generated
207
+ input sample.
208
+ Note that the gradients are not computed with respect
209
+ to these arguments.
210
+ Default: None
211
+ return_convergence_delta (bool, optional): Indicates whether to return
212
+ convergence delta or not. If `return_convergence_delta`
213
+ is set to True convergence delta will be returned in
214
+ a tuple following attributions.
215
+ Default: False
216
+ Returns:
217
+ **attributions** or 2-element tuple of **attributions**, **delta**:
218
+ - **attributions** (*tensor* or tuple of *tensors*):
219
+ Attribution score computed based on GradientSHAP with respect
220
+ to each input feature. Attributions will always be
221
+ the same size as the provided inputs, with each value
222
+ providing the attribution of the corresponding input index.
223
+ If a single tensor is provided as inputs, a single tensor is
224
+ returned. If a tuple is provided for inputs, a tuple of
225
+ corresponding sized tensors is returned.
226
+ - **delta** (*tensor*, returned if return_convergence_delta=True):
227
+ This is computed using the property that the total
228
+ sum of forward_func(inputs) - forward_func(baselines)
229
+ must be very close to the total sum of the attributions
230
+ based on GradientSHAP.
231
+ Delta is calculated for each example in the input after adding
232
+ `n_samples` times gaussian noise to each of them. Therefore,
233
+ the dimensionality of the deltas tensor is equal to the
234
+ `number of examples in the input` * `n_samples`
235
+ The deltas are ordered by each input example and `n_samples`
236
+ noisy samples generated for it.
237
+
238
+ Examples::
239
+
240
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
241
+ >>> # and returns an Nx10 tensor of class probabilities.
242
+ >>> net = ImageClassifier()
243
+ >>> gradient_shap = GradientShap(net)
244
+ >>> input = torch.randn(3, 3, 32, 32, requires_grad=True)
245
+ >>> # choosing baselines randomly
246
+ >>> baselines = torch.randn(20, 3, 32, 32)
247
+ >>> # Computes gradient shap for the input
248
+ >>> # Attribution size matches input size: 3x3x32x32
249
+ >>> attribution = gradient_shap.attribute(input, baselines,
250
+ target=5)
251
+
252
+ """
253
+ # since `baselines` is a distribution, we can generate it using a function
254
+ # rather than passing it as an input argument
255
+ baselines = _format_callable_baseline(baselines, inputs)
256
+ assert isinstance(baselines[0], torch.Tensor), (
257
+ "Baselines distribution has to be provided in a form "
258
+ "of a torch.Tensor {}.".format(baselines[0])
259
+ )
260
+
261
+ input_min_baseline_x_grad = InputBaselineXGradient(
262
+ self.forward_func, self.multiplies_by_inputs
263
+ )
264
+ input_min_baseline_x_grad.gradient_func = self.gradient_func
265
+
266
+ nt = NoiseTunnel(input_min_baseline_x_grad)
267
+
268
+ # NOTE: using attribute.__wrapped__ to not log
269
+ attributions = nt.attribute.__wrapped__(
270
+ nt, # self
271
+ inputs,
272
+ nt_type="smoothgrad",
273
+ nt_samples=n_samples,
274
+ stdevs=stdevs,
275
+ draw_baseline_from_distrib=True,
276
+ baselines=baselines,
277
+ target=target,
278
+ additional_forward_args=additional_forward_args,
279
+ return_convergence_delta=return_convergence_delta,
280
+ )
281
+
282
+ return attributions
283
+
284
+ def has_convergence_delta(self) -> bool:
285
+ return True
286
+
287
+ @property
288
+ def multiplies_by_inputs(self):
289
+ return self._multiply_by_inputs
290
+
291
+
292
+ class InputBaselineXGradient(GradientAttribution):
293
+ def __init__(self, forward_func: Callable, multiply_by_inputs=True) -> None:
294
+ r"""
295
+ Args:
296
+
297
+ forward_func (function): The forward function of the model or
298
+ any modification of it
299
+ multiply_by_inputs (bool, optional): Indicates whether to factor
300
+ model inputs' multiplier in the final attribution scores.
301
+ In the literature this is also known as local vs global
302
+ attribution. If inputs' multiplier isn't factored in
303
+ then this type of attribution method is also called local
304
+ attribution. If it is, then that type of attribution
305
+ method is called global.
306
+ More detailed can be found here:
307
+ https://arxiv.org/abs/1711.06104
308
+
309
+ In case of gradient shap, if `multiply_by_inputs`
310
+ is set to True, the sensitivity scores of scaled inputs
311
+ are being multiplied by (inputs - baselines).
312
+
313
+ """
314
+ GradientAttribution.__init__(self, forward_func)
315
+ self._multiply_by_inputs = multiply_by_inputs
316
+
317
+ @typing.overload
318
+ def attribute(
319
+ self,
320
+ inputs: TensorOrTupleOfTensorsGeneric,
321
+ baselines: BaselineType = None,
322
+ target: TargetType = None,
323
+ additional_forward_args: Any = None,
324
+ *,
325
+ return_convergence_delta: Literal[True],
326
+ ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
327
+ ...
328
+
329
+ @typing.overload
330
+ def attribute(
331
+ self,
332
+ inputs: TensorOrTupleOfTensorsGeneric,
333
+ baselines: BaselineType = None,
334
+ target: TargetType = None,
335
+ additional_forward_args: Any = None,
336
+ return_convergence_delta: Literal[False] = False,
337
+ ) -> TensorOrTupleOfTensorsGeneric:
338
+ ...
339
+
340
+ @log_usage()
341
+ def attribute( # type: ignore
342
+ self,
343
+ inputs: TensorOrTupleOfTensorsGeneric,
344
+ baselines: BaselineType = None,
345
+ target: TargetType = None,
346
+ additional_forward_args: Any = None,
347
+ return_convergence_delta: bool = False,
348
+ ) -> Union[
349
+ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
350
+ ]:
351
+ # Keeps track whether original input is a tuple or not before
352
+ # converting it into a tuple.
353
+ is_inputs_tuple = _is_tuple(inputs)
354
+ inputs, baselines = _format_input_baseline(inputs, baselines)
355
+
356
+ rand_coefficient = torch.tensor(
357
+ np.random.uniform(0.0, 1.0, inputs[0].shape[0]),
358
+ device=inputs[0].device,
359
+ dtype=inputs[0].dtype,
360
+ )
361
+
362
+ input_baseline_scaled = tuple(
363
+ _scale_input(input, baseline, rand_coefficient)
364
+ for input, baseline in zip(inputs, baselines)
365
+ )
366
+ grads = self.gradient_func(
367
+ self.forward_func, input_baseline_scaled, target, additional_forward_args
368
+ )
369
+
370
+ if self.multiplies_by_inputs:
371
+ input_baseline_diffs = tuple(
372
+ input - baseline for input, baseline in zip(inputs, baselines)
373
+ )
374
+ attributions = tuple(
375
+ input_baseline_diff * grad
376
+ for input_baseline_diff, grad in zip(input_baseline_diffs, grads)
377
+ )
378
+ else:
379
+ attributions = grads
380
+
381
+ return _compute_conv_delta_and_format_attrs(
382
+ self,
383
+ return_convergence_delta,
384
+ attributions,
385
+ baselines,
386
+ inputs,
387
+ additional_forward_args,
388
+ target,
389
+ is_inputs_tuple,
390
+ )
391
+
392
+ def has_convergence_delta(self) -> bool:
393
+ return True
394
+
395
+ @property
396
+ def multiplies_by_inputs(self):
397
+ return self._multiply_by_inputs
398
+
399
+
400
+ def _scale_input(
401
+ input: Tensor, baseline: Union[Tensor, int, float], rand_coefficient: Tensor
402
+ ) -> Tensor:
403
+ # batch size
404
+ bsz = input.shape[0]
405
+ inp_shape_wo_bsz = input.shape[1:]
406
+ inp_shape = (bsz,) + tuple([1] * len(inp_shape_wo_bsz))
407
+
408
+ # expand and reshape the indices
409
+ rand_coefficient = rand_coefficient.view(inp_shape)
410
+
411
+ input_baseline_scaled = (
412
+ rand_coefficient * input + (1.0 - rand_coefficient) * baseline
413
+ ).requires_grad_()
414
+ return input_baseline_scaled
captum/attr/_core/guided_backprop_deconvnet.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import warnings
3
+ from typing import Any, List, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from captum._utils.common import (
8
+ _format_output,
9
+ _format_tensor_into_tuples,
10
+ _is_tuple,
11
+ _register_backward_hook,
12
+ )
13
+ from captum._utils.gradient import (
14
+ apply_gradient_requirements,
15
+ undo_gradient_requirements,
16
+ )
17
+ from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
18
+ from captum.attr._utils.attribution import GradientAttribution
19
+ from captum.log import log_usage
20
+ from torch import Tensor
21
+ from torch.nn import Module
22
+ from torch.utils.hooks import RemovableHandle
23
+
24
+
25
+ class ModifiedReluGradientAttribution(GradientAttribution):
26
+ def __init__(self, model: Module, use_relu_grad_output: bool = False) -> None:
27
+ r"""
28
+ Args:
29
+
30
+ model (nn.Module): The reference to PyTorch model instance.
31
+ """
32
+ GradientAttribution.__init__(self, model)
33
+ self.model = model
34
+ self.backward_hooks: List[RemovableHandle] = []
35
+ self.use_relu_grad_output = use_relu_grad_output
36
+ assert isinstance(self.model, torch.nn.Module), (
37
+ "Given model must be an instance of torch.nn.Module to properly hook"
38
+ " ReLU layers."
39
+ )
40
+
41
+ @log_usage()
42
+ def attribute(
43
+ self,
44
+ inputs: TensorOrTupleOfTensorsGeneric,
45
+ target: TargetType = None,
46
+ additional_forward_args: Any = None,
47
+ ) -> TensorOrTupleOfTensorsGeneric:
48
+ r"""
49
+ Computes attribution by overriding relu gradients. Based on constructor
50
+ flag use_relu_grad_output, performs either GuidedBackpropagation if False
51
+ and Deconvolution if True. This class is the parent class of both these
52
+ methods, more information on usage can be found in the docstrings for each
53
+ implementing class.
54
+ """
55
+
56
+ # Keeps track whether original input is a tuple or not before
57
+ # converting it into a tuple.
58
+ is_inputs_tuple = _is_tuple(inputs)
59
+
60
+ inputs = _format_tensor_into_tuples(inputs)
61
+ gradient_mask = apply_gradient_requirements(inputs)
62
+
63
+ # set hooks for overriding ReLU gradients
64
+ warnings.warn(
65
+ "Setting backward hooks on ReLU activations."
66
+ "The hooks will be removed after the attribution is finished"
67
+ )
68
+ try:
69
+ self.model.apply(self._register_hooks)
70
+
71
+ gradients = self.gradient_func(
72
+ self.forward_func, inputs, target, additional_forward_args
73
+ )
74
+ finally:
75
+ self._remove_hooks()
76
+
77
+ undo_gradient_requirements(inputs, gradient_mask)
78
+ return _format_output(is_inputs_tuple, gradients)
79
+
80
+ def _register_hooks(self, module: Module):
81
+ if isinstance(module, torch.nn.ReLU):
82
+ hook = _register_backward_hook(module, self._backward_hook, self)
83
+ self.backward_hooks.append(hook)
84
+
85
+ def _backward_hook(
86
+ self,
87
+ module: Module,
88
+ grad_input: Union[Tensor, Tuple[Tensor, ...]],
89
+ grad_output: Union[Tensor, Tuple[Tensor, ...]],
90
+ ):
91
+ to_override_grads = grad_output if self.use_relu_grad_output else grad_input
92
+ if isinstance(to_override_grads, tuple):
93
+ return tuple(
94
+ F.relu(to_override_grad) for to_override_grad in to_override_grads
95
+ )
96
+ else:
97
+ return F.relu(to_override_grads)
98
+
99
+ def _remove_hooks(self):
100
+ for hook in self.backward_hooks:
101
+ hook.remove()
102
+
103
+
104
+ class GuidedBackprop(ModifiedReluGradientAttribution):
105
+ r"""
106
+ Computes attribution using guided backpropagation. Guided backpropagation
107
+ computes the gradient of the target output with respect to the input,
108
+ but gradients of ReLU functions are overridden so that only
109
+ non-negative gradients are backpropagated.
110
+
111
+ More details regarding the guided backpropagation algorithm can be found
112
+ in the original paper here:
113
+ https://arxiv.org/abs/1412.6806
114
+
115
+ Warning: Ensure that all ReLU operations in the forward function of the
116
+ given model are performed using a module (nn.module.ReLU).
117
+ If nn.functional.ReLU is used, gradients are not overridden appropriately.
118
+ """
119
+
120
+ def __init__(self, model: Module) -> None:
121
+ r"""
122
+ Args:
123
+
124
+ model (nn.Module): The reference to PyTorch model instance. Model cannot
125
+ contain any in-place ReLU submodules; these are not
126
+ supported by the register_full_backward_hook PyTorch API.
127
+ """
128
+ ModifiedReluGradientAttribution.__init__(
129
+ self, model, use_relu_grad_output=False
130
+ )
131
+
132
+ @log_usage()
133
+ def attribute(
134
+ self,
135
+ inputs: TensorOrTupleOfTensorsGeneric,
136
+ target: TargetType = None,
137
+ additional_forward_args: Any = None,
138
+ ) -> TensorOrTupleOfTensorsGeneric:
139
+ r"""
140
+ Args:
141
+
142
+ inputs (tensor or tuple of tensors): Input for which
143
+ attributions are computed. If forward_func takes a single
144
+ tensor as input, a single input tensor should be provided.
145
+ If forward_func takes multiple tensors as input, a tuple
146
+ of the input tensors should be provided. It is assumed
147
+ that for all given input tensors, dimension 0 corresponds
148
+ to the number of examples (aka batch size), and if
149
+ multiple input tensors are provided, the examples must
150
+ be aligned appropriately.
151
+ target (int, tuple, tensor or list, optional): Output indices for
152
+ which gradients are computed (for classification cases,
153
+ this is usually the target class).
154
+ If the network returns a scalar value per example,
155
+ no target index is necessary.
156
+ For general 2D outputs, targets can be either:
157
+
158
+ - a single integer or a tensor containing a single
159
+ integer, which is applied to all input examples
160
+
161
+ - a list of integers or a 1D tensor, with length matching
162
+ the number of examples in inputs (dim 0). Each integer
163
+ is applied as the target for the corresponding example.
164
+
165
+ For outputs with > 2 dimensions, targets can be either:
166
+
167
+ - A single tuple, which contains #output_dims - 1
168
+ elements. This target index is applied to all examples.
169
+
170
+ - A list of tuples with length equal to the number of
171
+ examples in inputs (dim 0), and each tuple containing
172
+ #output_dims - 1 elements. Each tuple is applied as the
173
+ target for the corresponding example.
174
+
175
+ Default: None
176
+ additional_forward_args (any, optional): If the forward function
177
+ requires additional arguments other than the inputs for
178
+ which attributions should not be computed, this argument
179
+ can be provided. It must be either a single additional
180
+ argument of a Tensor or arbitrary (non-tuple) type or a tuple
181
+ containing multiple additional arguments including tensors
182
+ or any arbitrary python types. These arguments are provided to
183
+ forward_func in order, following the arguments in inputs.
184
+ Note that attributions are not computed with respect
185
+ to these arguments.
186
+ Default: None
187
+
188
+ Returns:
189
+ *tensor* or tuple of *tensors* of **attributions**:
190
+ - **attributions** (*tensor* or tuple of *tensors*):
191
+ The guided backprop gradients with respect to each
192
+ input feature. Attributions will always
193
+ be the same size as the provided inputs, with each value
194
+ providing the attribution of the corresponding input index.
195
+ If a single tensor is provided as inputs, a single tensor is
196
+ returned. If a tuple is provided for inputs, a tuple of
197
+ corresponding sized tensors is returned.
198
+
199
+ Examples::
200
+
201
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
202
+ >>> # and returns an Nx10 tensor of class probabilities.
203
+ >>> net = ImageClassifier()
204
+ >>> gbp = GuidedBackprop(net)
205
+ >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
206
+ >>> # Computes Guided Backprop attribution scores for class 3.
207
+ >>> attribution = gbp.attribute(input, target=3)
208
+ """
209
+ return super().attribute.__wrapped__(
210
+ self, inputs, target, additional_forward_args
211
+ )
212
+
213
+
214
+ class Deconvolution(ModifiedReluGradientAttribution):
215
+ r"""
216
+ Computes attribution using deconvolution. Deconvolution
217
+ computes the gradient of the target output with respect to the input,
218
+ but gradients of ReLU functions are overridden so that the gradient
219
+ of the ReLU input is simply computed taking ReLU of the output gradient,
220
+ essentially only propagating non-negative gradients (without
221
+ dependence on the sign of the ReLU input).
222
+
223
+ More details regarding the deconvolution algorithm can be found
224
+ in these papers:
225
+ https://arxiv.org/abs/1311.2901
226
+ https://link.springer.com/chapter/10.1007/978-3-319-46466-4_8
227
+
228
+ Warning: Ensure that all ReLU operations in the forward function of the
229
+ given model are performed using a module (nn.module.ReLU).
230
+ If nn.functional.ReLU is used, gradients are not overridden appropriately.
231
+ """
232
+
233
+ def __init__(self, model: Module) -> None:
234
+ r"""
235
+ Args:
236
+
237
+ model (nn.Module): The reference to PyTorch model instance. Model cannot
238
+ contain any in-place ReLU submodules; these are not
239
+ supported by the register_full_backward_hook PyTorch API.
240
+ """
241
+ ModifiedReluGradientAttribution.__init__(self, model, use_relu_grad_output=True)
242
+
243
+ @log_usage()
244
+ def attribute(
245
+ self,
246
+ inputs: TensorOrTupleOfTensorsGeneric,
247
+ target: TargetType = None,
248
+ additional_forward_args: Any = None,
249
+ ) -> TensorOrTupleOfTensorsGeneric:
250
+ r"""
251
+ Args:
252
+
253
+ inputs (tensor or tuple of tensors): Input for which
254
+ attributions are computed. If forward_func takes a single
255
+ tensor as input, a single input tensor should be provided.
256
+ If forward_func takes multiple tensors as input, a tuple
257
+ of the input tensors should be provided. It is assumed
258
+ that for all given input tensors, dimension 0 corresponds
259
+ to the number of examples (aka batch size), and if
260
+ multiple input tensors are provided, the examples must
261
+ be aligned appropriately.
262
+ target (int, tuple, tensor or list, optional): Output indices for
263
+ which gradients are computed (for classification cases,
264
+ this is usually the target class).
265
+ If the network returns a scalar value per example,
266
+ no target index is necessary.
267
+ For general 2D outputs, targets can be either:
268
+
269
+ - a single integer or a tensor containing a single
270
+ integer, which is applied to all input examples
271
+
272
+ - a list of integers or a 1D tensor, with length matching
273
+ the number of examples in inputs (dim 0). Each integer
274
+ is applied as the target for the corresponding example.
275
+
276
+ For outputs with > 2 dimensions, targets can be either:
277
+
278
+ - A single tuple, which contains #output_dims - 1
279
+ elements. This target index is applied to all examples.
280
+
281
+ - A list of tuples with length equal to the number of
282
+ examples in inputs (dim 0), and each tuple containing
283
+ #output_dims - 1 elements. Each tuple is applied as the
284
+ target for the corresponding example.
285
+
286
+ Default: None
287
+ additional_forward_args (any, optional): If the forward function
288
+ requires additional arguments other than the inputs for
289
+ which attributions should not be computed, this argument
290
+ can be provided. It must be either a single additional
291
+ argument of a Tensor or arbitrary (non-tuple) type or a tuple
292
+ containing multiple additional arguments including tensors
293
+ or any arbitrary python types. These arguments are provided to
294
+ forward_func in order, following the arguments in inputs.
295
+ Note that attributions are not computed with respect
296
+ to these arguments.
297
+ Default: None
298
+
299
+ Returns:
300
+ *tensor* or tuple of *tensors* of **attributions**:
301
+ - **attributions** (*tensor* or tuple of *tensors*):
302
+ The deconvolution attributions with respect to each
303
+ input feature. Attributions will always
304
+ be the same size as the provided inputs, with each value
305
+ providing the attribution of the corresponding input index.
306
+ If a single tensor is provided as inputs, a single tensor is
307
+ returned. If a tuple is provided for inputs, a tuple of
308
+ corresponding sized tensors is returned.
309
+
310
+ Examples::
311
+
312
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
313
+ >>> # and returns an Nx10 tensor of class probabilities.
314
+ >>> net = ImageClassifier()
315
+ >>> deconv = Deconvolution(net)
316
+ >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
317
+ >>> # Computes Deconvolution attribution scores for class 3.
318
+ >>> attribution = deconv.attribute(input, target=3)
319
+ """
320
+ return super().attribute.__wrapped__(
321
+ self, inputs, target, additional_forward_args
322
+ )
captum/attr/_core/guided_grad_cam.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import warnings
3
+ from typing import Any, List, Union
4
+
5
+ import torch
6
+ from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
7
+ from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
8
+ from captum.attr._core.guided_backprop_deconvnet import GuidedBackprop
9
+ from captum.attr._core.layer.grad_cam import LayerGradCam
10
+ from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
11
+ from captum.log import log_usage
12
+ from torch import Tensor
13
+ from torch.nn import Module
14
+
15
+
16
+ class GuidedGradCam(GradientAttribution):
17
+ r"""
18
+ Computes element-wise product of guided backpropagation attributions
19
+ with upsampled (non-negative) GradCAM attributions.
20
+ GradCAM attributions are computed with respect to the layer
21
+ provided in the constructor, and attributions
22
+ are upsampled to match the input size. GradCAM is designed for
23
+ convolutional neural networks, and is usually applied to the last
24
+ convolutional layer.
25
+
26
+ Note that if multiple input tensors are provided, attributions for
27
+ each input tensor are computed by upsampling the GradCAM
28
+ attributions to match that input's dimensions. If interpolation is
29
+ not possible for the input tensor dimensions and interpolation mode,
30
+ then an empty tensor is returned in the attributions for the
31
+ corresponding position of that input tensor. This can occur if the
32
+ input tensor does not have the same number of dimensions as the chosen
33
+ layer's output or is not either 3D, 4D or 5D.
34
+
35
+ Note that attributions are only meaningful for input tensors
36
+ which are spatially alligned with the chosen layer, e.g. an input
37
+ image tensor for a convolutional layer.
38
+
39
+ More details regarding GuidedGradCAM can be found in the original
40
+ GradCAM paper here:
41
+ https://arxiv.org/pdf/1610.02391.pdf
42
+
43
+ Warning: Ensure that all ReLU operations in the forward function of the
44
+ given model are performed using a module (nn.module.ReLU).
45
+ If nn.functional.ReLU is used, gradients are not overridden appropriately.
46
+ """
47
+
48
+ def __init__(
49
+ self, model: Module, layer: Module, device_ids: Union[None, List[int]] = None
50
+ ) -> None:
51
+ r"""
52
+ Args:
53
+
54
+ model (nn.Module): The reference to PyTorch model instance. Model cannot
55
+ contain any in-place ReLU submodules; these are not
56
+ supported by the register_full_backward_hook PyTorch API
57
+ starting from PyTorch v1.9.
58
+ layer (torch.nn.Module): Layer for which GradCAM attributions are computed.
59
+ Currently, only layers with a single tensor output are
60
+ supported.
61
+ device_ids (list(int)): Device ID list, necessary only if forward_func
62
+ applies a DataParallel model. This allows reconstruction of
63
+ intermediate outputs from batched results across devices.
64
+ If forward_func is given as the DataParallel model itself,
65
+ then it is not necessary to provide this argument.
66
+ """
67
+ GradientAttribution.__init__(self, model)
68
+ self.grad_cam = LayerGradCam(model, layer, device_ids)
69
+ self.guided_backprop = GuidedBackprop(model)
70
+
71
+ @log_usage()
72
+ def attribute(
73
+ self,
74
+ inputs: TensorOrTupleOfTensorsGeneric,
75
+ target: TargetType = None,
76
+ additional_forward_args: Any = None,
77
+ interpolate_mode: str = "nearest",
78
+ attribute_to_layer_input: bool = False,
79
+ ) -> TensorOrTupleOfTensorsGeneric:
80
+ r"""
81
+ Args:
82
+
83
+ inputs (tensor or tuple of tensors): Input for which attributions
84
+ are computed. If forward_func takes a single
85
+ tensor as input, a single input tensor should be provided.
86
+ If forward_func takes multiple tensors as input, a tuple
87
+ of the input tensors should be provided. It is assumed
88
+ that for all given input tensors, dimension 0 corresponds
89
+ to the number of examples, and if multiple input tensors
90
+ are provided, the examples must be aligned appropriately.
91
+ target (int, tuple, tensor or list, optional): Output indices for
92
+ which gradients are computed (for classification cases,
93
+ this is usually the target class).
94
+ If the network returns a scalar value per example,
95
+ no target index is necessary.
96
+ For general 2D outputs, targets can be either:
97
+
98
+ - a single integer or a tensor containing a single
99
+ integer, which is applied to all input examples
100
+
101
+ - a list of integers or a 1D tensor, with length matching
102
+ the number of examples in inputs (dim 0). Each integer
103
+ is applied as the target for the corresponding example.
104
+
105
+ For outputs with > 2 dimensions, targets can be either:
106
+
107
+ - A single tuple, which contains #output_dims - 1
108
+ elements. This target index is applied to all examples.
109
+
110
+ - A list of tuples with length equal to the number of
111
+ examples in inputs (dim 0), and each tuple containing
112
+ #output_dims - 1 elements. Each tuple is applied as the
113
+ target for the corresponding example.
114
+
115
+ Default: None
116
+ additional_forward_args (any, optional): If the forward function
117
+ requires additional arguments other than the inputs for
118
+ which attributions should not be computed, this argument
119
+ can be provided. It must be either a single additional
120
+ argument of a Tensor or arbitrary (non-tuple) type or a
121
+ tuple containing multiple additional arguments including
122
+ tensors or any arbitrary python types. These arguments
123
+ are provided to forward_func in order following the
124
+ arguments in inputs.
125
+ Note that attributions are not computed with respect
126
+ to these arguments.
127
+ Default: None
128
+ interpolate_mode (str, optional): Method for interpolation, which
129
+ must be a valid input interpolation mode for
130
+ torch.nn.functional. These methods are
131
+ "nearest", "area", "linear" (3D-only), "bilinear"
132
+ (4D-only), "bicubic" (4D-only), "trilinear" (5D-only)
133
+ based on the number of dimensions of the chosen layer
134
+ output (which must also match the number of
135
+ dimensions for the input tensor). Note that
136
+ the original GradCAM paper uses "bilinear"
137
+ interpolation, but we default to "nearest" for
138
+ applicability to any of 3D, 4D or 5D tensors.
139
+ Default: "nearest"
140
+ attribute_to_layer_input (bool, optional): Indicates whether to
141
+ compute the attribution with respect to the layer input
142
+ or output in `LayerGradCam`.
143
+ If `attribute_to_layer_input` is set to True
144
+ then the attributions will be computed with respect to
145
+ layer inputs, otherwise it will be computed with respect
146
+ to layer outputs.
147
+ Note that currently it is assumed that either the input
148
+ or the output of internal layer, depending on whether we
149
+ attribute to the input or output, is a single tensor.
150
+ Support for multiple tensors will be added later.
151
+ Default: False
152
+
153
+ Returns:
154
+ *tensor* of **attributions**:
155
+ - **attributions** (*tensor*):
156
+ Element-wise product of (upsampled) GradCAM
157
+ and Guided Backprop attributions.
158
+ If a single tensor is provided as inputs, a single tensor is
159
+ returned. If a tuple is provided for inputs, a tuple of
160
+ corresponding sized tensors is returned.
161
+ Attributions will be the same size as the provided inputs,
162
+ with each value providing the attribution of the
163
+ corresponding input index.
164
+ If the GradCAM attributions cannot be upsampled to the shape
165
+ of a given input tensor, None is returned in the corresponding
166
+ index position.
167
+
168
+
169
+ Examples::
170
+
171
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
172
+ >>> # and returns an Nx10 tensor of class probabilities.
173
+ >>> # It contains an attribute conv4, which is an instance of nn.conv2d,
174
+ >>> # and the output of this layer has dimensions Nx50x8x8.
175
+ >>> # It is the last convolution layer, which is the recommended
176
+ >>> # use case for GuidedGradCAM.
177
+ >>> net = ImageClassifier()
178
+ >>> guided_gc = GuidedGradCam(net, net.conv4)
179
+ >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
180
+ >>> # Computes guided GradCAM attributions for class 3.
181
+ >>> # attribution size matches input size, Nx3x32x32
182
+ >>> attribution = guided_gc.attribute(input, 3)
183
+ """
184
+ is_inputs_tuple = _is_tuple(inputs)
185
+ inputs = _format_tensor_into_tuples(inputs)
186
+ grad_cam_attr = self.grad_cam.attribute.__wrapped__(
187
+ self.grad_cam, # self
188
+ inputs=inputs,
189
+ target=target,
190
+ additional_forward_args=additional_forward_args,
191
+ attribute_to_layer_input=attribute_to_layer_input,
192
+ relu_attributions=True,
193
+ )
194
+ if isinstance(grad_cam_attr, tuple):
195
+ assert len(grad_cam_attr) == 1, (
196
+ "GuidedGradCAM attributions for layer with multiple inputs / "
197
+ "outputs is not supported."
198
+ )
199
+ grad_cam_attr = grad_cam_attr[0]
200
+
201
+ guided_backprop_attr = self.guided_backprop.attribute.__wrapped__(
202
+ self.guided_backprop, # self
203
+ inputs=inputs,
204
+ target=target,
205
+ additional_forward_args=additional_forward_args,
206
+ )
207
+ output_attr: List[Tensor] = []
208
+ for i in range(len(inputs)):
209
+ try:
210
+ output_attr.append(
211
+ guided_backprop_attr[i]
212
+ * LayerAttribution.interpolate(
213
+ grad_cam_attr,
214
+ inputs[i].shape[2:],
215
+ interpolate_mode=interpolate_mode,
216
+ )
217
+ )
218
+ except Exception:
219
+ warnings.warn(
220
+ "Couldn't appropriately interpolate GradCAM attributions for some "
221
+ "input tensors, returning empty tensor for corresponding "
222
+ "attributions."
223
+ )
224
+ output_attr.append(torch.empty(0))
225
+
226
+ return _format_output(is_inputs_tuple, tuple(output_attr))
captum/attr/_core/input_x_gradient.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from typing import Any, Callable
3
+
4
+ from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
5
+ from captum._utils.gradient import (
6
+ apply_gradient_requirements,
7
+ undo_gradient_requirements,
8
+ )
9
+ from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
10
+ from captum.attr._utils.attribution import GradientAttribution
11
+ from captum.log import log_usage
12
+
13
+
14
+ class InputXGradient(GradientAttribution):
15
+ r"""
16
+ A baseline approach for computing the attribution. It multiplies input with
17
+ the gradient with respect to input.
18
+ https://arxiv.org/abs/1605.01713
19
+ """
20
+
21
+ def __init__(self, forward_func: Callable) -> None:
22
+ r"""
23
+ Args:
24
+
25
+ forward_func (callable): The forward function of the model or any
26
+ modification of it
27
+ """
28
+ GradientAttribution.__init__(self, forward_func)
29
+
30
+ @log_usage()
31
+ def attribute(
32
+ self,
33
+ inputs: TensorOrTupleOfTensorsGeneric,
34
+ target: TargetType = None,
35
+ additional_forward_args: Any = None,
36
+ ) -> TensorOrTupleOfTensorsGeneric:
37
+ r"""
38
+ Args:
39
+
40
+ inputs (tensor or tuple of tensors): Input for which
41
+ attributions are computed. If forward_func takes a single
42
+ tensor as input, a single input tensor should be provided.
43
+ If forward_func takes multiple tensors as input, a tuple
44
+ of the input tensors should be provided. It is assumed
45
+ that for all given input tensors, dimension 0 corresponds
46
+ to the number of examples (aka batch size), and if
47
+ multiple input tensors are provided, the examples must
48
+ be aligned appropriately.
49
+ target (int, tuple, tensor or list, optional): Output indices for
50
+ which gradients are computed (for classification cases,
51
+ this is usually the target class).
52
+ If the network returns a scalar value per example,
53
+ no target index is necessary.
54
+ For general 2D outputs, targets can be either:
55
+
56
+ - a single integer or a tensor containing a single
57
+ integer, which is applied to all input examples
58
+
59
+ - a list of integers or a 1D tensor, with length matching
60
+ the number of examples in inputs (dim 0). Each integer
61
+ is applied as the target for the corresponding example.
62
+
63
+ For outputs with > 2 dimensions, targets can be either:
64
+
65
+ - A single tuple, which contains #output_dims - 1
66
+ elements. This target index is applied to all examples.
67
+
68
+ - A list of tuples with length equal to the number of
69
+ examples in inputs (dim 0), and each tuple containing
70
+ #output_dims - 1 elements. Each tuple is applied as the
71
+ target for the corresponding example.
72
+
73
+ Default: None
74
+ additional_forward_args (any, optional): If the forward function
75
+ requires additional arguments other than the inputs for
76
+ which attributions should not be computed, this argument
77
+ can be provided. It must be either a single additional
78
+ argument of a Tensor or arbitrary (non-tuple) type or a tuple
79
+ containing multiple additional arguments including tensors
80
+ or any arbitrary python types. These arguments are provided to
81
+ forward_func in order following the arguments in inputs.
82
+ Note that attributions are not computed with respect
83
+ to these arguments.
84
+ Default: None
85
+
86
+ Returns:
87
+ *tensor* or tuple of *tensors* of **attributions**:
88
+ - **attributions** (*tensor* or tuple of *tensors*):
89
+ The input x gradient with
90
+ respect to each input feature. Attributions will always be
91
+ the same size as the provided inputs, with each value
92
+ providing the attribution of the corresponding input index.
93
+ If a single tensor is provided as inputs, a single tensor is
94
+ returned. If a tuple is provided for inputs, a tuple of
95
+ corresponding sized tensors is returned.
96
+
97
+
98
+ Examples::
99
+
100
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
101
+ >>> # and returns an Nx10 tensor of class probabilities.
102
+ >>> net = ImageClassifier()
103
+ >>> # Generating random input with size 2x3x3x32
104
+ >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
105
+ >>> # Defining InputXGradient interpreter
106
+ >>> input_x_gradient = InputXGradient(net)
107
+ >>> # Computes inputXgradient for class 4.
108
+ >>> attribution = input_x_gradient.attribute(input, target=4)
109
+ """
110
+ # Keeps track whether original input is a tuple or not before
111
+ # converting it into a tuple.
112
+ is_inputs_tuple = _is_tuple(inputs)
113
+
114
+ inputs = _format_tensor_into_tuples(inputs)
115
+ gradient_mask = apply_gradient_requirements(inputs)
116
+
117
+ gradients = self.gradient_func(
118
+ self.forward_func, inputs, target, additional_forward_args
119
+ )
120
+
121
+ attributions = tuple(
122
+ input * gradient for input, gradient in zip(inputs, gradients)
123
+ )
124
+
125
+ undo_gradient_requirements(inputs, gradient_mask)
126
+ return _format_output(is_inputs_tuple, attributions)
127
+
128
+ @property
129
+ def multiplies_by_inputs(self):
130
+ return True
captum/attr/_core/integrated_gradients.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import typing
3
+ from typing import Any, Callable, List, Tuple, Union
4
+
5
+ import torch
6
+ from captum._utils.common import (
7
+ _expand_additional_forward_args,
8
+ _expand_target,
9
+ _format_additional_forward_args,
10
+ _format_output,
11
+ _is_tuple,
12
+ )
13
+ from captum._utils.typing import (
14
+ BaselineType,
15
+ Literal,
16
+ TargetType,
17
+ TensorOrTupleOfTensorsGeneric,
18
+ )
19
+ from captum.attr._utils.approximation_methods import approximation_parameters
20
+ from captum.attr._utils.attribution import GradientAttribution
21
+ from captum.attr._utils.batching import _batch_attribution
22
+ from captum.attr._utils.common import (
23
+ _format_input_baseline,
24
+ _reshape_and_sum,
25
+ _validate_input,
26
+ )
27
+ from captum.log import log_usage
28
+ from torch import Tensor
29
+
30
+
31
+ class IntegratedGradients(GradientAttribution):
32
+ r"""
33
+ Integrated Gradients is an axiomatic model interpretability algorithm that
34
+ assigns an importance score to each input feature by approximating the
35
+ integral of gradients of the model's output with respect to the inputs
36
+ along the path (straight line) from given baselines / references to inputs.
37
+
38
+ Baselines can be provided as input arguments to attribute method.
39
+ To approximate the integral we can choose to use either a variant of
40
+ Riemann sum or Gauss-Legendre quadrature rule.
41
+
42
+ More details regarding the integrated gradients method can be found in the
43
+ original paper:
44
+ https://arxiv.org/abs/1703.01365
45
+
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ forward_func: Callable,
51
+ multiply_by_inputs: bool = True,
52
+ ) -> None:
53
+ r"""
54
+ Args:
55
+
56
+ forward_func (callable): The forward function of the model or any
57
+ modification of it
58
+ multiply_by_inputs (bool, optional): Indicates whether to factor
59
+ model inputs' multiplier in the final attribution scores.
60
+ In the literature this is also known as local vs global
61
+ attribution. If inputs' multiplier isn't factored in,
62
+ then that type of attribution method is also called local
63
+ attribution. If it is, then that type of attribution
64
+ method is called global.
65
+ More detailed can be found here:
66
+ https://arxiv.org/abs/1711.06104
67
+
68
+ In case of integrated gradients, if `multiply_by_inputs`
69
+ is set to True, final sensitivity scores are being multiplied by
70
+ (inputs - baselines).
71
+ """
72
+ GradientAttribution.__init__(self, forward_func)
73
+ self._multiply_by_inputs = multiply_by_inputs
74
+
75
+ # The following overloaded method signatures correspond to the case where
76
+ # return_convergence_delta is False, then only attributions are returned,
77
+ # and when return_convergence_delta is True, the return type is
78
+ # a tuple with both attributions and deltas.
79
+ @typing.overload
80
+ def attribute(
81
+ self,
82
+ inputs: TensorOrTupleOfTensorsGeneric,
83
+ baselines: BaselineType = None,
84
+ target: TargetType = None,
85
+ additional_forward_args: Any = None,
86
+ n_steps: int = 50,
87
+ method: str = "gausslegendre",
88
+ internal_batch_size: Union[None, int] = None,
89
+ return_convergence_delta: Literal[False] = False,
90
+ ) -> TensorOrTupleOfTensorsGeneric:
91
+ ...
92
+
93
+ @typing.overload
94
+ def attribute(
95
+ self,
96
+ inputs: TensorOrTupleOfTensorsGeneric,
97
+ baselines: BaselineType = None,
98
+ target: TargetType = None,
99
+ additional_forward_args: Any = None,
100
+ n_steps: int = 50,
101
+ method: str = "gausslegendre",
102
+ internal_batch_size: Union[None, int] = None,
103
+ *,
104
+ return_convergence_delta: Literal[True],
105
+ ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
106
+ ...
107
+
108
+ @log_usage()
109
+ def attribute( # type: ignore
110
+ self,
111
+ inputs: TensorOrTupleOfTensorsGeneric,
112
+ baselines: BaselineType = None,
113
+ target: TargetType = None,
114
+ additional_forward_args: Any = None,
115
+ n_steps: int = 50,
116
+ method: str = "gausslegendre",
117
+ internal_batch_size: Union[None, int] = None,
118
+ return_convergence_delta: bool = False,
119
+ ) -> Union[
120
+ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
121
+ ]:
122
+ r"""
123
+ This method attributes the output of the model with given target index
124
+ (in case it is provided, otherwise it assumes that output is a
125
+ scalar) to the inputs of the model using the approach described above.
126
+
127
+ In addition to that it also returns, if `return_convergence_delta` is
128
+ set to True, integral approximation delta based on the completeness
129
+ property of integrated gradients.
130
+
131
+ Args:
132
+
133
+ inputs (tensor or tuple of tensors): Input for which integrated
134
+ gradients are computed. If forward_func takes a single
135
+ tensor as input, a single input tensor should be provided.
136
+ If forward_func takes multiple tensors as input, a tuple
137
+ of the input tensors should be provided. It is assumed
138
+ that for all given input tensors, dimension 0 corresponds
139
+ to the number of examples, and if multiple input tensors
140
+ are provided, the examples must be aligned appropriately.
141
+ baselines (scalar, tensor, tuple of scalars or tensors, optional):
142
+ Baselines define the starting point from which integral
143
+ is computed and can be provided as:
144
+
145
+ - a single tensor, if inputs is a single tensor, with
146
+ exactly the same dimensions as inputs or the first
147
+ dimension is one and the remaining dimensions match
148
+ with inputs.
149
+
150
+ - a single scalar, if inputs is a single tensor, which will
151
+ be broadcasted for each input value in input tensor.
152
+
153
+ - a tuple of tensors or scalars, the baseline corresponding
154
+ to each tensor in the inputs' tuple can be:
155
+
156
+ - either a tensor with matching dimensions to
157
+ corresponding tensor in the inputs' tuple
158
+ or the first dimension is one and the remaining
159
+ dimensions match with the corresponding
160
+ input tensor.
161
+
162
+ - or a scalar, corresponding to a tensor in the
163
+ inputs' tuple. This scalar value is broadcasted
164
+ for corresponding input tensor.
165
+ In the cases when `baselines` is not provided, we internally
166
+ use zero scalar corresponding to each input tensor.
167
+
168
+ Default: None
169
+ target (int, tuple, tensor or list, optional): Output indices for
170
+ which gradients are computed (for classification cases,
171
+ this is usually the target class).
172
+ If the network returns a scalar value per example,
173
+ no target index is necessary.
174
+ For general 2D outputs, targets can be either:
175
+
176
+ - a single integer or a tensor containing a single
177
+ integer, which is applied to all input examples
178
+
179
+ - a list of integers or a 1D tensor, with length matching
180
+ the number of examples in inputs (dim 0). Each integer
181
+ is applied as the target for the corresponding example.
182
+
183
+ For outputs with > 2 dimensions, targets can be either:
184
+
185
+ - A single tuple, which contains #output_dims - 1
186
+ elements. This target index is applied to all examples.
187
+
188
+ - A list of tuples with length equal to the number of
189
+ examples in inputs (dim 0), and each tuple containing
190
+ #output_dims - 1 elements. Each tuple is applied as the
191
+ target for the corresponding example.
192
+
193
+ Default: None
194
+ additional_forward_args (any, optional): If the forward function
195
+ requires additional arguments other than the inputs for
196
+ which attributions should not be computed, this argument
197
+ can be provided. It must be either a single additional
198
+ argument of a Tensor or arbitrary (non-tuple) type or a
199
+ tuple containing multiple additional arguments including
200
+ tensors or any arbitrary python types. These arguments
201
+ are provided to forward_func in order following the
202
+ arguments in inputs.
203
+ For a tensor, the first dimension of the tensor must
204
+ correspond to the number of examples. It will be
205
+ repeated for each of `n_steps` along the integrated
206
+ path. For all other types, the given argument is used
207
+ for all forward evaluations.
208
+ Note that attributions are not computed with respect
209
+ to these arguments.
210
+ Default: None
211
+ n_steps (int, optional): The number of steps used by the approximation
212
+ method. Default: 50.
213
+ method (string, optional): Method for approximating the integral,
214
+ one of `riemann_right`, `riemann_left`, `riemann_middle`,
215
+ `riemann_trapezoid` or `gausslegendre`.
216
+ Default: `gausslegendre` if no method is provided.
217
+ internal_batch_size (int, optional): Divides total #steps * #examples
218
+ data points into chunks of size at most internal_batch_size,
219
+ which are computed (forward / backward passes)
220
+ sequentially. internal_batch_size must be at least equal to
221
+ #examples.
222
+ For DataParallel models, each batch is split among the
223
+ available devices, so evaluations on each available
224
+ device contain internal_batch_size / num_devices examples.
225
+ If internal_batch_size is None, then all evaluations are
226
+ processed in one batch.
227
+ Default: None
228
+ return_convergence_delta (bool, optional): Indicates whether to return
229
+ convergence delta or not. If `return_convergence_delta`
230
+ is set to True convergence delta will be returned in
231
+ a tuple following attributions.
232
+ Default: False
233
+ Returns:
234
+ **attributions** or 2-element tuple of **attributions**, **delta**:
235
+ - **attributions** (*tensor* or tuple of *tensors*):
236
+ Integrated gradients with respect to each input feature.
237
+ attributions will always be the same size as the provided
238
+ inputs, with each value providing the attribution of the
239
+ corresponding input index.
240
+ If a single tensor is provided as inputs, a single tensor is
241
+ returned. If a tuple is provided for inputs, a tuple of
242
+ corresponding sized tensors is returned.
243
+ - **delta** (*tensor*, returned if return_convergence_delta=True):
244
+ The difference between the total approximated and true
245
+ integrated gradients. This is computed using the property
246
+ that the total sum of forward_func(inputs) -
247
+ forward_func(baselines) must equal the total sum of the
248
+ integrated gradient.
249
+ Delta is calculated per example, meaning that the number of
250
+ elements in returned delta tensor is equal to the number of
251
+ of examples in inputs.
252
+
253
+ Examples::
254
+
255
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
256
+ >>> # and returns an Nx10 tensor of class probabilities.
257
+ >>> net = ImageClassifier()
258
+ >>> ig = IntegratedGradients(net)
259
+ >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
260
+ >>> # Computes integrated gradients for class 3.
261
+ >>> attribution = ig.attribute(input, target=3)
262
+ """
263
+ # Keeps track whether original input is a tuple or not before
264
+ # converting it into a tuple.
265
+ is_inputs_tuple = _is_tuple(inputs)
266
+
267
+ inputs, baselines = _format_input_baseline(inputs, baselines)
268
+
269
+ _validate_input(inputs, baselines, n_steps, method)
270
+
271
+ if internal_batch_size is not None:
272
+ num_examples = inputs[0].shape[0]
273
+ attributions = _batch_attribution(
274
+ self,
275
+ num_examples,
276
+ internal_batch_size,
277
+ n_steps,
278
+ inputs=inputs,
279
+ baselines=baselines,
280
+ target=target,
281
+ additional_forward_args=additional_forward_args,
282
+ method=method,
283
+ )
284
+ else:
285
+ attributions = self._attribute(
286
+ inputs=inputs,
287
+ baselines=baselines,
288
+ target=target,
289
+ additional_forward_args=additional_forward_args,
290
+ n_steps=n_steps,
291
+ method=method,
292
+ )
293
+
294
+ if return_convergence_delta:
295
+ start_point, end_point = baselines, inputs
296
+ # computes approximation error based on the completeness axiom
297
+ delta = self.compute_convergence_delta(
298
+ attributions,
299
+ start_point,
300
+ end_point,
301
+ additional_forward_args=additional_forward_args,
302
+ target=target,
303
+ )
304
+ return _format_output(is_inputs_tuple, attributions), delta
305
+ return _format_output(is_inputs_tuple, attributions)
306
+
307
+ def _attribute(
308
+ self,
309
+ inputs: Tuple[Tensor, ...],
310
+ baselines: Tuple[Union[Tensor, int, float], ...],
311
+ target: TargetType = None,
312
+ additional_forward_args: Any = None,
313
+ n_steps: int = 50,
314
+ method: str = "gausslegendre",
315
+ step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,
316
+ ) -> Tuple[Tensor, ...]:
317
+ if step_sizes_and_alphas is None:
318
+ # retrieve step size and scaling factor for specified
319
+ # approximation method
320
+ step_sizes_func, alphas_func = approximation_parameters(method)
321
+ step_sizes, alphas = step_sizes_func(n_steps), alphas_func(n_steps)
322
+ else:
323
+ step_sizes, alphas = step_sizes_and_alphas
324
+
325
+ # scale features and compute gradients. (batch size is abbreviated as bsz)
326
+ # scaled_features' dim -> (bsz * #steps x inputs[0].shape[1:], ...)
327
+ scaled_features_tpl = tuple(
328
+ torch.cat(
329
+ [baseline + alpha * (input - baseline) for alpha in alphas], dim=0
330
+ ).requires_grad_()
331
+ for input, baseline in zip(inputs, baselines)
332
+ )
333
+
334
+ additional_forward_args = _format_additional_forward_args(
335
+ additional_forward_args
336
+ )
337
+ # apply number of steps to additional forward args
338
+ # currently, number of steps is applied only to additional forward arguments
339
+ # that are nd-tensors. It is assumed that the first dimension is
340
+ # the number of batches.
341
+ # dim -> (bsz * #steps x additional_forward_args[0].shape[1:], ...)
342
+ input_additional_args = (
343
+ _expand_additional_forward_args(additional_forward_args, n_steps)
344
+ if additional_forward_args is not None
345
+ else None
346
+ )
347
+ expanded_target = _expand_target(target, n_steps)
348
+
349
+ # grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...)
350
+ grads = self.gradient_func(
351
+ forward_fn=self.forward_func,
352
+ inputs=scaled_features_tpl,
353
+ target_ind=expanded_target,
354
+ additional_forward_args=input_additional_args,
355
+ )
356
+
357
+ # flattening grads so that we can multilpy it with step-size
358
+ # calling contiguous to avoid `memory whole` problems
359
+ scaled_grads = [
360
+ grad.contiguous().view(n_steps, -1)
361
+ * torch.tensor(step_sizes).view(n_steps, 1).to(grad.device)
362
+ for grad in grads
363
+ ]
364
+
365
+ # aggregates across all steps for each tensor in the input tuple
366
+ # total_grads has the same dimensionality as inputs
367
+ total_grads = tuple(
368
+ _reshape_and_sum(
369
+ scaled_grad, n_steps, grad.shape[0] // n_steps, grad.shape[1:]
370
+ )
371
+ for (scaled_grad, grad) in zip(scaled_grads, grads)
372
+ )
373
+
374
+ # computes attribution for each tensor in input tuple
375
+ # attributions has the same dimensionality as inputs
376
+ if not self.multiplies_by_inputs:
377
+ attributions = total_grads
378
+ else:
379
+ attributions = tuple(
380
+ total_grad * (input - baseline)
381
+ for total_grad, input, baseline in zip(total_grads, inputs, baselines)
382
+ )
383
+ return attributions
384
+
385
+ def has_convergence_delta(self) -> bool:
386
+ return True
387
+
388
+ @property
389
+ def multiplies_by_inputs(self):
390
+ return self._multiply_by_inputs
captum/attr/_core/kernel_shap.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Any, Callable, Generator, Tuple, Union
4
+
5
+ import torch
6
+ from captum._utils.models.linear_model import SkLearnLinearRegression
7
+ from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
8
+ from captum.attr._core.lime import construct_feature_mask, Lime
9
+ from captum.attr._utils.common import _format_input_baseline
10
+ from captum.log import log_usage
11
+ from torch import Tensor
12
+ from torch.distributions.categorical import Categorical
13
+
14
+
15
+ class KernelShap(Lime):
16
+ r"""
17
+ Kernel SHAP is a method that uses the LIME framework to compute
18
+ Shapley Values. Setting the loss function, weighting kernel and
19
+ regularization terms appropriately in the LIME framework allows
20
+ theoretically obtaining Shapley Values more efficiently than
21
+ directly computing Shapley Values.
22
+
23
+ More information regarding this method and proof of equivalence
24
+ can be found in the original paper here:
25
+ https://arxiv.org/abs/1705.07874
26
+ """
27
+
28
+ def __init__(self, forward_func: Callable) -> None:
29
+ r"""
30
+ Args:
31
+
32
+ forward_func (callable): The forward function of the model or
33
+ any modification of it
34
+ """
35
+ Lime.__init__(
36
+ self,
37
+ forward_func,
38
+ interpretable_model=SkLearnLinearRegression(),
39
+ similarity_func=self.kernel_shap_similarity_kernel,
40
+ perturb_func=self.kernel_shap_perturb_generator,
41
+ )
42
+ self.inf_weight = 1000000.0
43
+
44
+ @log_usage()
45
+ def attribute( # type: ignore
46
+ self,
47
+ inputs: TensorOrTupleOfTensorsGeneric,
48
+ baselines: BaselineType = None,
49
+ target: TargetType = None,
50
+ additional_forward_args: Any = None,
51
+ feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
52
+ n_samples: int = 25,
53
+ perturbations_per_eval: int = 1,
54
+ return_input_shape: bool = True,
55
+ show_progress: bool = False,
56
+ ) -> TensorOrTupleOfTensorsGeneric:
57
+ r"""
58
+ This method attributes the output of the model with given target index
59
+ (in case it is provided, otherwise it assumes that output is a
60
+ scalar) to the inputs of the model using the approach described above,
61
+ training an interpretable model based on KernelSHAP and returning a
62
+ representation of the interpretable model.
63
+
64
+ It is recommended to only provide a single example as input (tensors
65
+ with first dimension or batch size = 1). This is because LIME / KernelShap
66
+ is generally used for sample-based interpretability, training a separate
67
+ interpretable model to explain a model's prediction on each individual example.
68
+
69
+ A batch of inputs can also be provided as inputs, similar to
70
+ other perturbation-based attribution methods. In this case, if forward_fn
71
+ returns a scalar per example, attributions will be computed for each
72
+ example independently, with a separate interpretable model trained for each
73
+ example. Note that provided similarity and perturbation functions will be
74
+ provided each example separately (first dimension = 1) in this case.
75
+ If forward_fn returns a scalar per batch (e.g. loss), attributions will
76
+ still be computed using a single interpretable model for the full batch.
77
+ In this case, similarity and perturbation functions will be provided the
78
+ same original input containing the full batch.
79
+
80
+ The number of interpretable features is determined from the provided
81
+ feature mask, or if none is provided, from the default feature mask,
82
+ which considers each scalar input as a separate feature. It is
83
+ generally recommended to provide a feature mask which groups features
84
+ into a small number of interpretable features / components (e.g.
85
+ superpixels in images).
86
+
87
+ Args:
88
+
89
+ inputs (tensor or tuple of tensors): Input for which KernelShap
90
+ is computed. If forward_func takes a single
91
+ tensor as input, a single input tensor should be provided.
92
+ If forward_func takes multiple tensors as input, a tuple
93
+ of the input tensors should be provided. It is assumed
94
+ that for all given input tensors, dimension 0 corresponds
95
+ to the number of examples, and if multiple input tensors
96
+ are provided, the examples must be aligned appropriately.
97
+ baselines (scalar, tensor, tuple of scalars or tensors, optional):
98
+ Baselines define the reference value which replaces each
99
+ feature when the corresponding interpretable feature
100
+ is set to 0.
101
+ Baselines can be provided as:
102
+
103
+ - a single tensor, if inputs is a single tensor, with
104
+ exactly the same dimensions as inputs or the first
105
+ dimension is one and the remaining dimensions match
106
+ with inputs.
107
+
108
+ - a single scalar, if inputs is a single tensor, which will
109
+ be broadcasted for each input value in input tensor.
110
+
111
+ - a tuple of tensors or scalars, the baseline corresponding
112
+ to each tensor in the inputs' tuple can be:
113
+
114
+ - either a tensor with matching dimensions to
115
+ corresponding tensor in the inputs' tuple
116
+ or the first dimension is one and the remaining
117
+ dimensions match with the corresponding
118
+ input tensor.
119
+
120
+ - or a scalar, corresponding to a tensor in the
121
+ inputs' tuple. This scalar value is broadcasted
122
+ for corresponding input tensor.
123
+ In the cases when `baselines` is not provided, we internally
124
+ use zero scalar corresponding to each input tensor.
125
+ Default: None
126
+ target (int, tuple, tensor or list, optional): Output indices for
127
+ which surrogate model is trained
128
+ (for classification cases,
129
+ this is usually the target class).
130
+ If the network returns a scalar value per example,
131
+ no target index is necessary.
132
+ For general 2D outputs, targets can be either:
133
+
134
+ - a single integer or a tensor containing a single
135
+ integer, which is applied to all input examples
136
+
137
+ - a list of integers or a 1D tensor, with length matching
138
+ the number of examples in inputs (dim 0). Each integer
139
+ is applied as the target for the corresponding example.
140
+
141
+ For outputs with > 2 dimensions, targets can be either:
142
+
143
+ - A single tuple, which contains #output_dims - 1
144
+ elements. This target index is applied to all examples.
145
+
146
+ - A list of tuples with length equal to the number of
147
+ examples in inputs (dim 0), and each tuple containing
148
+ #output_dims - 1 elements. Each tuple is applied as the
149
+ target for the corresponding example.
150
+
151
+ Default: None
152
+ additional_forward_args (any, optional): If the forward function
153
+ requires additional arguments other than the inputs for
154
+ which attributions should not be computed, this argument
155
+ can be provided. It must be either a single additional
156
+ argument of a Tensor or arbitrary (non-tuple) type or a
157
+ tuple containing multiple additional arguments including
158
+ tensors or any arbitrary python types. These arguments
159
+ are provided to forward_func in order following the
160
+ arguments in inputs.
161
+ For a tensor, the first dimension of the tensor must
162
+ correspond to the number of examples. It will be
163
+ repeated for each of `n_steps` along the integrated
164
+ path. For all other types, the given argument is used
165
+ for all forward evaluations.
166
+ Note that attributions are not computed with respect
167
+ to these arguments.
168
+ Default: None
169
+ feature_mask (tensor or tuple of tensors, optional):
170
+ feature_mask defines a mask for the input, grouping
171
+ features which correspond to the same
172
+ interpretable feature. feature_mask
173
+ should contain the same number of tensors as inputs.
174
+ Each tensor should
175
+ be the same size as the corresponding input or
176
+ broadcastable to match the input tensor. Values across
177
+ all tensors should be integers in the range 0 to
178
+ num_interp_features - 1, and indices corresponding to the
179
+ same feature should have the same value.
180
+ Note that features are grouped across tensors
181
+ (unlike feature ablation and occlusion), so
182
+ if the same index is used in different tensors, those
183
+ features are still grouped and added simultaneously.
184
+ If None, then a feature mask is constructed which assigns
185
+ each scalar within a tensor as a separate feature.
186
+ Default: None
187
+ n_samples (int, optional): The number of samples of the original
188
+ model used to train the surrogate interpretable model.
189
+ Default: `50` if `n_samples` is not provided.
190
+ perturbations_per_eval (int, optional): Allows multiple samples
191
+ to be processed simultaneously in one call to forward_fn.
192
+ Each forward pass will contain a maximum of
193
+ perturbations_per_eval * #examples samples.
194
+ For DataParallel models, each batch is split among the
195
+ available devices, so evaluations on each available
196
+ device contain at most
197
+ (perturbations_per_eval * #examples) / num_devices
198
+ samples.
199
+ If the forward function returns a single scalar per batch,
200
+ perturbations_per_eval must be set to 1.
201
+ Default: 1
202
+ return_input_shape (bool, optional): Determines whether the returned
203
+ tensor(s) only contain the coefficients for each interp-
204
+ retable feature from the trained surrogate model, or
205
+ whether the returned attributions match the input shape.
206
+ When return_input_shape is True, the return type of attribute
207
+ matches the input shape, with each element containing the
208
+ coefficient of the corresponding interpretable feature.
209
+ All elements with the same value in the feature mask
210
+ will contain the same coefficient in the returned
211
+ attributions. If return_input_shape is False, a 1D
212
+ tensor is returned, containing only the coefficients
213
+ of the trained interpretable model, with length
214
+ num_interp_features.
215
+ show_progress (bool, optional): Displays the progress of computation.
216
+ It will try to use tqdm if available for advanced features
217
+ (e.g. time estimation). Otherwise, it will fallback to
218
+ a simple output of progress.
219
+ Default: False
220
+
221
+ Returns:
222
+ *tensor* or tuple of *tensors* of **attributions**:
223
+ - **attributions** (*tensor* or tuple of *tensors*):
224
+ The attributions with respect to each input feature.
225
+ If return_input_shape = True, attributions will be
226
+ the same size as the provided inputs, with each value
227
+ providing the coefficient of the corresponding
228
+ interpretale feature.
229
+ If return_input_shape is False, a 1D
230
+ tensor is returned, containing only the coefficients
231
+ of the trained interpreatable models, with length
232
+ num_interp_features.
233
+ Examples::
234
+ >>> # SimpleClassifier takes a single input tensor of size Nx4x4,
235
+ >>> # and returns an Nx3 tensor of class probabilities.
236
+ >>> net = SimpleClassifier()
237
+
238
+ >>> # Generating random input with size 1 x 4 x 4
239
+ >>> input = torch.randn(1, 4, 4)
240
+
241
+ >>> # Defining KernelShap interpreter
242
+ >>> ks = KernelShap(net)
243
+ >>> # Computes attribution, with each of the 4 x 4 = 16
244
+ >>> # features as a separate interpretable feature
245
+ >>> attr = ks.attribute(input, target=1, n_samples=200)
246
+
247
+ >>> # Alternatively, we can group each 2x2 square of the inputs
248
+ >>> # as one 'interpretable' feature and perturb them together.
249
+ >>> # This can be done by creating a feature mask as follows, which
250
+ >>> # defines the feature groups, e.g.:
251
+ >>> # +---+---+---+---+
252
+ >>> # | 0 | 0 | 1 | 1 |
253
+ >>> # +---+---+---+---+
254
+ >>> # | 0 | 0 | 1 | 1 |
255
+ >>> # +---+---+---+---+
256
+ >>> # | 2 | 2 | 3 | 3 |
257
+ >>> # +---+---+---+---+
258
+ >>> # | 2 | 2 | 3 | 3 |
259
+ >>> # +---+---+---+---+
260
+ >>> # With this mask, all inputs with the same value are set to their
261
+ >>> # baseline value, when the corresponding binary interpretable
262
+ >>> # feature is set to 0.
263
+ >>> # The attributions can be calculated as follows:
264
+ >>> # feature mask has dimensions 1 x 4 x 4
265
+ >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1],
266
+ >>> [2,2,3,3],[2,2,3,3]]])
267
+
268
+ >>> # Computes KernelSHAP attributions with feature mask.
269
+ >>> attr = ks.attribute(input, target=1, feature_mask=feature_mask)
270
+ """
271
+ formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
272
+ feature_mask, num_interp_features = construct_feature_mask(
273
+ feature_mask, formatted_inputs
274
+ )
275
+ num_features_list = torch.arange(num_interp_features, dtype=torch.float)
276
+ denom = num_features_list * (num_interp_features - num_features_list)
277
+ probs = (num_interp_features - 1) / denom
278
+ probs[0] = 0.0
279
+ return self._attribute_kwargs(
280
+ inputs=inputs,
281
+ baselines=baselines,
282
+ target=target,
283
+ additional_forward_args=additional_forward_args,
284
+ feature_mask=feature_mask,
285
+ n_samples=n_samples,
286
+ perturbations_per_eval=perturbations_per_eval,
287
+ return_input_shape=return_input_shape,
288
+ num_select_distribution=Categorical(probs),
289
+ show_progress=show_progress,
290
+ )
291
+
292
+ def kernel_shap_similarity_kernel(
293
+ self, _, __, interpretable_sample: Tensor, **kwargs
294
+ ) -> Tensor:
295
+ assert (
296
+ "num_interp_features" in kwargs
297
+ ), "Must provide num_interp_features to use default similarity kernel"
298
+ num_selected_features = int(interpretable_sample.sum(dim=1).item())
299
+ num_features = kwargs["num_interp_features"]
300
+ if num_selected_features == 0 or num_selected_features == num_features:
301
+ # weight should be theoretically infinite when
302
+ # num_selected_features = 0 or num_features
303
+ # enforcing that trained linear model must satisfy
304
+ # end-point criteria. In practice, it is sufficient to
305
+ # make this weight substantially larger so setting this
306
+ # weight to 1000000 (all other weights are 1).
307
+ similarities = self.inf_weight
308
+ else:
309
+ similarities = 1.0
310
+ return torch.tensor([similarities])
311
+
312
+ def kernel_shap_perturb_generator(
313
+ self, original_inp: Union[Tensor, Tuple[Tensor, ...]], **kwargs
314
+ ) -> Generator[Tensor, None, None]:
315
+ r"""
316
+ Perturbations are sampled by the following process:
317
+ - Choose k (number of selected features), based on the distribution
318
+ p(k) = (M - 1) / (k * (M - k))
319
+ where M is the total number of features in the interpretable space
320
+ - Randomly select a binary vector with k ones, each sample is equally
321
+ likely. This is done by generating a random vector of normal
322
+ values and thresholding based on the top k elements.
323
+
324
+ Since there are M choose k vectors with k ones, this weighted sampling
325
+ is equivalent to applying the Shapley kernel for the sample weight,
326
+ defined as:
327
+ k(M, k) = (M - 1) / (k * (M - k) * (M choose k))
328
+ """
329
+ assert (
330
+ "num_select_distribution" in kwargs and "num_interp_features" in kwargs
331
+ ), (
332
+ "num_select_distribution and num_interp_features are necessary"
333
+ " to use kernel_shap_perturb_func"
334
+ )
335
+ if isinstance(original_inp, Tensor):
336
+ device = original_inp.device
337
+ else:
338
+ device = original_inp[0].device
339
+ num_features = kwargs["num_interp_features"]
340
+ yield torch.ones(1, num_features, device=device, dtype=torch.long)
341
+ yield torch.zeros(1, num_features, device=device, dtype=torch.long)
342
+ while True:
343
+ num_selected_features = kwargs["num_select_distribution"].sample()
344
+ rand_vals = torch.randn(1, num_features)
345
+ threshold = torch.kthvalue(
346
+ rand_vals, num_features - num_selected_features
347
+ ).values.item()
348
+ yield (rand_vals > threshold).to(device=device).long()
captum/attr/_core/layer/__init__.py ADDED
File without changes
captum/attr/_core/layer/grad_cam.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from typing import Any, Callable, List, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from captum._utils.common import (
7
+ _format_additional_forward_args,
8
+ _format_output,
9
+ _format_tensor_into_tuples,
10
+ )
11
+ from captum._utils.gradient import compute_layer_gradients_and_eval
12
+ from captum._utils.typing import TargetType
13
+ from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
14
+ from captum.log import log_usage
15
+ from torch import Tensor
16
+ from torch.nn import Module
17
+
18
+
19
+ class LayerGradCam(LayerAttribution, GradientAttribution):
20
+ r"""
21
+ Computes GradCAM attribution for chosen layer. GradCAM is designed for
22
+ convolutional neural networks, and is usually applied to the last
23
+ convolutional layer.
24
+
25
+ GradCAM computes the gradients of the target output with respect to
26
+ the given layer, averages for each output channel (dimension 2 of
27
+ output), and multiplies the average gradient for each channel by the
28
+ layer activations. The results are summed over all channels.
29
+
30
+ Note that in the original GradCAM algorithm described in the paper,
31
+ ReLU is applied to the output, returning only non-negative attributions.
32
+ For providing more flexibility to the user, we choose to not perform the
33
+ ReLU internally by default and return the sign information. To match the
34
+ original GradCAM algorithm, it is necessary to pass the parameter
35
+ relu_attributions=True to apply ReLU on the final
36
+ attributions or alternatively only visualize the positive attributions.
37
+
38
+ Note: this procedure sums over the second dimension (# of channels),
39
+ so the output of GradCAM attributions will have a second
40
+ dimension of 1, but all other dimensions will match that of the layer
41
+ output.
42
+
43
+ GradCAM attributions are generally upsampled and can be viewed as a
44
+ mask to the input, since a convolutional layer output generally
45
+ matches the input image spatially. This upsampling can be performed
46
+ using LayerAttribution.interpolate, as shown in the example below.
47
+
48
+ More details regarding the GradCAM method can be found in the
49
+ original paper here:
50
+ https://arxiv.org/pdf/1610.02391.pdf
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ forward_func: Callable,
56
+ layer: Module,
57
+ device_ids: Union[None, List[int]] = None,
58
+ ) -> None:
59
+ r"""
60
+ Args:
61
+
62
+ forward_func (callable): The forward function of the model or any
63
+ modification of it
64
+ layer (torch.nn.Module): Layer for which attributions are computed.
65
+ Output size of attribute matches this layer's output
66
+ dimensions, except for dimension 2, which will be 1,
67
+ since GradCAM sums over channels.
68
+ device_ids (list(int)): Device ID list, necessary only if forward_func
69
+ applies a DataParallel model. This allows reconstruction of
70
+ intermediate outputs from batched results across devices.
71
+ If forward_func is given as the DataParallel model itself,
72
+ then it is not necessary to provide this argument.
73
+ """
74
+ LayerAttribution.__init__(self, forward_func, layer, device_ids)
75
+ GradientAttribution.__init__(self, forward_func)
76
+
77
+ @log_usage()
78
+ def attribute(
79
+ self,
80
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
81
+ target: TargetType = None,
82
+ additional_forward_args: Any = None,
83
+ attribute_to_layer_input: bool = False,
84
+ relu_attributions: bool = False,
85
+ ) -> Union[Tensor, Tuple[Tensor, ...]]:
86
+ r"""
87
+ Args:
88
+
89
+ inputs (tensor or tuple of tensors): Input for which attributions
90
+ are computed. If forward_func takes a single
91
+ tensor as input, a single input tensor should be provided.
92
+ If forward_func takes multiple tensors as input, a tuple
93
+ of the input tensors should be provided. It is assumed
94
+ that for all given input tensors, dimension 0 corresponds
95
+ to the number of examples, and if multiple input tensors
96
+ are provided, the examples must be aligned appropriately.
97
+ target (int, tuple, tensor or list, optional): Output indices for
98
+ which gradients are computed (for classification cases,
99
+ this is usually the target class).
100
+ If the network returns a scalar value per example,
101
+ no target index is necessary.
102
+ For general 2D outputs, targets can be either:
103
+
104
+ - a single integer or a tensor containing a single
105
+ integer, which is applied to all input examples
106
+
107
+ - a list of integers or a 1D tensor, with length matching
108
+ the number of examples in inputs (dim 0). Each integer
109
+ is applied as the target for the corresponding example.
110
+
111
+ For outputs with > 2 dimensions, targets can be either:
112
+
113
+ - A single tuple, which contains #output_dims - 1
114
+ elements. This target index is applied to all examples.
115
+
116
+ - A list of tuples with length equal to the number of
117
+ examples in inputs (dim 0), and each tuple containing
118
+ #output_dims - 1 elements. Each tuple is applied as the
119
+ target for the corresponding example.
120
+
121
+ Default: None
122
+ additional_forward_args (any, optional): If the forward function
123
+ requires additional arguments other than the inputs for
124
+ which attributions should not be computed, this argument
125
+ can be provided. It must be either a single additional
126
+ argument of a Tensor or arbitrary (non-tuple) type or a
127
+ tuple containing multiple additional arguments including
128
+ tensors or any arbitrary python types. These arguments
129
+ are provided to forward_func in order following the
130
+ arguments in inputs.
131
+ Note that attributions are not computed with respect
132
+ to these arguments.
133
+ Default: None
134
+ attribute_to_layer_input (bool, optional): Indicates whether to
135
+ compute the attributions with respect to the layer input
136
+ or output. If `attribute_to_layer_input` is set to True
137
+ then the attributions will be computed with respect to the
138
+ layer input, otherwise it will be computed with respect
139
+ to layer output.
140
+ Note that currently it is assumed that either the input
141
+ or the outputs of internal layers, depending on whether we
142
+ attribute to the input or output, are single tensors.
143
+ Support for multiple tensors will be added later.
144
+ Default: False
145
+ relu_attributions (bool, optional): Indicates whether to
146
+ apply a ReLU operation on the final attribution,
147
+ returning only non-negative attributions. Setting this
148
+ flag to True matches the original GradCAM algorithm,
149
+ otherwise, by default, both positive and negative
150
+ attributions are returned.
151
+ Default: False
152
+
153
+ Returns:
154
+ *tensor* or tuple of *tensors* of **attributions**:
155
+ - **attributions** (*tensor* or tuple of *tensors*):
156
+ Attributions based on GradCAM method.
157
+ Attributions will be the same size as the
158
+ output of the given layer, except for dimension 2,
159
+ which will be 1 due to summing over channels.
160
+ Attributions are returned in a tuple if
161
+ the layer inputs / outputs contain multiple tensors,
162
+ otherwise a single tensor is returned.
163
+ Examples::
164
+
165
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
166
+ >>> # and returns an Nx10 tensor of class probabilities.
167
+ >>> # It contains a layer conv4, which is an instance of nn.conv2d,
168
+ >>> # and the output of this layer has dimensions Nx50x8x8.
169
+ >>> # It is the last convolution layer, which is the recommended
170
+ >>> # use case for GradCAM.
171
+ >>> net = ImageClassifier()
172
+ >>> layer_gc = LayerGradCam(net, net.conv4)
173
+ >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
174
+ >>> # Computes layer GradCAM for class 3.
175
+ >>> # attribution size matches layer output except for dimension
176
+ >>> # 1, so dimensions of attr would be Nx1x8x8.
177
+ >>> attr = layer_gc.attribute(input, 3)
178
+ >>> # GradCAM attributions are often upsampled and viewed as a
179
+ >>> # mask to the input, since the convolutional layer output
180
+ >>> # spatially matches the original input image.
181
+ >>> # This can be done with LayerAttribution's interpolate method.
182
+ >>> upsampled_attr = LayerAttribution.interpolate(attr, (32, 32))
183
+ """
184
+ inputs = _format_tensor_into_tuples(inputs)
185
+ additional_forward_args = _format_additional_forward_args(
186
+ additional_forward_args
187
+ )
188
+ # Returns gradient of output with respect to
189
+ # hidden layer and hidden layer evaluated at each input.
190
+ layer_gradients, layer_evals = compute_layer_gradients_and_eval(
191
+ self.forward_func,
192
+ self.layer,
193
+ inputs,
194
+ target,
195
+ additional_forward_args,
196
+ device_ids=self.device_ids,
197
+ attribute_to_layer_input=attribute_to_layer_input,
198
+ )
199
+
200
+ summed_grads = tuple(
201
+ torch.mean(
202
+ layer_grad,
203
+ dim=tuple(x for x in range(2, len(layer_grad.shape))),
204
+ keepdim=True,
205
+ )
206
+ if len(layer_grad.shape) > 2
207
+ else layer_grad
208
+ for layer_grad in layer_gradients
209
+ )
210
+
211
+ scaled_acts = tuple(
212
+ torch.sum(summed_grad * layer_eval, dim=1, keepdim=True)
213
+ for summed_grad, layer_eval in zip(summed_grads, layer_evals)
214
+ )
215
+ if relu_attributions:
216
+ scaled_acts = tuple(F.relu(scaled_act) for scaled_act in scaled_acts)
217
+ return _format_output(len(scaled_acts) > 1, scaled_acts)
captum/attr/_core/layer/internal_influence.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from typing import Any, Callable, List, Tuple, Union
3
+
4
+ import torch
5
+ from captum._utils.common import (
6
+ _expand_additional_forward_args,
7
+ _expand_target,
8
+ _format_additional_forward_args,
9
+ _format_output,
10
+ )
11
+ from captum._utils.gradient import compute_layer_gradients_and_eval
12
+ from captum._utils.typing import BaselineType, TargetType
13
+ from captum.attr._utils.approximation_methods import approximation_parameters
14
+ from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
15
+ from captum.attr._utils.batching import _batch_attribution
16
+ from captum.attr._utils.common import (
17
+ _format_input_baseline,
18
+ _reshape_and_sum,
19
+ _validate_input,
20
+ )
21
+ from captum.log import log_usage
22
+ from torch import Tensor
23
+ from torch.nn import Module
24
+
25
+
26
+ class InternalInfluence(LayerAttribution, GradientAttribution):
27
+ r"""
28
+ Computes internal influence by approximating the integral of gradients
29
+ for a particular layer along the path from a baseline input to the
30
+ given input.
31
+ If no baseline is provided, the default baseline is the zero tensor.
32
+ More details on this approach can be found here:
33
+ https://arxiv.org/pdf/1802.03788.pdf
34
+
35
+ Note that this method is similar to applying integrated gradients and
36
+ taking the layer as input, integrating the gradient of the layer with
37
+ respect to the output.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ forward_func: Callable,
43
+ layer: Module,
44
+ device_ids: Union[None, List[int]] = None,
45
+ ) -> None:
46
+ r"""
47
+ Args:
48
+
49
+ forward_func (callable): The forward function of the model or any
50
+ modification of it
51
+ layer (torch.nn.Module): Layer for which attributions are computed.
52
+ Output size of attribute matches this layer's input or
53
+ output dimensions, depending on whether we attribute to
54
+ the inputs or outputs of the layer, corresponding to
55
+ attribution of each neuron in the input or output of
56
+ this layer.
57
+ device_ids (list(int)): Device ID list, necessary only if forward_func
58
+ applies a DataParallel model. This allows reconstruction of
59
+ intermediate outputs from batched results across devices.
60
+ If forward_func is given as the DataParallel model itself,
61
+ then it is not necessary to provide this argument.
62
+ """
63
+ LayerAttribution.__init__(self, forward_func, layer, device_ids)
64
+ GradientAttribution.__init__(self, forward_func)
65
+
66
+ @log_usage()
67
+ def attribute(
68
+ self,
69
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
70
+ baselines: BaselineType = None,
71
+ target: TargetType = None,
72
+ additional_forward_args: Any = None,
73
+ n_steps: int = 50,
74
+ method: str = "gausslegendre",
75
+ internal_batch_size: Union[None, int] = None,
76
+ attribute_to_layer_input: bool = False,
77
+ ) -> Union[Tensor, Tuple[Tensor, ...]]:
78
+ r"""
79
+ Args:
80
+
81
+ inputs (tensor or tuple of tensors): Input for which internal
82
+ influence is computed. If forward_func takes a single
83
+ tensor as input, a single input tensor should be provided.
84
+ If forward_func takes multiple tensors as input, a tuple
85
+ of the input tensors should be provided. It is assumed
86
+ that for all given input tensors, dimension 0 corresponds
87
+ to the number of examples, and if multiple input tensors
88
+ are provided, the examples must be aligned appropriately.
89
+ baselines scalar, tensor, tuple of scalars or tensors, optional):
90
+ Baselines define a starting point from which integral
91
+ is computed and can be provided as:
92
+
93
+ - a single tensor, if inputs is a single tensor, with
94
+ exactly the same dimensions as inputs or the first
95
+ dimension is one and the remaining dimensions match
96
+ with inputs.
97
+
98
+ - a single scalar, if inputs is a single tensor, which will
99
+ be broadcasted for each input value in input tensor.
100
+
101
+ - a tuple of tensors or scalars, the baseline corresponding
102
+ to each tensor in the inputs' tuple can be:
103
+
104
+ - either a tensor with matching dimensions to
105
+ corresponding tensor in the inputs' tuple
106
+ or the first dimension is one and the remaining
107
+ dimensions match with the corresponding
108
+ input tensor.
109
+
110
+ - or a scalar, corresponding to a tensor in the
111
+ inputs' tuple. This scalar value is broadcasted
112
+ for corresponding input tensor.
113
+
114
+ In the cases when `baselines` is not provided, we internally
115
+ use zero scalar corresponding to each input tensor.
116
+
117
+ Default: None
118
+ target (int, tuple, tensor or list, optional): Output indices for
119
+ which gradients are computed (for classification cases,
120
+ this is usually the target class).
121
+ If the network returns a scalar value per example,
122
+ no target index is necessary.
123
+ For general 2D outputs, targets can be either:
124
+
125
+ - a single integer or a tensor containing a single
126
+ integer, which is applied to all input examples
127
+
128
+ - a list of integers or a 1D tensor, with length matching
129
+ the number of examples in inputs (dim 0). Each integer
130
+ is applied as the target for the corresponding example.
131
+
132
+ For outputs with > 2 dimensions, targets can be either:
133
+
134
+ - A single tuple, which contains #output_dims - 1
135
+ elements. This target index is applied to all examples.
136
+
137
+ - A list of tuples with length equal to the number of
138
+ examples in inputs (dim 0), and each tuple containing
139
+ #output_dims - 1 elements. Each tuple is applied as the
140
+ target for the corresponding example.
141
+
142
+ Default: None
143
+ additional_forward_args (any, optional): If the forward function
144
+ requires additional arguments other than the inputs for
145
+ which attributions should not be computed, this argument
146
+ can be provided. It must be either a single additional
147
+ argument of a Tensor or arbitrary (non-tuple) type or a
148
+ tuple containing multiple additional arguments including
149
+ tensors or any arbitrary python types. These arguments
150
+ are provided to forward_func in order following the
151
+ arguments in inputs.
152
+ For a tensor, the first dimension of the tensor must
153
+ correspond to the number of examples. It will be
154
+ repeated for each of `n_steps` along the integrated
155
+ path. For all other types, the given argument is used
156
+ for all forward evaluations.
157
+ Note that attributions are not computed with respect
158
+ to these arguments.
159
+ Default: None
160
+ n_steps (int, optional): The number of steps used by the approximation
161
+ method. Default: 50.
162
+ method (string, optional): Method for approximating the integral,
163
+ one of `riemann_right`, `riemann_left`, `riemann_middle`,
164
+ `riemann_trapezoid` or `gausslegendre`.
165
+ Default: `gausslegendre` if no method is provided.
166
+ internal_batch_size (int, optional): Divides total #steps * #examples
167
+ data points into chunks of size at most internal_batch_size,
168
+ which are computed (forward / backward passes)
169
+ sequentially. internal_batch_size must be at least equal to
170
+ #examples.
171
+ For DataParallel models, each batch is split among the
172
+ available devices, so evaluations on each available
173
+ device contain internal_batch_size / num_devices examples.
174
+ If internal_batch_size is None, then all evaluations
175
+ are processed in one batch.
176
+ Default: None
177
+ attribute_to_layer_input (bool, optional): Indicates whether to
178
+ compute the attribution with respect to the layer input
179
+ or output. If `attribute_to_layer_input` is set to True
180
+ then the attributions will be computed with respect to
181
+ layer inputs, otherwise it will be computed with respect
182
+ to layer outputs.
183
+ Note that currently it is assumed that either the input
184
+ or the output of internal layer, depending on whether we
185
+ attribute to the input or output, is a single tensor.
186
+ Support for multiple tensors will be added later.
187
+ Default: False
188
+
189
+ Returns:
190
+ *tensor* or tuple of *tensors* of **attributions**:
191
+ - **attributions** (*tensor* or tuple of *tensors*):
192
+ Internal influence of each neuron in given
193
+ layer output. Attributions will always be the same size
194
+ as the output or input of the given layer depending on
195
+ whether `attribute_to_layer_input` is set to `False` or
196
+ `True`respectively.
197
+ Attributions are returned in a tuple if
198
+ the layer inputs / outputs contain multiple tensors,
199
+ otherwise a single tensor is returned.
200
+
201
+ Examples::
202
+
203
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
204
+ >>> # and returns an Nx10 tensor of class probabilities.
205
+ >>> # It contains an attribute conv1, which is an instance of nn.conv2d,
206
+ >>> # and the output of this layer has dimensions Nx12x32x32.
207
+ >>> net = ImageClassifier()
208
+ >>> layer_int_inf = InternalInfluence(net, net.conv1)
209
+ >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
210
+ >>> # Computes layer internal influence.
211
+ >>> # attribution size matches layer output, Nx12x32x32
212
+ >>> attribution = layer_int_inf.attribute(input)
213
+ """
214
+ inputs, baselines = _format_input_baseline(inputs, baselines)
215
+ _validate_input(inputs, baselines, n_steps, method)
216
+ if internal_batch_size is not None:
217
+ num_examples = inputs[0].shape[0]
218
+ attrs = _batch_attribution(
219
+ self,
220
+ num_examples,
221
+ internal_batch_size,
222
+ n_steps,
223
+ inputs=inputs,
224
+ baselines=baselines,
225
+ target=target,
226
+ additional_forward_args=additional_forward_args,
227
+ method=method,
228
+ attribute_to_layer_input=attribute_to_layer_input,
229
+ )
230
+ else:
231
+ attrs = self._attribute(
232
+ inputs=inputs,
233
+ baselines=baselines,
234
+ target=target,
235
+ additional_forward_args=additional_forward_args,
236
+ n_steps=n_steps,
237
+ method=method,
238
+ attribute_to_layer_input=attribute_to_layer_input,
239
+ )
240
+
241
+ return attrs
242
+
243
+ def _attribute(
244
+ self,
245
+ inputs: Tuple[Tensor, ...],
246
+ baselines: Tuple[Union[Tensor, int, float], ...],
247
+ target: TargetType = None,
248
+ additional_forward_args: Any = None,
249
+ n_steps: int = 50,
250
+ method: str = "gausslegendre",
251
+ attribute_to_layer_input: bool = False,
252
+ step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,
253
+ ) -> Union[Tensor, Tuple[Tensor, ...]]:
254
+ if step_sizes_and_alphas is None:
255
+ # retrieve step size and scaling factor for specified approximation method
256
+ step_sizes_func, alphas_func = approximation_parameters(method)
257
+ step_sizes, alphas = step_sizes_func(n_steps), alphas_func(n_steps)
258
+ else:
259
+ step_sizes, alphas = step_sizes_and_alphas
260
+
261
+ # Compute scaled inputs from baseline to final input.
262
+ scaled_features_tpl = tuple(
263
+ torch.cat(
264
+ [baseline + alpha * (input - baseline) for alpha in alphas], dim=0
265
+ ).requires_grad_()
266
+ for input, baseline in zip(inputs, baselines)
267
+ )
268
+
269
+ additional_forward_args = _format_additional_forward_args(
270
+ additional_forward_args
271
+ )
272
+ # apply number of steps to additional forward args
273
+ # currently, number of steps is applied only to additional forward arguments
274
+ # that are nd-tensors. It is assumed that the first dimension is
275
+ # the number of batches.
276
+ # dim -> (bsz * #steps x additional_forward_args[0].shape[1:], ...)
277
+ input_additional_args = (
278
+ _expand_additional_forward_args(additional_forward_args, n_steps)
279
+ if additional_forward_args is not None
280
+ else None
281
+ )
282
+ expanded_target = _expand_target(target, n_steps)
283
+
284
+ # Returns gradient of output with respect to hidden layer.
285
+ layer_gradients, _ = compute_layer_gradients_and_eval(
286
+ forward_fn=self.forward_func,
287
+ layer=self.layer,
288
+ inputs=scaled_features_tpl,
289
+ target_ind=expanded_target,
290
+ additional_forward_args=input_additional_args,
291
+ device_ids=self.device_ids,
292
+ attribute_to_layer_input=attribute_to_layer_input,
293
+ )
294
+ # flattening grads so that we can multiply it with step-size
295
+ # calling contiguous to avoid `memory whole` problems
296
+ scaled_grads = tuple(
297
+ layer_grad.contiguous().view(n_steps, -1)
298
+ * torch.tensor(step_sizes).view(n_steps, 1).to(layer_grad.device)
299
+ for layer_grad in layer_gradients
300
+ )
301
+
302
+ # aggregates across all steps for each tensor in the input tuple
303
+ attrs = tuple(
304
+ _reshape_and_sum(
305
+ scaled_grad, n_steps, inputs[0].shape[0], layer_grad.shape[1:]
306
+ )
307
+ for scaled_grad, layer_grad in zip(scaled_grads, layer_gradients)
308
+ )
309
+ return _format_output(len(attrs) > 1, attrs)
captum/attr/_core/layer/layer_activation.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from typing import Any, Callable, List, Tuple, Union
3
+
4
+ import torch
5
+ from captum._utils.common import _format_output
6
+ from captum._utils.gradient import _forward_layer_eval
7
+ from captum._utils.typing import ModuleOrModuleList
8
+ from captum.attr._utils.attribution import LayerAttribution
9
+ from captum.log import log_usage
10
+ from torch import Tensor
11
+ from torch.nn import Module
12
+
13
+
14
+ class LayerActivation(LayerAttribution):
15
+ r"""
16
+ Computes activation of selected layer for given input.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ forward_func: Callable,
22
+ layer: ModuleOrModuleList,
23
+ device_ids: Union[None, List[int]] = None,
24
+ ) -> None:
25
+ r"""
26
+ Args:
27
+
28
+ forward_func (callable): The forward function of the model or any
29
+ modification of it
30
+ layer (torch.nn.Module or list(torch.nn.Module)): Layer or layers
31
+ for which attributions are computed.
32
+ Output size of attribute matches this layer's input or
33
+ output dimensions, depending on whether we attribute to
34
+ the inputs or outputs of the layer, corresponding to
35
+ attribution of each neuron in the input or output of
36
+ this layer. If multiple layers are provided, attributions
37
+ are returned as a list, each element corresponding to the
38
+ activations of the corresponding layer.
39
+ device_ids (list(int)): Device ID list, necessary only if forward_func
40
+ applies a DataParallel model. This allows reconstruction of
41
+ intermediate outputs from batched results across devices.
42
+ If forward_func is given as the DataParallel model itself,
43
+ then it is not necessary to provide this argument.
44
+ """
45
+ LayerAttribution.__init__(self, forward_func, layer, device_ids)
46
+
47
+ @log_usage()
48
+ def attribute(
49
+ self,
50
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
51
+ additional_forward_args: Any = None,
52
+ attribute_to_layer_input: bool = False,
53
+ ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
54
+ r"""
55
+ Args:
56
+
57
+ inputs (tensor or tuple of tensors): Input for which layer
58
+ activation is computed. If forward_func takes a single
59
+ tensor as input, a single input tensor should be provided.
60
+ If forward_func takes multiple tensors as input, a tuple
61
+ of the input tensors should be provided. It is assumed
62
+ that for all given input tensors, dimension 0 corresponds
63
+ to the number of examples, and if multiple input tensors
64
+ are provided, the examples must be aligned appropriately.
65
+ additional_forward_args (any, optional): If the forward function
66
+ requires additional arguments other than the inputs for
67
+ which attributions should not be computed, this argument
68
+ can be provided. It must be either a single additional
69
+ argument of a Tensor or arbitrary (non-tuple) type or a
70
+ tuple containing multiple additional arguments including
71
+ tensors or any arbitrary python types. These arguments
72
+ are provided to forward_func in order following the
73
+ arguments in inputs.
74
+ Note that attributions are not computed with respect
75
+ to these arguments.
76
+ Default: None
77
+ attribute_to_layer_input (bool, optional): Indicates whether to
78
+ compute the attribution with respect to the layer input
79
+ or output. If `attribute_to_layer_input` is set to True
80
+ then the attributions will be computed with respect to
81
+ layer input, otherwise it will be computed with respect
82
+ to layer output.
83
+ Note that currently it is assumed that either the input
84
+ or the output of internal layer, depending on whether we
85
+ attribute to the input or output, is a single tensor.
86
+ Support for multiple tensors will be added later.
87
+ Default: False
88
+
89
+ Returns:
90
+ *tensor* or tuple of *tensors* or *list* of **attributions**:
91
+ - **attributions** (*tensor* or tuple of *tensors* or *list*):
92
+ Activation of each neuron in given layer output.
93
+ Attributions will always be the same size as the
94
+ output of the given layer.
95
+ Attributions are returned in a tuple if
96
+ the layer inputs / outputs contain multiple tensors,
97
+ otherwise a single tensor is returned.
98
+ If multiple layers are provided, attributions
99
+ are returned as a list, each element corresponding to the
100
+ activations of the corresponding layer.
101
+
102
+
103
+
104
+ Examples::
105
+
106
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
107
+ >>> # and returns an Nx10 tensor of class probabilities.
108
+ >>> # It contains an attribute conv1, which is an instance of nn.conv2d,
109
+ >>> # and the output of this layer has dimensions Nx12x32x32.
110
+ >>> net = ImageClassifier()
111
+ >>> layer_act = LayerActivation(net, net.conv1)
112
+ >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
113
+ >>> # Computes layer activation.
114
+ >>> # attribution is layer output, with size Nx12x32x32
115
+ >>> attribution = layer_cond.attribute(input)
116
+ """
117
+ with torch.no_grad():
118
+ layer_eval = _forward_layer_eval(
119
+ self.forward_func,
120
+ inputs,
121
+ self.layer,
122
+ additional_forward_args,
123
+ device_ids=self.device_ids,
124
+ attribute_to_layer_input=attribute_to_layer_input,
125
+ )
126
+ if isinstance(self.layer, Module):
127
+ return _format_output(len(layer_eval) > 1, layer_eval)
128
+ else:
129
+ return [
130
+ _format_output(len(single_layer_eval) > 1, single_layer_eval)
131
+ for single_layer_eval in layer_eval
132
+ ]
133
+
134
+ @property
135
+ def multiplies_by_inputs(self):
136
+ return True
captum/attr/_core/layer/layer_conductance.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import typing
3
+ from typing import Any, Callable, List, Tuple, Union
4
+
5
+ import torch
6
+ from captum._utils.common import (
7
+ _expand_additional_forward_args,
8
+ _expand_target,
9
+ _format_additional_forward_args,
10
+ _format_output,
11
+ )
12
+ from captum._utils.gradient import compute_layer_gradients_and_eval
13
+ from captum._utils.typing import BaselineType, Literal, TargetType
14
+ from captum.attr._utils.approximation_methods import approximation_parameters
15
+ from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
16
+ from captum.attr._utils.batching import _batch_attribution
17
+ from captum.attr._utils.common import (
18
+ _format_input_baseline,
19
+ _reshape_and_sum,
20
+ _validate_input,
21
+ )
22
+ from captum.log import log_usage
23
+ from torch import Tensor
24
+ from torch.nn import Module
25
+
26
+
27
+ class LayerConductance(LayerAttribution, GradientAttribution):
28
+ r"""
29
+ Computes conductance with respect to the given layer. The
30
+ returned output is in the shape of the layer's output, showing the total
31
+ conductance of each hidden layer neuron.
32
+
33
+ The details of the approach can be found here:
34
+ https://arxiv.org/abs/1805.12233
35
+ https://arxiv.org/pdf/1807.09946.pdf
36
+
37
+ Note that this provides the total conductance of each neuron in the
38
+ layer's output. To obtain the breakdown of a neuron's conductance by input
39
+ features, utilize NeuronConductance instead, and provide the target
40
+ neuron index.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ forward_func: Callable,
46
+ layer: Module,
47
+ device_ids: Union[None, List[int]] = None,
48
+ ) -> None:
49
+ r"""
50
+ Args:
51
+
52
+ forward_func (callable): The forward function of the model or any
53
+ modification of it
54
+ layer (torch.nn.Module): Layer for which attributions are computed.
55
+ Output size of attribute matches this layer's input or
56
+ output dimensions, depending on whether we attribute to
57
+ the inputs or outputs of the layer, corresponding to
58
+ attribution of each neuron in the input or output of
59
+ this layer.
60
+ device_ids (list(int)): Device ID list, necessary only if forward_func
61
+ applies a DataParallel model. This allows reconstruction of
62
+ intermediate outputs from batched results across devices.
63
+ If forward_func is given as the DataParallel model itself,
64
+ then it is not necessary to provide this argument.
65
+ """
66
+ LayerAttribution.__init__(self, forward_func, layer, device_ids)
67
+ GradientAttribution.__init__(self, forward_func)
68
+
69
+ def has_convergence_delta(self) -> bool:
70
+ return True
71
+
72
+ @typing.overload
73
+ def attribute(
74
+ self,
75
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
76
+ baselines: BaselineType = None,
77
+ target: TargetType = None,
78
+ additional_forward_args: Any = None,
79
+ n_steps: int = 50,
80
+ method: str = "gausslegendre",
81
+ internal_batch_size: Union[None, int] = None,
82
+ *,
83
+ return_convergence_delta: Literal[True],
84
+ attribute_to_layer_input: bool = False,
85
+ ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]:
86
+ ...
87
+
88
+ @typing.overload
89
+ def attribute(
90
+ self,
91
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
92
+ baselines: BaselineType = None,
93
+ target: TargetType = None,
94
+ additional_forward_args: Any = None,
95
+ n_steps: int = 50,
96
+ method: str = "gausslegendre",
97
+ internal_batch_size: Union[None, int] = None,
98
+ return_convergence_delta: Literal[False] = False,
99
+ attribute_to_layer_input: bool = False,
100
+ ) -> Union[Tensor, Tuple[Tensor, ...]]:
101
+ ...
102
+
103
+ @log_usage()
104
+ def attribute(
105
+ self,
106
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
107
+ baselines: Union[
108
+ None, int, float, Tensor, Tuple[Union[int, float, Tensor], ...]
109
+ ] = None,
110
+ target: TargetType = None,
111
+ additional_forward_args: Any = None,
112
+ n_steps: int = 50,
113
+ method: str = "gausslegendre",
114
+ internal_batch_size: Union[None, int] = None,
115
+ return_convergence_delta: bool = False,
116
+ attribute_to_layer_input: bool = False,
117
+ ) -> Union[
118
+ Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
119
+ ]:
120
+ r"""
121
+ Args:
122
+
123
+ inputs (tensor or tuple of tensors): Input for which layer
124
+ conductance is computed. If forward_func takes a single
125
+ tensor as input, a single input tensor should be provided.
126
+ If forward_func takes multiple tensors as input, a tuple
127
+ of the input tensors should be provided. It is assumed
128
+ that for all given input tensors, dimension 0 corresponds
129
+ to the number of examples, and if multiple input tensors
130
+ are provided, the examples must be aligned appropriately.
131
+ baselines (scalar, tensor, tuple of scalars or tensors, optional):
132
+ Baselines define the starting point from which integral
133
+ is computed and can be provided as:
134
+
135
+ - a single tensor, if inputs is a single tensor, with
136
+ exactly the same dimensions as inputs or the first
137
+ dimension is one and the remaining dimensions match
138
+ with inputs.
139
+
140
+ - a single scalar, if inputs is a single tensor, which will
141
+ be broadcasted for each input value in input tensor.
142
+
143
+ - a tuple of tensors or scalars, the baseline corresponding
144
+ to each tensor in the inputs' tuple can be:
145
+
146
+ - either a tensor with matching dimensions to
147
+ corresponding tensor in the inputs' tuple
148
+ or the first dimension is one and the remaining
149
+ dimensions match with the corresponding
150
+ input tensor.
151
+
152
+ - or a scalar, corresponding to a tensor in the
153
+ inputs' tuple. This scalar value is broadcasted
154
+ for corresponding input tensor.
155
+ In the cases when `baselines` is not provided, we internally
156
+ use zero scalar corresponding to each input tensor.
157
+
158
+ Default: None
159
+ target (int, tuple, tensor or list, optional): Output indices for
160
+ which gradients are computed (for classification cases,
161
+ this is usually the target class).
162
+ If the network returns a scalar value per example,
163
+ no target index is necessary.
164
+ For general 2D outputs, targets can be either:
165
+
166
+ - a single integer or a tensor containing a single
167
+ integer, which is applied to all input examples
168
+
169
+ - a list of integers or a 1D tensor, with length matching
170
+ the number of examples in inputs (dim 0). Each integer
171
+ is applied as the target for the corresponding example.
172
+
173
+ For outputs with > 2 dimensions, targets can be either:
174
+
175
+ - A single tuple, which contains #output_dims - 1
176
+ elements. This target index is applied to all examples.
177
+
178
+ - A list of tuples with length equal to the number of
179
+ examples in inputs (dim 0), and each tuple containing
180
+ #output_dims - 1 elements. Each tuple is applied as the
181
+ target for the corresponding example.
182
+
183
+ Default: None
184
+ additional_forward_args (any, optional): If the forward function
185
+ requires additional arguments other than the inputs for
186
+ which attributions should not be computed, this argument
187
+ can be provided. It must be either a single additional
188
+ argument of a Tensor or arbitrary (non-tuple) type or a
189
+ tuple containing multiple additional arguments including
190
+ tensors or any arbitrary python types. These arguments
191
+ are provided to forward_func in order following the
192
+ arguments in inputs.
193
+ For a tensor, the first dimension of the tensor must
194
+ correspond to the number of examples. It will be repeated
195
+ for each of `n_steps` along the integrated path.
196
+ For all other types, the given argument is used for
197
+ all forward evaluations.
198
+ Note that attributions are not computed with respect
199
+ to these arguments.
200
+ Default: None
201
+ n_steps (int, optional): The number of steps used by the approximation
202
+ method. Default: 50.
203
+ method (string, optional): Method for approximating the integral,
204
+ one of `riemann_right`, `riemann_left`, `riemann_middle`,
205
+ `riemann_trapezoid` or `gausslegendre`.
206
+ Default: `gausslegendre` if no method is provided.
207
+ internal_batch_size (int, optional): Divides total #steps * #examples
208
+ data points into chunks of size at most internal_batch_size,
209
+ which are computed (forward / backward passes)
210
+ sequentially. internal_batch_size must be at least equal to
211
+ 2 * #examples.
212
+ For DataParallel models, each batch is split among the
213
+ available devices, so evaluations on each available
214
+ device contain internal_batch_size / num_devices examples.
215
+ If internal_batch_size is None, then all evaluations are
216
+ processed in one batch.
217
+ Default: None
218
+ return_convergence_delta (bool, optional): Indicates whether to return
219
+ convergence delta or not. If `return_convergence_delta`
220
+ is set to True convergence delta will be returned in
221
+ a tuple following attributions.
222
+ Default: False
223
+ attribute_to_layer_input (bool, optional): Indicates whether to
224
+ compute the attribution with respect to the layer input
225
+ or output. If `attribute_to_layer_input` is set to True
226
+ then the attributions will be computed with respect to
227
+ layer inputs, otherwise it will be computed with respect
228
+ to layer outputs.
229
+ Note that currently it is assumed that either the input
230
+ or the output of internal layer, depending on whether we
231
+ attribute to the input or output, is a single tensor.
232
+ Support for multiple tensors will be added later.
233
+ Default: False
234
+
235
+ Returns:
236
+ **attributions** or 2-element tuple of **attributions**, **delta**:
237
+ - **attributions** (*tensor* or tuple of *tensors*):
238
+ Conductance of each neuron in given layer input or
239
+ output. Attributions will always be the same size as
240
+ the input or output of the given layer, depending on
241
+ whether we attribute to the inputs or outputs
242
+ of the layer which is decided by the input flag
243
+ `attribute_to_layer_input`.
244
+ Attributions are returned in a tuple if
245
+ the layer inputs / outputs contain multiple tensors,
246
+ otherwise a single tensor is returned.
247
+ - **delta** (*tensor*, returned if return_convergence_delta=True):
248
+ The difference between the total
249
+ approximated and true conductance.
250
+ This is computed using the property that the total sum of
251
+ forward_func(inputs) - forward_func(baselines) must equal
252
+ the total sum of the attributions.
253
+ Delta is calculated per example, meaning that the number of
254
+ elements in returned delta tensor is equal to the number of
255
+ of examples in inputs.
256
+
257
+ Examples::
258
+
259
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
260
+ >>> # and returns an Nx10 tensor of class probabilities.
261
+ >>> # It contains an attribute conv1, which is an instance of nn.conv2d,
262
+ >>> # and the output of this layer has dimensions Nx12x32x32.
263
+ >>> net = ImageClassifier()
264
+ >>> layer_cond = LayerConductance(net, net.conv1)
265
+ >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
266
+ >>> # Computes layer conductance for class 3.
267
+ >>> # attribution size matches layer output, Nx12x32x32
268
+ >>> attribution = layer_cond.attribute(input, target=3)
269
+ """
270
+ inputs, baselines = _format_input_baseline(inputs, baselines)
271
+ _validate_input(inputs, baselines, n_steps, method)
272
+
273
+ num_examples = inputs[0].shape[0]
274
+ if internal_batch_size is not None:
275
+ num_examples = inputs[0].shape[0]
276
+ attrs = _batch_attribution(
277
+ self,
278
+ num_examples,
279
+ internal_batch_size,
280
+ n_steps + 1,
281
+ include_endpoint=True,
282
+ inputs=inputs,
283
+ baselines=baselines,
284
+ target=target,
285
+ additional_forward_args=additional_forward_args,
286
+ method=method,
287
+ attribute_to_layer_input=attribute_to_layer_input,
288
+ )
289
+
290
+ else:
291
+ attrs = self._attribute(
292
+ inputs=inputs,
293
+ baselines=baselines,
294
+ target=target,
295
+ additional_forward_args=additional_forward_args,
296
+ n_steps=n_steps,
297
+ method=method,
298
+ attribute_to_layer_input=attribute_to_layer_input,
299
+ )
300
+
301
+ is_layer_tuple = isinstance(attrs, tuple)
302
+ attributions = attrs if is_layer_tuple else (attrs,)
303
+
304
+ if return_convergence_delta:
305
+ start_point, end_point = baselines, inputs
306
+ delta = self.compute_convergence_delta(
307
+ attributions,
308
+ start_point,
309
+ end_point,
310
+ target=target,
311
+ additional_forward_args=additional_forward_args,
312
+ )
313
+ return _format_output(is_layer_tuple, attributions), delta
314
+ return _format_output(is_layer_tuple, attributions)
315
+
316
+ def _attribute(
317
+ self,
318
+ inputs: Tuple[Tensor, ...],
319
+ baselines: Tuple[Union[Tensor, int, float], ...],
320
+ target: TargetType = None,
321
+ additional_forward_args: Any = None,
322
+ n_steps: int = 50,
323
+ method: str = "gausslegendre",
324
+ attribute_to_layer_input: bool = False,
325
+ step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,
326
+ ) -> Union[Tensor, Tuple[Tensor, ...]]:
327
+ num_examples = inputs[0].shape[0]
328
+ if step_sizes_and_alphas is None:
329
+ # Retrieve scaling factors for specified approximation method
330
+ step_sizes_func, alphas_func = approximation_parameters(method)
331
+ alphas = alphas_func(n_steps + 1)
332
+ else:
333
+ _, alphas = step_sizes_and_alphas
334
+ # Compute scaled inputs from baseline to final input.
335
+ scaled_features_tpl = tuple(
336
+ torch.cat(
337
+ [baseline + alpha * (input - baseline) for alpha in alphas], dim=0
338
+ ).requires_grad_()
339
+ for input, baseline in zip(inputs, baselines)
340
+ )
341
+
342
+ additional_forward_args = _format_additional_forward_args(
343
+ additional_forward_args
344
+ )
345
+ # apply number of steps to additional forward args
346
+ # currently, number of steps is applied only to additional forward arguments
347
+ # that are nd-tensors. It is assumed that the first dimension is
348
+ # the number of batches.
349
+ # dim -> (#examples * #steps x additional_forward_args[0].shape[1:], ...)
350
+ input_additional_args = (
351
+ _expand_additional_forward_args(additional_forward_args, n_steps + 1)
352
+ if additional_forward_args is not None
353
+ else None
354
+ )
355
+ expanded_target = _expand_target(target, n_steps + 1)
356
+
357
+ # Conductance Gradients - Returns gradient of output with respect to
358
+ # hidden layer and hidden layer evaluated at each input.
359
+ (layer_gradients, layer_evals,) = compute_layer_gradients_and_eval(
360
+ forward_fn=self.forward_func,
361
+ layer=self.layer,
362
+ inputs=scaled_features_tpl,
363
+ additional_forward_args=input_additional_args,
364
+ target_ind=expanded_target,
365
+ device_ids=self.device_ids,
366
+ attribute_to_layer_input=attribute_to_layer_input,
367
+ )
368
+
369
+ # Compute differences between consecutive evaluations of layer_eval.
370
+ # This approximates the total input gradient of each step multiplied
371
+ # by the step size.
372
+ grad_diffs = tuple(
373
+ layer_eval[num_examples:] - layer_eval[:-num_examples]
374
+ for layer_eval in layer_evals
375
+ )
376
+
377
+ # Element-wise multiply gradient of output with respect to hidden layer
378
+ # and summed gradients with respect to input (chain rule) and sum
379
+ # across stepped inputs.
380
+ attributions = tuple(
381
+ _reshape_and_sum(
382
+ grad_diff * layer_gradient[:-num_examples],
383
+ n_steps,
384
+ num_examples,
385
+ layer_eval.shape[1:],
386
+ )
387
+ for layer_gradient, layer_eval, grad_diff in zip(
388
+ layer_gradients, layer_evals, grad_diffs
389
+ )
390
+ )
391
+ return _format_output(len(attributions) > 1, attributions)
392
+
393
+ @property
394
+ def multiplies_by_inputs(self):
395
+ return True
captum/attr/_core/layer/layer_deep_lift.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import typing
3
+ from typing import Any, Callable, cast, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ from captum._utils.common import (
7
+ _expand_target,
8
+ _format_additional_forward_args,
9
+ _format_baseline,
10
+ _format_tensor_into_tuples,
11
+ ExpansionTypes,
12
+ )
13
+ from captum._utils.gradient import compute_layer_gradients_and_eval
14
+ from captum._utils.typing import (
15
+ BaselineType,
16
+ Literal,
17
+ TargetType,
18
+ TensorOrTupleOfTensorsGeneric,
19
+ )
20
+ from captum.attr._core.deep_lift import DeepLift, DeepLiftShap
21
+ from captum.attr._utils.attribution import LayerAttribution
22
+ from captum.attr._utils.common import (
23
+ _call_custom_attribution_func,
24
+ _compute_conv_delta_and_format_attrs,
25
+ _format_callable_baseline,
26
+ _tensorize_baseline,
27
+ _validate_input,
28
+ )
29
+ from captum.log import log_usage
30
+ from torch import Tensor
31
+ from torch.nn import Module
32
+
33
+
34
+ class LayerDeepLift(LayerAttribution, DeepLift):
35
+ r"""
36
+ Implements DeepLIFT algorithm for the layer based on the following paper:
37
+ Learning Important Features Through Propagating Activation Differences,
38
+ Avanti Shrikumar, et. al.
39
+ https://arxiv.org/abs/1704.02685
40
+
41
+ and the gradient formulation proposed in:
42
+ Towards better understanding of gradient-based attribution methods for
43
+ deep neural networks, Marco Ancona, et.al.
44
+ https://openreview.net/pdf?id=Sy21R9JAW
45
+
46
+ This implementation supports only Rescale rule. RevealCancel rule will
47
+ be supported in later releases.
48
+ Although DeepLIFT's(Rescale Rule) attribution quality is comparable with
49
+ Integrated Gradients, it runs significantly faster than Integrated
50
+ Gradients and is preferred for large datasets.
51
+
52
+ Currently we only support a limited number of non-linear activations
53
+ but the plan is to expand the list in the future.
54
+
55
+ Note: As we know, currently we cannot access the building blocks,
56
+ of PyTorch's built-in LSTM, RNNs and GRUs such as Tanh and Sigmoid.
57
+ Nonetheless, it is possible to build custom LSTMs, RNNS and GRUs
58
+ with performance similar to built-in ones using TorchScript.
59
+ More details on how to build custom RNNs can be found here:
60
+ https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ model: Module,
66
+ layer: Module,
67
+ multiply_by_inputs: bool = True,
68
+ ) -> None:
69
+ r"""
70
+ Args:
71
+
72
+ model (nn.Module): The reference to PyTorch model instance. Model cannot
73
+ contain any in-place nonlinear submodules; these are not
74
+ supported by the register_full_backward_hook PyTorch API
75
+ starting from PyTorch v1.9.
76
+ layer (torch.nn.Module): Layer for which attributions are computed.
77
+ The size and dimensionality of the attributions
78
+ corresponds to the size and dimensionality of the layer's
79
+ input or output depending on whether we attribute to the
80
+ inputs or outputs of the layer.
81
+ multiply_by_inputs (bool, optional): Indicates whether to factor
82
+ model inputs' multiplier in the final attribution scores.
83
+ In the literature this is also known as local vs global
84
+ attribution. If inputs' multiplier isn't factored in
85
+ then that type of attribution method is also called local
86
+ attribution. If it is, then that type of attribution
87
+ method is called global.
88
+ More detailed can be found here:
89
+ https://arxiv.org/abs/1711.06104
90
+
91
+ In case of Layer DeepLift, if `multiply_by_inputs`
92
+ is set to True, final sensitivity scores
93
+ are being multiplied by
94
+ layer activations for inputs - layer activations for baselines.
95
+ This flag applies only if `custom_attribution_func` is
96
+ set to None.
97
+ """
98
+ LayerAttribution.__init__(self, model, layer)
99
+ DeepLift.__init__(self, model)
100
+ self.model = model
101
+ self._multiply_by_inputs = multiply_by_inputs
102
+
103
+ # Ignoring mypy error for inconsistent signature with DeepLift
104
+ @typing.overload # type: ignore
105
+ def attribute(
106
+ self,
107
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
108
+ baselines: BaselineType = None,
109
+ target: TargetType = None,
110
+ additional_forward_args: Any = None,
111
+ return_convergence_delta: Literal[False] = False,
112
+ attribute_to_layer_input: bool = False,
113
+ custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
114
+ ) -> Union[Tensor, Tuple[Tensor, ...]]:
115
+ ...
116
+
117
+ @typing.overload
118
+ def attribute(
119
+ self,
120
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
121
+ baselines: BaselineType = None,
122
+ target: TargetType = None,
123
+ additional_forward_args: Any = None,
124
+ *,
125
+ return_convergence_delta: Literal[True],
126
+ attribute_to_layer_input: bool = False,
127
+ custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
128
+ ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]:
129
+ ...
130
+
131
+ @log_usage()
132
+ def attribute(
133
+ self,
134
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
135
+ baselines: BaselineType = None,
136
+ target: TargetType = None,
137
+ additional_forward_args: Any = None,
138
+ return_convergence_delta: bool = False,
139
+ attribute_to_layer_input: bool = False,
140
+ custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
141
+ ) -> Union[
142
+ Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
143
+ ]:
144
+ r"""
145
+ Args:
146
+
147
+ inputs (tensor or tuple of tensors): Input for which layer
148
+ attributions are computed. If forward_func takes a
149
+ single tensor as input, a single input tensor should be
150
+ provided. If forward_func takes multiple tensors as input,
151
+ a tuple of the input tensors should be provided. It is
152
+ assumed that for all given input tensors, dimension 0
153
+ corresponds to the number of examples (aka batch size),
154
+ and if multiple input tensors are provided, the examples
155
+ must be aligned appropriately.
156
+ baselines (scalar, tensor, tuple of scalars or tensors, optional):
157
+ Baselines define reference samples that are compared with
158
+ the inputs. In order to assign attribution scores DeepLift
159
+ computes the differences between the inputs/outputs and
160
+ corresponding references.
161
+ Baselines can be provided as:
162
+
163
+ - a single tensor, if inputs is a single tensor, with
164
+ exactly the same dimensions as inputs or the first
165
+ dimension is one and the remaining dimensions match
166
+ with inputs.
167
+
168
+ - a single scalar, if inputs is a single tensor, which will
169
+ be broadcasted for each input value in input tensor.
170
+
171
+ - a tuple of tensors or scalars, the baseline corresponding
172
+ to each tensor in the inputs' tuple can be:
173
+
174
+ - either a tensor with matching dimensions to
175
+ corresponding tensor in the inputs' tuple
176
+ or the first dimension is one and the remaining
177
+ dimensions match with the corresponding
178
+ input tensor.
179
+
180
+ - or a scalar, corresponding to a tensor in the
181
+ inputs' tuple. This scalar value is broadcasted
182
+ for corresponding input tensor.
183
+ In the cases when `baselines` is not provided, we internally
184
+ use zero scalar corresponding to each input tensor.
185
+
186
+ Default: None
187
+ target (int, tuple, tensor or list, optional): Output indices for
188
+ which gradients are computed (for classification cases,
189
+ this is usually the target class).
190
+ If the network returns a scalar value per example,
191
+ no target index is necessary.
192
+ For general 2D outputs, targets can be either:
193
+
194
+ - a single integer or a tensor containing a single
195
+ integer, which is applied to all input examples
196
+
197
+ - a list of integers or a 1D tensor, with length matching
198
+ the number of examples in inputs (dim 0). Each integer
199
+ is applied as the target for the corresponding example.
200
+
201
+ For outputs with > 2 dimensions, targets can be either:
202
+
203
+ - A single tuple, which contains #output_dims - 1
204
+ elements. This target index is applied to all examples.
205
+
206
+ - A list of tuples with length equal to the number of
207
+ examples in inputs (dim 0), and each tuple containing
208
+ #output_dims - 1 elements. Each tuple is applied as the
209
+ target for the corresponding example.
210
+
211
+ Default: None
212
+ additional_forward_args (any, optional): If the forward function
213
+ requires additional arguments other than the inputs for
214
+ which attributions should not be computed, this argument
215
+ can be provided. It must be either a single additional
216
+ argument of a Tensor or arbitrary (non-tuple) type or a tuple
217
+ containing multiple additional arguments including tensors
218
+ or any arbitrary python types. These arguments are provided to
219
+ forward_func in order, following the arguments in inputs.
220
+ Note that attributions are not computed with respect
221
+ to these arguments.
222
+ Default: None
223
+ return_convergence_delta (bool, optional): Indicates whether to return
224
+ convergence delta or not. If `return_convergence_delta`
225
+ is set to True convergence delta will be returned in
226
+ a tuple following attributions.
227
+ Default: False
228
+ attribute_to_layer_input (bool, optional): Indicates whether to
229
+ compute the attribution with respect to the layer input
230
+ or output. If `attribute_to_layer_input` is set to True
231
+ then the attributions will be computed with respect to
232
+ layer input, otherwise it will be computed with respect
233
+ to layer output.
234
+ Note that currently it is assumed that either the input
235
+ or the output of internal layer, depending on whether we
236
+ attribute to the input or output, is a single tensor.
237
+ Support for multiple tensors will be added later.
238
+ Default: False
239
+ custom_attribution_func (callable, optional): A custom function for
240
+ computing final attribution scores. This function can take
241
+ at least one and at most three arguments with the
242
+ following signature:
243
+
244
+ - custom_attribution_func(multipliers)
245
+ - custom_attribution_func(multipliers, inputs)
246
+ - custom_attribution_func(multipliers, inputs, baselines)
247
+
248
+ In case this function is not provided, we use the default
249
+ logic defined as: multipliers * (inputs - baselines)
250
+ It is assumed that all input arguments, `multipliers`,
251
+ `inputs` and `baselines` are provided in tuples of same length.
252
+ `custom_attribution_func` returns a tuple of attribution
253
+ tensors that have the same length as the `inputs`.
254
+ Default: None
255
+
256
+ Returns:
257
+ **attributions** or 2-element tuple of **attributions**, **delta**:
258
+ - **attributions** (*tensor* or tuple of *tensors*):
259
+ Attribution score computed based on DeepLift's rescale rule with
260
+ respect to layer's inputs or outputs. Attributions will always be the
261
+ same size as the provided layer's inputs or outputs, depending on
262
+ whether we attribute to the inputs or outputs of the layer.
263
+ If the layer input / output is a single tensor, then
264
+ just a tensor is returned; if the layer input / output
265
+ has multiple tensors, then a corresponding tuple
266
+ of tensors is returned.
267
+ - **delta** (*tensor*, returned if return_convergence_delta=True):
268
+ This is computed using the property that the total sum of
269
+ forward_func(inputs) - forward_func(baselines) must equal the
270
+ total sum of the attributions computed based on DeepLift's
271
+ rescale rule.
272
+ Delta is calculated per example, meaning that the number of
273
+ elements in returned delta tensor is equal to the number of
274
+ of examples in input.
275
+ Note that the logic described for deltas is guaranteed
276
+ when the default logic for attribution computations is used,
277
+ meaning that the `custom_attribution_func=None`, otherwise
278
+ it is not guaranteed and depends on the specifics of the
279
+ `custom_attribution_func`.
280
+
281
+ Examples::
282
+
283
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
284
+ >>> # and returns an Nx10 tensor of class probabilities.
285
+ >>> net = ImageClassifier()
286
+ >>> # creates an instance of LayerDeepLift to interpret target
287
+ >>> # class 1 with respect to conv4 layer.
288
+ >>> dl = LayerDeepLift(net, net.conv4)
289
+ >>> input = torch.randn(1, 3, 32, 32, requires_grad=True)
290
+ >>> # Computes deeplift attribution scores for conv4 layer and class 3.
291
+ >>> attribution = dl.attribute(input, target=1)
292
+ """
293
+ inputs = _format_tensor_into_tuples(inputs)
294
+ baselines = _format_baseline(baselines, inputs)
295
+ _validate_input(inputs, baselines)
296
+
297
+ baselines = _tensorize_baseline(inputs, baselines)
298
+
299
+ main_model_hooks = []
300
+ try:
301
+ main_model_hooks = self._hook_main_model()
302
+
303
+ self.model.apply(
304
+ lambda mod: self._register_hooks(
305
+ mod, attribute_to_layer_input=attribute_to_layer_input
306
+ )
307
+ )
308
+
309
+ additional_forward_args = _format_additional_forward_args(
310
+ additional_forward_args
311
+ )
312
+ expanded_target = _expand_target(
313
+ target, 2, expansion_type=ExpansionTypes.repeat
314
+ )
315
+ wrapped_forward_func = self._construct_forward_func(
316
+ self.model,
317
+ (inputs, baselines),
318
+ expanded_target,
319
+ additional_forward_args,
320
+ )
321
+
322
+ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
323
+ if isinstance(out, Tensor):
324
+ return out.chunk(2)
325
+ return tuple(out_sub.chunk(2) for out_sub in out)
326
+
327
+ gradients, attrs = compute_layer_gradients_and_eval(
328
+ wrapped_forward_func,
329
+ self.layer,
330
+ inputs,
331
+ attribute_to_layer_input=attribute_to_layer_input,
332
+ output_fn=lambda out: chunk_output_fn(out),
333
+ )
334
+
335
+ attr_inputs = tuple(map(lambda attr: attr[0], attrs))
336
+ attr_baselines = tuple(map(lambda attr: attr[1], attrs))
337
+ gradients = tuple(map(lambda grad: grad[0], gradients))
338
+
339
+ if custom_attribution_func is None:
340
+ if self.multiplies_by_inputs:
341
+ attributions = tuple(
342
+ (input - baseline) * gradient
343
+ for input, baseline, gradient in zip(
344
+ attr_inputs, attr_baselines, gradients
345
+ )
346
+ )
347
+ else:
348
+ attributions = gradients
349
+ else:
350
+ attributions = _call_custom_attribution_func(
351
+ custom_attribution_func, gradients, attr_inputs, attr_baselines
352
+ )
353
+ finally:
354
+ # remove hooks from all activations
355
+ self._remove_hooks(main_model_hooks)
356
+
357
+ return _compute_conv_delta_and_format_attrs(
358
+ self,
359
+ return_convergence_delta,
360
+ attributions,
361
+ baselines,
362
+ inputs,
363
+ additional_forward_args,
364
+ target,
365
+ cast(Union[Literal[True], Literal[False]], len(attributions) > 1),
366
+ )
367
+
368
+ @property
369
+ def multiplies_by_inputs(self):
370
+ return self._multiply_by_inputs
371
+
372
+
373
+ class LayerDeepLiftShap(LayerDeepLift, DeepLiftShap):
374
+ r"""
375
+ Extends LayerDeepLift and DeepLiftShap algorithms and approximates SHAP
376
+ values for given input `layer`.
377
+ For each input sample - baseline pair it computes DeepLift attributions
378
+ with respect to inputs or outputs of given `layer` averages
379
+ resulting attributions across baselines. Whether to compute the attributions
380
+ with respect to the inputs or outputs of the layer is defined by the
381
+ input flag `attribute_to_layer_input`.
382
+ More details about the algorithm can be found here:
383
+
384
+ http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf
385
+
386
+ Note that the explanation model:
387
+ 1. Assumes that input features are independent of one another
388
+ 2. Is linear, meaning that the explanations are modeled through
389
+ the additive composition of feature effects.
390
+ Although, it assumes a linear model for each explanation, the overall
391
+ model across multiple explanations can be complex and non-linear.
392
+ """
393
+
394
+ def __init__(
395
+ self,
396
+ model: Module,
397
+ layer: Module,
398
+ multiply_by_inputs: bool = True,
399
+ ) -> None:
400
+ r"""
401
+ Args:
402
+
403
+ model (nn.Module): The reference to PyTorch model instance. Model cannot
404
+ contain any in-place nonlinear submodules; these are not
405
+ supported by the register_full_backward_hook PyTorch API
406
+ starting from PyTorch v1.9.
407
+ layer (torch.nn.Module): Layer for which attributions are computed.
408
+ The size and dimensionality of the attributions
409
+ corresponds to the size and dimensionality of the layer's
410
+ input or output depending on whether we attribute to the
411
+ inputs or outputs of the layer.
412
+ multiply_by_inputs (bool, optional): Indicates whether to factor
413
+ model inputs' multiplier in the final attribution scores.
414
+ In the literature this is also known as local vs global
415
+ attribution. If inputs' multiplier isn't factored in
416
+ then that type of attribution method is also called local
417
+ attribution. If it is, then that type of attribution
418
+ method is called global.
419
+ More detailed can be found here:
420
+ https://arxiv.org/abs/1711.06104
421
+
422
+ In case of LayerDeepLiftShap, if `multiply_by_inputs`
423
+ is set to True, final sensitivity scores are being
424
+ multiplied by
425
+ layer activations for inputs - layer activations for baselines
426
+ This flag applies only if `custom_attribution_func` is
427
+ set to None.
428
+ """
429
+ LayerDeepLift.__init__(self, model, layer)
430
+ DeepLiftShap.__init__(self, model, multiply_by_inputs)
431
+
432
+ # Ignoring mypy error for inconsistent signature with DeepLiftShap
433
+ @typing.overload # type: ignore
434
+ def attribute(
435
+ self,
436
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
437
+ baselines: Union[
438
+ Tensor, Tuple[Tensor, ...], Callable[..., Union[Tensor, Tuple[Tensor, ...]]]
439
+ ],
440
+ target: TargetType = None,
441
+ additional_forward_args: Any = None,
442
+ return_convergence_delta: Literal[False] = False,
443
+ attribute_to_layer_input: bool = False,
444
+ custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
445
+ ) -> Union[Tensor, Tuple[Tensor, ...]]:
446
+ ...
447
+
448
+ @typing.overload
449
+ def attribute(
450
+ self,
451
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
452
+ baselines: Union[
453
+ Tensor, Tuple[Tensor, ...], Callable[..., Union[Tensor, Tuple[Tensor, ...]]]
454
+ ],
455
+ target: TargetType = None,
456
+ additional_forward_args: Any = None,
457
+ *,
458
+ return_convergence_delta: Literal[True],
459
+ attribute_to_layer_input: bool = False,
460
+ custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
461
+ ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]:
462
+ ...
463
+
464
+ @log_usage()
465
+ def attribute(
466
+ self,
467
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
468
+ baselines: Union[
469
+ Tensor, Tuple[Tensor, ...], Callable[..., Union[Tensor, Tuple[Tensor, ...]]]
470
+ ],
471
+ target: TargetType = None,
472
+ additional_forward_args: Any = None,
473
+ return_convergence_delta: bool = False,
474
+ attribute_to_layer_input: bool = False,
475
+ custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
476
+ ) -> Union[
477
+ Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
478
+ ]:
479
+ r"""
480
+ Args:
481
+
482
+ inputs (tensor or tuple of tensors): Input for which layer
483
+ attributions are computed. If forward_func takes a single
484
+ tensor as input, a single input tensor should be provided.
485
+ If forward_func takes multiple tensors as input, a tuple
486
+ of the input tensors should be provided. It is assumed
487
+ that for all given input tensors, dimension 0 corresponds
488
+ to the number of examples (aka batch size), and if
489
+ multiple input tensors are provided, the examples must
490
+ be aligned appropriately.
491
+ baselines (tensor, tuple of tensors, callable):
492
+ Baselines define reference samples that are compared with
493
+ the inputs. In order to assign attribution scores DeepLift
494
+ computes the differences between the inputs/outputs and
495
+ corresponding references. Baselines can be provided as:
496
+
497
+ - a single tensor, if inputs is a single tensor, with
498
+ the first dimension equal to the number of examples
499
+ in the baselines' distribution. The remaining dimensions
500
+ must match with input tensor's dimension starting from
501
+ the second dimension.
502
+
503
+ - a tuple of tensors, if inputs is a tuple of tensors,
504
+ with the first dimension of any tensor inside the tuple
505
+ equal to the number of examples in the baseline's
506
+ distribution. The remaining dimensions must match
507
+ the dimensions of the corresponding input tensor
508
+ starting from the second dimension.
509
+
510
+ - callable function, optionally takes `inputs` as an
511
+ argument and either returns a single tensor
512
+ or a tuple of those.
513
+
514
+ It is recommended that the number of samples in the baselines'
515
+ tensors is larger than one.
516
+ target (int, tuple, tensor or list, optional): Output indices for
517
+ which gradients are computed (for classification cases,
518
+ this is usually the target class).
519
+ If the network returns a scalar value per example,
520
+ no target index is necessary.
521
+ For general 2D outputs, targets can be either:
522
+
523
+ - a single integer or a tensor containing a single
524
+ integer, which is applied to all input examples
525
+
526
+ - a list of integers or a 1D tensor, with length matching
527
+ the number of examples in inputs (dim 0). Each integer
528
+ is applied as the target for the corresponding example.
529
+
530
+ For outputs with > 2 dimensions, targets can be either:
531
+
532
+ - A single tuple, which contains #output_dims - 1
533
+ elements. This target index is applied to all examples.
534
+
535
+ - A list of tuples with length equal to the number of
536
+ examples in inputs (dim 0), and each tuple containing
537
+ #output_dims - 1 elements. Each tuple is applied as the
538
+ target for the corresponding example.
539
+
540
+ Default: None
541
+ additional_forward_args (any, optional): If the forward function
542
+ requires additional arguments other than the inputs for
543
+ which attributions should not be computed, this argument
544
+ can be provided. It must be either a single additional
545
+ argument of a Tensor or arbitrary (non-tuple) type or a tuple
546
+ containing multiple additional arguments including tensors
547
+ or any arbitrary python types. These arguments are provided to
548
+ forward_func in order, following the arguments in inputs.
549
+ Note that attributions are not computed with respect
550
+ to these arguments.
551
+ Default: None
552
+ return_convergence_delta (bool, optional): Indicates whether to return
553
+ convergence delta or not. If `return_convergence_delta`
554
+ is set to True convergence delta will be returned in
555
+ a tuple following attributions.
556
+ Default: False
557
+ attribute_to_layer_input (bool, optional): Indicates whether to
558
+ compute the attributions with respect to the layer input
559
+ or output. If `attribute_to_layer_input` is set to True
560
+ then the attributions will be computed with respect to
561
+ layer inputs, otherwise it will be computed with respect
562
+ to layer outputs.
563
+ Note that currently it assumes that both the inputs and
564
+ outputs of internal layers are single tensors.
565
+ Support for multiple tensors will be added later.
566
+ Default: False
567
+ custom_attribution_func (callable, optional): A custom function for
568
+ computing final attribution scores. This function can take
569
+ at least one and at most three arguments with the
570
+ following signature:
571
+
572
+ - custom_attribution_func(multipliers)
573
+ - custom_attribution_func(multipliers, inputs)
574
+ - custom_attribution_func(multipliers, inputs, baselines)
575
+
576
+ In case this function is not provided, we use the default
577
+ logic defined as: multipliers * (inputs - baselines)
578
+ It is assumed that all input arguments, `multipliers`,
579
+ `inputs` and `baselines` are provided in tuples of same
580
+ length. `custom_attribution_func` returns a tuple of
581
+ attribution tensors that have the same length as the
582
+ `inputs`.
583
+ Default: None
584
+
585
+ Returns:
586
+ **attributions** or 2-element tuple of **attributions**, **delta**:
587
+ - **attributions** (*tensor* or tuple of *tensors*):
588
+ Attribution score computed based on DeepLift's rescale rule
589
+ with respect to layer's inputs or outputs. Attributions
590
+ will always be the same size as the provided layer's inputs
591
+ or outputs, depending on whether we attribute to the inputs
592
+ or outputs of the layer.
593
+ Attributions are returned in a tuple based on whether
594
+ the layer inputs / outputs are contained in a tuple
595
+ from a forward hook. For standard modules, inputs of
596
+ a single tensor are usually wrapped in a tuple, while
597
+ outputs of a single tensor are not.
598
+ - **delta** (*tensor*, returned if return_convergence_delta=True):
599
+ This is computed using the property that the
600
+ total sum of forward_func(inputs) - forward_func(baselines)
601
+ must be very close to the total sum of attributions
602
+ computed based on approximated SHAP values using
603
+ DeepLift's rescale rule.
604
+ Delta is calculated for each example input and baseline pair,
605
+ meaning that the number of elements in returned delta tensor
606
+ is equal to the
607
+ `number of examples in input` * `number of examples
608
+ in baseline`. The deltas are ordered in the first place by
609
+ input example, followed by the baseline.
610
+ Note that the logic described for deltas is guaranteed
611
+ when the default logic for attribution computations is used,
612
+ meaning that the `custom_attribution_func=None`, otherwise
613
+ it is not guaranteed and depends on the specifics of the
614
+ `custom_attribution_func`.
615
+ Examples::
616
+
617
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
618
+ >>> # and returns an Nx10 tensor of class probabilities.
619
+ >>> net = ImageClassifier()
620
+ >>> # creates an instance of LayerDeepLift to interpret target
621
+ >>> # class 1 with respect to conv4 layer.
622
+ >>> dl = LayerDeepLiftShap(net, net.conv4)
623
+ >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
624
+ >>> # Computes shap values using deeplift for class 3.
625
+ >>> attribution = dl.attribute(input, target=3)
626
+ """
627
+ inputs = _format_tensor_into_tuples(inputs)
628
+ baselines = _format_callable_baseline(baselines, inputs)
629
+
630
+ assert isinstance(baselines[0], torch.Tensor) and baselines[0].shape[0] > 1, (
631
+ "Baselines distribution has to be provided in form of a torch.Tensor"
632
+ " with more than one example but found: {}."
633
+ " If baselines are provided in shape of scalars or with a single"
634
+ " baseline example, `LayerDeepLift`"
635
+ " approach can be used instead.".format(baselines[0])
636
+ )
637
+
638
+ # batch sizes
639
+ inp_bsz = inputs[0].shape[0]
640
+ base_bsz = baselines[0].shape[0]
641
+
642
+ (
643
+ exp_inp,
644
+ exp_base,
645
+ exp_target,
646
+ exp_addit_args,
647
+ ) = DeepLiftShap._expand_inputs_baselines_targets(
648
+ self, baselines, inputs, target, additional_forward_args
649
+ )
650
+ attributions = LayerDeepLift.attribute.__wrapped__( # type: ignore
651
+ self,
652
+ exp_inp,
653
+ exp_base,
654
+ target=exp_target,
655
+ additional_forward_args=exp_addit_args,
656
+ return_convergence_delta=cast(
657
+ Literal[True, False], return_convergence_delta
658
+ ),
659
+ attribute_to_layer_input=attribute_to_layer_input,
660
+ custom_attribution_func=custom_attribution_func,
661
+ )
662
+ if return_convergence_delta:
663
+ attributions, delta = attributions
664
+ if isinstance(attributions, tuple):
665
+ attributions = tuple(
666
+ DeepLiftShap._compute_mean_across_baselines(
667
+ self, inp_bsz, base_bsz, cast(Tensor, attrib)
668
+ )
669
+ for attrib in attributions
670
+ )
671
+ else:
672
+ attributions = DeepLiftShap._compute_mean_across_baselines(
673
+ self, inp_bsz, base_bsz, attributions
674
+ )
675
+ if return_convergence_delta:
676
+ return attributions, delta
677
+ else:
678
+ return attributions
679
+
680
+ @property
681
+ def multiplies_by_inputs(self):
682
+ return self._multiply_by_inputs
captum/attr/_core/layer/layer_feature_ablation.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from typing import Any, Callable, List, Tuple, Union
3
+
4
+ import torch
5
+ from captum._utils.common import (
6
+ _extract_device,
7
+ _format_additional_forward_args,
8
+ _format_output,
9
+ _format_tensor_into_tuples,
10
+ _run_forward,
11
+ )
12
+ from captum._utils.gradient import _forward_layer_eval
13
+ from captum._utils.typing import BaselineType, TargetType
14
+ from captum.attr._core.feature_ablation import FeatureAblation
15
+ from captum.attr._utils.attribution import LayerAttribution, PerturbationAttribution
16
+ from captum.log import log_usage
17
+ from torch import Tensor
18
+ from torch.nn import Module
19
+ from torch.nn.parallel.scatter_gather import scatter
20
+
21
+
22
+ class LayerFeatureAblation(LayerAttribution, PerturbationAttribution):
23
+ r"""
24
+ A perturbation based approach to computing layer attribution, involving
25
+ replacing values in the input / output of a layer with a given baseline /
26
+ reference, and computing the difference in output. By default, each
27
+ neuron (scalar input / output value) within the layer is replaced
28
+ independently.
29
+ Passing a layer mask allows grouping neurons to be
30
+ ablated together.
31
+ Each neuron in the group will be given the same attribution value
32
+ equal to the change in target as a result of ablating the entire neuron
33
+ group.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ forward_func: Callable,
39
+ layer: Module,
40
+ device_ids: Union[None, List[int]] = None,
41
+ ) -> None:
42
+ r"""
43
+ Args:
44
+
45
+ forward_func (callable): The forward function of the model or any
46
+ modification of it
47
+ layer (torch.nn.Module): Layer for which attributions are computed.
48
+ Output size of attribute matches this layer's input or
49
+ output dimensions, depending on whether we attribute to
50
+ the inputs or outputs of the layer, corresponding to
51
+ attribution of each neuron in the input or output of
52
+ this layer.
53
+ device_ids (list(int)): Device ID list, necessary only if forward_func
54
+ applies a DataParallel model. This allows reconstruction of
55
+ intermediate outputs from batched results across devices.
56
+ If forward_func is given as the DataParallel model itself
57
+ (or otherwise has a device_ids attribute with the device
58
+ ID list), then it is not necessary to provide this
59
+ argument.
60
+ """
61
+ LayerAttribution.__init__(self, forward_func, layer, device_ids)
62
+ PerturbationAttribution.__init__(self, forward_func)
63
+
64
+ @log_usage()
65
+ def attribute(
66
+ self,
67
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
68
+ layer_baselines: BaselineType = None,
69
+ target: TargetType = None,
70
+ additional_forward_args: Any = None,
71
+ layer_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
72
+ attribute_to_layer_input: bool = False,
73
+ perturbations_per_eval: int = 1,
74
+ ) -> Union[Tensor, Tuple[Tensor, ...]]:
75
+ r"""
76
+ Args:
77
+
78
+ inputs (tensor or tuple of tensors): Input for which layer
79
+ attributions are computed. If forward_func takes a single
80
+ tensor as input, a single input tensor should be provided.
81
+ If forward_func takes multiple tensors as input, a tuple
82
+ of the input tensors should be provided. It is assumed
83
+ that for all given input tensors, dimension 0 corresponds
84
+ to the number of examples, and if multiple input tensors
85
+ are provided, the examples must be aligned appropriately.
86
+ layer_baselines (scalar, tensor, tuple of scalars or tensors, optional):
87
+ Layer baselines define reference values which replace each
88
+ layer input / output value when ablated.
89
+ Layer baselines should be a single tensor with dimensions
90
+ matching the input / output of the target layer (or
91
+ broadcastable to match it), based
92
+ on whether we are attributing to the input or output
93
+ of the target layer.
94
+ In the cases when `baselines` is not provided, we internally
95
+ use zero as the baseline for each neuron.
96
+ Default: None
97
+ target (int, tuple, tensor or list, optional): Output indices for
98
+ which gradients are computed (for classification cases,
99
+ this is usually the target class).
100
+ If the network returns a scalar value per example,
101
+ no target index is necessary.
102
+ For general 2D outputs, targets can be either:
103
+
104
+ - a single integer or a tensor containing a single
105
+ integer, which is applied to all input examples
106
+
107
+ - a list of integers or a 1D tensor, with length matching
108
+ the number of examples in inputs (dim 0). Each integer
109
+ is applied as the target for the corresponding example.
110
+
111
+ For outputs with > 2 dimensions, targets can be either:
112
+
113
+ - A single tuple, which contains #output_dims - 1
114
+ elements. This target index is applied to all examples.
115
+
116
+ - A list of tuples with length equal to the number of
117
+ examples in inputs (dim 0), and each tuple containing
118
+ #output_dims - 1 elements. Each tuple is applied as the
119
+ target for the corresponding example.
120
+
121
+ Default: None
122
+ additional_forward_args (any, optional): If the forward function
123
+ requires additional arguments other than the inputs for
124
+ which attributions should not be computed, this argument
125
+ can be provided. It must be either a single additional
126
+ argument of a Tensor or arbitrary (non-tuple) type or a
127
+ tuple containing multiple additional arguments including
128
+ tensors or any arbitrary python types. These arguments
129
+ are provided to forward_func in order following the
130
+ arguments in inputs.
131
+ Note that attributions are not computed with respect
132
+ to these arguments.
133
+ Default: None
134
+ layer_mask (tensor or tuple of tensors, optional):
135
+ layer_mask defines a mask for the layer, grouping
136
+ elements of the layer input / output which should be
137
+ ablated together.
138
+ layer_mask should be a single tensor with dimensions
139
+ matching the input / output of the target layer (or
140
+ broadcastable to match it), based
141
+ on whether we are attributing to the input or output
142
+ of the target layer. layer_mask
143
+ should contain integers in the range 0 to num_groups
144
+ - 1, and all elements with the same value are
145
+ considered to be in the same group.
146
+ If None, then a layer mask is constructed which assigns
147
+ each neuron within the layer as a separate group, which
148
+ is ablated independently.
149
+ Default: None
150
+ attribute_to_layer_input (bool, optional): Indicates whether to
151
+ compute the attributions with respect to the layer input
152
+ or output. If `attribute_to_layer_input` is set to True
153
+ then the attributions will be computed with respect to
154
+ layer's inputs, otherwise it will be computed with respect
155
+ to layer's outputs.
156
+ Note that currently it is assumed that either the input
157
+ or the output of the layer, depending on whether we
158
+ attribute to the input or output, is a single tensor.
159
+ Support for multiple tensors will be added later.
160
+ Default: False
161
+ perturbations_per_eval (int, optional): Allows ablation of multiple
162
+ neuron (groups) to be processed simultaneously in one
163
+ call to forward_fn.
164
+ Each forward pass will contain a maximum of
165
+ perturbations_per_eval * #examples samples.
166
+ For DataParallel models, each batch is split among the
167
+ available devices, so evaluations on each available
168
+ device contain at most
169
+ (perturbations_per_eval * #examples) / num_devices
170
+ samples.
171
+ Default: 1
172
+
173
+ Returns:
174
+ *tensor* or tuple of *tensors* of **attributions**:
175
+ - **attributions** (*tensor* or tuple of *tensors*):
176
+ Attribution of each neuron in given layer input or
177
+ output. Attributions will always be the same size as
178
+ the input or output of the given layer, depending on
179
+ whether we attribute to the inputs or outputs
180
+ of the layer which is decided by the input flag
181
+ `attribute_to_layer_input`
182
+ Attributions are returned in a tuple if
183
+ the layer inputs / outputs contain multiple tensors,
184
+ otherwise a single tensor is returned.
185
+
186
+
187
+ Examples::
188
+
189
+ >>> # SimpleClassifier takes a single input tensor of size Nx4x4,
190
+ >>> # and returns an Nx3 tensor of class probabilities.
191
+ >>> # It contains an attribute conv1, which is an instance of nn.conv2d,
192
+ >>> # and the output of this layer has dimensions Nx12x3x3.
193
+ >>> net = SimpleClassifier()
194
+ >>> # Generating random input with size 2 x 4 x 4
195
+ >>> input = torch.randn(2, 4, 4)
196
+ >>> # Defining LayerFeatureAblation interpreter
197
+ >>> ablator = LayerFeatureAblation(net, net.conv1)
198
+ >>> # Computes ablation attribution, ablating each of the 108
199
+ >>> # neurons independently.
200
+ >>> attr = ablator.attribute(input, target=1)
201
+
202
+ >>> # Alternatively, we may want to ablate neurons in groups, e.g.
203
+ >>> # grouping all the layer outputs in the same row.
204
+ >>> # This can be done by creating a layer mask as follows, which
205
+ >>> # defines the groups of layer inputs / outouts, e.g.:
206
+ >>> # +---+---+---+
207
+ >>> # | 0 | 0 | 0 |
208
+ >>> # +---+---+---+
209
+ >>> # | 1 | 1 | 1 |
210
+ >>> # +---+---+---+
211
+ >>> # | 2 | 2 | 2 |
212
+ >>> # +---+---+---+
213
+ >>> # With this mask, all the 36 neurons in a row / channel are ablated
214
+ >>> # simultaneously, and the attribution for each neuron in the same
215
+ >>> # group (0 - 2) per example are the same.
216
+ >>> # The attributions can be calculated as follows:
217
+ >>> # layer mask has dimensions 1 x 3 x 3
218
+ >>> layer_mask = torch.tensor([[[0,0,0],[1,1,1],
219
+ >>> [2,2,2]]])
220
+ >>> attr = ablator.attribute(input, target=1,
221
+ >>> layer_mask=layer_mask)
222
+ """
223
+
224
+ def layer_forward_func(*args):
225
+ layer_length = args[-1]
226
+ layer_input = args[:layer_length]
227
+ original_inputs = args[layer_length:-1]
228
+
229
+ device_ids = self.device_ids
230
+ if device_ids is None:
231
+ device_ids = getattr(self.forward_func, "device_ids", None)
232
+
233
+ all_layer_inputs = {}
234
+ if device_ids is not None:
235
+ scattered_layer_input = scatter(layer_input, target_gpus=device_ids)
236
+ for device_tensors in scattered_layer_input:
237
+ all_layer_inputs[device_tensors[0].device] = device_tensors
238
+ else:
239
+ all_layer_inputs[layer_input[0].device] = layer_input
240
+
241
+ def forward_hook(module, inp, out=None):
242
+ device = _extract_device(module, inp, out)
243
+ is_layer_tuple = (
244
+ isinstance(out, tuple)
245
+ if out is not None
246
+ else isinstance(inp, tuple)
247
+ )
248
+ if device not in all_layer_inputs:
249
+ raise AssertionError(
250
+ "Layer input not placed on appropriate "
251
+ "device. If using a DataParallel model, either provide the "
252
+ "DataParallel model as forward_func or provide device ids"
253
+ " to the constructor."
254
+ )
255
+ if not is_layer_tuple:
256
+ return all_layer_inputs[device][0]
257
+ return all_layer_inputs[device]
258
+
259
+ hook = None
260
+ try:
261
+ if attribute_to_layer_input:
262
+ hook = self.layer.register_forward_pre_hook(forward_hook)
263
+ else:
264
+ hook = self.layer.register_forward_hook(forward_hook)
265
+ eval = _run_forward(self.forward_func, original_inputs, target=target)
266
+ finally:
267
+ if hook is not None:
268
+ hook.remove()
269
+ return eval
270
+
271
+ with torch.no_grad():
272
+ inputs = _format_tensor_into_tuples(inputs)
273
+ additional_forward_args = _format_additional_forward_args(
274
+ additional_forward_args
275
+ )
276
+ layer_eval = _forward_layer_eval(
277
+ self.forward_func,
278
+ inputs,
279
+ self.layer,
280
+ additional_forward_args,
281
+ device_ids=self.device_ids,
282
+ attribute_to_layer_input=attribute_to_layer_input,
283
+ )
284
+ layer_eval_len = (len(layer_eval),)
285
+ all_inputs = (
286
+ (inputs + additional_forward_args + layer_eval_len)
287
+ if additional_forward_args is not None
288
+ else inputs + layer_eval_len
289
+ )
290
+
291
+ ablator = FeatureAblation(layer_forward_func)
292
+
293
+ layer_attribs = ablator.attribute.__wrapped__(
294
+ ablator, # self
295
+ layer_eval,
296
+ baselines=layer_baselines,
297
+ additional_forward_args=all_inputs,
298
+ feature_mask=layer_mask,
299
+ perturbations_per_eval=perturbations_per_eval,
300
+ )
301
+ _attr = _format_output(len(layer_attribs) > 1, layer_attribs)
302
+ return _attr
captum/attr/_core/layer/layer_gradient_shap.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import typing
4
+ from typing import Any, Callable, cast, List, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from captum._utils.gradient import _forward_layer_eval, compute_layer_gradients_and_eval
9
+ from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric
10
+ from captum.attr._core.gradient_shap import _scale_input
11
+ from captum.attr._core.noise_tunnel import NoiseTunnel
12
+ from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
13
+ from captum.attr._utils.common import (
14
+ _compute_conv_delta_and_format_attrs,
15
+ _format_callable_baseline,
16
+ _format_input_baseline,
17
+ )
18
+ from captum.log import log_usage
19
+ from torch import Tensor
20
+ from torch.nn import Module
21
+
22
+
23
+ class LayerGradientShap(LayerAttribution, GradientAttribution):
24
+ r"""
25
+ Implements gradient SHAP for layer based on the implementation from SHAP's
26
+ primary author. For reference, please, view:
27
+
28
+ https://github.com/slundberg/shap\
29
+ #deep-learning-example-with-gradientexplainer-tensorflowkeraspytorch-models
30
+
31
+ A Unified Approach to Interpreting Model Predictions
32
+ http://papers.nips.cc/paper\
33
+ 7062-a-unified-approach-to-interpreting-model-predictions
34
+
35
+ GradientShap approximates SHAP values by computing the expectations of
36
+ gradients by randomly sampling from the distribution of baselines/references.
37
+ It adds white noise to each input sample `n_samples` times, selects a
38
+ random baseline from baselines' distribution and a random point along the
39
+ path between the baseline and the input, and computes the gradient of
40
+ outputs with respect to selected random points in chosen `layer`.
41
+ The final SHAP values represent the expected values of
42
+ `gradients * (layer_attr_inputs - layer_attr_baselines)`.
43
+
44
+ GradientShap makes an assumption that the input features are independent
45
+ and that the explanation model is linear, meaning that the explanations
46
+ are modeled through the additive composition of feature effects.
47
+ Under those assumptions, SHAP value can be approximated as the expectation
48
+ of gradients that are computed for randomly generated `n_samples` input
49
+ samples after adding gaussian noise `n_samples` times to each input for
50
+ different baselines/references.
51
+
52
+ In some sense it can be viewed as an approximation of integrated gradients
53
+ by computing the expectations of gradients for different baselines.
54
+
55
+ Current implementation uses Smoothgrad from `NoiseTunnel` in order to
56
+ randomly draw samples from the distribution of baselines, add noise to input
57
+ samples and compute the expectation (smoothgrad).
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ forward_func: Callable,
63
+ layer: Module,
64
+ device_ids: Union[None, List[int]] = None,
65
+ multiply_by_inputs: bool = True,
66
+ ) -> None:
67
+ r"""
68
+ Args:
69
+
70
+ forward_func (callable): The forward function of the model or any
71
+ modification of it
72
+ layer (torch.nn.Module): Layer for which attributions are computed.
73
+ Output size of attribute matches this layer's input or
74
+ output dimensions, depending on whether we attribute to
75
+ the inputs or outputs of the layer, corresponding to
76
+ attribution of each neuron in the input or output of
77
+ this layer.
78
+ device_ids (list(int)): Device ID list, necessary only if forward_func
79
+ applies a DataParallel model. This allows reconstruction of
80
+ intermediate outputs from batched results across devices.
81
+ If forward_func is given as the DataParallel model itself,
82
+ then it is not necessary to provide this argument.
83
+ multiply_by_inputs (bool, optional): Indicates whether to factor
84
+ model inputs' multiplier in the final attribution scores.
85
+ In the literature this is also known as local vs global
86
+ attribution. If inputs' multiplier isn't factored in,
87
+ then this type of attribution method is also called local
88
+ attribution. If it is, then that type of attribution
89
+ method is called global.
90
+ More detailed can be found here:
91
+ https://arxiv.org/abs/1711.06104
92
+
93
+ In case of layer gradient shap, if `multiply_by_inputs`
94
+ is set to True, the sensitivity scores for scaled inputs
95
+ are being multiplied by
96
+ layer activations for inputs - layer activations for baselines.
97
+
98
+ """
99
+ LayerAttribution.__init__(self, forward_func, layer, device_ids)
100
+ GradientAttribution.__init__(self, forward_func)
101
+ self._multiply_by_inputs = multiply_by_inputs
102
+
103
+ @typing.overload
104
+ def attribute(
105
+ self,
106
+ inputs: TensorOrTupleOfTensorsGeneric,
107
+ baselines: Union[TensorOrTupleOfTensorsGeneric, Callable],
108
+ n_samples: int = 5,
109
+ stdevs: Union[float, Tuple[float, ...]] = 0.0,
110
+ target: TargetType = None,
111
+ additional_forward_args: Any = None,
112
+ *,
113
+ return_convergence_delta: Literal[True],
114
+ attribute_to_layer_input: bool = False,
115
+ ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]:
116
+ ...
117
+
118
+ @typing.overload
119
+ def attribute(
120
+ self,
121
+ inputs: TensorOrTupleOfTensorsGeneric,
122
+ baselines: Union[TensorOrTupleOfTensorsGeneric, Callable],
123
+ n_samples: int = 5,
124
+ stdevs: Union[float, Tuple[float, ...]] = 0.0,
125
+ target: TargetType = None,
126
+ additional_forward_args: Any = None,
127
+ return_convergence_delta: Literal[False] = False,
128
+ attribute_to_layer_input: bool = False,
129
+ ) -> Union[Tensor, Tuple[Tensor, ...]]:
130
+ ...
131
+
132
+ @log_usage()
133
+ def attribute(
134
+ self,
135
+ inputs: TensorOrTupleOfTensorsGeneric,
136
+ baselines: Union[TensorOrTupleOfTensorsGeneric, Callable],
137
+ n_samples: int = 5,
138
+ stdevs: Union[float, Tuple[float, ...]] = 0.0,
139
+ target: TargetType = None,
140
+ additional_forward_args: Any = None,
141
+ return_convergence_delta: bool = False,
142
+ attribute_to_layer_input: bool = False,
143
+ ) -> Union[
144
+ Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
145
+ ]:
146
+ r"""
147
+ Args:
148
+
149
+ inputs (tensor or tuple of tensors): Input which are used to compute
150
+ SHAP attribution values for a given `layer`. If `forward_func`
151
+ takes a single tensor as input, a single input tensor should
152
+ be provided.
153
+ If `forward_func` takes multiple tensors as input, a tuple
154
+ of the input tensors should be provided. It is assumed
155
+ that for all given input tensors, dimension 0 corresponds
156
+ to the number of examples, and if multiple input tensors
157
+ are provided, the examples must be aligned appropriately.
158
+ baselines (tensor, tuple of tensors, callable):
159
+ Baselines define the starting point from which expectation
160
+ is computed and can be provided as:
161
+
162
+ - a single tensor, if inputs is a single tensor, with
163
+ the first dimension equal to the number of examples
164
+ in the baselines' distribution. The remaining dimensions
165
+ must match with input tensor's dimension starting from
166
+ the second dimension.
167
+
168
+ - a tuple of tensors, if inputs is a tuple of tensors,
169
+ with the first dimension of any tensor inside the tuple
170
+ equal to the number of examples in the baseline's
171
+ distribution. The remaining dimensions must match
172
+ the dimensions of the corresponding input tensor
173
+ starting from the second dimension.
174
+
175
+ - callable function, optionally takes `inputs` as an
176
+ argument and either returns a single tensor
177
+ or a tuple of those.
178
+
179
+ It is recommended that the number of samples in the baselines'
180
+ tensors is larger than one.
181
+ n_samples (int, optional): The number of randomly generated examples
182
+ per sample in the input batch. Random examples are
183
+ generated by adding gaussian random noise to each sample.
184
+ Default: `5` if `n_samples` is not provided.
185
+ stdevs (float, or a tuple of floats optional): The standard deviation
186
+ of gaussian noise with zero mean that is added to each
187
+ input in the batch. If `stdevs` is a single float value
188
+ then that same value is used for all inputs. If it is
189
+ a tuple, then it must have the same length as the inputs
190
+ tuple. In this case, each stdev value in the stdevs tuple
191
+ corresponds to the input with the same index in the inputs
192
+ tuple.
193
+ Default: 0.0
194
+ target (int, tuple, tensor or list, optional): Output indices for
195
+ which gradients are computed (for classification cases,
196
+ this is usually the target class).
197
+ If the network returns a scalar value per example,
198
+ no target index is necessary.
199
+ For general 2D outputs, targets can be either:
200
+
201
+ - a single integer or a tensor containing a single
202
+ integer, which is applied to all input examples
203
+
204
+ - a list of integers or a 1D tensor, with length matching
205
+ the number of examples in inputs (dim 0). Each integer
206
+ is applied as the target for the corresponding example.
207
+
208
+ For outputs with > 2 dimensions, targets can be either:
209
+
210
+ - A single tuple, which contains #output_dims - 1
211
+ elements. This target index is applied to all examples.
212
+
213
+ - A list of tuples with length equal to the number of
214
+ examples in inputs (dim 0), and each tuple containing
215
+ #output_dims - 1 elements. Each tuple is applied as the
216
+ target for the corresponding example.
217
+
218
+ Default: None
219
+ additional_forward_args (any, optional): If the forward function
220
+ requires additional arguments other than the inputs for
221
+ which attributions should not be computed, this argument
222
+ can be provided. It can contain a tuple of ND tensors or
223
+ any arbitrary python type of any shape.
224
+ In case of the ND tensor the first dimension of the
225
+ tensor must correspond to the batch size. It will be
226
+ repeated for each `n_steps` for each randomly generated
227
+ input sample.
228
+ Note that the attributions are not computed with respect
229
+ to these arguments.
230
+ Default: None
231
+ return_convergence_delta (bool, optional): Indicates whether to return
232
+ convergence delta or not. If `return_convergence_delta`
233
+ is set to True convergence delta will be returned in
234
+ a tuple following attributions.
235
+ Default: False
236
+ attribute_to_layer_input (bool, optional): Indicates whether to
237
+ compute the attribution with respect to the layer input
238
+ or output. If `attribute_to_layer_input` is set to True
239
+ then the attributions will be computed with respect to
240
+ layer input, otherwise it will be computed with respect
241
+ to layer output.
242
+ Note that currently it is assumed that either the input
243
+ or the output of internal layer, depending on whether we
244
+ attribute to the input or output, is a single tensor.
245
+ Support for multiple tensors will be added later.
246
+ Default: False
247
+ Returns:
248
+ **attributions** or 2-element tuple of **attributions**, **delta**:
249
+ - **attributions** (*tensor* or tuple of *tensors*):
250
+ Attribution score computed based on GradientSHAP with
251
+ respect to layer's input or output. Attributions will always
252
+ be the same size as the provided layer's inputs or outputs,
253
+ depending on whether we attribute to the inputs or outputs
254
+ of the layer.
255
+ Attributions are returned in a tuple if
256
+ the layer inputs / outputs contain multiple tensors,
257
+ otherwise a single tensor is returned.
258
+ - **delta** (*tensor*, returned if return_convergence_delta=True):
259
+ This is computed using the property that the total
260
+ sum of forward_func(inputs) - forward_func(baselines)
261
+ must be very close to the total sum of the attributions
262
+ based on layer gradient SHAP.
263
+ Delta is calculated for each example in the input after adding
264
+ `n_samples` times gaussian noise to each of them. Therefore,
265
+ the dimensionality of the deltas tensor is equal to the
266
+ `number of examples in the input` * `n_samples`
267
+ The deltas are ordered by each input example and `n_samples`
268
+ noisy samples generated for it.
269
+
270
+ Examples::
271
+
272
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
273
+ >>> # and returns an Nx10 tensor of class probabilities.
274
+ >>> net = ImageClassifier()
275
+ >>> layer_grad_shap = LayerGradientShap(net, net.linear1)
276
+ >>> input = torch.randn(3, 3, 32, 32, requires_grad=True)
277
+ >>> # choosing baselines randomly
278
+ >>> baselines = torch.randn(20, 3, 32, 32)
279
+ >>> # Computes gradient SHAP of output layer when target is equal
280
+ >>> # to 0 with respect to the layer linear1.
281
+ >>> # Attribution size matches to the size of the linear1 layer
282
+ >>> attribution = layer_grad_shap.attribute(input, baselines,
283
+ target=5)
284
+
285
+ """
286
+ # since `baselines` is a distribution, we can generate it using a function
287
+ # rather than passing it as an input argument
288
+ baselines = _format_callable_baseline(baselines, inputs)
289
+ assert isinstance(baselines[0], torch.Tensor), (
290
+ "Baselines distribution has to be provided in a form "
291
+ "of a torch.Tensor {}.".format(baselines[0])
292
+ )
293
+
294
+ input_min_baseline_x_grad = LayerInputBaselineXGradient(
295
+ self.forward_func,
296
+ self.layer,
297
+ device_ids=self.device_ids,
298
+ multiply_by_inputs=self.multiplies_by_inputs,
299
+ )
300
+
301
+ nt = NoiseTunnel(input_min_baseline_x_grad)
302
+
303
+ attributions = nt.attribute.__wrapped__(
304
+ nt, # self
305
+ inputs,
306
+ nt_type="smoothgrad",
307
+ nt_samples=n_samples,
308
+ stdevs=stdevs,
309
+ draw_baseline_from_distrib=True,
310
+ baselines=baselines,
311
+ target=target,
312
+ additional_forward_args=additional_forward_args,
313
+ return_convergence_delta=return_convergence_delta,
314
+ attribute_to_layer_input=attribute_to_layer_input,
315
+ )
316
+
317
+ return attributions
318
+
319
+ def has_convergence_delta(self) -> bool:
320
+ return True
321
+
322
+ @property
323
+ def multiplies_by_inputs(self):
324
+ return self._multiply_by_inputs
325
+
326
+
327
+ class LayerInputBaselineXGradient(LayerAttribution, GradientAttribution):
328
+ def __init__(
329
+ self,
330
+ forward_func: Callable,
331
+ layer: Module,
332
+ device_ids: Union[None, List[int]] = None,
333
+ multiply_by_inputs: bool = True,
334
+ ) -> None:
335
+ r"""
336
+ Args:
337
+
338
+ forward_func (callable): The forward function of the model or any
339
+ modification of it
340
+ layer (torch.nn.Module): Layer for which attributions are computed.
341
+ Output size of attribute matches this layer's input or
342
+ output dimensions, depending on whether we attribute to
343
+ the inputs or outputs of the layer, corresponding to
344
+ attribution of each neuron in the input or output of
345
+ this layer.
346
+ device_ids (list(int)): Device ID list, necessary only if forward_func
347
+ applies a DataParallel model. This allows reconstruction of
348
+ intermediate outputs from batched results across devices.
349
+ If forward_func is given as the DataParallel model itself,
350
+ then it is not necessary to provide this argument.
351
+ multiply_by_inputs (bool, optional): Indicates whether to factor
352
+ model inputs' multiplier in the final attribution scores.
353
+ In the literature this is also known as local vs global
354
+ attribution. If inputs' multiplier isn't factored in,
355
+ then this type of attribution method is also called local
356
+ attribution. If it is, then that type of attribution
357
+ method is called global.
358
+ More detailed can be found here:
359
+ https://arxiv.org/abs/1711.06104
360
+
361
+ In case of layer input minus baseline x gradient,
362
+ if `multiply_by_inputs` is set to True, the sensitivity scores
363
+ for scaled inputs are being multiplied by
364
+ layer activations for inputs - layer activations for baselines.
365
+
366
+ """
367
+ LayerAttribution.__init__(self, forward_func, layer, device_ids)
368
+ GradientAttribution.__init__(self, forward_func)
369
+ self._multiply_by_inputs = multiply_by_inputs
370
+
371
+ @typing.overload
372
+ def attribute(
373
+ self,
374
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
375
+ baselines: Union[Tensor, Tuple[Tensor, ...]],
376
+ target: TargetType = None,
377
+ additional_forward_args: Any = None,
378
+ return_convergence_delta: Literal[False] = False,
379
+ attribute_to_layer_input: bool = False,
380
+ ) -> Union[Tensor, Tuple[Tensor, ...]]:
381
+ ...
382
+
383
+ @typing.overload
384
+ def attribute(
385
+ self,
386
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
387
+ baselines: Union[Tensor, Tuple[Tensor, ...]],
388
+ target: TargetType = None,
389
+ additional_forward_args: Any = None,
390
+ *,
391
+ return_convergence_delta: Literal[True],
392
+ attribute_to_layer_input: bool = False,
393
+ ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]:
394
+ ...
395
+
396
+ @log_usage()
397
+ def attribute( # type: ignore
398
+ self,
399
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
400
+ baselines: Union[Tensor, Tuple[Tensor, ...]],
401
+ target: TargetType = None,
402
+ additional_forward_args: Any = None,
403
+ return_convergence_delta: bool = False,
404
+ attribute_to_layer_input: bool = False,
405
+ ) -> Union[
406
+ Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
407
+ ]:
408
+ inputs, baselines = _format_input_baseline(inputs, baselines)
409
+ rand_coefficient = torch.tensor(
410
+ np.random.uniform(0.0, 1.0, inputs[0].shape[0]),
411
+ device=inputs[0].device,
412
+ dtype=inputs[0].dtype,
413
+ )
414
+
415
+ input_baseline_scaled = tuple(
416
+ _scale_input(input, baseline, rand_coefficient)
417
+ for input, baseline in zip(inputs, baselines)
418
+ )
419
+ grads, _ = compute_layer_gradients_and_eval(
420
+ self.forward_func,
421
+ self.layer,
422
+ input_baseline_scaled,
423
+ target,
424
+ additional_forward_args,
425
+ device_ids=self.device_ids,
426
+ attribute_to_layer_input=attribute_to_layer_input,
427
+ )
428
+
429
+ attr_baselines = _forward_layer_eval(
430
+ self.forward_func,
431
+ baselines,
432
+ self.layer,
433
+ additional_forward_args=additional_forward_args,
434
+ device_ids=self.device_ids,
435
+ attribute_to_layer_input=attribute_to_layer_input,
436
+ )
437
+
438
+ attr_inputs = _forward_layer_eval(
439
+ self.forward_func,
440
+ inputs,
441
+ self.layer,
442
+ additional_forward_args=additional_forward_args,
443
+ device_ids=self.device_ids,
444
+ attribute_to_layer_input=attribute_to_layer_input,
445
+ )
446
+
447
+ if self.multiplies_by_inputs:
448
+ input_baseline_diffs = tuple(
449
+ input - baseline for input, baseline in zip(attr_inputs, attr_baselines)
450
+ )
451
+ attributions = tuple(
452
+ input_baseline_diff * grad
453
+ for input_baseline_diff, grad in zip(input_baseline_diffs, grads)
454
+ )
455
+ else:
456
+ attributions = grads
457
+
458
+ return _compute_conv_delta_and_format_attrs(
459
+ self,
460
+ return_convergence_delta,
461
+ attributions,
462
+ baselines,
463
+ inputs,
464
+ additional_forward_args,
465
+ target,
466
+ cast(Union[Literal[True], Literal[False]], len(attributions) > 1),
467
+ )
468
+
469
+ def has_convergence_delta(self) -> bool:
470
+ return True
471
+
472
+ @property
473
+ def multiplies_by_inputs(self):
474
+ return self._multiply_by_inputs
captum/attr/_core/layer/layer_gradient_x_activation.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from typing import Any, Callable, List, Tuple, Union
3
+
4
+ from captum._utils.common import (
5
+ _format_additional_forward_args,
6
+ _format_output,
7
+ _format_tensor_into_tuples,
8
+ )
9
+ from captum._utils.gradient import compute_layer_gradients_and_eval
10
+ from captum._utils.typing import ModuleOrModuleList, TargetType
11
+ from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
12
+ from captum.log import log_usage
13
+ from torch import Tensor
14
+ from torch.nn import Module
15
+
16
+
17
+ class LayerGradientXActivation(LayerAttribution, GradientAttribution):
18
+ r"""
19
+ Computes element-wise product of gradient and activation for selected
20
+ layer on given inputs.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ forward_func: Callable,
26
+ layer: ModuleOrModuleList,
27
+ device_ids: Union[None, List[int]] = None,
28
+ multiply_by_inputs: bool = True,
29
+ ) -> None:
30
+ r"""
31
+ Args:
32
+
33
+ forward_func (callable): The forward function of the model or any
34
+ modification of it
35
+ layer (torch.nn.Module or list(torch.nn.Module)): Layer or layers
36
+ for which attributions are computed.
37
+ Output size of attribute matches this layer's input or
38
+ output dimensions, depending on whether we attribute to
39
+ the inputs or outputs of the layer, corresponding to
40
+ attribution of each neuron in the input or output of
41
+ this layer. If multiple layers are provided, attributions
42
+ are returned as a list, each element corresponding to the
43
+ attributions of the corresponding layer.
44
+ device_ids (list(int)): Device ID list, necessary only if forward_func
45
+ applies a DataParallel model. This allows reconstruction of
46
+ intermediate outputs from batched results across devices.
47
+ If forward_func is given as the DataParallel model itself,
48
+ then it is not necessary to provide this argument.
49
+ multiply_by_inputs (bool, optional): Indicates whether to factor
50
+ model inputs' multiplier in the final attribution scores.
51
+ In the literature this is also known as local vs global
52
+ attribution. If inputs' multiplier isn't factored in,
53
+ then this type of attribution method is also called local
54
+ attribution. If it is, then that type of attribution
55
+ method is called global.
56
+ More detailed can be found here:
57
+ https://arxiv.org/abs/1711.06104
58
+
59
+ In case of layer gradient x activation, if `multiply_by_inputs`
60
+ is set to True, final sensitivity scores are being multiplied by
61
+ layer activations for inputs.
62
+
63
+ """
64
+ LayerAttribution.__init__(self, forward_func, layer, device_ids)
65
+ GradientAttribution.__init__(self, forward_func)
66
+ self._multiply_by_inputs = multiply_by_inputs
67
+
68
+ @property
69
+ def multiplies_by_inputs(self):
70
+ return self._multiply_by_inputs
71
+
72
+ @log_usage()
73
+ def attribute(
74
+ self,
75
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
76
+ target: TargetType = None,
77
+ additional_forward_args: Any = None,
78
+ attribute_to_layer_input: bool = False,
79
+ ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
80
+ r"""
81
+ Args:
82
+
83
+ inputs (tensor or tuple of tensors): Input for which attributions
84
+ are computed. If forward_func takes a single
85
+ tensor as input, a single input tensor should be provided.
86
+ If forward_func takes multiple tensors as input, a tuple
87
+ of the input tensors should be provided. It is assumed
88
+ that for all given input tensors, dimension 0 corresponds
89
+ to the number of examples, and if multiple input tensors
90
+ are provided, the examples must be aligned appropriately.
91
+ target (int, tuple, tensor or list, optional): Output indices for
92
+ which gradients are computed (for classification cases,
93
+ this is usually the target class).
94
+ If the network returns a scalar value per example,
95
+ no target index is necessary.
96
+ For general 2D outputs, targets can be either:
97
+
98
+ - a single integer or a tensor containing a single
99
+ integer, which is applied to all input examples
100
+
101
+ - a list of integers or a 1D tensor, with length matching
102
+ the number of examples in inputs (dim 0). Each integer
103
+ is applied as the target for the corresponding example.
104
+
105
+ For outputs with > 2 dimensions, targets can be either:
106
+
107
+ - A single tuple, which contains #output_dims - 1
108
+ elements. This target index is applied to all examples.
109
+
110
+ - A list of tuples with length equal to the number of
111
+ examples in inputs (dim 0), and each tuple containing
112
+ #output_dims - 1 elements. Each tuple is applied as the
113
+ target for the corresponding example.
114
+
115
+ Default: None
116
+ additional_forward_args (any, optional): If the forward function
117
+ requires additional arguments other than the inputs for
118
+ which attributions should not be computed, this argument
119
+ can be provided. It must be either a single additional
120
+ argument of a Tensor or arbitrary (non-tuple) type or a
121
+ tuple containing multiple additional arguments including
122
+ tensors or any arbitrary python types. These arguments
123
+ are provided to forward_func in order following the
124
+ arguments in inputs.
125
+ Note that attributions are not computed with respect
126
+ to these arguments.
127
+ Default: None
128
+ attribute_to_layer_input (bool, optional): Indicates whether to
129
+ compute the attribution with respect to the layer input
130
+ or output. If `attribute_to_layer_input` is set to True
131
+ then the attributions will be computed with respect to
132
+ layer input, otherwise it will be computed with respect
133
+ to layer output.
134
+ Default: False
135
+
136
+ Returns:
137
+ *tensor* or tuple of *tensors* or *list* of **attributions**:
138
+ - **attributions** (*tensor* or tuple of *tensors* or *list*):
139
+ Product of gradient and activation for each
140
+ neuron in given layer output.
141
+ Attributions will always be the same size as the
142
+ output of the given layer.
143
+ Attributions are returned in a tuple if
144
+ the layer inputs / outputs contain multiple tensors,
145
+ otherwise a single tensor is returned.
146
+ If multiple layers are provided, attributions
147
+ are returned as a list, each element corresponding to the
148
+ activations of the corresponding layer.
149
+
150
+
151
+ Examples::
152
+
153
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
154
+ >>> # and returns an Nx10 tensor of class probabilities.
155
+ >>> # It contains an attribute conv1, which is an instance of nn.conv2d,
156
+ >>> # and the output of this layer has dimensions Nx12x32x32.
157
+ >>> net = ImageClassifier()
158
+ >>> layer_ga = LayerGradientXActivation(net, net.conv1)
159
+ >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
160
+ >>> # Computes layer activation x gradient for class 3.
161
+ >>> # attribution size matches layer output, Nx12x32x32
162
+ >>> attribution = layer_ga.attribute(input, 3)
163
+ """
164
+ inputs = _format_tensor_into_tuples(inputs)
165
+ additional_forward_args = _format_additional_forward_args(
166
+ additional_forward_args
167
+ )
168
+ # Returns gradient of output with respect to
169
+ # hidden layer and hidden layer evaluated at each input.
170
+ layer_gradients, layer_evals = compute_layer_gradients_and_eval(
171
+ self.forward_func,
172
+ self.layer,
173
+ inputs,
174
+ target,
175
+ additional_forward_args,
176
+ device_ids=self.device_ids,
177
+ attribute_to_layer_input=attribute_to_layer_input,
178
+ )
179
+ if isinstance(self.layer, Module):
180
+ return _format_output(
181
+ len(layer_evals) > 1,
182
+ self.multiply_gradient_acts(layer_gradients, layer_evals),
183
+ )
184
+ else:
185
+ return [
186
+ _format_output(
187
+ len(layer_evals[i]) > 1,
188
+ self.multiply_gradient_acts(layer_gradients[i], layer_evals[i]),
189
+ )
190
+ for i in range(len(self.layer))
191
+ ]
192
+
193
+ def multiply_gradient_acts(
194
+ self, gradients: Tuple[Tensor, ...], evals: Tuple[Tensor, ...]
195
+ ) -> Tuple[Tensor, ...]:
196
+ return tuple(
197
+ single_gradient * single_eval
198
+ if self.multiplies_by_inputs
199
+ else single_gradient
200
+ for single_gradient, single_eval in zip(gradients, evals)
201
+ )
captum/attr/_core/layer/layer_integrated_gradients.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import functools
3
+ import warnings
4
+ from typing import Any, Callable, List, overload, Tuple, Union
5
+
6
+ import torch
7
+ from captum._utils.common import (
8
+ _extract_device,
9
+ _format_additional_forward_args,
10
+ _format_outputs,
11
+ )
12
+ from captum._utils.gradient import _forward_layer_eval, _run_forward
13
+ from captum._utils.typing import BaselineType, Literal, ModuleOrModuleList, TargetType
14
+ from captum.attr._core.integrated_gradients import IntegratedGradients
15
+ from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
16
+ from captum.attr._utils.common import (
17
+ _format_input_baseline,
18
+ _tensorize_baseline,
19
+ _validate_input,
20
+ )
21
+ from captum.log import log_usage
22
+ from torch import Tensor
23
+ from torch.nn.parallel.scatter_gather import scatter
24
+
25
+
26
+ class LayerIntegratedGradients(LayerAttribution, GradientAttribution):
27
+ r"""
28
+ Layer Integrated Gradients is a variant of Integrated Gradients that assigns
29
+ an importance score to layer inputs or outputs, depending on whether we
30
+ attribute to the former or to the latter one.
31
+
32
+ Integrated Gradients is an axiomatic model interpretability algorithm that
33
+ attributes / assigns an importance score to each input feature by approximating
34
+ the integral of gradients of the model's output with respect to the inputs
35
+ along the path (straight line) from given baselines / references to inputs.
36
+
37
+ Baselines can be provided as input arguments to attribute method.
38
+ To approximate the integral we can choose to use either a variant of
39
+ Riemann sum or Gauss-Legendre quadrature rule.
40
+
41
+ More details regarding the integrated gradients method can be found in the
42
+ original paper:
43
+ https://arxiv.org/abs/1703.01365
44
+
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ forward_func: Callable,
50
+ layer: ModuleOrModuleList,
51
+ device_ids: Union[None, List[int]] = None,
52
+ multiply_by_inputs: bool = True,
53
+ ) -> None:
54
+ r"""
55
+ Args:
56
+ forward_func (callable): The forward function of the model or any
57
+ modification of it
58
+ layer (ModuleOrModuleList):
59
+ Layer or list of layers for which attributions are computed.
60
+ For each layer the output size of the attribute matches
61
+ this layer's input or output dimensions, depending on
62
+ whether we attribute to the inputs or outputs of the
63
+ layer, corresponding to the attribution of each neuron
64
+ in the input or output of this layer.
65
+
66
+ Please note that layers to attribute on cannot be
67
+ dependent on each other. That is, a subset of layers in
68
+ `layer` cannot produce the inputs for another layer.
69
+
70
+ For example, if your model is of a simple linked-list
71
+ based graph structure (think nn.Sequence), e.g. x -> l1
72
+ -> l2 -> l3 -> output. If you pass in any one of those
73
+ layers, you cannot pass in another due to the
74
+ dependence, e.g. if you pass in l2 you cannot pass in
75
+ l1 or l3.
76
+
77
+ device_ids (list(int)): Device ID list, necessary only if forward_func
78
+ applies a DataParallel model. This allows reconstruction of
79
+ intermediate outputs from batched results across devices.
80
+ If forward_func is given as the DataParallel model itself,
81
+ then it is not necessary to provide this argument.
82
+ multiply_by_inputs (bool, optional): Indicates whether to factor
83
+ model inputs' multiplier in the final attribution scores.
84
+ In the literature this is also known as local vs global
85
+ attribution. If inputs' multiplier isn't factored in,
86
+ then this type of attribution method is also called local
87
+ attribution. If it is, then that type of attribution
88
+ method is called global.
89
+ More detailed can be found here:
90
+ https://arxiv.org/abs/1711.06104
91
+
92
+ In case of layer integrated gradients, if `multiply_by_inputs`
93
+ is set to True, final sensitivity scores are being multiplied by
94
+ layer activations for inputs - layer activations for baselines.
95
+
96
+ """
97
+ LayerAttribution.__init__(self, forward_func, layer, device_ids=device_ids)
98
+ GradientAttribution.__init__(self, forward_func)
99
+ self.ig = IntegratedGradients(forward_func, multiply_by_inputs)
100
+
101
+ if isinstance(layer, list) and len(layer) > 1:
102
+ warnings.warn(
103
+ "Multiple layers provided. Please ensure that each layer is"
104
+ "**not** solely solely dependent on the outputs of"
105
+ "another layer. Please refer to the documentation for more"
106
+ "detail."
107
+ )
108
+
109
+ @overload
110
+ def attribute(
111
+ self,
112
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
113
+ baselines: BaselineType,
114
+ target: TargetType,
115
+ additional_forward_args: Any,
116
+ n_steps: int,
117
+ method: str,
118
+ internal_batch_size: Union[None, int],
119
+ return_convergence_delta: Literal[False],
120
+ attribute_to_layer_input: bool,
121
+ ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
122
+ ...
123
+
124
+ @overload
125
+ def attribute(
126
+ self,
127
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
128
+ baselines: BaselineType,
129
+ target: TargetType,
130
+ additional_forward_args: Any,
131
+ n_steps: int,
132
+ method: str,
133
+ internal_batch_size: Union[None, int],
134
+ return_convergence_delta: Literal[True],
135
+ attribute_to_layer_input: bool,
136
+ ) -> Tuple[
137
+ Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
138
+ Tensor,
139
+ ]:
140
+ ...
141
+
142
+ @overload
143
+ def attribute(
144
+ self,
145
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
146
+ baselines: BaselineType = None,
147
+ target: TargetType = None,
148
+ additional_forward_args: Any = None,
149
+ n_steps: int = 50,
150
+ method: str = "gausslegendre",
151
+ internal_batch_size: Union[None, int] = None,
152
+ return_convergence_delta: bool = False,
153
+ attribute_to_layer_input: bool = False,
154
+ ) -> Union[
155
+ Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
156
+ Tuple[
157
+ Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
158
+ Tensor,
159
+ ],
160
+ ]:
161
+ ...
162
+
163
+ @log_usage()
164
+ def attribute(
165
+ self,
166
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
167
+ baselines: BaselineType = None,
168
+ target: TargetType = None,
169
+ additional_forward_args: Any = None,
170
+ n_steps: int = 50,
171
+ method: str = "gausslegendre",
172
+ internal_batch_size: Union[None, int] = None,
173
+ return_convergence_delta: bool = False,
174
+ attribute_to_layer_input: bool = False,
175
+ ) -> Union[
176
+ Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
177
+ Tuple[
178
+ Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
179
+ Tensor,
180
+ ],
181
+ ]:
182
+ r"""
183
+ This method attributes the output of the model with given target index
184
+ (in case it is provided, otherwise it assumes that output is a
185
+ scalar) to layer inputs or outputs of the model, depending on whether
186
+ `attribute_to_layer_input` is set to True or False, using the approach
187
+ described above.
188
+
189
+ In addition to that it also returns, if `return_convergence_delta` is
190
+ set to True, integral approximation delta based on the completeness
191
+ property of integrated gradients.
192
+
193
+ Args:
194
+
195
+ inputs (tensor or tuple of tensors): Input for which layer integrated
196
+ gradients are computed. If forward_func takes a single
197
+ tensor as input, a single input tensor should be provided.
198
+ If forward_func takes multiple tensors as input, a tuple
199
+ of the input tensors should be provided. It is assumed
200
+ that for all given input tensors, dimension 0 corresponds
201
+ to the number of examples, and if multiple input tensors
202
+ are provided, the examples must be aligned appropriately.
203
+ baselines (scalar, tensor, tuple of scalars or tensors, optional):
204
+ Baselines define the starting point from which integral
205
+ is computed and can be provided as:
206
+
207
+ - a single tensor, if inputs is a single tensor, with
208
+ exactly the same dimensions as inputs or the first
209
+ dimension is one and the remaining dimensions match
210
+ with inputs.
211
+
212
+ - a single scalar, if inputs is a single tensor, which will
213
+ be broadcasted for each input value in input tensor.
214
+
215
+ - a tuple of tensors or scalars, the baseline corresponding
216
+ to each tensor in the inputs' tuple can be:
217
+ - either a tensor with matching dimensions to
218
+ corresponding tensor in the inputs' tuple
219
+ or the first dimension is one and the remaining
220
+ dimensions match with the corresponding
221
+ input tensor.
222
+ - or a scalar, corresponding to a tensor in the
223
+ inputs' tuple. This scalar value is broadcasted
224
+ for corresponding input tensor.
225
+
226
+ In the cases when `baselines` is not provided, we internally
227
+ use zero scalar corresponding to each input tensor.
228
+
229
+ Default: None
230
+ target (int, tuple, tensor or list, optional): Output indices for
231
+ which gradients are computed (for classification cases,
232
+ this is usually the target class).
233
+ If the network returns a scalar value per example,
234
+ no target index is necessary.
235
+ For general 2D outputs, targets can be either:
236
+
237
+ - a single integer or a tensor containing a single
238
+ integer, which is applied to all input examples
239
+
240
+ - a list of integers or a 1D tensor, with length matching
241
+ the number of examples in inputs (dim 0). Each integer
242
+ is applied as the target for the corresponding example.
243
+
244
+ For outputs with > 2 dimensions, targets can be either:
245
+
246
+ - A single tuple, which contains #output_dims - 1
247
+ elements. This target index is applied to all examples.
248
+
249
+ - A list of tuples with length equal to the number of
250
+ examples in inputs (dim 0), and each tuple containing
251
+ #output_dims - 1 elements. Each tuple is applied as the
252
+ target for the corresponding example.
253
+
254
+ Default: None
255
+ additional_forward_args (any, optional): If the forward function
256
+ requires additional arguments other than the inputs for
257
+ which attributions should not be computed, this argument
258
+ can be provided. It must be either a single additional
259
+ argument of a Tensor or arbitrary (non-tuple) type or a
260
+ tuple containing multiple additional arguments including
261
+ tensors or any arbitrary python types. These arguments
262
+ are provided to forward_func in order following the
263
+ arguments in inputs.
264
+ For a tensor, the first dimension of the tensor must
265
+ correspond to the number of examples. It will be
266
+ repeated for each of `n_steps` along the integrated
267
+ path. For all other types, the given argument is used
268
+ for all forward evaluations.
269
+ Note that attributions are not computed with respect
270
+ to these arguments.
271
+ Default: None
272
+ n_steps (int, optional): The number of steps used by the approximation
273
+ method. Default: 50.
274
+ method (string, optional): Method for approximating the integral,
275
+ one of `riemann_right`, `riemann_left`, `riemann_middle`,
276
+ `riemann_trapezoid` or `gausslegendre`.
277
+ Default: `gausslegendre` if no method is provided.
278
+ internal_batch_size (int, optional): Divides total #steps * #examples
279
+ data points into chunks of size at most internal_batch_size,
280
+ which are computed (forward / backward passes)
281
+ sequentially. internal_batch_size must be at least equal to
282
+ #examples.
283
+ For DataParallel models, each batch is split among the
284
+ available devices, so evaluations on each available
285
+ device contain internal_batch_size / num_devices examples.
286
+ If internal_batch_size is None, then all evaluations are
287
+ processed in one batch.
288
+ Default: None
289
+ return_convergence_delta (bool, optional): Indicates whether to return
290
+ convergence delta or not. If `return_convergence_delta`
291
+ is set to True convergence delta will be returned in
292
+ a tuple following attributions.
293
+ Default: False
294
+ attribute_to_layer_input (bool, optional): Indicates whether to
295
+ compute the attribution with respect to the layer input
296
+ or output. If `attribute_to_layer_input` is set to True
297
+ then the attributions will be computed with respect to
298
+ layer input, otherwise it will be computed with respect
299
+ to layer output.
300
+ Note that currently it is assumed that either the input
301
+ or the output of internal layer, depending on whether we
302
+ attribute to the input or output, is a single tensor.
303
+ Support for multiple tensors will be added later.
304
+ Default: False
305
+ Returns:
306
+ **attributions** or 2-element tuple of **attributions**, **delta**:
307
+ - **attributions** (*tensor*, tuple of *tensors* or tuple of *tensors*):
308
+ Integrated gradients with respect to `layer`'s inputs or
309
+ outputs. Attributions will always be the same size and
310
+ dimensionality as the input or output of the given layer,
311
+ depending on whether we attribute to the inputs or outputs
312
+ of the layer which is decided by the input flag
313
+ `attribute_to_layer_input`.
314
+
315
+ For a single layer, attributions are returned in a tuple if
316
+ the layer inputs / outputs contain multiple tensors,
317
+ otherwise a single tensor is returned.
318
+
319
+ For multiple layers, attributions will always be
320
+ returned as a list. Each element in this list will be
321
+ equivalent to that of a single layer output, i.e. in the
322
+ case that one layer, in the given layers, inputs / outputs
323
+ multiple tensors: the corresponding output element will be
324
+ a tuple of tensors. The ordering of the outputs will be
325
+ the same order as the layers given in the constructor.
326
+ - **delta** (*tensor*, returned if return_convergence_delta=True):
327
+ The difference between the total approximated and true
328
+ integrated gradients. This is computed using the property
329
+ that the total sum of forward_func(inputs) -
330
+ forward_func(baselines) must equal the total sum of the
331
+ integrated gradient.
332
+ Delta is calculated per example, meaning that the number of
333
+ elements in returned delta tensor is equal to the number of
334
+ of examples in inputs.
335
+
336
+ Examples::
337
+
338
+ >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
339
+ >>> # and returns an Nx10 tensor of class probabilities.
340
+ >>> # It contains an attribute conv1, which is an instance of nn.conv2d,
341
+ >>> # and the output of this layer has dimensions Nx12x32x32.
342
+ >>> net = ImageClassifier()
343
+ >>> lig = LayerIntegratedGradients(net, net.conv1)
344
+ >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
345
+ >>> # Computes layer integrated gradients for class 3.
346
+ >>> # attribution size matches layer output, Nx12x32x32
347
+ >>> attribution = lig.attribute(input, target=3)
348
+ """
349
+ inps, baselines = _format_input_baseline(inputs, baselines)
350
+ _validate_input(inps, baselines, n_steps, method)
351
+
352
+ baselines = _tensorize_baseline(inps, baselines)
353
+ additional_forward_args = _format_additional_forward_args(
354
+ additional_forward_args
355
+ )
356
+
357
+ def flatten_tuple(tup):
358
+ return tuple(
359
+ sum((list(x) if isinstance(x, (tuple, list)) else [x] for x in tup), [])
360
+ )
361
+
362
+ if self.device_ids is None:
363
+ self.device_ids = getattr(self.forward_func, "device_ids", None)
364
+
365
+ inputs_layer = _forward_layer_eval(
366
+ self.forward_func,
367
+ inps,
368
+ self.layer,
369
+ device_ids=self.device_ids,
370
+ additional_forward_args=additional_forward_args,
371
+ attribute_to_layer_input=attribute_to_layer_input,
372
+ )
373
+
374
+ # if we have one output
375
+ if not isinstance(self.layer, list):
376
+ inputs_layer = (inputs_layer,)
377
+
378
+ num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in inputs_layer]
379
+ num_outputs_cumsum = torch.cumsum(
380
+ torch.IntTensor([0] + num_outputs), dim=0 # type: ignore
381
+ )
382
+ inputs_layer = flatten_tuple(inputs_layer)
383
+
384
+ baselines_layer = _forward_layer_eval(
385
+ self.forward_func,
386
+ baselines,
387
+ self.layer,
388
+ device_ids=self.device_ids,
389
+ additional_forward_args=additional_forward_args,
390
+ attribute_to_layer_input=attribute_to_layer_input,
391
+ )
392
+ baselines_layer = flatten_tuple(baselines_layer)
393
+
394
+ # inputs -> these inputs are scaled
395
+ def gradient_func(
396
+ forward_fn: Callable,
397
+ inputs: Union[Tensor, Tuple[Tensor, ...]],
398
+ target_ind: TargetType = None,
399
+ additional_forward_args: Any = None,
400
+ ) -> Tuple[Tensor, ...]:
401
+ if self.device_ids is None or len(self.device_ids) == 0:
402
+ scattered_inputs = (inputs,)
403
+ else:
404
+ # scatter method does not have a precise enough return type in its
405
+ # stub, so suppress the type warning.
406
+ scattered_inputs = scatter( # type:ignore
407
+ inputs, target_gpus=self.device_ids
408
+ )
409
+
410
+ scattered_inputs_dict = {
411
+ scattered_input[0].device: scattered_input
412
+ for scattered_input in scattered_inputs
413
+ }
414
+
415
+ with torch.autograd.set_grad_enabled(True):
416
+
417
+ def layer_forward_hook(
418
+ module, hook_inputs, hook_outputs=None, layer_idx=0
419
+ ):
420
+ device = _extract_device(module, hook_inputs, hook_outputs)
421
+ is_layer_tuple = (
422
+ isinstance(hook_outputs, tuple)
423
+ # hook_outputs is None if attribute_to_layer_input == True
424
+ if hook_outputs is not None
425
+ else isinstance(hook_inputs, tuple)
426
+ )
427
+
428
+ if is_layer_tuple:
429
+ return scattered_inputs_dict[device][
430
+ num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
431
+ layer_idx + 1
432
+ ]
433
+ ]
434
+
435
+ return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]]
436
+
437
+ hooks = []
438
+ try:
439
+
440
+ layers = self.layer
441
+ if not isinstance(layers, list):
442
+ layers = [self.layer]
443
+
444
+ for layer_idx, layer in enumerate(layers):
445
+ hook = None
446
+ # TODO:
447
+ # Allow multiple attribute_to_layer_input flags for
448
+ # each layer, i.e. attribute_to_layer_input[layer_idx]
449
+ if attribute_to_layer_input:
450
+ hook = layer.register_forward_pre_hook(
451
+ functools.partial(
452
+ layer_forward_hook, layer_idx=layer_idx
453
+ )
454
+ )
455
+ else:
456
+ hook = layer.register_forward_hook(
457
+ functools.partial(
458
+ layer_forward_hook, layer_idx=layer_idx
459
+ )
460
+ )
461
+
462
+ hooks.append(hook)
463
+
464
+ output = _run_forward(
465
+ self.forward_func, tuple(), target_ind, additional_forward_args
466
+ )
467
+ finally:
468
+ for hook in hooks:
469
+ if hook is not None:
470
+ hook.remove()
471
+
472
+ assert output[0].numel() == 1, (
473
+ "Target not provided when necessary, cannot"
474
+ " take gradient with respect to multiple outputs."
475
+ )
476
+ # torch.unbind(forward_out) is a list of scalar tensor tuples and
477
+ # contains batch_size * #steps elements
478
+ grads = torch.autograd.grad(torch.unbind(output), inputs)
479
+ return grads
480
+
481
+ self.ig.gradient_func = gradient_func
482
+ all_inputs = (
483
+ (inps + additional_forward_args)
484
+ if additional_forward_args is not None
485
+ else inps
486
+ )
487
+
488
+ attributions = self.ig.attribute.__wrapped__( # type: ignore
489
+ self.ig, # self
490
+ inputs_layer,
491
+ baselines=baselines_layer,
492
+ target=target,
493
+ additional_forward_args=all_inputs,
494
+ n_steps=n_steps,
495
+ method=method,
496
+ internal_batch_size=internal_batch_size,
497
+ return_convergence_delta=False,
498
+ )
499
+
500
+ # handle multiple outputs
501
+ output: List[Tuple[Tensor, ...]] = [
502
+ tuple(
503
+ attributions[
504
+ int(num_outputs_cumsum[i]) : int(num_outputs_cumsum[i + 1])
505
+ ]
506
+ )
507
+ for i in range(len(num_outputs))
508
+ ]
509
+
510
+ if return_convergence_delta:
511
+ start_point, end_point = baselines, inps
512
+ # computes approximation error based on the completeness axiom
513
+ delta = self.compute_convergence_delta(
514
+ attributions,
515
+ start_point,
516
+ end_point,
517
+ additional_forward_args=additional_forward_args,
518
+ target=target,
519
+ )
520
+ return _format_outputs(isinstance(self.layer, list), output), delta
521
+ return _format_outputs(isinstance(self.layer, list), output)
522
+
523
+ def has_convergence_delta(self) -> bool:
524
+ return True
525
+
526
+ @property
527
+ def multiplies_by_inputs(self):
528
+ return self.ig.multiplies_by_inputs