Spaces:
Build error
Build error
Commit
·
d61b9c7
1
Parent(s):
5f5c8d7
added strexp
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +224 -0
- attribution_ops.py +87 -0
- augmentation/blur.py +189 -0
- augmentation/camera.py +120 -0
- augmentation/frost/frost4.jpg +0 -0
- augmentation/frost/frost5.jpg +0 -0
- augmentation/frost/frost6.jpg +0 -0
- augmentation/geometry.py +233 -0
- augmentation/noise.py +94 -0
- augmentation/ops.py +87 -0
- augmentation/pattern.py +115 -0
- augmentation/process.py +123 -0
- augmentation/test.py +43 -0
- augmentation/warp.py +241 -0
- augmentation/weather.py +231 -0
- callbacks.py +360 -0
- captum/__init__.py +3 -0
- captum/_utils/__init__.py +0 -0
- captum/_utils/av.py +499 -0
- captum/_utils/common.py +679 -0
- captum/_utils/gradient.py +865 -0
- captum/_utils/models/__init__.py +25 -0
- captum/_utils/models/linear_model/__init__.py +23 -0
- captum/_utils/models/linear_model/model.py +341 -0
- captum/_utils/models/linear_model/train.py +364 -0
- captum/_utils/models/model.py +66 -0
- captum/_utils/progress.py +138 -0
- captum/_utils/sample_gradient.py +184 -0
- captum/_utils/typing.py +37 -0
- captum/attr/__init__.py +143 -0
- captum/attr/_core/__init__.py +0 -0
- captum/attr/_core/deep_lift.py +1151 -0
- captum/attr/_core/feature_ablation.py +591 -0
- captum/attr/_core/feature_permutation.py +305 -0
- captum/attr/_core/gradient_shap.py +414 -0
- captum/attr/_core/guided_backprop_deconvnet.py +322 -0
- captum/attr/_core/guided_grad_cam.py +226 -0
- captum/attr/_core/input_x_gradient.py +130 -0
- captum/attr/_core/integrated_gradients.py +390 -0
- captum/attr/_core/kernel_shap.py +348 -0
- captum/attr/_core/layer/__init__.py +0 -0
- captum/attr/_core/layer/grad_cam.py +217 -0
- captum/attr/_core/layer/internal_influence.py +309 -0
- captum/attr/_core/layer/layer_activation.py +136 -0
- captum/attr/_core/layer/layer_conductance.py +395 -0
- captum/attr/_core/layer/layer_deep_lift.py +682 -0
- captum/attr/_core/layer/layer_feature_ablation.py +302 -0
- captum/attr/_core/layer/layer_gradient_shap.py +474 -0
- captum/attr/_core/layer/layer_gradient_x_activation.py +201 -0
- captum/attr/_core/layer/layer_integrated_gradients.py +528 -0
.gitignore
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**/saved_models/*
|
2 |
+
**/data_lmdb_release/*
|
3 |
+
**/image_release/*
|
4 |
+
**/vitstr_base_patch*
|
5 |
+
**/result/*
|
6 |
+
**/results/*
|
7 |
+
**/oldData/
|
8 |
+
*.mdb
|
9 |
+
*.xlsx
|
10 |
+
*.pth
|
11 |
+
*.json
|
12 |
+
*.pkl
|
13 |
+
*.tar
|
14 |
+
*.ipynb
|
15 |
+
*.zip
|
16 |
+
*.eps
|
17 |
+
*.pdf
|
18 |
+
**/grcnn_straug/*
|
19 |
+
**/augmentation/results/*
|
20 |
+
**/tmp/*
|
21 |
+
*.sh
|
22 |
+
**/__pycache__
|
23 |
+
workdir/
|
24 |
+
.remote-sync.json
|
25 |
+
*.png
|
26 |
+
pretrained/
|
27 |
+
attributionImgs/
|
28 |
+
attributionImgsOld/
|
29 |
+
attrSelectivityOld/
|
30 |
+
|
31 |
+
### Linux ###
|
32 |
+
*~
|
33 |
+
|
34 |
+
# temporary files which can be created if a process still has a handle open of a deleted file
|
35 |
+
.fuse_hidden*
|
36 |
+
|
37 |
+
# KDE directory preferences
|
38 |
+
.directory
|
39 |
+
|
40 |
+
# Linux trash folder which might appear on any partition or disk
|
41 |
+
.Trash-*
|
42 |
+
|
43 |
+
# .nfs files are created when an open file is removed but is still being accessed
|
44 |
+
.nfs*
|
45 |
+
|
46 |
+
### OSX ###
|
47 |
+
# General
|
48 |
+
.DS_Store
|
49 |
+
.AppleDouble
|
50 |
+
.LSOverride
|
51 |
+
|
52 |
+
# Icon must end with two \r
|
53 |
+
Icon
|
54 |
+
|
55 |
+
# Thumbnails
|
56 |
+
._*
|
57 |
+
|
58 |
+
# Files that might appear in the root of a volume
|
59 |
+
.DocumentRevisions-V100
|
60 |
+
.fseventsd
|
61 |
+
.Spotlight-V100
|
62 |
+
.TemporaryItems
|
63 |
+
.Trashes
|
64 |
+
.VolumeIcon.icns
|
65 |
+
.com.apple.timemachine.donotpresent
|
66 |
+
|
67 |
+
# Directories potentially created on remote AFP share
|
68 |
+
.AppleDB
|
69 |
+
.AppleDesktop
|
70 |
+
Network Trash Folder
|
71 |
+
Temporary Items
|
72 |
+
.apdisk
|
73 |
+
|
74 |
+
### Python ###
|
75 |
+
# Byte-compiled / optimized / DLL files
|
76 |
+
__pycache__/
|
77 |
+
*.py[cod]
|
78 |
+
*$py.class
|
79 |
+
|
80 |
+
# C extensions
|
81 |
+
*.so
|
82 |
+
|
83 |
+
# Distribution / packaging
|
84 |
+
.Python
|
85 |
+
build/
|
86 |
+
develop-eggs/
|
87 |
+
dist/
|
88 |
+
downloads/
|
89 |
+
eggs/
|
90 |
+
.eggs/
|
91 |
+
lib/
|
92 |
+
lib64/
|
93 |
+
parts/
|
94 |
+
sdist/
|
95 |
+
var/
|
96 |
+
wheels/
|
97 |
+
*.egg-info/
|
98 |
+
.installed.cfg
|
99 |
+
*.egg
|
100 |
+
MANIFEST
|
101 |
+
|
102 |
+
# PyInstaller
|
103 |
+
# Usually these files are written by a python script from a template
|
104 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
105 |
+
*.manifest
|
106 |
+
*.spec
|
107 |
+
|
108 |
+
# Installer logs
|
109 |
+
pip-log.txt
|
110 |
+
pip-delete-this-directory.txt
|
111 |
+
|
112 |
+
# Unit test / coverage reports
|
113 |
+
htmlcov/
|
114 |
+
.tox/
|
115 |
+
.coverage
|
116 |
+
.coverage.*
|
117 |
+
.cache
|
118 |
+
nosetests.xml
|
119 |
+
coverage.xml
|
120 |
+
*.cover
|
121 |
+
.hypothesis/
|
122 |
+
.pytest_cache/
|
123 |
+
|
124 |
+
# Translations
|
125 |
+
*.mo
|
126 |
+
*.pot
|
127 |
+
|
128 |
+
# Django stuff:
|
129 |
+
*.log
|
130 |
+
local_settings.py
|
131 |
+
db.sqlite3
|
132 |
+
|
133 |
+
# Flask stuff:
|
134 |
+
instance/
|
135 |
+
.webassets-cache
|
136 |
+
|
137 |
+
# Scrapy stuff:
|
138 |
+
.scrapy
|
139 |
+
|
140 |
+
# Sphinx documentation
|
141 |
+
docs/_build/
|
142 |
+
|
143 |
+
# PyBuilder
|
144 |
+
target/
|
145 |
+
|
146 |
+
# Jupyter Notebook
|
147 |
+
.ipynb_checkpoints
|
148 |
+
|
149 |
+
# IPython
|
150 |
+
profile_default/
|
151 |
+
ipython_config.py
|
152 |
+
|
153 |
+
# pyenv
|
154 |
+
.python-version
|
155 |
+
|
156 |
+
# celery beat schedule file
|
157 |
+
celerybeat-schedule
|
158 |
+
|
159 |
+
# SageMath parsed files
|
160 |
+
*.sage.py
|
161 |
+
|
162 |
+
# Environments
|
163 |
+
.env
|
164 |
+
.venv
|
165 |
+
env/
|
166 |
+
venv/
|
167 |
+
ENV/
|
168 |
+
env.bak/
|
169 |
+
venv.bak/
|
170 |
+
|
171 |
+
# Spyder project settings
|
172 |
+
.spyderproject
|
173 |
+
.spyproject
|
174 |
+
|
175 |
+
# Rope project settings
|
176 |
+
.ropeproject
|
177 |
+
|
178 |
+
# mkdocs documentation
|
179 |
+
/site
|
180 |
+
|
181 |
+
# mypy
|
182 |
+
.mypy_cache/
|
183 |
+
.dmypy.json
|
184 |
+
dmypy.json
|
185 |
+
|
186 |
+
### Python Patch ###
|
187 |
+
.venv/
|
188 |
+
|
189 |
+
### Python.VirtualEnv Stack ###
|
190 |
+
# Virtualenv
|
191 |
+
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
|
192 |
+
[Bb]in
|
193 |
+
[Ii]nclude
|
194 |
+
[Ll]ib
|
195 |
+
[Ll]ib64
|
196 |
+
[Ll]ocal
|
197 |
+
[Ss]cripts
|
198 |
+
pyvenv.cfg
|
199 |
+
pip-selfcheck.json
|
200 |
+
|
201 |
+
### Windows ###
|
202 |
+
# Windows thumbnail cache files
|
203 |
+
Thumbs.db
|
204 |
+
ehthumbs.db
|
205 |
+
ehthumbs_vista.db
|
206 |
+
|
207 |
+
# Dump file
|
208 |
+
*.stackdump
|
209 |
+
|
210 |
+
# Folder config file
|
211 |
+
[Dd]esktop.ini
|
212 |
+
|
213 |
+
# Recycle Bin used on file shares
|
214 |
+
$RECYCLE.BIN/
|
215 |
+
|
216 |
+
# Windows Installer files
|
217 |
+
*.cab
|
218 |
+
*.msi
|
219 |
+
*.msix
|
220 |
+
*.msm
|
221 |
+
*.msp
|
222 |
+
|
223 |
+
# Windows shortcuts
|
224 |
+
*.lnk
|
attribution_ops.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
from captum_improve_vitstr import rankedAttributionsBySegm
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from skimage.color import gray2rgb
|
6 |
+
from captum.attr._utils.visualization import visualize_image_attr
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
def attr_one_dataset():
|
11 |
+
modelName = "vitstr"
|
12 |
+
datasetName = "IIIT5k_3000"
|
13 |
+
|
14 |
+
rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/"
|
15 |
+
attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/"
|
16 |
+
if not os.path.exists(attrOutputImgs):
|
17 |
+
os.makedirs(attrOutputImgs)
|
18 |
+
|
19 |
+
minNumber = 1000000
|
20 |
+
maxNumber = 0
|
21 |
+
# From a folder containing saved attribution pickle files, convert them into attribution images
|
22 |
+
for path, subdirs, files in os.walk(rootDir):
|
23 |
+
for name in files:
|
24 |
+
fullfilename = os.path.join(rootDir, name) # Value
|
25 |
+
# fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl
|
26 |
+
partfilename = fullfilename[fullfilename.rfind('/')+1:]
|
27 |
+
print("fullfilename: ", fullfilename)
|
28 |
+
imgNum = int(partfilename.split('_')[0])
|
29 |
+
attrImgName = partfilename.replace('.pkl', '.png')
|
30 |
+
minNumber = min(minNumber, imgNum)
|
31 |
+
maxNumber = max(maxNumber, imgNum)
|
32 |
+
with open(fullfilename, 'rb') as f:
|
33 |
+
pklData = pickle.load(f)
|
34 |
+
attributions = pklData['attribution']
|
35 |
+
segmDataNP = pklData['segmData']
|
36 |
+
origImgNP = pklData['origImg']
|
37 |
+
if np.isnan(attributions).any():
|
38 |
+
continue
|
39 |
+
attributions = torch.from_numpy(attributions)
|
40 |
+
rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP)
|
41 |
+
rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
|
42 |
+
rankedAttr = gray2rgb(rankedAttr)
|
43 |
+
mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
|
44 |
+
mplotfig.savefig(attrOutputImgs + attrImgName)
|
45 |
+
mplotfig.clear()
|
46 |
+
plt.close(mplotfig)
|
47 |
+
|
48 |
+
def attr_all_dataset():
|
49 |
+
modelName = "vitstr"
|
50 |
+
|
51 |
+
datasetNameList = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
|
52 |
+
|
53 |
+
for datasetName in datasetNameList:
|
54 |
+
rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/"
|
55 |
+
attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/"
|
56 |
+
if not os.path.exists(attrOutputImgs):
|
57 |
+
os.makedirs(attrOutputImgs)
|
58 |
+
|
59 |
+
minNumber = 1000000
|
60 |
+
maxNumber = 0
|
61 |
+
# From a folder containing saved attribution pickle files, convert them into attribution images
|
62 |
+
for path, subdirs, files in os.walk(rootDir):
|
63 |
+
for name in files:
|
64 |
+
fullfilename = os.path.join(rootDir, name) # Value
|
65 |
+
# fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl
|
66 |
+
partfilename = fullfilename[fullfilename.rfind('/')+1:]
|
67 |
+
imgNum = int(partfilename.split('_')[0])
|
68 |
+
attrImgName = partfilename.replace('.pkl', '.png')
|
69 |
+
minNumber = min(minNumber, imgNum)
|
70 |
+
maxNumber = max(maxNumber, imgNum)
|
71 |
+
with open(fullfilename, 'rb') as f:
|
72 |
+
pklData = pickle.load(f)
|
73 |
+
attributions = pklData['attribution']
|
74 |
+
segmDataNP = pklData['segmData']
|
75 |
+
origImgNP = pklData['origImg']
|
76 |
+
attributions = torch.from_numpy(attributions)
|
77 |
+
rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP)
|
78 |
+
rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
|
79 |
+
rankedAttr = gray2rgb(rankedAttr)
|
80 |
+
mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
|
81 |
+
mplotfig.savefig(attrOutputImgs + attrImgName)
|
82 |
+
mplotfig.clear()
|
83 |
+
plt.close(mplotfig)
|
84 |
+
|
85 |
+
if __name__ == '__main__':
|
86 |
+
attr_one_dataset()
|
87 |
+
# attr_all_dataset()
|
augmentation/blur.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image, ImageOps
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
from wand.image import Image as WandImage
|
7 |
+
from scipy.ndimage import zoom as scizoom
|
8 |
+
from skimage.filters import gaussian
|
9 |
+
from wand.api import library as wandlibrary
|
10 |
+
from io import BytesIO
|
11 |
+
|
12 |
+
#from skimage import color
|
13 |
+
from .ops import MotionImage, clipped_zoom, disk, plasma_fractal
|
14 |
+
'''
|
15 |
+
PIL resize (W,H)
|
16 |
+
'''
|
17 |
+
class GaussianBlur:
|
18 |
+
def __init__(self):
|
19 |
+
pass
|
20 |
+
|
21 |
+
def __call__(self, img, mag=-1, prob=1.):
|
22 |
+
if np.random.uniform(0,1) > prob:
|
23 |
+
return img
|
24 |
+
|
25 |
+
W, H = img.size
|
26 |
+
#kernel = [(31,31)] prev 1 level only
|
27 |
+
kernel = (31, 31)
|
28 |
+
sigmas = [.5, 1, 2]
|
29 |
+
if mag<0 or mag>=len(kernel):
|
30 |
+
index = np.random.randint(0, len(sigmas))
|
31 |
+
else:
|
32 |
+
index = mag
|
33 |
+
|
34 |
+
sigma = sigmas[index]
|
35 |
+
return transforms.GaussianBlur(kernel_size=kernel, sigma=sigma)(img)
|
36 |
+
|
37 |
+
|
38 |
+
class DefocusBlur:
|
39 |
+
def __init__(self):
|
40 |
+
pass
|
41 |
+
|
42 |
+
def __call__(self, img, mag=-1, prob=1.):
|
43 |
+
if np.random.uniform(0,1) > prob:
|
44 |
+
return img
|
45 |
+
|
46 |
+
n_channels = len(img.getbands())
|
47 |
+
isgray = n_channels == 1
|
48 |
+
#c = [(3, 0.1), (4, 0.5), (6, 0.5), (8, 0.5), (10, 0.5)]
|
49 |
+
c = [(2, 0.1), (3, 0.1), (4, 0.1)] #, (6, 0.5)] #prev 2 levels only
|
50 |
+
if mag<0 or mag>=len(c):
|
51 |
+
index = np.random.randint(0, len(c))
|
52 |
+
else:
|
53 |
+
index = mag
|
54 |
+
c = c[index]
|
55 |
+
|
56 |
+
img = np.array(img) / 255.
|
57 |
+
if isgray:
|
58 |
+
img = np.expand_dims(img, axis=2)
|
59 |
+
img = np.repeat(img, 3, axis=2)
|
60 |
+
n_channels = 3
|
61 |
+
kernel = disk(radius=c[0], alias_blur=c[1])
|
62 |
+
|
63 |
+
channels = []
|
64 |
+
for d in range(n_channels):
|
65 |
+
channels.append(cv2.filter2D(img[:, :, d], -1, kernel))
|
66 |
+
channels = np.array(channels).transpose((1, 2, 0)) # 3x224x224 -> 224x224x3
|
67 |
+
|
68 |
+
#if isgray:
|
69 |
+
# img = img[:,:,0]
|
70 |
+
# img = np.squeeze(img)
|
71 |
+
|
72 |
+
img = np.clip(channels, 0, 1) * 255
|
73 |
+
img = Image.fromarray(img.astype(np.uint8))
|
74 |
+
if isgray:
|
75 |
+
img = ImageOps.grayscale(img)
|
76 |
+
|
77 |
+
return img
|
78 |
+
|
79 |
+
|
80 |
+
class MotionBlur:
|
81 |
+
def __init__(self):
|
82 |
+
pass
|
83 |
+
|
84 |
+
def __call__(self, img, mag=-1, prob=1.):
|
85 |
+
if np.random.uniform(0,1) > prob:
|
86 |
+
return img
|
87 |
+
|
88 |
+
n_channels = len(img.getbands())
|
89 |
+
isgray = n_channels == 1
|
90 |
+
#c = [(10, 3), (15, 5), (15, 8), (15, 12), (20, 15)]
|
91 |
+
c = [(10, 3), (12, 4), (14, 5)]
|
92 |
+
if mag<0 or mag>=len(c):
|
93 |
+
index = np.random.randint(0, len(c))
|
94 |
+
else:
|
95 |
+
index = mag
|
96 |
+
c = c[index]
|
97 |
+
|
98 |
+
output = BytesIO()
|
99 |
+
img.save(output, format='PNG')
|
100 |
+
img = MotionImage(blob=output.getvalue())
|
101 |
+
|
102 |
+
img.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45))
|
103 |
+
img = cv2.imdecode(np.fromstring(img.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED)
|
104 |
+
if len(img.shape) > 2:
|
105 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
106 |
+
|
107 |
+
img = Image.fromarray(img.astype(np.uint8))
|
108 |
+
|
109 |
+
if isgray:
|
110 |
+
img = ImageOps.grayscale(img)
|
111 |
+
|
112 |
+
return img
|
113 |
+
|
114 |
+
class GlassBlur:
|
115 |
+
def __init__(self):
|
116 |
+
pass
|
117 |
+
|
118 |
+
def __call__(self, img, mag=-1, prob=1.):
|
119 |
+
if np.random.uniform(0,1) > prob:
|
120 |
+
return img
|
121 |
+
|
122 |
+
W, H = img.size
|
123 |
+
#c = [(0.7, 1, 2), (0.9, 2, 1), (1, 2, 3), (1.1, 3, 2), (1.5, 4, 2)][severity - 1]
|
124 |
+
c = [(0.7, 1, 2), (0.75, 1, 2), (0.8, 1, 2)] #, (1, 2, 3)] #prev 2 levels only
|
125 |
+
if mag<0 or mag>=len(c):
|
126 |
+
index = np.random.randint(0, len(c))
|
127 |
+
else:
|
128 |
+
index = mag
|
129 |
+
|
130 |
+
c = c[index]
|
131 |
+
|
132 |
+
img = np.uint8(gaussian(np.array(img) / 255., sigma=c[0], multichannel=True) * 255)
|
133 |
+
|
134 |
+
# locally shuffle pixels
|
135 |
+
for i in range(c[2]):
|
136 |
+
for h in range(H - c[1], c[1], -1):
|
137 |
+
for w in range(W - c[1], c[1], -1):
|
138 |
+
dx, dy = np.random.randint(-c[1], c[1], size=(2,))
|
139 |
+
h_prime, w_prime = h + dy, w + dx
|
140 |
+
# swap
|
141 |
+
img[h, w], img[h_prime, w_prime] = img[h_prime, w_prime], img[h, w]
|
142 |
+
|
143 |
+
img = np.clip(gaussian(img / 255., sigma=c[0], multichannel=True), 0, 1) * 255
|
144 |
+
return Image.fromarray(img.astype(np.uint8))
|
145 |
+
|
146 |
+
|
147 |
+
class ZoomBlur:
|
148 |
+
def __init__(self):
|
149 |
+
pass
|
150 |
+
|
151 |
+
def __call__(self, img, mag=-1, prob=1.):
|
152 |
+
if np.random.uniform(0,1) > prob:
|
153 |
+
return img
|
154 |
+
|
155 |
+
W, H = img.size
|
156 |
+
c = [np.arange(1, 1.11, .01),
|
157 |
+
np.arange(1, 1.16, .01),
|
158 |
+
np.arange(1, 1.21, .02)]
|
159 |
+
if mag<0 or mag>=len(c):
|
160 |
+
index = np.random.randint(0, len(c))
|
161 |
+
else:
|
162 |
+
index = mag
|
163 |
+
|
164 |
+
c = c[index]
|
165 |
+
|
166 |
+
n_channels = len(img.getbands())
|
167 |
+
isgray = n_channels == 1
|
168 |
+
|
169 |
+
uint8_img = img
|
170 |
+
img = (np.array(img) / 255.).astype(np.float32)
|
171 |
+
|
172 |
+
out = np.zeros_like(img)
|
173 |
+
for zoom_factor in c:
|
174 |
+
ZW = int(W*zoom_factor)
|
175 |
+
ZH = int(H*zoom_factor)
|
176 |
+
zoom_img = uint8_img.resize((ZW, ZH), Image.BICUBIC)
|
177 |
+
x1 = (ZW - W) // 2
|
178 |
+
y1 = (ZH - H) // 2
|
179 |
+
x2 = x1 + W
|
180 |
+
y2 = y1 + H
|
181 |
+
zoom_img = zoom_img.crop((x1,y1,x2,y2))
|
182 |
+
out += (np.array(zoom_img) / 255.).astype(np.float32)
|
183 |
+
|
184 |
+
img = (img + out) / (len(c) + 1)
|
185 |
+
|
186 |
+
img = np.clip(img, 0, 1) * 255
|
187 |
+
img = Image.fromarray(img.astype(np.uint8))
|
188 |
+
|
189 |
+
return img
|
augmentation/camera.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import skimage as sk
|
5 |
+
from PIL import Image, ImageOps
|
6 |
+
from io import BytesIO
|
7 |
+
|
8 |
+
from skimage import color
|
9 |
+
'''
|
10 |
+
PIL resize (W,H)
|
11 |
+
cv2 image is BGR
|
12 |
+
PIL image is RGB
|
13 |
+
'''
|
14 |
+
class Contrast:
|
15 |
+
def __init__(self):
|
16 |
+
pass
|
17 |
+
|
18 |
+
def __call__(self, img, mag=-1, prob=1.):
|
19 |
+
if np.random.uniform(0,1) > prob:
|
20 |
+
return img
|
21 |
+
|
22 |
+
#c = [0.4, .3, .2, .1, .05]
|
23 |
+
c = [0.4, .3, .2]
|
24 |
+
if mag<0 or mag>=len(c):
|
25 |
+
index = np.random.randint(0, len(c))
|
26 |
+
else:
|
27 |
+
index = mag
|
28 |
+
c = c[index]
|
29 |
+
img = np.array(img) / 255.
|
30 |
+
means = np.mean(img, axis=(0, 1), keepdims=True)
|
31 |
+
img = np.clip((img - means) * c + means, 0, 1) * 255
|
32 |
+
|
33 |
+
return Image.fromarray(img.astype(np.uint8))
|
34 |
+
|
35 |
+
|
36 |
+
class Brightness:
|
37 |
+
def __init__(self):
|
38 |
+
pass
|
39 |
+
|
40 |
+
def __call__(self, img, mag=-1, prob=1.):
|
41 |
+
if np.random.uniform(0,1) > prob:
|
42 |
+
return img
|
43 |
+
|
44 |
+
#W, H = img.size
|
45 |
+
#c = [.1, .2, .3, .4, .5]
|
46 |
+
c = [.1, .2, .3]
|
47 |
+
if mag<0 or mag>=len(c):
|
48 |
+
index = np.random.randint(0, len(c))
|
49 |
+
else:
|
50 |
+
index = mag
|
51 |
+
c = c[index]
|
52 |
+
|
53 |
+
n_channels = len(img.getbands())
|
54 |
+
isgray = n_channels == 1
|
55 |
+
|
56 |
+
img = np.array(img) / 255.
|
57 |
+
if isgray:
|
58 |
+
img = np.expand_dims(img, axis=2)
|
59 |
+
img = np.repeat(img, 3, axis=2)
|
60 |
+
|
61 |
+
img = sk.color.rgb2hsv(img)
|
62 |
+
img[:, :, 2] = np.clip(img[:, :, 2] + c, 0, 1)
|
63 |
+
img = sk.color.hsv2rgb(img)
|
64 |
+
|
65 |
+
#if isgray:
|
66 |
+
# img = img[:,:,0]
|
67 |
+
# img = np.squeeze(img)
|
68 |
+
|
69 |
+
img = np.clip(img, 0, 1) * 255
|
70 |
+
img = Image.fromarray(img.astype(np.uint8))
|
71 |
+
if isgray:
|
72 |
+
img = ImageOps.grayscale(img)
|
73 |
+
|
74 |
+
return img
|
75 |
+
#if isgray:
|
76 |
+
#if isgray:
|
77 |
+
# img = color.rgb2gray(img)
|
78 |
+
|
79 |
+
#return Image.fromarray(img.astype(np.uint8))
|
80 |
+
|
81 |
+
|
82 |
+
class JpegCompression:
|
83 |
+
def __init__(self):
|
84 |
+
pass
|
85 |
+
|
86 |
+
def __call__(self, img, mag=-1, prob=1.):
|
87 |
+
if np.random.uniform(0,1) > prob:
|
88 |
+
return img
|
89 |
+
|
90 |
+
#c = [25, 18, 15, 10, 7]
|
91 |
+
c = [25, 18, 15]
|
92 |
+
if mag<0 or mag>=len(c):
|
93 |
+
index = np.random.randint(0, len(c))
|
94 |
+
else:
|
95 |
+
index = mag
|
96 |
+
c = c[index]
|
97 |
+
output = BytesIO()
|
98 |
+
img.save(output, 'JPEG', quality=c)
|
99 |
+
return Image.open(output)
|
100 |
+
|
101 |
+
|
102 |
+
class Pixelate:
|
103 |
+
def __init__(self):
|
104 |
+
pass
|
105 |
+
|
106 |
+
def __call__(self, img, mag=-1, prob=1.):
|
107 |
+
if np.random.uniform(0,1) > prob:
|
108 |
+
return img
|
109 |
+
|
110 |
+
W, H = img.size
|
111 |
+
#c = [0.6, 0.5, 0.4, 0.3, 0.25]
|
112 |
+
c = [0.6, 0.5, 0.4]
|
113 |
+
if mag<0 or mag>=len(c):
|
114 |
+
index = np.random.randint(0, len(c))
|
115 |
+
else:
|
116 |
+
index = mag
|
117 |
+
c = c[index]
|
118 |
+
img = img.resize((int(W* c), int(H * c)), Image.BOX)
|
119 |
+
return img.resize((W, H), Image.BOX)
|
120 |
+
|
augmentation/frost/frost4.jpg
ADDED
![]() |
augmentation/frost/frost5.jpg
ADDED
![]() |
augmentation/frost/frost6.jpg
ADDED
![]() |
augmentation/geometry.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image, ImageOps
|
5 |
+
|
6 |
+
'''
|
7 |
+
PIL resize (W,H)
|
8 |
+
Torch resize is (H,W)
|
9 |
+
'''
|
10 |
+
class Shrink:
|
11 |
+
def __init__(self):
|
12 |
+
self.tps = cv2.createThinPlateSplineShapeTransformer()
|
13 |
+
self.translateXAbs = TranslateXAbs()
|
14 |
+
self.translateYAbs = TranslateYAbs()
|
15 |
+
|
16 |
+
def __call__(self, img, mag=-1, prob=1.):
|
17 |
+
if np.random.uniform(0,1) > prob:
|
18 |
+
return img
|
19 |
+
|
20 |
+
W, H = img.size
|
21 |
+
img = np.array(img)
|
22 |
+
srcpt = list()
|
23 |
+
dstpt = list()
|
24 |
+
|
25 |
+
W_33 = 0.33 * W
|
26 |
+
W_50 = 0.50 * W
|
27 |
+
W_66 = 0.66 * W
|
28 |
+
|
29 |
+
H_50 = 0.50 * H
|
30 |
+
|
31 |
+
P = 0
|
32 |
+
|
33 |
+
#frac = 0.4
|
34 |
+
|
35 |
+
b = [.2, .3, .4]
|
36 |
+
if mag<0 or mag>=len(b):
|
37 |
+
index = 0
|
38 |
+
else:
|
39 |
+
index = mag
|
40 |
+
frac = b[index]
|
41 |
+
|
42 |
+
# left-most
|
43 |
+
srcpt.append([P, P])
|
44 |
+
srcpt.append([P, H-P])
|
45 |
+
x = np.random.uniform(frac-.1, frac)*W_33
|
46 |
+
y = np.random.uniform(frac-.1, frac)*H_50
|
47 |
+
dstpt.append([P+x, P+y])
|
48 |
+
dstpt.append([P+x, H-P-y])
|
49 |
+
|
50 |
+
# 2nd left-most
|
51 |
+
srcpt.append([P+W_33, P])
|
52 |
+
srcpt.append([P+W_33, H-P])
|
53 |
+
dstpt.append([P+W_33, P+y])
|
54 |
+
dstpt.append([P+W_33, H-P-y])
|
55 |
+
|
56 |
+
# 3rd left-most
|
57 |
+
srcpt.append([P+W_66, P])
|
58 |
+
srcpt.append([P+W_66, H-P])
|
59 |
+
dstpt.append([P+W_66, P+y])
|
60 |
+
dstpt.append([P+W_66, H-P-y])
|
61 |
+
|
62 |
+
# right-most
|
63 |
+
srcpt.append([W-P, P])
|
64 |
+
srcpt.append([W-P, H-P])
|
65 |
+
dstpt.append([W-P-x, P+y])
|
66 |
+
dstpt.append([W-P-x, H-P-y])
|
67 |
+
|
68 |
+
N = len(dstpt)
|
69 |
+
matches = [cv2.DMatch(i, i, 0) for i in range(N)]
|
70 |
+
dst_shape = np.array(dstpt).reshape((-1, N, 2))
|
71 |
+
src_shape = np.array(srcpt).reshape((-1, N, 2))
|
72 |
+
self.tps.estimateTransformation(dst_shape, src_shape, matches)
|
73 |
+
img = self.tps.warpImage(img)
|
74 |
+
img = Image.fromarray(img)
|
75 |
+
|
76 |
+
if np.random.uniform(0, 1) < 0.5:
|
77 |
+
img = self.translateXAbs(img, val=x)
|
78 |
+
else:
|
79 |
+
img = self.translateYAbs(img, val=y)
|
80 |
+
|
81 |
+
return img
|
82 |
+
|
83 |
+
|
84 |
+
class Rotate:
|
85 |
+
def __init__(self, square_side=224):
|
86 |
+
self.side = square_side
|
87 |
+
|
88 |
+
def __call__(self, img, iscurve=False, mag=-1, prob=1.):
|
89 |
+
if np.random.uniform(0,1) > prob:
|
90 |
+
return img
|
91 |
+
|
92 |
+
W, H = img.size
|
93 |
+
|
94 |
+
if H!=self.side or W!=self.side:
|
95 |
+
img = img.resize((self.side, self.side), Image.BICUBIC)
|
96 |
+
|
97 |
+
b = [20., 40, 60]
|
98 |
+
if mag<0 or mag>=len(b):
|
99 |
+
index = 1
|
100 |
+
else:
|
101 |
+
index = mag
|
102 |
+
rotate_angle = b[index]
|
103 |
+
|
104 |
+
angle = np.random.uniform(rotate_angle-20, rotate_angle)
|
105 |
+
if np.random.uniform(0, 1) < 0.5:
|
106 |
+
angle = -angle
|
107 |
+
|
108 |
+
#angle = np.random.normal(loc=0., scale=rotate_angle)
|
109 |
+
#angle = min(angle, 2*rotate_angle)
|
110 |
+
#angle = max(angle, -2*rotate_angle)
|
111 |
+
|
112 |
+
expand = False if iscurve else True
|
113 |
+
img = img.rotate(angle=angle, resample=Image.BICUBIC, expand=expand)
|
114 |
+
img = img.resize((W, H), Image.BICUBIC)
|
115 |
+
|
116 |
+
return img
|
117 |
+
|
118 |
+
class Perspective:
|
119 |
+
def __init__(self):
|
120 |
+
pass
|
121 |
+
|
122 |
+
def __call__(self, img, mag=-1, prob=1.):
|
123 |
+
if np.random.uniform(0,1) > prob:
|
124 |
+
return img
|
125 |
+
|
126 |
+
W, H = img.size
|
127 |
+
|
128 |
+
# upper-left, upper-right, lower-left, lower-right
|
129 |
+
src = np.float32([[0, 0], [W, 0], [0, H], [W, H]])
|
130 |
+
#low = 0.3
|
131 |
+
|
132 |
+
b = [.1, .2, .3]
|
133 |
+
if mag<0 or mag>=len(b):
|
134 |
+
index = 2
|
135 |
+
else:
|
136 |
+
index = mag
|
137 |
+
low = b[index]
|
138 |
+
|
139 |
+
high = 1 - low
|
140 |
+
if np.random.uniform(0, 1) > 0.5:
|
141 |
+
toprightY = np.random.uniform(low, low+.1)*H
|
142 |
+
bottomrightY = np.random.uniform(high-.1, high)*H
|
143 |
+
dest = np.float32([[0, 0], [W, toprightY], [0, H], [W, bottomrightY]])
|
144 |
+
else:
|
145 |
+
topleftY = np.random.uniform(low, low+.1)*H
|
146 |
+
bottomleftY = np.random.uniform(high-.1, high)*H
|
147 |
+
dest = np.float32([[0, topleftY], [W, 0], [0, bottomleftY], [W, H]])
|
148 |
+
M = cv2.getPerspectiveTransform(src, dest)
|
149 |
+
img = np.array(img)
|
150 |
+
img = cv2.warpPerspective(img, M, (W, H) )
|
151 |
+
img = Image.fromarray(img)
|
152 |
+
|
153 |
+
return img
|
154 |
+
|
155 |
+
|
156 |
+
class TranslateX:
|
157 |
+
def __init__(self):
|
158 |
+
pass
|
159 |
+
|
160 |
+
def __call__(self, img, mag=-1, prob=1.):
|
161 |
+
if np.random.uniform(0,1) > prob:
|
162 |
+
return img
|
163 |
+
|
164 |
+
b = [.03, .06, .09]
|
165 |
+
if mag<0 or mag>=len(b):
|
166 |
+
index = 2
|
167 |
+
else:
|
168 |
+
index = mag
|
169 |
+
v = b[index]
|
170 |
+
v = np.random.uniform(v-0.03, v)
|
171 |
+
|
172 |
+
v = v * img.size[0]
|
173 |
+
if np.random.uniform(0,1) > 0.5:
|
174 |
+
v = -v
|
175 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0))
|
176 |
+
|
177 |
+
|
178 |
+
class TranslateY:
|
179 |
+
def __init__(self):
|
180 |
+
pass
|
181 |
+
|
182 |
+
def __call__(self, img, mag=-1, prob=1.):
|
183 |
+
if np.random.uniform(0,1) > prob:
|
184 |
+
return img
|
185 |
+
|
186 |
+
b = [.07, .14, .21]
|
187 |
+
if mag<0 or mag>=len(b):
|
188 |
+
index = 2
|
189 |
+
else:
|
190 |
+
index = mag
|
191 |
+
v = b[index]
|
192 |
+
v = np.random.uniform(v-0.07, v)
|
193 |
+
|
194 |
+
v = v * img.size[1]
|
195 |
+
if np.random.uniform(0,1) > 0.5:
|
196 |
+
v = -v
|
197 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v))
|
198 |
+
|
199 |
+
|
200 |
+
class TranslateXAbs:
|
201 |
+
def __init__(self):
|
202 |
+
pass
|
203 |
+
|
204 |
+
def __call__(self, img, val=0, prob=1.):
|
205 |
+
if np.random.uniform(0,1) > prob:
|
206 |
+
return img
|
207 |
+
|
208 |
+
v = np.random.uniform(0, val)
|
209 |
+
|
210 |
+
if np.random.uniform(0,1) > 0.5:
|
211 |
+
v = -v
|
212 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0))
|
213 |
+
|
214 |
+
|
215 |
+
class TranslateYAbs:
|
216 |
+
def __init__(self):
|
217 |
+
pass
|
218 |
+
|
219 |
+
def __call__(self, img, val=0, prob=1.):
|
220 |
+
if np.random.uniform(0,1) > prob:
|
221 |
+
return img
|
222 |
+
|
223 |
+
v = np.random.uniform(0, val)
|
224 |
+
|
225 |
+
if np.random.uniform(0,1) > 0.5:
|
226 |
+
v = -v
|
227 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v))
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
|
augmentation/noise.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
import skimage as sk
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
'''
|
7 |
+
PIL resize (W,H)
|
8 |
+
'''
|
9 |
+
class GaussianNoise:
|
10 |
+
def __init__(self):
|
11 |
+
pass
|
12 |
+
|
13 |
+
def __call__(self, img, mag=-1, prob=1.):
|
14 |
+
if np.random.uniform(0,1) > prob:
|
15 |
+
return img
|
16 |
+
|
17 |
+
W, H = img.size
|
18 |
+
#c = np.random.uniform(.08, .38)
|
19 |
+
b = [.08, 0.1, 0.12]
|
20 |
+
if mag<0 or mag>=len(b):
|
21 |
+
index = 0
|
22 |
+
else:
|
23 |
+
index = mag
|
24 |
+
a = b[index]
|
25 |
+
c = np.random.uniform(a, a+0.03)
|
26 |
+
img = np.array(img) / 255.
|
27 |
+
img = np.clip(img + np.random.normal(size=img.shape, scale=c), 0, 1) * 255
|
28 |
+
return Image.fromarray(img.astype(np.uint8))
|
29 |
+
|
30 |
+
|
31 |
+
class ShotNoise:
|
32 |
+
def __init__(self):
|
33 |
+
pass
|
34 |
+
|
35 |
+
def __call__(self, img, mag=-1, prob=1.):
|
36 |
+
if np.random.uniform(0,1) > prob:
|
37 |
+
return img
|
38 |
+
|
39 |
+
W, H = img.size
|
40 |
+
#c = np.random.uniform(3, 60)
|
41 |
+
b = [13, 8, 3]
|
42 |
+
if mag<0 or mag>=len(b):
|
43 |
+
index = 2
|
44 |
+
else:
|
45 |
+
index = mag
|
46 |
+
a = b[index]
|
47 |
+
c = np.random.uniform(a, a+7)
|
48 |
+
img = np.array(img) / 255.
|
49 |
+
img = np.clip(np.random.poisson(img * c) / float(c), 0, 1) * 255
|
50 |
+
return Image.fromarray(img.astype(np.uint8))
|
51 |
+
|
52 |
+
|
53 |
+
class ImpulseNoise:
|
54 |
+
def __init__(self):
|
55 |
+
pass
|
56 |
+
|
57 |
+
def __call__(self, img, mag=-1, prob=1.):
|
58 |
+
if np.random.uniform(0,1) > prob:
|
59 |
+
return img
|
60 |
+
|
61 |
+
W, H = img.size
|
62 |
+
#c = np.random.uniform(.03, .27)
|
63 |
+
b = [.03, .07, .11]
|
64 |
+
if mag<0 or mag>=len(b):
|
65 |
+
index = 0
|
66 |
+
else:
|
67 |
+
index = mag
|
68 |
+
a = b[index]
|
69 |
+
c = np.random.uniform(a, a+.04)
|
70 |
+
img = sk.util.random_noise(np.array(img) / 255., mode='s&p', amount=c) * 255
|
71 |
+
return Image.fromarray(img.astype(np.uint8))
|
72 |
+
|
73 |
+
|
74 |
+
class SpeckleNoise:
|
75 |
+
def __init__(self):
|
76 |
+
pass
|
77 |
+
|
78 |
+
def __call__(self, img, mag=-1, prob=1.):
|
79 |
+
if np.random.uniform(0,1) > prob:
|
80 |
+
return img
|
81 |
+
|
82 |
+
W, H = img.size
|
83 |
+
# c = np.random.uniform(.15, .6)
|
84 |
+
b = [.15, .2, .25]
|
85 |
+
if mag<0 or mag>=len(b):
|
86 |
+
index = 0
|
87 |
+
else:
|
88 |
+
index = mag
|
89 |
+
a = b[index]
|
90 |
+
c = np.random.uniform(a, a+.05)
|
91 |
+
img = np.array(img) / 255.
|
92 |
+
img = np.clip(img + img * np.random.normal(size=img.shape, scale=c), 0, 1) * 255
|
93 |
+
return Image.fromarray(img.astype(np.uint8))
|
94 |
+
|
augmentation/ops.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from wand.image import Image as WandImage
|
5 |
+
from scipy.ndimage import zoom as scizoom
|
6 |
+
from wand.api import library as wandlibrary
|
7 |
+
|
8 |
+
class MotionImage(WandImage):
|
9 |
+
def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
|
10 |
+
wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)
|
11 |
+
|
12 |
+
def clipped_zoom(img, zoom_factor):
|
13 |
+
h = img.shape[1]
|
14 |
+
# ceil crop height(= crop width)
|
15 |
+
ch = int(np.ceil(h / float(zoom_factor)))
|
16 |
+
|
17 |
+
top = (h - ch) // 2
|
18 |
+
img = scizoom(img[top:top + ch, top:top + ch], (zoom_factor, zoom_factor, 1), order=1)
|
19 |
+
# trim off any extra pixels
|
20 |
+
trim_top = (img.shape[0] - h) // 2
|
21 |
+
|
22 |
+
return img[trim_top:trim_top + h, trim_top:trim_top + h]
|
23 |
+
|
24 |
+
def disk(radius, alias_blur=0.1, dtype=np.float32):
|
25 |
+
if radius <= 8:
|
26 |
+
L = np.arange(-8, 8 + 1)
|
27 |
+
ksize = (3, 3)
|
28 |
+
else:
|
29 |
+
L = np.arange(-radius, radius + 1)
|
30 |
+
ksize = (5, 5)
|
31 |
+
X, Y = np.meshgrid(L, L)
|
32 |
+
aliased_disk = np.array((X ** 2 + Y ** 2) <= radius ** 2, dtype=dtype)
|
33 |
+
aliased_disk /= np.sum(aliased_disk)
|
34 |
+
|
35 |
+
# supersample disk to antialias
|
36 |
+
return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)
|
37 |
+
|
38 |
+
# modification of https://github.com/FLHerne/mapgen/blob/master/diamondsquare.py
|
39 |
+
def plasma_fractal(mapsize=256, wibbledecay=3):
|
40 |
+
"""
|
41 |
+
Generate a heightmap using diamond-square algorithm.
|
42 |
+
Return square 2d array, side length 'mapsize', of floats in range 0-255.
|
43 |
+
'mapsize' must be a power of two.
|
44 |
+
"""
|
45 |
+
assert (mapsize & (mapsize - 1) == 0)
|
46 |
+
maparray = np.empty((mapsize, mapsize), dtype=np.float_)
|
47 |
+
maparray[0, 0] = 0
|
48 |
+
stepsize = mapsize
|
49 |
+
wibble = 100
|
50 |
+
|
51 |
+
def wibbledmean(array):
|
52 |
+
return array / 4 + wibble * np.random.uniform(-wibble, wibble, array.shape)
|
53 |
+
|
54 |
+
def fillsquares():
|
55 |
+
"""For each square of points stepsize apart,
|
56 |
+
calculate middle value as mean of points + wibble"""
|
57 |
+
cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
|
58 |
+
squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
|
59 |
+
squareaccum += np.roll(squareaccum, shift=-1, axis=1)
|
60 |
+
maparray[stepsize // 2:mapsize:stepsize,
|
61 |
+
stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum)
|
62 |
+
|
63 |
+
def filldiamonds():
|
64 |
+
"""For each diamond of points stepsize apart,
|
65 |
+
calculate middle value as mean of points + wibble"""
|
66 |
+
mapsize = maparray.shape[0]
|
67 |
+
drgrid = maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize]
|
68 |
+
ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
|
69 |
+
ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
|
70 |
+
lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
|
71 |
+
ltsum = ldrsum + lulsum
|
72 |
+
maparray[0:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum)
|
73 |
+
tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
|
74 |
+
tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
|
75 |
+
ttsum = tdrsum + tulsum
|
76 |
+
maparray[stepsize // 2:mapsize:stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum)
|
77 |
+
|
78 |
+
while stepsize >= 2:
|
79 |
+
fillsquares()
|
80 |
+
filldiamonds()
|
81 |
+
stepsize //= 2
|
82 |
+
wibble /= wibbledecay
|
83 |
+
|
84 |
+
maparray -= maparray.min()
|
85 |
+
return maparray / maparray.max()
|
86 |
+
|
87 |
+
|
augmentation/pattern.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image, ImageOps, ImageDraw
|
5 |
+
|
6 |
+
'''
|
7 |
+
PIL resize (W,H)
|
8 |
+
Torch resize is (H,W)
|
9 |
+
'''
|
10 |
+
class VGrid:
|
11 |
+
def __init__(self):
|
12 |
+
pass
|
13 |
+
|
14 |
+
def __call__(self, img, copy=True, max_width=4, mag=-1, prob=1.):
|
15 |
+
if np.random.uniform(0,1) > prob:
|
16 |
+
return img
|
17 |
+
|
18 |
+
if copy:
|
19 |
+
img = img.copy()
|
20 |
+
W, H = img.size
|
21 |
+
|
22 |
+
if mag<0 or mag>max_width:
|
23 |
+
line_width = np.random.randint(1, max_width)
|
24 |
+
image_stripe = np.random.randint(1, max_width)
|
25 |
+
else:
|
26 |
+
line_width = 1
|
27 |
+
image_stripe = 3 - mag
|
28 |
+
|
29 |
+
n_lines = W // (line_width + image_stripe) + 1
|
30 |
+
draw = ImageDraw.Draw(img)
|
31 |
+
for i in range(1, n_lines):
|
32 |
+
x = image_stripe*i + line_width*(i-1)
|
33 |
+
draw.line([(x,0), (x,H)], width=line_width, fill='black')
|
34 |
+
|
35 |
+
return img
|
36 |
+
|
37 |
+
class HGrid:
|
38 |
+
def __init__(self):
|
39 |
+
pass
|
40 |
+
|
41 |
+
def __call__(self, img, copy=True, max_width=4, mag=-1, prob=1.):
|
42 |
+
if np.random.uniform(0,1) > prob:
|
43 |
+
return img
|
44 |
+
|
45 |
+
if copy:
|
46 |
+
img = img.copy()
|
47 |
+
W, H = img.size
|
48 |
+
if mag<0 or mag>max_width:
|
49 |
+
line_width = np.random.randint(1, max_width)
|
50 |
+
image_stripe = np.random.randint(1, max_width)
|
51 |
+
else:
|
52 |
+
line_width = 1
|
53 |
+
image_stripe = 3 - mag
|
54 |
+
|
55 |
+
n_lines = H // (line_width + image_stripe) + 1
|
56 |
+
draw = ImageDraw.Draw(img)
|
57 |
+
for i in range(1, n_lines):
|
58 |
+
y = image_stripe*i + line_width*(i-1)
|
59 |
+
draw.line([(0,y), (W, y)], width=line_width, fill='black')
|
60 |
+
|
61 |
+
return img
|
62 |
+
|
63 |
+
class Grid:
|
64 |
+
def __init__(self):
|
65 |
+
pass
|
66 |
+
|
67 |
+
def __call__(self, img, mag=-1, prob=1.):
|
68 |
+
if np.random.uniform(0,1) > prob:
|
69 |
+
return img
|
70 |
+
|
71 |
+
img = VGrid()(img, copy=True, mag=mag)
|
72 |
+
img = HGrid()(img, copy=False, mag=mag)
|
73 |
+
return img
|
74 |
+
|
75 |
+
class RectGrid:
|
76 |
+
def __init__(self):
|
77 |
+
pass
|
78 |
+
|
79 |
+
def __call__(self, img, isellipse=False, mag=-1, prob=1.):
|
80 |
+
if np.random.uniform(0,1) > prob:
|
81 |
+
return img
|
82 |
+
|
83 |
+
img = img.copy()
|
84 |
+
W, H = img.size
|
85 |
+
line_width = 1
|
86 |
+
image_stripe = 3 - mag #np.random.randint(2, 6)
|
87 |
+
offset = 4 if isellipse else 1
|
88 |
+
n_lines = ((H//2) // (line_width + image_stripe)) + offset
|
89 |
+
draw = ImageDraw.Draw(img)
|
90 |
+
x_center = W // 2
|
91 |
+
y_center = H // 2
|
92 |
+
for i in range(1, n_lines):
|
93 |
+
dx = image_stripe*i + line_width*(i-1)
|
94 |
+
dy = image_stripe*i + line_width*(i-1)
|
95 |
+
x1 = x_center - (dx * W//H)
|
96 |
+
y1 = y_center - dy
|
97 |
+
x2 = x_center + (dx * W/H)
|
98 |
+
y2 = y_center + dy
|
99 |
+
if isellipse:
|
100 |
+
draw.ellipse([(x1,y1), (x2, y2)], width=line_width, outline='black')
|
101 |
+
else:
|
102 |
+
draw.rectangle([(x1,y1), (x2, y2)], width=line_width, outline='black')
|
103 |
+
|
104 |
+
return img
|
105 |
+
|
106 |
+
class EllipseGrid:
|
107 |
+
def __init__(self):
|
108 |
+
pass
|
109 |
+
|
110 |
+
def __call__(self, img, mag=-1, prob=1.):
|
111 |
+
if np.random.uniform(0,1) > prob:
|
112 |
+
return img
|
113 |
+
|
114 |
+
img = RectGrid()(img, isellipse=True, mag=mag, prob=prob)
|
115 |
+
return img
|
augmentation/process.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from PIL import Image
|
3 |
+
import PIL.ImageOps, PIL.ImageEnhance
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
class Posterize:
|
7 |
+
def __init__(self):
|
8 |
+
pass
|
9 |
+
|
10 |
+
def __call__(self, img, mag=-1, prob=1.):
|
11 |
+
if np.random.uniform(0,1) > prob:
|
12 |
+
return img
|
13 |
+
|
14 |
+
c = [1, 3, 6]
|
15 |
+
if mag<0 or mag>=len(c):
|
16 |
+
index = np.random.randint(0, len(c))
|
17 |
+
else:
|
18 |
+
index = mag
|
19 |
+
c = c[index]
|
20 |
+
bit = np.random.randint(c, c+2)
|
21 |
+
img = PIL.ImageOps.posterize(img, bit)
|
22 |
+
|
23 |
+
return img
|
24 |
+
|
25 |
+
|
26 |
+
class Solarize:
|
27 |
+
def __init__(self):
|
28 |
+
pass
|
29 |
+
|
30 |
+
def __call__(self, img, mag=-1, prob=1.):
|
31 |
+
if np.random.uniform(0,1) > prob:
|
32 |
+
return img
|
33 |
+
|
34 |
+
c = [64, 128, 192]
|
35 |
+
if mag<0 or mag>=len(c):
|
36 |
+
index = np.random.randint(0, len(c))
|
37 |
+
else:
|
38 |
+
index = mag
|
39 |
+
c = c[index]
|
40 |
+
thresh = np.random.randint(c, c+64)
|
41 |
+
img = PIL.ImageOps.solarize(img, thresh)
|
42 |
+
|
43 |
+
return img
|
44 |
+
|
45 |
+
class Invert:
|
46 |
+
def __init__(self):
|
47 |
+
pass
|
48 |
+
|
49 |
+
def __call__(self, img, mag=-1, prob=1.):
|
50 |
+
if np.random.uniform(0,1) > prob:
|
51 |
+
return img
|
52 |
+
|
53 |
+
img = PIL.ImageOps.invert(img)
|
54 |
+
|
55 |
+
return img
|
56 |
+
|
57 |
+
|
58 |
+
class Equalize:
|
59 |
+
def __init__(self):
|
60 |
+
pass
|
61 |
+
|
62 |
+
def __call__(self, img, mag=-1, prob=1.):
|
63 |
+
if np.random.uniform(0,1) > prob:
|
64 |
+
return img
|
65 |
+
|
66 |
+
mg = PIL.ImageOps.equalize(img)
|
67 |
+
|
68 |
+
return img
|
69 |
+
|
70 |
+
|
71 |
+
class AutoContrast:
|
72 |
+
def __init__(self):
|
73 |
+
pass
|
74 |
+
|
75 |
+
def __call__(self, img, mag=-1, prob=1.):
|
76 |
+
if np.random.uniform(0,1) > prob:
|
77 |
+
return img
|
78 |
+
|
79 |
+
mg = PIL.ImageOps.autocontrast(img)
|
80 |
+
|
81 |
+
return img
|
82 |
+
|
83 |
+
|
84 |
+
class Sharpness:
|
85 |
+
def __init__(self):
|
86 |
+
pass
|
87 |
+
|
88 |
+
def __call__(self, img, mag=-1, prob=1.):
|
89 |
+
if np.random.uniform(0,1) > prob:
|
90 |
+
return img
|
91 |
+
|
92 |
+
c = [.1, .7, 1.3]
|
93 |
+
if mag<0 or mag>=len(c):
|
94 |
+
index = np.random.randint(0, len(c))
|
95 |
+
else:
|
96 |
+
index = mag
|
97 |
+
c = c[index]
|
98 |
+
magnitude = np.random.uniform(c, c+.6)
|
99 |
+
img = PIL.ImageEnhance.Sharpness(img).enhance(magnitude)
|
100 |
+
|
101 |
+
return img
|
102 |
+
|
103 |
+
|
104 |
+
class Color:
|
105 |
+
def __init__(self):
|
106 |
+
pass
|
107 |
+
|
108 |
+
def __call__(self, img, mag=-1, prob=1.):
|
109 |
+
if np.random.uniform(0,1) > prob:
|
110 |
+
return img
|
111 |
+
|
112 |
+
c = [.1, .7, 1.3]
|
113 |
+
if mag<0 or mag>=len(c):
|
114 |
+
index = np.random.randint(0, len(c))
|
115 |
+
else:
|
116 |
+
index = mag
|
117 |
+
c = c[index]
|
118 |
+
magnitude = np.random.uniform(c, c+.6)
|
119 |
+
img = PIL.ImageEnhance.Color(img).enhance(magnitude)
|
120 |
+
|
121 |
+
return img
|
122 |
+
|
123 |
+
|
augmentation/test.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
from warp import Curve, Distort, Stretch
|
5 |
+
from geometry import Rotate, Perspective, Shrink, TranslateX, TranslateY
|
6 |
+
from pattern import VGrid, HGrid, Grid, RectGrid, EllipseGrid
|
7 |
+
from noise import GaussianNoise, ShotNoise, ImpulseNoise, SpeckleNoise
|
8 |
+
from blur import GaussianBlur, DefocusBlur, MotionBlur, GlassBlur, ZoomBlur
|
9 |
+
from camera import Contrast, Brightness, JpegCompression, Pixelate
|
10 |
+
from weather import Fog, Snow, Frost, Rain, Shadow
|
11 |
+
from process import Posterize, Solarize, Invert, Equalize, AutoContrast, Sharpness, Color
|
12 |
+
|
13 |
+
from PIL import Image
|
14 |
+
import PIL.ImageOps
|
15 |
+
import numpy as np
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
|
19 |
+
if __name__ == '__main__':
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument('--image', default="images/delivery.png", help='Load image file')
|
22 |
+
parser.add_argument('--results', default="results", help='Load image file')
|
23 |
+
parser.add_argument('--gray', action='store_true', help='Convert to grayscale 1st')
|
24 |
+
opt = parser.parse_args()
|
25 |
+
os.makedirs(opt.results, exist_ok=True)
|
26 |
+
|
27 |
+
img = Image.open(opt.image)
|
28 |
+
img = img.resize( (100,32) )
|
29 |
+
ops = [Curve(), Rotate(), Perspective(), Distort(), Stretch(), Shrink(), TranslateX(), TranslateY(), VGrid(), HGrid(), Grid(), RectGrid(), EllipseGrid()]
|
30 |
+
ops.extend([GaussianNoise(), ShotNoise(), ImpulseNoise(), SpeckleNoise()])
|
31 |
+
ops.extend([GaussianBlur(), DefocusBlur(), MotionBlur(), GlassBlur(), ZoomBlur()])
|
32 |
+
ops.extend([Contrast(), Brightness(), JpegCompression(), Pixelate()])
|
33 |
+
ops.extend([Fog(), Snow(), Frost(), Rain(), Shadow()])
|
34 |
+
ops.extend([Posterize(), Solarize(), Invert(), Equalize(), AutoContrast(), Sharpness(), Color()])
|
35 |
+
for op in ops:
|
36 |
+
for mag in range(3):
|
37 |
+
filename = type(op).__name__ + "-" + str(mag) + ".png"
|
38 |
+
out_img = op(img, mag=mag)
|
39 |
+
if opt.gray:
|
40 |
+
out_img = PIL.ImageOps.grayscale(out_img)
|
41 |
+
out_img.save(os.path.join(opt.results, filename))
|
42 |
+
|
43 |
+
|
augmentation/warp.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image, ImageOps
|
5 |
+
|
6 |
+
'''
|
7 |
+
PIL resize (W,H)
|
8 |
+
Torch resize is (H,W)
|
9 |
+
'''
|
10 |
+
class Stretch:
|
11 |
+
def __init__(self):
|
12 |
+
self.tps = cv2.createThinPlateSplineShapeTransformer()
|
13 |
+
|
14 |
+
def __call__(self, img, mag=-1, prob=1.):
|
15 |
+
if np.random.uniform(0,1) > prob:
|
16 |
+
return img
|
17 |
+
|
18 |
+
W, H = img.size
|
19 |
+
img = np.array(img)
|
20 |
+
srcpt = list()
|
21 |
+
dstpt = list()
|
22 |
+
|
23 |
+
W_33 = 0.33 * W
|
24 |
+
W_50 = 0.50 * W
|
25 |
+
W_66 = 0.66 * W
|
26 |
+
|
27 |
+
H_50 = 0.50 * H
|
28 |
+
|
29 |
+
P = 0
|
30 |
+
#frac = 0.4
|
31 |
+
|
32 |
+
b = [.2, .3, .4]
|
33 |
+
if mag<0 or mag>=len(b):
|
34 |
+
index = len(b)-1
|
35 |
+
else:
|
36 |
+
index = mag
|
37 |
+
frac = b[index]
|
38 |
+
|
39 |
+
# left-most
|
40 |
+
srcpt.append([P, P])
|
41 |
+
srcpt.append([P, H-P])
|
42 |
+
srcpt.append([P, H_50])
|
43 |
+
x = np.random.uniform(0, frac)*W_33 #if np.random.uniform(0,1) > 0.5 else 0
|
44 |
+
dstpt.append([P+x, P])
|
45 |
+
dstpt.append([P+x, H-P])
|
46 |
+
dstpt.append([P+x, H_50])
|
47 |
+
|
48 |
+
# 2nd left-most
|
49 |
+
srcpt.append([P+W_33, P])
|
50 |
+
srcpt.append([P+W_33, H-P])
|
51 |
+
x = np.random.uniform(-frac, frac)*W_33
|
52 |
+
dstpt.append([P+W_33+x, P])
|
53 |
+
dstpt.append([P+W_33+x, H-P])
|
54 |
+
|
55 |
+
# 3rd left-most
|
56 |
+
srcpt.append([P+W_66, P])
|
57 |
+
srcpt.append([P+W_66, H-P])
|
58 |
+
x = np.random.uniform(-frac, frac)*W_33
|
59 |
+
dstpt.append([P+W_66+x, P])
|
60 |
+
dstpt.append([P+W_66+x, H-P])
|
61 |
+
|
62 |
+
# right-most
|
63 |
+
srcpt.append([W-P, P])
|
64 |
+
srcpt.append([W-P, H-P])
|
65 |
+
srcpt.append([W-P, H_50])
|
66 |
+
x = np.random.uniform(-frac, 0)*W_33 #if np.random.uniform(0,1) > 0.5 else 0
|
67 |
+
dstpt.append([W-P+x, P])
|
68 |
+
dstpt.append([W-P+x, H-P])
|
69 |
+
dstpt.append([W-P+x, H_50])
|
70 |
+
|
71 |
+
N = len(dstpt)
|
72 |
+
matches = [cv2.DMatch(i, i, 0) for i in range(N)]
|
73 |
+
dst_shape = np.array(dstpt).reshape((-1, N, 2))
|
74 |
+
src_shape = np.array(srcpt).reshape((-1, N, 2))
|
75 |
+
self.tps.estimateTransformation(dst_shape, src_shape, matches)
|
76 |
+
img = self.tps.warpImage(img)
|
77 |
+
img = Image.fromarray(img)
|
78 |
+
|
79 |
+
return img
|
80 |
+
|
81 |
+
|
82 |
+
class Distort:
|
83 |
+
def __init__(self):
|
84 |
+
self.tps = cv2.createThinPlateSplineShapeTransformer()
|
85 |
+
|
86 |
+
def __call__(self, img, mag=-1, prob=1.):
|
87 |
+
if np.random.uniform(0,1) > prob:
|
88 |
+
return img
|
89 |
+
|
90 |
+
W, H = img.size
|
91 |
+
img = np.array(img)
|
92 |
+
srcpt = list()
|
93 |
+
dstpt = list()
|
94 |
+
|
95 |
+
W_33 = 0.33 * W
|
96 |
+
W_50 = 0.50 * W
|
97 |
+
W_66 = 0.66 * W
|
98 |
+
|
99 |
+
H_50 = 0.50 * H
|
100 |
+
|
101 |
+
P = 0
|
102 |
+
#frac = 0.4
|
103 |
+
|
104 |
+
b = [.2, .3, .4]
|
105 |
+
if mag<0 or mag>=len(b):
|
106 |
+
index = len(b)-1
|
107 |
+
else:
|
108 |
+
index = mag
|
109 |
+
frac = b[index]
|
110 |
+
|
111 |
+
# top pts
|
112 |
+
srcpt.append([P, P])
|
113 |
+
x = np.random.uniform(0, frac)*W_33
|
114 |
+
y = np.random.uniform(0, frac)*H_50
|
115 |
+
dstpt.append([P+x, P+y])
|
116 |
+
|
117 |
+
srcpt.append([P+W_33, P])
|
118 |
+
x = np.random.uniform(-frac, frac)*W_33
|
119 |
+
y = np.random.uniform(0, frac)*H_50
|
120 |
+
dstpt.append([P+W_33+x, P+y])
|
121 |
+
|
122 |
+
srcpt.append([P+W_66, P])
|
123 |
+
x = np.random.uniform(-frac, frac)*W_33
|
124 |
+
y = np.random.uniform(0, frac)*H_50
|
125 |
+
dstpt.append([P+W_66+x, P+y])
|
126 |
+
|
127 |
+
srcpt.append([W-P, P])
|
128 |
+
x = np.random.uniform(-frac, 0)*W_33
|
129 |
+
y = np.random.uniform(0, frac)*H_50
|
130 |
+
dstpt.append([W-P+x, P+y])
|
131 |
+
|
132 |
+
# bottom pts
|
133 |
+
srcpt.append([P, H-P])
|
134 |
+
x = np.random.uniform(0, frac)*W_33
|
135 |
+
y = np.random.uniform(-frac, 0)*H_50
|
136 |
+
dstpt.append([P+x, H-P+y])
|
137 |
+
|
138 |
+
srcpt.append([P+W_33, H-P])
|
139 |
+
x = np.random.uniform(-frac, frac)*W_33
|
140 |
+
y = np.random.uniform(-frac, 0)*H_50
|
141 |
+
dstpt.append([P+W_33+x, H-P+y])
|
142 |
+
|
143 |
+
srcpt.append([P+W_66, H-P])
|
144 |
+
x = np.random.uniform(-frac, frac)*W_33
|
145 |
+
y = np.random.uniform(-frac, 0)*H_50
|
146 |
+
dstpt.append([P+W_66+x, H-P+y])
|
147 |
+
|
148 |
+
srcpt.append([W-P, H-P])
|
149 |
+
x = np.random.uniform(-frac, 0)*W_33
|
150 |
+
y = np.random.uniform(-frac, 0)*H_50
|
151 |
+
dstpt.append([W-P+x, H-P+y])
|
152 |
+
|
153 |
+
N = len(dstpt)
|
154 |
+
matches = [cv2.DMatch(i, i, 0) for i in range(N)]
|
155 |
+
dst_shape = np.array(dstpt).reshape((-1, N, 2))
|
156 |
+
src_shape = np.array(srcpt).reshape((-1, N, 2))
|
157 |
+
self.tps.estimateTransformation(dst_shape, src_shape, matches)
|
158 |
+
img = self.tps.warpImage(img)
|
159 |
+
img = Image.fromarray(img)
|
160 |
+
|
161 |
+
return img
|
162 |
+
|
163 |
+
|
164 |
+
class Curve:
|
165 |
+
def __init__(self, square_side=224):
|
166 |
+
self.tps = cv2.createThinPlateSplineShapeTransformer()
|
167 |
+
self.side = square_side
|
168 |
+
|
169 |
+
def __call__(self, img, mag=-1, prob=1.):
|
170 |
+
if np.random.uniform(0,1) > prob:
|
171 |
+
return img
|
172 |
+
|
173 |
+
W, H = img.size
|
174 |
+
|
175 |
+
if H!=self.side or W!=self.side:
|
176 |
+
img = img.resize((self.side, self.side), Image.BICUBIC)
|
177 |
+
|
178 |
+
isflip = np.random.uniform(0,1) > 0.5
|
179 |
+
if isflip:
|
180 |
+
img = ImageOps.flip(img)
|
181 |
+
#img = TF.vflip(img)
|
182 |
+
|
183 |
+
img = np.array(img)
|
184 |
+
w = self.side
|
185 |
+
h = self.side
|
186 |
+
w_25 = 0.25 * w
|
187 |
+
w_50 = 0.50 * w
|
188 |
+
w_75 = 0.75 * w
|
189 |
+
|
190 |
+
b = [1.1, .95, .8]
|
191 |
+
if mag<0 or mag>=len(b):
|
192 |
+
index = 0
|
193 |
+
else:
|
194 |
+
index = mag
|
195 |
+
rmin = b[index]
|
196 |
+
|
197 |
+
r = np.random.uniform(rmin, rmin+.1)*h
|
198 |
+
x1 = (r**2 - w_50**2)**0.5
|
199 |
+
h1 = r - x1
|
200 |
+
|
201 |
+
t = np.random.uniform(0.4, 0.5)*h
|
202 |
+
|
203 |
+
w2 = w_50*t/r
|
204 |
+
hi = x1*t/r
|
205 |
+
h2 = h1 + hi
|
206 |
+
|
207 |
+
sinb_2 = ((1 - x1/r)/2)**0.5
|
208 |
+
cosb_2 = ((1 + x1/r)/2)**0.5
|
209 |
+
w3 = w_50 - r*sinb_2
|
210 |
+
h3 = r - r*cosb_2
|
211 |
+
|
212 |
+
w4 = w_50 - (r-t)*sinb_2
|
213 |
+
h4 = r - (r-t)*cosb_2
|
214 |
+
|
215 |
+
w5 = 0.5*w2
|
216 |
+
h5 = h1 + 0.5*hi
|
217 |
+
h_50 = 0.50*h
|
218 |
+
|
219 |
+
srcpt = [(0,0 ), (w,0 ), (w_50,0), (0,h ), (w,h ), (w_25,0), (w_75,0 ), (w_50,h), (w_25,h), (w_75,h ), (0,h_50), (w,h_50 )]
|
220 |
+
dstpt = [(0,h1), (w,h1), (w_50,0), (w2,h2), (w-w2,h2), (w3, h3), (w-w3,h3), (w_50,t), (w4,h4 ), (w-w4,h4), (w5,h5 ), (w-w5,h5)]
|
221 |
+
|
222 |
+
N = len(dstpt)
|
223 |
+
matches = [cv2.DMatch(i, i, 0) for i in range(N)]
|
224 |
+
dst_shape = np.array(dstpt).reshape((-1, N, 2))
|
225 |
+
src_shape = np.array(srcpt).reshape((-1, N, 2))
|
226 |
+
self.tps.estimateTransformation(dst_shape, src_shape, matches)
|
227 |
+
img = self.tps.warpImage(img)
|
228 |
+
img = Image.fromarray(img)
|
229 |
+
|
230 |
+
if isflip:
|
231 |
+
#img = TF.vflip(img)
|
232 |
+
img = ImageOps.flip(img)
|
233 |
+
rect = (0, self.side//2, self.side, self.side)
|
234 |
+
else:
|
235 |
+
rect = (0, 0, self.side, self.side//2)
|
236 |
+
|
237 |
+
img = img.crop(rect)
|
238 |
+
img = img.resize((W, H), Image.BICUBIC)
|
239 |
+
return img
|
240 |
+
|
241 |
+
|
augmentation/weather.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
from PIL import Image, ImageOps, ImageDraw
|
6 |
+
from skimage import color
|
7 |
+
from pkg_resources import resource_filename
|
8 |
+
from io import BytesIO
|
9 |
+
from .ops import plasma_fractal, clipped_zoom, MotionImage
|
10 |
+
|
11 |
+
'''
|
12 |
+
PIL resize (W,H)
|
13 |
+
'''
|
14 |
+
class Fog:
|
15 |
+
def __init__(self):
|
16 |
+
pass
|
17 |
+
|
18 |
+
def __call__(self, img, mag=-1, prob=1.):
|
19 |
+
if np.random.uniform(0,1) > prob:
|
20 |
+
return img
|
21 |
+
|
22 |
+
W, H = img.size
|
23 |
+
c = [(1.5, 2), (2., 2), (2.5, 1.7)]
|
24 |
+
if mag<0 or mag>=len(c):
|
25 |
+
index = np.random.randint(0, len(c))
|
26 |
+
else:
|
27 |
+
index = mag
|
28 |
+
c = c[index]
|
29 |
+
|
30 |
+
n_channels = len(img.getbands())
|
31 |
+
isgray = n_channels == 1
|
32 |
+
|
33 |
+
img = np.array(img) / 255.
|
34 |
+
max_val = img.max()
|
35 |
+
fog = c[0] * plasma_fractal(wibbledecay=c[1])[:H, :W][..., np.newaxis]
|
36 |
+
#x += c[0] * plasma_fractal(wibbledecay=c[1])[:224, :224][..., np.newaxis]
|
37 |
+
#return np.clip(x * max_val / (max_val + c[0]), 0, 1) * 255
|
38 |
+
if isgray:
|
39 |
+
fog = np.squeeze(fog)
|
40 |
+
else:
|
41 |
+
fog = np.repeat(fog, 3, axis=2)
|
42 |
+
|
43 |
+
img += fog
|
44 |
+
img = np.clip(img * max_val / (max_val + c[0]), 0, 1) * 255
|
45 |
+
return Image.fromarray(img.astype(np.uint8))
|
46 |
+
|
47 |
+
|
48 |
+
class Frost:
|
49 |
+
def __init__(self):
|
50 |
+
pass
|
51 |
+
|
52 |
+
def __call__(self, img, mag=-1, prob=1.):
|
53 |
+
if np.random.uniform(0,1) > prob:
|
54 |
+
return img
|
55 |
+
|
56 |
+
W, H = img.size
|
57 |
+
c = [(1, 0.4), (0.8, 0.6), (0.7, 0.7)]
|
58 |
+
if mag<0 or mag>=len(c):
|
59 |
+
index = np.random.randint(0, len(c))
|
60 |
+
else:
|
61 |
+
index = mag
|
62 |
+
c = c[index]
|
63 |
+
|
64 |
+
filename = [resource_filename(__name__, 'frost/frost1.png'),
|
65 |
+
resource_filename(__name__, 'frost/frost2.png'),
|
66 |
+
resource_filename(__name__, 'frost/frost3.png'),
|
67 |
+
resource_filename(__name__, 'frost/frost4.jpg'),
|
68 |
+
resource_filename(__name__, 'frost/frost5.jpg'),
|
69 |
+
resource_filename(__name__, 'frost/frost6.jpg')]
|
70 |
+
index = np.random.randint(0, len(filename))
|
71 |
+
filename = filename[index]
|
72 |
+
frost = cv2.imread(filename)
|
73 |
+
#randomly crop and convert to rgb
|
74 |
+
x_start, y_start = np.random.randint(0, frost.shape[0] - H), np.random.randint(0, frost.shape[1] - W)
|
75 |
+
frost = frost[x_start:x_start + H, y_start:y_start + W][..., [2, 1, 0]]
|
76 |
+
|
77 |
+
n_channels = len(img.getbands())
|
78 |
+
isgray = n_channels == 1
|
79 |
+
|
80 |
+
img = np.array(img)
|
81 |
+
|
82 |
+
if isgray:
|
83 |
+
img = np.expand_dims(img, axis=2)
|
84 |
+
img = np.repeat(img, 3, axis=2)
|
85 |
+
|
86 |
+
img = img * c[0]
|
87 |
+
frost = frost * c[1]
|
88 |
+
img = np.clip(c[0] * img + c[1] * frost, 0, 255)
|
89 |
+
img = Image.fromarray(img.astype(np.uint8))
|
90 |
+
if isgray:
|
91 |
+
img = ImageOps.grayscale(img)
|
92 |
+
|
93 |
+
return img
|
94 |
+
|
95 |
+
class Snow:
|
96 |
+
def __init__(self):
|
97 |
+
pass
|
98 |
+
|
99 |
+
def __call__(self, img, mag=-1, prob=1.):
|
100 |
+
if np.random.uniform(0,1) > prob:
|
101 |
+
return img
|
102 |
+
|
103 |
+
W, H = img.size
|
104 |
+
c = [(0.1, 0.3, 3, 0.5, 10, 4, 0.8),
|
105 |
+
(0.2, 0.3, 2, 0.5, 12, 4, 0.7),
|
106 |
+
(0.55, 0.3, 4, 0.9, 12, 8, 0.7)]
|
107 |
+
if mag<0 or mag>=len(c):
|
108 |
+
index = np.random.randint(0, len(c))
|
109 |
+
else:
|
110 |
+
index = mag
|
111 |
+
c = c[index]
|
112 |
+
|
113 |
+
n_channels = len(img.getbands())
|
114 |
+
isgray = n_channels == 1
|
115 |
+
|
116 |
+
img = np.array(img, dtype=np.float32) / 255.
|
117 |
+
if isgray:
|
118 |
+
img = np.expand_dims(img, axis=2)
|
119 |
+
img = np.repeat(img, 3, axis=2)
|
120 |
+
|
121 |
+
snow_layer = np.random.normal(size=img.shape[:2], loc=c[0], scale=c[1]) # [:2] for monochrome
|
122 |
+
|
123 |
+
#snow_layer = clipped_zoom(snow_layer[..., np.newaxis], c[2])
|
124 |
+
snow_layer[snow_layer < c[3]] = 0
|
125 |
+
|
126 |
+
snow_layer = Image.fromarray((np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode='L')
|
127 |
+
output = BytesIO()
|
128 |
+
snow_layer.save(output, format='PNG')
|
129 |
+
snow_layer = MotionImage(blob=output.getvalue())
|
130 |
+
|
131 |
+
snow_layer.motion_blur(radius=c[4], sigma=c[5], angle=np.random.uniform(-135, -45))
|
132 |
+
|
133 |
+
snow_layer = cv2.imdecode(np.fromstring(snow_layer.make_blob(), np.uint8),
|
134 |
+
cv2.IMREAD_UNCHANGED) / 255.
|
135 |
+
|
136 |
+
#snow_layer = cv2.cvtColor(snow_layer, cv2.COLOR_BGR2RGB)
|
137 |
+
|
138 |
+
snow_layer = snow_layer[..., np.newaxis]
|
139 |
+
|
140 |
+
img = c[6] * img
|
141 |
+
gray_img = (1 - c[6]) * np.maximum(img, cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).reshape(H, W, 1) * 1.5 + 0.5)
|
142 |
+
img += gray_img
|
143 |
+
img = np.clip(img + snow_layer + np.rot90(snow_layer, k=2), 0, 1) * 255
|
144 |
+
img = Image.fromarray(img.astype(np.uint8))
|
145 |
+
if isgray:
|
146 |
+
img = ImageOps.grayscale(img)
|
147 |
+
|
148 |
+
return img
|
149 |
+
|
150 |
+
class Rain:
|
151 |
+
def __init__(self):
|
152 |
+
pass
|
153 |
+
|
154 |
+
def __call__(self, img, mag=-1, prob=1.):
|
155 |
+
if np.random.uniform(0,1) > prob:
|
156 |
+
return img
|
157 |
+
|
158 |
+
img = img.copy()
|
159 |
+
W, H = img.size
|
160 |
+
n_channels = len(img.getbands())
|
161 |
+
isgray = n_channels == 1
|
162 |
+
line_width = np.random.randint(1, 2)
|
163 |
+
|
164 |
+
c =[50, 70, 90]
|
165 |
+
if mag<0 or mag>=len(c):
|
166 |
+
index = 0
|
167 |
+
else:
|
168 |
+
index = mag
|
169 |
+
c = c[index]
|
170 |
+
|
171 |
+
n_rains = np.random.randint(c, c+20)
|
172 |
+
slant = np.random.randint(-60, 60)
|
173 |
+
fillcolor = 200 if isgray else (200,200,200)
|
174 |
+
|
175 |
+
draw = ImageDraw.Draw(img)
|
176 |
+
for i in range(1, n_rains):
|
177 |
+
length = np.random.randint(5, 10)
|
178 |
+
x1 = np.random.randint(0, W-length)
|
179 |
+
y1 = np.random.randint(0, H-length)
|
180 |
+
x2 = x1 + length*math.sin(slant*math.pi/180.)
|
181 |
+
y2 = y1 + length*math.cos(slant*math.pi/180.)
|
182 |
+
x2 = int(x2)
|
183 |
+
y2 = int(y2)
|
184 |
+
draw.line([(x1,y1), (x2,y2)], width=line_width, fill=fillcolor)
|
185 |
+
|
186 |
+
return img
|
187 |
+
|
188 |
+
class Shadow:
|
189 |
+
def __init__(self):
|
190 |
+
pass
|
191 |
+
|
192 |
+
def __call__(self, img, mag=-1, prob=1.):
|
193 |
+
if np.random.uniform(0,1) > prob:
|
194 |
+
return img
|
195 |
+
|
196 |
+
#img = img.copy()
|
197 |
+
W, H = img.size
|
198 |
+
n_channels = len(img.getbands())
|
199 |
+
isgray = n_channels == 1
|
200 |
+
|
201 |
+
c =[64, 96, 128]
|
202 |
+
if mag<0 or mag>=len(c):
|
203 |
+
index = 0
|
204 |
+
else:
|
205 |
+
index = mag
|
206 |
+
c = c[index]
|
207 |
+
|
208 |
+
img = img.convert('RGBA')
|
209 |
+
overlay = Image.new('RGBA', img.size, (255,255,255,0))
|
210 |
+
draw = ImageDraw.Draw(overlay)
|
211 |
+
transparency = np.random.randint(c, c+32)
|
212 |
+
x1 = np.random.randint(0, W//2)
|
213 |
+
y1 = 0
|
214 |
+
|
215 |
+
x2 = np.random.randint(W//2, W)
|
216 |
+
y2 = 0
|
217 |
+
|
218 |
+
x3 = np.random.randint(W//2, W)
|
219 |
+
y3 = H - 1
|
220 |
+
|
221 |
+
x4 = np.random.randint(0, W//2)
|
222 |
+
y4 = H - 1
|
223 |
+
|
224 |
+
draw.polygon([(x1,y1), (x2,y2), (x3,y3), (x4,y4)], fill=(0,0,0,transparency))
|
225 |
+
|
226 |
+
img = Image.alpha_composite(img, overlay)
|
227 |
+
img = img.convert("RGB")
|
228 |
+
if isgray:
|
229 |
+
img = ImageOps.grayscale(img)
|
230 |
+
|
231 |
+
return img
|
callbacks.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import shutil
|
3 |
+
import time
|
4 |
+
|
5 |
+
import editdistance as ed
|
6 |
+
import torchvision.utils as vutils
|
7 |
+
from fastai.callbacks.tensorboard import (LearnerTensorboardWriter,
|
8 |
+
SummaryWriter, TBWriteRequest,
|
9 |
+
asyncTBWriter)
|
10 |
+
from fastai.vision import *
|
11 |
+
from torch.nn.parallel import DistributedDataParallel
|
12 |
+
from torchvision import transforms
|
13 |
+
|
14 |
+
import dataset_abinet
|
15 |
+
from utils_abinet import CharsetMapper, Timer, blend_mask
|
16 |
+
|
17 |
+
|
18 |
+
class IterationCallback(LearnerTensorboardWriter):
|
19 |
+
"A `TrackerCallback` that monitor in each iteration."
|
20 |
+
def __init__(self, learn:Learner, name:str='model', checpoint_keep_num=5,
|
21 |
+
show_iters:int=50, eval_iters:int=1000, save_iters:int=20000,
|
22 |
+
start_iters:int=0, stats_iters=20000):
|
23 |
+
#if self.learn.rank is not None: time.sleep(self.learn.rank) # keep all event files
|
24 |
+
super().__init__(learn, base_dir='.', name=learn.path, loss_iters=show_iters,
|
25 |
+
stats_iters=stats_iters, hist_iters=stats_iters)
|
26 |
+
self.name, self.bestname = Path(name).name, f'best-{Path(name).name}'
|
27 |
+
self.show_iters = show_iters
|
28 |
+
self.eval_iters = eval_iters
|
29 |
+
self.save_iters = save_iters
|
30 |
+
self.start_iters = start_iters
|
31 |
+
self.checpoint_keep_num = checpoint_keep_num
|
32 |
+
self.metrics_root = 'metrics/' # rewrite
|
33 |
+
self.timer = Timer()
|
34 |
+
self.host = self.learn.rank is None or self.learn.rank == 0
|
35 |
+
|
36 |
+
def _write_metrics(self, iteration:int, names:List[str], last_metrics:MetricsList)->None:
|
37 |
+
"Writes training metrics to Tensorboard."
|
38 |
+
for i, name in enumerate(names):
|
39 |
+
if last_metrics is None or len(last_metrics) < i+1: return
|
40 |
+
scalar_value = last_metrics[i]
|
41 |
+
self._write_scalar(name=name, scalar_value=scalar_value, iteration=iteration)
|
42 |
+
|
43 |
+
def _write_sub_loss(self, iteration:int, last_losses:dict)->None:
|
44 |
+
"Writes sub loss to Tensorboard."
|
45 |
+
for name, loss in last_losses.items():
|
46 |
+
scalar_value = to_np(loss)
|
47 |
+
tag = self.metrics_root + name
|
48 |
+
self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
|
49 |
+
|
50 |
+
def _save(self, name):
|
51 |
+
if isinstance(self.learn.model, DistributedDataParallel):
|
52 |
+
tmp = self.learn.model
|
53 |
+
self.learn.model = self.learn.model.module
|
54 |
+
self.learn.save(name)
|
55 |
+
self.learn.model = tmp
|
56 |
+
else: self.learn.save(name)
|
57 |
+
|
58 |
+
def _validate(self, dl=None, callbacks=None, metrics=None, keeped_items=False):
|
59 |
+
"Validate on `dl` with potential `callbacks` and `metrics`."
|
60 |
+
dl = ifnone(dl, self.learn.data.valid_dl)
|
61 |
+
metrics = ifnone(metrics, self.learn.metrics)
|
62 |
+
cb_handler = CallbackHandler(ifnone(callbacks, []), metrics)
|
63 |
+
cb_handler.on_train_begin(1, None, metrics); cb_handler.on_epoch_begin()
|
64 |
+
if keeped_items: cb_handler.state_dict.update(dict(keeped_items=[]))
|
65 |
+
val_metrics = validate(self.learn.model, dl, self.loss_func, cb_handler)
|
66 |
+
cb_handler.on_epoch_end(val_metrics)
|
67 |
+
if keeped_items: return cb_handler.state_dict['keeped_items']
|
68 |
+
else: return cb_handler.state_dict['last_metrics']
|
69 |
+
|
70 |
+
def jump_to_epoch_iter(self, epoch:int, iteration:int)->None:
|
71 |
+
try:
|
72 |
+
self.learn.load(f'{self.name}_{epoch}_{iteration}', purge=False)
|
73 |
+
logging.info(f'Loaded {self.name}_{epoch}_{iteration}')
|
74 |
+
except: logging.info(f'Model {self.name}_{epoch}_{iteration} not found.')
|
75 |
+
|
76 |
+
def on_train_begin(self, n_epochs, **kwargs):
|
77 |
+
# TODO: can not write graph here
|
78 |
+
# super().on_train_begin(**kwargs)
|
79 |
+
self.best = -float('inf')
|
80 |
+
self.timer.tic()
|
81 |
+
if self.host:
|
82 |
+
checkpoint_path = self.learn.path/'checkpoint.yaml'
|
83 |
+
if checkpoint_path.exists():
|
84 |
+
os.remove(checkpoint_path)
|
85 |
+
open(checkpoint_path, 'w').close()
|
86 |
+
return {'skip_validate': True, 'iteration':self.start_iters} # disable default validate
|
87 |
+
|
88 |
+
def on_batch_begin(self, **kwargs:Any)->None:
|
89 |
+
self.timer.toc_data()
|
90 |
+
super().on_batch_begin(**kwargs)
|
91 |
+
|
92 |
+
def on_batch_end(self, iteration, epoch, last_loss, smooth_loss, train, **kwargs):
|
93 |
+
super().on_batch_end(last_loss, iteration, train, **kwargs)
|
94 |
+
if iteration == 0: return
|
95 |
+
|
96 |
+
if iteration % self.loss_iters == 0:
|
97 |
+
last_losses = self.learn.loss_func.last_losses
|
98 |
+
self._write_sub_loss(iteration=iteration, last_losses=last_losses)
|
99 |
+
self.tbwriter.add_scalar(tag=self.metrics_root + 'lr',
|
100 |
+
scalar_value=self.opt.lr, global_step=iteration)
|
101 |
+
|
102 |
+
if iteration % self.show_iters == 0:
|
103 |
+
log_str = f'epoch {epoch} iter {iteration}: loss = {last_loss:6.4f}, ' \
|
104 |
+
f'smooth loss = {smooth_loss:6.4f}'
|
105 |
+
logging.info(log_str)
|
106 |
+
# log_str = f'data time = {self.timer.data_diff:.4f}s, runing time = {self.timer.running_diff:.4f}s'
|
107 |
+
# logging.info(log_str)
|
108 |
+
|
109 |
+
if iteration % self.eval_iters == 0:
|
110 |
+
# TODO: or remove time to on_epoch_end
|
111 |
+
# 1. Record time
|
112 |
+
log_str = f'average data time = {self.timer.average_data_time():.4f}s, ' \
|
113 |
+
f'average running time = {self.timer.average_running_time():.4f}s'
|
114 |
+
logging.info(log_str)
|
115 |
+
|
116 |
+
# 2. Call validate
|
117 |
+
last_metrics = self._validate()
|
118 |
+
self.learn.model.train()
|
119 |
+
log_str = f'epoch {epoch} iter {iteration}: eval loss = {last_metrics[0]:6.4f}, ' \
|
120 |
+
f'ccr = {last_metrics[1]:6.4f}, cwr = {last_metrics[2]:6.4f}, ' \
|
121 |
+
f'ted = {last_metrics[3]:6.4f}, ned = {last_metrics[4]:6.4f}, ' \
|
122 |
+
f'ted/w = {last_metrics[5]:6.4f}, '
|
123 |
+
logging.info(log_str)
|
124 |
+
names = ['eval_loss', 'ccr', 'cwr', 'ted', 'ned', 'ted/w']
|
125 |
+
self._write_metrics(iteration, names, last_metrics)
|
126 |
+
|
127 |
+
# 3. Save best model
|
128 |
+
current = last_metrics[2]
|
129 |
+
if current is not None and current > self.best:
|
130 |
+
logging.info(f'Better model found at epoch {epoch}, '\
|
131 |
+
f'iter {iteration} with accuracy value: {current:6.4f}.')
|
132 |
+
self.best = current
|
133 |
+
self._save(f'{self.bestname}')
|
134 |
+
|
135 |
+
if iteration % self.save_iters == 0 and self.host:
|
136 |
+
logging.info(f'Save model {self.name}_{epoch}_{iteration}')
|
137 |
+
filename = f'{self.name}_{epoch}_{iteration}'
|
138 |
+
self._save(filename)
|
139 |
+
|
140 |
+
checkpoint_path = self.learn.path/'checkpoint.yaml'
|
141 |
+
if not checkpoint_path.exists():
|
142 |
+
open(checkpoint_path, 'w').close()
|
143 |
+
with open(checkpoint_path, 'r') as file:
|
144 |
+
checkpoints = yaml.load(file, Loader=yaml.FullLoader) or dict()
|
145 |
+
checkpoints['all_checkpoints'] = (
|
146 |
+
checkpoints.get('all_checkpoints') or list())
|
147 |
+
checkpoints['all_checkpoints'].insert(0, filename)
|
148 |
+
if len(checkpoints['all_checkpoints']) > self.checpoint_keep_num:
|
149 |
+
removed_checkpoint = checkpoints['all_checkpoints'].pop()
|
150 |
+
removed_checkpoint = self.learn.path/self.learn.model_dir/f'{removed_checkpoint}.pth'
|
151 |
+
os.remove(removed_checkpoint)
|
152 |
+
checkpoints['current_checkpoint'] = filename
|
153 |
+
with open(checkpoint_path, 'w') as file:
|
154 |
+
yaml.dump(checkpoints, file)
|
155 |
+
|
156 |
+
|
157 |
+
self.timer.toc_running()
|
158 |
+
|
159 |
+
def on_train_end(self, **kwargs):
|
160 |
+
#self.learn.load(f'{self.bestname}', purge=False)
|
161 |
+
pass
|
162 |
+
|
163 |
+
def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs)->None:
|
164 |
+
self._write_embedding(iteration=iteration)
|
165 |
+
|
166 |
+
|
167 |
+
class TextAccuracy(Callback):
|
168 |
+
_names = ['ccr', 'cwr', 'ted', 'ned', 'ted/w']
|
169 |
+
def __init__(self, charset_path, max_length, case_sensitive, model_eval):
|
170 |
+
self.charset_path = charset_path
|
171 |
+
self.max_length = max_length
|
172 |
+
self.case_sensitive = case_sensitive
|
173 |
+
self.charset = CharsetMapper(charset_path, self.max_length)
|
174 |
+
self.names = self._names
|
175 |
+
|
176 |
+
self.model_eval = model_eval or 'alignment'
|
177 |
+
assert self.model_eval in ['vision', 'language', 'alignment']
|
178 |
+
|
179 |
+
def on_epoch_begin(self, **kwargs):
|
180 |
+
self.total_num_char = 0.
|
181 |
+
self.total_num_word = 0.
|
182 |
+
self.correct_num_char = 0.
|
183 |
+
self.correct_num_word = 0.
|
184 |
+
self.total_ed = 0.
|
185 |
+
self.total_ned = 0.
|
186 |
+
|
187 |
+
def _get_output(self, last_output):
|
188 |
+
if isinstance(last_output, (tuple, list)):
|
189 |
+
for res in last_output:
|
190 |
+
if res['name'] == self.model_eval: output = res
|
191 |
+
else: output = last_output
|
192 |
+
return output
|
193 |
+
|
194 |
+
def _update_output(self, last_output, items):
|
195 |
+
if isinstance(last_output, (tuple, list)):
|
196 |
+
for res in last_output:
|
197 |
+
if res['name'] == self.model_eval: res.update(items)
|
198 |
+
else: last_output.update(items)
|
199 |
+
return last_output
|
200 |
+
|
201 |
+
def on_batch_end(self, last_output, last_target, **kwargs):
|
202 |
+
output = self._get_output(last_output)
|
203 |
+
logits, pt_lengths = output['logits'], output['pt_lengths']
|
204 |
+
pt_text, pt_scores, pt_lengths_ = self.decode(logits)
|
205 |
+
assert (pt_lengths == pt_lengths_).all(), f'{pt_lengths} != {pt_lengths_} for {pt_text}'
|
206 |
+
last_output = self._update_output(last_output, {'pt_text':pt_text, 'pt_scores':pt_scores})
|
207 |
+
|
208 |
+
pt_text = [self.charset.trim(t) for t in pt_text]
|
209 |
+
label = last_target[0]
|
210 |
+
if label.dim() == 3: label = label.argmax(dim=-1) # one-hot label
|
211 |
+
gt_text = [self.charset.get_text(l, trim=True) for l in label]
|
212 |
+
|
213 |
+
for i in range(len(gt_text)):
|
214 |
+
if not self.case_sensitive:
|
215 |
+
gt_text[i], pt_text[i] = gt_text[i].lower(), pt_text[i].lower()
|
216 |
+
distance = ed.eval(gt_text[i], pt_text[i])
|
217 |
+
self.total_ed += distance
|
218 |
+
self.total_ned += float(distance) / max(len(gt_text[i]), 1)
|
219 |
+
|
220 |
+
if gt_text[i] == pt_text[i]:
|
221 |
+
self.correct_num_word += 1
|
222 |
+
self.total_num_word += 1
|
223 |
+
|
224 |
+
for j in range(min(len(gt_text[i]), len(pt_text[i]))):
|
225 |
+
if gt_text[i][j] == pt_text[i][j]:
|
226 |
+
self.correct_num_char += 1
|
227 |
+
self.total_num_char += len(gt_text[i])
|
228 |
+
|
229 |
+
return {'last_output': last_output}
|
230 |
+
|
231 |
+
def on_epoch_end(self, last_metrics, **kwargs):
|
232 |
+
mets = [self.correct_num_char / self.total_num_char,
|
233 |
+
self.correct_num_word / self.total_num_word,
|
234 |
+
self.total_ed,
|
235 |
+
self.total_ned,
|
236 |
+
self.total_ed / self.total_num_word]
|
237 |
+
return add_metrics(last_metrics, mets)
|
238 |
+
|
239 |
+
def decode(self, logit):
|
240 |
+
""" Greed decode """
|
241 |
+
# TODO: test running time and decode on GPU
|
242 |
+
out = F.softmax(logit, dim=2)
|
243 |
+
pt_text, pt_scores, pt_lengths = [], [], []
|
244 |
+
for o in out:
|
245 |
+
text = self.charset.get_text(o.argmax(dim=1), padding=False, trim=False)
|
246 |
+
text = text.split(self.charset.null_char)[0] # end at end-token
|
247 |
+
pt_text.append(text)
|
248 |
+
pt_scores.append(o.max(dim=1)[0])
|
249 |
+
pt_lengths.append(min(len(text) + 1, self.max_length)) # one for end-token
|
250 |
+
pt_scores = torch.stack(pt_scores)
|
251 |
+
pt_lengths = pt_scores.new_tensor(pt_lengths, dtype=torch.long)
|
252 |
+
return pt_text, pt_scores, pt_lengths
|
253 |
+
|
254 |
+
|
255 |
+
class TopKTextAccuracy(TextAccuracy):
|
256 |
+
_names = ['ccr', 'cwr']
|
257 |
+
def __init__(self, k, charset_path, max_length, case_sensitive, model_eval):
|
258 |
+
self.k = k
|
259 |
+
self.charset_path = charset_path
|
260 |
+
self.max_length = max_length
|
261 |
+
self.case_sensitive = case_sensitive
|
262 |
+
self.charset = CharsetMapper(charset_path, self.max_length)
|
263 |
+
self.names = self._names
|
264 |
+
|
265 |
+
def on_epoch_begin(self, **kwargs):
|
266 |
+
self.total_num_char = 0.
|
267 |
+
self.total_num_word = 0.
|
268 |
+
self.correct_num_char = 0.
|
269 |
+
self.correct_num_word = 0.
|
270 |
+
|
271 |
+
def on_batch_end(self, last_output, last_target, **kwargs):
|
272 |
+
logits, pt_lengths = last_output['logits'], last_output['pt_lengths']
|
273 |
+
gt_labels, gt_lengths = last_target[:]
|
274 |
+
|
275 |
+
for logit, pt_length, label, length in zip(logits, pt_lengths, gt_labels, gt_lengths):
|
276 |
+
word_flag = True
|
277 |
+
for i in range(length):
|
278 |
+
char_logit = logit[i].topk(self.k)[1]
|
279 |
+
char_label = label[i].argmax(-1)
|
280 |
+
if char_label in char_logit: self.correct_num_char += 1
|
281 |
+
else: word_flag = False
|
282 |
+
self.total_num_char += 1
|
283 |
+
if pt_length == length and word_flag:
|
284 |
+
self.correct_num_word += 1
|
285 |
+
self.total_num_word += 1
|
286 |
+
|
287 |
+
def on_epoch_end(self, last_metrics, **kwargs):
|
288 |
+
mets = [self.correct_num_char / self.total_num_char,
|
289 |
+
self.correct_num_word / self.total_num_word,
|
290 |
+
0., 0., 0.]
|
291 |
+
return add_metrics(last_metrics, mets)
|
292 |
+
|
293 |
+
|
294 |
+
class DumpPrediction(LearnerCallback):
|
295 |
+
|
296 |
+
def __init__(self, learn, dataset, charset_path, model_eval, image_only=False, debug=False):
|
297 |
+
super().__init__(learn=learn)
|
298 |
+
self.debug = debug
|
299 |
+
self.model_eval = model_eval or 'alignment'
|
300 |
+
self.image_only = image_only
|
301 |
+
assert self.model_eval in ['vision', 'language', 'alignment']
|
302 |
+
|
303 |
+
self.dataset, self.root = dataset, Path(self.learn.path)/f'{dataset}-{self.model_eval}'
|
304 |
+
self.attn_root = self.root/'attn'
|
305 |
+
self.charset = CharsetMapper(charset_path)
|
306 |
+
if self.root.exists(): shutil.rmtree(self.root)
|
307 |
+
self.root.mkdir(), self.attn_root.mkdir()
|
308 |
+
|
309 |
+
self.pil = transforms.ToPILImage()
|
310 |
+
self.tensor = transforms.ToTensor()
|
311 |
+
size = self.learn.data.img_h, self.learn.data.img_w
|
312 |
+
self.resize = transforms.Resize(size=size, interpolation=0)
|
313 |
+
self.c = 0
|
314 |
+
|
315 |
+
def on_batch_end(self, last_input, last_output, last_target, **kwargs):
|
316 |
+
if isinstance(last_output, (tuple, list)):
|
317 |
+
for res in last_output:
|
318 |
+
if res['name'] == self.model_eval: pt_text = res['pt_text']
|
319 |
+
if res['name'] == 'vision': attn_scores = res['attn_scores'].detach().cpu()
|
320 |
+
if res['name'] == self.model_eval: logits = res['logits']
|
321 |
+
else:
|
322 |
+
pt_text = last_output['pt_text']
|
323 |
+
attn_scores = last_output['attn_scores'].detach().cpu()
|
324 |
+
logits = last_output['logits']
|
325 |
+
|
326 |
+
images = last_input[0] if isinstance(last_input, (tuple, list)) else last_input
|
327 |
+
images = images.detach().cpu()
|
328 |
+
pt_text = [self.charset.trim(t) for t in pt_text]
|
329 |
+
gt_label = last_target[0]
|
330 |
+
if gt_label.dim() == 3: gt_label = gt_label.argmax(dim=-1) # one-hot label
|
331 |
+
gt_text = [self.charset.get_text(l, trim=True) for l in gt_label]
|
332 |
+
|
333 |
+
prediction, false_prediction = [], []
|
334 |
+
for gt, pt, image, attn, logit in zip(gt_text, pt_text, images, attn_scores, logits):
|
335 |
+
prediction.append(f'{gt}\t{pt}\n')
|
336 |
+
if gt != pt:
|
337 |
+
if self.debug:
|
338 |
+
scores = torch.softmax(logit, dim=-1)[:max(len(pt), len(gt)) + 1]
|
339 |
+
logging.info(f'{self.c} gt {gt}, pt {pt}, logit {logit.shape}, scores {scores.topk(5, dim=-1)}')
|
340 |
+
false_prediction.append(f'{gt}\t{pt}\n')
|
341 |
+
|
342 |
+
image = self.learn.data.denorm(image)
|
343 |
+
if not self.image_only:
|
344 |
+
image_np = np.array(self.pil(image))
|
345 |
+
attn_pil = [self.pil(a) for a in attn[:, None, :, :]]
|
346 |
+
attn = [self.tensor(self.resize(a)).repeat(3, 1, 1) for a in attn_pil]
|
347 |
+
attn_sum = np.array([np.array(a) for a in attn_pil[:len(pt)]]).sum(axis=0)
|
348 |
+
blended_sum = self.tensor(blend_mask(image_np, attn_sum))
|
349 |
+
blended = [self.tensor(blend_mask(image_np, np.array(a))) for a in attn_pil]
|
350 |
+
save_image = torch.stack([image] + attn + [blended_sum] + blended)
|
351 |
+
save_image = save_image.view(2, -1, *save_image.shape[1:])
|
352 |
+
save_image = save_image.permute(1, 0, 2, 3, 4).flatten(0, 1)
|
353 |
+
vutils.save_image(save_image, self.attn_root/f'{self.c}_{gt}_{pt}.jpg',
|
354 |
+
nrow=2, normalize=True, scale_each=True)
|
355 |
+
else:
|
356 |
+
self.pil(image).save(self.attn_root/f'{self.c}_{gt}_{pt}.jpg')
|
357 |
+
self.c += 1
|
358 |
+
|
359 |
+
with open(self.root/f'{self.model_eval}.txt', 'a') as f: f.writelines(prediction)
|
360 |
+
with open(self.root/f'{self.model_eval}-false.txt', 'a') as f: f.writelines(false_prediction)
|
captum/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
__version__ = "0.5.0"
|
captum/_utils/__init__.py
ADDED
File without changes
|
captum/_utils/av.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import warnings
|
7 |
+
from typing import Any, List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import captum._utils.common as common
|
10 |
+
import torch
|
11 |
+
from captum.attr import LayerActivation
|
12 |
+
from torch import Tensor
|
13 |
+
from torch.nn import Module
|
14 |
+
from torch.utils.data import DataLoader, Dataset
|
15 |
+
|
16 |
+
|
17 |
+
class AV:
|
18 |
+
r"""
|
19 |
+
This class provides functionality to store and load activation vectors
|
20 |
+
generated for pre-defined neural network layers.
|
21 |
+
It also provides functionality to check if activation vectors already
|
22 |
+
exist in the manifold and other auxiliary functions.
|
23 |
+
|
24 |
+
This class also defines a torch `Dataset`, representing Activation Vectors,
|
25 |
+
which enables lazy access to activation vectors and layer stored in the manifold.
|
26 |
+
|
27 |
+
"""
|
28 |
+
|
29 |
+
r"""
|
30 |
+
The name of the subfolder in the manifold where the activation vectors
|
31 |
+
are stored.
|
32 |
+
"""
|
33 |
+
|
34 |
+
class AVDataset(Dataset):
|
35 |
+
r"""
|
36 |
+
This dataset enables access to activation vectors for a given `model` stored
|
37 |
+
under a pre-defined path.
|
38 |
+
The iterator of this dataset returns a batch of data tensors.
|
39 |
+
Additionally, subsets of the model activations can be loaded based on layer
|
40 |
+
or identifier or num_id (representing batch number in source dataset).
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
path: str,
|
46 |
+
model_id: str,
|
47 |
+
identifier: Optional[str] = None,
|
48 |
+
layer: Optional[str] = None,
|
49 |
+
num_id: Optional[str] = None,
|
50 |
+
):
|
51 |
+
r"""
|
52 |
+
Loads into memory the list of all activation file paths associated
|
53 |
+
with the input `model_id`.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
path (str): The path where the activation vectors
|
57 |
+
for the `layer` are stored.
|
58 |
+
model_id (str): The name/version of the model for which layer
|
59 |
+
activations are being computed and stored.
|
60 |
+
identifier (str or None): An optional identifier for the layer
|
61 |
+
activations. Can be used to distinguish between activations for
|
62 |
+
different training batches.
|
63 |
+
layer (str or None): The layer for which the activation vectors
|
64 |
+
are computed.
|
65 |
+
num_id (str): An optional string representing the batch number for
|
66 |
+
which the activation vectors are computed
|
67 |
+
"""
|
68 |
+
|
69 |
+
self.av_filesearch = AV._construct_file_search(
|
70 |
+
path, model_id, identifier, layer, num_id
|
71 |
+
)
|
72 |
+
|
73 |
+
files = glob.glob(self.av_filesearch)
|
74 |
+
|
75 |
+
self.files = AV.sort_files(files)
|
76 |
+
|
77 |
+
def __getitem__(self, idx: int) -> Union[Tensor, Tuple[Tensor, ...]]:
|
78 |
+
assert idx < len(self.files), "Layer index is out of bounds!"
|
79 |
+
fl = self.files[idx]
|
80 |
+
av = torch.load(fl)
|
81 |
+
return av
|
82 |
+
|
83 |
+
def __len__(self):
|
84 |
+
return len(self.files)
|
85 |
+
|
86 |
+
AV_DIR_NAME: str = "av"
|
87 |
+
|
88 |
+
def __init__(self) -> None:
|
89 |
+
pass
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
def _assemble_model_dir(path: str, model_id: str) -> str:
|
93 |
+
r"""
|
94 |
+
Returns a directory path for the given source path `path` and `model_id.`
|
95 |
+
This path is suffixed with the '/' delimiter.
|
96 |
+
"""
|
97 |
+
return "/".join([path, AV.AV_DIR_NAME, model_id, ""])
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def _assemble_file_path(source_dir: str, identifier: str, layer: str) -> str:
|
101 |
+
r"""
|
102 |
+
Returns a full filepath given a source directory, layer, and required
|
103 |
+
identifier. The source dir is not required to end with a "/" delimiter.
|
104 |
+
"""
|
105 |
+
if not source_dir.endswith("/"):
|
106 |
+
source_dir += "/"
|
107 |
+
|
108 |
+
filepath = os.path.join(source_dir, identifier)
|
109 |
+
|
110 |
+
filepath = os.path.join(filepath, layer)
|
111 |
+
|
112 |
+
return filepath
|
113 |
+
|
114 |
+
@staticmethod
|
115 |
+
def _construct_file_search(
|
116 |
+
source_dir: str,
|
117 |
+
model_id: str,
|
118 |
+
identifier: Optional[str] = None,
|
119 |
+
layer: Optional[str] = None,
|
120 |
+
num_id: Optional[str] = None,
|
121 |
+
) -> str:
|
122 |
+
r"""
|
123 |
+
Returns a search string that can be used by glob to search `source_dir/model_id`
|
124 |
+
for the desired layer/identifier pair. Leaving `layer` as None will search ids
|
125 |
+
over all layers, and leaving `identifier` as none will search layers over all
|
126 |
+
ids. Leaving both as none will return a path to glob for every activation.
|
127 |
+
Assumes identifier is always specified when saving activations, so that
|
128 |
+
activations live at source_dir/model_id/identifier/layer
|
129 |
+
(and never source_dir/model_id/layer)
|
130 |
+
"""
|
131 |
+
|
132 |
+
av_filesearch = AV._assemble_model_dir(source_dir, model_id)
|
133 |
+
|
134 |
+
av_filesearch = os.path.join(
|
135 |
+
av_filesearch, "*" if identifier is None else identifier
|
136 |
+
)
|
137 |
+
|
138 |
+
av_filesearch = os.path.join(av_filesearch, "*" if layer is None else layer)
|
139 |
+
|
140 |
+
av_filesearch = os.path.join(
|
141 |
+
av_filesearch, "*.pt" if num_id is None else "%s.pt" % num_id
|
142 |
+
)
|
143 |
+
|
144 |
+
return av_filesearch
|
145 |
+
|
146 |
+
@staticmethod
|
147 |
+
def exists(
|
148 |
+
path: str,
|
149 |
+
model_id: str,
|
150 |
+
identifier: Optional[str] = None,
|
151 |
+
layer: Optional[str] = None,
|
152 |
+
num_id: Optional[str] = None,
|
153 |
+
) -> bool:
|
154 |
+
r"""
|
155 |
+
Verifies whether the model + layer activations exist
|
156 |
+
under the path.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
path (str): The path where the activation vectors
|
160 |
+
for the `model_id` are stored.
|
161 |
+
model_id (str): The name/version of the model for which layer activations
|
162 |
+
are being computed and stored.
|
163 |
+
identifier (str or None): An optional identifier for the layer activations.
|
164 |
+
Can be used to distinguish between activations for different
|
165 |
+
training batches. For example, the id could be a suffix composed of
|
166 |
+
a train/test label and numerical value, such as "-train-xxxxx".
|
167 |
+
The numerical id is often a monotonic sequence taken from datetime.
|
168 |
+
layer (str or None): The layer for which the activation vectors are
|
169 |
+
computed.
|
170 |
+
num_id (str): An optional string representing the batch number for which
|
171 |
+
the activation vectors are computed
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
exists (bool): Indicating whether the activation vectors for the `layer`
|
175 |
+
and `identifier` (if provided) and num_id (if provided) were stored
|
176 |
+
in the manifold. If no `identifier` is provided, will return `True`
|
177 |
+
if any layer activation exists, whether it has an identifier or
|
178 |
+
not, and vice-versa.
|
179 |
+
"""
|
180 |
+
av_dir = AV._assemble_model_dir(path, model_id)
|
181 |
+
av_filesearch = AV._construct_file_search(
|
182 |
+
path, model_id, identifier, layer, num_id
|
183 |
+
)
|
184 |
+
return os.path.exists(av_dir) and len(glob.glob(av_filesearch)) > 0
|
185 |
+
|
186 |
+
@staticmethod
|
187 |
+
def save(
|
188 |
+
path: str,
|
189 |
+
model_id: str,
|
190 |
+
identifier: str,
|
191 |
+
layers: Union[str, List[str]],
|
192 |
+
act_tensors: Union[Tensor, List[Tensor]],
|
193 |
+
num_id: str,
|
194 |
+
) -> None:
|
195 |
+
r"""
|
196 |
+
Saves the activation vectors `act_tensor` for the
|
197 |
+
`layer` under the manifold `path`.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
path (str): The path where the activation vectors
|
201 |
+
for the `layer` are stored.
|
202 |
+
model_id (str): The name/version of the model for which layer activations
|
203 |
+
are being computed and stored.
|
204 |
+
identifier (str or None): An optional identifier for the layer
|
205 |
+
activations. Can be used to distinguish between activations for
|
206 |
+
different training batches. For example, the identifier could be
|
207 |
+
a suffix composed of a train/test label and numerical value, such
|
208 |
+
as "-src-abc".
|
209 |
+
Additionally, (abc) could be a unique identifying number. For
|
210 |
+
example, it is automatically created in
|
211 |
+
AV.generate_dataset_activations from batch index.
|
212 |
+
It assumes identifier is same for all layers if a list of
|
213 |
+
`layers` is provided.
|
214 |
+
layers (str or List of str): The layer(s) for which the activation vectors
|
215 |
+
are computed.
|
216 |
+
act_tensors (Tensor or List of Tensor): A batch of activation vectors.
|
217 |
+
This must match the dimension of `layers`.
|
218 |
+
num_id (str): string representing the batch number for which the activation
|
219 |
+
vectors are computed
|
220 |
+
"""
|
221 |
+
if isinstance(layers, str):
|
222 |
+
layers = [layers]
|
223 |
+
if isinstance(act_tensors, Tensor):
|
224 |
+
act_tensors = [act_tensors]
|
225 |
+
|
226 |
+
if len(layers) != len(act_tensors):
|
227 |
+
raise ValueError("The dimension of `layers` and `act_tensors` must match!")
|
228 |
+
|
229 |
+
av_dir = AV._assemble_model_dir(path, model_id)
|
230 |
+
|
231 |
+
for i, layer in enumerate(layers):
|
232 |
+
av_save_fl_path = os.path.join(
|
233 |
+
AV._assemble_file_path(av_dir, identifier, layer), "%s.pt" % num_id
|
234 |
+
)
|
235 |
+
|
236 |
+
layer_dir = os.path.dirname(av_save_fl_path)
|
237 |
+
if not os.path.exists(layer_dir):
|
238 |
+
os.makedirs(layer_dir)
|
239 |
+
torch.save(act_tensors[i], av_save_fl_path)
|
240 |
+
|
241 |
+
@staticmethod
|
242 |
+
def load(
|
243 |
+
path: str,
|
244 |
+
model_id: str,
|
245 |
+
identifier: Optional[str] = None,
|
246 |
+
layer: Optional[str] = None,
|
247 |
+
num_id: Optional[str] = None,
|
248 |
+
) -> AVDataset:
|
249 |
+
r"""
|
250 |
+
Loads lazily the activation vectors for given `model_id` and
|
251 |
+
`layer` saved under the `path`.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
path (str): The path where the activation vectors
|
255 |
+
for the `layer` are stored.
|
256 |
+
model_id (str): The name/version of the model for which layer activations
|
257 |
+
are being computed and stored.
|
258 |
+
identifier (str or None): An optional identifier for the layer
|
259 |
+
activations. Can be used to distinguish between activations for
|
260 |
+
different training batches.
|
261 |
+
layer (str or None): The layer for which the activation vectors
|
262 |
+
are computed.
|
263 |
+
num_id (str): An optional string representing the batch number for which
|
264 |
+
the activation vectors are computed
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
dataset (AV.AVDataset): AV.AVDataset that allows to iterate
|
268 |
+
over the activation vectors for given layer, identifier (if
|
269 |
+
provided), num_id (if provided). Returning an AV.AVDataset as
|
270 |
+
opposed to a DataLoader constructed from it offers more
|
271 |
+
flexibility. Raises RuntimeError if activation vectors are not
|
272 |
+
found.
|
273 |
+
"""
|
274 |
+
|
275 |
+
av_save_dir = AV._assemble_model_dir(path, model_id)
|
276 |
+
|
277 |
+
if os.path.exists(av_save_dir):
|
278 |
+
avdataset = AV.AVDataset(path, model_id, identifier, layer, num_id)
|
279 |
+
return avdataset
|
280 |
+
else:
|
281 |
+
raise RuntimeError(
|
282 |
+
f"Activation vectors for model {model_id} was not found at path {path}"
|
283 |
+
)
|
284 |
+
|
285 |
+
@staticmethod
|
286 |
+
def _manage_loading_layers(
|
287 |
+
path: str,
|
288 |
+
model_id: str,
|
289 |
+
layers: Union[str, List[str]],
|
290 |
+
load_from_disk: bool = True,
|
291 |
+
identifier: Optional[str] = None,
|
292 |
+
num_id: Optional[str] = None,
|
293 |
+
) -> List[str]:
|
294 |
+
r"""
|
295 |
+
Returns unsaved layers, and deletes saved layers if load_from_disk is False.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
path (str): The path where the activation vectors
|
299 |
+
for the `layer` are stored.
|
300 |
+
model_id (str): The name/version of the model for which layer activations
|
301 |
+
are being computed and stored.
|
302 |
+
layers (str or List of str): The layer(s) for which the activation vectors
|
303 |
+
are computed.
|
304 |
+
identifier (str or None): An optional identifier for the layer
|
305 |
+
activations. Can be used to distinguish between activations for
|
306 |
+
different training batches.
|
307 |
+
num_id (str): An optional string representing the batch number for which the
|
308 |
+
activation vectors are computed
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
List of layer names for which activations should be generated
|
312 |
+
"""
|
313 |
+
|
314 |
+
layers = [layers] if isinstance(layers, str) else layers
|
315 |
+
unsaved_layers = []
|
316 |
+
|
317 |
+
if load_from_disk:
|
318 |
+
for layer in layers:
|
319 |
+
if not AV.exists(path, model_id, identifier, layer, num_id):
|
320 |
+
unsaved_layers.append(layer)
|
321 |
+
else:
|
322 |
+
unsaved_layers = layers
|
323 |
+
warnings.warn(
|
324 |
+
"Overwriting activations: load_from_disk is set to False. Removing all "
|
325 |
+
f"activations matching specified parameters {{path: {path}, "
|
326 |
+
f"model_id: {model_id}, layers: {layers}, identifier: {identifier}}} "
|
327 |
+
"before generating new activations."
|
328 |
+
)
|
329 |
+
for layer in layers:
|
330 |
+
files = glob.glob(
|
331 |
+
AV._construct_file_search(path, model_id, identifier, layer)
|
332 |
+
)
|
333 |
+
for filename in files:
|
334 |
+
os.remove(filename)
|
335 |
+
|
336 |
+
return unsaved_layers
|
337 |
+
|
338 |
+
@staticmethod
|
339 |
+
def _compute_and_save_activations(
|
340 |
+
path: str,
|
341 |
+
model: Module,
|
342 |
+
model_id: str,
|
343 |
+
layers: Union[str, List[str]],
|
344 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
345 |
+
identifier: str,
|
346 |
+
num_id: str,
|
347 |
+
additional_forward_args: Any = None,
|
348 |
+
load_from_disk: bool = True,
|
349 |
+
) -> None:
|
350 |
+
r"""
|
351 |
+
Computes layer activations for the given inputs and specified `layers`
|
352 |
+
|
353 |
+
Args:
|
354 |
+
path (str): The path where the activation vectors
|
355 |
+
for the `layer` are stored.
|
356 |
+
model (torch.nn.Module): An instance of pytorch model. This model should
|
357 |
+
define all of its layers as attributes of the model.
|
358 |
+
model_id (str): The name/version of the model for which layer activations
|
359 |
+
are being computed and stored.
|
360 |
+
layers (str or List of str): The layer(s) for which the activation vectors
|
361 |
+
are computed.
|
362 |
+
inputs (tensor or tuple of tensors): Batch of examples for
|
363 |
+
which influential instances are computed. They are passed to the
|
364 |
+
input `model`. The first dimension in `inputs` tensor or tuple of
|
365 |
+
tensors corresponds to the batch size.
|
366 |
+
identifier (str or None): An optional identifier for the layer
|
367 |
+
activations. Can be used to distinguish between activations for
|
368 |
+
different training batches.
|
369 |
+
num_id (str): An required string representing the batch number for which the
|
370 |
+
activation vectors are computed
|
371 |
+
additional_forward_args (optional): Additional arguments that will be
|
372 |
+
passed to `model` after inputs.
|
373 |
+
Default: None
|
374 |
+
load_from_disk (bool): Forces function to regenerate activations if False.
|
375 |
+
Default: True
|
376 |
+
"""
|
377 |
+
unsaved_layers = AV._manage_loading_layers(
|
378 |
+
path,
|
379 |
+
model_id,
|
380 |
+
layers,
|
381 |
+
load_from_disk,
|
382 |
+
identifier,
|
383 |
+
num_id,
|
384 |
+
)
|
385 |
+
layer_modules = [
|
386 |
+
common._get_module_from_name(model, layer) for layer in unsaved_layers
|
387 |
+
]
|
388 |
+
if len(unsaved_layers) > 0:
|
389 |
+
layer_act = LayerActivation(model, layer_modules)
|
390 |
+
new_activations = layer_act.attribute.__wrapped__( # type: ignore
|
391 |
+
layer_act, inputs, additional_forward_args
|
392 |
+
)
|
393 |
+
AV.save(path, model_id, identifier, unsaved_layers, new_activations, num_id)
|
394 |
+
|
395 |
+
@staticmethod
|
396 |
+
def _unpack_data(data: Union[Any, Tuple[Any, Any]]) -> Any:
|
397 |
+
r"""
|
398 |
+
Helper to extract input from labels when getting items from a Dataset. Assumes
|
399 |
+
that data is either a single value, or a tuple containing two elements.
|
400 |
+
The input could itself be a Tuple containing multiple values. If your
|
401 |
+
dataset returns a Tuple with more than 2 elements, please reformat it such that
|
402 |
+
all inputs are formatted into a tuple stored at the first position.
|
403 |
+
"""
|
404 |
+
if isinstance(data, tuple) or isinstance(data, list):
|
405 |
+
data = data[0]
|
406 |
+
return data
|
407 |
+
|
408 |
+
r"""TODO:
|
409 |
+
1. Can propagate saving labels along with activations.
|
410 |
+
2. Use of additional_forward_args when sourcing from dataset?
|
411 |
+
"""
|
412 |
+
|
413 |
+
@staticmethod
|
414 |
+
def generate_dataset_activations(
|
415 |
+
path: str,
|
416 |
+
model: Module,
|
417 |
+
model_id: str,
|
418 |
+
layers: Union[str, List[str]],
|
419 |
+
dataloader: DataLoader,
|
420 |
+
identifier: str = "default",
|
421 |
+
load_from_disk: bool = True,
|
422 |
+
return_activations: bool = False,
|
423 |
+
) -> Optional[Union[AVDataset, List[AVDataset]]]:
|
424 |
+
r"""
|
425 |
+
Computes layer activations for a source dataset and specified `layers`. Assumes
|
426 |
+
that the dataset returns a single value, or a tuple containing two elements
|
427 |
+
(see AV._unpack_data).
|
428 |
+
|
429 |
+
Args:
|
430 |
+
path (str): The path where the activation vectors
|
431 |
+
for the `layer` are stored.
|
432 |
+
module (torch.nn.Module): An instance of pytorch model. This model should
|
433 |
+
define all of its layers as attributes of the model.
|
434 |
+
model_id (str): The name/version of the model for which layer activations
|
435 |
+
are being computed and stored.
|
436 |
+
layers (str or List of str): The layer(s) for which the activation vectors
|
437 |
+
are computed.
|
438 |
+
dataloader (torch.utils.data.DataLoader): DataLoader that yields Dataset
|
439 |
+
for which influential instances are computed. They are passed to
|
440 |
+
input `model`.
|
441 |
+
identifier (str or None): An identifier for the layer
|
442 |
+
activations. Can be used to distinguish between activations for
|
443 |
+
different training batches.
|
444 |
+
Default: "default"
|
445 |
+
load_from_disk (bool): Forces function to regenerate activations if False.
|
446 |
+
Default: True
|
447 |
+
return_activations (bool, optional): Whether to return the activations.
|
448 |
+
Default: False
|
449 |
+
Returns: If `return_activations == True`, returns a single `AVDataset` if
|
450 |
+
`layers` is a str, otherwise, a list of `AVDataset`s of the length
|
451 |
+
of `layers`, where each element corresponds to a layer. In either
|
452 |
+
case, `AVDataset`'s represent the activations for a single layer,
|
453 |
+
over the entire `dataloader`. If `return_activations == False`,
|
454 |
+
does not return anything.
|
455 |
+
|
456 |
+
"""
|
457 |
+
|
458 |
+
unsaved_layers = AV._manage_loading_layers(
|
459 |
+
path,
|
460 |
+
model_id,
|
461 |
+
layers,
|
462 |
+
load_from_disk,
|
463 |
+
identifier,
|
464 |
+
)
|
465 |
+
if len(unsaved_layers) > 0:
|
466 |
+
for i, data in enumerate(dataloader):
|
467 |
+
AV._compute_and_save_activations(
|
468 |
+
path,
|
469 |
+
model,
|
470 |
+
model_id,
|
471 |
+
layers,
|
472 |
+
AV._unpack_data(data),
|
473 |
+
identifier,
|
474 |
+
str(i),
|
475 |
+
)
|
476 |
+
|
477 |
+
if not return_activations:
|
478 |
+
return None
|
479 |
+
if isinstance(layers, str):
|
480 |
+
return AV.load(path, model_id, identifier, layers)
|
481 |
+
else:
|
482 |
+
return [AV.load(path, model_id, identifier, layer) for layer in layers]
|
483 |
+
|
484 |
+
@staticmethod
|
485 |
+
def sort_files(files: List[str]) -> List[str]:
|
486 |
+
r"""
|
487 |
+
Utility for sorting files based on natural sorting instead of the default
|
488 |
+
lexigraphical sort.
|
489 |
+
"""
|
490 |
+
|
491 |
+
def split_alphanum(s):
|
492 |
+
r"""
|
493 |
+
Splits string into a list of strings and numbers
|
494 |
+
"z23a" -> ["z", 23, "a"]
|
495 |
+
"""
|
496 |
+
|
497 |
+
return [int(x) if x.isdigit() else x for x in re.split("([0-9]+)", s)]
|
498 |
+
|
499 |
+
return sorted(files, key=split_alphanum)
|
captum/_utils/common.py
ADDED
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import typing
|
3 |
+
from enum import Enum
|
4 |
+
from functools import reduce
|
5 |
+
from inspect import signature
|
6 |
+
from typing import Any, Callable, cast, Dict, List, overload, Tuple, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from captum._utils.typing import (
|
11 |
+
BaselineType,
|
12 |
+
Literal,
|
13 |
+
TargetType,
|
14 |
+
TensorOrTupleOfTensorsGeneric,
|
15 |
+
TupleOrTensorOrBoolGeneric,
|
16 |
+
)
|
17 |
+
from torch import device, Tensor
|
18 |
+
from torch.nn import Module
|
19 |
+
|
20 |
+
|
21 |
+
class ExpansionTypes(Enum):
|
22 |
+
repeat = 1
|
23 |
+
repeat_interleave = 2
|
24 |
+
|
25 |
+
|
26 |
+
def safe_div(
|
27 |
+
numerator: Tensor,
|
28 |
+
denom: Union[Tensor, int, float],
|
29 |
+
default_denom: Union[Tensor, int, float] = 1.0,
|
30 |
+
) -> Tensor:
|
31 |
+
r"""
|
32 |
+
A simple utility function to perform `numerator / denom`
|
33 |
+
if the statement is undefined => result will be `numerator / default_denorm`
|
34 |
+
"""
|
35 |
+
if isinstance(denom, (int, float)):
|
36 |
+
return numerator / (denom if denom != 0 else default_denom)
|
37 |
+
|
38 |
+
# convert default_denom to tensor if it is float
|
39 |
+
if not torch.is_tensor(default_denom):
|
40 |
+
default_denom = torch.tensor(
|
41 |
+
default_denom, dtype=denom.dtype, device=denom.device
|
42 |
+
)
|
43 |
+
|
44 |
+
return numerator / torch.where(denom != 0, denom, default_denom)
|
45 |
+
|
46 |
+
|
47 |
+
@typing.overload
|
48 |
+
def _is_tuple(inputs: Tensor) -> Literal[False]:
|
49 |
+
...
|
50 |
+
|
51 |
+
|
52 |
+
@typing.overload
|
53 |
+
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]:
|
54 |
+
...
|
55 |
+
|
56 |
+
|
57 |
+
def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
|
58 |
+
return isinstance(inputs, tuple)
|
59 |
+
|
60 |
+
|
61 |
+
def _validate_target(num_samples: int, target: TargetType) -> None:
|
62 |
+
if isinstance(target, list) or (
|
63 |
+
isinstance(target, torch.Tensor) and torch.numel(target) > 1
|
64 |
+
):
|
65 |
+
assert num_samples == len(target), (
|
66 |
+
"The number of samples provied in the"
|
67 |
+
"input {} does not match with the number of targets. {}".format(
|
68 |
+
num_samples, len(target)
|
69 |
+
)
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
def _validate_input(
|
74 |
+
inputs: Tuple[Tensor, ...],
|
75 |
+
baselines: Tuple[Union[Tensor, int, float], ...],
|
76 |
+
draw_baseline_from_distrib: bool = False,
|
77 |
+
) -> None:
|
78 |
+
assert len(inputs) == len(baselines), (
|
79 |
+
"Input and baseline must have the same "
|
80 |
+
"dimensions, baseline has {} features whereas input has {}.".format(
|
81 |
+
len(baselines), len(inputs)
|
82 |
+
)
|
83 |
+
)
|
84 |
+
|
85 |
+
for input, baseline in zip(inputs, baselines):
|
86 |
+
if draw_baseline_from_distrib:
|
87 |
+
assert (
|
88 |
+
isinstance(baseline, (int, float))
|
89 |
+
or input.shape[1:] == baseline.shape[1:]
|
90 |
+
), (
|
91 |
+
"The samples in input and baseline batches must have"
|
92 |
+
" the same shape or the baseline corresponding to the"
|
93 |
+
" input tensor must be a scalar."
|
94 |
+
" Found baseline: {} and input: {} ".format(baseline, input)
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
assert (
|
98 |
+
isinstance(baseline, (int, float))
|
99 |
+
or input.shape == baseline.shape
|
100 |
+
or baseline.shape[0] == 1
|
101 |
+
), (
|
102 |
+
"Baseline can be provided as a tensor for just one input and"
|
103 |
+
" broadcasted to the batch or input and baseline must have the"
|
104 |
+
" same shape or the baseline corresponding to each input tensor"
|
105 |
+
" must be a scalar. Found baseline: {} and input: {}".format(
|
106 |
+
baseline, input
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
def _zeros(inputs: Tuple[Tensor, ...]) -> Tuple[int, ...]:
|
112 |
+
r"""
|
113 |
+
Takes a tuple of tensors as input and returns a tuple that has the same
|
114 |
+
length as `inputs` with each element as the integer 0.
|
115 |
+
"""
|
116 |
+
return tuple(0 if input.dtype is not torch.bool else False for input in inputs)
|
117 |
+
|
118 |
+
|
119 |
+
def _format_baseline(
|
120 |
+
baselines: BaselineType, inputs: Tuple[Tensor, ...]
|
121 |
+
) -> Tuple[Union[Tensor, int, float], ...]:
|
122 |
+
if baselines is None:
|
123 |
+
return _zeros(inputs)
|
124 |
+
|
125 |
+
if not isinstance(baselines, tuple):
|
126 |
+
baselines = (baselines,)
|
127 |
+
|
128 |
+
for baseline in baselines:
|
129 |
+
assert isinstance(
|
130 |
+
baseline, (torch.Tensor, int, float)
|
131 |
+
), "baseline input argument must be either a torch.Tensor or a number \
|
132 |
+
however {} detected".format(
|
133 |
+
type(baseline)
|
134 |
+
)
|
135 |
+
|
136 |
+
return baselines
|
137 |
+
|
138 |
+
|
139 |
+
@overload
|
140 |
+
def _format_tensor_into_tuples(inputs: None) -> None:
|
141 |
+
...
|
142 |
+
|
143 |
+
|
144 |
+
@overload
|
145 |
+
def _format_tensor_into_tuples(
|
146 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]]
|
147 |
+
) -> Tuple[Tensor, ...]:
|
148 |
+
...
|
149 |
+
|
150 |
+
|
151 |
+
def _format_tensor_into_tuples(
|
152 |
+
inputs: Union[None, Tensor, Tuple[Tensor, ...]]
|
153 |
+
) -> Union[None, Tuple[Tensor, ...]]:
|
154 |
+
if inputs is None:
|
155 |
+
return None
|
156 |
+
if not isinstance(inputs, tuple):
|
157 |
+
assert isinstance(
|
158 |
+
inputs, torch.Tensor
|
159 |
+
), "`inputs` must have type " "torch.Tensor but {} found: ".format(type(inputs))
|
160 |
+
inputs = (inputs,)
|
161 |
+
return inputs
|
162 |
+
|
163 |
+
|
164 |
+
def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any:
|
165 |
+
return (
|
166 |
+
inputs
|
167 |
+
if (isinstance(inputs, tuple) or isinstance(inputs, list)) and unpack_inputs
|
168 |
+
else (inputs,)
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
def _format_float_or_tensor_into_tuples(
|
173 |
+
inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]
|
174 |
+
) -> Tuple[Union[float, Tensor], ...]:
|
175 |
+
if not isinstance(inputs, tuple):
|
176 |
+
assert isinstance(
|
177 |
+
inputs, (torch.Tensor, float)
|
178 |
+
), "`inputs` must have type float or torch.Tensor but {} found: ".format(
|
179 |
+
type(inputs)
|
180 |
+
)
|
181 |
+
inputs = (inputs,)
|
182 |
+
return inputs
|
183 |
+
|
184 |
+
|
185 |
+
@overload
|
186 |
+
def _format_additional_forward_args(additional_forward_args: None) -> None:
|
187 |
+
...
|
188 |
+
|
189 |
+
|
190 |
+
@overload
|
191 |
+
def _format_additional_forward_args(
|
192 |
+
additional_forward_args: Union[Tensor, Tuple]
|
193 |
+
) -> Tuple:
|
194 |
+
...
|
195 |
+
|
196 |
+
|
197 |
+
@overload
|
198 |
+
def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, Tuple]:
|
199 |
+
...
|
200 |
+
|
201 |
+
|
202 |
+
def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, Tuple]:
|
203 |
+
if additional_forward_args is not None and not isinstance(
|
204 |
+
additional_forward_args, tuple
|
205 |
+
):
|
206 |
+
additional_forward_args = (additional_forward_args,)
|
207 |
+
return additional_forward_args
|
208 |
+
|
209 |
+
|
210 |
+
def _expand_additional_forward_args(
|
211 |
+
additional_forward_args: Any,
|
212 |
+
n_steps: int,
|
213 |
+
expansion_type: ExpansionTypes = ExpansionTypes.repeat,
|
214 |
+
) -> Union[None, Tuple]:
|
215 |
+
def _expand_tensor_forward_arg(
|
216 |
+
additional_forward_arg: Tensor,
|
217 |
+
n_steps: int,
|
218 |
+
expansion_type: ExpansionTypes = ExpansionTypes.repeat,
|
219 |
+
) -> Tensor:
|
220 |
+
if len(additional_forward_arg.size()) == 0:
|
221 |
+
return additional_forward_arg
|
222 |
+
if expansion_type == ExpansionTypes.repeat:
|
223 |
+
return torch.cat([additional_forward_arg] * n_steps, dim=0)
|
224 |
+
elif expansion_type == ExpansionTypes.repeat_interleave:
|
225 |
+
return additional_forward_arg.repeat_interleave(n_steps, dim=0)
|
226 |
+
else:
|
227 |
+
raise NotImplementedError(
|
228 |
+
"Currently only `repeat` and `repeat_interleave`"
|
229 |
+
" expansion_types are supported"
|
230 |
+
)
|
231 |
+
|
232 |
+
if additional_forward_args is None:
|
233 |
+
return None
|
234 |
+
|
235 |
+
return tuple(
|
236 |
+
_expand_tensor_forward_arg(additional_forward_arg, n_steps, expansion_type)
|
237 |
+
if isinstance(additional_forward_arg, torch.Tensor)
|
238 |
+
else additional_forward_arg
|
239 |
+
for additional_forward_arg in additional_forward_args
|
240 |
+
)
|
241 |
+
|
242 |
+
|
243 |
+
def _expand_target(
|
244 |
+
target: TargetType,
|
245 |
+
n_steps: int,
|
246 |
+
expansion_type: ExpansionTypes = ExpansionTypes.repeat,
|
247 |
+
) -> TargetType:
|
248 |
+
if isinstance(target, list):
|
249 |
+
if expansion_type == ExpansionTypes.repeat:
|
250 |
+
return target * n_steps
|
251 |
+
elif expansion_type == ExpansionTypes.repeat_interleave:
|
252 |
+
expanded_target = []
|
253 |
+
for i in target:
|
254 |
+
expanded_target.extend([i] * n_steps)
|
255 |
+
return cast(Union[List[Tuple[int, ...]], List[int]], expanded_target)
|
256 |
+
else:
|
257 |
+
raise NotImplementedError(
|
258 |
+
"Currently only `repeat` and `repeat_interleave`"
|
259 |
+
" expansion_types are supported"
|
260 |
+
)
|
261 |
+
|
262 |
+
elif isinstance(target, torch.Tensor) and torch.numel(target) > 1:
|
263 |
+
if expansion_type == ExpansionTypes.repeat:
|
264 |
+
return torch.cat([target] * n_steps, dim=0)
|
265 |
+
elif expansion_type == ExpansionTypes.repeat_interleave:
|
266 |
+
return target.repeat_interleave(n_steps, dim=0)
|
267 |
+
else:
|
268 |
+
raise NotImplementedError(
|
269 |
+
"Currently only `repeat` and `repeat_interleave`"
|
270 |
+
" expansion_types are supported"
|
271 |
+
)
|
272 |
+
|
273 |
+
return target
|
274 |
+
|
275 |
+
|
276 |
+
def _expand_feature_mask(
|
277 |
+
feature_mask: Union[Tensor, Tuple[Tensor, ...]], n_samples: int
|
278 |
+
):
|
279 |
+
is_feature_mask_tuple = _is_tuple(feature_mask)
|
280 |
+
feature_mask = _format_tensor_into_tuples(feature_mask)
|
281 |
+
feature_mask_new = tuple(
|
282 |
+
feature_mask_elem.repeat_interleave(n_samples, dim=0)
|
283 |
+
if feature_mask_elem.size(0) > 1
|
284 |
+
else feature_mask_elem
|
285 |
+
for feature_mask_elem in feature_mask
|
286 |
+
)
|
287 |
+
return _format_output(is_feature_mask_tuple, feature_mask_new)
|
288 |
+
|
289 |
+
|
290 |
+
def _expand_and_update_baselines(
|
291 |
+
inputs: Tuple[Tensor, ...],
|
292 |
+
n_samples: int,
|
293 |
+
kwargs: dict,
|
294 |
+
draw_baseline_from_distrib: bool = False,
|
295 |
+
):
|
296 |
+
def get_random_baseline_indices(bsz, baseline):
|
297 |
+
num_ref_samples = baseline.shape[0]
|
298 |
+
return np.random.choice(num_ref_samples, n_samples * bsz).tolist()
|
299 |
+
|
300 |
+
# expand baselines to match the sizes of input
|
301 |
+
if "baselines" not in kwargs:
|
302 |
+
return
|
303 |
+
|
304 |
+
baselines = kwargs["baselines"]
|
305 |
+
baselines = _format_baseline(baselines, inputs)
|
306 |
+
_validate_input(
|
307 |
+
inputs, baselines, draw_baseline_from_distrib=draw_baseline_from_distrib
|
308 |
+
)
|
309 |
+
|
310 |
+
if draw_baseline_from_distrib:
|
311 |
+
bsz = inputs[0].shape[0]
|
312 |
+
baselines = tuple(
|
313 |
+
baseline[get_random_baseline_indices(bsz, baseline)]
|
314 |
+
if isinstance(baseline, torch.Tensor)
|
315 |
+
else baseline
|
316 |
+
for baseline in baselines
|
317 |
+
)
|
318 |
+
else:
|
319 |
+
baselines = tuple(
|
320 |
+
baseline.repeat_interleave(n_samples, dim=0)
|
321 |
+
if isinstance(baseline, torch.Tensor)
|
322 |
+
and baseline.shape[0] == input.shape[0]
|
323 |
+
and baseline.shape[0] > 1
|
324 |
+
else baseline
|
325 |
+
for input, baseline in zip(inputs, baselines)
|
326 |
+
)
|
327 |
+
# update kwargs with expanded baseline
|
328 |
+
kwargs["baselines"] = baselines
|
329 |
+
|
330 |
+
|
331 |
+
def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict):
|
332 |
+
if "additional_forward_args" not in kwargs:
|
333 |
+
return
|
334 |
+
additional_forward_args = kwargs["additional_forward_args"]
|
335 |
+
additional_forward_args = _format_additional_forward_args(additional_forward_args)
|
336 |
+
if additional_forward_args is None:
|
337 |
+
return
|
338 |
+
additional_forward_args = _expand_additional_forward_args(
|
339 |
+
additional_forward_args,
|
340 |
+
n_samples,
|
341 |
+
expansion_type=ExpansionTypes.repeat_interleave,
|
342 |
+
)
|
343 |
+
# update kwargs with expanded baseline
|
344 |
+
kwargs["additional_forward_args"] = additional_forward_args
|
345 |
+
|
346 |
+
|
347 |
+
def _expand_and_update_target(n_samples: int, kwargs: dict):
|
348 |
+
if "target" not in kwargs:
|
349 |
+
return
|
350 |
+
target = kwargs["target"]
|
351 |
+
target = _expand_target(
|
352 |
+
target, n_samples, expansion_type=ExpansionTypes.repeat_interleave
|
353 |
+
)
|
354 |
+
# update kwargs with expanded baseline
|
355 |
+
kwargs["target"] = target
|
356 |
+
|
357 |
+
|
358 |
+
def _expand_and_update_feature_mask(n_samples: int, kwargs: dict):
|
359 |
+
if "feature_mask" not in kwargs:
|
360 |
+
return
|
361 |
+
|
362 |
+
feature_mask = kwargs["feature_mask"]
|
363 |
+
if feature_mask is None:
|
364 |
+
return
|
365 |
+
|
366 |
+
feature_mask = _expand_feature_mask(feature_mask, n_samples)
|
367 |
+
kwargs["feature_mask"] = feature_mask
|
368 |
+
|
369 |
+
|
370 |
+
@typing.overload
|
371 |
+
def _format_output(
|
372 |
+
is_inputs_tuple: Literal[True], output: Tuple[Tensor, ...]
|
373 |
+
) -> Tuple[Tensor, ...]:
|
374 |
+
...
|
375 |
+
|
376 |
+
|
377 |
+
@typing.overload
|
378 |
+
def _format_output(
|
379 |
+
is_inputs_tuple: Literal[False], output: Tuple[Tensor, ...]
|
380 |
+
) -> Tensor:
|
381 |
+
...
|
382 |
+
|
383 |
+
|
384 |
+
@typing.overload
|
385 |
+
def _format_output(
|
386 |
+
is_inputs_tuple: bool, output: Tuple[Tensor, ...]
|
387 |
+
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
388 |
+
...
|
389 |
+
|
390 |
+
|
391 |
+
def _format_output(
|
392 |
+
is_inputs_tuple: bool, output: Tuple[Tensor, ...]
|
393 |
+
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
394 |
+
r"""
|
395 |
+
In case input is a tensor and the output is returned in form of a
|
396 |
+
tuple we take the first element of the output's tuple to match the
|
397 |
+
same shape signatues of the inputs
|
398 |
+
"""
|
399 |
+
assert isinstance(output, tuple), "Output must be in shape of a tuple"
|
400 |
+
assert is_inputs_tuple or len(output) == 1, (
|
401 |
+
"The input is a single tensor however the output isn't."
|
402 |
+
"The number of output tensors is: {}".format(len(output))
|
403 |
+
)
|
404 |
+
return output if is_inputs_tuple else output[0]
|
405 |
+
|
406 |
+
|
407 |
+
@typing.overload
|
408 |
+
def _format_outputs(
|
409 |
+
is_multiple_inputs: Literal[False], outputs: List[Tuple[Tensor, ...]]
|
410 |
+
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
411 |
+
...
|
412 |
+
|
413 |
+
|
414 |
+
@typing.overload
|
415 |
+
def _format_outputs(
|
416 |
+
is_multiple_inputs: Literal[True], outputs: List[Tuple[Tensor, ...]]
|
417 |
+
) -> List[Union[Tensor, Tuple[Tensor, ...]]]:
|
418 |
+
...
|
419 |
+
|
420 |
+
|
421 |
+
@typing.overload
|
422 |
+
def _format_outputs(
|
423 |
+
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
|
424 |
+
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
|
425 |
+
...
|
426 |
+
|
427 |
+
|
428 |
+
def _format_outputs(
|
429 |
+
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
|
430 |
+
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
|
431 |
+
assert isinstance(outputs, list), "Outputs must be a list"
|
432 |
+
assert is_multiple_inputs or len(outputs) == 1, (
|
433 |
+
"outputs should contain multiple inputs or have a single output"
|
434 |
+
f"however the number of outputs is: {len(outputs)}"
|
435 |
+
)
|
436 |
+
|
437 |
+
return (
|
438 |
+
[_format_output(len(output) > 1, output) for output in outputs]
|
439 |
+
if is_multiple_inputs
|
440 |
+
else _format_output(len(outputs[0]) > 1, outputs[0])
|
441 |
+
)
|
442 |
+
|
443 |
+
|
444 |
+
def _run_forward(
|
445 |
+
forward_func: Callable,
|
446 |
+
inputs: Any,
|
447 |
+
target: TargetType = None,
|
448 |
+
additional_forward_args: Any = None,
|
449 |
+
) -> Tensor:
|
450 |
+
forward_func_args = signature(forward_func).parameters
|
451 |
+
if len(forward_func_args) == 0:
|
452 |
+
output = forward_func()
|
453 |
+
return output if target is None else _select_targets(output, target)
|
454 |
+
|
455 |
+
# make everything a tuple so that it is easy to unpack without
|
456 |
+
# using if-statements
|
457 |
+
inputs = _format_inputs(inputs)
|
458 |
+
additional_forward_args = _format_additional_forward_args(additional_forward_args)
|
459 |
+
|
460 |
+
output = forward_func(
|
461 |
+
*(*inputs, *additional_forward_args)
|
462 |
+
if additional_forward_args is not None
|
463 |
+
else inputs
|
464 |
+
)
|
465 |
+
return _select_targets(output, target)
|
466 |
+
|
467 |
+
|
468 |
+
def _select_targets(output: Tensor, target: TargetType) -> Tensor:
|
469 |
+
if target is None:
|
470 |
+
return output
|
471 |
+
|
472 |
+
num_examples = output.shape[0]
|
473 |
+
dims = len(output.shape)
|
474 |
+
device = output.device
|
475 |
+
if isinstance(target, (int, tuple)):
|
476 |
+
return _verify_select_column(output, target)
|
477 |
+
elif isinstance(target, torch.Tensor):
|
478 |
+
if torch.numel(target) == 1 and isinstance(target.item(), int):
|
479 |
+
return _verify_select_column(output, cast(int, target.item()))
|
480 |
+
elif len(target.shape) == 1 and torch.numel(target) == num_examples:
|
481 |
+
assert dims == 2, "Output must be 2D to select tensor of targets."
|
482 |
+
return torch.gather(output, 1, target.reshape(len(output), 1))
|
483 |
+
else:
|
484 |
+
raise AssertionError(
|
485 |
+
"Tensor target dimension %r is not valid. %r"
|
486 |
+
% (target.shape, output.shape)
|
487 |
+
)
|
488 |
+
elif isinstance(target, list):
|
489 |
+
assert len(target) == num_examples, "Target list length does not match output!"
|
490 |
+
if isinstance(target[0], int):
|
491 |
+
assert dims == 2, "Output must be 2D to select tensor of targets."
|
492 |
+
return torch.gather(
|
493 |
+
output, 1, torch.tensor(target, device=device).reshape(len(output), 1)
|
494 |
+
)
|
495 |
+
elif isinstance(target[0], tuple):
|
496 |
+
return torch.stack(
|
497 |
+
[
|
498 |
+
output[(i,) + cast(Tuple, targ_elem)]
|
499 |
+
for i, targ_elem in enumerate(target)
|
500 |
+
]
|
501 |
+
)
|
502 |
+
else:
|
503 |
+
raise AssertionError("Target element type in list is not valid.")
|
504 |
+
else:
|
505 |
+
raise AssertionError("Target type %r is not valid." % target)
|
506 |
+
|
507 |
+
|
508 |
+
def _contains_slice(target: Union[int, Tuple[Union[int, slice], ...]]) -> bool:
|
509 |
+
if isinstance(target, tuple):
|
510 |
+
for index in target:
|
511 |
+
if isinstance(index, slice):
|
512 |
+
return True
|
513 |
+
return False
|
514 |
+
return isinstance(target, slice)
|
515 |
+
|
516 |
+
|
517 |
+
def _verify_select_column(
|
518 |
+
output: Tensor, target: Union[int, Tuple[Union[int, slice], ...]]
|
519 |
+
) -> Tensor:
|
520 |
+
target = (target,) if isinstance(target, int) else target
|
521 |
+
assert (
|
522 |
+
len(target) <= len(output.shape) - 1
|
523 |
+
), "Cannot choose target column with output shape %r." % (output.shape,)
|
524 |
+
return output[(slice(None), *target)]
|
525 |
+
|
526 |
+
|
527 |
+
def _verify_select_neuron(
|
528 |
+
layer_output: Tuple[Tensor, ...],
|
529 |
+
selector: Union[int, Tuple[Union[int, slice], ...], Callable],
|
530 |
+
) -> Tensor:
|
531 |
+
if callable(selector):
|
532 |
+
return selector(layer_output if len(layer_output) > 1 else layer_output[0])
|
533 |
+
|
534 |
+
assert len(layer_output) == 1, (
|
535 |
+
"Cannot select neuron index from layer with multiple tensors,"
|
536 |
+
"consider providing a neuron selector function instead."
|
537 |
+
)
|
538 |
+
|
539 |
+
selected_neurons = _verify_select_column(layer_output[0], selector)
|
540 |
+
if _contains_slice(selector):
|
541 |
+
return selected_neurons.reshape(selected_neurons.shape[0], -1).sum(1)
|
542 |
+
return selected_neurons
|
543 |
+
|
544 |
+
|
545 |
+
def _extract_device(
|
546 |
+
module: Module,
|
547 |
+
hook_inputs: Union[None, Tensor, Tuple[Tensor, ...]],
|
548 |
+
hook_outputs: Union[None, Tensor, Tuple[Tensor, ...]],
|
549 |
+
) -> device:
|
550 |
+
params = list(module.parameters())
|
551 |
+
if (
|
552 |
+
(hook_inputs is None or len(hook_inputs) == 0)
|
553 |
+
and (hook_outputs is None or len(hook_outputs) == 0)
|
554 |
+
and len(params) == 0
|
555 |
+
):
|
556 |
+
raise RuntimeError(
|
557 |
+
"""Unable to extract device information for the module
|
558 |
+
{}. Both inputs and outputs to the forward hook and
|
559 |
+
`module.parameters()` are empty.
|
560 |
+
The reason that the inputs to the forward hook are empty
|
561 |
+
could be due to the fact that the arguments to that
|
562 |
+
module {} are all named and are passed as named
|
563 |
+
variables to its forward function.
|
564 |
+
""".format(
|
565 |
+
module, module
|
566 |
+
)
|
567 |
+
)
|
568 |
+
if hook_inputs is not None and len(hook_inputs) > 0:
|
569 |
+
return hook_inputs[0].device
|
570 |
+
if hook_outputs is not None and len(hook_outputs) > 0:
|
571 |
+
return hook_outputs[0].device
|
572 |
+
|
573 |
+
return params[0].device
|
574 |
+
|
575 |
+
|
576 |
+
def _reduce_list(
|
577 |
+
val_list: List[TupleOrTensorOrBoolGeneric],
|
578 |
+
red_func: Callable[[List], Any] = torch.cat,
|
579 |
+
) -> TupleOrTensorOrBoolGeneric:
|
580 |
+
"""
|
581 |
+
Applies reduction function to given list. If each element in the list is
|
582 |
+
a Tensor, applies reduction function to all elements of the list, and returns
|
583 |
+
the output Tensor / value. If each element is a boolean, apply any method (or).
|
584 |
+
If each element is a tuple, applies reduction
|
585 |
+
function to corresponding elements of each tuple in the list, and returns
|
586 |
+
tuple of reduction function outputs with length matching the length of tuple
|
587 |
+
val_list[0]. It is assumed that all tuples in the list have the same length
|
588 |
+
and red_func can be applied to all elements in each corresponding position.
|
589 |
+
"""
|
590 |
+
assert len(val_list) > 0, "Cannot reduce empty list!"
|
591 |
+
if isinstance(val_list[0], torch.Tensor):
|
592 |
+
first_device = val_list[0].device
|
593 |
+
return red_func([elem.to(first_device) for elem in val_list])
|
594 |
+
elif isinstance(val_list[0], bool):
|
595 |
+
return any(val_list)
|
596 |
+
elif isinstance(val_list[0], tuple):
|
597 |
+
final_out = []
|
598 |
+
for i in range(len(val_list[0])):
|
599 |
+
final_out.append(
|
600 |
+
_reduce_list([val_elem[i] for val_elem in val_list], red_func)
|
601 |
+
)
|
602 |
+
else:
|
603 |
+
raise AssertionError(
|
604 |
+
"Elements to be reduced can only be"
|
605 |
+
"either Tensors or tuples containing Tensors."
|
606 |
+
)
|
607 |
+
return tuple(final_out)
|
608 |
+
|
609 |
+
|
610 |
+
def _sort_key_list(
|
611 |
+
keys: List[device], device_ids: Union[None, List[int]] = None
|
612 |
+
) -> List[device]:
|
613 |
+
"""
|
614 |
+
Sorts list of torch devices (keys) by given index list, device_ids. If keys
|
615 |
+
contains only one device, then the list is returned unchanged. If keys
|
616 |
+
contains a device for which the id is not contained in device_ids, then
|
617 |
+
an error is returned. This method is used to identify the order of DataParallel
|
618 |
+
batched devices, given the device ID ordering.
|
619 |
+
"""
|
620 |
+
if len(keys) == 1:
|
621 |
+
return keys
|
622 |
+
id_dict: Dict[int, device] = {}
|
623 |
+
assert device_ids is not None, "Device IDs must be provided with multiple devices."
|
624 |
+
for key in keys:
|
625 |
+
if key.index in id_dict:
|
626 |
+
raise AssertionError("Duplicate CUDA Device ID identified in device list.")
|
627 |
+
id_dict[key.index] = key
|
628 |
+
|
629 |
+
out_list = [
|
630 |
+
id_dict[device_id]
|
631 |
+
for device_id in filter(lambda device_id: device_id in id_dict, device_ids)
|
632 |
+
]
|
633 |
+
|
634 |
+
assert len(out_list) == len(keys), "Given Device ID List does not match"
|
635 |
+
"devices with computed tensors."
|
636 |
+
|
637 |
+
return out_list
|
638 |
+
|
639 |
+
|
640 |
+
def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor:
|
641 |
+
if isinstance(inp, Tensor):
|
642 |
+
return inp.flatten()
|
643 |
+
return torch.cat([single_inp.flatten() for single_inp in inp])
|
644 |
+
|
645 |
+
|
646 |
+
def _get_module_from_name(model: Module, layer_name: str) -> Any:
|
647 |
+
r"""
|
648 |
+
Returns the module (layer) object, given its (string) name
|
649 |
+
in the model.
|
650 |
+
|
651 |
+
Args:
|
652 |
+
name (str): Module or nested modules name string in self.model
|
653 |
+
|
654 |
+
Returns:
|
655 |
+
The module (layer) in self.model.
|
656 |
+
"""
|
657 |
+
|
658 |
+
return reduce(getattr, layer_name.split("."), model)
|
659 |
+
|
660 |
+
|
661 |
+
def _register_backward_hook(
|
662 |
+
module: Module, hook: Callable, attr_obj: Any
|
663 |
+
) -> torch.utils.hooks.RemovableHandle:
|
664 |
+
# Special case for supporting output attributions for neuron methods
|
665 |
+
# This can be removed after deprecation of neuron output attributions
|
666 |
+
# for NeuronDeepLift, NeuronDeconvolution, and NeuronGuidedBackprop
|
667 |
+
# in v0.6.0
|
668 |
+
if (
|
669 |
+
hasattr(attr_obj, "skip_new_hook_layer")
|
670 |
+
and attr_obj.skip_new_hook_layer == module
|
671 |
+
):
|
672 |
+
return module.register_backward_hook(hook)
|
673 |
+
|
674 |
+
if torch.__version__ >= "1.9":
|
675 |
+
# Only supported for torch >= 1.9
|
676 |
+
return module.register_full_backward_hook(hook)
|
677 |
+
else:
|
678 |
+
# Fallback for previous versions of PyTorch
|
679 |
+
return module.register_backward_hook(hook)
|
captum/_utils/gradient.py
ADDED
@@ -0,0 +1,865 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import threading
|
3 |
+
import typing
|
4 |
+
import warnings
|
5 |
+
from collections import defaultdict
|
6 |
+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from captum._utils.common import (
|
10 |
+
_reduce_list,
|
11 |
+
_run_forward,
|
12 |
+
_sort_key_list,
|
13 |
+
_verify_select_neuron,
|
14 |
+
)
|
15 |
+
from captum._utils.sample_gradient import SampleGradientWrapper
|
16 |
+
from captum._utils.typing import (
|
17 |
+
Literal,
|
18 |
+
ModuleOrModuleList,
|
19 |
+
TargetType,
|
20 |
+
TensorOrTupleOfTensorsGeneric,
|
21 |
+
)
|
22 |
+
from torch import device, Tensor
|
23 |
+
from torch.nn import Module
|
24 |
+
|
25 |
+
|
26 |
+
def apply_gradient_requirements(
|
27 |
+
inputs: Tuple[Tensor, ...], warn: bool = True
|
28 |
+
) -> List[bool]:
|
29 |
+
"""
|
30 |
+
Iterates through tuple on input tensors and sets requires_grad to be true on
|
31 |
+
each Tensor, and ensures all grads are set to zero. To ensure that the input
|
32 |
+
is returned to its initial state, a list of flags representing whether or not
|
33 |
+
a tensor originally required grad is returned.
|
34 |
+
"""
|
35 |
+
assert isinstance(
|
36 |
+
inputs, tuple
|
37 |
+
), "Inputs should be wrapped in a tuple prior to preparing for gradients"
|
38 |
+
grad_required = []
|
39 |
+
for index, input in enumerate(inputs):
|
40 |
+
assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor"
|
41 |
+
grad_required.append(input.requires_grad)
|
42 |
+
inputs_dtype = input.dtype
|
43 |
+
# Note: torch 1.2 doesn't support is_complex for dtype that's why we check
|
44 |
+
# on the existance of is_complex method.
|
45 |
+
if not inputs_dtype.is_floating_point and not (
|
46 |
+
hasattr(inputs_dtype, "is_complex") and inputs_dtype.is_complex
|
47 |
+
):
|
48 |
+
if warn:
|
49 |
+
warnings.warn(
|
50 |
+
"""Input Tensor %d has a dtype of %s.
|
51 |
+
Gradients cannot be activated
|
52 |
+
for these data types."""
|
53 |
+
% (index, str(inputs_dtype))
|
54 |
+
)
|
55 |
+
elif not input.requires_grad:
|
56 |
+
if warn:
|
57 |
+
warnings.warn(
|
58 |
+
"Input Tensor %d did not already require gradients, "
|
59 |
+
"required_grads has been set automatically." % index
|
60 |
+
)
|
61 |
+
input.requires_grad_()
|
62 |
+
return grad_required
|
63 |
+
|
64 |
+
|
65 |
+
def undo_gradient_requirements(
|
66 |
+
inputs: Tuple[Tensor, ...], grad_required: List[bool]
|
67 |
+
) -> None:
|
68 |
+
"""
|
69 |
+
Iterates through list of tensors, zeros each gradient, and sets required
|
70 |
+
grad to false if the corresponding index in grad_required is False.
|
71 |
+
This method is used to undo the effects of prepare_gradient_inputs, making
|
72 |
+
grads not required for any input tensor that did not initially require
|
73 |
+
gradients.
|
74 |
+
"""
|
75 |
+
|
76 |
+
assert isinstance(
|
77 |
+
inputs, tuple
|
78 |
+
), "Inputs should be wrapped in a tuple prior to preparing for gradients."
|
79 |
+
assert len(inputs) == len(
|
80 |
+
grad_required
|
81 |
+
), "Input tuple length should match gradient mask."
|
82 |
+
for index, input in enumerate(inputs):
|
83 |
+
assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor"
|
84 |
+
if not grad_required[index]:
|
85 |
+
input.requires_grad_(False)
|
86 |
+
|
87 |
+
|
88 |
+
def compute_gradients(
|
89 |
+
forward_fn: Callable,
|
90 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
91 |
+
target_ind: TargetType = None,
|
92 |
+
additional_forward_args: Any = None,
|
93 |
+
) -> Tuple[Tensor, ...]:
|
94 |
+
r"""
|
95 |
+
Computes gradients of the output with respect to inputs for an
|
96 |
+
arbitrary forward function.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
|
100 |
+
forward_fn: forward function. This can be for example model's
|
101 |
+
forward function.
|
102 |
+
input: Input at which gradients are evaluated,
|
103 |
+
will be passed to forward_fn.
|
104 |
+
target_ind: Index of the target class for which gradients
|
105 |
+
must be computed (classification only).
|
106 |
+
additional_forward_args: Additional input arguments that forward
|
107 |
+
function requires. It takes an empty tuple (no additional
|
108 |
+
arguments) if no additional arguments are required
|
109 |
+
"""
|
110 |
+
with torch.autograd.set_grad_enabled(True):
|
111 |
+
# runs forward pass
|
112 |
+
outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
|
113 |
+
assert outputs[0].numel() == 1, (
|
114 |
+
"Target not provided when necessary, cannot"
|
115 |
+
" take gradient with respect to multiple outputs."
|
116 |
+
)
|
117 |
+
# torch.unbind(forward_out) is a list of scalar tensor tuples and
|
118 |
+
# contains batch_size * #steps elements
|
119 |
+
grads = torch.autograd.grad(torch.unbind(outputs), inputs)
|
120 |
+
return grads
|
121 |
+
|
122 |
+
|
123 |
+
def _neuron_gradients(
|
124 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
125 |
+
saved_layer: Dict[device, Tuple[Tensor, ...]],
|
126 |
+
key_list: List[device],
|
127 |
+
gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
|
128 |
+
) -> Tuple[Tensor, ...]:
|
129 |
+
with torch.autograd.set_grad_enabled(True):
|
130 |
+
gradient_tensors = []
|
131 |
+
for key in key_list:
|
132 |
+
current_out_tensor = _verify_select_neuron(
|
133 |
+
saved_layer[key], gradient_neuron_selector
|
134 |
+
)
|
135 |
+
gradient_tensors.append(
|
136 |
+
torch.autograd.grad(
|
137 |
+
torch.unbind(current_out_tensor)
|
138 |
+
if current_out_tensor.numel() > 1
|
139 |
+
else current_out_tensor,
|
140 |
+
inputs,
|
141 |
+
)
|
142 |
+
)
|
143 |
+
_total_gradients = _reduce_list(gradient_tensors, sum)
|
144 |
+
return _total_gradients
|
145 |
+
|
146 |
+
|
147 |
+
@typing.overload
|
148 |
+
def _forward_layer_eval(
|
149 |
+
forward_fn: Callable,
|
150 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
151 |
+
layer: Module,
|
152 |
+
additional_forward_args: Any = None,
|
153 |
+
device_ids: Union[None, List[int]] = None,
|
154 |
+
attribute_to_layer_input: bool = False,
|
155 |
+
grad_enabled: bool = False,
|
156 |
+
) -> Tuple[Tensor, ...]:
|
157 |
+
...
|
158 |
+
|
159 |
+
|
160 |
+
@typing.overload
|
161 |
+
def _forward_layer_eval(
|
162 |
+
forward_fn: Callable,
|
163 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
164 |
+
layer: List[Module],
|
165 |
+
additional_forward_args: Any = None,
|
166 |
+
device_ids: Union[None, List[int]] = None,
|
167 |
+
attribute_to_layer_input: bool = False,
|
168 |
+
grad_enabled: bool = False,
|
169 |
+
) -> List[Tuple[Tensor, ...]]:
|
170 |
+
...
|
171 |
+
|
172 |
+
|
173 |
+
def _forward_layer_eval(
|
174 |
+
forward_fn: Callable,
|
175 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
176 |
+
layer: ModuleOrModuleList,
|
177 |
+
additional_forward_args: Any = None,
|
178 |
+
device_ids: Union[None, List[int]] = None,
|
179 |
+
attribute_to_layer_input: bool = False,
|
180 |
+
grad_enabled: bool = False,
|
181 |
+
) -> Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]]:
|
182 |
+
return _forward_layer_eval_with_neuron_grads(
|
183 |
+
forward_fn,
|
184 |
+
inputs,
|
185 |
+
layer,
|
186 |
+
additional_forward_args=additional_forward_args,
|
187 |
+
gradient_neuron_selector=None,
|
188 |
+
grad_enabled=grad_enabled,
|
189 |
+
device_ids=device_ids,
|
190 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
191 |
+
)
|
192 |
+
|
193 |
+
|
194 |
+
@typing.overload
|
195 |
+
def _forward_layer_distributed_eval(
|
196 |
+
forward_fn: Callable,
|
197 |
+
inputs: Any,
|
198 |
+
layer: ModuleOrModuleList,
|
199 |
+
target_ind: TargetType = None,
|
200 |
+
additional_forward_args: Any = None,
|
201 |
+
attribute_to_layer_input: bool = False,
|
202 |
+
forward_hook_with_return: Literal[False] = False,
|
203 |
+
require_layer_grads: bool = False,
|
204 |
+
) -> Dict[Module, Dict[device, Tuple[Tensor, ...]]]:
|
205 |
+
...
|
206 |
+
|
207 |
+
|
208 |
+
@typing.overload
|
209 |
+
def _forward_layer_distributed_eval(
|
210 |
+
forward_fn: Callable,
|
211 |
+
inputs: Any,
|
212 |
+
layer: ModuleOrModuleList,
|
213 |
+
target_ind: TargetType = None,
|
214 |
+
additional_forward_args: Any = None,
|
215 |
+
attribute_to_layer_input: bool = False,
|
216 |
+
*,
|
217 |
+
forward_hook_with_return: Literal[True],
|
218 |
+
require_layer_grads: bool = False,
|
219 |
+
) -> Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor]:
|
220 |
+
...
|
221 |
+
|
222 |
+
|
223 |
+
def _forward_layer_distributed_eval(
|
224 |
+
forward_fn: Callable,
|
225 |
+
inputs: Any,
|
226 |
+
layer: ModuleOrModuleList,
|
227 |
+
target_ind: TargetType = None,
|
228 |
+
additional_forward_args: Any = None,
|
229 |
+
attribute_to_layer_input: bool = False,
|
230 |
+
forward_hook_with_return: bool = False,
|
231 |
+
require_layer_grads: bool = False,
|
232 |
+
) -> Union[
|
233 |
+
Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor],
|
234 |
+
Dict[Module, Dict[device, Tuple[Tensor, ...]]],
|
235 |
+
]:
|
236 |
+
r"""
|
237 |
+
A helper function that allows to set a hook on model's `layer`, run the forward
|
238 |
+
pass and returns intermediate layer results, stored in a dictionary,
|
239 |
+
and optionally also the output of the forward function. The keys in the
|
240 |
+
dictionary are the device ids and the values are corresponding intermediate layer
|
241 |
+
results, either the inputs or the outputs of the layer depending on whether we set
|
242 |
+
`attribute_to_layer_input` to True or False.
|
243 |
+
This is especially useful when we execute forward pass in a distributed setting,
|
244 |
+
using `DataParallel`s for example.
|
245 |
+
"""
|
246 |
+
saved_layer: Dict[Module, Dict[device, Tuple[Tensor, ...]]] = defaultdict(dict)
|
247 |
+
lock = threading.Lock()
|
248 |
+
all_layers: List[Module] = [layer] if isinstance(layer, Module) else layer
|
249 |
+
|
250 |
+
# Set a forward hook on specified module and run forward pass to
|
251 |
+
# get layer output tensor(s).
|
252 |
+
# For DataParallel models, each partition adds entry to dictionary
|
253 |
+
# with key as device and value as corresponding Tensor.
|
254 |
+
def hook_wrapper(original_module):
|
255 |
+
def forward_hook(module, inp, out=None):
|
256 |
+
eval_tsrs = inp if attribute_to_layer_input else out
|
257 |
+
is_eval_tuple = isinstance(eval_tsrs, tuple)
|
258 |
+
|
259 |
+
if not is_eval_tuple:
|
260 |
+
eval_tsrs = (eval_tsrs,)
|
261 |
+
if require_layer_grads:
|
262 |
+
apply_gradient_requirements(eval_tsrs, warn=False)
|
263 |
+
with lock:
|
264 |
+
nonlocal saved_layer
|
265 |
+
# Note that cloning behaviour of `eval_tsr` is different
|
266 |
+
# when `forward_hook_with_return` is set to True. This is because
|
267 |
+
# otherwise `backward()` on the last output layer won't execute.
|
268 |
+
if forward_hook_with_return:
|
269 |
+
saved_layer[original_module][eval_tsrs[0].device] = eval_tsrs
|
270 |
+
eval_tsrs_to_return = tuple(
|
271 |
+
eval_tsr.clone() for eval_tsr in eval_tsrs
|
272 |
+
)
|
273 |
+
if not is_eval_tuple:
|
274 |
+
eval_tsrs_to_return = eval_tsrs_to_return[0]
|
275 |
+
return eval_tsrs_to_return
|
276 |
+
else:
|
277 |
+
saved_layer[original_module][eval_tsrs[0].device] = tuple(
|
278 |
+
eval_tsr.clone() for eval_tsr in eval_tsrs
|
279 |
+
)
|
280 |
+
|
281 |
+
return forward_hook
|
282 |
+
|
283 |
+
all_hooks = []
|
284 |
+
try:
|
285 |
+
for single_layer in all_layers:
|
286 |
+
if attribute_to_layer_input:
|
287 |
+
all_hooks.append(
|
288 |
+
single_layer.register_forward_pre_hook(hook_wrapper(single_layer))
|
289 |
+
)
|
290 |
+
else:
|
291 |
+
all_hooks.append(
|
292 |
+
single_layer.register_forward_hook(hook_wrapper(single_layer))
|
293 |
+
)
|
294 |
+
output = _run_forward(
|
295 |
+
forward_fn,
|
296 |
+
inputs,
|
297 |
+
target=target_ind,
|
298 |
+
additional_forward_args=additional_forward_args,
|
299 |
+
)
|
300 |
+
finally:
|
301 |
+
for hook in all_hooks:
|
302 |
+
hook.remove()
|
303 |
+
|
304 |
+
if len(saved_layer) == 0:
|
305 |
+
raise AssertionError("Forward hook did not obtain any outputs for given layer")
|
306 |
+
|
307 |
+
if forward_hook_with_return:
|
308 |
+
return saved_layer, output
|
309 |
+
return saved_layer
|
310 |
+
|
311 |
+
|
312 |
+
def _gather_distributed_tensors(
|
313 |
+
saved_layer: Dict[device, Tuple[Tensor, ...]],
|
314 |
+
device_ids: Union[None, List[int]] = None,
|
315 |
+
key_list: Union[None, List[device]] = None,
|
316 |
+
) -> Tuple[Tensor, ...]:
|
317 |
+
r"""
|
318 |
+
A helper function to concatenate intermediate layer results stored on
|
319 |
+
different devices in `saved_layer`. `saved_layer` is a dictionary that
|
320 |
+
contains `device_id` as a key and intermediate layer results (either
|
321 |
+
the input or the output of the layer) stored on the device corresponding to
|
322 |
+
the key.
|
323 |
+
`key_list` is a list of devices in appropriate ordering for concatenation
|
324 |
+
and if not provided, keys are sorted based on device ids.
|
325 |
+
|
326 |
+
If only one key exists (standard model), key list simply has one element.
|
327 |
+
"""
|
328 |
+
if key_list is None:
|
329 |
+
key_list = _sort_key_list(list(saved_layer.keys()), device_ids)
|
330 |
+
return _reduce_list([saved_layer[device_id] for device_id in key_list])
|
331 |
+
|
332 |
+
|
333 |
+
def _extract_device_ids(
|
334 |
+
forward_fn: Callable,
|
335 |
+
saved_layer: Dict[Module, Dict[device, Tuple[Tensor, ...]]],
|
336 |
+
device_ids: Union[None, List[int]],
|
337 |
+
) -> Union[None, List[int]]:
|
338 |
+
r"""
|
339 |
+
A helper function to extract device_ids from `forward_function` in case it is
|
340 |
+
provided as part of a `DataParallel` model or if is accessible from
|
341 |
+
`forward_fn`.
|
342 |
+
In case input device_ids is not None, this function returns that value.
|
343 |
+
"""
|
344 |
+
# Multiple devices / keys implies a DataParallel model, so we look for
|
345 |
+
# device IDs if given or available from forward function
|
346 |
+
# (DataParallel model object).
|
347 |
+
if (
|
348 |
+
max(len(saved_layer[single_layer]) for single_layer in saved_layer) > 1
|
349 |
+
and device_ids is None
|
350 |
+
):
|
351 |
+
if (
|
352 |
+
hasattr(forward_fn, "device_ids")
|
353 |
+
and cast(Any, forward_fn).device_ids is not None
|
354 |
+
):
|
355 |
+
device_ids = cast(Any, forward_fn).device_ids
|
356 |
+
else:
|
357 |
+
raise AssertionError(
|
358 |
+
"Layer tensors are saved on multiple devices, however unable to access"
|
359 |
+
" device ID list from the `forward_fn`. Device ID list must be"
|
360 |
+
" accessible from `forward_fn`. For example, they can be retrieved"
|
361 |
+
" if `forward_fn` is a model of type `DataParallel`. It is used"
|
362 |
+
" for identifying device batch ordering."
|
363 |
+
)
|
364 |
+
return device_ids
|
365 |
+
|
366 |
+
|
367 |
+
@typing.overload
|
368 |
+
def _forward_layer_eval_with_neuron_grads(
|
369 |
+
forward_fn: Callable,
|
370 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
371 |
+
layer: Module,
|
372 |
+
additional_forward_args: Any = None,
|
373 |
+
*,
|
374 |
+
gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
|
375 |
+
grad_enabled: bool = False,
|
376 |
+
device_ids: Union[None, List[int]] = None,
|
377 |
+
attribute_to_layer_input: bool = False,
|
378 |
+
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
|
379 |
+
...
|
380 |
+
|
381 |
+
|
382 |
+
@typing.overload
|
383 |
+
def _forward_layer_eval_with_neuron_grads(
|
384 |
+
forward_fn: Callable,
|
385 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
386 |
+
layer: Module,
|
387 |
+
additional_forward_args: Any = None,
|
388 |
+
gradient_neuron_selector: None = None,
|
389 |
+
grad_enabled: bool = False,
|
390 |
+
device_ids: Union[None, List[int]] = None,
|
391 |
+
attribute_to_layer_input: bool = False,
|
392 |
+
) -> Tuple[Tensor, ...]:
|
393 |
+
...
|
394 |
+
|
395 |
+
|
396 |
+
@typing.overload
|
397 |
+
def _forward_layer_eval_with_neuron_grads(
|
398 |
+
forward_fn: Callable,
|
399 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
400 |
+
layer: List[Module],
|
401 |
+
additional_forward_args: Any = None,
|
402 |
+
gradient_neuron_selector: None = None,
|
403 |
+
grad_enabled: bool = False,
|
404 |
+
device_ids: Union[None, List[int]] = None,
|
405 |
+
attribute_to_layer_input: bool = False,
|
406 |
+
) -> List[Tuple[Tensor, ...]]:
|
407 |
+
...
|
408 |
+
|
409 |
+
|
410 |
+
def _forward_layer_eval_with_neuron_grads(
|
411 |
+
forward_fn: Callable,
|
412 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
413 |
+
layer: ModuleOrModuleList,
|
414 |
+
additional_forward_args: Any = None,
|
415 |
+
gradient_neuron_selector: Union[
|
416 |
+
None, int, Tuple[Union[int, slice], ...], Callable
|
417 |
+
] = None,
|
418 |
+
grad_enabled: bool = False,
|
419 |
+
device_ids: Union[None, List[int]] = None,
|
420 |
+
attribute_to_layer_input: bool = False,
|
421 |
+
) -> Union[
|
422 |
+
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
|
423 |
+
Tuple[Tensor, ...],
|
424 |
+
List[Tuple[Tensor, ...]],
|
425 |
+
]:
|
426 |
+
"""
|
427 |
+
This method computes forward evaluation for a particular layer using a
|
428 |
+
forward hook. If a gradient_neuron_selector is provided, then gradients with
|
429 |
+
respect to that neuron in the layer output are also returned.
|
430 |
+
|
431 |
+
These functionalities are combined due to the behavior of DataParallel models
|
432 |
+
with hooks, in which hooks are executed once per device. We need to internally
|
433 |
+
combine the separated tensors from devices by concatenating based on device_ids.
|
434 |
+
Any necessary gradients must be taken with respect to each independent batched
|
435 |
+
tensor, so the gradients are computed and combined appropriately.
|
436 |
+
|
437 |
+
More information regarding the behavior of forward hooks with DataParallel models
|
438 |
+
can be found in the PyTorch data parallel documentation. We maintain the separate
|
439 |
+
evals in a dictionary protected by a lock, analogous to the gather implementation
|
440 |
+
for the core PyTorch DataParallel implementation.
|
441 |
+
"""
|
442 |
+
grad_enabled = True if gradient_neuron_selector is not None else grad_enabled
|
443 |
+
|
444 |
+
with torch.autograd.set_grad_enabled(grad_enabled):
|
445 |
+
saved_layer = _forward_layer_distributed_eval(
|
446 |
+
forward_fn,
|
447 |
+
inputs,
|
448 |
+
layer,
|
449 |
+
additional_forward_args=additional_forward_args,
|
450 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
451 |
+
)
|
452 |
+
device_ids = _extract_device_ids(forward_fn, saved_layer, device_ids)
|
453 |
+
# Identifies correct device ordering based on device ids.
|
454 |
+
# key_list is a list of devices in appropriate ordering for concatenation.
|
455 |
+
# If only one key exists (standard model), key list simply has one element.
|
456 |
+
key_list = _sort_key_list(list(next(iter(saved_layer.values())).keys()), device_ids)
|
457 |
+
if gradient_neuron_selector is not None:
|
458 |
+
assert isinstance(
|
459 |
+
layer, Module
|
460 |
+
), "Cannot compute neuron gradients for multiple layers simultaneously!"
|
461 |
+
inp_grads = _neuron_gradients(
|
462 |
+
inputs, saved_layer[layer], key_list, gradient_neuron_selector
|
463 |
+
)
|
464 |
+
return (
|
465 |
+
_gather_distributed_tensors(saved_layer[layer], key_list=key_list),
|
466 |
+
inp_grads,
|
467 |
+
)
|
468 |
+
else:
|
469 |
+
if isinstance(layer, Module):
|
470 |
+
return _gather_distributed_tensors(saved_layer[layer], key_list=key_list)
|
471 |
+
else:
|
472 |
+
return [
|
473 |
+
_gather_distributed_tensors(saved_layer[curr_layer], key_list=key_list)
|
474 |
+
for curr_layer in layer
|
475 |
+
]
|
476 |
+
|
477 |
+
|
478 |
+
@typing.overload
|
479 |
+
def compute_layer_gradients_and_eval(
|
480 |
+
forward_fn: Callable,
|
481 |
+
layer: Module,
|
482 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
483 |
+
target_ind: TargetType = None,
|
484 |
+
additional_forward_args: Any = None,
|
485 |
+
*,
|
486 |
+
gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
|
487 |
+
device_ids: Union[None, List[int]] = None,
|
488 |
+
attribute_to_layer_input: bool = False,
|
489 |
+
output_fn: Union[None, Callable] = None,
|
490 |
+
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]]:
|
491 |
+
...
|
492 |
+
|
493 |
+
|
494 |
+
@typing.overload
|
495 |
+
def compute_layer_gradients_and_eval(
|
496 |
+
forward_fn: Callable,
|
497 |
+
layer: List[Module],
|
498 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
499 |
+
target_ind: TargetType = None,
|
500 |
+
additional_forward_args: Any = None,
|
501 |
+
gradient_neuron_selector: None = None,
|
502 |
+
device_ids: Union[None, List[int]] = None,
|
503 |
+
attribute_to_layer_input: bool = False,
|
504 |
+
output_fn: Union[None, Callable] = None,
|
505 |
+
) -> Tuple[List[Tuple[Tensor, ...]], List[Tuple[Tensor, ...]]]:
|
506 |
+
...
|
507 |
+
|
508 |
+
|
509 |
+
@typing.overload
|
510 |
+
def compute_layer_gradients_and_eval(
|
511 |
+
forward_fn: Callable,
|
512 |
+
layer: Module,
|
513 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
514 |
+
target_ind: TargetType = None,
|
515 |
+
additional_forward_args: Any = None,
|
516 |
+
gradient_neuron_selector: None = None,
|
517 |
+
device_ids: Union[None, List[int]] = None,
|
518 |
+
attribute_to_layer_input: bool = False,
|
519 |
+
output_fn: Union[None, Callable] = None,
|
520 |
+
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
|
521 |
+
...
|
522 |
+
|
523 |
+
|
524 |
+
def compute_layer_gradients_and_eval(
|
525 |
+
forward_fn: Callable,
|
526 |
+
layer: ModuleOrModuleList,
|
527 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
528 |
+
target_ind: TargetType = None,
|
529 |
+
additional_forward_args: Any = None,
|
530 |
+
gradient_neuron_selector: Union[
|
531 |
+
None, int, Tuple[Union[int, slice], ...], Callable
|
532 |
+
] = None,
|
533 |
+
device_ids: Union[None, List[int]] = None,
|
534 |
+
attribute_to_layer_input: bool = False,
|
535 |
+
output_fn: Union[None, Callable] = None,
|
536 |
+
) -> Union[
|
537 |
+
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
|
538 |
+
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]],
|
539 |
+
Tuple[List[Tuple[Tensor, ...]], List[Tuple[Tensor, ...]]],
|
540 |
+
]:
|
541 |
+
r"""
|
542 |
+
Computes gradients of the output with respect to a given layer as well
|
543 |
+
as the output evaluation of the layer for an arbitrary forward function
|
544 |
+
and given input.
|
545 |
+
|
546 |
+
For data parallel models, hooks are executed once per device ,so we
|
547 |
+
need to internally combine the separated tensors from devices by
|
548 |
+
concatenating based on device_ids. Any necessary gradients must be taken
|
549 |
+
with respect to each independent batched tensor, so the gradients are
|
550 |
+
computed and combined appropriately.
|
551 |
+
|
552 |
+
More information regarding the behavior of forward hooks with DataParallel
|
553 |
+
models can be found in the PyTorch data parallel documentation. We maintain
|
554 |
+
the separate inputs in a dictionary protected by a lock, analogous to the
|
555 |
+
gather implementation for the core PyTorch DataParallel implementation.
|
556 |
+
|
557 |
+
NOTE: To properly handle inplace operations, a clone of the layer output
|
558 |
+
is stored. This structure inhibits execution of a backward hook on the last
|
559 |
+
module for the layer output when computing the gradient with respect to
|
560 |
+
the input, since we store an intermediate clone, as
|
561 |
+
opposed to the true module output. If backward module hooks are necessary
|
562 |
+
for the final module when computing input gradients, utilize
|
563 |
+
_forward_layer_eval_with_neuron_grads instead.
|
564 |
+
|
565 |
+
Args:
|
566 |
+
|
567 |
+
forward_fn: forward function. This can be for example model's
|
568 |
+
forward function.
|
569 |
+
layer: Layer for which gradients / output will be evaluated.
|
570 |
+
inputs: Input at which gradients are evaluated,
|
571 |
+
will be passed to forward_fn.
|
572 |
+
target_ind: Index of the target class for which gradients
|
573 |
+
must be computed (classification only).
|
574 |
+
output_fn: An optional function that is applied to the layer inputs or
|
575 |
+
outputs depending whether the `attribute_to_layer_input` is
|
576 |
+
set to `True` or `False`
|
577 |
+
args: Additional input arguments that forward function requires.
|
578 |
+
It takes an empty tuple (no additional arguments) if no
|
579 |
+
additional arguments are required
|
580 |
+
|
581 |
+
|
582 |
+
Returns:
|
583 |
+
2-element tuple of **gradients**, **evals**:
|
584 |
+
- **gradients**:
|
585 |
+
Gradients of output with respect to target layer output.
|
586 |
+
- **evals**:
|
587 |
+
Target layer output for given input.
|
588 |
+
"""
|
589 |
+
with torch.autograd.set_grad_enabled(True):
|
590 |
+
# saved_layer is a dictionary mapping device to a tuple of
|
591 |
+
# layer evaluations on that device.
|
592 |
+
saved_layer, output = _forward_layer_distributed_eval(
|
593 |
+
forward_fn,
|
594 |
+
inputs,
|
595 |
+
layer,
|
596 |
+
target_ind=target_ind,
|
597 |
+
additional_forward_args=additional_forward_args,
|
598 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
599 |
+
forward_hook_with_return=True,
|
600 |
+
require_layer_grads=True,
|
601 |
+
)
|
602 |
+
assert output[0].numel() == 1, (
|
603 |
+
"Target not provided when necessary, cannot"
|
604 |
+
" take gradient with respect to multiple outputs."
|
605 |
+
)
|
606 |
+
|
607 |
+
device_ids = _extract_device_ids(forward_fn, saved_layer, device_ids)
|
608 |
+
|
609 |
+
# Identifies correct device ordering based on device ids.
|
610 |
+
# key_list is a list of devices in appropriate ordering for concatenation.
|
611 |
+
# If only one key exists (standard model), key list simply has one element.
|
612 |
+
key_list = _sort_key_list(
|
613 |
+
list(next(iter(saved_layer.values())).keys()), device_ids
|
614 |
+
)
|
615 |
+
all_outputs: Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]]
|
616 |
+
if isinstance(layer, Module):
|
617 |
+
all_outputs = _reduce_list(
|
618 |
+
[
|
619 |
+
saved_layer[layer][device_id]
|
620 |
+
if output_fn is None
|
621 |
+
else output_fn(saved_layer[layer][device_id])
|
622 |
+
for device_id in key_list
|
623 |
+
]
|
624 |
+
)
|
625 |
+
else:
|
626 |
+
all_outputs = [
|
627 |
+
_reduce_list(
|
628 |
+
[
|
629 |
+
saved_layer[single_layer][device_id]
|
630 |
+
if output_fn is None
|
631 |
+
else output_fn(saved_layer[single_layer][device_id])
|
632 |
+
for device_id in key_list
|
633 |
+
]
|
634 |
+
)
|
635 |
+
for single_layer in layer
|
636 |
+
]
|
637 |
+
all_layers: List[Module] = [layer] if isinstance(layer, Module) else layer
|
638 |
+
grad_inputs = tuple(
|
639 |
+
layer_tensor
|
640 |
+
for single_layer in all_layers
|
641 |
+
for device_id in key_list
|
642 |
+
for layer_tensor in saved_layer[single_layer][device_id]
|
643 |
+
)
|
644 |
+
saved_grads = torch.autograd.grad(torch.unbind(output), grad_inputs)
|
645 |
+
|
646 |
+
offset = 0
|
647 |
+
all_grads: List[Tuple[Tensor, ...]] = []
|
648 |
+
for single_layer in all_layers:
|
649 |
+
num_tensors = len(next(iter(saved_layer[single_layer].values())))
|
650 |
+
curr_saved_grads = [
|
651 |
+
saved_grads[i : i + num_tensors]
|
652 |
+
for i in range(
|
653 |
+
offset, offset + len(key_list) * num_tensors, num_tensors
|
654 |
+
)
|
655 |
+
]
|
656 |
+
offset += len(key_list) * num_tensors
|
657 |
+
if output_fn is not None:
|
658 |
+
curr_saved_grads = [
|
659 |
+
output_fn(curr_saved_grad) for curr_saved_grad in curr_saved_grads
|
660 |
+
]
|
661 |
+
|
662 |
+
all_grads.append(_reduce_list(curr_saved_grads))
|
663 |
+
|
664 |
+
layer_grads: Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]]
|
665 |
+
layer_grads = all_grads
|
666 |
+
if isinstance(layer, Module):
|
667 |
+
layer_grads = all_grads[0]
|
668 |
+
|
669 |
+
if gradient_neuron_selector is not None:
|
670 |
+
assert isinstance(
|
671 |
+
layer, Module
|
672 |
+
), "Cannot compute neuron gradients for multiple layers simultaneously!"
|
673 |
+
inp_grads = _neuron_gradients(
|
674 |
+
inputs, saved_layer[layer], key_list, gradient_neuron_selector
|
675 |
+
)
|
676 |
+
return (
|
677 |
+
cast(Tuple[Tensor, ...], layer_grads),
|
678 |
+
cast(Tuple[Tensor, ...], all_outputs),
|
679 |
+
inp_grads,
|
680 |
+
)
|
681 |
+
return layer_grads, all_outputs # type: ignore
|
682 |
+
|
683 |
+
|
684 |
+
def construct_neuron_grad_fn(
|
685 |
+
layer: Module,
|
686 |
+
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
|
687 |
+
device_ids: Union[None, List[int]] = None,
|
688 |
+
attribute_to_neuron_input: bool = False,
|
689 |
+
) -> Callable:
|
690 |
+
def grad_fn(
|
691 |
+
forward_fn: Callable,
|
692 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
693 |
+
target_ind: TargetType = None,
|
694 |
+
additional_forward_args: Any = None,
|
695 |
+
) -> Tuple[Tensor, ...]:
|
696 |
+
_, grads = _forward_layer_eval_with_neuron_grads(
|
697 |
+
forward_fn,
|
698 |
+
inputs,
|
699 |
+
layer,
|
700 |
+
additional_forward_args,
|
701 |
+
gradient_neuron_selector=neuron_selector,
|
702 |
+
device_ids=device_ids,
|
703 |
+
attribute_to_layer_input=attribute_to_neuron_input,
|
704 |
+
)
|
705 |
+
return grads
|
706 |
+
|
707 |
+
return grad_fn
|
708 |
+
|
709 |
+
|
710 |
+
def _compute_jacobian_wrt_params(
|
711 |
+
model: Module,
|
712 |
+
inputs: Tuple[Any, ...],
|
713 |
+
labels: Optional[Tensor] = None,
|
714 |
+
loss_fn: Optional[Union[Module, Callable]] = None,
|
715 |
+
) -> Tuple[Tensor, ...]:
|
716 |
+
r"""
|
717 |
+
Computes the Jacobian of a batch of test examples given a model, and optional
|
718 |
+
loss function and target labels. This method is equivalent to calculating the
|
719 |
+
gradient for every individual example in the minibatch.
|
720 |
+
|
721 |
+
Args:
|
722 |
+
model (torch.nn.Module): The trainable model providing the forward pass
|
723 |
+
inputs (tuple of Any): The minibatch for which the forward pass is computed.
|
724 |
+
It is unpacked before passing to `model`, so it must be a tuple. The
|
725 |
+
individual elements of `inputs` can be anything.
|
726 |
+
labels (Tensor or None): Labels for input if computing a loss function.
|
727 |
+
loss_fn (torch.nn.Module or Callable or None): The loss function. If a library
|
728 |
+
defined loss function is provided, it would be expected to be a
|
729 |
+
torch.nn.Module. If a custom loss is provided, it can be either type,
|
730 |
+
but must behave as a library loss function would if `reduction='none'`.
|
731 |
+
|
732 |
+
Returns:
|
733 |
+
grads (Tuple of Tensor): Returns the Jacobian for the minibatch as a
|
734 |
+
tuple of gradients corresponding to the tuple of trainable parameters
|
735 |
+
returned by `model.parameters()`. Each object grads[i] references to the
|
736 |
+
gradients for the parameters in the i-th trainable layer of the model.
|
737 |
+
Each grads[i] object is a tensor with the gradients for the `inputs`
|
738 |
+
batch. For example, grads[i][j] would reference the gradients for the
|
739 |
+
parameters of the i-th layer, for the j-th member of the minibatch.
|
740 |
+
"""
|
741 |
+
with torch.autograd.set_grad_enabled(True):
|
742 |
+
out = model(*inputs)
|
743 |
+
assert out.dim() != 0, "Please ensure model output has at least one dimension."
|
744 |
+
|
745 |
+
if labels is not None and loss_fn is not None:
|
746 |
+
loss = loss_fn(out, labels)
|
747 |
+
if hasattr(loss_fn, "reduction"):
|
748 |
+
msg0 = "Please ensure loss_fn.reduction is set to `none`"
|
749 |
+
assert loss_fn.reduction == "none", msg0 # type: ignore
|
750 |
+
else:
|
751 |
+
msg1 = (
|
752 |
+
"Loss function is applying a reduction. Please ensure "
|
753 |
+
f"Output shape: {out.shape} and Loss shape: {loss.shape} "
|
754 |
+
"are matching."
|
755 |
+
)
|
756 |
+
assert loss.dim() != 0, msg1
|
757 |
+
assert out.shape[0] == loss.shape[0], msg1
|
758 |
+
out = loss
|
759 |
+
|
760 |
+
grads_list = [
|
761 |
+
torch.autograd.grad(
|
762 |
+
outputs=out[i],
|
763 |
+
inputs=model.parameters(), # type: ignore
|
764 |
+
grad_outputs=torch.ones_like(out[i]),
|
765 |
+
retain_graph=True,
|
766 |
+
)
|
767 |
+
for i in range(out.shape[0])
|
768 |
+
]
|
769 |
+
|
770 |
+
grads = tuple([torch.stack(x) for x in zip(*grads_list)])
|
771 |
+
|
772 |
+
return tuple(grads)
|
773 |
+
|
774 |
+
|
775 |
+
def _compute_jacobian_wrt_params_with_sample_wise_trick(
|
776 |
+
model: Module,
|
777 |
+
inputs: Tuple[Any, ...],
|
778 |
+
labels: Optional[Tensor] = None,
|
779 |
+
loss_fn: Optional[Union[Module, Callable]] = None,
|
780 |
+
reduction_type: Optional[str] = "sum",
|
781 |
+
) -> Tuple[Any, ...]:
|
782 |
+
r"""
|
783 |
+
Computes the Jacobian of a batch of test examples given a model, and optional
|
784 |
+
loss function and target labels. This method uses sample-wise gradients per
|
785 |
+
batch trick to fully vectorize the Jacobian calculation. Currently, only
|
786 |
+
linear and conv2d layers are supported.
|
787 |
+
|
788 |
+
User must `add_hooks(model)` before calling this function.
|
789 |
+
|
790 |
+
Args:
|
791 |
+
model (torch.nn.Module): The trainable model providing the forward pass
|
792 |
+
inputs (tuple of Any): The minibatch for which the forward pass is computed.
|
793 |
+
It is unpacked before passing to `model`, so it must be a tuple. The
|
794 |
+
individual elements of `inputs` can be anything.
|
795 |
+
labels (Tensor or None): Labels for input if computing a loss function.
|
796 |
+
loss_fn (torch.nn.Module or Callable or None): The loss function. If a library
|
797 |
+
defined loss function is provided, it would be expected to be a
|
798 |
+
torch.nn.Module. If a custom loss is provided, it can be either type,
|
799 |
+
but must behave as a library loss function would if `reduction='sum'` or
|
800 |
+
`reduction='mean'`.
|
801 |
+
reduction_type (str): The type of reduction applied. If a loss_fn is passed,
|
802 |
+
this should match `loss_fn.reduction`. Else if gradients are being
|
803 |
+
computed on direct model outputs (scores), then 'sum' should be used.
|
804 |
+
Defaults to 'sum'.
|
805 |
+
|
806 |
+
Returns:
|
807 |
+
grads (Tuple of Tensor): Returns the Jacobian for the minibatch as a
|
808 |
+
tuple of gradients corresponding to the tuple of trainable parameters
|
809 |
+
returned by `model.parameters()`. Each object grads[i] references to the
|
810 |
+
gradients for the parameters in the i-th trainable layer of the model.
|
811 |
+
Each grads[i] object is a tensor with the gradients for the `inputs`
|
812 |
+
batch. For example, grads[i][j] would reference the gradients for the
|
813 |
+
parameters of the i-th layer, for the j-th member of the minibatch.
|
814 |
+
"""
|
815 |
+
with torch.autograd.set_grad_enabled(True):
|
816 |
+
sample_grad_wrapper = SampleGradientWrapper(model)
|
817 |
+
try:
|
818 |
+
sample_grad_wrapper.add_hooks()
|
819 |
+
|
820 |
+
out = model(*inputs)
|
821 |
+
assert (
|
822 |
+
out.dim() != 0
|
823 |
+
), "Please ensure model output has at least one dimension."
|
824 |
+
|
825 |
+
if labels is not None and loss_fn is not None:
|
826 |
+
loss = loss_fn(out, labels)
|
827 |
+
# TODO: allow loss_fn to be Callable
|
828 |
+
if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
|
829 |
+
msg0 = (
|
830 |
+
"Please ensure that loss_fn.reduction is set to `sum` or `mean`"
|
831 |
+
)
|
832 |
+
|
833 |
+
assert loss_fn.reduction != "none", msg0
|
834 |
+
msg1 = (
|
835 |
+
f"loss_fn.reduction ({loss_fn.reduction}) does not match"
|
836 |
+
f"reduction type ({reduction_type}). Please ensure they are"
|
837 |
+
" matching."
|
838 |
+
)
|
839 |
+
assert loss_fn.reduction == reduction_type, msg1
|
840 |
+
msg2 = (
|
841 |
+
"Please ensure custom loss function is applying either a "
|
842 |
+
"sum or mean reduction."
|
843 |
+
)
|
844 |
+
assert out.shape != loss.shape, msg2
|
845 |
+
|
846 |
+
if reduction_type != "sum" and reduction_type != "mean":
|
847 |
+
raise ValueError(
|
848 |
+
f"{reduction_type} is not a valid value for reduction_type. "
|
849 |
+
"Must be either 'sum' or 'mean'."
|
850 |
+
)
|
851 |
+
out = loss
|
852 |
+
|
853 |
+
sample_grad_wrapper.compute_param_sample_gradients(
|
854 |
+
out, loss_mode=reduction_type
|
855 |
+
)
|
856 |
+
|
857 |
+
grads = tuple(
|
858 |
+
param.sample_grad # type: ignore
|
859 |
+
for param in model.parameters()
|
860 |
+
if hasattr(param, "sample_grad")
|
861 |
+
)
|
862 |
+
finally:
|
863 |
+
sample_grad_wrapper.remove_hooks()
|
864 |
+
|
865 |
+
return grads
|
captum/_utils/models/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from captum._utils.models.linear_model import (
|
2 |
+
LinearModel,
|
3 |
+
SGDLasso,
|
4 |
+
SGDLinearModel,
|
5 |
+
SGDLinearRegression,
|
6 |
+
SGDRidge,
|
7 |
+
SkLearnLasso,
|
8 |
+
SkLearnLinearModel,
|
9 |
+
SkLearnLinearRegression,
|
10 |
+
SkLearnRidge,
|
11 |
+
)
|
12 |
+
from captum._utils.models.model import Model
|
13 |
+
|
14 |
+
__all__ = [
|
15 |
+
"Model",
|
16 |
+
"LinearModel",
|
17 |
+
"SGDLinearModel",
|
18 |
+
"SGDLasso",
|
19 |
+
"SGDRidge",
|
20 |
+
"SGDLinearRegression",
|
21 |
+
"SkLearnLinearModel",
|
22 |
+
"SkLearnLasso",
|
23 |
+
"SkLearnRidge",
|
24 |
+
"SkLearnLinearRegression",
|
25 |
+
]
|
captum/_utils/models/linear_model/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from captum._utils.models.linear_model.model import (
|
2 |
+
LinearModel,
|
3 |
+
SGDLasso,
|
4 |
+
SGDLinearModel,
|
5 |
+
SGDLinearRegression,
|
6 |
+
SGDRidge,
|
7 |
+
SkLearnLasso,
|
8 |
+
SkLearnLinearModel,
|
9 |
+
SkLearnLinearRegression,
|
10 |
+
SkLearnRidge,
|
11 |
+
)
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"LinearModel",
|
15 |
+
"SGDLinearModel",
|
16 |
+
"SGDLasso",
|
17 |
+
"SGDRidge",
|
18 |
+
"SGDLinearRegression",
|
19 |
+
"SkLearnLinearModel",
|
20 |
+
"SkLearnLasso",
|
21 |
+
"SkLearnRidge",
|
22 |
+
"SkLearnLinearRegression",
|
23 |
+
]
|
captum/_utils/models/linear_model/model.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, cast, List, Optional
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
from captum._utils.models.model import Model
|
5 |
+
from torch import Tensor
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
|
8 |
+
|
9 |
+
class LinearModel(nn.Module, Model):
|
10 |
+
SUPPORTED_NORMS: List[Optional[str]] = [None, "batch_norm", "layer_norm"]
|
11 |
+
|
12 |
+
def __init__(self, train_fn: Callable, **kwargs) -> None:
|
13 |
+
r"""
|
14 |
+
Constructs a linear model with a training function and additional
|
15 |
+
construction arguments that will be sent to
|
16 |
+
`self._construct_model_params` after a `self.fit` is called. Please note
|
17 |
+
that this assumes the `self.train_fn` will call
|
18 |
+
`self._construct_model_params`.
|
19 |
+
|
20 |
+
Please note that this is an experimental feature.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
train_fn (callable)
|
24 |
+
The function to train with. See
|
25 |
+
`captum._utils.models.linear_model.train.sgd_train_linear_model`
|
26 |
+
and
|
27 |
+
`captum._utils.models.linear_model.train.sklearn_train_linear_model`
|
28 |
+
for examples
|
29 |
+
kwargs
|
30 |
+
Any additional keyword arguments to send to
|
31 |
+
`self._construct_model_params` once a `self.fit` is called.
|
32 |
+
"""
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.norm: Optional[nn.Module] = None
|
36 |
+
self.linear: Optional[nn.Linear] = None
|
37 |
+
self.train_fn = train_fn
|
38 |
+
self.construct_kwargs = kwargs
|
39 |
+
|
40 |
+
def _construct_model_params(
|
41 |
+
self,
|
42 |
+
in_features: Optional[int] = None,
|
43 |
+
out_features: Optional[int] = None,
|
44 |
+
norm_type: Optional[str] = None,
|
45 |
+
affine_norm: bool = False,
|
46 |
+
bias: bool = True,
|
47 |
+
weight_values: Optional[Tensor] = None,
|
48 |
+
bias_value: Optional[Tensor] = None,
|
49 |
+
classes: Optional[Tensor] = None,
|
50 |
+
):
|
51 |
+
r"""
|
52 |
+
Lazily initializes a linear model. This will be called for you in a
|
53 |
+
train method.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
in_features (int):
|
57 |
+
The number of input features
|
58 |
+
output_features (int):
|
59 |
+
The number of output features.
|
60 |
+
norm_type (str, optional):
|
61 |
+
The type of normalization that can occur. Please assign this
|
62 |
+
to one of `PyTorchLinearModel.SUPPORTED_NORMS`.
|
63 |
+
affine_norm (bool):
|
64 |
+
Whether or not to learn an affine transformation of the
|
65 |
+
normalization parameters used.
|
66 |
+
bias (bool):
|
67 |
+
Whether to add a bias term. Not needed if normalized input.
|
68 |
+
weight_values (tensor, optional):
|
69 |
+
The values to initialize the linear model with. This must be a
|
70 |
+
1D or 2D tensor, and of the form `(num_outputs, num_features)` or
|
71 |
+
`(num_features,)`. Additionally, if this is provided you need not
|
72 |
+
to provide `in_features` or `out_features`.
|
73 |
+
bias_value (tensor, optional):
|
74 |
+
The bias value to initialize the model with.
|
75 |
+
classes (tensor, optional):
|
76 |
+
The list of prediction classes supported by the model in case it
|
77 |
+
performs classificaton. In case of regression it is set to None.
|
78 |
+
Default: None
|
79 |
+
"""
|
80 |
+
if norm_type not in LinearModel.SUPPORTED_NORMS:
|
81 |
+
raise ValueError(
|
82 |
+
f"{norm_type} not supported. Please use {LinearModel.SUPPORTED_NORMS}"
|
83 |
+
)
|
84 |
+
|
85 |
+
if weight_values is not None:
|
86 |
+
in_features = weight_values.shape[-1]
|
87 |
+
out_features = (
|
88 |
+
1 if len(weight_values.shape) == 1 else weight_values.shape[0]
|
89 |
+
)
|
90 |
+
|
91 |
+
if in_features is None or out_features is None:
|
92 |
+
raise ValueError(
|
93 |
+
"Please provide `in_features` and `out_features` or `weight_values`"
|
94 |
+
)
|
95 |
+
|
96 |
+
if norm_type == "batch_norm":
|
97 |
+
self.norm = nn.BatchNorm1d(in_features, eps=1e-8, affine=affine_norm)
|
98 |
+
elif norm_type == "layer_norm":
|
99 |
+
self.norm = nn.LayerNorm(
|
100 |
+
in_features, eps=1e-8, elementwise_affine=affine_norm
|
101 |
+
)
|
102 |
+
else:
|
103 |
+
self.norm = None
|
104 |
+
|
105 |
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
106 |
+
|
107 |
+
if weight_values is not None:
|
108 |
+
self.linear.weight.data = weight_values
|
109 |
+
|
110 |
+
if bias_value is not None:
|
111 |
+
if not bias:
|
112 |
+
raise ValueError("`bias_value` is not None and bias is False")
|
113 |
+
|
114 |
+
self.linear.bias.data = bias_value
|
115 |
+
|
116 |
+
if classes is not None:
|
117 |
+
self.linear.classes = classes
|
118 |
+
|
119 |
+
def fit(self, train_data: DataLoader, **kwargs):
|
120 |
+
r"""
|
121 |
+
Calls `self.train_fn`
|
122 |
+
"""
|
123 |
+
return self.train_fn(
|
124 |
+
self,
|
125 |
+
dataloader=train_data,
|
126 |
+
construct_kwargs=self.construct_kwargs,
|
127 |
+
**kwargs,
|
128 |
+
)
|
129 |
+
|
130 |
+
def forward(self, x: Tensor) -> Tensor:
|
131 |
+
assert self.linear is not None
|
132 |
+
if self.norm is not None:
|
133 |
+
x = self.norm(x)
|
134 |
+
return self.linear(x)
|
135 |
+
|
136 |
+
def representation(self) -> Tensor:
|
137 |
+
r"""
|
138 |
+
Returns a tensor which describes the hyper-plane input space. This does
|
139 |
+
not include the bias. For bias/intercept, please use `self.bias`
|
140 |
+
"""
|
141 |
+
assert self.linear is not None
|
142 |
+
return self.linear.weight.detach()
|
143 |
+
|
144 |
+
def bias(self) -> Optional[Tensor]:
|
145 |
+
r"""
|
146 |
+
Returns the bias of the linear model
|
147 |
+
"""
|
148 |
+
if self.linear is None or self.linear.bias is None:
|
149 |
+
return None
|
150 |
+
return self.linear.bias.detach()
|
151 |
+
|
152 |
+
def classes(self) -> Optional[Tensor]:
|
153 |
+
if self.linear is None or self.linear.classes is None:
|
154 |
+
return None
|
155 |
+
return cast(Tensor, self.linear.classes).detach()
|
156 |
+
|
157 |
+
|
158 |
+
class SGDLinearModel(LinearModel):
|
159 |
+
def __init__(self, **kwargs) -> None:
|
160 |
+
r"""
|
161 |
+
Factory class. Construct a a `LinearModel` with the
|
162 |
+
`sgd_train_linear_model` as the train method
|
163 |
+
|
164 |
+
Args:
|
165 |
+
kwargs
|
166 |
+
Arguments send to `self._construct_model_params` after
|
167 |
+
`self.fit` is called. Please refer to that method for parameter
|
168 |
+
documentation.
|
169 |
+
"""
|
170 |
+
# avoid cycles
|
171 |
+
from captum._utils.models.linear_model.train import sgd_train_linear_model
|
172 |
+
|
173 |
+
super().__init__(train_fn=sgd_train_linear_model, **kwargs)
|
174 |
+
|
175 |
+
|
176 |
+
class SGDLasso(SGDLinearModel):
|
177 |
+
def __init__(self, **kwargs) -> None:
|
178 |
+
r"""
|
179 |
+
Factory class to train a `LinearModel` with SGD
|
180 |
+
(`sgd_train_linear_model`) whilst setting appropriate parameters to
|
181 |
+
optimize for ridge regression loss. This optimizes L2 loss + alpha * L1
|
182 |
+
regularization.
|
183 |
+
|
184 |
+
Please note that with SGD it is not guaranteed that weights will
|
185 |
+
converge to 0.
|
186 |
+
"""
|
187 |
+
super().__init__(**kwargs)
|
188 |
+
|
189 |
+
def fit(self, train_data: DataLoader, **kwargs):
|
190 |
+
# avoid cycles
|
191 |
+
from captum._utils.models.linear_model.train import l2_loss
|
192 |
+
|
193 |
+
return super().fit(train_data=train_data, loss_fn=l2_loss, reg_term=1, **kwargs)
|
194 |
+
|
195 |
+
|
196 |
+
class SGDRidge(SGDLinearModel):
|
197 |
+
def __init__(self, **kwargs) -> None:
|
198 |
+
r"""
|
199 |
+
Factory class to train a `LinearModel` with SGD
|
200 |
+
(`sgd_train_linear_model`) whilst setting appropriate parameters to
|
201 |
+
optimize for ridge regression loss. This optimizes L2 loss + alpha *
|
202 |
+
L2 regularization.
|
203 |
+
"""
|
204 |
+
super().__init__(**kwargs)
|
205 |
+
|
206 |
+
def fit(self, train_data: DataLoader, **kwargs):
|
207 |
+
# avoid cycles
|
208 |
+
from captum._utils.models.linear_model.train import l2_loss
|
209 |
+
|
210 |
+
return super().fit(train_data=train_data, loss_fn=l2_loss, reg_term=2, **kwargs)
|
211 |
+
|
212 |
+
|
213 |
+
class SGDLinearRegression(SGDLinearModel):
|
214 |
+
def __init__(self, **kwargs) -> None:
|
215 |
+
r"""
|
216 |
+
Factory class to train a `LinearModel` with SGD
|
217 |
+
(`sgd_train_linear_model`). For linear regression this assigns the loss
|
218 |
+
to L2 and no regularization.
|
219 |
+
"""
|
220 |
+
super().__init__(**kwargs)
|
221 |
+
|
222 |
+
def fit(self, train_data: DataLoader, **kwargs):
|
223 |
+
# avoid cycles
|
224 |
+
from captum._utils.models.linear_model.train import l2_loss
|
225 |
+
|
226 |
+
return super().fit(
|
227 |
+
train_data=train_data, loss_fn=l2_loss, reg_term=None, **kwargs
|
228 |
+
)
|
229 |
+
|
230 |
+
|
231 |
+
class SkLearnLinearModel(LinearModel):
|
232 |
+
def __init__(self, sklearn_module: str, **kwargs) -> None:
|
233 |
+
r"""
|
234 |
+
Factory class to construct a `LinearModel` with sklearn training method.
|
235 |
+
|
236 |
+
Please note that this assumes:
|
237 |
+
|
238 |
+
0. You have sklearn and numpy installed
|
239 |
+
1. The dataset can fit into memory
|
240 |
+
|
241 |
+
SkLearn support does introduce some slight overhead as we convert the
|
242 |
+
tensors to numpy and then convert the resulting trained model to a
|
243 |
+
`LinearModel` object. However, this conversion should be negligible.
|
244 |
+
|
245 |
+
Args:
|
246 |
+
sklearn_module
|
247 |
+
The module under sklearn to construct and use for training, e.g.
|
248 |
+
use "svm.LinearSVC" for an SVM or "linear_model.Lasso" for Lasso.
|
249 |
+
|
250 |
+
There are factory classes defined for you for common use cases,
|
251 |
+
such as `SkLearnLasso`.
|
252 |
+
kwargs
|
253 |
+
The kwargs to pass to the construction of the sklearn model
|
254 |
+
"""
|
255 |
+
# avoid cycles
|
256 |
+
from captum._utils.models.linear_model.train import sklearn_train_linear_model
|
257 |
+
|
258 |
+
super().__init__(train_fn=sklearn_train_linear_model, **kwargs)
|
259 |
+
|
260 |
+
self.sklearn_module = sklearn_module
|
261 |
+
|
262 |
+
def fit(self, train_data: DataLoader, **kwargs):
|
263 |
+
r"""
|
264 |
+
Args:
|
265 |
+
train_data
|
266 |
+
Train data to use
|
267 |
+
kwargs
|
268 |
+
Arguments to feed to `.fit` method for sklearn
|
269 |
+
"""
|
270 |
+
return super().fit(
|
271 |
+
train_data=train_data, sklearn_trainer=self.sklearn_module, **kwargs
|
272 |
+
)
|
273 |
+
|
274 |
+
|
275 |
+
class SkLearnLasso(SkLearnLinearModel):
|
276 |
+
def __init__(self, **kwargs) -> None:
|
277 |
+
r"""
|
278 |
+
Factory class. Trains a `LinearModel` model with
|
279 |
+
`sklearn.linear_model.Lasso`. You will need sklearn version >= 0.23 to
|
280 |
+
support sample weights.
|
281 |
+
"""
|
282 |
+
super().__init__(sklearn_module="linear_model.Lasso", **kwargs)
|
283 |
+
|
284 |
+
def fit(self, train_data: DataLoader, **kwargs):
|
285 |
+
return super().fit(train_data=train_data, **kwargs)
|
286 |
+
|
287 |
+
|
288 |
+
class SkLearnRidge(SkLearnLinearModel):
|
289 |
+
def __init__(self, **kwargs) -> None:
|
290 |
+
r"""
|
291 |
+
Factory class. Trains a model with `sklearn.linear_model.Ridge`.
|
292 |
+
|
293 |
+
Any arguments provided to the sklearn constructor can be provided
|
294 |
+
as kwargs here.
|
295 |
+
"""
|
296 |
+
super().__init__(sklearn_module="linear_model.Ridge", **kwargs)
|
297 |
+
|
298 |
+
def fit(self, train_data: DataLoader, **kwargs):
|
299 |
+
return super().fit(train_data=train_data, **kwargs)
|
300 |
+
|
301 |
+
|
302 |
+
class SkLearnLinearRegression(SkLearnLinearModel):
|
303 |
+
def __init__(self, **kwargs) -> None:
|
304 |
+
r"""
|
305 |
+
Factory class. Trains a model with `sklearn.linear_model.LinearRegression`.
|
306 |
+
|
307 |
+
Any arguments provided to the sklearn constructor can be provided
|
308 |
+
as kwargs here.
|
309 |
+
"""
|
310 |
+
super().__init__(sklearn_module="linear_model.LinearRegression", **kwargs)
|
311 |
+
|
312 |
+
def fit(self, train_data: DataLoader, **kwargs):
|
313 |
+
return super().fit(train_data=train_data, **kwargs)
|
314 |
+
|
315 |
+
|
316 |
+
class SkLearnLogisticRegression(SkLearnLinearModel):
|
317 |
+
def __init__(self, **kwargs) -> None:
|
318 |
+
r"""
|
319 |
+
Factory class. Trains a model with `sklearn.linear_model.LogisticRegression`.
|
320 |
+
|
321 |
+
Any arguments provided to the sklearn constructor can be provided
|
322 |
+
as kwargs here.
|
323 |
+
"""
|
324 |
+
super().__init__(sklearn_module="linear_model.LogisticRegression", **kwargs)
|
325 |
+
|
326 |
+
def fit(self, train_data: DataLoader, **kwargs):
|
327 |
+
return super().fit(train_data=train_data, **kwargs)
|
328 |
+
|
329 |
+
|
330 |
+
class SkLearnSGDClassifier(SkLearnLinearModel):
|
331 |
+
def __init__(self, **kwargs) -> None:
|
332 |
+
r"""
|
333 |
+
Factory class. Trains a model with `sklearn.linear_model.SGDClassifier(`.
|
334 |
+
|
335 |
+
Any arguments provided to the sklearn constructor can be provided
|
336 |
+
as kwargs here.
|
337 |
+
"""
|
338 |
+
super().__init__(sklearn_module="linear_model.SGDClassifier", **kwargs)
|
339 |
+
|
340 |
+
def fit(self, train_data: DataLoader, **kwargs):
|
341 |
+
return super().fit(train_data=train_data, **kwargs)
|
captum/_utils/models/linear_model/train.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import warnings
|
3 |
+
from typing import Any, Callable, Dict, List, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from captum._utils.models.linear_model.model import LinearModel
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
|
10 |
+
|
11 |
+
def l2_loss(x1, x2, weights=None):
|
12 |
+
if weights is None:
|
13 |
+
return torch.mean((x1 - x2) ** 2) / 2.0
|
14 |
+
else:
|
15 |
+
return torch.sum((weights / weights.norm(p=1)) * ((x1 - x2) ** 2)) / 2.0
|
16 |
+
|
17 |
+
|
18 |
+
def sgd_train_linear_model(
|
19 |
+
model: LinearModel,
|
20 |
+
dataloader: DataLoader,
|
21 |
+
construct_kwargs: Dict[str, Any],
|
22 |
+
max_epoch: int = 100,
|
23 |
+
reduce_lr: bool = True,
|
24 |
+
initial_lr: float = 0.01,
|
25 |
+
alpha: float = 1.0,
|
26 |
+
loss_fn: Callable = l2_loss,
|
27 |
+
reg_term: Optional[int] = 1,
|
28 |
+
patience: int = 10,
|
29 |
+
threshold: float = 1e-4,
|
30 |
+
running_loss_window: Optional[int] = None,
|
31 |
+
device: Optional[str] = None,
|
32 |
+
init_scheme: str = "zeros",
|
33 |
+
debug: bool = False,
|
34 |
+
) -> Dict[str, float]:
|
35 |
+
r"""
|
36 |
+
Trains a linear model with SGD. This will continue to iterate your
|
37 |
+
dataloader until we converged to a solution or alternatively until we have
|
38 |
+
exhausted `max_epoch`.
|
39 |
+
|
40 |
+
Convergence is defined by the loss not changing by `threshold` amount for
|
41 |
+
`patience` number of iterations.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
model
|
45 |
+
The model to train
|
46 |
+
dataloader
|
47 |
+
The data to train it with. We will assume the dataloader produces
|
48 |
+
either pairs or triples of the form (x, y) or (x, y, w). Where x and
|
49 |
+
y are typical pairs for supervised learning and w is a weight
|
50 |
+
vector.
|
51 |
+
|
52 |
+
We will call `model._construct_model_params` with construct_kwargs
|
53 |
+
and the input features set to `x.shape[1]` (`x.shape[0]` corresponds
|
54 |
+
to the batch size). We assume that `len(x.shape) == 2`, i.e. the
|
55 |
+
tensor is flat. The number of output features will be set to
|
56 |
+
y.shape[1] or 1 (if `len(y.shape) == 1`); we require `len(y.shape)
|
57 |
+
<= 2`.
|
58 |
+
max_epoch
|
59 |
+
The maximum number of epochs to exhaust
|
60 |
+
reduce_lr
|
61 |
+
Whether or not to reduce the learning rate as iterations progress.
|
62 |
+
Halves the learning rate when the training loss does not move. This
|
63 |
+
uses torch.optim.lr_scheduler.ReduceLROnPlateau and uses the
|
64 |
+
parameters `patience` and `threshold`
|
65 |
+
initial_lr
|
66 |
+
The initial learning rate to use.
|
67 |
+
alpha
|
68 |
+
A constant for the regularization term.
|
69 |
+
loss_fn
|
70 |
+
The loss to optimise for. This must accept three parameters:
|
71 |
+
x1 (predicted), x2 (labels) and a weight vector
|
72 |
+
reg_term
|
73 |
+
Regularization is defined by the `reg_term` norm of the weights.
|
74 |
+
Please use `None` if you do not wish to use regularization.
|
75 |
+
patience
|
76 |
+
Defines the number of iterations in a row the loss must remain
|
77 |
+
within `threshold` in order to be classified as converged.
|
78 |
+
threshold
|
79 |
+
Threshold for convergence detection.
|
80 |
+
running_loss_window
|
81 |
+
Used to report the training loss once we have finished training and
|
82 |
+
to determine when we have converged (along with reducing the
|
83 |
+
learning rate).
|
84 |
+
|
85 |
+
The reported training loss will take the last `running_loss_window`
|
86 |
+
iterations and average them.
|
87 |
+
|
88 |
+
If `None` we will approximate this to be the number of examples in
|
89 |
+
an epoch.
|
90 |
+
init_scheme
|
91 |
+
Initialization to use prior to training the linear model.
|
92 |
+
device
|
93 |
+
The device to send the model and data to. If None then no `.to` call
|
94 |
+
will be used.
|
95 |
+
debug
|
96 |
+
Whether to print the loss, learning rate per iteration
|
97 |
+
|
98 |
+
Returns
|
99 |
+
This will return the final training loss (averaged with
|
100 |
+
`running_loss_window`)
|
101 |
+
"""
|
102 |
+
|
103 |
+
loss_window: List[torch.Tensor] = []
|
104 |
+
min_avg_loss = None
|
105 |
+
convergence_counter = 0
|
106 |
+
converged = False
|
107 |
+
|
108 |
+
def get_point(datapoint):
|
109 |
+
if len(datapoint) == 2:
|
110 |
+
x, y = datapoint
|
111 |
+
w = None
|
112 |
+
else:
|
113 |
+
x, y, w = datapoint
|
114 |
+
|
115 |
+
if device is not None:
|
116 |
+
x = x.to(device)
|
117 |
+
y = y.to(device)
|
118 |
+
if w is not None:
|
119 |
+
w = w.to(device)
|
120 |
+
|
121 |
+
return x, y, w
|
122 |
+
|
123 |
+
# get a point and construct the model
|
124 |
+
data_iter = iter(dataloader)
|
125 |
+
x, y, w = get_point(next(data_iter))
|
126 |
+
|
127 |
+
model._construct_model_params(
|
128 |
+
in_features=x.shape[1],
|
129 |
+
out_features=y.shape[1] if len(y.shape) == 2 else 1,
|
130 |
+
**construct_kwargs,
|
131 |
+
)
|
132 |
+
model.train()
|
133 |
+
|
134 |
+
assert model.linear is not None
|
135 |
+
|
136 |
+
if init_scheme is not None:
|
137 |
+
assert init_scheme in ["xavier", "zeros"]
|
138 |
+
|
139 |
+
with torch.no_grad():
|
140 |
+
if init_scheme == "xavier":
|
141 |
+
torch.nn.init.xavier_uniform_(model.linear.weight)
|
142 |
+
else:
|
143 |
+
model.linear.weight.zero_()
|
144 |
+
|
145 |
+
if model.linear.bias is not None:
|
146 |
+
model.linear.bias.zero_()
|
147 |
+
|
148 |
+
optim = torch.optim.SGD(model.parameters(), lr=initial_lr)
|
149 |
+
if reduce_lr:
|
150 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
151 |
+
optim, factor=0.5, patience=patience, threshold=threshold
|
152 |
+
)
|
153 |
+
|
154 |
+
t1 = time.time()
|
155 |
+
epoch = 0
|
156 |
+
i = 0
|
157 |
+
while epoch < max_epoch:
|
158 |
+
while True: # for x, y, w in dataloader
|
159 |
+
if running_loss_window is None:
|
160 |
+
running_loss_window = x.shape[0] * len(dataloader)
|
161 |
+
|
162 |
+
y = y.view(x.shape[0], -1)
|
163 |
+
if w is not None:
|
164 |
+
w = w.view(x.shape[0], -1)
|
165 |
+
|
166 |
+
i += 1
|
167 |
+
|
168 |
+
out = model(x)
|
169 |
+
|
170 |
+
loss = loss_fn(y, out, w)
|
171 |
+
if reg_term is not None:
|
172 |
+
reg = torch.norm(model.linear.weight, p=reg_term)
|
173 |
+
loss += reg.sum() * alpha
|
174 |
+
|
175 |
+
if len(loss_window) >= running_loss_window:
|
176 |
+
loss_window = loss_window[1:]
|
177 |
+
loss_window.append(loss.clone().detach())
|
178 |
+
assert len(loss_window) <= running_loss_window
|
179 |
+
|
180 |
+
average_loss = torch.mean(torch.stack(loss_window))
|
181 |
+
if min_avg_loss is not None:
|
182 |
+
# if we haven't improved by at least `threshold`
|
183 |
+
if average_loss > min_avg_loss or torch.isclose(
|
184 |
+
min_avg_loss, average_loss, atol=threshold
|
185 |
+
):
|
186 |
+
convergence_counter += 1
|
187 |
+
if convergence_counter >= patience:
|
188 |
+
converged = True
|
189 |
+
break
|
190 |
+
else:
|
191 |
+
convergence_counter = 0
|
192 |
+
if min_avg_loss is None or min_avg_loss >= average_loss:
|
193 |
+
min_avg_loss = average_loss.clone()
|
194 |
+
|
195 |
+
if debug:
|
196 |
+
print(
|
197 |
+
f"lr={optim.param_groups[0]['lr']}, Loss={loss},"
|
198 |
+
+ "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
|
199 |
+
)
|
200 |
+
|
201 |
+
loss.backward()
|
202 |
+
|
203 |
+
optim.step()
|
204 |
+
model.zero_grad()
|
205 |
+
if scheduler:
|
206 |
+
scheduler.step(average_loss)
|
207 |
+
|
208 |
+
temp = next(data_iter, None)
|
209 |
+
if temp is None:
|
210 |
+
break
|
211 |
+
x, y, w = get_point(temp)
|
212 |
+
|
213 |
+
if converged:
|
214 |
+
break
|
215 |
+
|
216 |
+
epoch += 1
|
217 |
+
data_iter = iter(dataloader)
|
218 |
+
x, y, w = get_point(next(data_iter))
|
219 |
+
|
220 |
+
t2 = time.time()
|
221 |
+
return {
|
222 |
+
"train_time": t2 - t1,
|
223 |
+
"train_loss": torch.mean(torch.stack(loss_window)).item(),
|
224 |
+
"train_iter": i,
|
225 |
+
"train_epoch": epoch,
|
226 |
+
}
|
227 |
+
|
228 |
+
|
229 |
+
class NormLayer(nn.Module):
|
230 |
+
def __init__(self, mean, std, n=None, eps=1e-8) -> None:
|
231 |
+
super().__init__()
|
232 |
+
self.mean = mean
|
233 |
+
self.std = std
|
234 |
+
self.eps = eps
|
235 |
+
|
236 |
+
def forward(self, x):
|
237 |
+
return (x - self.mean) / (self.std + self.eps)
|
238 |
+
|
239 |
+
|
240 |
+
def sklearn_train_linear_model(
|
241 |
+
model: LinearModel,
|
242 |
+
dataloader: DataLoader,
|
243 |
+
construct_kwargs: Dict[str, Any],
|
244 |
+
sklearn_trainer: str = "Lasso",
|
245 |
+
norm_input: bool = False,
|
246 |
+
**fit_kwargs,
|
247 |
+
):
|
248 |
+
r"""
|
249 |
+
Alternative method to train with sklearn. This does introduce some slight
|
250 |
+
overhead as we convert the tensors to numpy and then convert the resulting
|
251 |
+
trained model to a `LinearModel` object. However, this conversion
|
252 |
+
should be negligible.
|
253 |
+
|
254 |
+
Please note that this assumes:
|
255 |
+
|
256 |
+
0. You have sklearn and numpy installed
|
257 |
+
1. The dataset can fit into memory
|
258 |
+
|
259 |
+
Args
|
260 |
+
model
|
261 |
+
The model to train.
|
262 |
+
dataloader
|
263 |
+
The data to use. This will be exhausted and converted to numpy
|
264 |
+
arrays. Therefore please do not feed an infinite dataloader.
|
265 |
+
norm_input
|
266 |
+
Whether or not to normalize the input
|
267 |
+
sklearn_trainer
|
268 |
+
The sklearn model to use to train the model. Please refer to
|
269 |
+
sklearn.linear_model for a list of modules to use.
|
270 |
+
construct_kwargs
|
271 |
+
Additional arguments provided to the `sklearn_trainer` constructor
|
272 |
+
fit_kwargs
|
273 |
+
Other arguments to send to `sklearn_trainer`'s `.fit` method
|
274 |
+
"""
|
275 |
+
from functools import reduce
|
276 |
+
|
277 |
+
try:
|
278 |
+
import numpy as np
|
279 |
+
except ImportError:
|
280 |
+
raise ValueError("numpy is not available. Please install numpy.")
|
281 |
+
|
282 |
+
try:
|
283 |
+
import sklearn
|
284 |
+
import sklearn.linear_model
|
285 |
+
import sklearn.svm
|
286 |
+
except ImportError:
|
287 |
+
raise ValueError("sklearn is not available. Please install sklearn >= 0.23")
|
288 |
+
|
289 |
+
if not sklearn.__version__ >= "0.23.0":
|
290 |
+
warnings.warn(
|
291 |
+
"Must have sklearn version 0.23.0 or higher to use "
|
292 |
+
"sample_weight in Lasso regression."
|
293 |
+
)
|
294 |
+
|
295 |
+
num_batches = 0
|
296 |
+
xs, ys, ws = [], [], []
|
297 |
+
for data in dataloader:
|
298 |
+
if len(data) == 3:
|
299 |
+
x, y, w = data
|
300 |
+
else:
|
301 |
+
assert len(data) == 2
|
302 |
+
x, y = data
|
303 |
+
w = None
|
304 |
+
|
305 |
+
xs.append(x.cpu().numpy())
|
306 |
+
ys.append(y.cpu().numpy())
|
307 |
+
if w is not None:
|
308 |
+
ws.append(w.cpu().numpy())
|
309 |
+
num_batches += 1
|
310 |
+
|
311 |
+
x = np.concatenate(xs, axis=0)
|
312 |
+
y = np.concatenate(ys, axis=0)
|
313 |
+
if len(ws) > 0:
|
314 |
+
w = np.concatenate(ws, axis=0)
|
315 |
+
else:
|
316 |
+
w = None
|
317 |
+
|
318 |
+
if norm_input:
|
319 |
+
mean, std = x.mean(0), x.std(0)
|
320 |
+
x -= mean
|
321 |
+
x /= std
|
322 |
+
|
323 |
+
t1 = time.time()
|
324 |
+
sklearn_model = reduce(
|
325 |
+
lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".")
|
326 |
+
)(**construct_kwargs)
|
327 |
+
try:
|
328 |
+
sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
|
329 |
+
except TypeError:
|
330 |
+
sklearn_model.fit(x, y, **fit_kwargs)
|
331 |
+
warnings.warn(
|
332 |
+
"Sample weight is not supported for the provided linear model!"
|
333 |
+
" Trained model without weighting inputs. For Lasso, please"
|
334 |
+
" upgrade sklearn to a version >= 0.23.0."
|
335 |
+
)
|
336 |
+
|
337 |
+
t2 = time.time()
|
338 |
+
|
339 |
+
# Convert weights to pytorch
|
340 |
+
classes = (
|
341 |
+
torch.IntTensor(sklearn_model.classes_)
|
342 |
+
if hasattr(sklearn_model, "classes_")
|
343 |
+
else None
|
344 |
+
)
|
345 |
+
|
346 |
+
# extract model device
|
347 |
+
device = model.device if hasattr(model, "device") else "cpu"
|
348 |
+
|
349 |
+
num_outputs = sklearn_model.coef_.shape[0] if sklearn_model.coef_.ndim > 1 else 1
|
350 |
+
weight_values = torch.FloatTensor(sklearn_model.coef_).to(device) # type: ignore
|
351 |
+
bias_values = torch.FloatTensor([sklearn_model.intercept_]).to( # type: ignore
|
352 |
+
device # type: ignore
|
353 |
+
) # type: ignore
|
354 |
+
model._construct_model_params(
|
355 |
+
norm_type=None,
|
356 |
+
weight_values=weight_values.view(num_outputs, -1),
|
357 |
+
bias_value=bias_values.squeeze().unsqueeze(0),
|
358 |
+
classes=classes,
|
359 |
+
)
|
360 |
+
|
361 |
+
if norm_input:
|
362 |
+
model.norm = NormLayer(mean, std)
|
363 |
+
|
364 |
+
return {"train_time": t2 - t1}
|
captum/_utils/models/model.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from typing import Dict, Optional, Union
|
5 |
+
|
6 |
+
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
|
7 |
+
from torch import Tensor
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
|
10 |
+
|
11 |
+
class Model(ABC):
|
12 |
+
r"""
|
13 |
+
Abstract Class to describe the interface of a trainable model to be used
|
14 |
+
within the algorithms of captum.
|
15 |
+
|
16 |
+
Please note that this is an experimental feature.
|
17 |
+
"""
|
18 |
+
|
19 |
+
@abstractmethod
|
20 |
+
def fit(
|
21 |
+
self, train_data: DataLoader, **kwargs
|
22 |
+
) -> Optional[Dict[str, Union[int, float, Tensor]]]:
|
23 |
+
r"""
|
24 |
+
Override this method to actually train your model.
|
25 |
+
|
26 |
+
The specification of the dataloader will be supplied by the algorithm
|
27 |
+
you are using within captum. This will likely be a supervised learning
|
28 |
+
task, thus you should expect batched (x, y) pairs or (x, y, w) triples.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
train_data (DataLoader):
|
32 |
+
The data to train on
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Optional statistics about training, e.g. iterations it took to
|
36 |
+
train, training loss, etc.
|
37 |
+
"""
|
38 |
+
pass
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
def representation(self) -> Tensor:
|
42 |
+
r"""
|
43 |
+
Returns the underlying representation of the interpretable model. For a
|
44 |
+
linear model this is simply a tensor (the concatenation of weights
|
45 |
+
and bias). For something slightly more complicated, such as a decision
|
46 |
+
tree, this could be the nodes of a decision tree.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
A Tensor describing the representation of the model.
|
50 |
+
"""
|
51 |
+
pass
|
52 |
+
|
53 |
+
@abstractmethod
|
54 |
+
def __call__(
|
55 |
+
self, x: TensorOrTupleOfTensorsGeneric
|
56 |
+
) -> TensorOrTupleOfTensorsGeneric:
|
57 |
+
r"""
|
58 |
+
Predicts with the interpretable model.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
x (TensorOrTupleOfTensorsGeneric)
|
62 |
+
A batched input of tensor(s) to the model to predict
|
63 |
+
Returns:
|
64 |
+
The prediction of the input as a TensorOrTupleOfTensorsGeneric.
|
65 |
+
"""
|
66 |
+
pass
|
captum/_utils/progress.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import warnings
|
5 |
+
from time import time
|
6 |
+
from typing import cast, Iterable, Sized, TextIO
|
7 |
+
|
8 |
+
try:
|
9 |
+
from tqdm import tqdm
|
10 |
+
except ImportError:
|
11 |
+
tqdm = None
|
12 |
+
|
13 |
+
|
14 |
+
class DisableErrorIOWrapper(object):
|
15 |
+
def __init__(self, wrapped: TextIO):
|
16 |
+
"""
|
17 |
+
The wrapper around a TextIO object to ignore write errors like tqdm
|
18 |
+
https://github.com/tqdm/tqdm/blob/bcce20f771a16cb8e4ac5cc5b2307374a2c0e535/tqdm/utils.py#L131
|
19 |
+
"""
|
20 |
+
self._wrapped = wrapped
|
21 |
+
|
22 |
+
def __getattr__(self, name):
|
23 |
+
return getattr(self._wrapped, name)
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def _wrapped_run(func, *args, **kwargs):
|
27 |
+
try:
|
28 |
+
return func(*args, **kwargs)
|
29 |
+
except OSError as e:
|
30 |
+
if e.errno != 5:
|
31 |
+
raise
|
32 |
+
except ValueError as e:
|
33 |
+
if "closed" not in str(e):
|
34 |
+
raise
|
35 |
+
|
36 |
+
def write(self, *args, **kwargs):
|
37 |
+
return self._wrapped_run(self._wrapped.write, *args, **kwargs)
|
38 |
+
|
39 |
+
def flush(self, *args, **kwargs):
|
40 |
+
return self._wrapped_run(self._wrapped.flush, *args, **kwargs)
|
41 |
+
|
42 |
+
|
43 |
+
class SimpleProgress:
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
iterable: Iterable = None,
|
47 |
+
desc: str = None,
|
48 |
+
total: int = None,
|
49 |
+
file: TextIO = None,
|
50 |
+
mininterval: float = 0.5,
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
Simple progress output used when tqdm is unavailable.
|
54 |
+
Same as tqdm, output to stderr channel
|
55 |
+
"""
|
56 |
+
self.cur = 0
|
57 |
+
|
58 |
+
self.iterable = iterable
|
59 |
+
self.total = total
|
60 |
+
if total is None and hasattr(iterable, "__len__"):
|
61 |
+
self.total = len(cast(Sized, iterable))
|
62 |
+
|
63 |
+
self.desc = desc
|
64 |
+
|
65 |
+
file = DisableErrorIOWrapper(file if file else sys.stderr)
|
66 |
+
cast(TextIO, file)
|
67 |
+
self.file = file
|
68 |
+
|
69 |
+
self.mininterval = mininterval
|
70 |
+
self.last_print_t = 0.0
|
71 |
+
self.closed = False
|
72 |
+
|
73 |
+
def __iter__(self):
|
74 |
+
if self.closed or not self.iterable:
|
75 |
+
return
|
76 |
+
self._refresh()
|
77 |
+
for it in self.iterable:
|
78 |
+
yield it
|
79 |
+
self.update()
|
80 |
+
self.close()
|
81 |
+
|
82 |
+
def _refresh(self):
|
83 |
+
progress_str = self.desc + ": " if self.desc else ""
|
84 |
+
if self.total:
|
85 |
+
# e.g., progress: 60% 3/5
|
86 |
+
progress_str += f"{100 * self.cur // self.total}% {self.cur}/{self.total}"
|
87 |
+
else:
|
88 |
+
# e.g., progress: .....
|
89 |
+
progress_str += "." * self.cur
|
90 |
+
|
91 |
+
print("\r" + progress_str, end="", file=self.file)
|
92 |
+
|
93 |
+
def update(self, amount: int = 1):
|
94 |
+
if self.closed:
|
95 |
+
return
|
96 |
+
self.cur += amount
|
97 |
+
|
98 |
+
cur_t = time()
|
99 |
+
if cur_t - self.last_print_t >= self.mininterval:
|
100 |
+
self._refresh()
|
101 |
+
self.last_print_t = cur_t
|
102 |
+
|
103 |
+
def close(self):
|
104 |
+
if not self.closed:
|
105 |
+
self._refresh()
|
106 |
+
print(file=self.file) # end with new line
|
107 |
+
self.closed = True
|
108 |
+
|
109 |
+
|
110 |
+
def progress(
|
111 |
+
iterable: Iterable = None,
|
112 |
+
desc: str = None,
|
113 |
+
total: int = None,
|
114 |
+
use_tqdm=True,
|
115 |
+
file: TextIO = None,
|
116 |
+
mininterval: float = 0.5,
|
117 |
+
**kwargs,
|
118 |
+
):
|
119 |
+
# Try to use tqdm is possible. Fall back to simple progress print
|
120 |
+
if tqdm and use_tqdm:
|
121 |
+
return tqdm(
|
122 |
+
iterable,
|
123 |
+
desc=desc,
|
124 |
+
total=total,
|
125 |
+
file=file,
|
126 |
+
mininterval=mininterval,
|
127 |
+
**kwargs,
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
if not tqdm and use_tqdm:
|
131 |
+
warnings.warn(
|
132 |
+
"Tried to show progress with tqdm "
|
133 |
+
"but tqdm is not installed. "
|
134 |
+
"Fall back to simply print out the progress."
|
135 |
+
)
|
136 |
+
return SimpleProgress(
|
137 |
+
iterable, desc=desc, total=total, file=file, mininterval=mininterval
|
138 |
+
)
|
captum/_utils/sample_gradient.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from enum import Enum
|
3 |
+
from typing import cast, Iterable, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from captum._utils.common import _format_tensor_into_tuples, _register_backward_hook
|
7 |
+
from torch import Tensor
|
8 |
+
from torch.nn import Module
|
9 |
+
|
10 |
+
|
11 |
+
def _reset_sample_grads(module: Module):
|
12 |
+
module.weight.sample_grad = 0 # type: ignore
|
13 |
+
if module.bias is not None:
|
14 |
+
module.bias.sample_grad = 0 # type: ignore
|
15 |
+
|
16 |
+
|
17 |
+
def linear_param_grads(
|
18 |
+
module: Module, activation: Tensor, gradient_out: Tensor, reset: bool = False
|
19 |
+
) -> None:
|
20 |
+
r"""
|
21 |
+
Computes parameter gradients per sample for nn.Linear module, given module
|
22 |
+
input activations and output gradients.
|
23 |
+
|
24 |
+
Gradients are accumulated in the sample_grad attribute of each parameter
|
25 |
+
(weight and bias). If reset = True, any current sample_grad values are reset,
|
26 |
+
otherwise computed gradients are accumulated and added to the existing
|
27 |
+
stored gradients.
|
28 |
+
|
29 |
+
Inputs with more than 2 dimensions are only supported with torch 1.8 or later
|
30 |
+
"""
|
31 |
+
if reset:
|
32 |
+
_reset_sample_grads(module)
|
33 |
+
|
34 |
+
module.weight.sample_grad += torch.einsum( # type: ignore
|
35 |
+
"n...i,n...j->nij", gradient_out, activation
|
36 |
+
)
|
37 |
+
if module.bias is not None:
|
38 |
+
module.bias.sample_grad += torch.einsum( # type: ignore
|
39 |
+
"n...i->ni", gradient_out
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def conv2d_param_grads(
|
44 |
+
module: Module, activation: Tensor, gradient_out: Tensor, reset: bool = False
|
45 |
+
) -> None:
|
46 |
+
r"""
|
47 |
+
Computes parameter gradients per sample for nn.Conv2d module, given module
|
48 |
+
input activations and output gradients.
|
49 |
+
|
50 |
+
nn.Conv2d modules with padding set to a string option ('same' or 'valid') are
|
51 |
+
currently unsupported.
|
52 |
+
|
53 |
+
Gradients are accumulated in the sample_grad attribute of each parameter
|
54 |
+
(weight and bias). If reset = True, any current sample_grad values are reset,
|
55 |
+
otherwise computed gradients are accumulated and added to the existing
|
56 |
+
stored gradients.
|
57 |
+
"""
|
58 |
+
if reset:
|
59 |
+
_reset_sample_grads(module)
|
60 |
+
|
61 |
+
batch_size = cast(int, activation.shape[0])
|
62 |
+
unfolded_act = torch.nn.functional.unfold(
|
63 |
+
activation,
|
64 |
+
cast(Union[int, Tuple[int, ...]], module.kernel_size),
|
65 |
+
dilation=cast(Union[int, Tuple[int, ...]], module.dilation),
|
66 |
+
padding=cast(Union[int, Tuple[int, ...]], module.padding),
|
67 |
+
stride=cast(Union[int, Tuple[int, ...]], module.stride),
|
68 |
+
)
|
69 |
+
reshaped_grad = gradient_out.reshape(batch_size, -1, unfolded_act.shape[-1])
|
70 |
+
grad1 = torch.einsum("ijk,ilk->ijl", reshaped_grad, unfolded_act)
|
71 |
+
shape = [batch_size] + list(cast(Iterable[int], module.weight.shape))
|
72 |
+
module.weight.sample_grad += grad1.reshape(shape) # type: ignore
|
73 |
+
if module.bias is not None:
|
74 |
+
module.bias.sample_grad += torch.sum(reshaped_grad, dim=2) # type: ignore
|
75 |
+
|
76 |
+
|
77 |
+
SUPPORTED_MODULES = {
|
78 |
+
torch.nn.Conv2d: conv2d_param_grads,
|
79 |
+
torch.nn.Linear: linear_param_grads,
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
class LossMode(Enum):
|
84 |
+
SUM = 0
|
85 |
+
MEAN = 1
|
86 |
+
|
87 |
+
|
88 |
+
class SampleGradientWrapper:
|
89 |
+
r"""
|
90 |
+
Wrapper which allows computing sample-wise gradients in a single backward pass.
|
91 |
+
|
92 |
+
This is accomplished by adding hooks to capture activations and output
|
93 |
+
gradients for supported modules, and using these activations and gradients
|
94 |
+
to compute the parameter gradients per-sample.
|
95 |
+
|
96 |
+
Currently, only nn.Linear and nn.Conv2d modules are supported.
|
97 |
+
|
98 |
+
Similar reference implementations of sample-based gradients include:
|
99 |
+
- https://github.com/cybertronai/autograd-hacks
|
100 |
+
- https://github.com/pytorch/opacus/tree/main/opacus/grad_sample
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, model):
|
104 |
+
self.model = model
|
105 |
+
self.hooks_added = False
|
106 |
+
self.activation_dict = defaultdict(list)
|
107 |
+
self.gradient_dict = defaultdict(list)
|
108 |
+
self.forward_hooks = []
|
109 |
+
self.backward_hooks = []
|
110 |
+
|
111 |
+
def add_hooks(self):
|
112 |
+
self.hooks_added = True
|
113 |
+
self.model.apply(self._register_module_hooks)
|
114 |
+
|
115 |
+
def _register_module_hooks(self, module: torch.nn.Module):
|
116 |
+
if isinstance(module, tuple(SUPPORTED_MODULES.keys())):
|
117 |
+
self.forward_hooks.append(
|
118 |
+
module.register_forward_hook(self._forward_hook_fn)
|
119 |
+
)
|
120 |
+
self.backward_hooks.append(
|
121 |
+
_register_backward_hook(module, self._backward_hook_fn, None)
|
122 |
+
)
|
123 |
+
|
124 |
+
def _forward_hook_fn(
|
125 |
+
self,
|
126 |
+
module: Module,
|
127 |
+
module_input: Union[Tensor, Tuple[Tensor, ...]],
|
128 |
+
module_output: Union[Tensor, Tuple[Tensor, ...]],
|
129 |
+
):
|
130 |
+
inp_tuple = _format_tensor_into_tuples(module_input)
|
131 |
+
self.activation_dict[module].append(inp_tuple[0].clone().detach())
|
132 |
+
|
133 |
+
def _backward_hook_fn(
|
134 |
+
self,
|
135 |
+
module: Module,
|
136 |
+
grad_input: Union[Tensor, Tuple[Tensor, ...]],
|
137 |
+
grad_output: Union[Tensor, Tuple[Tensor, ...]],
|
138 |
+
):
|
139 |
+
grad_output_tuple = _format_tensor_into_tuples(grad_output)
|
140 |
+
self.gradient_dict[module].append(grad_output_tuple[0].clone().detach())
|
141 |
+
|
142 |
+
def remove_hooks(self):
|
143 |
+
self.hooks_added = False
|
144 |
+
|
145 |
+
for hook in self.forward_hooks:
|
146 |
+
hook.remove()
|
147 |
+
|
148 |
+
for hook in self.backward_hooks:
|
149 |
+
hook.remove()
|
150 |
+
|
151 |
+
self.forward_hooks = []
|
152 |
+
self.backward_hooks = []
|
153 |
+
|
154 |
+
def _reset(self):
|
155 |
+
self.activation_dict = defaultdict(list)
|
156 |
+
self.gradient_dict = defaultdict(list)
|
157 |
+
|
158 |
+
def compute_param_sample_gradients(self, loss_blob, loss_mode="mean"):
|
159 |
+
assert (
|
160 |
+
loss_mode.upper() in LossMode.__members__
|
161 |
+
), f"Provided loss mode {loss_mode} is not valid"
|
162 |
+
mode = LossMode[loss_mode.upper()]
|
163 |
+
|
164 |
+
self.model.zero_grad()
|
165 |
+
loss_blob.backward(gradient=torch.ones_like(loss_blob))
|
166 |
+
|
167 |
+
for module in self.gradient_dict:
|
168 |
+
sample_grad_fn = SUPPORTED_MODULES[type(module)]
|
169 |
+
activations = self.activation_dict[module]
|
170 |
+
gradients = self.gradient_dict[module]
|
171 |
+
assert len(activations) == len(gradients), (
|
172 |
+
"Number of saved activations do not match number of saved gradients."
|
173 |
+
" This may occur if multiple forward passes are run without calling"
|
174 |
+
" reset or computing param gradients."
|
175 |
+
)
|
176 |
+
# Reversing grads since when a module is used multiple times,
|
177 |
+
# the activations will be aligned with the reverse order of the gradients,
|
178 |
+
# since the order is reversed in backprop.
|
179 |
+
for i, (act, grad) in enumerate(
|
180 |
+
zip(activations, list(reversed(gradients)))
|
181 |
+
):
|
182 |
+
mult = 1 if mode is LossMode.SUM else act.shape[0]
|
183 |
+
sample_grad_fn(module, act, grad * mult, reset=(i == 0))
|
184 |
+
self._reset()
|
captum/_utils/typing.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
from typing import List, Tuple, TYPE_CHECKING, TypeVar, Union
|
4 |
+
|
5 |
+
from torch import Tensor
|
6 |
+
from torch.nn import Module
|
7 |
+
|
8 |
+
if TYPE_CHECKING:
|
9 |
+
import sys
|
10 |
+
|
11 |
+
if sys.version_info >= (3, 8):
|
12 |
+
from typing import Literal # noqa: F401
|
13 |
+
else:
|
14 |
+
from typing_extensions import Literal # noqa: F401
|
15 |
+
else:
|
16 |
+
Literal = {True: bool, False: bool, (True, False): bool}
|
17 |
+
|
18 |
+
TensorOrTupleOfTensorsGeneric = TypeVar(
|
19 |
+
"TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...]
|
20 |
+
)
|
21 |
+
TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool)
|
22 |
+
ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module])
|
23 |
+
TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]]
|
24 |
+
BaselineType = Union[None, Tensor, int, float, Tuple[Union[Tensor, int, float], ...]]
|
25 |
+
|
26 |
+
TensorLikeList1D = List[float]
|
27 |
+
TensorLikeList2D = List[TensorLikeList1D]
|
28 |
+
TensorLikeList3D = List[TensorLikeList2D]
|
29 |
+
TensorLikeList4D = List[TensorLikeList3D]
|
30 |
+
TensorLikeList5D = List[TensorLikeList4D]
|
31 |
+
TensorLikeList = Union[
|
32 |
+
TensorLikeList1D,
|
33 |
+
TensorLikeList2D,
|
34 |
+
TensorLikeList3D,
|
35 |
+
TensorLikeList4D,
|
36 |
+
TensorLikeList5D,
|
37 |
+
]
|
captum/attr/__init__.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from captum.attr._core.deep_lift import DeepLift, DeepLiftShap # noqa
|
3 |
+
from captum.attr._core.feature_ablation import FeatureAblation # noqa
|
4 |
+
from captum.attr._core.feature_permutation import FeaturePermutation # noqa
|
5 |
+
from captum.attr._core.gradient_shap import GradientShap # noqa
|
6 |
+
from captum.attr._core.guided_backprop_deconvnet import ( # noqa
|
7 |
+
Deconvolution,
|
8 |
+
GuidedBackprop,
|
9 |
+
)
|
10 |
+
from captum.attr._core.guided_grad_cam import GuidedGradCam # noqa
|
11 |
+
from captum.attr._core.input_x_gradient import InputXGradient # noqa
|
12 |
+
from captum.attr._core.integrated_gradients import IntegratedGradients # noqa
|
13 |
+
from captum.attr._core.kernel_shap import KernelShap # noqa
|
14 |
+
from captum.attr._core.layer.grad_cam import LayerGradCam # noqa
|
15 |
+
from captum.attr._core.layer.internal_influence import InternalInfluence # noqa
|
16 |
+
from captum.attr._core.layer.layer_activation import LayerActivation # noqa
|
17 |
+
from captum.attr._core.layer.layer_conductance import LayerConductance # noqa
|
18 |
+
from captum.attr._core.layer.layer_deep_lift import ( # noqa
|
19 |
+
LayerDeepLift,
|
20 |
+
LayerDeepLiftShap,
|
21 |
+
)
|
22 |
+
from captum.attr._core.layer.layer_feature_ablation import LayerFeatureAblation # noqa
|
23 |
+
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap # noqa
|
24 |
+
from captum.attr._core.layer.layer_gradient_x_activation import ( # noqa
|
25 |
+
LayerGradientXActivation,
|
26 |
+
)
|
27 |
+
from captum.attr._core.layer.layer_integrated_gradients import ( # noqa
|
28 |
+
LayerIntegratedGradients,
|
29 |
+
)
|
30 |
+
from captum.attr._core.layer.layer_lrp import LayerLRP # noqa
|
31 |
+
from captum.attr._core.lime import Lime, LimeBase # noqa
|
32 |
+
from captum.attr._core.lrp import LRP # noqa
|
33 |
+
from captum.attr._core.neuron.neuron_conductance import NeuronConductance # noqa
|
34 |
+
from captum.attr._core.neuron.neuron_deep_lift import ( # noqa
|
35 |
+
NeuronDeepLift,
|
36 |
+
NeuronDeepLiftShap,
|
37 |
+
)
|
38 |
+
from captum.attr._core.neuron.neuron_feature_ablation import ( # noqa
|
39 |
+
NeuronFeatureAblation,
|
40 |
+
)
|
41 |
+
from captum.attr._core.neuron.neuron_gradient import NeuronGradient # noqa
|
42 |
+
from captum.attr._core.neuron.neuron_gradient_shap import NeuronGradientShap # noqa
|
43 |
+
from captum.attr._core.neuron.neuron_guided_backprop_deconvnet import ( # noqa
|
44 |
+
NeuronDeconvolution,
|
45 |
+
NeuronGuidedBackprop,
|
46 |
+
)
|
47 |
+
from captum.attr._core.neuron.neuron_integrated_gradients import ( # noqa
|
48 |
+
NeuronIntegratedGradients,
|
49 |
+
)
|
50 |
+
from captum.attr._core.noise_tunnel import NoiseTunnel # noqa
|
51 |
+
from captum.attr._core.occlusion import Occlusion # noqa
|
52 |
+
from captum.attr._core.saliency import Saliency # noqa
|
53 |
+
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling # noqa
|
54 |
+
from captum.attr._models.base import ( # noqa
|
55 |
+
configure_interpretable_embedding_layer,
|
56 |
+
InterpretableEmbeddingBase,
|
57 |
+
remove_interpretable_embedding_layer,
|
58 |
+
TokenReferenceBase,
|
59 |
+
)
|
60 |
+
from captum.attr._utils import visualization # noqa
|
61 |
+
from captum.attr._utils.attribution import ( # noqa # noqa # noqa # noqa # noqa
|
62 |
+
Attribution,
|
63 |
+
GradientAttribution,
|
64 |
+
LayerAttribution,
|
65 |
+
NeuronAttribution,
|
66 |
+
PerturbationAttribution,
|
67 |
+
)
|
68 |
+
from captum.attr._utils.class_summarizer import ClassSummarizer
|
69 |
+
from captum.attr._utils.stat import (
|
70 |
+
CommonStats,
|
71 |
+
Count,
|
72 |
+
Max,
|
73 |
+
Mean,
|
74 |
+
Min,
|
75 |
+
MSE,
|
76 |
+
StdDev,
|
77 |
+
Sum,
|
78 |
+
Var,
|
79 |
+
)
|
80 |
+
from captum.attr._utils.summarizer import Summarizer
|
81 |
+
|
82 |
+
__all__ = [
|
83 |
+
"Attribution",
|
84 |
+
"GradientAttribution",
|
85 |
+
"PerturbationAttribution",
|
86 |
+
"NeuronAttribution",
|
87 |
+
"LayerAttribution",
|
88 |
+
"IntegratedGradients",
|
89 |
+
"DeepLift",
|
90 |
+
"DeepLiftShap",
|
91 |
+
"InputXGradient",
|
92 |
+
"Saliency",
|
93 |
+
"GuidedBackprop",
|
94 |
+
"Deconvolution",
|
95 |
+
"GuidedGradCam",
|
96 |
+
"FeatureAblation",
|
97 |
+
"FeaturePermutation",
|
98 |
+
"Occlusion",
|
99 |
+
"ShapleyValueSampling",
|
100 |
+
"ShapleyValues",
|
101 |
+
"LimeBase",
|
102 |
+
"Lime",
|
103 |
+
"LRP",
|
104 |
+
"KernelShap",
|
105 |
+
"LayerConductance",
|
106 |
+
"LayerGradientXActivation",
|
107 |
+
"LayerActivation",
|
108 |
+
"LayerFeatureAblation",
|
109 |
+
"InternalInfluence",
|
110 |
+
"LayerGradCam",
|
111 |
+
"LayerDeepLift",
|
112 |
+
"LayerDeepLiftShap",
|
113 |
+
"LayerGradientShap",
|
114 |
+
"LayerIntegratedGradients",
|
115 |
+
"LayerLRP",
|
116 |
+
"NeuronConductance",
|
117 |
+
"NeuronFeatureAblation",
|
118 |
+
"NeuronGradient",
|
119 |
+
"NeuronIntegratedGradients",
|
120 |
+
"NeuronDeepLift",
|
121 |
+
"NeuronDeepLiftShap",
|
122 |
+
"NeuronGradientShap",
|
123 |
+
"NeuronDeconvolution",
|
124 |
+
"NeuronGuidedBackprop",
|
125 |
+
"NoiseTunnel",
|
126 |
+
"GradientShap",
|
127 |
+
"InterpretableEmbeddingBase",
|
128 |
+
"TokenReferenceBase",
|
129 |
+
"visualization",
|
130 |
+
"configure_interpretable_embedding_layer",
|
131 |
+
"remove_interpretable_embedding_layer",
|
132 |
+
"Summarizer",
|
133 |
+
"CommonStats",
|
134 |
+
"ClassSummarizer",
|
135 |
+
"Mean",
|
136 |
+
"StdDev",
|
137 |
+
"MSE",
|
138 |
+
"Var",
|
139 |
+
"Min",
|
140 |
+
"Max",
|
141 |
+
"Sum",
|
142 |
+
"Count",
|
143 |
+
]
|
captum/attr/_core/__init__.py
ADDED
File without changes
|
captum/attr/_core/deep_lift.py
ADDED
@@ -0,0 +1,1151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import typing
|
3 |
+
import warnings
|
4 |
+
from typing import Any, Callable, cast, List, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from captum._utils.common import (
|
10 |
+
_expand_additional_forward_args,
|
11 |
+
_expand_target,
|
12 |
+
_format_additional_forward_args,
|
13 |
+
_format_baseline,
|
14 |
+
_format_output,
|
15 |
+
_format_tensor_into_tuples,
|
16 |
+
_is_tuple,
|
17 |
+
_register_backward_hook,
|
18 |
+
_run_forward,
|
19 |
+
_select_targets,
|
20 |
+
ExpansionTypes,
|
21 |
+
)
|
22 |
+
from captum._utils.gradient import (
|
23 |
+
apply_gradient_requirements,
|
24 |
+
undo_gradient_requirements,
|
25 |
+
)
|
26 |
+
from captum._utils.typing import (
|
27 |
+
BaselineType,
|
28 |
+
Literal,
|
29 |
+
TargetType,
|
30 |
+
TensorOrTupleOfTensorsGeneric,
|
31 |
+
)
|
32 |
+
from captum.attr._utils.attribution import GradientAttribution
|
33 |
+
from captum.attr._utils.common import (
|
34 |
+
_call_custom_attribution_func,
|
35 |
+
_compute_conv_delta_and_format_attrs,
|
36 |
+
_format_callable_baseline,
|
37 |
+
_tensorize_baseline,
|
38 |
+
_validate_input,
|
39 |
+
)
|
40 |
+
from captum.log import log_usage
|
41 |
+
from torch import Tensor
|
42 |
+
from torch.nn import Module
|
43 |
+
from torch.utils.hooks import RemovableHandle
|
44 |
+
|
45 |
+
|
46 |
+
# Check if module backward hook can safely be used for the module that produced
|
47 |
+
# this inputs / outputs mapping
|
48 |
+
def _check_valid_module(inputs_grad_fn, outputs) -> bool:
|
49 |
+
def is_output_cloned(output_fn, input_grad_fn) -> bool:
|
50 |
+
"""
|
51 |
+
Checks if the output has been cloned. This happens especially in case of
|
52 |
+
layer deeplift.
|
53 |
+
"""
|
54 |
+
return (
|
55 |
+
output_fn[0].next_functions is not None
|
56 |
+
and output_fn[0].next_functions[0][0] == input_grad_fn
|
57 |
+
)
|
58 |
+
|
59 |
+
curr_fn = outputs.grad_fn
|
60 |
+
first_next = curr_fn.next_functions[0]
|
61 |
+
try:
|
62 |
+
# if `inputs` in the input to the network then the grad_fn is None and
|
63 |
+
# for that input backward_hook isn't computed. That's the reason why we
|
64 |
+
# need to check on `inputs_grad_fns[first_next[1]]` being None.
|
65 |
+
return (
|
66 |
+
inputs_grad_fn is None
|
67 |
+
or first_next[0] == inputs_grad_fn
|
68 |
+
or is_output_cloned(first_next, inputs_grad_fn)
|
69 |
+
)
|
70 |
+
except IndexError:
|
71 |
+
return False
|
72 |
+
|
73 |
+
|
74 |
+
class DeepLift(GradientAttribution):
|
75 |
+
r"""
|
76 |
+
Implements DeepLIFT algorithm based on the following paper:
|
77 |
+
Learning Important Features Through Propagating Activation Differences,
|
78 |
+
Avanti Shrikumar, et. al.
|
79 |
+
https://arxiv.org/abs/1704.02685
|
80 |
+
|
81 |
+
and the gradient formulation proposed in:
|
82 |
+
Towards better understanding of gradient-based attribution methods for
|
83 |
+
deep neural networks, Marco Ancona, et.al.
|
84 |
+
https://openreview.net/pdf?id=Sy21R9JAW
|
85 |
+
|
86 |
+
This implementation supports only Rescale rule. RevealCancel rule will
|
87 |
+
be supported in later releases.
|
88 |
+
In addition to that, in order to keep the implementation cleaner, DeepLIFT
|
89 |
+
for internal neurons and layers extends current implementation and is
|
90 |
+
implemented separately in LayerDeepLift and NeuronDeepLift.
|
91 |
+
Although DeepLIFT's(Rescale Rule) attribution quality is comparable with
|
92 |
+
Integrated Gradients, it runs significantly faster than Integrated
|
93 |
+
Gradients and is preferred for large datasets.
|
94 |
+
|
95 |
+
Currently we only support a limited number of non-linear activations
|
96 |
+
but the plan is to expand the list in the future.
|
97 |
+
|
98 |
+
Note: As we know, currently we cannot access the building blocks,
|
99 |
+
of PyTorch's built-in LSTM, RNNs and GRUs such as Tanh and Sigmoid.
|
100 |
+
Nonetheless, it is possible to build custom LSTMs, RNNS and GRUs
|
101 |
+
with performance similar to built-in ones using TorchScript.
|
102 |
+
More details on how to build custom RNNs can be found here:
|
103 |
+
https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
model: Module,
|
109 |
+
multiply_by_inputs: bool = True,
|
110 |
+
eps: float = 1e-10,
|
111 |
+
) -> None:
|
112 |
+
r"""
|
113 |
+
Args:
|
114 |
+
|
115 |
+
model (nn.Module): The reference to PyTorch model instance. Model cannot
|
116 |
+
contain any in-place nonlinear submodules; these are not
|
117 |
+
supported by the register_full_backward_hook PyTorch API
|
118 |
+
starting from PyTorch v1.9.
|
119 |
+
multiply_by_inputs (bool, optional): Indicates whether to factor
|
120 |
+
model inputs' multiplier in the final attribution scores.
|
121 |
+
In the literature this is also known as local vs global
|
122 |
+
attribution. If inputs' multiplier isn't factored in
|
123 |
+
then that type of attribution method is also called local
|
124 |
+
attribution. If it is, then that type of attribution
|
125 |
+
method is called global.
|
126 |
+
More detailed can be found here:
|
127 |
+
https://arxiv.org/abs/1711.06104
|
128 |
+
|
129 |
+
In case of DeepLift, if `multiply_by_inputs`
|
130 |
+
is set to True, final sensitivity scores
|
131 |
+
are being multiplied by (inputs - baselines).
|
132 |
+
This flag applies only if `custom_attribution_func` is
|
133 |
+
set to None.
|
134 |
+
|
135 |
+
eps (float, optional): A value at which to consider output/input change
|
136 |
+
significant when computing the gradients for non-linear layers.
|
137 |
+
This is useful to adjust, depending on your model's bit depth,
|
138 |
+
to avoid numerical issues during the gradient computation.
|
139 |
+
Default: 1e-10
|
140 |
+
"""
|
141 |
+
GradientAttribution.__init__(self, model)
|
142 |
+
self.model = model
|
143 |
+
self.eps = eps
|
144 |
+
self.forward_handles: List[RemovableHandle] = []
|
145 |
+
self.backward_handles: List[RemovableHandle] = []
|
146 |
+
self._multiply_by_inputs = multiply_by_inputs
|
147 |
+
|
148 |
+
@typing.overload
|
149 |
+
def attribute(
|
150 |
+
self,
|
151 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
152 |
+
baselines: BaselineType = None,
|
153 |
+
target: TargetType = None,
|
154 |
+
additional_forward_args: Any = None,
|
155 |
+
return_convergence_delta: Literal[False] = False,
|
156 |
+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
|
157 |
+
) -> TensorOrTupleOfTensorsGeneric:
|
158 |
+
...
|
159 |
+
|
160 |
+
@typing.overload
|
161 |
+
def attribute(
|
162 |
+
self,
|
163 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
164 |
+
baselines: BaselineType = None,
|
165 |
+
target: TargetType = None,
|
166 |
+
additional_forward_args: Any = None,
|
167 |
+
*,
|
168 |
+
return_convergence_delta: Literal[True],
|
169 |
+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
|
170 |
+
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
|
171 |
+
...
|
172 |
+
|
173 |
+
@log_usage()
|
174 |
+
def attribute( # type: ignore
|
175 |
+
self,
|
176 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
177 |
+
baselines: BaselineType = None,
|
178 |
+
target: TargetType = None,
|
179 |
+
additional_forward_args: Any = None,
|
180 |
+
return_convergence_delta: bool = False,
|
181 |
+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
|
182 |
+
) -> Union[
|
183 |
+
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
|
184 |
+
]:
|
185 |
+
r"""
|
186 |
+
Args:
|
187 |
+
|
188 |
+
inputs (tensor or tuple of tensors): Input for which
|
189 |
+
attributions are computed. If forward_func takes a single
|
190 |
+
tensor as input, a single input tensor should be provided.
|
191 |
+
If forward_func takes multiple tensors as input, a tuple
|
192 |
+
of the input tensors should be provided. It is assumed
|
193 |
+
that for all given input tensors, dimension 0 corresponds
|
194 |
+
to the number of examples (aka batch size), and if
|
195 |
+
multiple input tensors are provided, the examples must
|
196 |
+
be aligned appropriately.
|
197 |
+
baselines (scalar, tensor, tuple of scalars or tensors, optional):
|
198 |
+
Baselines define reference samples that are compared with
|
199 |
+
the inputs. In order to assign attribution scores DeepLift
|
200 |
+
computes the differences between the inputs/outputs and
|
201 |
+
corresponding references.
|
202 |
+
Baselines can be provided as:
|
203 |
+
|
204 |
+
- a single tensor, if inputs is a single tensor, with
|
205 |
+
exactly the same dimensions as inputs or the first
|
206 |
+
dimension is one and the remaining dimensions match
|
207 |
+
with inputs.
|
208 |
+
|
209 |
+
- a single scalar, if inputs is a single tensor, which will
|
210 |
+
be broadcasted for each input value in input tensor.
|
211 |
+
|
212 |
+
- a tuple of tensors or scalars, the baseline corresponding
|
213 |
+
to each tensor in the inputs' tuple can be:
|
214 |
+
|
215 |
+
- either a tensor with matching dimensions to
|
216 |
+
corresponding tensor in the inputs' tuple
|
217 |
+
or the first dimension is one and the remaining
|
218 |
+
dimensions match with the corresponding
|
219 |
+
input tensor.
|
220 |
+
|
221 |
+
- or a scalar, corresponding to a tensor in the
|
222 |
+
inputs' tuple. This scalar value is broadcasted
|
223 |
+
for corresponding input tensor.
|
224 |
+
|
225 |
+
In the cases when `baselines` is not provided, we internally
|
226 |
+
use zero scalar corresponding to each input tensor.
|
227 |
+
|
228 |
+
Default: None
|
229 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
230 |
+
which gradients are computed (for classification cases,
|
231 |
+
this is usually the target class).
|
232 |
+
If the network returns a scalar value per example,
|
233 |
+
no target index is necessary.
|
234 |
+
For general 2D outputs, targets can be either:
|
235 |
+
|
236 |
+
- a single integer or a tensor containing a single
|
237 |
+
integer, which is applied to all input examples
|
238 |
+
|
239 |
+
- a list of integers or a 1D tensor, with length matching
|
240 |
+
the number of examples in inputs (dim 0). Each integer
|
241 |
+
is applied as the target for the corresponding example.
|
242 |
+
|
243 |
+
For outputs with > 2 dimensions, targets can be either:
|
244 |
+
|
245 |
+
- A single tuple, which contains #output_dims - 1
|
246 |
+
elements. This target index is applied to all examples.
|
247 |
+
|
248 |
+
- A list of tuples with length equal to the number of
|
249 |
+
examples in inputs (dim 0), and each tuple containing
|
250 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
251 |
+
target for the corresponding example.
|
252 |
+
|
253 |
+
Default: None
|
254 |
+
additional_forward_args (any, optional): If the forward function
|
255 |
+
requires additional arguments other than the inputs for
|
256 |
+
which attributions should not be computed, this argument
|
257 |
+
can be provided. It must be either a single additional
|
258 |
+
argument of a Tensor or arbitrary (non-tuple) type or a tuple
|
259 |
+
containing multiple additional arguments including tensors
|
260 |
+
or any arbitrary python types. These arguments are provided to
|
261 |
+
forward_func in order, following the arguments in inputs.
|
262 |
+
Note that attributions are not computed with respect
|
263 |
+
to these arguments.
|
264 |
+
Default: None
|
265 |
+
return_convergence_delta (bool, optional): Indicates whether to return
|
266 |
+
convergence delta or not. If `return_convergence_delta`
|
267 |
+
is set to True convergence delta will be returned in
|
268 |
+
a tuple following attributions.
|
269 |
+
Default: False
|
270 |
+
custom_attribution_func (callable, optional): A custom function for
|
271 |
+
computing final attribution scores. This function can take
|
272 |
+
at least one and at most three arguments with the
|
273 |
+
following signature:
|
274 |
+
|
275 |
+
- custom_attribution_func(multipliers)
|
276 |
+
- custom_attribution_func(multipliers, inputs)
|
277 |
+
- custom_attribution_func(multipliers, inputs, baselines)
|
278 |
+
|
279 |
+
In case this function is not provided, we use the default
|
280 |
+
logic defined as: multipliers * (inputs - baselines)
|
281 |
+
It is assumed that all input arguments, `multipliers`,
|
282 |
+
`inputs` and `baselines` are provided in tuples of same
|
283 |
+
length. `custom_attribution_func` returns a tuple of
|
284 |
+
attribution tensors that have the same length as the
|
285 |
+
`inputs`.
|
286 |
+
|
287 |
+
Default: None
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
**attributions** or 2-element tuple of **attributions**, **delta**:
|
291 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
292 |
+
Attribution score computed based on DeepLift rescale rule with respect
|
293 |
+
to each input feature. Attributions will always be
|
294 |
+
the same size as the provided inputs, with each value
|
295 |
+
providing the attribution of the corresponding input index.
|
296 |
+
If a single tensor is provided as inputs, a single tensor is
|
297 |
+
returned. If a tuple is provided for inputs, a tuple of
|
298 |
+
corresponding sized tensors is returned.
|
299 |
+
- **delta** (*tensor*, returned if return_convergence_delta=True):
|
300 |
+
This is computed using the property that
|
301 |
+
the total sum of forward_func(inputs) - forward_func(baselines)
|
302 |
+
must equal the total sum of the attributions computed
|
303 |
+
based on DeepLift's rescale rule.
|
304 |
+
Delta is calculated per example, meaning that the number of
|
305 |
+
elements in returned delta tensor is equal to the number of
|
306 |
+
of examples in input.
|
307 |
+
Note that the logic described for deltas is guaranteed when the
|
308 |
+
default logic for attribution computations is used, meaning that the
|
309 |
+
`custom_attribution_func=None`, otherwise it is not guaranteed and
|
310 |
+
depends on the specifics of the `custom_attribution_func`.
|
311 |
+
|
312 |
+
Examples::
|
313 |
+
|
314 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
315 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
316 |
+
>>> net = ImageClassifier()
|
317 |
+
>>> dl = DeepLift(net)
|
318 |
+
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
|
319 |
+
>>> # Computes deeplift attribution scores for class 3.
|
320 |
+
>>> attribution = dl.attribute(input, target=3)
|
321 |
+
"""
|
322 |
+
|
323 |
+
# Keeps track whether original input is a tuple or not before
|
324 |
+
# converting it into a tuple.
|
325 |
+
is_inputs_tuple = _is_tuple(inputs)
|
326 |
+
|
327 |
+
inputs = _format_tensor_into_tuples(inputs)
|
328 |
+
baselines = _format_baseline(baselines, inputs)
|
329 |
+
|
330 |
+
gradient_mask = apply_gradient_requirements(inputs)
|
331 |
+
|
332 |
+
_validate_input(inputs, baselines)
|
333 |
+
|
334 |
+
# set hooks for baselines
|
335 |
+
warnings.warn(
|
336 |
+
"""Setting forward, backward hooks and attributes on non-linear
|
337 |
+
activations. The hooks and attributes will be removed
|
338 |
+
after the attribution is finished"""
|
339 |
+
)
|
340 |
+
baselines = _tensorize_baseline(inputs, baselines)
|
341 |
+
main_model_hooks = []
|
342 |
+
try:
|
343 |
+
main_model_hooks = self._hook_main_model()
|
344 |
+
|
345 |
+
self.model.apply(self._register_hooks)
|
346 |
+
|
347 |
+
additional_forward_args = _format_additional_forward_args(
|
348 |
+
additional_forward_args
|
349 |
+
)
|
350 |
+
|
351 |
+
expanded_target = _expand_target(
|
352 |
+
target, 2, expansion_type=ExpansionTypes.repeat
|
353 |
+
)
|
354 |
+
|
355 |
+
wrapped_forward_func = self._construct_forward_func(
|
356 |
+
self.model,
|
357 |
+
(inputs, baselines),
|
358 |
+
expanded_target,
|
359 |
+
additional_forward_args,
|
360 |
+
)
|
361 |
+
gradients = self.gradient_func(wrapped_forward_func, inputs)
|
362 |
+
if custom_attribution_func is None:
|
363 |
+
if self.multiplies_by_inputs:
|
364 |
+
attributions = tuple(
|
365 |
+
(input - baseline) * gradient
|
366 |
+
for input, baseline, gradient in zip(
|
367 |
+
inputs, baselines, gradients
|
368 |
+
)
|
369 |
+
)
|
370 |
+
else:
|
371 |
+
attributions = gradients
|
372 |
+
else:
|
373 |
+
attributions = _call_custom_attribution_func(
|
374 |
+
custom_attribution_func, gradients, inputs, baselines
|
375 |
+
)
|
376 |
+
finally:
|
377 |
+
# Even if any error is raised, remove all hooks before raising
|
378 |
+
self._remove_hooks(main_model_hooks)
|
379 |
+
|
380 |
+
undo_gradient_requirements(inputs, gradient_mask)
|
381 |
+
return _compute_conv_delta_and_format_attrs(
|
382 |
+
self,
|
383 |
+
return_convergence_delta,
|
384 |
+
attributions,
|
385 |
+
baselines,
|
386 |
+
inputs,
|
387 |
+
additional_forward_args,
|
388 |
+
target,
|
389 |
+
is_inputs_tuple,
|
390 |
+
)
|
391 |
+
|
392 |
+
def _construct_forward_func(
|
393 |
+
self,
|
394 |
+
forward_func: Callable,
|
395 |
+
inputs: Tuple,
|
396 |
+
target: TargetType = None,
|
397 |
+
additional_forward_args: Any = None,
|
398 |
+
) -> Callable:
|
399 |
+
def forward_fn():
|
400 |
+
model_out = _run_forward(
|
401 |
+
forward_func, inputs, None, additional_forward_args
|
402 |
+
)
|
403 |
+
return _select_targets(
|
404 |
+
torch.cat((model_out[:, 0], model_out[:, 1])), target
|
405 |
+
)
|
406 |
+
|
407 |
+
if hasattr(forward_func, "device_ids"):
|
408 |
+
forward_fn.device_ids = forward_func.device_ids # type: ignore
|
409 |
+
return forward_fn
|
410 |
+
|
411 |
+
def _is_non_linear(self, module: Module) -> bool:
|
412 |
+
return type(module) in SUPPORTED_NON_LINEAR.keys()
|
413 |
+
|
414 |
+
def _forward_pre_hook_ref(
|
415 |
+
self, module: Module, inputs: Union[Tensor, Tuple[Tensor, ...]]
|
416 |
+
) -> None:
|
417 |
+
inputs = _format_tensor_into_tuples(inputs)
|
418 |
+
module.input_ref = tuple( # type: ignore
|
419 |
+
input.clone().detach() for input in inputs
|
420 |
+
)
|
421 |
+
|
422 |
+
def _forward_pre_hook(
|
423 |
+
self, module: Module, inputs: Union[Tensor, Tuple[Tensor, ...]]
|
424 |
+
) -> None:
|
425 |
+
"""
|
426 |
+
For the modules that perform in-place operations such as ReLUs, we cannot
|
427 |
+
use inputs from forward hooks. This is because in that case inputs
|
428 |
+
and outputs are the same. We need access the inputs in pre-hooks and
|
429 |
+
set necessary hooks on inputs there.
|
430 |
+
"""
|
431 |
+
inputs = _format_tensor_into_tuples(inputs)
|
432 |
+
module.input = inputs[0].clone().detach()
|
433 |
+
module.input_grad_fns = inputs[0].grad_fn # type: ignore
|
434 |
+
|
435 |
+
def tensor_backward_hook(grad):
|
436 |
+
if module.saved_grad is None:
|
437 |
+
raise RuntimeError(
|
438 |
+
"""Module {} was detected as not supporting correctly module
|
439 |
+
backward hook. You should modify your hook to ignore the given
|
440 |
+
grad_inputs (recompute them by hand if needed) and save the
|
441 |
+
newly computed grad_inputs in module.saved_grad. See MaxPool1d
|
442 |
+
as an example.""".format(
|
443 |
+
module
|
444 |
+
)
|
445 |
+
)
|
446 |
+
return module.saved_grad
|
447 |
+
|
448 |
+
# the hook is set by default but it will be used only for
|
449 |
+
# failure cases and will be removed otherwise
|
450 |
+
handle = inputs[0].register_hook(tensor_backward_hook)
|
451 |
+
module.input_hook = handle
|
452 |
+
|
453 |
+
def _forward_hook(
|
454 |
+
self,
|
455 |
+
module: Module,
|
456 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
457 |
+
outputs: Union[Tensor, Tuple[Tensor, ...]],
|
458 |
+
) -> None:
|
459 |
+
r"""
|
460 |
+
we need forward hook to access and detach the inputs and
|
461 |
+
outputs of a neuron
|
462 |
+
"""
|
463 |
+
outputs = _format_tensor_into_tuples(outputs)
|
464 |
+
module.output = outputs[0].clone().detach()
|
465 |
+
if not _check_valid_module(module.input_grad_fns, outputs[0]):
|
466 |
+
warnings.warn(
|
467 |
+
"""An invalid module {} is detected. Saved gradients will
|
468 |
+
be used as the gradients of the module's input tensor.
|
469 |
+
See MaxPool1d as an example.""".format(
|
470 |
+
module
|
471 |
+
)
|
472 |
+
)
|
473 |
+
module.is_invalid = True # type: ignore
|
474 |
+
module.saved_grad = None # type: ignore
|
475 |
+
self.forward_handles.append(cast(RemovableHandle, module.input_hook))
|
476 |
+
else:
|
477 |
+
module.is_invalid = False # type: ignore
|
478 |
+
# removing the hook if there is no failure case
|
479 |
+
cast(RemovableHandle, module.input_hook).remove()
|
480 |
+
del module.input_hook
|
481 |
+
del module.input_grad_fns
|
482 |
+
|
483 |
+
def _backward_hook(
|
484 |
+
self,
|
485 |
+
module: Module,
|
486 |
+
grad_input: Union[Tensor, Tuple[Tensor, ...]],
|
487 |
+
grad_output: Union[Tensor, Tuple[Tensor, ...]],
|
488 |
+
):
|
489 |
+
r"""
|
490 |
+
`grad_input` is the gradient of the neuron with respect to its input
|
491 |
+
`grad_output` is the gradient of the neuron with respect to its output
|
492 |
+
we can override `grad_input` according to chain rule with.
|
493 |
+
`grad_output` * delta_out / delta_in.
|
494 |
+
|
495 |
+
"""
|
496 |
+
# before accessing the attributes from the module we want
|
497 |
+
# to ensure that the properties exist, if not, then it is
|
498 |
+
# likely that the module is being reused.
|
499 |
+
attr_criteria = self.satisfies_attribute_criteria(module)
|
500 |
+
if not attr_criteria:
|
501 |
+
raise RuntimeError(
|
502 |
+
"A Module {} was detected that does not contain some of "
|
503 |
+
"the input/output attributes that are required for DeepLift "
|
504 |
+
"computations. This can occur, for example, if "
|
505 |
+
"your module is being used more than once in the network."
|
506 |
+
"Please, ensure that module is being used only once in the "
|
507 |
+
"network.".format(module)
|
508 |
+
)
|
509 |
+
multipliers = tuple(
|
510 |
+
SUPPORTED_NON_LINEAR[type(module)](
|
511 |
+
module,
|
512 |
+
module.input,
|
513 |
+
module.output,
|
514 |
+
grad_input,
|
515 |
+
grad_output,
|
516 |
+
eps=self.eps,
|
517 |
+
)
|
518 |
+
)
|
519 |
+
# remove all the properies that we set for the inputs and output
|
520 |
+
del module.input
|
521 |
+
del module.output
|
522 |
+
|
523 |
+
return multipliers
|
524 |
+
|
525 |
+
def satisfies_attribute_criteria(self, module: Module) -> bool:
|
526 |
+
return hasattr(module, "input") and hasattr(module, "output")
|
527 |
+
|
528 |
+
def _can_register_hook(self, module: Module) -> bool:
|
529 |
+
# TODO find a better way of checking if a module is a container or not
|
530 |
+
module_fullname = str(type(module))
|
531 |
+
has_already_hooks = len(module._backward_hooks) > 0 # type: ignore
|
532 |
+
return not (
|
533 |
+
"nn.modules.container" in module_fullname
|
534 |
+
or has_already_hooks
|
535 |
+
or not self._is_non_linear(module)
|
536 |
+
)
|
537 |
+
|
538 |
+
def _register_hooks(
|
539 |
+
self, module: Module, attribute_to_layer_input: bool = True
|
540 |
+
) -> None:
|
541 |
+
if not self._can_register_hook(module) or (
|
542 |
+
not attribute_to_layer_input and module is self.layer # type: ignore
|
543 |
+
):
|
544 |
+
return
|
545 |
+
# adds forward hook to leaf nodes that are non-linear
|
546 |
+
forward_handle = module.register_forward_hook(self._forward_hook)
|
547 |
+
pre_forward_handle = module.register_forward_pre_hook(self._forward_pre_hook)
|
548 |
+
backward_handle = _register_backward_hook(module, self._backward_hook, self)
|
549 |
+
self.forward_handles.append(forward_handle)
|
550 |
+
self.forward_handles.append(pre_forward_handle)
|
551 |
+
self.backward_handles.append(backward_handle)
|
552 |
+
|
553 |
+
def _remove_hooks(self, extra_hooks_to_remove: List[RemovableHandle]) -> None:
|
554 |
+
for handle in extra_hooks_to_remove:
|
555 |
+
handle.remove()
|
556 |
+
for forward_handle in self.forward_handles:
|
557 |
+
forward_handle.remove()
|
558 |
+
for backward_handle in self.backward_handles:
|
559 |
+
backward_handle.remove()
|
560 |
+
|
561 |
+
def _hook_main_model(self) -> List[RemovableHandle]:
|
562 |
+
def pre_hook(module: Module, baseline_inputs_add_args: Tuple) -> Tuple:
|
563 |
+
inputs = baseline_inputs_add_args[0]
|
564 |
+
baselines = baseline_inputs_add_args[1]
|
565 |
+
additional_args = None
|
566 |
+
if len(baseline_inputs_add_args) > 2:
|
567 |
+
additional_args = baseline_inputs_add_args[2:]
|
568 |
+
|
569 |
+
baseline_input_tsr = tuple(
|
570 |
+
torch.cat([input, baseline])
|
571 |
+
for input, baseline in zip(inputs, baselines)
|
572 |
+
)
|
573 |
+
if additional_args is not None:
|
574 |
+
expanded_additional_args = cast(
|
575 |
+
Tuple,
|
576 |
+
_expand_additional_forward_args(
|
577 |
+
additional_args, 2, ExpansionTypes.repeat
|
578 |
+
),
|
579 |
+
)
|
580 |
+
return (*baseline_input_tsr, *expanded_additional_args)
|
581 |
+
return baseline_input_tsr
|
582 |
+
|
583 |
+
def forward_hook(module: Module, inputs: Tuple, outputs: Tensor):
|
584 |
+
return torch.stack(torch.chunk(outputs, 2), dim=1)
|
585 |
+
|
586 |
+
if isinstance(
|
587 |
+
self.model, (nn.DataParallel, nn.parallel.DistributedDataParallel)
|
588 |
+
):
|
589 |
+
return [
|
590 |
+
self.model.module.register_forward_pre_hook(pre_hook), # type: ignore
|
591 |
+
self.model.module.register_forward_hook(forward_hook),
|
592 |
+
] # type: ignore
|
593 |
+
else:
|
594 |
+
return [
|
595 |
+
self.model.register_forward_pre_hook(pre_hook), # type: ignore
|
596 |
+
self.model.register_forward_hook(forward_hook),
|
597 |
+
] # type: ignore
|
598 |
+
|
599 |
+
def has_convergence_delta(self) -> bool:
|
600 |
+
return True
|
601 |
+
|
602 |
+
@property
|
603 |
+
def multiplies_by_inputs(self):
|
604 |
+
return self._multiply_by_inputs
|
605 |
+
|
606 |
+
|
607 |
+
class DeepLiftShap(DeepLift):
|
608 |
+
r"""
|
609 |
+
Extends DeepLift algorithm and approximates SHAP values using Deeplift.
|
610 |
+
For each input sample it computes DeepLift attribution with respect to
|
611 |
+
each baseline and averages resulting attributions.
|
612 |
+
More details about the algorithm can be found here:
|
613 |
+
|
614 |
+
http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf
|
615 |
+
|
616 |
+
Note that the explanation model:
|
617 |
+
1. Assumes that input features are independent of one another
|
618 |
+
2. Is linear, meaning that the explanations are modeled through
|
619 |
+
the additive composition of feature effects.
|
620 |
+
Although, it assumes a linear model for each explanation, the overall
|
621 |
+
model across multiple explanations can be complex and non-linear.
|
622 |
+
"""
|
623 |
+
|
624 |
+
def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None:
|
625 |
+
r"""
|
626 |
+
Args:
|
627 |
+
|
628 |
+
model (nn.Module): The reference to PyTorch model instance. Model cannot
|
629 |
+
contain any in-place nonlinear submodules; these are not
|
630 |
+
supported by the register_full_backward_hook PyTorch API.
|
631 |
+
multiply_by_inputs (bool, optional): Indicates whether to factor
|
632 |
+
model inputs' multiplier in the final attribution scores.
|
633 |
+
In the literature this is also known as local vs global
|
634 |
+
attribution. If inputs' multiplier isn't factored in
|
635 |
+
then that type of attribution method is also called local
|
636 |
+
attribution. If it is, then that type of attribution
|
637 |
+
method is called global.
|
638 |
+
More detailed can be found here:
|
639 |
+
https://arxiv.org/abs/1711.06104
|
640 |
+
|
641 |
+
In case of DeepLiftShap, if `multiply_by_inputs`
|
642 |
+
is set to True, final sensitivity scores
|
643 |
+
are being multiplied by (inputs - baselines).
|
644 |
+
This flag applies only if `custom_attribution_func` is
|
645 |
+
set to None.
|
646 |
+
"""
|
647 |
+
DeepLift.__init__(self, model, multiply_by_inputs=multiply_by_inputs)
|
648 |
+
|
649 |
+
# There's a mismatch between the signatures of DeepLift.attribute and
|
650 |
+
# DeepLiftShap.attribute, so we ignore typing here
|
651 |
+
@typing.overload # type: ignore
|
652 |
+
def attribute(
|
653 |
+
self,
|
654 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
655 |
+
baselines: Union[
|
656 |
+
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
|
657 |
+
],
|
658 |
+
target: TargetType = None,
|
659 |
+
additional_forward_args: Any = None,
|
660 |
+
return_convergence_delta: Literal[False] = False,
|
661 |
+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
|
662 |
+
) -> TensorOrTupleOfTensorsGeneric:
|
663 |
+
...
|
664 |
+
|
665 |
+
@typing.overload
|
666 |
+
def attribute(
|
667 |
+
self,
|
668 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
669 |
+
baselines: Union[
|
670 |
+
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
|
671 |
+
],
|
672 |
+
target: TargetType = None,
|
673 |
+
additional_forward_args: Any = None,
|
674 |
+
*,
|
675 |
+
return_convergence_delta: Literal[True],
|
676 |
+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
|
677 |
+
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
|
678 |
+
...
|
679 |
+
|
680 |
+
@log_usage()
|
681 |
+
def attribute( # type: ignore
|
682 |
+
self,
|
683 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
684 |
+
baselines: Union[
|
685 |
+
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
|
686 |
+
],
|
687 |
+
target: TargetType = None,
|
688 |
+
additional_forward_args: Any = None,
|
689 |
+
return_convergence_delta: bool = False,
|
690 |
+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
|
691 |
+
) -> Union[
|
692 |
+
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
|
693 |
+
]:
|
694 |
+
r"""
|
695 |
+
Args:
|
696 |
+
|
697 |
+
inputs (tensor or tuple of tensors): Input for which
|
698 |
+
attributions are computed. If forward_func takes a single
|
699 |
+
tensor as input, a single input tensor should be provided.
|
700 |
+
If forward_func takes multiple tensors as input, a tuple
|
701 |
+
of the input tensors should be provided. It is assumed
|
702 |
+
that for all given input tensors, dimension 0 corresponds
|
703 |
+
to the number of examples (aka batch size), and if
|
704 |
+
multiple input tensors are provided, the examples must
|
705 |
+
be aligned appropriately.
|
706 |
+
baselines (tensor, tuple of tensors, callable):
|
707 |
+
Baselines define reference samples that are compared with
|
708 |
+
the inputs. In order to assign attribution scores DeepLift
|
709 |
+
computes the differences between the inputs/outputs and
|
710 |
+
corresponding references. Baselines can be provided as:
|
711 |
+
|
712 |
+
- a single tensor, if inputs is a single tensor, with
|
713 |
+
the first dimension equal to the number of examples
|
714 |
+
in the baselines' distribution. The remaining dimensions
|
715 |
+
must match with input tensor's dimension starting from
|
716 |
+
the second dimension.
|
717 |
+
|
718 |
+
- a tuple of tensors, if inputs is a tuple of tensors,
|
719 |
+
with the first dimension of any tensor inside the tuple
|
720 |
+
equal to the number of examples in the baseline's
|
721 |
+
distribution. The remaining dimensions must match
|
722 |
+
the dimensions of the corresponding input tensor
|
723 |
+
starting from the second dimension.
|
724 |
+
|
725 |
+
- callable function, optionally takes `inputs` as an
|
726 |
+
argument and either returns a single tensor
|
727 |
+
or a tuple of those.
|
728 |
+
|
729 |
+
It is recommended that the number of samples in the baselines'
|
730 |
+
tensors is larger than one.
|
731 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
732 |
+
which gradients are computed (for classification cases,
|
733 |
+
this is usually the target class).
|
734 |
+
If the network returns a scalar value per example,
|
735 |
+
no target index is necessary.
|
736 |
+
For general 2D outputs, targets can be either:
|
737 |
+
|
738 |
+
- a single integer or a tensor containing a single
|
739 |
+
integer, which is applied to all input examples
|
740 |
+
|
741 |
+
- a list of integers or a 1D tensor, with length matching
|
742 |
+
the number of examples in inputs (dim 0). Each integer
|
743 |
+
is applied as the target for the corresponding example.
|
744 |
+
|
745 |
+
For outputs with > 2 dimensions, targets can be either:
|
746 |
+
|
747 |
+
- A single tuple, which contains #output_dims - 1
|
748 |
+
elements. This target index is applied to all examples.
|
749 |
+
|
750 |
+
- A list of tuples with length equal to the number of
|
751 |
+
examples in inputs (dim 0), and each tuple containing
|
752 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
753 |
+
target for the corresponding example.
|
754 |
+
|
755 |
+
Default: None
|
756 |
+
additional_forward_args (any, optional): If the forward function
|
757 |
+
requires additional arguments other than the inputs for
|
758 |
+
which attributions should not be computed, this argument
|
759 |
+
can be provided. It must be either a single additional
|
760 |
+
argument of a Tensor or arbitrary (non-tuple) type or a tuple
|
761 |
+
containing multiple additional arguments including tensors
|
762 |
+
or any arbitrary python types. These arguments are provided to
|
763 |
+
forward_func in order, following the arguments in inputs.
|
764 |
+
Note that attributions are not computed with respect
|
765 |
+
to these arguments.
|
766 |
+
Default: None
|
767 |
+
return_convergence_delta (bool, optional): Indicates whether to return
|
768 |
+
convergence delta or not. If `return_convergence_delta`
|
769 |
+
is set to True convergence delta will be returned in
|
770 |
+
a tuple following attributions.
|
771 |
+
Default: False
|
772 |
+
custom_attribution_func (callable, optional): A custom function for
|
773 |
+
computing final attribution scores. This function can take
|
774 |
+
at least one and at most three arguments with the
|
775 |
+
following signature:
|
776 |
+
|
777 |
+
- custom_attribution_func(multipliers)
|
778 |
+
- custom_attribution_func(multipliers, inputs)
|
779 |
+
- custom_attribution_func(multipliers, inputs, baselines)
|
780 |
+
|
781 |
+
In case this function is not provided we use the default
|
782 |
+
logic defined as: multipliers * (inputs - baselines)
|
783 |
+
It is assumed that all input arguments, `multipliers`,
|
784 |
+
`inputs` and `baselines` are provided in tuples of same
|
785 |
+
length. `custom_attribution_func` returns a tuple of
|
786 |
+
attribution tensors that have the same length as the
|
787 |
+
`inputs`.
|
788 |
+
Default: None
|
789 |
+
|
790 |
+
Returns:
|
791 |
+
**attributions** or 2-element tuple of **attributions**, **delta**:
|
792 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
793 |
+
Attribution score computed based on DeepLift rescale rule with
|
794 |
+
respect to each input feature. Attributions will always be
|
795 |
+
the same size as the provided inputs, with each value
|
796 |
+
providing the attribution of the corresponding input index.
|
797 |
+
If a single tensor is provided as inputs, a single tensor is
|
798 |
+
returned. If a tuple is provided for inputs, a tuple of
|
799 |
+
corresponding sized tensors is returned.
|
800 |
+
- **delta** (*tensor*, returned if return_convergence_delta=True):
|
801 |
+
This is computed using the property that the
|
802 |
+
total sum of forward_func(inputs) - forward_func(baselines)
|
803 |
+
must be very close to the total sum of attributions
|
804 |
+
computed based on approximated SHAP values using
|
805 |
+
Deeplift's rescale rule.
|
806 |
+
Delta is calculated for each example input and baseline pair,
|
807 |
+
meaning that the number of elements in returned delta tensor
|
808 |
+
is equal to the
|
809 |
+
`number of examples in input` * `number of examples
|
810 |
+
in baseline`. The deltas are ordered in the first place by
|
811 |
+
input example, followed by the baseline.
|
812 |
+
Note that the logic described for deltas is guaranteed
|
813 |
+
when the default logic for attribution computations is used,
|
814 |
+
meaning that the `custom_attribution_func=None`, otherwise
|
815 |
+
it is not guaranteed and depends on the specifics of the
|
816 |
+
`custom_attribution_func`.
|
817 |
+
|
818 |
+
Examples::
|
819 |
+
|
820 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
821 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
822 |
+
>>> net = ImageClassifier()
|
823 |
+
>>> dl = DeepLiftShap(net)
|
824 |
+
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
|
825 |
+
>>> # Computes shap values using deeplift for class 3.
|
826 |
+
>>> attribution = dl.attribute(input, target=3)
|
827 |
+
"""
|
828 |
+
baselines = _format_callable_baseline(baselines, inputs)
|
829 |
+
|
830 |
+
assert isinstance(baselines[0], torch.Tensor) and baselines[0].shape[0] > 1, (
|
831 |
+
"Baselines distribution has to be provided in form of a torch.Tensor"
|
832 |
+
" with more than one example but found: {}."
|
833 |
+
" If baselines are provided in shape of scalars or with a single"
|
834 |
+
" baseline example, `DeepLift`"
|
835 |
+
" approach can be used instead.".format(baselines[0])
|
836 |
+
)
|
837 |
+
|
838 |
+
# Keeps track whether original input is a tuple or not before
|
839 |
+
# converting it into a tuple.
|
840 |
+
is_inputs_tuple = _is_tuple(inputs)
|
841 |
+
|
842 |
+
inputs = _format_tensor_into_tuples(inputs)
|
843 |
+
|
844 |
+
# batch sizes
|
845 |
+
inp_bsz = inputs[0].shape[0]
|
846 |
+
base_bsz = baselines[0].shape[0]
|
847 |
+
|
848 |
+
(
|
849 |
+
exp_inp,
|
850 |
+
exp_base,
|
851 |
+
exp_tgt,
|
852 |
+
exp_addit_args,
|
853 |
+
) = self._expand_inputs_baselines_targets(
|
854 |
+
baselines, inputs, target, additional_forward_args
|
855 |
+
)
|
856 |
+
attributions = super().attribute.__wrapped__( # type: ignore
|
857 |
+
self,
|
858 |
+
exp_inp,
|
859 |
+
exp_base,
|
860 |
+
target=exp_tgt,
|
861 |
+
additional_forward_args=exp_addit_args,
|
862 |
+
return_convergence_delta=cast(
|
863 |
+
Literal[True, False], return_convergence_delta
|
864 |
+
),
|
865 |
+
custom_attribution_func=custom_attribution_func,
|
866 |
+
)
|
867 |
+
if return_convergence_delta:
|
868 |
+
attributions, delta = cast(Tuple[Tuple[Tensor, ...], Tensor], attributions)
|
869 |
+
|
870 |
+
attributions = tuple(
|
871 |
+
self._compute_mean_across_baselines(
|
872 |
+
inp_bsz, base_bsz, cast(Tensor, attribution)
|
873 |
+
)
|
874 |
+
for attribution in attributions
|
875 |
+
)
|
876 |
+
|
877 |
+
if return_convergence_delta:
|
878 |
+
return _format_output(is_inputs_tuple, attributions), delta
|
879 |
+
else:
|
880 |
+
return _format_output(is_inputs_tuple, attributions)
|
881 |
+
|
882 |
+
def _expand_inputs_baselines_targets(
|
883 |
+
self,
|
884 |
+
baselines: Tuple[Tensor, ...],
|
885 |
+
inputs: Tuple[Tensor, ...],
|
886 |
+
target: TargetType,
|
887 |
+
additional_forward_args: Any,
|
888 |
+
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], TargetType, Any]:
|
889 |
+
inp_bsz = inputs[0].shape[0]
|
890 |
+
base_bsz = baselines[0].shape[0]
|
891 |
+
|
892 |
+
expanded_inputs = tuple(
|
893 |
+
[
|
894 |
+
input.repeat_interleave(base_bsz, dim=0).requires_grad_()
|
895 |
+
for input in inputs
|
896 |
+
]
|
897 |
+
)
|
898 |
+
expanded_baselines = tuple(
|
899 |
+
[
|
900 |
+
baseline.repeat(
|
901 |
+
(inp_bsz,) + tuple([1] * (len(baseline.shape) - 1))
|
902 |
+
).requires_grad_()
|
903 |
+
for baseline in baselines
|
904 |
+
]
|
905 |
+
)
|
906 |
+
expanded_target = _expand_target(
|
907 |
+
target, base_bsz, expansion_type=ExpansionTypes.repeat_interleave
|
908 |
+
)
|
909 |
+
input_additional_args = (
|
910 |
+
_expand_additional_forward_args(
|
911 |
+
additional_forward_args,
|
912 |
+
base_bsz,
|
913 |
+
expansion_type=ExpansionTypes.repeat_interleave,
|
914 |
+
)
|
915 |
+
if additional_forward_args is not None
|
916 |
+
else None
|
917 |
+
)
|
918 |
+
return (
|
919 |
+
expanded_inputs,
|
920 |
+
expanded_baselines,
|
921 |
+
expanded_target,
|
922 |
+
input_additional_args,
|
923 |
+
)
|
924 |
+
|
925 |
+
def _compute_mean_across_baselines(
|
926 |
+
self, inp_bsz: int, base_bsz: int, attribution: Tensor
|
927 |
+
) -> Tensor:
|
928 |
+
# Average for multiple references
|
929 |
+
attr_shape: Tuple = (inp_bsz, base_bsz)
|
930 |
+
if len(attribution.shape) > 1:
|
931 |
+
attr_shape += attribution.shape[1:]
|
932 |
+
return torch.mean(attribution.view(attr_shape), dim=1, keepdim=False)
|
933 |
+
|
934 |
+
|
935 |
+
def nonlinear(
|
936 |
+
module: Module,
|
937 |
+
inputs: Tensor,
|
938 |
+
outputs: Tensor,
|
939 |
+
grad_input: Tensor,
|
940 |
+
grad_output: Tensor,
|
941 |
+
eps: float = 1e-10,
|
942 |
+
):
|
943 |
+
r"""
|
944 |
+
grad_input: (dLoss / dprev_layer_out, dLoss / wij, dLoss / bij)
|
945 |
+
grad_output: (dLoss / dlayer_out)
|
946 |
+
https://github.com/pytorch/pytorch/issues/12331
|
947 |
+
"""
|
948 |
+
delta_in, delta_out = _compute_diffs(inputs, outputs)
|
949 |
+
|
950 |
+
new_grad_inp = list(grad_input)
|
951 |
+
|
952 |
+
# supported non-linear modules take only single tensor as input hence accessing
|
953 |
+
# only the first element in `grad_input` and `grad_output`
|
954 |
+
new_grad_inp[0] = torch.where(
|
955 |
+
abs(delta_in) < eps, new_grad_inp[0], grad_output[0] * delta_out / delta_in
|
956 |
+
)
|
957 |
+
|
958 |
+
# If the module is invalid, save the newly computed gradients
|
959 |
+
# The original_grad_input will be overridden later in the Tensor hook
|
960 |
+
if module.is_invalid:
|
961 |
+
module.saved_grad = new_grad_inp[0]
|
962 |
+
return new_grad_inp
|
963 |
+
|
964 |
+
|
965 |
+
def softmax(
|
966 |
+
module: Module,
|
967 |
+
inputs: Tensor,
|
968 |
+
outputs: Tensor,
|
969 |
+
grad_input: Tensor,
|
970 |
+
grad_output: Tensor,
|
971 |
+
eps: float = 1e-10,
|
972 |
+
):
|
973 |
+
delta_in, delta_out = _compute_diffs(inputs, outputs)
|
974 |
+
|
975 |
+
new_grad_inp = list(grad_input)
|
976 |
+
grad_input_unnorm = torch.where(
|
977 |
+
abs(delta_in) < eps, new_grad_inp[0], grad_output[0] * delta_out / delta_in
|
978 |
+
)
|
979 |
+
# normalizing
|
980 |
+
n = grad_input[0].numel()
|
981 |
+
|
982 |
+
# updating only the first half
|
983 |
+
new_grad_inp[0] = grad_input_unnorm - grad_input_unnorm.sum() * 1 / n
|
984 |
+
return new_grad_inp
|
985 |
+
|
986 |
+
|
987 |
+
def maxpool1d(
|
988 |
+
module: Module,
|
989 |
+
inputs: Tensor,
|
990 |
+
outputs: Tensor,
|
991 |
+
grad_input: Tensor,
|
992 |
+
grad_output: Tensor,
|
993 |
+
eps: float = 1e-10,
|
994 |
+
):
|
995 |
+
return maxpool(
|
996 |
+
module,
|
997 |
+
F.max_pool1d,
|
998 |
+
F.max_unpool1d,
|
999 |
+
inputs,
|
1000 |
+
outputs,
|
1001 |
+
grad_input,
|
1002 |
+
grad_output,
|
1003 |
+
eps=eps,
|
1004 |
+
)
|
1005 |
+
|
1006 |
+
|
1007 |
+
def maxpool2d(
|
1008 |
+
module: Module,
|
1009 |
+
inputs: Tensor,
|
1010 |
+
outputs: Tensor,
|
1011 |
+
grad_input: Tensor,
|
1012 |
+
grad_output: Tensor,
|
1013 |
+
eps: float = 1e-10,
|
1014 |
+
):
|
1015 |
+
return maxpool(
|
1016 |
+
module,
|
1017 |
+
F.max_pool2d,
|
1018 |
+
F.max_unpool2d,
|
1019 |
+
inputs,
|
1020 |
+
outputs,
|
1021 |
+
grad_input,
|
1022 |
+
grad_output,
|
1023 |
+
eps=eps,
|
1024 |
+
)
|
1025 |
+
|
1026 |
+
|
1027 |
+
def maxpool3d(
|
1028 |
+
module: Module, inputs, outputs, grad_input, grad_output, eps: float = 1e-10
|
1029 |
+
):
|
1030 |
+
return maxpool(
|
1031 |
+
module,
|
1032 |
+
F.max_pool3d,
|
1033 |
+
F.max_unpool3d,
|
1034 |
+
inputs,
|
1035 |
+
outputs,
|
1036 |
+
grad_input,
|
1037 |
+
grad_output,
|
1038 |
+
eps=eps,
|
1039 |
+
)
|
1040 |
+
|
1041 |
+
|
1042 |
+
def maxpool(
|
1043 |
+
module: Module,
|
1044 |
+
pool_func: Callable,
|
1045 |
+
unpool_func: Callable,
|
1046 |
+
inputs,
|
1047 |
+
outputs,
|
1048 |
+
grad_input,
|
1049 |
+
grad_output,
|
1050 |
+
eps: float = 1e-10,
|
1051 |
+
):
|
1052 |
+
with torch.no_grad():
|
1053 |
+
input, input_ref = inputs.chunk(2)
|
1054 |
+
output, output_ref = outputs.chunk(2)
|
1055 |
+
|
1056 |
+
delta_in = input - input_ref
|
1057 |
+
delta_in = torch.cat(2 * [delta_in])
|
1058 |
+
# Extracts cross maximum between the outputs of maxpool for the
|
1059 |
+
# actual inputs and its corresponding references. In case the delta outputs
|
1060 |
+
# for the references are larger the method relies on the references and
|
1061 |
+
# corresponding gradients to compute the multiplies and contributions.
|
1062 |
+
delta_out_xmax = torch.max(output, output_ref)
|
1063 |
+
delta_out = torch.cat([delta_out_xmax - output_ref, output - delta_out_xmax])
|
1064 |
+
|
1065 |
+
_, indices = pool_func(
|
1066 |
+
module.input,
|
1067 |
+
module.kernel_size,
|
1068 |
+
module.stride,
|
1069 |
+
module.padding,
|
1070 |
+
module.dilation,
|
1071 |
+
module.ceil_mode,
|
1072 |
+
True,
|
1073 |
+
)
|
1074 |
+
grad_output_updated = grad_output[0]
|
1075 |
+
unpool_grad_out_delta, unpool_grad_out_ref_delta = torch.chunk(
|
1076 |
+
unpool_func(
|
1077 |
+
grad_output_updated * delta_out,
|
1078 |
+
indices,
|
1079 |
+
module.kernel_size,
|
1080 |
+
module.stride,
|
1081 |
+
module.padding,
|
1082 |
+
list(cast(torch.Size, module.input.shape)),
|
1083 |
+
),
|
1084 |
+
2,
|
1085 |
+
)
|
1086 |
+
|
1087 |
+
unpool_grad_out_delta = unpool_grad_out_delta + unpool_grad_out_ref_delta
|
1088 |
+
unpool_grad_out_delta = torch.cat(2 * [unpool_grad_out_delta])
|
1089 |
+
|
1090 |
+
# If the module is invalid, we need to recompute the grad_input
|
1091 |
+
if module.is_invalid:
|
1092 |
+
original_grad_input = grad_input
|
1093 |
+
grad_input = (
|
1094 |
+
unpool_func(
|
1095 |
+
grad_output_updated,
|
1096 |
+
indices,
|
1097 |
+
module.kernel_size,
|
1098 |
+
module.stride,
|
1099 |
+
module.padding,
|
1100 |
+
list(cast(torch.Size, module.input.shape)),
|
1101 |
+
),
|
1102 |
+
)
|
1103 |
+
if grad_input[0].shape != inputs.shape:
|
1104 |
+
raise AssertionError(
|
1105 |
+
"A problem occurred during maxpool modul's backward pass. "
|
1106 |
+
"The gradients with respect to inputs include only a "
|
1107 |
+
"subset of inputs. More details about this issue can "
|
1108 |
+
"be found here: "
|
1109 |
+
"https://pytorch.org/docs/stable/"
|
1110 |
+
"nn.html#torch.nn.Module.register_backward_hook "
|
1111 |
+
"This can happen for example if you attribute to the outputs of a "
|
1112 |
+
"MaxPool. As a workaround, please, attribute to the inputs of "
|
1113 |
+
"the following layer."
|
1114 |
+
)
|
1115 |
+
|
1116 |
+
new_grad_inp = torch.where(
|
1117 |
+
abs(delta_in) < eps, grad_input[0], unpool_grad_out_delta / delta_in
|
1118 |
+
)
|
1119 |
+
# If the module is invalid, save the newly computed gradients
|
1120 |
+
# The original_grad_input will be overridden later in the Tensor hook
|
1121 |
+
if module.is_invalid:
|
1122 |
+
module.saved_grad = new_grad_inp
|
1123 |
+
return original_grad_input
|
1124 |
+
else:
|
1125 |
+
return (new_grad_inp,)
|
1126 |
+
|
1127 |
+
|
1128 |
+
def _compute_diffs(inputs: Tensor, outputs: Tensor) -> Tuple[Tensor, Tensor]:
|
1129 |
+
input, input_ref = inputs.chunk(2)
|
1130 |
+
# if the model is a single non-linear module and we apply Rescale rule on it
|
1131 |
+
# we might not be able to perform chunk-ing because the output of the module is
|
1132 |
+
# usually being replaced by model output.
|
1133 |
+
output, output_ref = outputs.chunk(2)
|
1134 |
+
delta_in = input - input_ref
|
1135 |
+
delta_out = output - output_ref
|
1136 |
+
|
1137 |
+
return torch.cat(2 * [delta_in]), torch.cat(2 * [delta_out])
|
1138 |
+
|
1139 |
+
|
1140 |
+
SUPPORTED_NON_LINEAR = {
|
1141 |
+
nn.ReLU: nonlinear,
|
1142 |
+
nn.ELU: nonlinear,
|
1143 |
+
nn.LeakyReLU: nonlinear,
|
1144 |
+
nn.Sigmoid: nonlinear,
|
1145 |
+
nn.Tanh: nonlinear,
|
1146 |
+
nn.Softplus: nonlinear,
|
1147 |
+
nn.MaxPool1d: maxpool1d,
|
1148 |
+
nn.MaxPool2d: maxpool2d,
|
1149 |
+
nn.MaxPool3d: maxpool3d,
|
1150 |
+
nn.Softmax: softmax,
|
1151 |
+
}
|
captum/attr/_core/feature_ablation.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Any, Callable, cast, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from captum._utils.common import (
|
8 |
+
_expand_additional_forward_args,
|
9 |
+
_expand_target,
|
10 |
+
_format_additional_forward_args,
|
11 |
+
_format_output,
|
12 |
+
_format_tensor_into_tuples,
|
13 |
+
_is_tuple,
|
14 |
+
_run_forward,
|
15 |
+
)
|
16 |
+
from captum._utils.progress import progress
|
17 |
+
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
|
18 |
+
from captum.attr._utils.attribution import PerturbationAttribution
|
19 |
+
from captum.attr._utils.common import _format_input_baseline
|
20 |
+
from captum.log import log_usage
|
21 |
+
from torch import dtype, Tensor
|
22 |
+
|
23 |
+
|
24 |
+
class FeatureAblation(PerturbationAttribution):
|
25 |
+
r"""
|
26 |
+
A perturbation based approach to computing attribution, involving
|
27 |
+
replacing each input feature with a given baseline / reference, and
|
28 |
+
computing the difference in output. By default, each scalar value within
|
29 |
+
each input tensor is taken as a feature and replaced independently. Passing
|
30 |
+
a feature mask, allows grouping features to be ablated together. This can
|
31 |
+
be used in cases such as images, where an entire segment or region
|
32 |
+
can be ablated, measuring the importance of the segment (feature group).
|
33 |
+
Each input scalar in the group will be given the same attribution value
|
34 |
+
equal to the change in target as a result of ablating the entire feature
|
35 |
+
group.
|
36 |
+
|
37 |
+
The forward function can either return a scalar per example or a tensor
|
38 |
+
of a fixed sized tensor (or scalar value) for the full batch, i.e. the
|
39 |
+
output does not grow as the batch size increase. If the output is fixed
|
40 |
+
we consider this model to be an "aggregation" of the inputs. In the fixed
|
41 |
+
sized output mode we require `perturbations_per_eval == 1` and the
|
42 |
+
`feature_mask` to be either `None` or for all of them to have 1 as their
|
43 |
+
first dimension (i.e. a feature mask requires to be applied to all inputs).
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, forward_func: Callable) -> None:
|
47 |
+
r"""
|
48 |
+
Args:
|
49 |
+
|
50 |
+
forward_func (callable): The forward function of the model or
|
51 |
+
any modification of it
|
52 |
+
"""
|
53 |
+
PerturbationAttribution.__init__(self, forward_func)
|
54 |
+
self.use_weights = False
|
55 |
+
|
56 |
+
@log_usage()
|
57 |
+
def attribute(
|
58 |
+
self,
|
59 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
60 |
+
baselines: BaselineType = None,
|
61 |
+
target: TargetType = None,
|
62 |
+
additional_forward_args: Any = None,
|
63 |
+
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
|
64 |
+
perturbations_per_eval: int = 1,
|
65 |
+
show_progress: bool = False,
|
66 |
+
**kwargs: Any,
|
67 |
+
) -> TensorOrTupleOfTensorsGeneric:
|
68 |
+
r"""
|
69 |
+
Args:
|
70 |
+
|
71 |
+
inputs (tensor or tuple of tensors): Input for which ablation
|
72 |
+
attributions are computed. If forward_func takes a single
|
73 |
+
tensor as input, a single input tensor should be provided.
|
74 |
+
If forward_func takes multiple tensors as input, a tuple
|
75 |
+
of the input tensors should be provided. It is assumed
|
76 |
+
that for all given input tensors, dimension 0 corresponds
|
77 |
+
to the number of examples (aka batch size), and if
|
78 |
+
multiple input tensors are provided, the examples must
|
79 |
+
be aligned appropriately.
|
80 |
+
baselines (scalar, tensor, tuple of scalars or tensors, optional):
|
81 |
+
Baselines define reference value which replaces each
|
82 |
+
feature when ablated.
|
83 |
+
Baselines can be provided as:
|
84 |
+
|
85 |
+
- a single tensor, if inputs is a single tensor, with
|
86 |
+
exactly the same dimensions as inputs or
|
87 |
+
broadcastable to match the dimensions of inputs
|
88 |
+
|
89 |
+
- a single scalar, if inputs is a single tensor, which will
|
90 |
+
be broadcasted for each input value in input tensor.
|
91 |
+
|
92 |
+
- a tuple of tensors or scalars, the baseline corresponding
|
93 |
+
to each tensor in the inputs' tuple can be:
|
94 |
+
|
95 |
+
- either a tensor with matching dimensions to
|
96 |
+
corresponding tensor in the inputs' tuple
|
97 |
+
or the first dimension is one and the remaining
|
98 |
+
dimensions match with the corresponding
|
99 |
+
input tensor.
|
100 |
+
|
101 |
+
- or a scalar, corresponding to a tensor in the
|
102 |
+
inputs' tuple. This scalar value is broadcasted
|
103 |
+
for corresponding input tensor.
|
104 |
+
In the cases when `baselines` is not provided, we internally
|
105 |
+
use zero scalar corresponding to each input tensor.
|
106 |
+
Default: None
|
107 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
108 |
+
which gradients are computed (for classification cases,
|
109 |
+
this is usually the target class).
|
110 |
+
If the network returns a scalar value per example,
|
111 |
+
no target index is necessary.
|
112 |
+
For general 2D outputs, targets can be either:
|
113 |
+
|
114 |
+
- a single integer or a tensor containing a single
|
115 |
+
integer, which is applied to all input examples
|
116 |
+
|
117 |
+
- a list of integers or a 1D tensor, with length matching
|
118 |
+
the number of examples in inputs (dim 0). Each integer
|
119 |
+
is applied as the target for the corresponding example.
|
120 |
+
|
121 |
+
For outputs with > 2 dimensions, targets can be either:
|
122 |
+
|
123 |
+
- A single tuple, which contains #output_dims - 1
|
124 |
+
elements. This target index is applied to all examples.
|
125 |
+
|
126 |
+
- A list of tuples with length equal to the number of
|
127 |
+
examples in inputs (dim 0), and each tuple containing
|
128 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
129 |
+
target for the corresponding example.
|
130 |
+
|
131 |
+
Default: None
|
132 |
+
additional_forward_args (any, optional): If the forward function
|
133 |
+
requires additional arguments other than the inputs for
|
134 |
+
which attributions should not be computed, this argument
|
135 |
+
can be provided. It must be either a single additional
|
136 |
+
argument of a Tensor or arbitrary (non-tuple) type or a
|
137 |
+
tuple containing multiple additional arguments including
|
138 |
+
tensors or any arbitrary python types. These arguments
|
139 |
+
are provided to forward_func in order following the
|
140 |
+
arguments in inputs.
|
141 |
+
For a tensor, the first dimension of the tensor must
|
142 |
+
correspond to the number of examples. For all other types,
|
143 |
+
the given argument is used for all forward evaluations.
|
144 |
+
Note that attributions are not computed with respect
|
145 |
+
to these arguments.
|
146 |
+
Default: None
|
147 |
+
feature_mask (tensor or tuple of tensors, optional):
|
148 |
+
feature_mask defines a mask for the input, grouping
|
149 |
+
features which should be ablated together. feature_mask
|
150 |
+
should contain the same number of tensors as inputs.
|
151 |
+
Each tensor should
|
152 |
+
be the same size as the corresponding input or
|
153 |
+
broadcastable to match the input tensor. Each tensor
|
154 |
+
should contain integers in the range 0 to num_features
|
155 |
+
- 1, and indices corresponding to the same feature should
|
156 |
+
have the same value.
|
157 |
+
Note that features within each input tensor are ablated
|
158 |
+
independently (not across tensors).
|
159 |
+
If the forward function returns a single scalar per batch,
|
160 |
+
we enforce that the first dimension of each mask must be 1,
|
161 |
+
since attributions are returned batch-wise rather than per
|
162 |
+
example, so the attributions must correspond to the
|
163 |
+
same features (indices) in each input example.
|
164 |
+
If None, then a feature mask is constructed which assigns
|
165 |
+
each scalar within a tensor as a separate feature, which
|
166 |
+
is ablated independently.
|
167 |
+
Default: None
|
168 |
+
perturbations_per_eval (int, optional): Allows ablation of multiple
|
169 |
+
features to be processed simultaneously in one call to
|
170 |
+
forward_fn.
|
171 |
+
Each forward pass will contain a maximum of
|
172 |
+
perturbations_per_eval * #examples samples.
|
173 |
+
For DataParallel models, each batch is split among the
|
174 |
+
available devices, so evaluations on each available
|
175 |
+
device contain at most
|
176 |
+
(perturbations_per_eval * #examples) / num_devices
|
177 |
+
samples.
|
178 |
+
If the forward function's number of outputs does not
|
179 |
+
change as the batch size grows (e.g. if it outputs a
|
180 |
+
scalar value), you must set perturbations_per_eval to 1
|
181 |
+
and use a single feature mask to describe the features
|
182 |
+
for all examples in the batch.
|
183 |
+
Default: 1
|
184 |
+
show_progress (bool, optional): Displays the progress of computation.
|
185 |
+
It will try to use tqdm if available for advanced features
|
186 |
+
(e.g. time estimation). Otherwise, it will fallback to
|
187 |
+
a simple output of progress.
|
188 |
+
Default: False
|
189 |
+
**kwargs (Any, optional): Any additional arguments used by child
|
190 |
+
classes of FeatureAblation (such as Occlusion) to construct
|
191 |
+
ablations. These arguments are ignored when using
|
192 |
+
FeatureAblation directly.
|
193 |
+
Default: None
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
*tensor* or tuple of *tensors* of **attributions**:
|
197 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
198 |
+
The attributions with respect to each input feature.
|
199 |
+
If the forward function returns
|
200 |
+
a scalar value per example, attributions will be
|
201 |
+
the same size as the provided inputs, with each value
|
202 |
+
providing the attribution of the corresponding input index.
|
203 |
+
If the forward function returns a scalar per batch, then
|
204 |
+
attribution tensor(s) will have first dimension 1 and
|
205 |
+
the remaining dimensions will match the input.
|
206 |
+
If a single tensor is provided as inputs, a single tensor is
|
207 |
+
returned. If a tuple of tensors is provided for inputs, a
|
208 |
+
tuple of corresponding sized tensors is returned.
|
209 |
+
|
210 |
+
|
211 |
+
Examples::
|
212 |
+
|
213 |
+
>>> # SimpleClassifier takes a single input tensor of size Nx4x4,
|
214 |
+
>>> # and returns an Nx3 tensor of class probabilities.
|
215 |
+
>>> net = SimpleClassifier()
|
216 |
+
>>> # Generating random input with size 2 x 4 x 4
|
217 |
+
>>> input = torch.randn(2, 4, 4)
|
218 |
+
>>> # Defining FeatureAblation interpreter
|
219 |
+
>>> ablator = FeatureAblation(net)
|
220 |
+
>>> # Computes ablation attribution, ablating each of the 16
|
221 |
+
>>> # scalar input independently.
|
222 |
+
>>> attr = ablator.attribute(input, target=1)
|
223 |
+
|
224 |
+
>>> # Alternatively, we may want to ablate features in groups, e.g.
|
225 |
+
>>> # grouping each 2x2 square of the inputs and ablating them together.
|
226 |
+
>>> # This can be done by creating a feature mask as follows, which
|
227 |
+
>>> # defines the feature groups, e.g.:
|
228 |
+
>>> # +---+---+---+---+
|
229 |
+
>>> # | 0 | 0 | 1 | 1 |
|
230 |
+
>>> # +---+---+---+---+
|
231 |
+
>>> # | 0 | 0 | 1 | 1 |
|
232 |
+
>>> # +---+---+---+---+
|
233 |
+
>>> # | 2 | 2 | 3 | 3 |
|
234 |
+
>>> # +---+---+---+---+
|
235 |
+
>>> # | 2 | 2 | 3 | 3 |
|
236 |
+
>>> # +---+---+---+---+
|
237 |
+
>>> # With this mask, all inputs with the same value are ablated
|
238 |
+
>>> # simultaneously, and the attribution for each input in the same
|
239 |
+
>>> # group (0, 1, 2, and 3) per example are the same.
|
240 |
+
>>> # The attributions can be calculated as follows:
|
241 |
+
>>> # feature mask has dimensions 1 x 4 x 4
|
242 |
+
>>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1],
|
243 |
+
>>> [2,2,3,3],[2,2,3,3]]])
|
244 |
+
>>> attr = ablator.attribute(input, target=1, feature_mask=feature_mask)
|
245 |
+
"""
|
246 |
+
# Keeps track whether original input is a tuple or not before
|
247 |
+
# converting it into a tuple.
|
248 |
+
is_inputs_tuple = _is_tuple(inputs)
|
249 |
+
inputs, baselines = _format_input_baseline(inputs, baselines)
|
250 |
+
additional_forward_args = _format_additional_forward_args(
|
251 |
+
additional_forward_args
|
252 |
+
)
|
253 |
+
num_examples = inputs[0].shape[0]
|
254 |
+
feature_mask = (
|
255 |
+
_format_tensor_into_tuples(feature_mask)
|
256 |
+
if feature_mask is not None
|
257 |
+
else None
|
258 |
+
)
|
259 |
+
assert (
|
260 |
+
isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1
|
261 |
+
), "Perturbations per evaluation must be an integer and at least 1."
|
262 |
+
with torch.no_grad():
|
263 |
+
if show_progress:
|
264 |
+
feature_counts = self._get_feature_counts(
|
265 |
+
inputs, feature_mask, **kwargs
|
266 |
+
)
|
267 |
+
total_forwards = (
|
268 |
+
sum(
|
269 |
+
math.ceil(count / perturbations_per_eval)
|
270 |
+
for count in feature_counts
|
271 |
+
)
|
272 |
+
+ 1
|
273 |
+
) # add 1 for the initial eval
|
274 |
+
attr_progress = progress(
|
275 |
+
desc=f"{self.get_name()} attribution", total=total_forwards
|
276 |
+
)
|
277 |
+
attr_progress.update(0)
|
278 |
+
|
279 |
+
# Computes initial evaluation with all features, which is compared
|
280 |
+
# to each ablated result.
|
281 |
+
initial_eval = _run_forward(
|
282 |
+
self.forward_func, inputs, target, additional_forward_args
|
283 |
+
)
|
284 |
+
|
285 |
+
if show_progress:
|
286 |
+
attr_progress.update()
|
287 |
+
|
288 |
+
agg_output_mode = FeatureAblation._find_output_mode(
|
289 |
+
perturbations_per_eval, feature_mask
|
290 |
+
)
|
291 |
+
|
292 |
+
# get as a 2D tensor (if it is not a scalar)
|
293 |
+
if isinstance(initial_eval, torch.Tensor):
|
294 |
+
initial_eval = initial_eval.reshape(1, -1)
|
295 |
+
num_outputs = initial_eval.shape[1]
|
296 |
+
else:
|
297 |
+
num_outputs = 1
|
298 |
+
|
299 |
+
if not agg_output_mode:
|
300 |
+
assert (
|
301 |
+
isinstance(initial_eval, torch.Tensor)
|
302 |
+
and num_outputs == num_examples
|
303 |
+
), (
|
304 |
+
"expected output of `forward_func` to have "
|
305 |
+
+ "`batch_size` elements for perturbations_per_eval > 1 "
|
306 |
+
+ "and all feature_mask.shape[0] > 1"
|
307 |
+
)
|
308 |
+
|
309 |
+
# Initialize attribution totals and counts
|
310 |
+
attrib_type = cast(
|
311 |
+
dtype,
|
312 |
+
initial_eval.dtype
|
313 |
+
if isinstance(initial_eval, Tensor)
|
314 |
+
else type(initial_eval),
|
315 |
+
)
|
316 |
+
|
317 |
+
total_attrib = [
|
318 |
+
torch.zeros(
|
319 |
+
(num_outputs,) + input.shape[1:],
|
320 |
+
dtype=attrib_type,
|
321 |
+
device=input.device,
|
322 |
+
)
|
323 |
+
for input in inputs
|
324 |
+
]
|
325 |
+
|
326 |
+
# Weights are used in cases where ablations may be overlapping.
|
327 |
+
if self.use_weights:
|
328 |
+
weights = [
|
329 |
+
torch.zeros(
|
330 |
+
(num_outputs,) + input.shape[1:], device=input.device
|
331 |
+
).float()
|
332 |
+
for input in inputs
|
333 |
+
]
|
334 |
+
|
335 |
+
# Iterate through each feature tensor for ablation
|
336 |
+
for i in range(len(inputs)):
|
337 |
+
# Skip any empty input tensors
|
338 |
+
if torch.numel(inputs[i]) == 0:
|
339 |
+
continue
|
340 |
+
|
341 |
+
for (
|
342 |
+
current_inputs,
|
343 |
+
current_add_args,
|
344 |
+
current_target,
|
345 |
+
current_mask,
|
346 |
+
) in self._ith_input_ablation_generator(
|
347 |
+
i,
|
348 |
+
inputs,
|
349 |
+
additional_forward_args,
|
350 |
+
target,
|
351 |
+
baselines,
|
352 |
+
feature_mask,
|
353 |
+
perturbations_per_eval,
|
354 |
+
**kwargs,
|
355 |
+
):
|
356 |
+
# modified_eval dimensions: 1D tensor with length
|
357 |
+
# equal to #num_examples * #features in batch
|
358 |
+
modified_eval = _run_forward(
|
359 |
+
self.forward_func,
|
360 |
+
current_inputs,
|
361 |
+
current_target,
|
362 |
+
current_add_args,
|
363 |
+
)
|
364 |
+
|
365 |
+
if show_progress:
|
366 |
+
attr_progress.update()
|
367 |
+
|
368 |
+
# (contains 1 more dimension than inputs). This adds extra
|
369 |
+
# dimensions of 1 to make the tensor broadcastable with the inputs
|
370 |
+
# tensor.
|
371 |
+
if not isinstance(modified_eval, torch.Tensor):
|
372 |
+
eval_diff = initial_eval - modified_eval
|
373 |
+
else:
|
374 |
+
if not agg_output_mode:
|
375 |
+
assert (
|
376 |
+
modified_eval.numel() == current_inputs[0].shape[0]
|
377 |
+
), """expected output of forward_func to grow with
|
378 |
+
batch_size. If this is not the case for your model
|
379 |
+
please set perturbations_per_eval = 1"""
|
380 |
+
|
381 |
+
eval_diff = (
|
382 |
+
initial_eval - modified_eval.reshape((-1, num_outputs))
|
383 |
+
).reshape((-1, num_outputs) + (len(inputs[i].shape) - 1) * (1,))
|
384 |
+
eval_diff = eval_diff.to(total_attrib[i].device)
|
385 |
+
if self.use_weights:
|
386 |
+
weights[i] += current_mask.float().sum(dim=0)
|
387 |
+
total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(
|
388 |
+
dim=0
|
389 |
+
)
|
390 |
+
|
391 |
+
if show_progress:
|
392 |
+
attr_progress.close()
|
393 |
+
|
394 |
+
# Divide total attributions by counts and return formatted attributions
|
395 |
+
if self.use_weights:
|
396 |
+
attrib = tuple(
|
397 |
+
single_attrib.float() / weight
|
398 |
+
for single_attrib, weight in zip(total_attrib, weights)
|
399 |
+
)
|
400 |
+
else:
|
401 |
+
attrib = tuple(total_attrib)
|
402 |
+
_result = _format_output(is_inputs_tuple, attrib)
|
403 |
+
return _result
|
404 |
+
|
405 |
+
def _ith_input_ablation_generator(
|
406 |
+
self,
|
407 |
+
i,
|
408 |
+
inputs,
|
409 |
+
additional_args,
|
410 |
+
target,
|
411 |
+
baselines,
|
412 |
+
input_mask,
|
413 |
+
perturbations_per_eval,
|
414 |
+
**kwargs,
|
415 |
+
):
|
416 |
+
"""
|
417 |
+
This method return an generator of ablation perturbations of the i-th input
|
418 |
+
|
419 |
+
Returns:
|
420 |
+
ablation_iter (generator): yields each perturbation to be evaluated
|
421 |
+
as a tuple (inputs, additional_forward_args, targets, mask).
|
422 |
+
"""
|
423 |
+
extra_args = {}
|
424 |
+
for key, value in kwargs.items():
|
425 |
+
# For any tuple argument in kwargs, we choose index i of the tuple.
|
426 |
+
if isinstance(value, tuple):
|
427 |
+
extra_args[key] = value[i]
|
428 |
+
else:
|
429 |
+
extra_args[key] = value
|
430 |
+
|
431 |
+
input_mask = input_mask[i] if input_mask is not None else None
|
432 |
+
min_feature, num_features, input_mask = self._get_feature_range_and_mask(
|
433 |
+
inputs[i], input_mask, **extra_args
|
434 |
+
)
|
435 |
+
num_examples = inputs[0].shape[0]
|
436 |
+
perturbations_per_eval = min(perturbations_per_eval, num_features)
|
437 |
+
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
|
438 |
+
if isinstance(baseline, torch.Tensor):
|
439 |
+
baseline = baseline.reshape((1,) + baseline.shape)
|
440 |
+
|
441 |
+
if perturbations_per_eval > 1:
|
442 |
+
# Repeat features and additional args for batch size.
|
443 |
+
all_features_repeated = [
|
444 |
+
torch.cat([inputs[j]] * perturbations_per_eval, dim=0)
|
445 |
+
for j in range(len(inputs))
|
446 |
+
]
|
447 |
+
additional_args_repeated = (
|
448 |
+
_expand_additional_forward_args(additional_args, perturbations_per_eval)
|
449 |
+
if additional_args is not None
|
450 |
+
else None
|
451 |
+
)
|
452 |
+
target_repeated = _expand_target(target, perturbations_per_eval)
|
453 |
+
else:
|
454 |
+
all_features_repeated = list(inputs)
|
455 |
+
additional_args_repeated = additional_args
|
456 |
+
target_repeated = target
|
457 |
+
|
458 |
+
num_features_processed = min_feature
|
459 |
+
while num_features_processed < num_features:
|
460 |
+
current_num_ablated_features = min(
|
461 |
+
perturbations_per_eval, num_features - num_features_processed
|
462 |
+
)
|
463 |
+
|
464 |
+
# Store appropriate inputs and additional args based on batch size.
|
465 |
+
if current_num_ablated_features != perturbations_per_eval:
|
466 |
+
current_features = [
|
467 |
+
feature_repeated[0 : current_num_ablated_features * num_examples]
|
468 |
+
for feature_repeated in all_features_repeated
|
469 |
+
]
|
470 |
+
current_additional_args = (
|
471 |
+
_expand_additional_forward_args(
|
472 |
+
additional_args, current_num_ablated_features
|
473 |
+
)
|
474 |
+
if additional_args is not None
|
475 |
+
else None
|
476 |
+
)
|
477 |
+
current_target = _expand_target(target, current_num_ablated_features)
|
478 |
+
else:
|
479 |
+
current_features = all_features_repeated
|
480 |
+
current_additional_args = additional_args_repeated
|
481 |
+
current_target = target_repeated
|
482 |
+
|
483 |
+
# Store existing tensor before modifying
|
484 |
+
original_tensor = current_features[i]
|
485 |
+
# Construct ablated batch for features in range num_features_processed
|
486 |
+
# to num_features_processed + current_num_ablated_features and return
|
487 |
+
# mask with same size as ablated batch. ablated_features has dimension
|
488 |
+
# (current_num_ablated_features, num_examples, inputs[i].shape[1:])
|
489 |
+
# Note that in the case of sparse tensors, the second dimension
|
490 |
+
# may not necessarilly be num_examples and will match the first
|
491 |
+
# dimension of this tensor.
|
492 |
+
current_reshaped = current_features[i].reshape(
|
493 |
+
(current_num_ablated_features, -1) + current_features[i].shape[1:]
|
494 |
+
)
|
495 |
+
|
496 |
+
ablated_features, current_mask = self._construct_ablated_input(
|
497 |
+
current_reshaped,
|
498 |
+
input_mask,
|
499 |
+
baseline,
|
500 |
+
num_features_processed,
|
501 |
+
num_features_processed + current_num_ablated_features,
|
502 |
+
**extra_args,
|
503 |
+
)
|
504 |
+
|
505 |
+
# current_features[i] has dimension
|
506 |
+
# (current_num_ablated_features * num_examples, inputs[i].shape[1:]),
|
507 |
+
# which can be provided to the model as input.
|
508 |
+
current_features[i] = ablated_features.reshape(
|
509 |
+
(-1,) + ablated_features.shape[2:]
|
510 |
+
)
|
511 |
+
yield tuple(
|
512 |
+
current_features
|
513 |
+
), current_additional_args, current_target, current_mask
|
514 |
+
# Replace existing tensor at index i.
|
515 |
+
current_features[i] = original_tensor
|
516 |
+
num_features_processed += current_num_ablated_features
|
517 |
+
|
518 |
+
def _construct_ablated_input(
|
519 |
+
self, expanded_input, input_mask, baseline, start_feature, end_feature, **kwargs
|
520 |
+
):
|
521 |
+
r"""
|
522 |
+
Ablates given expanded_input tensor with given feature mask, feature range,
|
523 |
+
and baselines. expanded_input shape is (`num_features`, `num_examples`, ...)
|
524 |
+
with remaining dimensions corresponding to remaining original tensor
|
525 |
+
dimensions and `num_features` = `end_feature` - `start_feature`.
|
526 |
+
input_mask has same number of dimensions as original input tensor (one less
|
527 |
+
than `expanded_input`), and can have first dimension either 1, applying same
|
528 |
+
feature mask to all examples, or `num_examples`. baseline is expected to
|
529 |
+
be broadcastable to match `expanded_input`.
|
530 |
+
|
531 |
+
This method returns the ablated input tensor, which has the same
|
532 |
+
dimensionality as `expanded_input` as well as the corresponding mask with
|
533 |
+
either the same dimensionality as `expanded_input` or second dimension
|
534 |
+
being 1. This mask contains 1s in locations which have been ablated (and
|
535 |
+
thus counted towards ablations for that feature) and 0s otherwise.
|
536 |
+
"""
|
537 |
+
current_mask = torch.stack(
|
538 |
+
[input_mask == j for j in range(start_feature, end_feature)], dim=0
|
539 |
+
).long()
|
540 |
+
ablated_tensor = (
|
541 |
+
expanded_input * (1 - current_mask).to(expanded_input.dtype)
|
542 |
+
) + (baseline * current_mask.to(expanded_input.dtype))
|
543 |
+
return ablated_tensor, current_mask
|
544 |
+
|
545 |
+
def _get_feature_range_and_mask(self, input, input_mask, **kwargs):
|
546 |
+
if input_mask is None:
|
547 |
+
# Obtain feature mask for selected input tensor, matches size of
|
548 |
+
# 1 input example, (1 x inputs[i].shape[1:])
|
549 |
+
input_mask = torch.reshape(
|
550 |
+
torch.arange(torch.numel(input[0]), device=input.device),
|
551 |
+
input[0:1].shape,
|
552 |
+
).long()
|
553 |
+
return (
|
554 |
+
torch.min(input_mask).item(),
|
555 |
+
torch.max(input_mask).item() + 1,
|
556 |
+
input_mask,
|
557 |
+
)
|
558 |
+
|
559 |
+
def _get_feature_counts(self, inputs, feature_mask, **kwargs):
|
560 |
+
"""return the numbers of input features"""
|
561 |
+
if not feature_mask:
|
562 |
+
return tuple(inp[0].numel() if inp.numel() else 0 for inp in inputs)
|
563 |
+
|
564 |
+
return tuple(
|
565 |
+
(mask.max() - mask.min()).item() + 1
|
566 |
+
if mask is not None
|
567 |
+
else (inp[0].numel() if inp.numel() else 0)
|
568 |
+
for inp, mask in zip(inputs, feature_mask)
|
569 |
+
)
|
570 |
+
|
571 |
+
@staticmethod
|
572 |
+
def _find_output_mode(
|
573 |
+
perturbations_per_eval: int,
|
574 |
+
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric],
|
575 |
+
) -> bool:
|
576 |
+
"""
|
577 |
+
Returns True if the output mode is "aggregation output mode"
|
578 |
+
|
579 |
+
Aggregation output mode is defined as: when there is no 1:1 correspondence
|
580 |
+
with the `num_examples` (`batch_size`) and the amount of outputs your model
|
581 |
+
produces, i.e. the model output does not grow in size as the input becomes
|
582 |
+
larger.
|
583 |
+
|
584 |
+
We assume this is the case if `perturbations_per_eval == 1`
|
585 |
+
and your feature mask is None or is associated to all
|
586 |
+
examples in a batch (fm.shape[0] == 1 for all fm in feature_mask).
|
587 |
+
"""
|
588 |
+
return perturbations_per_eval == 1 and (
|
589 |
+
feature_mask is None
|
590 |
+
or all(len(sm.shape) == 0 or sm.shape[0] == 1 for sm in feature_mask)
|
591 |
+
)
|
captum/attr/_core/feature_permutation.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from typing import Any, Callable, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
|
6 |
+
from captum.attr._core.feature_ablation import FeatureAblation
|
7 |
+
from captum.log import log_usage
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
|
11 |
+
def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor:
|
12 |
+
n = x.size(0)
|
13 |
+
assert n > 1, "cannot permute features with batch_size = 1"
|
14 |
+
|
15 |
+
perm = torch.randperm(n)
|
16 |
+
no_perm = torch.arange(n)
|
17 |
+
while (perm == no_perm).all():
|
18 |
+
perm = torch.randperm(n)
|
19 |
+
|
20 |
+
return (x[perm] * feature_mask.to(dtype=x.dtype)) + (
|
21 |
+
x * feature_mask.bitwise_not().to(dtype=x.dtype)
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class FeaturePermutation(FeatureAblation):
|
26 |
+
r"""
|
27 |
+
A perturbation based approach to compute attribution, which
|
28 |
+
takes each input feature, permutes the feature values within a batch,
|
29 |
+
and computes the difference between original and shuffled outputs for
|
30 |
+
the given batch. This difference signifies the feature importance
|
31 |
+
for the permuted feature.
|
32 |
+
|
33 |
+
Example pseudocode for the algorithm is as follows::
|
34 |
+
|
35 |
+
perm_feature_importance(batch):
|
36 |
+
importance = dict()
|
37 |
+
baseline_error = error_metric(model(batch), batch_labels)
|
38 |
+
for each feature:
|
39 |
+
permute this feature across the batch
|
40 |
+
error = error_metric(model(permuted_batch), batch_labels)
|
41 |
+
importance[feature] = baseline_error - error
|
42 |
+
"un-permute" the feature across the batch
|
43 |
+
|
44 |
+
return importance
|
45 |
+
|
46 |
+
It should be noted that the `error_metric` must be called in the
|
47 |
+
`forward_func`. You do not need to have an error metric, e.g. you
|
48 |
+
could simply return the logits (the model output), but this may or may
|
49 |
+
not provide a meaningful attribution.
|
50 |
+
|
51 |
+
This method, unlike other attribution methods, requires a batch
|
52 |
+
of examples to compute attributions and cannot be performed on a single example.
|
53 |
+
|
54 |
+
By default, each scalar value within
|
55 |
+
each input tensor is taken as a feature and shuffled independently. Passing
|
56 |
+
a feature mask, allows grouping features to be shuffled together.
|
57 |
+
Each input scalar in the group will be given the same attribution value
|
58 |
+
equal to the change in target as a result of shuffling the entire feature
|
59 |
+
group.
|
60 |
+
|
61 |
+
The forward function can either return a scalar per example, or a single
|
62 |
+
scalar for the full batch. If a single scalar is returned for the batch,
|
63 |
+
`perturbations_per_eval` must be 1, and the returned attributions will have
|
64 |
+
first dimension 1, corresponding to feature importance across all
|
65 |
+
examples in the batch.
|
66 |
+
|
67 |
+
More information can be found in the permutation feature
|
68 |
+
importance algorithm description here:
|
69 |
+
https://christophm.github.io/interpretable-ml-book/feature-importance.html
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(
|
73 |
+
self, forward_func: Callable, perm_func: Callable = _permute_feature
|
74 |
+
) -> None:
|
75 |
+
r"""
|
76 |
+
Args:
|
77 |
+
|
78 |
+
forward_func (callable): The forward function of the model or
|
79 |
+
any modification of it
|
80 |
+
perm_func (callable, optional): A function that accepts a batch of
|
81 |
+
inputs and a feature mask, and "permutes" the feature using
|
82 |
+
feature mask across the batch. This defaults to a function
|
83 |
+
which applies a random permutation, this argument only needs
|
84 |
+
to be provided if a custom permutation behavior is desired.
|
85 |
+
Default: `_permute_feature`
|
86 |
+
"""
|
87 |
+
FeatureAblation.__init__(self, forward_func=forward_func)
|
88 |
+
self.perm_func = perm_func
|
89 |
+
|
90 |
+
# suppressing error caused by the child class not having a matching
|
91 |
+
# signature to the parent
|
92 |
+
@log_usage()
|
93 |
+
def attribute( # type: ignore
|
94 |
+
self,
|
95 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
96 |
+
target: TargetType = None,
|
97 |
+
additional_forward_args: Any = None,
|
98 |
+
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
|
99 |
+
perturbations_per_eval: int = 1,
|
100 |
+
show_progress: bool = False,
|
101 |
+
**kwargs: Any,
|
102 |
+
) -> TensorOrTupleOfTensorsGeneric:
|
103 |
+
r"""
|
104 |
+
This function is almost equivalent to `FeatureAblation.attribute`. The
|
105 |
+
main difference is the way ablated examples are generated. Specifically
|
106 |
+
they are generated through the `perm_func`, as we set the baselines for
|
107 |
+
`FeatureAblation.attribute` to None.
|
108 |
+
|
109 |
+
|
110 |
+
Args:
|
111 |
+
inputs (tensor or tuple of tensors): Input for which
|
112 |
+
permutation attributions are computed. If
|
113 |
+
forward_func takes a single tensor as input, a
|
114 |
+
single input tensor should be provided. If
|
115 |
+
forward_func takes multiple tensors as input, a
|
116 |
+
tuple of the input tensors should be provided. It is
|
117 |
+
assumed that for all given input tensors, dimension
|
118 |
+
0 corresponds to the number of examples (aka batch
|
119 |
+
size), and if multiple input tensors are provided,
|
120 |
+
the examples must be aligned appropriately.
|
121 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
122 |
+
which difference is computed (for classification cases,
|
123 |
+
this is usually the target class).
|
124 |
+
If the network returns a scalar value per example,
|
125 |
+
no target index is necessary.
|
126 |
+
For general 2D outputs, targets can be either:
|
127 |
+
|
128 |
+
- a single integer or a tensor containing a single
|
129 |
+
integer, which is applied to all input examples
|
130 |
+
|
131 |
+
- a list of integers or a 1D tensor, with length matching
|
132 |
+
the number of examples in inputs (dim 0). Each integer
|
133 |
+
is applied as the target for the corresponding example.
|
134 |
+
|
135 |
+
For outputs with > 2 dimensions, targets can be either:
|
136 |
+
|
137 |
+
- A single tuple, which contains #output_dims - 1
|
138 |
+
elements. This target index is applied to all examples.
|
139 |
+
|
140 |
+
- A list of tuples with length equal to the number of
|
141 |
+
examples in inputs (dim 0), and each tuple containing
|
142 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
143 |
+
target for the corresponding example.
|
144 |
+
|
145 |
+
Default: None
|
146 |
+
additional_forward_args (any, optional): If the forward function
|
147 |
+
requires additional arguments other than the inputs for
|
148 |
+
which attributions should not be computed, this argument
|
149 |
+
can be provided. It must be either a single additional
|
150 |
+
argument of a Tensor or arbitrary (non-tuple) type or a
|
151 |
+
tuple containing multiple additional arguments including
|
152 |
+
tensors or any arbitrary python types. These arguments
|
153 |
+
are provided to forward_func in order following the
|
154 |
+
arguments in inputs.
|
155 |
+
For a tensor, the first dimension of the tensor must
|
156 |
+
correspond to the number of examples. For all other types,
|
157 |
+
the given argument is used for all forward evaluations.
|
158 |
+
Note that attributions are not computed with respect
|
159 |
+
to these arguments.
|
160 |
+
Default: None
|
161 |
+
feature_mask (tensor or tuple of tensors, optional):
|
162 |
+
feature_mask defines a mask for the input, grouping
|
163 |
+
features which should be ablated together. feature_mask
|
164 |
+
should contain the same number of tensors as inputs.
|
165 |
+
Each tensor should be the same size as the
|
166 |
+
corresponding input or broadcastable to match the
|
167 |
+
input tensor. Each tensor should contain integers in
|
168 |
+
the range 0 to num_features - 1, and indices
|
169 |
+
corresponding to the same feature should have the
|
170 |
+
same value. Note that features within each input
|
171 |
+
tensor are ablated independently (not across
|
172 |
+
tensors).
|
173 |
+
|
174 |
+
The first dimension of each mask must be 1, as we require
|
175 |
+
to have the same group of features for each input sample.
|
176 |
+
|
177 |
+
If None, then a feature mask is constructed which assigns
|
178 |
+
each scalar within a tensor as a separate feature, which
|
179 |
+
is permuted independently.
|
180 |
+
Default: None
|
181 |
+
perturbations_per_eval (int, optional): Allows permutations
|
182 |
+
of multiple features to be processed simultaneously
|
183 |
+
in one call to forward_fn. Each forward pass will
|
184 |
+
contain a maximum of perturbations_per_eval * #examples
|
185 |
+
samples. For DataParallel models, each batch is
|
186 |
+
split among the available devices, so evaluations on
|
187 |
+
each available device contain at most
|
188 |
+
(perturbations_per_eval * #examples) / num_devices
|
189 |
+
samples.
|
190 |
+
If the forward function returns a single scalar per batch,
|
191 |
+
perturbations_per_eval must be set to 1.
|
192 |
+
Default: 1
|
193 |
+
show_progress (bool, optional): Displays the progress of computation.
|
194 |
+
It will try to use tqdm if available for advanced features
|
195 |
+
(e.g. time estimation). Otherwise, it will fallback to
|
196 |
+
a simple output of progress.
|
197 |
+
Default: False
|
198 |
+
**kwargs (Any, optional): Any additional arguments used by child
|
199 |
+
classes of FeatureAblation (such as Occlusion) to construct
|
200 |
+
ablations. These arguments are ignored when using
|
201 |
+
FeatureAblation directly.
|
202 |
+
Default: None
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
*tensor* or tuple of *tensors* of **attributions**:
|
206 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
207 |
+
The attributions with respect to each input feature.
|
208 |
+
If the forward function returns
|
209 |
+
a scalar value per example, attributions will be
|
210 |
+
the same size as the provided inputs, with each value
|
211 |
+
providing the attribution of the corresponding input index.
|
212 |
+
If the forward function returns a scalar per batch, then
|
213 |
+
attribution tensor(s) will have first dimension 1 and
|
214 |
+
the remaining dimensions will match the input.
|
215 |
+
If a single tensor is provided as inputs, a single tensor is
|
216 |
+
returned. If a tuple of tensors is provided for inputs,
|
217 |
+
a tuple of corresponding sized tensors is returned.
|
218 |
+
|
219 |
+
|
220 |
+
Examples::
|
221 |
+
|
222 |
+
>>> # SimpleClassifier takes a single input tensor of size Nx4x4,
|
223 |
+
>>> # and returns an Nx3 tensor of class probabilities.
|
224 |
+
>>> net = SimpleClassifier()
|
225 |
+
>>> # Generating random input with size 10 x 4 x 4
|
226 |
+
>>> input = torch.randn(10, 4, 4)
|
227 |
+
>>> # Defining FeaturePermutation interpreter
|
228 |
+
>>> feature_perm = FeaturePermutation(net)
|
229 |
+
>>> # Computes permutation attribution, shuffling each of the 16
|
230 |
+
>>> # scalar input independently.
|
231 |
+
>>> attr = feature_perm.attribute(input, target=1)
|
232 |
+
|
233 |
+
>>> # Alternatively, we may want to permute features in groups, e.g.
|
234 |
+
>>> # grouping each 2x2 square of the inputs and shuffling them together.
|
235 |
+
>>> # This can be done by creating a feature mask as follows, which
|
236 |
+
>>> # defines the feature groups, e.g.:
|
237 |
+
>>> # +---+---+---+---+
|
238 |
+
>>> # | 0 | 0 | 1 | 1 |
|
239 |
+
>>> # +---+---+---+---+
|
240 |
+
>>> # | 0 | 0 | 1 | 1 |
|
241 |
+
>>> # +---+---+---+---+
|
242 |
+
>>> # | 2 | 2 | 3 | 3 |
|
243 |
+
>>> # +---+---+---+---+
|
244 |
+
>>> # | 2 | 2 | 3 | 3 |
|
245 |
+
>>> # +---+---+---+---+
|
246 |
+
>>> # With this mask, all inputs with the same value are shuffled
|
247 |
+
>>> # simultaneously, and the attribution for each input in the same
|
248 |
+
>>> # group (0, 1, 2, and 3) per example are the same.
|
249 |
+
>>> # The attributions can be calculated as follows:
|
250 |
+
>>> # feature mask has dimensions 1 x 4 x 4
|
251 |
+
>>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1],
|
252 |
+
>>> [2,2,3,3],[2,2,3,3]]])
|
253 |
+
>>> attr = feature_perm.attribute(input, target=1,
|
254 |
+
>>> feature_mask=feature_mask)
|
255 |
+
"""
|
256 |
+
return FeatureAblation.attribute.__wrapped__(
|
257 |
+
self,
|
258 |
+
inputs,
|
259 |
+
baselines=None,
|
260 |
+
target=target,
|
261 |
+
additional_forward_args=additional_forward_args,
|
262 |
+
feature_mask=feature_mask,
|
263 |
+
perturbations_per_eval=perturbations_per_eval,
|
264 |
+
show_progress=show_progress,
|
265 |
+
**kwargs,
|
266 |
+
)
|
267 |
+
|
268 |
+
def _construct_ablated_input(
|
269 |
+
self,
|
270 |
+
expanded_input: Tensor,
|
271 |
+
input_mask: Tensor,
|
272 |
+
baseline: Union[int, float, Tensor],
|
273 |
+
start_feature: int,
|
274 |
+
end_feature: int,
|
275 |
+
**kwargs: Any,
|
276 |
+
) -> Tuple[Tensor, Tensor]:
|
277 |
+
r"""
|
278 |
+
This function permutes the features of `expanded_input` with a given
|
279 |
+
feature mask and feature range. Permutation occurs via calling
|
280 |
+
`self.perm_func` across each batch within `expanded_input`. As with
|
281 |
+
`FeatureAblation._construct_ablated_input`:
|
282 |
+
- `expanded_input.shape = (num_features, num_examples, ...)`
|
283 |
+
- `num_features = end_feature - start_feature` (i.e. start and end is a
|
284 |
+
half-closed interval)
|
285 |
+
- `input_mask` is a tensor of the same shape as one input, which
|
286 |
+
describes the locations of each feature via their "index"
|
287 |
+
|
288 |
+
Since `baselines` is set to None for `FeatureAblation.attribute, this
|
289 |
+
will be the zero tensor, however, it is not used.
|
290 |
+
"""
|
291 |
+
assert input_mask.shape[0] == 1, (
|
292 |
+
"input_mask.shape[0] != 1: pass in one mask in order to permute"
|
293 |
+
"the same features for each input"
|
294 |
+
)
|
295 |
+
current_mask = torch.stack(
|
296 |
+
[input_mask == j for j in range(start_feature, end_feature)], dim=0
|
297 |
+
).bool()
|
298 |
+
|
299 |
+
output = torch.stack(
|
300 |
+
[
|
301 |
+
self.perm_func(x, mask.squeeze(0))
|
302 |
+
for x, mask in zip(expanded_input, current_mask)
|
303 |
+
]
|
304 |
+
)
|
305 |
+
return output, current_mask
|
captum/attr/_core/gradient_shap.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import typing
|
3 |
+
from typing import Any, Callable, Tuple, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from captum._utils.common import _is_tuple
|
8 |
+
from captum._utils.typing import (
|
9 |
+
BaselineType,
|
10 |
+
Literal,
|
11 |
+
TargetType,
|
12 |
+
Tensor,
|
13 |
+
TensorOrTupleOfTensorsGeneric,
|
14 |
+
)
|
15 |
+
from captum.attr._core.noise_tunnel import NoiseTunnel
|
16 |
+
from captum.attr._utils.attribution import GradientAttribution
|
17 |
+
from captum.attr._utils.common import (
|
18 |
+
_compute_conv_delta_and_format_attrs,
|
19 |
+
_format_callable_baseline,
|
20 |
+
_format_input_baseline,
|
21 |
+
)
|
22 |
+
from captum.log import log_usage
|
23 |
+
|
24 |
+
|
25 |
+
class GradientShap(GradientAttribution):
|
26 |
+
r"""
|
27 |
+
Implements gradient SHAP based on the implementation from SHAP's primary
|
28 |
+
author. For reference, please view the original
|
29 |
+
`implementation
|
30 |
+
<https://github.com/slundberg/shap#deep-learning-example-with-gradientexplainer-tensorflowkeraspytorch-models>`_
|
31 |
+
and the paper: `A Unified Approach to Interpreting Model Predictions
|
32 |
+
<https://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions>`_
|
33 |
+
|
34 |
+
GradientShap approximates SHAP values by computing the expectations of
|
35 |
+
gradients by randomly sampling from the distribution of baselines/references.
|
36 |
+
It adds white noise to each input sample `n_samples` times, selects a
|
37 |
+
random baseline from baselines' distribution and a random point along the
|
38 |
+
path between the baseline and the input, and computes the gradient of outputs
|
39 |
+
with respect to those selected random points. The final SHAP values represent
|
40 |
+
the expected values of gradients * (inputs - baselines).
|
41 |
+
|
42 |
+
GradientShap makes an assumption that the input features are independent
|
43 |
+
and that the explanation model is linear, meaning that the explanations
|
44 |
+
are modeled through the additive composition of feature effects.
|
45 |
+
Under those assumptions, SHAP value can be approximated as the expectation
|
46 |
+
of gradients that are computed for randomly generated `n_samples` input
|
47 |
+
samples after adding gaussian noise `n_samples` times to each input for
|
48 |
+
different baselines/references.
|
49 |
+
|
50 |
+
In some sense it can be viewed as an approximation of integrated gradients
|
51 |
+
by computing the expectations of gradients for different baselines.
|
52 |
+
|
53 |
+
Current implementation uses Smoothgrad from `NoiseTunnel` in order to
|
54 |
+
randomly draw samples from the distribution of baselines, add noise to input
|
55 |
+
samples and compute the expectation (smoothgrad).
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> None:
|
59 |
+
r"""
|
60 |
+
Args:
|
61 |
+
|
62 |
+
forward_func (function): The forward function of the model or
|
63 |
+
any modification of it.
|
64 |
+
multiply_by_inputs (bool, optional): Indicates whether to factor
|
65 |
+
model inputs' multiplier in the final attribution scores.
|
66 |
+
In the literature this is also known as local vs global
|
67 |
+
attribution. If inputs' multiplier isn't factored in
|
68 |
+
then this type of attribution method is also called local
|
69 |
+
attribution. If it is, then that type of attribution
|
70 |
+
method is called global.
|
71 |
+
More detailed can be found here:
|
72 |
+
https://arxiv.org/abs/1711.06104
|
73 |
+
|
74 |
+
In case of gradient shap, if `multiply_by_inputs`
|
75 |
+
is set to True, the sensitivity scores of scaled inputs
|
76 |
+
are being multiplied by (inputs - baselines).
|
77 |
+
"""
|
78 |
+
GradientAttribution.__init__(self, forward_func)
|
79 |
+
self._multiply_by_inputs = multiply_by_inputs
|
80 |
+
|
81 |
+
@typing.overload
|
82 |
+
def attribute(
|
83 |
+
self,
|
84 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
85 |
+
baselines: Union[
|
86 |
+
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
|
87 |
+
],
|
88 |
+
n_samples: int = 5,
|
89 |
+
stdevs: Union[float, Tuple[float, ...]] = 0.0,
|
90 |
+
target: TargetType = None,
|
91 |
+
additional_forward_args: Any = None,
|
92 |
+
*,
|
93 |
+
return_convergence_delta: Literal[True],
|
94 |
+
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
|
95 |
+
...
|
96 |
+
|
97 |
+
@typing.overload
|
98 |
+
def attribute(
|
99 |
+
self,
|
100 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
101 |
+
baselines: Union[
|
102 |
+
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
|
103 |
+
],
|
104 |
+
n_samples: int = 5,
|
105 |
+
stdevs: Union[float, Tuple[float, ...]] = 0.0,
|
106 |
+
target: TargetType = None,
|
107 |
+
additional_forward_args: Any = None,
|
108 |
+
return_convergence_delta: Literal[False] = False,
|
109 |
+
) -> TensorOrTupleOfTensorsGeneric:
|
110 |
+
...
|
111 |
+
|
112 |
+
@log_usage()
|
113 |
+
def attribute(
|
114 |
+
self,
|
115 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
116 |
+
baselines: Union[
|
117 |
+
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
|
118 |
+
],
|
119 |
+
n_samples: int = 5,
|
120 |
+
stdevs: Union[float, Tuple[float, ...]] = 0.0,
|
121 |
+
target: TargetType = None,
|
122 |
+
additional_forward_args: Any = None,
|
123 |
+
return_convergence_delta: bool = False,
|
124 |
+
) -> Union[
|
125 |
+
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
|
126 |
+
]:
|
127 |
+
r"""
|
128 |
+
Args:
|
129 |
+
|
130 |
+
inputs (tensor or tuple of tensors): Input for which SHAP attribution
|
131 |
+
values are computed. If `forward_func` takes a single
|
132 |
+
tensor as input, a single input tensor should be provided.
|
133 |
+
If `forward_func` takes multiple tensors as input, a tuple
|
134 |
+
of the input tensors should be provided. It is assumed
|
135 |
+
that for all given input tensors, dimension 0 corresponds
|
136 |
+
to the number of examples, and if multiple input tensors
|
137 |
+
are provided, the examples must be aligned appropriately.
|
138 |
+
baselines (tensor, tuple of tensors, callable):
|
139 |
+
Baselines define the starting point from which expectation
|
140 |
+
is computed and can be provided as:
|
141 |
+
|
142 |
+
- a single tensor, if inputs is a single tensor, with
|
143 |
+
the first dimension equal to the number of examples
|
144 |
+
in the baselines' distribution. The remaining dimensions
|
145 |
+
must match with input tensor's dimension starting from
|
146 |
+
the second dimension.
|
147 |
+
|
148 |
+
- a tuple of tensors, if inputs is a tuple of tensors,
|
149 |
+
with the first dimension of any tensor inside the tuple
|
150 |
+
equal to the number of examples in the baseline's
|
151 |
+
distribution. The remaining dimensions must match
|
152 |
+
the dimensions of the corresponding input tensor
|
153 |
+
starting from the second dimension.
|
154 |
+
|
155 |
+
- callable function, optionally takes `inputs` as an
|
156 |
+
argument and either returns a single tensor
|
157 |
+
or a tuple of those.
|
158 |
+
|
159 |
+
It is recommended that the number of samples in the baselines'
|
160 |
+
tensors is larger than one.
|
161 |
+
n_samples (int, optional): The number of randomly generated examples
|
162 |
+
per sample in the input batch. Random examples are
|
163 |
+
generated by adding gaussian random noise to each sample.
|
164 |
+
Default: `5` if `n_samples` is not provided.
|
165 |
+
stdevs (float, or a tuple of floats optional): The standard deviation
|
166 |
+
of gaussian noise with zero mean that is added to each
|
167 |
+
input in the batch. If `stdevs` is a single float value
|
168 |
+
then that same value is used for all inputs. If it is
|
169 |
+
a tuple, then it must have the same length as the inputs
|
170 |
+
tuple. In this case, each stdev value in the stdevs tuple
|
171 |
+
corresponds to the input with the same index in the inputs
|
172 |
+
tuple.
|
173 |
+
Default: 0.0
|
174 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
175 |
+
which gradients are computed (for classification cases,
|
176 |
+
this is usually the target class).
|
177 |
+
If the network returns a scalar value per example,
|
178 |
+
no target index is necessary.
|
179 |
+
For general 2D outputs, targets can be either:
|
180 |
+
|
181 |
+
- a single integer or a tensor containing a single
|
182 |
+
integer, which is applied to all input examples
|
183 |
+
|
184 |
+
- a list of integers or a 1D tensor, with length matching
|
185 |
+
the number of examples in inputs (dim 0). Each integer
|
186 |
+
is applied as the target for the corresponding example.
|
187 |
+
|
188 |
+
For outputs with > 2 dimensions, targets can be either:
|
189 |
+
|
190 |
+
- A single tuple, which contains #output_dims - 1
|
191 |
+
elements. This target index is applied to all examples.
|
192 |
+
|
193 |
+
- A list of tuples with length equal to the number of
|
194 |
+
examples in inputs (dim 0), and each tuple containing
|
195 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
196 |
+
target for the corresponding example.
|
197 |
+
|
198 |
+
Default: None
|
199 |
+
additional_forward_args (any, optional): If the forward function
|
200 |
+
requires additional arguments other than the inputs for
|
201 |
+
which attributions should not be computed, this argument
|
202 |
+
can be provided. It can contain a tuple of ND tensors or
|
203 |
+
any arbitrary python type of any shape.
|
204 |
+
In case of the ND tensor the first dimension of the
|
205 |
+
tensor must correspond to the batch size. It will be
|
206 |
+
repeated for each `n_steps` for each randomly generated
|
207 |
+
input sample.
|
208 |
+
Note that the gradients are not computed with respect
|
209 |
+
to these arguments.
|
210 |
+
Default: None
|
211 |
+
return_convergence_delta (bool, optional): Indicates whether to return
|
212 |
+
convergence delta or not. If `return_convergence_delta`
|
213 |
+
is set to True convergence delta will be returned in
|
214 |
+
a tuple following attributions.
|
215 |
+
Default: False
|
216 |
+
Returns:
|
217 |
+
**attributions** or 2-element tuple of **attributions**, **delta**:
|
218 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
219 |
+
Attribution score computed based on GradientSHAP with respect
|
220 |
+
to each input feature. Attributions will always be
|
221 |
+
the same size as the provided inputs, with each value
|
222 |
+
providing the attribution of the corresponding input index.
|
223 |
+
If a single tensor is provided as inputs, a single tensor is
|
224 |
+
returned. If a tuple is provided for inputs, a tuple of
|
225 |
+
corresponding sized tensors is returned.
|
226 |
+
- **delta** (*tensor*, returned if return_convergence_delta=True):
|
227 |
+
This is computed using the property that the total
|
228 |
+
sum of forward_func(inputs) - forward_func(baselines)
|
229 |
+
must be very close to the total sum of the attributions
|
230 |
+
based on GradientSHAP.
|
231 |
+
Delta is calculated for each example in the input after adding
|
232 |
+
`n_samples` times gaussian noise to each of them. Therefore,
|
233 |
+
the dimensionality of the deltas tensor is equal to the
|
234 |
+
`number of examples in the input` * `n_samples`
|
235 |
+
The deltas are ordered by each input example and `n_samples`
|
236 |
+
noisy samples generated for it.
|
237 |
+
|
238 |
+
Examples::
|
239 |
+
|
240 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
241 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
242 |
+
>>> net = ImageClassifier()
|
243 |
+
>>> gradient_shap = GradientShap(net)
|
244 |
+
>>> input = torch.randn(3, 3, 32, 32, requires_grad=True)
|
245 |
+
>>> # choosing baselines randomly
|
246 |
+
>>> baselines = torch.randn(20, 3, 32, 32)
|
247 |
+
>>> # Computes gradient shap for the input
|
248 |
+
>>> # Attribution size matches input size: 3x3x32x32
|
249 |
+
>>> attribution = gradient_shap.attribute(input, baselines,
|
250 |
+
target=5)
|
251 |
+
|
252 |
+
"""
|
253 |
+
# since `baselines` is a distribution, we can generate it using a function
|
254 |
+
# rather than passing it as an input argument
|
255 |
+
baselines = _format_callable_baseline(baselines, inputs)
|
256 |
+
assert isinstance(baselines[0], torch.Tensor), (
|
257 |
+
"Baselines distribution has to be provided in a form "
|
258 |
+
"of a torch.Tensor {}.".format(baselines[0])
|
259 |
+
)
|
260 |
+
|
261 |
+
input_min_baseline_x_grad = InputBaselineXGradient(
|
262 |
+
self.forward_func, self.multiplies_by_inputs
|
263 |
+
)
|
264 |
+
input_min_baseline_x_grad.gradient_func = self.gradient_func
|
265 |
+
|
266 |
+
nt = NoiseTunnel(input_min_baseline_x_grad)
|
267 |
+
|
268 |
+
# NOTE: using attribute.__wrapped__ to not log
|
269 |
+
attributions = nt.attribute.__wrapped__(
|
270 |
+
nt, # self
|
271 |
+
inputs,
|
272 |
+
nt_type="smoothgrad",
|
273 |
+
nt_samples=n_samples,
|
274 |
+
stdevs=stdevs,
|
275 |
+
draw_baseline_from_distrib=True,
|
276 |
+
baselines=baselines,
|
277 |
+
target=target,
|
278 |
+
additional_forward_args=additional_forward_args,
|
279 |
+
return_convergence_delta=return_convergence_delta,
|
280 |
+
)
|
281 |
+
|
282 |
+
return attributions
|
283 |
+
|
284 |
+
def has_convergence_delta(self) -> bool:
|
285 |
+
return True
|
286 |
+
|
287 |
+
@property
|
288 |
+
def multiplies_by_inputs(self):
|
289 |
+
return self._multiply_by_inputs
|
290 |
+
|
291 |
+
|
292 |
+
class InputBaselineXGradient(GradientAttribution):
|
293 |
+
def __init__(self, forward_func: Callable, multiply_by_inputs=True) -> None:
|
294 |
+
r"""
|
295 |
+
Args:
|
296 |
+
|
297 |
+
forward_func (function): The forward function of the model or
|
298 |
+
any modification of it
|
299 |
+
multiply_by_inputs (bool, optional): Indicates whether to factor
|
300 |
+
model inputs' multiplier in the final attribution scores.
|
301 |
+
In the literature this is also known as local vs global
|
302 |
+
attribution. If inputs' multiplier isn't factored in
|
303 |
+
then this type of attribution method is also called local
|
304 |
+
attribution. If it is, then that type of attribution
|
305 |
+
method is called global.
|
306 |
+
More detailed can be found here:
|
307 |
+
https://arxiv.org/abs/1711.06104
|
308 |
+
|
309 |
+
In case of gradient shap, if `multiply_by_inputs`
|
310 |
+
is set to True, the sensitivity scores of scaled inputs
|
311 |
+
are being multiplied by (inputs - baselines).
|
312 |
+
|
313 |
+
"""
|
314 |
+
GradientAttribution.__init__(self, forward_func)
|
315 |
+
self._multiply_by_inputs = multiply_by_inputs
|
316 |
+
|
317 |
+
@typing.overload
|
318 |
+
def attribute(
|
319 |
+
self,
|
320 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
321 |
+
baselines: BaselineType = None,
|
322 |
+
target: TargetType = None,
|
323 |
+
additional_forward_args: Any = None,
|
324 |
+
*,
|
325 |
+
return_convergence_delta: Literal[True],
|
326 |
+
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
|
327 |
+
...
|
328 |
+
|
329 |
+
@typing.overload
|
330 |
+
def attribute(
|
331 |
+
self,
|
332 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
333 |
+
baselines: BaselineType = None,
|
334 |
+
target: TargetType = None,
|
335 |
+
additional_forward_args: Any = None,
|
336 |
+
return_convergence_delta: Literal[False] = False,
|
337 |
+
) -> TensorOrTupleOfTensorsGeneric:
|
338 |
+
...
|
339 |
+
|
340 |
+
@log_usage()
|
341 |
+
def attribute( # type: ignore
|
342 |
+
self,
|
343 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
344 |
+
baselines: BaselineType = None,
|
345 |
+
target: TargetType = None,
|
346 |
+
additional_forward_args: Any = None,
|
347 |
+
return_convergence_delta: bool = False,
|
348 |
+
) -> Union[
|
349 |
+
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
|
350 |
+
]:
|
351 |
+
# Keeps track whether original input is a tuple or not before
|
352 |
+
# converting it into a tuple.
|
353 |
+
is_inputs_tuple = _is_tuple(inputs)
|
354 |
+
inputs, baselines = _format_input_baseline(inputs, baselines)
|
355 |
+
|
356 |
+
rand_coefficient = torch.tensor(
|
357 |
+
np.random.uniform(0.0, 1.0, inputs[0].shape[0]),
|
358 |
+
device=inputs[0].device,
|
359 |
+
dtype=inputs[0].dtype,
|
360 |
+
)
|
361 |
+
|
362 |
+
input_baseline_scaled = tuple(
|
363 |
+
_scale_input(input, baseline, rand_coefficient)
|
364 |
+
for input, baseline in zip(inputs, baselines)
|
365 |
+
)
|
366 |
+
grads = self.gradient_func(
|
367 |
+
self.forward_func, input_baseline_scaled, target, additional_forward_args
|
368 |
+
)
|
369 |
+
|
370 |
+
if self.multiplies_by_inputs:
|
371 |
+
input_baseline_diffs = tuple(
|
372 |
+
input - baseline for input, baseline in zip(inputs, baselines)
|
373 |
+
)
|
374 |
+
attributions = tuple(
|
375 |
+
input_baseline_diff * grad
|
376 |
+
for input_baseline_diff, grad in zip(input_baseline_diffs, grads)
|
377 |
+
)
|
378 |
+
else:
|
379 |
+
attributions = grads
|
380 |
+
|
381 |
+
return _compute_conv_delta_and_format_attrs(
|
382 |
+
self,
|
383 |
+
return_convergence_delta,
|
384 |
+
attributions,
|
385 |
+
baselines,
|
386 |
+
inputs,
|
387 |
+
additional_forward_args,
|
388 |
+
target,
|
389 |
+
is_inputs_tuple,
|
390 |
+
)
|
391 |
+
|
392 |
+
def has_convergence_delta(self) -> bool:
|
393 |
+
return True
|
394 |
+
|
395 |
+
@property
|
396 |
+
def multiplies_by_inputs(self):
|
397 |
+
return self._multiply_by_inputs
|
398 |
+
|
399 |
+
|
400 |
+
def _scale_input(
|
401 |
+
input: Tensor, baseline: Union[Tensor, int, float], rand_coefficient: Tensor
|
402 |
+
) -> Tensor:
|
403 |
+
# batch size
|
404 |
+
bsz = input.shape[0]
|
405 |
+
inp_shape_wo_bsz = input.shape[1:]
|
406 |
+
inp_shape = (bsz,) + tuple([1] * len(inp_shape_wo_bsz))
|
407 |
+
|
408 |
+
# expand and reshape the indices
|
409 |
+
rand_coefficient = rand_coefficient.view(inp_shape)
|
410 |
+
|
411 |
+
input_baseline_scaled = (
|
412 |
+
rand_coefficient * input + (1.0 - rand_coefficient) * baseline
|
413 |
+
).requires_grad_()
|
414 |
+
return input_baseline_scaled
|
captum/attr/_core/guided_backprop_deconvnet.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import warnings
|
3 |
+
from typing import Any, List, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from captum._utils.common import (
|
8 |
+
_format_output,
|
9 |
+
_format_tensor_into_tuples,
|
10 |
+
_is_tuple,
|
11 |
+
_register_backward_hook,
|
12 |
+
)
|
13 |
+
from captum._utils.gradient import (
|
14 |
+
apply_gradient_requirements,
|
15 |
+
undo_gradient_requirements,
|
16 |
+
)
|
17 |
+
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
|
18 |
+
from captum.attr._utils.attribution import GradientAttribution
|
19 |
+
from captum.log import log_usage
|
20 |
+
from torch import Tensor
|
21 |
+
from torch.nn import Module
|
22 |
+
from torch.utils.hooks import RemovableHandle
|
23 |
+
|
24 |
+
|
25 |
+
class ModifiedReluGradientAttribution(GradientAttribution):
|
26 |
+
def __init__(self, model: Module, use_relu_grad_output: bool = False) -> None:
|
27 |
+
r"""
|
28 |
+
Args:
|
29 |
+
|
30 |
+
model (nn.Module): The reference to PyTorch model instance.
|
31 |
+
"""
|
32 |
+
GradientAttribution.__init__(self, model)
|
33 |
+
self.model = model
|
34 |
+
self.backward_hooks: List[RemovableHandle] = []
|
35 |
+
self.use_relu_grad_output = use_relu_grad_output
|
36 |
+
assert isinstance(self.model, torch.nn.Module), (
|
37 |
+
"Given model must be an instance of torch.nn.Module to properly hook"
|
38 |
+
" ReLU layers."
|
39 |
+
)
|
40 |
+
|
41 |
+
@log_usage()
|
42 |
+
def attribute(
|
43 |
+
self,
|
44 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
45 |
+
target: TargetType = None,
|
46 |
+
additional_forward_args: Any = None,
|
47 |
+
) -> TensorOrTupleOfTensorsGeneric:
|
48 |
+
r"""
|
49 |
+
Computes attribution by overriding relu gradients. Based on constructor
|
50 |
+
flag use_relu_grad_output, performs either GuidedBackpropagation if False
|
51 |
+
and Deconvolution if True. This class is the parent class of both these
|
52 |
+
methods, more information on usage can be found in the docstrings for each
|
53 |
+
implementing class.
|
54 |
+
"""
|
55 |
+
|
56 |
+
# Keeps track whether original input is a tuple or not before
|
57 |
+
# converting it into a tuple.
|
58 |
+
is_inputs_tuple = _is_tuple(inputs)
|
59 |
+
|
60 |
+
inputs = _format_tensor_into_tuples(inputs)
|
61 |
+
gradient_mask = apply_gradient_requirements(inputs)
|
62 |
+
|
63 |
+
# set hooks for overriding ReLU gradients
|
64 |
+
warnings.warn(
|
65 |
+
"Setting backward hooks on ReLU activations."
|
66 |
+
"The hooks will be removed after the attribution is finished"
|
67 |
+
)
|
68 |
+
try:
|
69 |
+
self.model.apply(self._register_hooks)
|
70 |
+
|
71 |
+
gradients = self.gradient_func(
|
72 |
+
self.forward_func, inputs, target, additional_forward_args
|
73 |
+
)
|
74 |
+
finally:
|
75 |
+
self._remove_hooks()
|
76 |
+
|
77 |
+
undo_gradient_requirements(inputs, gradient_mask)
|
78 |
+
return _format_output(is_inputs_tuple, gradients)
|
79 |
+
|
80 |
+
def _register_hooks(self, module: Module):
|
81 |
+
if isinstance(module, torch.nn.ReLU):
|
82 |
+
hook = _register_backward_hook(module, self._backward_hook, self)
|
83 |
+
self.backward_hooks.append(hook)
|
84 |
+
|
85 |
+
def _backward_hook(
|
86 |
+
self,
|
87 |
+
module: Module,
|
88 |
+
grad_input: Union[Tensor, Tuple[Tensor, ...]],
|
89 |
+
grad_output: Union[Tensor, Tuple[Tensor, ...]],
|
90 |
+
):
|
91 |
+
to_override_grads = grad_output if self.use_relu_grad_output else grad_input
|
92 |
+
if isinstance(to_override_grads, tuple):
|
93 |
+
return tuple(
|
94 |
+
F.relu(to_override_grad) for to_override_grad in to_override_grads
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
return F.relu(to_override_grads)
|
98 |
+
|
99 |
+
def _remove_hooks(self):
|
100 |
+
for hook in self.backward_hooks:
|
101 |
+
hook.remove()
|
102 |
+
|
103 |
+
|
104 |
+
class GuidedBackprop(ModifiedReluGradientAttribution):
|
105 |
+
r"""
|
106 |
+
Computes attribution using guided backpropagation. Guided backpropagation
|
107 |
+
computes the gradient of the target output with respect to the input,
|
108 |
+
but gradients of ReLU functions are overridden so that only
|
109 |
+
non-negative gradients are backpropagated.
|
110 |
+
|
111 |
+
More details regarding the guided backpropagation algorithm can be found
|
112 |
+
in the original paper here:
|
113 |
+
https://arxiv.org/abs/1412.6806
|
114 |
+
|
115 |
+
Warning: Ensure that all ReLU operations in the forward function of the
|
116 |
+
given model are performed using a module (nn.module.ReLU).
|
117 |
+
If nn.functional.ReLU is used, gradients are not overridden appropriately.
|
118 |
+
"""
|
119 |
+
|
120 |
+
def __init__(self, model: Module) -> None:
|
121 |
+
r"""
|
122 |
+
Args:
|
123 |
+
|
124 |
+
model (nn.Module): The reference to PyTorch model instance. Model cannot
|
125 |
+
contain any in-place ReLU submodules; these are not
|
126 |
+
supported by the register_full_backward_hook PyTorch API.
|
127 |
+
"""
|
128 |
+
ModifiedReluGradientAttribution.__init__(
|
129 |
+
self, model, use_relu_grad_output=False
|
130 |
+
)
|
131 |
+
|
132 |
+
@log_usage()
|
133 |
+
def attribute(
|
134 |
+
self,
|
135 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
136 |
+
target: TargetType = None,
|
137 |
+
additional_forward_args: Any = None,
|
138 |
+
) -> TensorOrTupleOfTensorsGeneric:
|
139 |
+
r"""
|
140 |
+
Args:
|
141 |
+
|
142 |
+
inputs (tensor or tuple of tensors): Input for which
|
143 |
+
attributions are computed. If forward_func takes a single
|
144 |
+
tensor as input, a single input tensor should be provided.
|
145 |
+
If forward_func takes multiple tensors as input, a tuple
|
146 |
+
of the input tensors should be provided. It is assumed
|
147 |
+
that for all given input tensors, dimension 0 corresponds
|
148 |
+
to the number of examples (aka batch size), and if
|
149 |
+
multiple input tensors are provided, the examples must
|
150 |
+
be aligned appropriately.
|
151 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
152 |
+
which gradients are computed (for classification cases,
|
153 |
+
this is usually the target class).
|
154 |
+
If the network returns a scalar value per example,
|
155 |
+
no target index is necessary.
|
156 |
+
For general 2D outputs, targets can be either:
|
157 |
+
|
158 |
+
- a single integer or a tensor containing a single
|
159 |
+
integer, which is applied to all input examples
|
160 |
+
|
161 |
+
- a list of integers or a 1D tensor, with length matching
|
162 |
+
the number of examples in inputs (dim 0). Each integer
|
163 |
+
is applied as the target for the corresponding example.
|
164 |
+
|
165 |
+
For outputs with > 2 dimensions, targets can be either:
|
166 |
+
|
167 |
+
- A single tuple, which contains #output_dims - 1
|
168 |
+
elements. This target index is applied to all examples.
|
169 |
+
|
170 |
+
- A list of tuples with length equal to the number of
|
171 |
+
examples in inputs (dim 0), and each tuple containing
|
172 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
173 |
+
target for the corresponding example.
|
174 |
+
|
175 |
+
Default: None
|
176 |
+
additional_forward_args (any, optional): If the forward function
|
177 |
+
requires additional arguments other than the inputs for
|
178 |
+
which attributions should not be computed, this argument
|
179 |
+
can be provided. It must be either a single additional
|
180 |
+
argument of a Tensor or arbitrary (non-tuple) type or a tuple
|
181 |
+
containing multiple additional arguments including tensors
|
182 |
+
or any arbitrary python types. These arguments are provided to
|
183 |
+
forward_func in order, following the arguments in inputs.
|
184 |
+
Note that attributions are not computed with respect
|
185 |
+
to these arguments.
|
186 |
+
Default: None
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
*tensor* or tuple of *tensors* of **attributions**:
|
190 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
191 |
+
The guided backprop gradients with respect to each
|
192 |
+
input feature. Attributions will always
|
193 |
+
be the same size as the provided inputs, with each value
|
194 |
+
providing the attribution of the corresponding input index.
|
195 |
+
If a single tensor is provided as inputs, a single tensor is
|
196 |
+
returned. If a tuple is provided for inputs, a tuple of
|
197 |
+
corresponding sized tensors is returned.
|
198 |
+
|
199 |
+
Examples::
|
200 |
+
|
201 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
202 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
203 |
+
>>> net = ImageClassifier()
|
204 |
+
>>> gbp = GuidedBackprop(net)
|
205 |
+
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
|
206 |
+
>>> # Computes Guided Backprop attribution scores for class 3.
|
207 |
+
>>> attribution = gbp.attribute(input, target=3)
|
208 |
+
"""
|
209 |
+
return super().attribute.__wrapped__(
|
210 |
+
self, inputs, target, additional_forward_args
|
211 |
+
)
|
212 |
+
|
213 |
+
|
214 |
+
class Deconvolution(ModifiedReluGradientAttribution):
|
215 |
+
r"""
|
216 |
+
Computes attribution using deconvolution. Deconvolution
|
217 |
+
computes the gradient of the target output with respect to the input,
|
218 |
+
but gradients of ReLU functions are overridden so that the gradient
|
219 |
+
of the ReLU input is simply computed taking ReLU of the output gradient,
|
220 |
+
essentially only propagating non-negative gradients (without
|
221 |
+
dependence on the sign of the ReLU input).
|
222 |
+
|
223 |
+
More details regarding the deconvolution algorithm can be found
|
224 |
+
in these papers:
|
225 |
+
https://arxiv.org/abs/1311.2901
|
226 |
+
https://link.springer.com/chapter/10.1007/978-3-319-46466-4_8
|
227 |
+
|
228 |
+
Warning: Ensure that all ReLU operations in the forward function of the
|
229 |
+
given model are performed using a module (nn.module.ReLU).
|
230 |
+
If nn.functional.ReLU is used, gradients are not overridden appropriately.
|
231 |
+
"""
|
232 |
+
|
233 |
+
def __init__(self, model: Module) -> None:
|
234 |
+
r"""
|
235 |
+
Args:
|
236 |
+
|
237 |
+
model (nn.Module): The reference to PyTorch model instance. Model cannot
|
238 |
+
contain any in-place ReLU submodules; these are not
|
239 |
+
supported by the register_full_backward_hook PyTorch API.
|
240 |
+
"""
|
241 |
+
ModifiedReluGradientAttribution.__init__(self, model, use_relu_grad_output=True)
|
242 |
+
|
243 |
+
@log_usage()
|
244 |
+
def attribute(
|
245 |
+
self,
|
246 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
247 |
+
target: TargetType = None,
|
248 |
+
additional_forward_args: Any = None,
|
249 |
+
) -> TensorOrTupleOfTensorsGeneric:
|
250 |
+
r"""
|
251 |
+
Args:
|
252 |
+
|
253 |
+
inputs (tensor or tuple of tensors): Input for which
|
254 |
+
attributions are computed. If forward_func takes a single
|
255 |
+
tensor as input, a single input tensor should be provided.
|
256 |
+
If forward_func takes multiple tensors as input, a tuple
|
257 |
+
of the input tensors should be provided. It is assumed
|
258 |
+
that for all given input tensors, dimension 0 corresponds
|
259 |
+
to the number of examples (aka batch size), and if
|
260 |
+
multiple input tensors are provided, the examples must
|
261 |
+
be aligned appropriately.
|
262 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
263 |
+
which gradients are computed (for classification cases,
|
264 |
+
this is usually the target class).
|
265 |
+
If the network returns a scalar value per example,
|
266 |
+
no target index is necessary.
|
267 |
+
For general 2D outputs, targets can be either:
|
268 |
+
|
269 |
+
- a single integer or a tensor containing a single
|
270 |
+
integer, which is applied to all input examples
|
271 |
+
|
272 |
+
- a list of integers or a 1D tensor, with length matching
|
273 |
+
the number of examples in inputs (dim 0). Each integer
|
274 |
+
is applied as the target for the corresponding example.
|
275 |
+
|
276 |
+
For outputs with > 2 dimensions, targets can be either:
|
277 |
+
|
278 |
+
- A single tuple, which contains #output_dims - 1
|
279 |
+
elements. This target index is applied to all examples.
|
280 |
+
|
281 |
+
- A list of tuples with length equal to the number of
|
282 |
+
examples in inputs (dim 0), and each tuple containing
|
283 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
284 |
+
target for the corresponding example.
|
285 |
+
|
286 |
+
Default: None
|
287 |
+
additional_forward_args (any, optional): If the forward function
|
288 |
+
requires additional arguments other than the inputs for
|
289 |
+
which attributions should not be computed, this argument
|
290 |
+
can be provided. It must be either a single additional
|
291 |
+
argument of a Tensor or arbitrary (non-tuple) type or a tuple
|
292 |
+
containing multiple additional arguments including tensors
|
293 |
+
or any arbitrary python types. These arguments are provided to
|
294 |
+
forward_func in order, following the arguments in inputs.
|
295 |
+
Note that attributions are not computed with respect
|
296 |
+
to these arguments.
|
297 |
+
Default: None
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
*tensor* or tuple of *tensors* of **attributions**:
|
301 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
302 |
+
The deconvolution attributions with respect to each
|
303 |
+
input feature. Attributions will always
|
304 |
+
be the same size as the provided inputs, with each value
|
305 |
+
providing the attribution of the corresponding input index.
|
306 |
+
If a single tensor is provided as inputs, a single tensor is
|
307 |
+
returned. If a tuple is provided for inputs, a tuple of
|
308 |
+
corresponding sized tensors is returned.
|
309 |
+
|
310 |
+
Examples::
|
311 |
+
|
312 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
313 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
314 |
+
>>> net = ImageClassifier()
|
315 |
+
>>> deconv = Deconvolution(net)
|
316 |
+
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
|
317 |
+
>>> # Computes Deconvolution attribution scores for class 3.
|
318 |
+
>>> attribution = deconv.attribute(input, target=3)
|
319 |
+
"""
|
320 |
+
return super().attribute.__wrapped__(
|
321 |
+
self, inputs, target, additional_forward_args
|
322 |
+
)
|
captum/attr/_core/guided_grad_cam.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import warnings
|
3 |
+
from typing import Any, List, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
|
7 |
+
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
|
8 |
+
from captum.attr._core.guided_backprop_deconvnet import GuidedBackprop
|
9 |
+
from captum.attr._core.layer.grad_cam import LayerGradCam
|
10 |
+
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
|
11 |
+
from captum.log import log_usage
|
12 |
+
from torch import Tensor
|
13 |
+
from torch.nn import Module
|
14 |
+
|
15 |
+
|
16 |
+
class GuidedGradCam(GradientAttribution):
|
17 |
+
r"""
|
18 |
+
Computes element-wise product of guided backpropagation attributions
|
19 |
+
with upsampled (non-negative) GradCAM attributions.
|
20 |
+
GradCAM attributions are computed with respect to the layer
|
21 |
+
provided in the constructor, and attributions
|
22 |
+
are upsampled to match the input size. GradCAM is designed for
|
23 |
+
convolutional neural networks, and is usually applied to the last
|
24 |
+
convolutional layer.
|
25 |
+
|
26 |
+
Note that if multiple input tensors are provided, attributions for
|
27 |
+
each input tensor are computed by upsampling the GradCAM
|
28 |
+
attributions to match that input's dimensions. If interpolation is
|
29 |
+
not possible for the input tensor dimensions and interpolation mode,
|
30 |
+
then an empty tensor is returned in the attributions for the
|
31 |
+
corresponding position of that input tensor. This can occur if the
|
32 |
+
input tensor does not have the same number of dimensions as the chosen
|
33 |
+
layer's output or is not either 3D, 4D or 5D.
|
34 |
+
|
35 |
+
Note that attributions are only meaningful for input tensors
|
36 |
+
which are spatially alligned with the chosen layer, e.g. an input
|
37 |
+
image tensor for a convolutional layer.
|
38 |
+
|
39 |
+
More details regarding GuidedGradCAM can be found in the original
|
40 |
+
GradCAM paper here:
|
41 |
+
https://arxiv.org/pdf/1610.02391.pdf
|
42 |
+
|
43 |
+
Warning: Ensure that all ReLU operations in the forward function of the
|
44 |
+
given model are performed using a module (nn.module.ReLU).
|
45 |
+
If nn.functional.ReLU is used, gradients are not overridden appropriately.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self, model: Module, layer: Module, device_ids: Union[None, List[int]] = None
|
50 |
+
) -> None:
|
51 |
+
r"""
|
52 |
+
Args:
|
53 |
+
|
54 |
+
model (nn.Module): The reference to PyTorch model instance. Model cannot
|
55 |
+
contain any in-place ReLU submodules; these are not
|
56 |
+
supported by the register_full_backward_hook PyTorch API
|
57 |
+
starting from PyTorch v1.9.
|
58 |
+
layer (torch.nn.Module): Layer for which GradCAM attributions are computed.
|
59 |
+
Currently, only layers with a single tensor output are
|
60 |
+
supported.
|
61 |
+
device_ids (list(int)): Device ID list, necessary only if forward_func
|
62 |
+
applies a DataParallel model. This allows reconstruction of
|
63 |
+
intermediate outputs from batched results across devices.
|
64 |
+
If forward_func is given as the DataParallel model itself,
|
65 |
+
then it is not necessary to provide this argument.
|
66 |
+
"""
|
67 |
+
GradientAttribution.__init__(self, model)
|
68 |
+
self.grad_cam = LayerGradCam(model, layer, device_ids)
|
69 |
+
self.guided_backprop = GuidedBackprop(model)
|
70 |
+
|
71 |
+
@log_usage()
|
72 |
+
def attribute(
|
73 |
+
self,
|
74 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
75 |
+
target: TargetType = None,
|
76 |
+
additional_forward_args: Any = None,
|
77 |
+
interpolate_mode: str = "nearest",
|
78 |
+
attribute_to_layer_input: bool = False,
|
79 |
+
) -> TensorOrTupleOfTensorsGeneric:
|
80 |
+
r"""
|
81 |
+
Args:
|
82 |
+
|
83 |
+
inputs (tensor or tuple of tensors): Input for which attributions
|
84 |
+
are computed. If forward_func takes a single
|
85 |
+
tensor as input, a single input tensor should be provided.
|
86 |
+
If forward_func takes multiple tensors as input, a tuple
|
87 |
+
of the input tensors should be provided. It is assumed
|
88 |
+
that for all given input tensors, dimension 0 corresponds
|
89 |
+
to the number of examples, and if multiple input tensors
|
90 |
+
are provided, the examples must be aligned appropriately.
|
91 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
92 |
+
which gradients are computed (for classification cases,
|
93 |
+
this is usually the target class).
|
94 |
+
If the network returns a scalar value per example,
|
95 |
+
no target index is necessary.
|
96 |
+
For general 2D outputs, targets can be either:
|
97 |
+
|
98 |
+
- a single integer or a tensor containing a single
|
99 |
+
integer, which is applied to all input examples
|
100 |
+
|
101 |
+
- a list of integers or a 1D tensor, with length matching
|
102 |
+
the number of examples in inputs (dim 0). Each integer
|
103 |
+
is applied as the target for the corresponding example.
|
104 |
+
|
105 |
+
For outputs with > 2 dimensions, targets can be either:
|
106 |
+
|
107 |
+
- A single tuple, which contains #output_dims - 1
|
108 |
+
elements. This target index is applied to all examples.
|
109 |
+
|
110 |
+
- A list of tuples with length equal to the number of
|
111 |
+
examples in inputs (dim 0), and each tuple containing
|
112 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
113 |
+
target for the corresponding example.
|
114 |
+
|
115 |
+
Default: None
|
116 |
+
additional_forward_args (any, optional): If the forward function
|
117 |
+
requires additional arguments other than the inputs for
|
118 |
+
which attributions should not be computed, this argument
|
119 |
+
can be provided. It must be either a single additional
|
120 |
+
argument of a Tensor or arbitrary (non-tuple) type or a
|
121 |
+
tuple containing multiple additional arguments including
|
122 |
+
tensors or any arbitrary python types. These arguments
|
123 |
+
are provided to forward_func in order following the
|
124 |
+
arguments in inputs.
|
125 |
+
Note that attributions are not computed with respect
|
126 |
+
to these arguments.
|
127 |
+
Default: None
|
128 |
+
interpolate_mode (str, optional): Method for interpolation, which
|
129 |
+
must be a valid input interpolation mode for
|
130 |
+
torch.nn.functional. These methods are
|
131 |
+
"nearest", "area", "linear" (3D-only), "bilinear"
|
132 |
+
(4D-only), "bicubic" (4D-only), "trilinear" (5D-only)
|
133 |
+
based on the number of dimensions of the chosen layer
|
134 |
+
output (which must also match the number of
|
135 |
+
dimensions for the input tensor). Note that
|
136 |
+
the original GradCAM paper uses "bilinear"
|
137 |
+
interpolation, but we default to "nearest" for
|
138 |
+
applicability to any of 3D, 4D or 5D tensors.
|
139 |
+
Default: "nearest"
|
140 |
+
attribute_to_layer_input (bool, optional): Indicates whether to
|
141 |
+
compute the attribution with respect to the layer input
|
142 |
+
or output in `LayerGradCam`.
|
143 |
+
If `attribute_to_layer_input` is set to True
|
144 |
+
then the attributions will be computed with respect to
|
145 |
+
layer inputs, otherwise it will be computed with respect
|
146 |
+
to layer outputs.
|
147 |
+
Note that currently it is assumed that either the input
|
148 |
+
or the output of internal layer, depending on whether we
|
149 |
+
attribute to the input or output, is a single tensor.
|
150 |
+
Support for multiple tensors will be added later.
|
151 |
+
Default: False
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
*tensor* of **attributions**:
|
155 |
+
- **attributions** (*tensor*):
|
156 |
+
Element-wise product of (upsampled) GradCAM
|
157 |
+
and Guided Backprop attributions.
|
158 |
+
If a single tensor is provided as inputs, a single tensor is
|
159 |
+
returned. If a tuple is provided for inputs, a tuple of
|
160 |
+
corresponding sized tensors is returned.
|
161 |
+
Attributions will be the same size as the provided inputs,
|
162 |
+
with each value providing the attribution of the
|
163 |
+
corresponding input index.
|
164 |
+
If the GradCAM attributions cannot be upsampled to the shape
|
165 |
+
of a given input tensor, None is returned in the corresponding
|
166 |
+
index position.
|
167 |
+
|
168 |
+
|
169 |
+
Examples::
|
170 |
+
|
171 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
172 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
173 |
+
>>> # It contains an attribute conv4, which is an instance of nn.conv2d,
|
174 |
+
>>> # and the output of this layer has dimensions Nx50x8x8.
|
175 |
+
>>> # It is the last convolution layer, which is the recommended
|
176 |
+
>>> # use case for GuidedGradCAM.
|
177 |
+
>>> net = ImageClassifier()
|
178 |
+
>>> guided_gc = GuidedGradCam(net, net.conv4)
|
179 |
+
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
|
180 |
+
>>> # Computes guided GradCAM attributions for class 3.
|
181 |
+
>>> # attribution size matches input size, Nx3x32x32
|
182 |
+
>>> attribution = guided_gc.attribute(input, 3)
|
183 |
+
"""
|
184 |
+
is_inputs_tuple = _is_tuple(inputs)
|
185 |
+
inputs = _format_tensor_into_tuples(inputs)
|
186 |
+
grad_cam_attr = self.grad_cam.attribute.__wrapped__(
|
187 |
+
self.grad_cam, # self
|
188 |
+
inputs=inputs,
|
189 |
+
target=target,
|
190 |
+
additional_forward_args=additional_forward_args,
|
191 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
192 |
+
relu_attributions=True,
|
193 |
+
)
|
194 |
+
if isinstance(grad_cam_attr, tuple):
|
195 |
+
assert len(grad_cam_attr) == 1, (
|
196 |
+
"GuidedGradCAM attributions for layer with multiple inputs / "
|
197 |
+
"outputs is not supported."
|
198 |
+
)
|
199 |
+
grad_cam_attr = grad_cam_attr[0]
|
200 |
+
|
201 |
+
guided_backprop_attr = self.guided_backprop.attribute.__wrapped__(
|
202 |
+
self.guided_backprop, # self
|
203 |
+
inputs=inputs,
|
204 |
+
target=target,
|
205 |
+
additional_forward_args=additional_forward_args,
|
206 |
+
)
|
207 |
+
output_attr: List[Tensor] = []
|
208 |
+
for i in range(len(inputs)):
|
209 |
+
try:
|
210 |
+
output_attr.append(
|
211 |
+
guided_backprop_attr[i]
|
212 |
+
* LayerAttribution.interpolate(
|
213 |
+
grad_cam_attr,
|
214 |
+
inputs[i].shape[2:],
|
215 |
+
interpolate_mode=interpolate_mode,
|
216 |
+
)
|
217 |
+
)
|
218 |
+
except Exception:
|
219 |
+
warnings.warn(
|
220 |
+
"Couldn't appropriately interpolate GradCAM attributions for some "
|
221 |
+
"input tensors, returning empty tensor for corresponding "
|
222 |
+
"attributions."
|
223 |
+
)
|
224 |
+
output_attr.append(torch.empty(0))
|
225 |
+
|
226 |
+
return _format_output(is_inputs_tuple, tuple(output_attr))
|
captum/attr/_core/input_x_gradient.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from typing import Any, Callable
|
3 |
+
|
4 |
+
from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
|
5 |
+
from captum._utils.gradient import (
|
6 |
+
apply_gradient_requirements,
|
7 |
+
undo_gradient_requirements,
|
8 |
+
)
|
9 |
+
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
|
10 |
+
from captum.attr._utils.attribution import GradientAttribution
|
11 |
+
from captum.log import log_usage
|
12 |
+
|
13 |
+
|
14 |
+
class InputXGradient(GradientAttribution):
|
15 |
+
r"""
|
16 |
+
A baseline approach for computing the attribution. It multiplies input with
|
17 |
+
the gradient with respect to input.
|
18 |
+
https://arxiv.org/abs/1605.01713
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, forward_func: Callable) -> None:
|
22 |
+
r"""
|
23 |
+
Args:
|
24 |
+
|
25 |
+
forward_func (callable): The forward function of the model or any
|
26 |
+
modification of it
|
27 |
+
"""
|
28 |
+
GradientAttribution.__init__(self, forward_func)
|
29 |
+
|
30 |
+
@log_usage()
|
31 |
+
def attribute(
|
32 |
+
self,
|
33 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
34 |
+
target: TargetType = None,
|
35 |
+
additional_forward_args: Any = None,
|
36 |
+
) -> TensorOrTupleOfTensorsGeneric:
|
37 |
+
r"""
|
38 |
+
Args:
|
39 |
+
|
40 |
+
inputs (tensor or tuple of tensors): Input for which
|
41 |
+
attributions are computed. If forward_func takes a single
|
42 |
+
tensor as input, a single input tensor should be provided.
|
43 |
+
If forward_func takes multiple tensors as input, a tuple
|
44 |
+
of the input tensors should be provided. It is assumed
|
45 |
+
that for all given input tensors, dimension 0 corresponds
|
46 |
+
to the number of examples (aka batch size), and if
|
47 |
+
multiple input tensors are provided, the examples must
|
48 |
+
be aligned appropriately.
|
49 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
50 |
+
which gradients are computed (for classification cases,
|
51 |
+
this is usually the target class).
|
52 |
+
If the network returns a scalar value per example,
|
53 |
+
no target index is necessary.
|
54 |
+
For general 2D outputs, targets can be either:
|
55 |
+
|
56 |
+
- a single integer or a tensor containing a single
|
57 |
+
integer, which is applied to all input examples
|
58 |
+
|
59 |
+
- a list of integers or a 1D tensor, with length matching
|
60 |
+
the number of examples in inputs (dim 0). Each integer
|
61 |
+
is applied as the target for the corresponding example.
|
62 |
+
|
63 |
+
For outputs with > 2 dimensions, targets can be either:
|
64 |
+
|
65 |
+
- A single tuple, which contains #output_dims - 1
|
66 |
+
elements. This target index is applied to all examples.
|
67 |
+
|
68 |
+
- A list of tuples with length equal to the number of
|
69 |
+
examples in inputs (dim 0), and each tuple containing
|
70 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
71 |
+
target for the corresponding example.
|
72 |
+
|
73 |
+
Default: None
|
74 |
+
additional_forward_args (any, optional): If the forward function
|
75 |
+
requires additional arguments other than the inputs for
|
76 |
+
which attributions should not be computed, this argument
|
77 |
+
can be provided. It must be either a single additional
|
78 |
+
argument of a Tensor or arbitrary (non-tuple) type or a tuple
|
79 |
+
containing multiple additional arguments including tensors
|
80 |
+
or any arbitrary python types. These arguments are provided to
|
81 |
+
forward_func in order following the arguments in inputs.
|
82 |
+
Note that attributions are not computed with respect
|
83 |
+
to these arguments.
|
84 |
+
Default: None
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
*tensor* or tuple of *tensors* of **attributions**:
|
88 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
89 |
+
The input x gradient with
|
90 |
+
respect to each input feature. Attributions will always be
|
91 |
+
the same size as the provided inputs, with each value
|
92 |
+
providing the attribution of the corresponding input index.
|
93 |
+
If a single tensor is provided as inputs, a single tensor is
|
94 |
+
returned. If a tuple is provided for inputs, a tuple of
|
95 |
+
corresponding sized tensors is returned.
|
96 |
+
|
97 |
+
|
98 |
+
Examples::
|
99 |
+
|
100 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
101 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
102 |
+
>>> net = ImageClassifier()
|
103 |
+
>>> # Generating random input with size 2x3x3x32
|
104 |
+
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
|
105 |
+
>>> # Defining InputXGradient interpreter
|
106 |
+
>>> input_x_gradient = InputXGradient(net)
|
107 |
+
>>> # Computes inputXgradient for class 4.
|
108 |
+
>>> attribution = input_x_gradient.attribute(input, target=4)
|
109 |
+
"""
|
110 |
+
# Keeps track whether original input is a tuple or not before
|
111 |
+
# converting it into a tuple.
|
112 |
+
is_inputs_tuple = _is_tuple(inputs)
|
113 |
+
|
114 |
+
inputs = _format_tensor_into_tuples(inputs)
|
115 |
+
gradient_mask = apply_gradient_requirements(inputs)
|
116 |
+
|
117 |
+
gradients = self.gradient_func(
|
118 |
+
self.forward_func, inputs, target, additional_forward_args
|
119 |
+
)
|
120 |
+
|
121 |
+
attributions = tuple(
|
122 |
+
input * gradient for input, gradient in zip(inputs, gradients)
|
123 |
+
)
|
124 |
+
|
125 |
+
undo_gradient_requirements(inputs, gradient_mask)
|
126 |
+
return _format_output(is_inputs_tuple, attributions)
|
127 |
+
|
128 |
+
@property
|
129 |
+
def multiplies_by_inputs(self):
|
130 |
+
return True
|
captum/attr/_core/integrated_gradients.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import typing
|
3 |
+
from typing import Any, Callable, List, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from captum._utils.common import (
|
7 |
+
_expand_additional_forward_args,
|
8 |
+
_expand_target,
|
9 |
+
_format_additional_forward_args,
|
10 |
+
_format_output,
|
11 |
+
_is_tuple,
|
12 |
+
)
|
13 |
+
from captum._utils.typing import (
|
14 |
+
BaselineType,
|
15 |
+
Literal,
|
16 |
+
TargetType,
|
17 |
+
TensorOrTupleOfTensorsGeneric,
|
18 |
+
)
|
19 |
+
from captum.attr._utils.approximation_methods import approximation_parameters
|
20 |
+
from captum.attr._utils.attribution import GradientAttribution
|
21 |
+
from captum.attr._utils.batching import _batch_attribution
|
22 |
+
from captum.attr._utils.common import (
|
23 |
+
_format_input_baseline,
|
24 |
+
_reshape_and_sum,
|
25 |
+
_validate_input,
|
26 |
+
)
|
27 |
+
from captum.log import log_usage
|
28 |
+
from torch import Tensor
|
29 |
+
|
30 |
+
|
31 |
+
class IntegratedGradients(GradientAttribution):
|
32 |
+
r"""
|
33 |
+
Integrated Gradients is an axiomatic model interpretability algorithm that
|
34 |
+
assigns an importance score to each input feature by approximating the
|
35 |
+
integral of gradients of the model's output with respect to the inputs
|
36 |
+
along the path (straight line) from given baselines / references to inputs.
|
37 |
+
|
38 |
+
Baselines can be provided as input arguments to attribute method.
|
39 |
+
To approximate the integral we can choose to use either a variant of
|
40 |
+
Riemann sum or Gauss-Legendre quadrature rule.
|
41 |
+
|
42 |
+
More details regarding the integrated gradients method can be found in the
|
43 |
+
original paper:
|
44 |
+
https://arxiv.org/abs/1703.01365
|
45 |
+
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
forward_func: Callable,
|
51 |
+
multiply_by_inputs: bool = True,
|
52 |
+
) -> None:
|
53 |
+
r"""
|
54 |
+
Args:
|
55 |
+
|
56 |
+
forward_func (callable): The forward function of the model or any
|
57 |
+
modification of it
|
58 |
+
multiply_by_inputs (bool, optional): Indicates whether to factor
|
59 |
+
model inputs' multiplier in the final attribution scores.
|
60 |
+
In the literature this is also known as local vs global
|
61 |
+
attribution. If inputs' multiplier isn't factored in,
|
62 |
+
then that type of attribution method is also called local
|
63 |
+
attribution. If it is, then that type of attribution
|
64 |
+
method is called global.
|
65 |
+
More detailed can be found here:
|
66 |
+
https://arxiv.org/abs/1711.06104
|
67 |
+
|
68 |
+
In case of integrated gradients, if `multiply_by_inputs`
|
69 |
+
is set to True, final sensitivity scores are being multiplied by
|
70 |
+
(inputs - baselines).
|
71 |
+
"""
|
72 |
+
GradientAttribution.__init__(self, forward_func)
|
73 |
+
self._multiply_by_inputs = multiply_by_inputs
|
74 |
+
|
75 |
+
# The following overloaded method signatures correspond to the case where
|
76 |
+
# return_convergence_delta is False, then only attributions are returned,
|
77 |
+
# and when return_convergence_delta is True, the return type is
|
78 |
+
# a tuple with both attributions and deltas.
|
79 |
+
@typing.overload
|
80 |
+
def attribute(
|
81 |
+
self,
|
82 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
83 |
+
baselines: BaselineType = None,
|
84 |
+
target: TargetType = None,
|
85 |
+
additional_forward_args: Any = None,
|
86 |
+
n_steps: int = 50,
|
87 |
+
method: str = "gausslegendre",
|
88 |
+
internal_batch_size: Union[None, int] = None,
|
89 |
+
return_convergence_delta: Literal[False] = False,
|
90 |
+
) -> TensorOrTupleOfTensorsGeneric:
|
91 |
+
...
|
92 |
+
|
93 |
+
@typing.overload
|
94 |
+
def attribute(
|
95 |
+
self,
|
96 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
97 |
+
baselines: BaselineType = None,
|
98 |
+
target: TargetType = None,
|
99 |
+
additional_forward_args: Any = None,
|
100 |
+
n_steps: int = 50,
|
101 |
+
method: str = "gausslegendre",
|
102 |
+
internal_batch_size: Union[None, int] = None,
|
103 |
+
*,
|
104 |
+
return_convergence_delta: Literal[True],
|
105 |
+
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
|
106 |
+
...
|
107 |
+
|
108 |
+
@log_usage()
|
109 |
+
def attribute( # type: ignore
|
110 |
+
self,
|
111 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
112 |
+
baselines: BaselineType = None,
|
113 |
+
target: TargetType = None,
|
114 |
+
additional_forward_args: Any = None,
|
115 |
+
n_steps: int = 50,
|
116 |
+
method: str = "gausslegendre",
|
117 |
+
internal_batch_size: Union[None, int] = None,
|
118 |
+
return_convergence_delta: bool = False,
|
119 |
+
) -> Union[
|
120 |
+
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
|
121 |
+
]:
|
122 |
+
r"""
|
123 |
+
This method attributes the output of the model with given target index
|
124 |
+
(in case it is provided, otherwise it assumes that output is a
|
125 |
+
scalar) to the inputs of the model using the approach described above.
|
126 |
+
|
127 |
+
In addition to that it also returns, if `return_convergence_delta` is
|
128 |
+
set to True, integral approximation delta based on the completeness
|
129 |
+
property of integrated gradients.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
|
133 |
+
inputs (tensor or tuple of tensors): Input for which integrated
|
134 |
+
gradients are computed. If forward_func takes a single
|
135 |
+
tensor as input, a single input tensor should be provided.
|
136 |
+
If forward_func takes multiple tensors as input, a tuple
|
137 |
+
of the input tensors should be provided. It is assumed
|
138 |
+
that for all given input tensors, dimension 0 corresponds
|
139 |
+
to the number of examples, and if multiple input tensors
|
140 |
+
are provided, the examples must be aligned appropriately.
|
141 |
+
baselines (scalar, tensor, tuple of scalars or tensors, optional):
|
142 |
+
Baselines define the starting point from which integral
|
143 |
+
is computed and can be provided as:
|
144 |
+
|
145 |
+
- a single tensor, if inputs is a single tensor, with
|
146 |
+
exactly the same dimensions as inputs or the first
|
147 |
+
dimension is one and the remaining dimensions match
|
148 |
+
with inputs.
|
149 |
+
|
150 |
+
- a single scalar, if inputs is a single tensor, which will
|
151 |
+
be broadcasted for each input value in input tensor.
|
152 |
+
|
153 |
+
- a tuple of tensors or scalars, the baseline corresponding
|
154 |
+
to each tensor in the inputs' tuple can be:
|
155 |
+
|
156 |
+
- either a tensor with matching dimensions to
|
157 |
+
corresponding tensor in the inputs' tuple
|
158 |
+
or the first dimension is one and the remaining
|
159 |
+
dimensions match with the corresponding
|
160 |
+
input tensor.
|
161 |
+
|
162 |
+
- or a scalar, corresponding to a tensor in the
|
163 |
+
inputs' tuple. This scalar value is broadcasted
|
164 |
+
for corresponding input tensor.
|
165 |
+
In the cases when `baselines` is not provided, we internally
|
166 |
+
use zero scalar corresponding to each input tensor.
|
167 |
+
|
168 |
+
Default: None
|
169 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
170 |
+
which gradients are computed (for classification cases,
|
171 |
+
this is usually the target class).
|
172 |
+
If the network returns a scalar value per example,
|
173 |
+
no target index is necessary.
|
174 |
+
For general 2D outputs, targets can be either:
|
175 |
+
|
176 |
+
- a single integer or a tensor containing a single
|
177 |
+
integer, which is applied to all input examples
|
178 |
+
|
179 |
+
- a list of integers or a 1D tensor, with length matching
|
180 |
+
the number of examples in inputs (dim 0). Each integer
|
181 |
+
is applied as the target for the corresponding example.
|
182 |
+
|
183 |
+
For outputs with > 2 dimensions, targets can be either:
|
184 |
+
|
185 |
+
- A single tuple, which contains #output_dims - 1
|
186 |
+
elements. This target index is applied to all examples.
|
187 |
+
|
188 |
+
- A list of tuples with length equal to the number of
|
189 |
+
examples in inputs (dim 0), and each tuple containing
|
190 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
191 |
+
target for the corresponding example.
|
192 |
+
|
193 |
+
Default: None
|
194 |
+
additional_forward_args (any, optional): If the forward function
|
195 |
+
requires additional arguments other than the inputs for
|
196 |
+
which attributions should not be computed, this argument
|
197 |
+
can be provided. It must be either a single additional
|
198 |
+
argument of a Tensor or arbitrary (non-tuple) type or a
|
199 |
+
tuple containing multiple additional arguments including
|
200 |
+
tensors or any arbitrary python types. These arguments
|
201 |
+
are provided to forward_func in order following the
|
202 |
+
arguments in inputs.
|
203 |
+
For a tensor, the first dimension of the tensor must
|
204 |
+
correspond to the number of examples. It will be
|
205 |
+
repeated for each of `n_steps` along the integrated
|
206 |
+
path. For all other types, the given argument is used
|
207 |
+
for all forward evaluations.
|
208 |
+
Note that attributions are not computed with respect
|
209 |
+
to these arguments.
|
210 |
+
Default: None
|
211 |
+
n_steps (int, optional): The number of steps used by the approximation
|
212 |
+
method. Default: 50.
|
213 |
+
method (string, optional): Method for approximating the integral,
|
214 |
+
one of `riemann_right`, `riemann_left`, `riemann_middle`,
|
215 |
+
`riemann_trapezoid` or `gausslegendre`.
|
216 |
+
Default: `gausslegendre` if no method is provided.
|
217 |
+
internal_batch_size (int, optional): Divides total #steps * #examples
|
218 |
+
data points into chunks of size at most internal_batch_size,
|
219 |
+
which are computed (forward / backward passes)
|
220 |
+
sequentially. internal_batch_size must be at least equal to
|
221 |
+
#examples.
|
222 |
+
For DataParallel models, each batch is split among the
|
223 |
+
available devices, so evaluations on each available
|
224 |
+
device contain internal_batch_size / num_devices examples.
|
225 |
+
If internal_batch_size is None, then all evaluations are
|
226 |
+
processed in one batch.
|
227 |
+
Default: None
|
228 |
+
return_convergence_delta (bool, optional): Indicates whether to return
|
229 |
+
convergence delta or not. If `return_convergence_delta`
|
230 |
+
is set to True convergence delta will be returned in
|
231 |
+
a tuple following attributions.
|
232 |
+
Default: False
|
233 |
+
Returns:
|
234 |
+
**attributions** or 2-element tuple of **attributions**, **delta**:
|
235 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
236 |
+
Integrated gradients with respect to each input feature.
|
237 |
+
attributions will always be the same size as the provided
|
238 |
+
inputs, with each value providing the attribution of the
|
239 |
+
corresponding input index.
|
240 |
+
If a single tensor is provided as inputs, a single tensor is
|
241 |
+
returned. If a tuple is provided for inputs, a tuple of
|
242 |
+
corresponding sized tensors is returned.
|
243 |
+
- **delta** (*tensor*, returned if return_convergence_delta=True):
|
244 |
+
The difference between the total approximated and true
|
245 |
+
integrated gradients. This is computed using the property
|
246 |
+
that the total sum of forward_func(inputs) -
|
247 |
+
forward_func(baselines) must equal the total sum of the
|
248 |
+
integrated gradient.
|
249 |
+
Delta is calculated per example, meaning that the number of
|
250 |
+
elements in returned delta tensor is equal to the number of
|
251 |
+
of examples in inputs.
|
252 |
+
|
253 |
+
Examples::
|
254 |
+
|
255 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
256 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
257 |
+
>>> net = ImageClassifier()
|
258 |
+
>>> ig = IntegratedGradients(net)
|
259 |
+
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
|
260 |
+
>>> # Computes integrated gradients for class 3.
|
261 |
+
>>> attribution = ig.attribute(input, target=3)
|
262 |
+
"""
|
263 |
+
# Keeps track whether original input is a tuple or not before
|
264 |
+
# converting it into a tuple.
|
265 |
+
is_inputs_tuple = _is_tuple(inputs)
|
266 |
+
|
267 |
+
inputs, baselines = _format_input_baseline(inputs, baselines)
|
268 |
+
|
269 |
+
_validate_input(inputs, baselines, n_steps, method)
|
270 |
+
|
271 |
+
if internal_batch_size is not None:
|
272 |
+
num_examples = inputs[0].shape[0]
|
273 |
+
attributions = _batch_attribution(
|
274 |
+
self,
|
275 |
+
num_examples,
|
276 |
+
internal_batch_size,
|
277 |
+
n_steps,
|
278 |
+
inputs=inputs,
|
279 |
+
baselines=baselines,
|
280 |
+
target=target,
|
281 |
+
additional_forward_args=additional_forward_args,
|
282 |
+
method=method,
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
attributions = self._attribute(
|
286 |
+
inputs=inputs,
|
287 |
+
baselines=baselines,
|
288 |
+
target=target,
|
289 |
+
additional_forward_args=additional_forward_args,
|
290 |
+
n_steps=n_steps,
|
291 |
+
method=method,
|
292 |
+
)
|
293 |
+
|
294 |
+
if return_convergence_delta:
|
295 |
+
start_point, end_point = baselines, inputs
|
296 |
+
# computes approximation error based on the completeness axiom
|
297 |
+
delta = self.compute_convergence_delta(
|
298 |
+
attributions,
|
299 |
+
start_point,
|
300 |
+
end_point,
|
301 |
+
additional_forward_args=additional_forward_args,
|
302 |
+
target=target,
|
303 |
+
)
|
304 |
+
return _format_output(is_inputs_tuple, attributions), delta
|
305 |
+
return _format_output(is_inputs_tuple, attributions)
|
306 |
+
|
307 |
+
def _attribute(
|
308 |
+
self,
|
309 |
+
inputs: Tuple[Tensor, ...],
|
310 |
+
baselines: Tuple[Union[Tensor, int, float], ...],
|
311 |
+
target: TargetType = None,
|
312 |
+
additional_forward_args: Any = None,
|
313 |
+
n_steps: int = 50,
|
314 |
+
method: str = "gausslegendre",
|
315 |
+
step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,
|
316 |
+
) -> Tuple[Tensor, ...]:
|
317 |
+
if step_sizes_and_alphas is None:
|
318 |
+
# retrieve step size and scaling factor for specified
|
319 |
+
# approximation method
|
320 |
+
step_sizes_func, alphas_func = approximation_parameters(method)
|
321 |
+
step_sizes, alphas = step_sizes_func(n_steps), alphas_func(n_steps)
|
322 |
+
else:
|
323 |
+
step_sizes, alphas = step_sizes_and_alphas
|
324 |
+
|
325 |
+
# scale features and compute gradients. (batch size is abbreviated as bsz)
|
326 |
+
# scaled_features' dim -> (bsz * #steps x inputs[0].shape[1:], ...)
|
327 |
+
scaled_features_tpl = tuple(
|
328 |
+
torch.cat(
|
329 |
+
[baseline + alpha * (input - baseline) for alpha in alphas], dim=0
|
330 |
+
).requires_grad_()
|
331 |
+
for input, baseline in zip(inputs, baselines)
|
332 |
+
)
|
333 |
+
|
334 |
+
additional_forward_args = _format_additional_forward_args(
|
335 |
+
additional_forward_args
|
336 |
+
)
|
337 |
+
# apply number of steps to additional forward args
|
338 |
+
# currently, number of steps is applied only to additional forward arguments
|
339 |
+
# that are nd-tensors. It is assumed that the first dimension is
|
340 |
+
# the number of batches.
|
341 |
+
# dim -> (bsz * #steps x additional_forward_args[0].shape[1:], ...)
|
342 |
+
input_additional_args = (
|
343 |
+
_expand_additional_forward_args(additional_forward_args, n_steps)
|
344 |
+
if additional_forward_args is not None
|
345 |
+
else None
|
346 |
+
)
|
347 |
+
expanded_target = _expand_target(target, n_steps)
|
348 |
+
|
349 |
+
# grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...)
|
350 |
+
grads = self.gradient_func(
|
351 |
+
forward_fn=self.forward_func,
|
352 |
+
inputs=scaled_features_tpl,
|
353 |
+
target_ind=expanded_target,
|
354 |
+
additional_forward_args=input_additional_args,
|
355 |
+
)
|
356 |
+
|
357 |
+
# flattening grads so that we can multilpy it with step-size
|
358 |
+
# calling contiguous to avoid `memory whole` problems
|
359 |
+
scaled_grads = [
|
360 |
+
grad.contiguous().view(n_steps, -1)
|
361 |
+
* torch.tensor(step_sizes).view(n_steps, 1).to(grad.device)
|
362 |
+
for grad in grads
|
363 |
+
]
|
364 |
+
|
365 |
+
# aggregates across all steps for each tensor in the input tuple
|
366 |
+
# total_grads has the same dimensionality as inputs
|
367 |
+
total_grads = tuple(
|
368 |
+
_reshape_and_sum(
|
369 |
+
scaled_grad, n_steps, grad.shape[0] // n_steps, grad.shape[1:]
|
370 |
+
)
|
371 |
+
for (scaled_grad, grad) in zip(scaled_grads, grads)
|
372 |
+
)
|
373 |
+
|
374 |
+
# computes attribution for each tensor in input tuple
|
375 |
+
# attributions has the same dimensionality as inputs
|
376 |
+
if not self.multiplies_by_inputs:
|
377 |
+
attributions = total_grads
|
378 |
+
else:
|
379 |
+
attributions = tuple(
|
380 |
+
total_grad * (input - baseline)
|
381 |
+
for total_grad, input, baseline in zip(total_grads, inputs, baselines)
|
382 |
+
)
|
383 |
+
return attributions
|
384 |
+
|
385 |
+
def has_convergence_delta(self) -> bool:
|
386 |
+
return True
|
387 |
+
|
388 |
+
@property
|
389 |
+
def multiplies_by_inputs(self):
|
390 |
+
return self._multiply_by_inputs
|
captum/attr/_core/kernel_shap.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
from typing import Any, Callable, Generator, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from captum._utils.models.linear_model import SkLearnLinearRegression
|
7 |
+
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
|
8 |
+
from captum.attr._core.lime import construct_feature_mask, Lime
|
9 |
+
from captum.attr._utils.common import _format_input_baseline
|
10 |
+
from captum.log import log_usage
|
11 |
+
from torch import Tensor
|
12 |
+
from torch.distributions.categorical import Categorical
|
13 |
+
|
14 |
+
|
15 |
+
class KernelShap(Lime):
|
16 |
+
r"""
|
17 |
+
Kernel SHAP is a method that uses the LIME framework to compute
|
18 |
+
Shapley Values. Setting the loss function, weighting kernel and
|
19 |
+
regularization terms appropriately in the LIME framework allows
|
20 |
+
theoretically obtaining Shapley Values more efficiently than
|
21 |
+
directly computing Shapley Values.
|
22 |
+
|
23 |
+
More information regarding this method and proof of equivalence
|
24 |
+
can be found in the original paper here:
|
25 |
+
https://arxiv.org/abs/1705.07874
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, forward_func: Callable) -> None:
|
29 |
+
r"""
|
30 |
+
Args:
|
31 |
+
|
32 |
+
forward_func (callable): The forward function of the model or
|
33 |
+
any modification of it
|
34 |
+
"""
|
35 |
+
Lime.__init__(
|
36 |
+
self,
|
37 |
+
forward_func,
|
38 |
+
interpretable_model=SkLearnLinearRegression(),
|
39 |
+
similarity_func=self.kernel_shap_similarity_kernel,
|
40 |
+
perturb_func=self.kernel_shap_perturb_generator,
|
41 |
+
)
|
42 |
+
self.inf_weight = 1000000.0
|
43 |
+
|
44 |
+
@log_usage()
|
45 |
+
def attribute( # type: ignore
|
46 |
+
self,
|
47 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
48 |
+
baselines: BaselineType = None,
|
49 |
+
target: TargetType = None,
|
50 |
+
additional_forward_args: Any = None,
|
51 |
+
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
|
52 |
+
n_samples: int = 25,
|
53 |
+
perturbations_per_eval: int = 1,
|
54 |
+
return_input_shape: bool = True,
|
55 |
+
show_progress: bool = False,
|
56 |
+
) -> TensorOrTupleOfTensorsGeneric:
|
57 |
+
r"""
|
58 |
+
This method attributes the output of the model with given target index
|
59 |
+
(in case it is provided, otherwise it assumes that output is a
|
60 |
+
scalar) to the inputs of the model using the approach described above,
|
61 |
+
training an interpretable model based on KernelSHAP and returning a
|
62 |
+
representation of the interpretable model.
|
63 |
+
|
64 |
+
It is recommended to only provide a single example as input (tensors
|
65 |
+
with first dimension or batch size = 1). This is because LIME / KernelShap
|
66 |
+
is generally used for sample-based interpretability, training a separate
|
67 |
+
interpretable model to explain a model's prediction on each individual example.
|
68 |
+
|
69 |
+
A batch of inputs can also be provided as inputs, similar to
|
70 |
+
other perturbation-based attribution methods. In this case, if forward_fn
|
71 |
+
returns a scalar per example, attributions will be computed for each
|
72 |
+
example independently, with a separate interpretable model trained for each
|
73 |
+
example. Note that provided similarity and perturbation functions will be
|
74 |
+
provided each example separately (first dimension = 1) in this case.
|
75 |
+
If forward_fn returns a scalar per batch (e.g. loss), attributions will
|
76 |
+
still be computed using a single interpretable model for the full batch.
|
77 |
+
In this case, similarity and perturbation functions will be provided the
|
78 |
+
same original input containing the full batch.
|
79 |
+
|
80 |
+
The number of interpretable features is determined from the provided
|
81 |
+
feature mask, or if none is provided, from the default feature mask,
|
82 |
+
which considers each scalar input as a separate feature. It is
|
83 |
+
generally recommended to provide a feature mask which groups features
|
84 |
+
into a small number of interpretable features / components (e.g.
|
85 |
+
superpixels in images).
|
86 |
+
|
87 |
+
Args:
|
88 |
+
|
89 |
+
inputs (tensor or tuple of tensors): Input for which KernelShap
|
90 |
+
is computed. If forward_func takes a single
|
91 |
+
tensor as input, a single input tensor should be provided.
|
92 |
+
If forward_func takes multiple tensors as input, a tuple
|
93 |
+
of the input tensors should be provided. It is assumed
|
94 |
+
that for all given input tensors, dimension 0 corresponds
|
95 |
+
to the number of examples, and if multiple input tensors
|
96 |
+
are provided, the examples must be aligned appropriately.
|
97 |
+
baselines (scalar, tensor, tuple of scalars or tensors, optional):
|
98 |
+
Baselines define the reference value which replaces each
|
99 |
+
feature when the corresponding interpretable feature
|
100 |
+
is set to 0.
|
101 |
+
Baselines can be provided as:
|
102 |
+
|
103 |
+
- a single tensor, if inputs is a single tensor, with
|
104 |
+
exactly the same dimensions as inputs or the first
|
105 |
+
dimension is one and the remaining dimensions match
|
106 |
+
with inputs.
|
107 |
+
|
108 |
+
- a single scalar, if inputs is a single tensor, which will
|
109 |
+
be broadcasted for each input value in input tensor.
|
110 |
+
|
111 |
+
- a tuple of tensors or scalars, the baseline corresponding
|
112 |
+
to each tensor in the inputs' tuple can be:
|
113 |
+
|
114 |
+
- either a tensor with matching dimensions to
|
115 |
+
corresponding tensor in the inputs' tuple
|
116 |
+
or the first dimension is one and the remaining
|
117 |
+
dimensions match with the corresponding
|
118 |
+
input tensor.
|
119 |
+
|
120 |
+
- or a scalar, corresponding to a tensor in the
|
121 |
+
inputs' tuple. This scalar value is broadcasted
|
122 |
+
for corresponding input tensor.
|
123 |
+
In the cases when `baselines` is not provided, we internally
|
124 |
+
use zero scalar corresponding to each input tensor.
|
125 |
+
Default: None
|
126 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
127 |
+
which surrogate model is trained
|
128 |
+
(for classification cases,
|
129 |
+
this is usually the target class).
|
130 |
+
If the network returns a scalar value per example,
|
131 |
+
no target index is necessary.
|
132 |
+
For general 2D outputs, targets can be either:
|
133 |
+
|
134 |
+
- a single integer or a tensor containing a single
|
135 |
+
integer, which is applied to all input examples
|
136 |
+
|
137 |
+
- a list of integers or a 1D tensor, with length matching
|
138 |
+
the number of examples in inputs (dim 0). Each integer
|
139 |
+
is applied as the target for the corresponding example.
|
140 |
+
|
141 |
+
For outputs with > 2 dimensions, targets can be either:
|
142 |
+
|
143 |
+
- A single tuple, which contains #output_dims - 1
|
144 |
+
elements. This target index is applied to all examples.
|
145 |
+
|
146 |
+
- A list of tuples with length equal to the number of
|
147 |
+
examples in inputs (dim 0), and each tuple containing
|
148 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
149 |
+
target for the corresponding example.
|
150 |
+
|
151 |
+
Default: None
|
152 |
+
additional_forward_args (any, optional): If the forward function
|
153 |
+
requires additional arguments other than the inputs for
|
154 |
+
which attributions should not be computed, this argument
|
155 |
+
can be provided. It must be either a single additional
|
156 |
+
argument of a Tensor or arbitrary (non-tuple) type or a
|
157 |
+
tuple containing multiple additional arguments including
|
158 |
+
tensors or any arbitrary python types. These arguments
|
159 |
+
are provided to forward_func in order following the
|
160 |
+
arguments in inputs.
|
161 |
+
For a tensor, the first dimension of the tensor must
|
162 |
+
correspond to the number of examples. It will be
|
163 |
+
repeated for each of `n_steps` along the integrated
|
164 |
+
path. For all other types, the given argument is used
|
165 |
+
for all forward evaluations.
|
166 |
+
Note that attributions are not computed with respect
|
167 |
+
to these arguments.
|
168 |
+
Default: None
|
169 |
+
feature_mask (tensor or tuple of tensors, optional):
|
170 |
+
feature_mask defines a mask for the input, grouping
|
171 |
+
features which correspond to the same
|
172 |
+
interpretable feature. feature_mask
|
173 |
+
should contain the same number of tensors as inputs.
|
174 |
+
Each tensor should
|
175 |
+
be the same size as the corresponding input or
|
176 |
+
broadcastable to match the input tensor. Values across
|
177 |
+
all tensors should be integers in the range 0 to
|
178 |
+
num_interp_features - 1, and indices corresponding to the
|
179 |
+
same feature should have the same value.
|
180 |
+
Note that features are grouped across tensors
|
181 |
+
(unlike feature ablation and occlusion), so
|
182 |
+
if the same index is used in different tensors, those
|
183 |
+
features are still grouped and added simultaneously.
|
184 |
+
If None, then a feature mask is constructed which assigns
|
185 |
+
each scalar within a tensor as a separate feature.
|
186 |
+
Default: None
|
187 |
+
n_samples (int, optional): The number of samples of the original
|
188 |
+
model used to train the surrogate interpretable model.
|
189 |
+
Default: `50` if `n_samples` is not provided.
|
190 |
+
perturbations_per_eval (int, optional): Allows multiple samples
|
191 |
+
to be processed simultaneously in one call to forward_fn.
|
192 |
+
Each forward pass will contain a maximum of
|
193 |
+
perturbations_per_eval * #examples samples.
|
194 |
+
For DataParallel models, each batch is split among the
|
195 |
+
available devices, so evaluations on each available
|
196 |
+
device contain at most
|
197 |
+
(perturbations_per_eval * #examples) / num_devices
|
198 |
+
samples.
|
199 |
+
If the forward function returns a single scalar per batch,
|
200 |
+
perturbations_per_eval must be set to 1.
|
201 |
+
Default: 1
|
202 |
+
return_input_shape (bool, optional): Determines whether the returned
|
203 |
+
tensor(s) only contain the coefficients for each interp-
|
204 |
+
retable feature from the trained surrogate model, or
|
205 |
+
whether the returned attributions match the input shape.
|
206 |
+
When return_input_shape is True, the return type of attribute
|
207 |
+
matches the input shape, with each element containing the
|
208 |
+
coefficient of the corresponding interpretable feature.
|
209 |
+
All elements with the same value in the feature mask
|
210 |
+
will contain the same coefficient in the returned
|
211 |
+
attributions. If return_input_shape is False, a 1D
|
212 |
+
tensor is returned, containing only the coefficients
|
213 |
+
of the trained interpretable model, with length
|
214 |
+
num_interp_features.
|
215 |
+
show_progress (bool, optional): Displays the progress of computation.
|
216 |
+
It will try to use tqdm if available for advanced features
|
217 |
+
(e.g. time estimation). Otherwise, it will fallback to
|
218 |
+
a simple output of progress.
|
219 |
+
Default: False
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
*tensor* or tuple of *tensors* of **attributions**:
|
223 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
224 |
+
The attributions with respect to each input feature.
|
225 |
+
If return_input_shape = True, attributions will be
|
226 |
+
the same size as the provided inputs, with each value
|
227 |
+
providing the coefficient of the corresponding
|
228 |
+
interpretale feature.
|
229 |
+
If return_input_shape is False, a 1D
|
230 |
+
tensor is returned, containing only the coefficients
|
231 |
+
of the trained interpreatable models, with length
|
232 |
+
num_interp_features.
|
233 |
+
Examples::
|
234 |
+
>>> # SimpleClassifier takes a single input tensor of size Nx4x4,
|
235 |
+
>>> # and returns an Nx3 tensor of class probabilities.
|
236 |
+
>>> net = SimpleClassifier()
|
237 |
+
|
238 |
+
>>> # Generating random input with size 1 x 4 x 4
|
239 |
+
>>> input = torch.randn(1, 4, 4)
|
240 |
+
|
241 |
+
>>> # Defining KernelShap interpreter
|
242 |
+
>>> ks = KernelShap(net)
|
243 |
+
>>> # Computes attribution, with each of the 4 x 4 = 16
|
244 |
+
>>> # features as a separate interpretable feature
|
245 |
+
>>> attr = ks.attribute(input, target=1, n_samples=200)
|
246 |
+
|
247 |
+
>>> # Alternatively, we can group each 2x2 square of the inputs
|
248 |
+
>>> # as one 'interpretable' feature and perturb them together.
|
249 |
+
>>> # This can be done by creating a feature mask as follows, which
|
250 |
+
>>> # defines the feature groups, e.g.:
|
251 |
+
>>> # +---+---+---+---+
|
252 |
+
>>> # | 0 | 0 | 1 | 1 |
|
253 |
+
>>> # +---+---+---+---+
|
254 |
+
>>> # | 0 | 0 | 1 | 1 |
|
255 |
+
>>> # +---+---+---+---+
|
256 |
+
>>> # | 2 | 2 | 3 | 3 |
|
257 |
+
>>> # +---+---+---+---+
|
258 |
+
>>> # | 2 | 2 | 3 | 3 |
|
259 |
+
>>> # +---+---+---+---+
|
260 |
+
>>> # With this mask, all inputs with the same value are set to their
|
261 |
+
>>> # baseline value, when the corresponding binary interpretable
|
262 |
+
>>> # feature is set to 0.
|
263 |
+
>>> # The attributions can be calculated as follows:
|
264 |
+
>>> # feature mask has dimensions 1 x 4 x 4
|
265 |
+
>>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1],
|
266 |
+
>>> [2,2,3,3],[2,2,3,3]]])
|
267 |
+
|
268 |
+
>>> # Computes KernelSHAP attributions with feature mask.
|
269 |
+
>>> attr = ks.attribute(input, target=1, feature_mask=feature_mask)
|
270 |
+
"""
|
271 |
+
formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
|
272 |
+
feature_mask, num_interp_features = construct_feature_mask(
|
273 |
+
feature_mask, formatted_inputs
|
274 |
+
)
|
275 |
+
num_features_list = torch.arange(num_interp_features, dtype=torch.float)
|
276 |
+
denom = num_features_list * (num_interp_features - num_features_list)
|
277 |
+
probs = (num_interp_features - 1) / denom
|
278 |
+
probs[0] = 0.0
|
279 |
+
return self._attribute_kwargs(
|
280 |
+
inputs=inputs,
|
281 |
+
baselines=baselines,
|
282 |
+
target=target,
|
283 |
+
additional_forward_args=additional_forward_args,
|
284 |
+
feature_mask=feature_mask,
|
285 |
+
n_samples=n_samples,
|
286 |
+
perturbations_per_eval=perturbations_per_eval,
|
287 |
+
return_input_shape=return_input_shape,
|
288 |
+
num_select_distribution=Categorical(probs),
|
289 |
+
show_progress=show_progress,
|
290 |
+
)
|
291 |
+
|
292 |
+
def kernel_shap_similarity_kernel(
|
293 |
+
self, _, __, interpretable_sample: Tensor, **kwargs
|
294 |
+
) -> Tensor:
|
295 |
+
assert (
|
296 |
+
"num_interp_features" in kwargs
|
297 |
+
), "Must provide num_interp_features to use default similarity kernel"
|
298 |
+
num_selected_features = int(interpretable_sample.sum(dim=1).item())
|
299 |
+
num_features = kwargs["num_interp_features"]
|
300 |
+
if num_selected_features == 0 or num_selected_features == num_features:
|
301 |
+
# weight should be theoretically infinite when
|
302 |
+
# num_selected_features = 0 or num_features
|
303 |
+
# enforcing that trained linear model must satisfy
|
304 |
+
# end-point criteria. In practice, it is sufficient to
|
305 |
+
# make this weight substantially larger so setting this
|
306 |
+
# weight to 1000000 (all other weights are 1).
|
307 |
+
similarities = self.inf_weight
|
308 |
+
else:
|
309 |
+
similarities = 1.0
|
310 |
+
return torch.tensor([similarities])
|
311 |
+
|
312 |
+
def kernel_shap_perturb_generator(
|
313 |
+
self, original_inp: Union[Tensor, Tuple[Tensor, ...]], **kwargs
|
314 |
+
) -> Generator[Tensor, None, None]:
|
315 |
+
r"""
|
316 |
+
Perturbations are sampled by the following process:
|
317 |
+
- Choose k (number of selected features), based on the distribution
|
318 |
+
p(k) = (M - 1) / (k * (M - k))
|
319 |
+
where M is the total number of features in the interpretable space
|
320 |
+
- Randomly select a binary vector with k ones, each sample is equally
|
321 |
+
likely. This is done by generating a random vector of normal
|
322 |
+
values and thresholding based on the top k elements.
|
323 |
+
|
324 |
+
Since there are M choose k vectors with k ones, this weighted sampling
|
325 |
+
is equivalent to applying the Shapley kernel for the sample weight,
|
326 |
+
defined as:
|
327 |
+
k(M, k) = (M - 1) / (k * (M - k) * (M choose k))
|
328 |
+
"""
|
329 |
+
assert (
|
330 |
+
"num_select_distribution" in kwargs and "num_interp_features" in kwargs
|
331 |
+
), (
|
332 |
+
"num_select_distribution and num_interp_features are necessary"
|
333 |
+
" to use kernel_shap_perturb_func"
|
334 |
+
)
|
335 |
+
if isinstance(original_inp, Tensor):
|
336 |
+
device = original_inp.device
|
337 |
+
else:
|
338 |
+
device = original_inp[0].device
|
339 |
+
num_features = kwargs["num_interp_features"]
|
340 |
+
yield torch.ones(1, num_features, device=device, dtype=torch.long)
|
341 |
+
yield torch.zeros(1, num_features, device=device, dtype=torch.long)
|
342 |
+
while True:
|
343 |
+
num_selected_features = kwargs["num_select_distribution"].sample()
|
344 |
+
rand_vals = torch.randn(1, num_features)
|
345 |
+
threshold = torch.kthvalue(
|
346 |
+
rand_vals, num_features - num_selected_features
|
347 |
+
).values.item()
|
348 |
+
yield (rand_vals > threshold).to(device=device).long()
|
captum/attr/_core/layer/__init__.py
ADDED
File without changes
|
captum/attr/_core/layer/grad_cam.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from typing import Any, Callable, List, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from captum._utils.common import (
|
7 |
+
_format_additional_forward_args,
|
8 |
+
_format_output,
|
9 |
+
_format_tensor_into_tuples,
|
10 |
+
)
|
11 |
+
from captum._utils.gradient import compute_layer_gradients_and_eval
|
12 |
+
from captum._utils.typing import TargetType
|
13 |
+
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
|
14 |
+
from captum.log import log_usage
|
15 |
+
from torch import Tensor
|
16 |
+
from torch.nn import Module
|
17 |
+
|
18 |
+
|
19 |
+
class LayerGradCam(LayerAttribution, GradientAttribution):
|
20 |
+
r"""
|
21 |
+
Computes GradCAM attribution for chosen layer. GradCAM is designed for
|
22 |
+
convolutional neural networks, and is usually applied to the last
|
23 |
+
convolutional layer.
|
24 |
+
|
25 |
+
GradCAM computes the gradients of the target output with respect to
|
26 |
+
the given layer, averages for each output channel (dimension 2 of
|
27 |
+
output), and multiplies the average gradient for each channel by the
|
28 |
+
layer activations. The results are summed over all channels.
|
29 |
+
|
30 |
+
Note that in the original GradCAM algorithm described in the paper,
|
31 |
+
ReLU is applied to the output, returning only non-negative attributions.
|
32 |
+
For providing more flexibility to the user, we choose to not perform the
|
33 |
+
ReLU internally by default and return the sign information. To match the
|
34 |
+
original GradCAM algorithm, it is necessary to pass the parameter
|
35 |
+
relu_attributions=True to apply ReLU on the final
|
36 |
+
attributions or alternatively only visualize the positive attributions.
|
37 |
+
|
38 |
+
Note: this procedure sums over the second dimension (# of channels),
|
39 |
+
so the output of GradCAM attributions will have a second
|
40 |
+
dimension of 1, but all other dimensions will match that of the layer
|
41 |
+
output.
|
42 |
+
|
43 |
+
GradCAM attributions are generally upsampled and can be viewed as a
|
44 |
+
mask to the input, since a convolutional layer output generally
|
45 |
+
matches the input image spatially. This upsampling can be performed
|
46 |
+
using LayerAttribution.interpolate, as shown in the example below.
|
47 |
+
|
48 |
+
More details regarding the GradCAM method can be found in the
|
49 |
+
original paper here:
|
50 |
+
https://arxiv.org/pdf/1610.02391.pdf
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
forward_func: Callable,
|
56 |
+
layer: Module,
|
57 |
+
device_ids: Union[None, List[int]] = None,
|
58 |
+
) -> None:
|
59 |
+
r"""
|
60 |
+
Args:
|
61 |
+
|
62 |
+
forward_func (callable): The forward function of the model or any
|
63 |
+
modification of it
|
64 |
+
layer (torch.nn.Module): Layer for which attributions are computed.
|
65 |
+
Output size of attribute matches this layer's output
|
66 |
+
dimensions, except for dimension 2, which will be 1,
|
67 |
+
since GradCAM sums over channels.
|
68 |
+
device_ids (list(int)): Device ID list, necessary only if forward_func
|
69 |
+
applies a DataParallel model. This allows reconstruction of
|
70 |
+
intermediate outputs from batched results across devices.
|
71 |
+
If forward_func is given as the DataParallel model itself,
|
72 |
+
then it is not necessary to provide this argument.
|
73 |
+
"""
|
74 |
+
LayerAttribution.__init__(self, forward_func, layer, device_ids)
|
75 |
+
GradientAttribution.__init__(self, forward_func)
|
76 |
+
|
77 |
+
@log_usage()
|
78 |
+
def attribute(
|
79 |
+
self,
|
80 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
81 |
+
target: TargetType = None,
|
82 |
+
additional_forward_args: Any = None,
|
83 |
+
attribute_to_layer_input: bool = False,
|
84 |
+
relu_attributions: bool = False,
|
85 |
+
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
86 |
+
r"""
|
87 |
+
Args:
|
88 |
+
|
89 |
+
inputs (tensor or tuple of tensors): Input for which attributions
|
90 |
+
are computed. If forward_func takes a single
|
91 |
+
tensor as input, a single input tensor should be provided.
|
92 |
+
If forward_func takes multiple tensors as input, a tuple
|
93 |
+
of the input tensors should be provided. It is assumed
|
94 |
+
that for all given input tensors, dimension 0 corresponds
|
95 |
+
to the number of examples, and if multiple input tensors
|
96 |
+
are provided, the examples must be aligned appropriately.
|
97 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
98 |
+
which gradients are computed (for classification cases,
|
99 |
+
this is usually the target class).
|
100 |
+
If the network returns a scalar value per example,
|
101 |
+
no target index is necessary.
|
102 |
+
For general 2D outputs, targets can be either:
|
103 |
+
|
104 |
+
- a single integer or a tensor containing a single
|
105 |
+
integer, which is applied to all input examples
|
106 |
+
|
107 |
+
- a list of integers or a 1D tensor, with length matching
|
108 |
+
the number of examples in inputs (dim 0). Each integer
|
109 |
+
is applied as the target for the corresponding example.
|
110 |
+
|
111 |
+
For outputs with > 2 dimensions, targets can be either:
|
112 |
+
|
113 |
+
- A single tuple, which contains #output_dims - 1
|
114 |
+
elements. This target index is applied to all examples.
|
115 |
+
|
116 |
+
- A list of tuples with length equal to the number of
|
117 |
+
examples in inputs (dim 0), and each tuple containing
|
118 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
119 |
+
target for the corresponding example.
|
120 |
+
|
121 |
+
Default: None
|
122 |
+
additional_forward_args (any, optional): If the forward function
|
123 |
+
requires additional arguments other than the inputs for
|
124 |
+
which attributions should not be computed, this argument
|
125 |
+
can be provided. It must be either a single additional
|
126 |
+
argument of a Tensor or arbitrary (non-tuple) type or a
|
127 |
+
tuple containing multiple additional arguments including
|
128 |
+
tensors or any arbitrary python types. These arguments
|
129 |
+
are provided to forward_func in order following the
|
130 |
+
arguments in inputs.
|
131 |
+
Note that attributions are not computed with respect
|
132 |
+
to these arguments.
|
133 |
+
Default: None
|
134 |
+
attribute_to_layer_input (bool, optional): Indicates whether to
|
135 |
+
compute the attributions with respect to the layer input
|
136 |
+
or output. If `attribute_to_layer_input` is set to True
|
137 |
+
then the attributions will be computed with respect to the
|
138 |
+
layer input, otherwise it will be computed with respect
|
139 |
+
to layer output.
|
140 |
+
Note that currently it is assumed that either the input
|
141 |
+
or the outputs of internal layers, depending on whether we
|
142 |
+
attribute to the input or output, are single tensors.
|
143 |
+
Support for multiple tensors will be added later.
|
144 |
+
Default: False
|
145 |
+
relu_attributions (bool, optional): Indicates whether to
|
146 |
+
apply a ReLU operation on the final attribution,
|
147 |
+
returning only non-negative attributions. Setting this
|
148 |
+
flag to True matches the original GradCAM algorithm,
|
149 |
+
otherwise, by default, both positive and negative
|
150 |
+
attributions are returned.
|
151 |
+
Default: False
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
*tensor* or tuple of *tensors* of **attributions**:
|
155 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
156 |
+
Attributions based on GradCAM method.
|
157 |
+
Attributions will be the same size as the
|
158 |
+
output of the given layer, except for dimension 2,
|
159 |
+
which will be 1 due to summing over channels.
|
160 |
+
Attributions are returned in a tuple if
|
161 |
+
the layer inputs / outputs contain multiple tensors,
|
162 |
+
otherwise a single tensor is returned.
|
163 |
+
Examples::
|
164 |
+
|
165 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
166 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
167 |
+
>>> # It contains a layer conv4, which is an instance of nn.conv2d,
|
168 |
+
>>> # and the output of this layer has dimensions Nx50x8x8.
|
169 |
+
>>> # It is the last convolution layer, which is the recommended
|
170 |
+
>>> # use case for GradCAM.
|
171 |
+
>>> net = ImageClassifier()
|
172 |
+
>>> layer_gc = LayerGradCam(net, net.conv4)
|
173 |
+
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
|
174 |
+
>>> # Computes layer GradCAM for class 3.
|
175 |
+
>>> # attribution size matches layer output except for dimension
|
176 |
+
>>> # 1, so dimensions of attr would be Nx1x8x8.
|
177 |
+
>>> attr = layer_gc.attribute(input, 3)
|
178 |
+
>>> # GradCAM attributions are often upsampled and viewed as a
|
179 |
+
>>> # mask to the input, since the convolutional layer output
|
180 |
+
>>> # spatially matches the original input image.
|
181 |
+
>>> # This can be done with LayerAttribution's interpolate method.
|
182 |
+
>>> upsampled_attr = LayerAttribution.interpolate(attr, (32, 32))
|
183 |
+
"""
|
184 |
+
inputs = _format_tensor_into_tuples(inputs)
|
185 |
+
additional_forward_args = _format_additional_forward_args(
|
186 |
+
additional_forward_args
|
187 |
+
)
|
188 |
+
# Returns gradient of output with respect to
|
189 |
+
# hidden layer and hidden layer evaluated at each input.
|
190 |
+
layer_gradients, layer_evals = compute_layer_gradients_and_eval(
|
191 |
+
self.forward_func,
|
192 |
+
self.layer,
|
193 |
+
inputs,
|
194 |
+
target,
|
195 |
+
additional_forward_args,
|
196 |
+
device_ids=self.device_ids,
|
197 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
198 |
+
)
|
199 |
+
|
200 |
+
summed_grads = tuple(
|
201 |
+
torch.mean(
|
202 |
+
layer_grad,
|
203 |
+
dim=tuple(x for x in range(2, len(layer_grad.shape))),
|
204 |
+
keepdim=True,
|
205 |
+
)
|
206 |
+
if len(layer_grad.shape) > 2
|
207 |
+
else layer_grad
|
208 |
+
for layer_grad in layer_gradients
|
209 |
+
)
|
210 |
+
|
211 |
+
scaled_acts = tuple(
|
212 |
+
torch.sum(summed_grad * layer_eval, dim=1, keepdim=True)
|
213 |
+
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
|
214 |
+
)
|
215 |
+
if relu_attributions:
|
216 |
+
scaled_acts = tuple(F.relu(scaled_act) for scaled_act in scaled_acts)
|
217 |
+
return _format_output(len(scaled_acts) > 1, scaled_acts)
|
captum/attr/_core/layer/internal_influence.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from typing import Any, Callable, List, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from captum._utils.common import (
|
6 |
+
_expand_additional_forward_args,
|
7 |
+
_expand_target,
|
8 |
+
_format_additional_forward_args,
|
9 |
+
_format_output,
|
10 |
+
)
|
11 |
+
from captum._utils.gradient import compute_layer_gradients_and_eval
|
12 |
+
from captum._utils.typing import BaselineType, TargetType
|
13 |
+
from captum.attr._utils.approximation_methods import approximation_parameters
|
14 |
+
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
|
15 |
+
from captum.attr._utils.batching import _batch_attribution
|
16 |
+
from captum.attr._utils.common import (
|
17 |
+
_format_input_baseline,
|
18 |
+
_reshape_and_sum,
|
19 |
+
_validate_input,
|
20 |
+
)
|
21 |
+
from captum.log import log_usage
|
22 |
+
from torch import Tensor
|
23 |
+
from torch.nn import Module
|
24 |
+
|
25 |
+
|
26 |
+
class InternalInfluence(LayerAttribution, GradientAttribution):
|
27 |
+
r"""
|
28 |
+
Computes internal influence by approximating the integral of gradients
|
29 |
+
for a particular layer along the path from a baseline input to the
|
30 |
+
given input.
|
31 |
+
If no baseline is provided, the default baseline is the zero tensor.
|
32 |
+
More details on this approach can be found here:
|
33 |
+
https://arxiv.org/pdf/1802.03788.pdf
|
34 |
+
|
35 |
+
Note that this method is similar to applying integrated gradients and
|
36 |
+
taking the layer as input, integrating the gradient of the layer with
|
37 |
+
respect to the output.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
forward_func: Callable,
|
43 |
+
layer: Module,
|
44 |
+
device_ids: Union[None, List[int]] = None,
|
45 |
+
) -> None:
|
46 |
+
r"""
|
47 |
+
Args:
|
48 |
+
|
49 |
+
forward_func (callable): The forward function of the model or any
|
50 |
+
modification of it
|
51 |
+
layer (torch.nn.Module): Layer for which attributions are computed.
|
52 |
+
Output size of attribute matches this layer's input or
|
53 |
+
output dimensions, depending on whether we attribute to
|
54 |
+
the inputs or outputs of the layer, corresponding to
|
55 |
+
attribution of each neuron in the input or output of
|
56 |
+
this layer.
|
57 |
+
device_ids (list(int)): Device ID list, necessary only if forward_func
|
58 |
+
applies a DataParallel model. This allows reconstruction of
|
59 |
+
intermediate outputs from batched results across devices.
|
60 |
+
If forward_func is given as the DataParallel model itself,
|
61 |
+
then it is not necessary to provide this argument.
|
62 |
+
"""
|
63 |
+
LayerAttribution.__init__(self, forward_func, layer, device_ids)
|
64 |
+
GradientAttribution.__init__(self, forward_func)
|
65 |
+
|
66 |
+
@log_usage()
|
67 |
+
def attribute(
|
68 |
+
self,
|
69 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
70 |
+
baselines: BaselineType = None,
|
71 |
+
target: TargetType = None,
|
72 |
+
additional_forward_args: Any = None,
|
73 |
+
n_steps: int = 50,
|
74 |
+
method: str = "gausslegendre",
|
75 |
+
internal_batch_size: Union[None, int] = None,
|
76 |
+
attribute_to_layer_input: bool = False,
|
77 |
+
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
78 |
+
r"""
|
79 |
+
Args:
|
80 |
+
|
81 |
+
inputs (tensor or tuple of tensors): Input for which internal
|
82 |
+
influence is computed. If forward_func takes a single
|
83 |
+
tensor as input, a single input tensor should be provided.
|
84 |
+
If forward_func takes multiple tensors as input, a tuple
|
85 |
+
of the input tensors should be provided. It is assumed
|
86 |
+
that for all given input tensors, dimension 0 corresponds
|
87 |
+
to the number of examples, and if multiple input tensors
|
88 |
+
are provided, the examples must be aligned appropriately.
|
89 |
+
baselines scalar, tensor, tuple of scalars or tensors, optional):
|
90 |
+
Baselines define a starting point from which integral
|
91 |
+
is computed and can be provided as:
|
92 |
+
|
93 |
+
- a single tensor, if inputs is a single tensor, with
|
94 |
+
exactly the same dimensions as inputs or the first
|
95 |
+
dimension is one and the remaining dimensions match
|
96 |
+
with inputs.
|
97 |
+
|
98 |
+
- a single scalar, if inputs is a single tensor, which will
|
99 |
+
be broadcasted for each input value in input tensor.
|
100 |
+
|
101 |
+
- a tuple of tensors or scalars, the baseline corresponding
|
102 |
+
to each tensor in the inputs' tuple can be:
|
103 |
+
|
104 |
+
- either a tensor with matching dimensions to
|
105 |
+
corresponding tensor in the inputs' tuple
|
106 |
+
or the first dimension is one and the remaining
|
107 |
+
dimensions match with the corresponding
|
108 |
+
input tensor.
|
109 |
+
|
110 |
+
- or a scalar, corresponding to a tensor in the
|
111 |
+
inputs' tuple. This scalar value is broadcasted
|
112 |
+
for corresponding input tensor.
|
113 |
+
|
114 |
+
In the cases when `baselines` is not provided, we internally
|
115 |
+
use zero scalar corresponding to each input tensor.
|
116 |
+
|
117 |
+
Default: None
|
118 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
119 |
+
which gradients are computed (for classification cases,
|
120 |
+
this is usually the target class).
|
121 |
+
If the network returns a scalar value per example,
|
122 |
+
no target index is necessary.
|
123 |
+
For general 2D outputs, targets can be either:
|
124 |
+
|
125 |
+
- a single integer or a tensor containing a single
|
126 |
+
integer, which is applied to all input examples
|
127 |
+
|
128 |
+
- a list of integers or a 1D tensor, with length matching
|
129 |
+
the number of examples in inputs (dim 0). Each integer
|
130 |
+
is applied as the target for the corresponding example.
|
131 |
+
|
132 |
+
For outputs with > 2 dimensions, targets can be either:
|
133 |
+
|
134 |
+
- A single tuple, which contains #output_dims - 1
|
135 |
+
elements. This target index is applied to all examples.
|
136 |
+
|
137 |
+
- A list of tuples with length equal to the number of
|
138 |
+
examples in inputs (dim 0), and each tuple containing
|
139 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
140 |
+
target for the corresponding example.
|
141 |
+
|
142 |
+
Default: None
|
143 |
+
additional_forward_args (any, optional): If the forward function
|
144 |
+
requires additional arguments other than the inputs for
|
145 |
+
which attributions should not be computed, this argument
|
146 |
+
can be provided. It must be either a single additional
|
147 |
+
argument of a Tensor or arbitrary (non-tuple) type or a
|
148 |
+
tuple containing multiple additional arguments including
|
149 |
+
tensors or any arbitrary python types. These arguments
|
150 |
+
are provided to forward_func in order following the
|
151 |
+
arguments in inputs.
|
152 |
+
For a tensor, the first dimension of the tensor must
|
153 |
+
correspond to the number of examples. It will be
|
154 |
+
repeated for each of `n_steps` along the integrated
|
155 |
+
path. For all other types, the given argument is used
|
156 |
+
for all forward evaluations.
|
157 |
+
Note that attributions are not computed with respect
|
158 |
+
to these arguments.
|
159 |
+
Default: None
|
160 |
+
n_steps (int, optional): The number of steps used by the approximation
|
161 |
+
method. Default: 50.
|
162 |
+
method (string, optional): Method for approximating the integral,
|
163 |
+
one of `riemann_right`, `riemann_left`, `riemann_middle`,
|
164 |
+
`riemann_trapezoid` or `gausslegendre`.
|
165 |
+
Default: `gausslegendre` if no method is provided.
|
166 |
+
internal_batch_size (int, optional): Divides total #steps * #examples
|
167 |
+
data points into chunks of size at most internal_batch_size,
|
168 |
+
which are computed (forward / backward passes)
|
169 |
+
sequentially. internal_batch_size must be at least equal to
|
170 |
+
#examples.
|
171 |
+
For DataParallel models, each batch is split among the
|
172 |
+
available devices, so evaluations on each available
|
173 |
+
device contain internal_batch_size / num_devices examples.
|
174 |
+
If internal_batch_size is None, then all evaluations
|
175 |
+
are processed in one batch.
|
176 |
+
Default: None
|
177 |
+
attribute_to_layer_input (bool, optional): Indicates whether to
|
178 |
+
compute the attribution with respect to the layer input
|
179 |
+
or output. If `attribute_to_layer_input` is set to True
|
180 |
+
then the attributions will be computed with respect to
|
181 |
+
layer inputs, otherwise it will be computed with respect
|
182 |
+
to layer outputs.
|
183 |
+
Note that currently it is assumed that either the input
|
184 |
+
or the output of internal layer, depending on whether we
|
185 |
+
attribute to the input or output, is a single tensor.
|
186 |
+
Support for multiple tensors will be added later.
|
187 |
+
Default: False
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
*tensor* or tuple of *tensors* of **attributions**:
|
191 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
192 |
+
Internal influence of each neuron in given
|
193 |
+
layer output. Attributions will always be the same size
|
194 |
+
as the output or input of the given layer depending on
|
195 |
+
whether `attribute_to_layer_input` is set to `False` or
|
196 |
+
`True`respectively.
|
197 |
+
Attributions are returned in a tuple if
|
198 |
+
the layer inputs / outputs contain multiple tensors,
|
199 |
+
otherwise a single tensor is returned.
|
200 |
+
|
201 |
+
Examples::
|
202 |
+
|
203 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
204 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
205 |
+
>>> # It contains an attribute conv1, which is an instance of nn.conv2d,
|
206 |
+
>>> # and the output of this layer has dimensions Nx12x32x32.
|
207 |
+
>>> net = ImageClassifier()
|
208 |
+
>>> layer_int_inf = InternalInfluence(net, net.conv1)
|
209 |
+
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
|
210 |
+
>>> # Computes layer internal influence.
|
211 |
+
>>> # attribution size matches layer output, Nx12x32x32
|
212 |
+
>>> attribution = layer_int_inf.attribute(input)
|
213 |
+
"""
|
214 |
+
inputs, baselines = _format_input_baseline(inputs, baselines)
|
215 |
+
_validate_input(inputs, baselines, n_steps, method)
|
216 |
+
if internal_batch_size is not None:
|
217 |
+
num_examples = inputs[0].shape[0]
|
218 |
+
attrs = _batch_attribution(
|
219 |
+
self,
|
220 |
+
num_examples,
|
221 |
+
internal_batch_size,
|
222 |
+
n_steps,
|
223 |
+
inputs=inputs,
|
224 |
+
baselines=baselines,
|
225 |
+
target=target,
|
226 |
+
additional_forward_args=additional_forward_args,
|
227 |
+
method=method,
|
228 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
229 |
+
)
|
230 |
+
else:
|
231 |
+
attrs = self._attribute(
|
232 |
+
inputs=inputs,
|
233 |
+
baselines=baselines,
|
234 |
+
target=target,
|
235 |
+
additional_forward_args=additional_forward_args,
|
236 |
+
n_steps=n_steps,
|
237 |
+
method=method,
|
238 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
239 |
+
)
|
240 |
+
|
241 |
+
return attrs
|
242 |
+
|
243 |
+
def _attribute(
|
244 |
+
self,
|
245 |
+
inputs: Tuple[Tensor, ...],
|
246 |
+
baselines: Tuple[Union[Tensor, int, float], ...],
|
247 |
+
target: TargetType = None,
|
248 |
+
additional_forward_args: Any = None,
|
249 |
+
n_steps: int = 50,
|
250 |
+
method: str = "gausslegendre",
|
251 |
+
attribute_to_layer_input: bool = False,
|
252 |
+
step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,
|
253 |
+
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
254 |
+
if step_sizes_and_alphas is None:
|
255 |
+
# retrieve step size and scaling factor for specified approximation method
|
256 |
+
step_sizes_func, alphas_func = approximation_parameters(method)
|
257 |
+
step_sizes, alphas = step_sizes_func(n_steps), alphas_func(n_steps)
|
258 |
+
else:
|
259 |
+
step_sizes, alphas = step_sizes_and_alphas
|
260 |
+
|
261 |
+
# Compute scaled inputs from baseline to final input.
|
262 |
+
scaled_features_tpl = tuple(
|
263 |
+
torch.cat(
|
264 |
+
[baseline + alpha * (input - baseline) for alpha in alphas], dim=0
|
265 |
+
).requires_grad_()
|
266 |
+
for input, baseline in zip(inputs, baselines)
|
267 |
+
)
|
268 |
+
|
269 |
+
additional_forward_args = _format_additional_forward_args(
|
270 |
+
additional_forward_args
|
271 |
+
)
|
272 |
+
# apply number of steps to additional forward args
|
273 |
+
# currently, number of steps is applied only to additional forward arguments
|
274 |
+
# that are nd-tensors. It is assumed that the first dimension is
|
275 |
+
# the number of batches.
|
276 |
+
# dim -> (bsz * #steps x additional_forward_args[0].shape[1:], ...)
|
277 |
+
input_additional_args = (
|
278 |
+
_expand_additional_forward_args(additional_forward_args, n_steps)
|
279 |
+
if additional_forward_args is not None
|
280 |
+
else None
|
281 |
+
)
|
282 |
+
expanded_target = _expand_target(target, n_steps)
|
283 |
+
|
284 |
+
# Returns gradient of output with respect to hidden layer.
|
285 |
+
layer_gradients, _ = compute_layer_gradients_and_eval(
|
286 |
+
forward_fn=self.forward_func,
|
287 |
+
layer=self.layer,
|
288 |
+
inputs=scaled_features_tpl,
|
289 |
+
target_ind=expanded_target,
|
290 |
+
additional_forward_args=input_additional_args,
|
291 |
+
device_ids=self.device_ids,
|
292 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
293 |
+
)
|
294 |
+
# flattening grads so that we can multiply it with step-size
|
295 |
+
# calling contiguous to avoid `memory whole` problems
|
296 |
+
scaled_grads = tuple(
|
297 |
+
layer_grad.contiguous().view(n_steps, -1)
|
298 |
+
* torch.tensor(step_sizes).view(n_steps, 1).to(layer_grad.device)
|
299 |
+
for layer_grad in layer_gradients
|
300 |
+
)
|
301 |
+
|
302 |
+
# aggregates across all steps for each tensor in the input tuple
|
303 |
+
attrs = tuple(
|
304 |
+
_reshape_and_sum(
|
305 |
+
scaled_grad, n_steps, inputs[0].shape[0], layer_grad.shape[1:]
|
306 |
+
)
|
307 |
+
for scaled_grad, layer_grad in zip(scaled_grads, layer_gradients)
|
308 |
+
)
|
309 |
+
return _format_output(len(attrs) > 1, attrs)
|
captum/attr/_core/layer/layer_activation.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from typing import Any, Callable, List, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from captum._utils.common import _format_output
|
6 |
+
from captum._utils.gradient import _forward_layer_eval
|
7 |
+
from captum._utils.typing import ModuleOrModuleList
|
8 |
+
from captum.attr._utils.attribution import LayerAttribution
|
9 |
+
from captum.log import log_usage
|
10 |
+
from torch import Tensor
|
11 |
+
from torch.nn import Module
|
12 |
+
|
13 |
+
|
14 |
+
class LayerActivation(LayerAttribution):
|
15 |
+
r"""
|
16 |
+
Computes activation of selected layer for given input.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
forward_func: Callable,
|
22 |
+
layer: ModuleOrModuleList,
|
23 |
+
device_ids: Union[None, List[int]] = None,
|
24 |
+
) -> None:
|
25 |
+
r"""
|
26 |
+
Args:
|
27 |
+
|
28 |
+
forward_func (callable): The forward function of the model or any
|
29 |
+
modification of it
|
30 |
+
layer (torch.nn.Module or list(torch.nn.Module)): Layer or layers
|
31 |
+
for which attributions are computed.
|
32 |
+
Output size of attribute matches this layer's input or
|
33 |
+
output dimensions, depending on whether we attribute to
|
34 |
+
the inputs or outputs of the layer, corresponding to
|
35 |
+
attribution of each neuron in the input or output of
|
36 |
+
this layer. If multiple layers are provided, attributions
|
37 |
+
are returned as a list, each element corresponding to the
|
38 |
+
activations of the corresponding layer.
|
39 |
+
device_ids (list(int)): Device ID list, necessary only if forward_func
|
40 |
+
applies a DataParallel model. This allows reconstruction of
|
41 |
+
intermediate outputs from batched results across devices.
|
42 |
+
If forward_func is given as the DataParallel model itself,
|
43 |
+
then it is not necessary to provide this argument.
|
44 |
+
"""
|
45 |
+
LayerAttribution.__init__(self, forward_func, layer, device_ids)
|
46 |
+
|
47 |
+
@log_usage()
|
48 |
+
def attribute(
|
49 |
+
self,
|
50 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
51 |
+
additional_forward_args: Any = None,
|
52 |
+
attribute_to_layer_input: bool = False,
|
53 |
+
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
|
54 |
+
r"""
|
55 |
+
Args:
|
56 |
+
|
57 |
+
inputs (tensor or tuple of tensors): Input for which layer
|
58 |
+
activation is computed. If forward_func takes a single
|
59 |
+
tensor as input, a single input tensor should be provided.
|
60 |
+
If forward_func takes multiple tensors as input, a tuple
|
61 |
+
of the input tensors should be provided. It is assumed
|
62 |
+
that for all given input tensors, dimension 0 corresponds
|
63 |
+
to the number of examples, and if multiple input tensors
|
64 |
+
are provided, the examples must be aligned appropriately.
|
65 |
+
additional_forward_args (any, optional): If the forward function
|
66 |
+
requires additional arguments other than the inputs for
|
67 |
+
which attributions should not be computed, this argument
|
68 |
+
can be provided. It must be either a single additional
|
69 |
+
argument of a Tensor or arbitrary (non-tuple) type or a
|
70 |
+
tuple containing multiple additional arguments including
|
71 |
+
tensors or any arbitrary python types. These arguments
|
72 |
+
are provided to forward_func in order following the
|
73 |
+
arguments in inputs.
|
74 |
+
Note that attributions are not computed with respect
|
75 |
+
to these arguments.
|
76 |
+
Default: None
|
77 |
+
attribute_to_layer_input (bool, optional): Indicates whether to
|
78 |
+
compute the attribution with respect to the layer input
|
79 |
+
or output. If `attribute_to_layer_input` is set to True
|
80 |
+
then the attributions will be computed with respect to
|
81 |
+
layer input, otherwise it will be computed with respect
|
82 |
+
to layer output.
|
83 |
+
Note that currently it is assumed that either the input
|
84 |
+
or the output of internal layer, depending on whether we
|
85 |
+
attribute to the input or output, is a single tensor.
|
86 |
+
Support for multiple tensors will be added later.
|
87 |
+
Default: False
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
*tensor* or tuple of *tensors* or *list* of **attributions**:
|
91 |
+
- **attributions** (*tensor* or tuple of *tensors* or *list*):
|
92 |
+
Activation of each neuron in given layer output.
|
93 |
+
Attributions will always be the same size as the
|
94 |
+
output of the given layer.
|
95 |
+
Attributions are returned in a tuple if
|
96 |
+
the layer inputs / outputs contain multiple tensors,
|
97 |
+
otherwise a single tensor is returned.
|
98 |
+
If multiple layers are provided, attributions
|
99 |
+
are returned as a list, each element corresponding to the
|
100 |
+
activations of the corresponding layer.
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
Examples::
|
105 |
+
|
106 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
107 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
108 |
+
>>> # It contains an attribute conv1, which is an instance of nn.conv2d,
|
109 |
+
>>> # and the output of this layer has dimensions Nx12x32x32.
|
110 |
+
>>> net = ImageClassifier()
|
111 |
+
>>> layer_act = LayerActivation(net, net.conv1)
|
112 |
+
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
|
113 |
+
>>> # Computes layer activation.
|
114 |
+
>>> # attribution is layer output, with size Nx12x32x32
|
115 |
+
>>> attribution = layer_cond.attribute(input)
|
116 |
+
"""
|
117 |
+
with torch.no_grad():
|
118 |
+
layer_eval = _forward_layer_eval(
|
119 |
+
self.forward_func,
|
120 |
+
inputs,
|
121 |
+
self.layer,
|
122 |
+
additional_forward_args,
|
123 |
+
device_ids=self.device_ids,
|
124 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
125 |
+
)
|
126 |
+
if isinstance(self.layer, Module):
|
127 |
+
return _format_output(len(layer_eval) > 1, layer_eval)
|
128 |
+
else:
|
129 |
+
return [
|
130 |
+
_format_output(len(single_layer_eval) > 1, single_layer_eval)
|
131 |
+
for single_layer_eval in layer_eval
|
132 |
+
]
|
133 |
+
|
134 |
+
@property
|
135 |
+
def multiplies_by_inputs(self):
|
136 |
+
return True
|
captum/attr/_core/layer/layer_conductance.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import typing
|
3 |
+
from typing import Any, Callable, List, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from captum._utils.common import (
|
7 |
+
_expand_additional_forward_args,
|
8 |
+
_expand_target,
|
9 |
+
_format_additional_forward_args,
|
10 |
+
_format_output,
|
11 |
+
)
|
12 |
+
from captum._utils.gradient import compute_layer_gradients_and_eval
|
13 |
+
from captum._utils.typing import BaselineType, Literal, TargetType
|
14 |
+
from captum.attr._utils.approximation_methods import approximation_parameters
|
15 |
+
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
|
16 |
+
from captum.attr._utils.batching import _batch_attribution
|
17 |
+
from captum.attr._utils.common import (
|
18 |
+
_format_input_baseline,
|
19 |
+
_reshape_and_sum,
|
20 |
+
_validate_input,
|
21 |
+
)
|
22 |
+
from captum.log import log_usage
|
23 |
+
from torch import Tensor
|
24 |
+
from torch.nn import Module
|
25 |
+
|
26 |
+
|
27 |
+
class LayerConductance(LayerAttribution, GradientAttribution):
|
28 |
+
r"""
|
29 |
+
Computes conductance with respect to the given layer. The
|
30 |
+
returned output is in the shape of the layer's output, showing the total
|
31 |
+
conductance of each hidden layer neuron.
|
32 |
+
|
33 |
+
The details of the approach can be found here:
|
34 |
+
https://arxiv.org/abs/1805.12233
|
35 |
+
https://arxiv.org/pdf/1807.09946.pdf
|
36 |
+
|
37 |
+
Note that this provides the total conductance of each neuron in the
|
38 |
+
layer's output. To obtain the breakdown of a neuron's conductance by input
|
39 |
+
features, utilize NeuronConductance instead, and provide the target
|
40 |
+
neuron index.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
forward_func: Callable,
|
46 |
+
layer: Module,
|
47 |
+
device_ids: Union[None, List[int]] = None,
|
48 |
+
) -> None:
|
49 |
+
r"""
|
50 |
+
Args:
|
51 |
+
|
52 |
+
forward_func (callable): The forward function of the model or any
|
53 |
+
modification of it
|
54 |
+
layer (torch.nn.Module): Layer for which attributions are computed.
|
55 |
+
Output size of attribute matches this layer's input or
|
56 |
+
output dimensions, depending on whether we attribute to
|
57 |
+
the inputs or outputs of the layer, corresponding to
|
58 |
+
attribution of each neuron in the input or output of
|
59 |
+
this layer.
|
60 |
+
device_ids (list(int)): Device ID list, necessary only if forward_func
|
61 |
+
applies a DataParallel model. This allows reconstruction of
|
62 |
+
intermediate outputs from batched results across devices.
|
63 |
+
If forward_func is given as the DataParallel model itself,
|
64 |
+
then it is not necessary to provide this argument.
|
65 |
+
"""
|
66 |
+
LayerAttribution.__init__(self, forward_func, layer, device_ids)
|
67 |
+
GradientAttribution.__init__(self, forward_func)
|
68 |
+
|
69 |
+
def has_convergence_delta(self) -> bool:
|
70 |
+
return True
|
71 |
+
|
72 |
+
@typing.overload
|
73 |
+
def attribute(
|
74 |
+
self,
|
75 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
76 |
+
baselines: BaselineType = None,
|
77 |
+
target: TargetType = None,
|
78 |
+
additional_forward_args: Any = None,
|
79 |
+
n_steps: int = 50,
|
80 |
+
method: str = "gausslegendre",
|
81 |
+
internal_batch_size: Union[None, int] = None,
|
82 |
+
*,
|
83 |
+
return_convergence_delta: Literal[True],
|
84 |
+
attribute_to_layer_input: bool = False,
|
85 |
+
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]:
|
86 |
+
...
|
87 |
+
|
88 |
+
@typing.overload
|
89 |
+
def attribute(
|
90 |
+
self,
|
91 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
92 |
+
baselines: BaselineType = None,
|
93 |
+
target: TargetType = None,
|
94 |
+
additional_forward_args: Any = None,
|
95 |
+
n_steps: int = 50,
|
96 |
+
method: str = "gausslegendre",
|
97 |
+
internal_batch_size: Union[None, int] = None,
|
98 |
+
return_convergence_delta: Literal[False] = False,
|
99 |
+
attribute_to_layer_input: bool = False,
|
100 |
+
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
101 |
+
...
|
102 |
+
|
103 |
+
@log_usage()
|
104 |
+
def attribute(
|
105 |
+
self,
|
106 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
107 |
+
baselines: Union[
|
108 |
+
None, int, float, Tensor, Tuple[Union[int, float, Tensor], ...]
|
109 |
+
] = None,
|
110 |
+
target: TargetType = None,
|
111 |
+
additional_forward_args: Any = None,
|
112 |
+
n_steps: int = 50,
|
113 |
+
method: str = "gausslegendre",
|
114 |
+
internal_batch_size: Union[None, int] = None,
|
115 |
+
return_convergence_delta: bool = False,
|
116 |
+
attribute_to_layer_input: bool = False,
|
117 |
+
) -> Union[
|
118 |
+
Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
|
119 |
+
]:
|
120 |
+
r"""
|
121 |
+
Args:
|
122 |
+
|
123 |
+
inputs (tensor or tuple of tensors): Input for which layer
|
124 |
+
conductance is computed. If forward_func takes a single
|
125 |
+
tensor as input, a single input tensor should be provided.
|
126 |
+
If forward_func takes multiple tensors as input, a tuple
|
127 |
+
of the input tensors should be provided. It is assumed
|
128 |
+
that for all given input tensors, dimension 0 corresponds
|
129 |
+
to the number of examples, and if multiple input tensors
|
130 |
+
are provided, the examples must be aligned appropriately.
|
131 |
+
baselines (scalar, tensor, tuple of scalars or tensors, optional):
|
132 |
+
Baselines define the starting point from which integral
|
133 |
+
is computed and can be provided as:
|
134 |
+
|
135 |
+
- a single tensor, if inputs is a single tensor, with
|
136 |
+
exactly the same dimensions as inputs or the first
|
137 |
+
dimension is one and the remaining dimensions match
|
138 |
+
with inputs.
|
139 |
+
|
140 |
+
- a single scalar, if inputs is a single tensor, which will
|
141 |
+
be broadcasted for each input value in input tensor.
|
142 |
+
|
143 |
+
- a tuple of tensors or scalars, the baseline corresponding
|
144 |
+
to each tensor in the inputs' tuple can be:
|
145 |
+
|
146 |
+
- either a tensor with matching dimensions to
|
147 |
+
corresponding tensor in the inputs' tuple
|
148 |
+
or the first dimension is one and the remaining
|
149 |
+
dimensions match with the corresponding
|
150 |
+
input tensor.
|
151 |
+
|
152 |
+
- or a scalar, corresponding to a tensor in the
|
153 |
+
inputs' tuple. This scalar value is broadcasted
|
154 |
+
for corresponding input tensor.
|
155 |
+
In the cases when `baselines` is not provided, we internally
|
156 |
+
use zero scalar corresponding to each input tensor.
|
157 |
+
|
158 |
+
Default: None
|
159 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
160 |
+
which gradients are computed (for classification cases,
|
161 |
+
this is usually the target class).
|
162 |
+
If the network returns a scalar value per example,
|
163 |
+
no target index is necessary.
|
164 |
+
For general 2D outputs, targets can be either:
|
165 |
+
|
166 |
+
- a single integer or a tensor containing a single
|
167 |
+
integer, which is applied to all input examples
|
168 |
+
|
169 |
+
- a list of integers or a 1D tensor, with length matching
|
170 |
+
the number of examples in inputs (dim 0). Each integer
|
171 |
+
is applied as the target for the corresponding example.
|
172 |
+
|
173 |
+
For outputs with > 2 dimensions, targets can be either:
|
174 |
+
|
175 |
+
- A single tuple, which contains #output_dims - 1
|
176 |
+
elements. This target index is applied to all examples.
|
177 |
+
|
178 |
+
- A list of tuples with length equal to the number of
|
179 |
+
examples in inputs (dim 0), and each tuple containing
|
180 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
181 |
+
target for the corresponding example.
|
182 |
+
|
183 |
+
Default: None
|
184 |
+
additional_forward_args (any, optional): If the forward function
|
185 |
+
requires additional arguments other than the inputs for
|
186 |
+
which attributions should not be computed, this argument
|
187 |
+
can be provided. It must be either a single additional
|
188 |
+
argument of a Tensor or arbitrary (non-tuple) type or a
|
189 |
+
tuple containing multiple additional arguments including
|
190 |
+
tensors or any arbitrary python types. These arguments
|
191 |
+
are provided to forward_func in order following the
|
192 |
+
arguments in inputs.
|
193 |
+
For a tensor, the first dimension of the tensor must
|
194 |
+
correspond to the number of examples. It will be repeated
|
195 |
+
for each of `n_steps` along the integrated path.
|
196 |
+
For all other types, the given argument is used for
|
197 |
+
all forward evaluations.
|
198 |
+
Note that attributions are not computed with respect
|
199 |
+
to these arguments.
|
200 |
+
Default: None
|
201 |
+
n_steps (int, optional): The number of steps used by the approximation
|
202 |
+
method. Default: 50.
|
203 |
+
method (string, optional): Method for approximating the integral,
|
204 |
+
one of `riemann_right`, `riemann_left`, `riemann_middle`,
|
205 |
+
`riemann_trapezoid` or `gausslegendre`.
|
206 |
+
Default: `gausslegendre` if no method is provided.
|
207 |
+
internal_batch_size (int, optional): Divides total #steps * #examples
|
208 |
+
data points into chunks of size at most internal_batch_size,
|
209 |
+
which are computed (forward / backward passes)
|
210 |
+
sequentially. internal_batch_size must be at least equal to
|
211 |
+
2 * #examples.
|
212 |
+
For DataParallel models, each batch is split among the
|
213 |
+
available devices, so evaluations on each available
|
214 |
+
device contain internal_batch_size / num_devices examples.
|
215 |
+
If internal_batch_size is None, then all evaluations are
|
216 |
+
processed in one batch.
|
217 |
+
Default: None
|
218 |
+
return_convergence_delta (bool, optional): Indicates whether to return
|
219 |
+
convergence delta or not. If `return_convergence_delta`
|
220 |
+
is set to True convergence delta will be returned in
|
221 |
+
a tuple following attributions.
|
222 |
+
Default: False
|
223 |
+
attribute_to_layer_input (bool, optional): Indicates whether to
|
224 |
+
compute the attribution with respect to the layer input
|
225 |
+
or output. If `attribute_to_layer_input` is set to True
|
226 |
+
then the attributions will be computed with respect to
|
227 |
+
layer inputs, otherwise it will be computed with respect
|
228 |
+
to layer outputs.
|
229 |
+
Note that currently it is assumed that either the input
|
230 |
+
or the output of internal layer, depending on whether we
|
231 |
+
attribute to the input or output, is a single tensor.
|
232 |
+
Support for multiple tensors will be added later.
|
233 |
+
Default: False
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
**attributions** or 2-element tuple of **attributions**, **delta**:
|
237 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
238 |
+
Conductance of each neuron in given layer input or
|
239 |
+
output. Attributions will always be the same size as
|
240 |
+
the input or output of the given layer, depending on
|
241 |
+
whether we attribute to the inputs or outputs
|
242 |
+
of the layer which is decided by the input flag
|
243 |
+
`attribute_to_layer_input`.
|
244 |
+
Attributions are returned in a tuple if
|
245 |
+
the layer inputs / outputs contain multiple tensors,
|
246 |
+
otherwise a single tensor is returned.
|
247 |
+
- **delta** (*tensor*, returned if return_convergence_delta=True):
|
248 |
+
The difference between the total
|
249 |
+
approximated and true conductance.
|
250 |
+
This is computed using the property that the total sum of
|
251 |
+
forward_func(inputs) - forward_func(baselines) must equal
|
252 |
+
the total sum of the attributions.
|
253 |
+
Delta is calculated per example, meaning that the number of
|
254 |
+
elements in returned delta tensor is equal to the number of
|
255 |
+
of examples in inputs.
|
256 |
+
|
257 |
+
Examples::
|
258 |
+
|
259 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
260 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
261 |
+
>>> # It contains an attribute conv1, which is an instance of nn.conv2d,
|
262 |
+
>>> # and the output of this layer has dimensions Nx12x32x32.
|
263 |
+
>>> net = ImageClassifier()
|
264 |
+
>>> layer_cond = LayerConductance(net, net.conv1)
|
265 |
+
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
|
266 |
+
>>> # Computes layer conductance for class 3.
|
267 |
+
>>> # attribution size matches layer output, Nx12x32x32
|
268 |
+
>>> attribution = layer_cond.attribute(input, target=3)
|
269 |
+
"""
|
270 |
+
inputs, baselines = _format_input_baseline(inputs, baselines)
|
271 |
+
_validate_input(inputs, baselines, n_steps, method)
|
272 |
+
|
273 |
+
num_examples = inputs[0].shape[0]
|
274 |
+
if internal_batch_size is not None:
|
275 |
+
num_examples = inputs[0].shape[0]
|
276 |
+
attrs = _batch_attribution(
|
277 |
+
self,
|
278 |
+
num_examples,
|
279 |
+
internal_batch_size,
|
280 |
+
n_steps + 1,
|
281 |
+
include_endpoint=True,
|
282 |
+
inputs=inputs,
|
283 |
+
baselines=baselines,
|
284 |
+
target=target,
|
285 |
+
additional_forward_args=additional_forward_args,
|
286 |
+
method=method,
|
287 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
288 |
+
)
|
289 |
+
|
290 |
+
else:
|
291 |
+
attrs = self._attribute(
|
292 |
+
inputs=inputs,
|
293 |
+
baselines=baselines,
|
294 |
+
target=target,
|
295 |
+
additional_forward_args=additional_forward_args,
|
296 |
+
n_steps=n_steps,
|
297 |
+
method=method,
|
298 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
299 |
+
)
|
300 |
+
|
301 |
+
is_layer_tuple = isinstance(attrs, tuple)
|
302 |
+
attributions = attrs if is_layer_tuple else (attrs,)
|
303 |
+
|
304 |
+
if return_convergence_delta:
|
305 |
+
start_point, end_point = baselines, inputs
|
306 |
+
delta = self.compute_convergence_delta(
|
307 |
+
attributions,
|
308 |
+
start_point,
|
309 |
+
end_point,
|
310 |
+
target=target,
|
311 |
+
additional_forward_args=additional_forward_args,
|
312 |
+
)
|
313 |
+
return _format_output(is_layer_tuple, attributions), delta
|
314 |
+
return _format_output(is_layer_tuple, attributions)
|
315 |
+
|
316 |
+
def _attribute(
|
317 |
+
self,
|
318 |
+
inputs: Tuple[Tensor, ...],
|
319 |
+
baselines: Tuple[Union[Tensor, int, float], ...],
|
320 |
+
target: TargetType = None,
|
321 |
+
additional_forward_args: Any = None,
|
322 |
+
n_steps: int = 50,
|
323 |
+
method: str = "gausslegendre",
|
324 |
+
attribute_to_layer_input: bool = False,
|
325 |
+
step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,
|
326 |
+
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
327 |
+
num_examples = inputs[0].shape[0]
|
328 |
+
if step_sizes_and_alphas is None:
|
329 |
+
# Retrieve scaling factors for specified approximation method
|
330 |
+
step_sizes_func, alphas_func = approximation_parameters(method)
|
331 |
+
alphas = alphas_func(n_steps + 1)
|
332 |
+
else:
|
333 |
+
_, alphas = step_sizes_and_alphas
|
334 |
+
# Compute scaled inputs from baseline to final input.
|
335 |
+
scaled_features_tpl = tuple(
|
336 |
+
torch.cat(
|
337 |
+
[baseline + alpha * (input - baseline) for alpha in alphas], dim=0
|
338 |
+
).requires_grad_()
|
339 |
+
for input, baseline in zip(inputs, baselines)
|
340 |
+
)
|
341 |
+
|
342 |
+
additional_forward_args = _format_additional_forward_args(
|
343 |
+
additional_forward_args
|
344 |
+
)
|
345 |
+
# apply number of steps to additional forward args
|
346 |
+
# currently, number of steps is applied only to additional forward arguments
|
347 |
+
# that are nd-tensors. It is assumed that the first dimension is
|
348 |
+
# the number of batches.
|
349 |
+
# dim -> (#examples * #steps x additional_forward_args[0].shape[1:], ...)
|
350 |
+
input_additional_args = (
|
351 |
+
_expand_additional_forward_args(additional_forward_args, n_steps + 1)
|
352 |
+
if additional_forward_args is not None
|
353 |
+
else None
|
354 |
+
)
|
355 |
+
expanded_target = _expand_target(target, n_steps + 1)
|
356 |
+
|
357 |
+
# Conductance Gradients - Returns gradient of output with respect to
|
358 |
+
# hidden layer and hidden layer evaluated at each input.
|
359 |
+
(layer_gradients, layer_evals,) = compute_layer_gradients_and_eval(
|
360 |
+
forward_fn=self.forward_func,
|
361 |
+
layer=self.layer,
|
362 |
+
inputs=scaled_features_tpl,
|
363 |
+
additional_forward_args=input_additional_args,
|
364 |
+
target_ind=expanded_target,
|
365 |
+
device_ids=self.device_ids,
|
366 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
367 |
+
)
|
368 |
+
|
369 |
+
# Compute differences between consecutive evaluations of layer_eval.
|
370 |
+
# This approximates the total input gradient of each step multiplied
|
371 |
+
# by the step size.
|
372 |
+
grad_diffs = tuple(
|
373 |
+
layer_eval[num_examples:] - layer_eval[:-num_examples]
|
374 |
+
for layer_eval in layer_evals
|
375 |
+
)
|
376 |
+
|
377 |
+
# Element-wise multiply gradient of output with respect to hidden layer
|
378 |
+
# and summed gradients with respect to input (chain rule) and sum
|
379 |
+
# across stepped inputs.
|
380 |
+
attributions = tuple(
|
381 |
+
_reshape_and_sum(
|
382 |
+
grad_diff * layer_gradient[:-num_examples],
|
383 |
+
n_steps,
|
384 |
+
num_examples,
|
385 |
+
layer_eval.shape[1:],
|
386 |
+
)
|
387 |
+
for layer_gradient, layer_eval, grad_diff in zip(
|
388 |
+
layer_gradients, layer_evals, grad_diffs
|
389 |
+
)
|
390 |
+
)
|
391 |
+
return _format_output(len(attributions) > 1, attributions)
|
392 |
+
|
393 |
+
@property
|
394 |
+
def multiplies_by_inputs(self):
|
395 |
+
return True
|
captum/attr/_core/layer/layer_deep_lift.py
ADDED
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import typing
|
3 |
+
from typing import Any, Callable, cast, Sequence, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from captum._utils.common import (
|
7 |
+
_expand_target,
|
8 |
+
_format_additional_forward_args,
|
9 |
+
_format_baseline,
|
10 |
+
_format_tensor_into_tuples,
|
11 |
+
ExpansionTypes,
|
12 |
+
)
|
13 |
+
from captum._utils.gradient import compute_layer_gradients_and_eval
|
14 |
+
from captum._utils.typing import (
|
15 |
+
BaselineType,
|
16 |
+
Literal,
|
17 |
+
TargetType,
|
18 |
+
TensorOrTupleOfTensorsGeneric,
|
19 |
+
)
|
20 |
+
from captum.attr._core.deep_lift import DeepLift, DeepLiftShap
|
21 |
+
from captum.attr._utils.attribution import LayerAttribution
|
22 |
+
from captum.attr._utils.common import (
|
23 |
+
_call_custom_attribution_func,
|
24 |
+
_compute_conv_delta_and_format_attrs,
|
25 |
+
_format_callable_baseline,
|
26 |
+
_tensorize_baseline,
|
27 |
+
_validate_input,
|
28 |
+
)
|
29 |
+
from captum.log import log_usage
|
30 |
+
from torch import Tensor
|
31 |
+
from torch.nn import Module
|
32 |
+
|
33 |
+
|
34 |
+
class LayerDeepLift(LayerAttribution, DeepLift):
|
35 |
+
r"""
|
36 |
+
Implements DeepLIFT algorithm for the layer based on the following paper:
|
37 |
+
Learning Important Features Through Propagating Activation Differences,
|
38 |
+
Avanti Shrikumar, et. al.
|
39 |
+
https://arxiv.org/abs/1704.02685
|
40 |
+
|
41 |
+
and the gradient formulation proposed in:
|
42 |
+
Towards better understanding of gradient-based attribution methods for
|
43 |
+
deep neural networks, Marco Ancona, et.al.
|
44 |
+
https://openreview.net/pdf?id=Sy21R9JAW
|
45 |
+
|
46 |
+
This implementation supports only Rescale rule. RevealCancel rule will
|
47 |
+
be supported in later releases.
|
48 |
+
Although DeepLIFT's(Rescale Rule) attribution quality is comparable with
|
49 |
+
Integrated Gradients, it runs significantly faster than Integrated
|
50 |
+
Gradients and is preferred for large datasets.
|
51 |
+
|
52 |
+
Currently we only support a limited number of non-linear activations
|
53 |
+
but the plan is to expand the list in the future.
|
54 |
+
|
55 |
+
Note: As we know, currently we cannot access the building blocks,
|
56 |
+
of PyTorch's built-in LSTM, RNNs and GRUs such as Tanh and Sigmoid.
|
57 |
+
Nonetheless, it is possible to build custom LSTMs, RNNS and GRUs
|
58 |
+
with performance similar to built-in ones using TorchScript.
|
59 |
+
More details on how to build custom RNNs can be found here:
|
60 |
+
https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
model: Module,
|
66 |
+
layer: Module,
|
67 |
+
multiply_by_inputs: bool = True,
|
68 |
+
) -> None:
|
69 |
+
r"""
|
70 |
+
Args:
|
71 |
+
|
72 |
+
model (nn.Module): The reference to PyTorch model instance. Model cannot
|
73 |
+
contain any in-place nonlinear submodules; these are not
|
74 |
+
supported by the register_full_backward_hook PyTorch API
|
75 |
+
starting from PyTorch v1.9.
|
76 |
+
layer (torch.nn.Module): Layer for which attributions are computed.
|
77 |
+
The size and dimensionality of the attributions
|
78 |
+
corresponds to the size and dimensionality of the layer's
|
79 |
+
input or output depending on whether we attribute to the
|
80 |
+
inputs or outputs of the layer.
|
81 |
+
multiply_by_inputs (bool, optional): Indicates whether to factor
|
82 |
+
model inputs' multiplier in the final attribution scores.
|
83 |
+
In the literature this is also known as local vs global
|
84 |
+
attribution. If inputs' multiplier isn't factored in
|
85 |
+
then that type of attribution method is also called local
|
86 |
+
attribution. If it is, then that type of attribution
|
87 |
+
method is called global.
|
88 |
+
More detailed can be found here:
|
89 |
+
https://arxiv.org/abs/1711.06104
|
90 |
+
|
91 |
+
In case of Layer DeepLift, if `multiply_by_inputs`
|
92 |
+
is set to True, final sensitivity scores
|
93 |
+
are being multiplied by
|
94 |
+
layer activations for inputs - layer activations for baselines.
|
95 |
+
This flag applies only if `custom_attribution_func` is
|
96 |
+
set to None.
|
97 |
+
"""
|
98 |
+
LayerAttribution.__init__(self, model, layer)
|
99 |
+
DeepLift.__init__(self, model)
|
100 |
+
self.model = model
|
101 |
+
self._multiply_by_inputs = multiply_by_inputs
|
102 |
+
|
103 |
+
# Ignoring mypy error for inconsistent signature with DeepLift
|
104 |
+
@typing.overload # type: ignore
|
105 |
+
def attribute(
|
106 |
+
self,
|
107 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
108 |
+
baselines: BaselineType = None,
|
109 |
+
target: TargetType = None,
|
110 |
+
additional_forward_args: Any = None,
|
111 |
+
return_convergence_delta: Literal[False] = False,
|
112 |
+
attribute_to_layer_input: bool = False,
|
113 |
+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
|
114 |
+
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
115 |
+
...
|
116 |
+
|
117 |
+
@typing.overload
|
118 |
+
def attribute(
|
119 |
+
self,
|
120 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
121 |
+
baselines: BaselineType = None,
|
122 |
+
target: TargetType = None,
|
123 |
+
additional_forward_args: Any = None,
|
124 |
+
*,
|
125 |
+
return_convergence_delta: Literal[True],
|
126 |
+
attribute_to_layer_input: bool = False,
|
127 |
+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
|
128 |
+
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]:
|
129 |
+
...
|
130 |
+
|
131 |
+
@log_usage()
|
132 |
+
def attribute(
|
133 |
+
self,
|
134 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
135 |
+
baselines: BaselineType = None,
|
136 |
+
target: TargetType = None,
|
137 |
+
additional_forward_args: Any = None,
|
138 |
+
return_convergence_delta: bool = False,
|
139 |
+
attribute_to_layer_input: bool = False,
|
140 |
+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
|
141 |
+
) -> Union[
|
142 |
+
Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
|
143 |
+
]:
|
144 |
+
r"""
|
145 |
+
Args:
|
146 |
+
|
147 |
+
inputs (tensor or tuple of tensors): Input for which layer
|
148 |
+
attributions are computed. If forward_func takes a
|
149 |
+
single tensor as input, a single input tensor should be
|
150 |
+
provided. If forward_func takes multiple tensors as input,
|
151 |
+
a tuple of the input tensors should be provided. It is
|
152 |
+
assumed that for all given input tensors, dimension 0
|
153 |
+
corresponds to the number of examples (aka batch size),
|
154 |
+
and if multiple input tensors are provided, the examples
|
155 |
+
must be aligned appropriately.
|
156 |
+
baselines (scalar, tensor, tuple of scalars or tensors, optional):
|
157 |
+
Baselines define reference samples that are compared with
|
158 |
+
the inputs. In order to assign attribution scores DeepLift
|
159 |
+
computes the differences between the inputs/outputs and
|
160 |
+
corresponding references.
|
161 |
+
Baselines can be provided as:
|
162 |
+
|
163 |
+
- a single tensor, if inputs is a single tensor, with
|
164 |
+
exactly the same dimensions as inputs or the first
|
165 |
+
dimension is one and the remaining dimensions match
|
166 |
+
with inputs.
|
167 |
+
|
168 |
+
- a single scalar, if inputs is a single tensor, which will
|
169 |
+
be broadcasted for each input value in input tensor.
|
170 |
+
|
171 |
+
- a tuple of tensors or scalars, the baseline corresponding
|
172 |
+
to each tensor in the inputs' tuple can be:
|
173 |
+
|
174 |
+
- either a tensor with matching dimensions to
|
175 |
+
corresponding tensor in the inputs' tuple
|
176 |
+
or the first dimension is one and the remaining
|
177 |
+
dimensions match with the corresponding
|
178 |
+
input tensor.
|
179 |
+
|
180 |
+
- or a scalar, corresponding to a tensor in the
|
181 |
+
inputs' tuple. This scalar value is broadcasted
|
182 |
+
for corresponding input tensor.
|
183 |
+
In the cases when `baselines` is not provided, we internally
|
184 |
+
use zero scalar corresponding to each input tensor.
|
185 |
+
|
186 |
+
Default: None
|
187 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
188 |
+
which gradients are computed (for classification cases,
|
189 |
+
this is usually the target class).
|
190 |
+
If the network returns a scalar value per example,
|
191 |
+
no target index is necessary.
|
192 |
+
For general 2D outputs, targets can be either:
|
193 |
+
|
194 |
+
- a single integer or a tensor containing a single
|
195 |
+
integer, which is applied to all input examples
|
196 |
+
|
197 |
+
- a list of integers or a 1D tensor, with length matching
|
198 |
+
the number of examples in inputs (dim 0). Each integer
|
199 |
+
is applied as the target for the corresponding example.
|
200 |
+
|
201 |
+
For outputs with > 2 dimensions, targets can be either:
|
202 |
+
|
203 |
+
- A single tuple, which contains #output_dims - 1
|
204 |
+
elements. This target index is applied to all examples.
|
205 |
+
|
206 |
+
- A list of tuples with length equal to the number of
|
207 |
+
examples in inputs (dim 0), and each tuple containing
|
208 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
209 |
+
target for the corresponding example.
|
210 |
+
|
211 |
+
Default: None
|
212 |
+
additional_forward_args (any, optional): If the forward function
|
213 |
+
requires additional arguments other than the inputs for
|
214 |
+
which attributions should not be computed, this argument
|
215 |
+
can be provided. It must be either a single additional
|
216 |
+
argument of a Tensor or arbitrary (non-tuple) type or a tuple
|
217 |
+
containing multiple additional arguments including tensors
|
218 |
+
or any arbitrary python types. These arguments are provided to
|
219 |
+
forward_func in order, following the arguments in inputs.
|
220 |
+
Note that attributions are not computed with respect
|
221 |
+
to these arguments.
|
222 |
+
Default: None
|
223 |
+
return_convergence_delta (bool, optional): Indicates whether to return
|
224 |
+
convergence delta or not. If `return_convergence_delta`
|
225 |
+
is set to True convergence delta will be returned in
|
226 |
+
a tuple following attributions.
|
227 |
+
Default: False
|
228 |
+
attribute_to_layer_input (bool, optional): Indicates whether to
|
229 |
+
compute the attribution with respect to the layer input
|
230 |
+
or output. If `attribute_to_layer_input` is set to True
|
231 |
+
then the attributions will be computed with respect to
|
232 |
+
layer input, otherwise it will be computed with respect
|
233 |
+
to layer output.
|
234 |
+
Note that currently it is assumed that either the input
|
235 |
+
or the output of internal layer, depending on whether we
|
236 |
+
attribute to the input or output, is a single tensor.
|
237 |
+
Support for multiple tensors will be added later.
|
238 |
+
Default: False
|
239 |
+
custom_attribution_func (callable, optional): A custom function for
|
240 |
+
computing final attribution scores. This function can take
|
241 |
+
at least one and at most three arguments with the
|
242 |
+
following signature:
|
243 |
+
|
244 |
+
- custom_attribution_func(multipliers)
|
245 |
+
- custom_attribution_func(multipliers, inputs)
|
246 |
+
- custom_attribution_func(multipliers, inputs, baselines)
|
247 |
+
|
248 |
+
In case this function is not provided, we use the default
|
249 |
+
logic defined as: multipliers * (inputs - baselines)
|
250 |
+
It is assumed that all input arguments, `multipliers`,
|
251 |
+
`inputs` and `baselines` are provided in tuples of same length.
|
252 |
+
`custom_attribution_func` returns a tuple of attribution
|
253 |
+
tensors that have the same length as the `inputs`.
|
254 |
+
Default: None
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
**attributions** or 2-element tuple of **attributions**, **delta**:
|
258 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
259 |
+
Attribution score computed based on DeepLift's rescale rule with
|
260 |
+
respect to layer's inputs or outputs. Attributions will always be the
|
261 |
+
same size as the provided layer's inputs or outputs, depending on
|
262 |
+
whether we attribute to the inputs or outputs of the layer.
|
263 |
+
If the layer input / output is a single tensor, then
|
264 |
+
just a tensor is returned; if the layer input / output
|
265 |
+
has multiple tensors, then a corresponding tuple
|
266 |
+
of tensors is returned.
|
267 |
+
- **delta** (*tensor*, returned if return_convergence_delta=True):
|
268 |
+
This is computed using the property that the total sum of
|
269 |
+
forward_func(inputs) - forward_func(baselines) must equal the
|
270 |
+
total sum of the attributions computed based on DeepLift's
|
271 |
+
rescale rule.
|
272 |
+
Delta is calculated per example, meaning that the number of
|
273 |
+
elements in returned delta tensor is equal to the number of
|
274 |
+
of examples in input.
|
275 |
+
Note that the logic described for deltas is guaranteed
|
276 |
+
when the default logic for attribution computations is used,
|
277 |
+
meaning that the `custom_attribution_func=None`, otherwise
|
278 |
+
it is not guaranteed and depends on the specifics of the
|
279 |
+
`custom_attribution_func`.
|
280 |
+
|
281 |
+
Examples::
|
282 |
+
|
283 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
284 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
285 |
+
>>> net = ImageClassifier()
|
286 |
+
>>> # creates an instance of LayerDeepLift to interpret target
|
287 |
+
>>> # class 1 with respect to conv4 layer.
|
288 |
+
>>> dl = LayerDeepLift(net, net.conv4)
|
289 |
+
>>> input = torch.randn(1, 3, 32, 32, requires_grad=True)
|
290 |
+
>>> # Computes deeplift attribution scores for conv4 layer and class 3.
|
291 |
+
>>> attribution = dl.attribute(input, target=1)
|
292 |
+
"""
|
293 |
+
inputs = _format_tensor_into_tuples(inputs)
|
294 |
+
baselines = _format_baseline(baselines, inputs)
|
295 |
+
_validate_input(inputs, baselines)
|
296 |
+
|
297 |
+
baselines = _tensorize_baseline(inputs, baselines)
|
298 |
+
|
299 |
+
main_model_hooks = []
|
300 |
+
try:
|
301 |
+
main_model_hooks = self._hook_main_model()
|
302 |
+
|
303 |
+
self.model.apply(
|
304 |
+
lambda mod: self._register_hooks(
|
305 |
+
mod, attribute_to_layer_input=attribute_to_layer_input
|
306 |
+
)
|
307 |
+
)
|
308 |
+
|
309 |
+
additional_forward_args = _format_additional_forward_args(
|
310 |
+
additional_forward_args
|
311 |
+
)
|
312 |
+
expanded_target = _expand_target(
|
313 |
+
target, 2, expansion_type=ExpansionTypes.repeat
|
314 |
+
)
|
315 |
+
wrapped_forward_func = self._construct_forward_func(
|
316 |
+
self.model,
|
317 |
+
(inputs, baselines),
|
318 |
+
expanded_target,
|
319 |
+
additional_forward_args,
|
320 |
+
)
|
321 |
+
|
322 |
+
def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
|
323 |
+
if isinstance(out, Tensor):
|
324 |
+
return out.chunk(2)
|
325 |
+
return tuple(out_sub.chunk(2) for out_sub in out)
|
326 |
+
|
327 |
+
gradients, attrs = compute_layer_gradients_and_eval(
|
328 |
+
wrapped_forward_func,
|
329 |
+
self.layer,
|
330 |
+
inputs,
|
331 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
332 |
+
output_fn=lambda out: chunk_output_fn(out),
|
333 |
+
)
|
334 |
+
|
335 |
+
attr_inputs = tuple(map(lambda attr: attr[0], attrs))
|
336 |
+
attr_baselines = tuple(map(lambda attr: attr[1], attrs))
|
337 |
+
gradients = tuple(map(lambda grad: grad[0], gradients))
|
338 |
+
|
339 |
+
if custom_attribution_func is None:
|
340 |
+
if self.multiplies_by_inputs:
|
341 |
+
attributions = tuple(
|
342 |
+
(input - baseline) * gradient
|
343 |
+
for input, baseline, gradient in zip(
|
344 |
+
attr_inputs, attr_baselines, gradients
|
345 |
+
)
|
346 |
+
)
|
347 |
+
else:
|
348 |
+
attributions = gradients
|
349 |
+
else:
|
350 |
+
attributions = _call_custom_attribution_func(
|
351 |
+
custom_attribution_func, gradients, attr_inputs, attr_baselines
|
352 |
+
)
|
353 |
+
finally:
|
354 |
+
# remove hooks from all activations
|
355 |
+
self._remove_hooks(main_model_hooks)
|
356 |
+
|
357 |
+
return _compute_conv_delta_and_format_attrs(
|
358 |
+
self,
|
359 |
+
return_convergence_delta,
|
360 |
+
attributions,
|
361 |
+
baselines,
|
362 |
+
inputs,
|
363 |
+
additional_forward_args,
|
364 |
+
target,
|
365 |
+
cast(Union[Literal[True], Literal[False]], len(attributions) > 1),
|
366 |
+
)
|
367 |
+
|
368 |
+
@property
|
369 |
+
def multiplies_by_inputs(self):
|
370 |
+
return self._multiply_by_inputs
|
371 |
+
|
372 |
+
|
373 |
+
class LayerDeepLiftShap(LayerDeepLift, DeepLiftShap):
|
374 |
+
r"""
|
375 |
+
Extends LayerDeepLift and DeepLiftShap algorithms and approximates SHAP
|
376 |
+
values for given input `layer`.
|
377 |
+
For each input sample - baseline pair it computes DeepLift attributions
|
378 |
+
with respect to inputs or outputs of given `layer` averages
|
379 |
+
resulting attributions across baselines. Whether to compute the attributions
|
380 |
+
with respect to the inputs or outputs of the layer is defined by the
|
381 |
+
input flag `attribute_to_layer_input`.
|
382 |
+
More details about the algorithm can be found here:
|
383 |
+
|
384 |
+
http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf
|
385 |
+
|
386 |
+
Note that the explanation model:
|
387 |
+
1. Assumes that input features are independent of one another
|
388 |
+
2. Is linear, meaning that the explanations are modeled through
|
389 |
+
the additive composition of feature effects.
|
390 |
+
Although, it assumes a linear model for each explanation, the overall
|
391 |
+
model across multiple explanations can be complex and non-linear.
|
392 |
+
"""
|
393 |
+
|
394 |
+
def __init__(
|
395 |
+
self,
|
396 |
+
model: Module,
|
397 |
+
layer: Module,
|
398 |
+
multiply_by_inputs: bool = True,
|
399 |
+
) -> None:
|
400 |
+
r"""
|
401 |
+
Args:
|
402 |
+
|
403 |
+
model (nn.Module): The reference to PyTorch model instance. Model cannot
|
404 |
+
contain any in-place nonlinear submodules; these are not
|
405 |
+
supported by the register_full_backward_hook PyTorch API
|
406 |
+
starting from PyTorch v1.9.
|
407 |
+
layer (torch.nn.Module): Layer for which attributions are computed.
|
408 |
+
The size and dimensionality of the attributions
|
409 |
+
corresponds to the size and dimensionality of the layer's
|
410 |
+
input or output depending on whether we attribute to the
|
411 |
+
inputs or outputs of the layer.
|
412 |
+
multiply_by_inputs (bool, optional): Indicates whether to factor
|
413 |
+
model inputs' multiplier in the final attribution scores.
|
414 |
+
In the literature this is also known as local vs global
|
415 |
+
attribution. If inputs' multiplier isn't factored in
|
416 |
+
then that type of attribution method is also called local
|
417 |
+
attribution. If it is, then that type of attribution
|
418 |
+
method is called global.
|
419 |
+
More detailed can be found here:
|
420 |
+
https://arxiv.org/abs/1711.06104
|
421 |
+
|
422 |
+
In case of LayerDeepLiftShap, if `multiply_by_inputs`
|
423 |
+
is set to True, final sensitivity scores are being
|
424 |
+
multiplied by
|
425 |
+
layer activations for inputs - layer activations for baselines
|
426 |
+
This flag applies only if `custom_attribution_func` is
|
427 |
+
set to None.
|
428 |
+
"""
|
429 |
+
LayerDeepLift.__init__(self, model, layer)
|
430 |
+
DeepLiftShap.__init__(self, model, multiply_by_inputs)
|
431 |
+
|
432 |
+
# Ignoring mypy error for inconsistent signature with DeepLiftShap
|
433 |
+
@typing.overload # type: ignore
|
434 |
+
def attribute(
|
435 |
+
self,
|
436 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
437 |
+
baselines: Union[
|
438 |
+
Tensor, Tuple[Tensor, ...], Callable[..., Union[Tensor, Tuple[Tensor, ...]]]
|
439 |
+
],
|
440 |
+
target: TargetType = None,
|
441 |
+
additional_forward_args: Any = None,
|
442 |
+
return_convergence_delta: Literal[False] = False,
|
443 |
+
attribute_to_layer_input: bool = False,
|
444 |
+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
|
445 |
+
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
446 |
+
...
|
447 |
+
|
448 |
+
@typing.overload
|
449 |
+
def attribute(
|
450 |
+
self,
|
451 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
452 |
+
baselines: Union[
|
453 |
+
Tensor, Tuple[Tensor, ...], Callable[..., Union[Tensor, Tuple[Tensor, ...]]]
|
454 |
+
],
|
455 |
+
target: TargetType = None,
|
456 |
+
additional_forward_args: Any = None,
|
457 |
+
*,
|
458 |
+
return_convergence_delta: Literal[True],
|
459 |
+
attribute_to_layer_input: bool = False,
|
460 |
+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
|
461 |
+
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]:
|
462 |
+
...
|
463 |
+
|
464 |
+
@log_usage()
|
465 |
+
def attribute(
|
466 |
+
self,
|
467 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
468 |
+
baselines: Union[
|
469 |
+
Tensor, Tuple[Tensor, ...], Callable[..., Union[Tensor, Tuple[Tensor, ...]]]
|
470 |
+
],
|
471 |
+
target: TargetType = None,
|
472 |
+
additional_forward_args: Any = None,
|
473 |
+
return_convergence_delta: bool = False,
|
474 |
+
attribute_to_layer_input: bool = False,
|
475 |
+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
|
476 |
+
) -> Union[
|
477 |
+
Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
|
478 |
+
]:
|
479 |
+
r"""
|
480 |
+
Args:
|
481 |
+
|
482 |
+
inputs (tensor or tuple of tensors): Input for which layer
|
483 |
+
attributions are computed. If forward_func takes a single
|
484 |
+
tensor as input, a single input tensor should be provided.
|
485 |
+
If forward_func takes multiple tensors as input, a tuple
|
486 |
+
of the input tensors should be provided. It is assumed
|
487 |
+
that for all given input tensors, dimension 0 corresponds
|
488 |
+
to the number of examples (aka batch size), and if
|
489 |
+
multiple input tensors are provided, the examples must
|
490 |
+
be aligned appropriately.
|
491 |
+
baselines (tensor, tuple of tensors, callable):
|
492 |
+
Baselines define reference samples that are compared with
|
493 |
+
the inputs. In order to assign attribution scores DeepLift
|
494 |
+
computes the differences between the inputs/outputs and
|
495 |
+
corresponding references. Baselines can be provided as:
|
496 |
+
|
497 |
+
- a single tensor, if inputs is a single tensor, with
|
498 |
+
the first dimension equal to the number of examples
|
499 |
+
in the baselines' distribution. The remaining dimensions
|
500 |
+
must match with input tensor's dimension starting from
|
501 |
+
the second dimension.
|
502 |
+
|
503 |
+
- a tuple of tensors, if inputs is a tuple of tensors,
|
504 |
+
with the first dimension of any tensor inside the tuple
|
505 |
+
equal to the number of examples in the baseline's
|
506 |
+
distribution. The remaining dimensions must match
|
507 |
+
the dimensions of the corresponding input tensor
|
508 |
+
starting from the second dimension.
|
509 |
+
|
510 |
+
- callable function, optionally takes `inputs` as an
|
511 |
+
argument and either returns a single tensor
|
512 |
+
or a tuple of those.
|
513 |
+
|
514 |
+
It is recommended that the number of samples in the baselines'
|
515 |
+
tensors is larger than one.
|
516 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
517 |
+
which gradients are computed (for classification cases,
|
518 |
+
this is usually the target class).
|
519 |
+
If the network returns a scalar value per example,
|
520 |
+
no target index is necessary.
|
521 |
+
For general 2D outputs, targets can be either:
|
522 |
+
|
523 |
+
- a single integer or a tensor containing a single
|
524 |
+
integer, which is applied to all input examples
|
525 |
+
|
526 |
+
- a list of integers or a 1D tensor, with length matching
|
527 |
+
the number of examples in inputs (dim 0). Each integer
|
528 |
+
is applied as the target for the corresponding example.
|
529 |
+
|
530 |
+
For outputs with > 2 dimensions, targets can be either:
|
531 |
+
|
532 |
+
- A single tuple, which contains #output_dims - 1
|
533 |
+
elements. This target index is applied to all examples.
|
534 |
+
|
535 |
+
- A list of tuples with length equal to the number of
|
536 |
+
examples in inputs (dim 0), and each tuple containing
|
537 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
538 |
+
target for the corresponding example.
|
539 |
+
|
540 |
+
Default: None
|
541 |
+
additional_forward_args (any, optional): If the forward function
|
542 |
+
requires additional arguments other than the inputs for
|
543 |
+
which attributions should not be computed, this argument
|
544 |
+
can be provided. It must be either a single additional
|
545 |
+
argument of a Tensor or arbitrary (non-tuple) type or a tuple
|
546 |
+
containing multiple additional arguments including tensors
|
547 |
+
or any arbitrary python types. These arguments are provided to
|
548 |
+
forward_func in order, following the arguments in inputs.
|
549 |
+
Note that attributions are not computed with respect
|
550 |
+
to these arguments.
|
551 |
+
Default: None
|
552 |
+
return_convergence_delta (bool, optional): Indicates whether to return
|
553 |
+
convergence delta or not. If `return_convergence_delta`
|
554 |
+
is set to True convergence delta will be returned in
|
555 |
+
a tuple following attributions.
|
556 |
+
Default: False
|
557 |
+
attribute_to_layer_input (bool, optional): Indicates whether to
|
558 |
+
compute the attributions with respect to the layer input
|
559 |
+
or output. If `attribute_to_layer_input` is set to True
|
560 |
+
then the attributions will be computed with respect to
|
561 |
+
layer inputs, otherwise it will be computed with respect
|
562 |
+
to layer outputs.
|
563 |
+
Note that currently it assumes that both the inputs and
|
564 |
+
outputs of internal layers are single tensors.
|
565 |
+
Support for multiple tensors will be added later.
|
566 |
+
Default: False
|
567 |
+
custom_attribution_func (callable, optional): A custom function for
|
568 |
+
computing final attribution scores. This function can take
|
569 |
+
at least one and at most three arguments with the
|
570 |
+
following signature:
|
571 |
+
|
572 |
+
- custom_attribution_func(multipliers)
|
573 |
+
- custom_attribution_func(multipliers, inputs)
|
574 |
+
- custom_attribution_func(multipliers, inputs, baselines)
|
575 |
+
|
576 |
+
In case this function is not provided, we use the default
|
577 |
+
logic defined as: multipliers * (inputs - baselines)
|
578 |
+
It is assumed that all input arguments, `multipliers`,
|
579 |
+
`inputs` and `baselines` are provided in tuples of same
|
580 |
+
length. `custom_attribution_func` returns a tuple of
|
581 |
+
attribution tensors that have the same length as the
|
582 |
+
`inputs`.
|
583 |
+
Default: None
|
584 |
+
|
585 |
+
Returns:
|
586 |
+
**attributions** or 2-element tuple of **attributions**, **delta**:
|
587 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
588 |
+
Attribution score computed based on DeepLift's rescale rule
|
589 |
+
with respect to layer's inputs or outputs. Attributions
|
590 |
+
will always be the same size as the provided layer's inputs
|
591 |
+
or outputs, depending on whether we attribute to the inputs
|
592 |
+
or outputs of the layer.
|
593 |
+
Attributions are returned in a tuple based on whether
|
594 |
+
the layer inputs / outputs are contained in a tuple
|
595 |
+
from a forward hook. For standard modules, inputs of
|
596 |
+
a single tensor are usually wrapped in a tuple, while
|
597 |
+
outputs of a single tensor are not.
|
598 |
+
- **delta** (*tensor*, returned if return_convergence_delta=True):
|
599 |
+
This is computed using the property that the
|
600 |
+
total sum of forward_func(inputs) - forward_func(baselines)
|
601 |
+
must be very close to the total sum of attributions
|
602 |
+
computed based on approximated SHAP values using
|
603 |
+
DeepLift's rescale rule.
|
604 |
+
Delta is calculated for each example input and baseline pair,
|
605 |
+
meaning that the number of elements in returned delta tensor
|
606 |
+
is equal to the
|
607 |
+
`number of examples in input` * `number of examples
|
608 |
+
in baseline`. The deltas are ordered in the first place by
|
609 |
+
input example, followed by the baseline.
|
610 |
+
Note that the logic described for deltas is guaranteed
|
611 |
+
when the default logic for attribution computations is used,
|
612 |
+
meaning that the `custom_attribution_func=None`, otherwise
|
613 |
+
it is not guaranteed and depends on the specifics of the
|
614 |
+
`custom_attribution_func`.
|
615 |
+
Examples::
|
616 |
+
|
617 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
618 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
619 |
+
>>> net = ImageClassifier()
|
620 |
+
>>> # creates an instance of LayerDeepLift to interpret target
|
621 |
+
>>> # class 1 with respect to conv4 layer.
|
622 |
+
>>> dl = LayerDeepLiftShap(net, net.conv4)
|
623 |
+
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
|
624 |
+
>>> # Computes shap values using deeplift for class 3.
|
625 |
+
>>> attribution = dl.attribute(input, target=3)
|
626 |
+
"""
|
627 |
+
inputs = _format_tensor_into_tuples(inputs)
|
628 |
+
baselines = _format_callable_baseline(baselines, inputs)
|
629 |
+
|
630 |
+
assert isinstance(baselines[0], torch.Tensor) and baselines[0].shape[0] > 1, (
|
631 |
+
"Baselines distribution has to be provided in form of a torch.Tensor"
|
632 |
+
" with more than one example but found: {}."
|
633 |
+
" If baselines are provided in shape of scalars or with a single"
|
634 |
+
" baseline example, `LayerDeepLift`"
|
635 |
+
" approach can be used instead.".format(baselines[0])
|
636 |
+
)
|
637 |
+
|
638 |
+
# batch sizes
|
639 |
+
inp_bsz = inputs[0].shape[0]
|
640 |
+
base_bsz = baselines[0].shape[0]
|
641 |
+
|
642 |
+
(
|
643 |
+
exp_inp,
|
644 |
+
exp_base,
|
645 |
+
exp_target,
|
646 |
+
exp_addit_args,
|
647 |
+
) = DeepLiftShap._expand_inputs_baselines_targets(
|
648 |
+
self, baselines, inputs, target, additional_forward_args
|
649 |
+
)
|
650 |
+
attributions = LayerDeepLift.attribute.__wrapped__( # type: ignore
|
651 |
+
self,
|
652 |
+
exp_inp,
|
653 |
+
exp_base,
|
654 |
+
target=exp_target,
|
655 |
+
additional_forward_args=exp_addit_args,
|
656 |
+
return_convergence_delta=cast(
|
657 |
+
Literal[True, False], return_convergence_delta
|
658 |
+
),
|
659 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
660 |
+
custom_attribution_func=custom_attribution_func,
|
661 |
+
)
|
662 |
+
if return_convergence_delta:
|
663 |
+
attributions, delta = attributions
|
664 |
+
if isinstance(attributions, tuple):
|
665 |
+
attributions = tuple(
|
666 |
+
DeepLiftShap._compute_mean_across_baselines(
|
667 |
+
self, inp_bsz, base_bsz, cast(Tensor, attrib)
|
668 |
+
)
|
669 |
+
for attrib in attributions
|
670 |
+
)
|
671 |
+
else:
|
672 |
+
attributions = DeepLiftShap._compute_mean_across_baselines(
|
673 |
+
self, inp_bsz, base_bsz, attributions
|
674 |
+
)
|
675 |
+
if return_convergence_delta:
|
676 |
+
return attributions, delta
|
677 |
+
else:
|
678 |
+
return attributions
|
679 |
+
|
680 |
+
@property
|
681 |
+
def multiplies_by_inputs(self):
|
682 |
+
return self._multiply_by_inputs
|
captum/attr/_core/layer/layer_feature_ablation.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from typing import Any, Callable, List, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from captum._utils.common import (
|
6 |
+
_extract_device,
|
7 |
+
_format_additional_forward_args,
|
8 |
+
_format_output,
|
9 |
+
_format_tensor_into_tuples,
|
10 |
+
_run_forward,
|
11 |
+
)
|
12 |
+
from captum._utils.gradient import _forward_layer_eval
|
13 |
+
from captum._utils.typing import BaselineType, TargetType
|
14 |
+
from captum.attr._core.feature_ablation import FeatureAblation
|
15 |
+
from captum.attr._utils.attribution import LayerAttribution, PerturbationAttribution
|
16 |
+
from captum.log import log_usage
|
17 |
+
from torch import Tensor
|
18 |
+
from torch.nn import Module
|
19 |
+
from torch.nn.parallel.scatter_gather import scatter
|
20 |
+
|
21 |
+
|
22 |
+
class LayerFeatureAblation(LayerAttribution, PerturbationAttribution):
|
23 |
+
r"""
|
24 |
+
A perturbation based approach to computing layer attribution, involving
|
25 |
+
replacing values in the input / output of a layer with a given baseline /
|
26 |
+
reference, and computing the difference in output. By default, each
|
27 |
+
neuron (scalar input / output value) within the layer is replaced
|
28 |
+
independently.
|
29 |
+
Passing a layer mask allows grouping neurons to be
|
30 |
+
ablated together.
|
31 |
+
Each neuron in the group will be given the same attribution value
|
32 |
+
equal to the change in target as a result of ablating the entire neuron
|
33 |
+
group.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
forward_func: Callable,
|
39 |
+
layer: Module,
|
40 |
+
device_ids: Union[None, List[int]] = None,
|
41 |
+
) -> None:
|
42 |
+
r"""
|
43 |
+
Args:
|
44 |
+
|
45 |
+
forward_func (callable): The forward function of the model or any
|
46 |
+
modification of it
|
47 |
+
layer (torch.nn.Module): Layer for which attributions are computed.
|
48 |
+
Output size of attribute matches this layer's input or
|
49 |
+
output dimensions, depending on whether we attribute to
|
50 |
+
the inputs or outputs of the layer, corresponding to
|
51 |
+
attribution of each neuron in the input or output of
|
52 |
+
this layer.
|
53 |
+
device_ids (list(int)): Device ID list, necessary only if forward_func
|
54 |
+
applies a DataParallel model. This allows reconstruction of
|
55 |
+
intermediate outputs from batched results across devices.
|
56 |
+
If forward_func is given as the DataParallel model itself
|
57 |
+
(or otherwise has a device_ids attribute with the device
|
58 |
+
ID list), then it is not necessary to provide this
|
59 |
+
argument.
|
60 |
+
"""
|
61 |
+
LayerAttribution.__init__(self, forward_func, layer, device_ids)
|
62 |
+
PerturbationAttribution.__init__(self, forward_func)
|
63 |
+
|
64 |
+
@log_usage()
|
65 |
+
def attribute(
|
66 |
+
self,
|
67 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
68 |
+
layer_baselines: BaselineType = None,
|
69 |
+
target: TargetType = None,
|
70 |
+
additional_forward_args: Any = None,
|
71 |
+
layer_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
|
72 |
+
attribute_to_layer_input: bool = False,
|
73 |
+
perturbations_per_eval: int = 1,
|
74 |
+
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
75 |
+
r"""
|
76 |
+
Args:
|
77 |
+
|
78 |
+
inputs (tensor or tuple of tensors): Input for which layer
|
79 |
+
attributions are computed. If forward_func takes a single
|
80 |
+
tensor as input, a single input tensor should be provided.
|
81 |
+
If forward_func takes multiple tensors as input, a tuple
|
82 |
+
of the input tensors should be provided. It is assumed
|
83 |
+
that for all given input tensors, dimension 0 corresponds
|
84 |
+
to the number of examples, and if multiple input tensors
|
85 |
+
are provided, the examples must be aligned appropriately.
|
86 |
+
layer_baselines (scalar, tensor, tuple of scalars or tensors, optional):
|
87 |
+
Layer baselines define reference values which replace each
|
88 |
+
layer input / output value when ablated.
|
89 |
+
Layer baselines should be a single tensor with dimensions
|
90 |
+
matching the input / output of the target layer (or
|
91 |
+
broadcastable to match it), based
|
92 |
+
on whether we are attributing to the input or output
|
93 |
+
of the target layer.
|
94 |
+
In the cases when `baselines` is not provided, we internally
|
95 |
+
use zero as the baseline for each neuron.
|
96 |
+
Default: None
|
97 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
98 |
+
which gradients are computed (for classification cases,
|
99 |
+
this is usually the target class).
|
100 |
+
If the network returns a scalar value per example,
|
101 |
+
no target index is necessary.
|
102 |
+
For general 2D outputs, targets can be either:
|
103 |
+
|
104 |
+
- a single integer or a tensor containing a single
|
105 |
+
integer, which is applied to all input examples
|
106 |
+
|
107 |
+
- a list of integers or a 1D tensor, with length matching
|
108 |
+
the number of examples in inputs (dim 0). Each integer
|
109 |
+
is applied as the target for the corresponding example.
|
110 |
+
|
111 |
+
For outputs with > 2 dimensions, targets can be either:
|
112 |
+
|
113 |
+
- A single tuple, which contains #output_dims - 1
|
114 |
+
elements. This target index is applied to all examples.
|
115 |
+
|
116 |
+
- A list of tuples with length equal to the number of
|
117 |
+
examples in inputs (dim 0), and each tuple containing
|
118 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
119 |
+
target for the corresponding example.
|
120 |
+
|
121 |
+
Default: None
|
122 |
+
additional_forward_args (any, optional): If the forward function
|
123 |
+
requires additional arguments other than the inputs for
|
124 |
+
which attributions should not be computed, this argument
|
125 |
+
can be provided. It must be either a single additional
|
126 |
+
argument of a Tensor or arbitrary (non-tuple) type or a
|
127 |
+
tuple containing multiple additional arguments including
|
128 |
+
tensors or any arbitrary python types. These arguments
|
129 |
+
are provided to forward_func in order following the
|
130 |
+
arguments in inputs.
|
131 |
+
Note that attributions are not computed with respect
|
132 |
+
to these arguments.
|
133 |
+
Default: None
|
134 |
+
layer_mask (tensor or tuple of tensors, optional):
|
135 |
+
layer_mask defines a mask for the layer, grouping
|
136 |
+
elements of the layer input / output which should be
|
137 |
+
ablated together.
|
138 |
+
layer_mask should be a single tensor with dimensions
|
139 |
+
matching the input / output of the target layer (or
|
140 |
+
broadcastable to match it), based
|
141 |
+
on whether we are attributing to the input or output
|
142 |
+
of the target layer. layer_mask
|
143 |
+
should contain integers in the range 0 to num_groups
|
144 |
+
- 1, and all elements with the same value are
|
145 |
+
considered to be in the same group.
|
146 |
+
If None, then a layer mask is constructed which assigns
|
147 |
+
each neuron within the layer as a separate group, which
|
148 |
+
is ablated independently.
|
149 |
+
Default: None
|
150 |
+
attribute_to_layer_input (bool, optional): Indicates whether to
|
151 |
+
compute the attributions with respect to the layer input
|
152 |
+
or output. If `attribute_to_layer_input` is set to True
|
153 |
+
then the attributions will be computed with respect to
|
154 |
+
layer's inputs, otherwise it will be computed with respect
|
155 |
+
to layer's outputs.
|
156 |
+
Note that currently it is assumed that either the input
|
157 |
+
or the output of the layer, depending on whether we
|
158 |
+
attribute to the input or output, is a single tensor.
|
159 |
+
Support for multiple tensors will be added later.
|
160 |
+
Default: False
|
161 |
+
perturbations_per_eval (int, optional): Allows ablation of multiple
|
162 |
+
neuron (groups) to be processed simultaneously in one
|
163 |
+
call to forward_fn.
|
164 |
+
Each forward pass will contain a maximum of
|
165 |
+
perturbations_per_eval * #examples samples.
|
166 |
+
For DataParallel models, each batch is split among the
|
167 |
+
available devices, so evaluations on each available
|
168 |
+
device contain at most
|
169 |
+
(perturbations_per_eval * #examples) / num_devices
|
170 |
+
samples.
|
171 |
+
Default: 1
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
*tensor* or tuple of *tensors* of **attributions**:
|
175 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
176 |
+
Attribution of each neuron in given layer input or
|
177 |
+
output. Attributions will always be the same size as
|
178 |
+
the input or output of the given layer, depending on
|
179 |
+
whether we attribute to the inputs or outputs
|
180 |
+
of the layer which is decided by the input flag
|
181 |
+
`attribute_to_layer_input`
|
182 |
+
Attributions are returned in a tuple if
|
183 |
+
the layer inputs / outputs contain multiple tensors,
|
184 |
+
otherwise a single tensor is returned.
|
185 |
+
|
186 |
+
|
187 |
+
Examples::
|
188 |
+
|
189 |
+
>>> # SimpleClassifier takes a single input tensor of size Nx4x4,
|
190 |
+
>>> # and returns an Nx3 tensor of class probabilities.
|
191 |
+
>>> # It contains an attribute conv1, which is an instance of nn.conv2d,
|
192 |
+
>>> # and the output of this layer has dimensions Nx12x3x3.
|
193 |
+
>>> net = SimpleClassifier()
|
194 |
+
>>> # Generating random input with size 2 x 4 x 4
|
195 |
+
>>> input = torch.randn(2, 4, 4)
|
196 |
+
>>> # Defining LayerFeatureAblation interpreter
|
197 |
+
>>> ablator = LayerFeatureAblation(net, net.conv1)
|
198 |
+
>>> # Computes ablation attribution, ablating each of the 108
|
199 |
+
>>> # neurons independently.
|
200 |
+
>>> attr = ablator.attribute(input, target=1)
|
201 |
+
|
202 |
+
>>> # Alternatively, we may want to ablate neurons in groups, e.g.
|
203 |
+
>>> # grouping all the layer outputs in the same row.
|
204 |
+
>>> # This can be done by creating a layer mask as follows, which
|
205 |
+
>>> # defines the groups of layer inputs / outouts, e.g.:
|
206 |
+
>>> # +---+---+---+
|
207 |
+
>>> # | 0 | 0 | 0 |
|
208 |
+
>>> # +---+---+---+
|
209 |
+
>>> # | 1 | 1 | 1 |
|
210 |
+
>>> # +---+---+---+
|
211 |
+
>>> # | 2 | 2 | 2 |
|
212 |
+
>>> # +---+---+---+
|
213 |
+
>>> # With this mask, all the 36 neurons in a row / channel are ablated
|
214 |
+
>>> # simultaneously, and the attribution for each neuron in the same
|
215 |
+
>>> # group (0 - 2) per example are the same.
|
216 |
+
>>> # The attributions can be calculated as follows:
|
217 |
+
>>> # layer mask has dimensions 1 x 3 x 3
|
218 |
+
>>> layer_mask = torch.tensor([[[0,0,0],[1,1,1],
|
219 |
+
>>> [2,2,2]]])
|
220 |
+
>>> attr = ablator.attribute(input, target=1,
|
221 |
+
>>> layer_mask=layer_mask)
|
222 |
+
"""
|
223 |
+
|
224 |
+
def layer_forward_func(*args):
|
225 |
+
layer_length = args[-1]
|
226 |
+
layer_input = args[:layer_length]
|
227 |
+
original_inputs = args[layer_length:-1]
|
228 |
+
|
229 |
+
device_ids = self.device_ids
|
230 |
+
if device_ids is None:
|
231 |
+
device_ids = getattr(self.forward_func, "device_ids", None)
|
232 |
+
|
233 |
+
all_layer_inputs = {}
|
234 |
+
if device_ids is not None:
|
235 |
+
scattered_layer_input = scatter(layer_input, target_gpus=device_ids)
|
236 |
+
for device_tensors in scattered_layer_input:
|
237 |
+
all_layer_inputs[device_tensors[0].device] = device_tensors
|
238 |
+
else:
|
239 |
+
all_layer_inputs[layer_input[0].device] = layer_input
|
240 |
+
|
241 |
+
def forward_hook(module, inp, out=None):
|
242 |
+
device = _extract_device(module, inp, out)
|
243 |
+
is_layer_tuple = (
|
244 |
+
isinstance(out, tuple)
|
245 |
+
if out is not None
|
246 |
+
else isinstance(inp, tuple)
|
247 |
+
)
|
248 |
+
if device not in all_layer_inputs:
|
249 |
+
raise AssertionError(
|
250 |
+
"Layer input not placed on appropriate "
|
251 |
+
"device. If using a DataParallel model, either provide the "
|
252 |
+
"DataParallel model as forward_func or provide device ids"
|
253 |
+
" to the constructor."
|
254 |
+
)
|
255 |
+
if not is_layer_tuple:
|
256 |
+
return all_layer_inputs[device][0]
|
257 |
+
return all_layer_inputs[device]
|
258 |
+
|
259 |
+
hook = None
|
260 |
+
try:
|
261 |
+
if attribute_to_layer_input:
|
262 |
+
hook = self.layer.register_forward_pre_hook(forward_hook)
|
263 |
+
else:
|
264 |
+
hook = self.layer.register_forward_hook(forward_hook)
|
265 |
+
eval = _run_forward(self.forward_func, original_inputs, target=target)
|
266 |
+
finally:
|
267 |
+
if hook is not None:
|
268 |
+
hook.remove()
|
269 |
+
return eval
|
270 |
+
|
271 |
+
with torch.no_grad():
|
272 |
+
inputs = _format_tensor_into_tuples(inputs)
|
273 |
+
additional_forward_args = _format_additional_forward_args(
|
274 |
+
additional_forward_args
|
275 |
+
)
|
276 |
+
layer_eval = _forward_layer_eval(
|
277 |
+
self.forward_func,
|
278 |
+
inputs,
|
279 |
+
self.layer,
|
280 |
+
additional_forward_args,
|
281 |
+
device_ids=self.device_ids,
|
282 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
283 |
+
)
|
284 |
+
layer_eval_len = (len(layer_eval),)
|
285 |
+
all_inputs = (
|
286 |
+
(inputs + additional_forward_args + layer_eval_len)
|
287 |
+
if additional_forward_args is not None
|
288 |
+
else inputs + layer_eval_len
|
289 |
+
)
|
290 |
+
|
291 |
+
ablator = FeatureAblation(layer_forward_func)
|
292 |
+
|
293 |
+
layer_attribs = ablator.attribute.__wrapped__(
|
294 |
+
ablator, # self
|
295 |
+
layer_eval,
|
296 |
+
baselines=layer_baselines,
|
297 |
+
additional_forward_args=all_inputs,
|
298 |
+
feature_mask=layer_mask,
|
299 |
+
perturbations_per_eval=perturbations_per_eval,
|
300 |
+
)
|
301 |
+
_attr = _format_output(len(layer_attribs) > 1, layer_attribs)
|
302 |
+
return _attr
|
captum/attr/_core/layer/layer_gradient_shap.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import typing
|
4 |
+
from typing import Any, Callable, cast, List, Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from captum._utils.gradient import _forward_layer_eval, compute_layer_gradients_and_eval
|
9 |
+
from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric
|
10 |
+
from captum.attr._core.gradient_shap import _scale_input
|
11 |
+
from captum.attr._core.noise_tunnel import NoiseTunnel
|
12 |
+
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
|
13 |
+
from captum.attr._utils.common import (
|
14 |
+
_compute_conv_delta_and_format_attrs,
|
15 |
+
_format_callable_baseline,
|
16 |
+
_format_input_baseline,
|
17 |
+
)
|
18 |
+
from captum.log import log_usage
|
19 |
+
from torch import Tensor
|
20 |
+
from torch.nn import Module
|
21 |
+
|
22 |
+
|
23 |
+
class LayerGradientShap(LayerAttribution, GradientAttribution):
|
24 |
+
r"""
|
25 |
+
Implements gradient SHAP for layer based on the implementation from SHAP's
|
26 |
+
primary author. For reference, please, view:
|
27 |
+
|
28 |
+
https://github.com/slundberg/shap\
|
29 |
+
#deep-learning-example-with-gradientexplainer-tensorflowkeraspytorch-models
|
30 |
+
|
31 |
+
A Unified Approach to Interpreting Model Predictions
|
32 |
+
http://papers.nips.cc/paper\
|
33 |
+
7062-a-unified-approach-to-interpreting-model-predictions
|
34 |
+
|
35 |
+
GradientShap approximates SHAP values by computing the expectations of
|
36 |
+
gradients by randomly sampling from the distribution of baselines/references.
|
37 |
+
It adds white noise to each input sample `n_samples` times, selects a
|
38 |
+
random baseline from baselines' distribution and a random point along the
|
39 |
+
path between the baseline and the input, and computes the gradient of
|
40 |
+
outputs with respect to selected random points in chosen `layer`.
|
41 |
+
The final SHAP values represent the expected values of
|
42 |
+
`gradients * (layer_attr_inputs - layer_attr_baselines)`.
|
43 |
+
|
44 |
+
GradientShap makes an assumption that the input features are independent
|
45 |
+
and that the explanation model is linear, meaning that the explanations
|
46 |
+
are modeled through the additive composition of feature effects.
|
47 |
+
Under those assumptions, SHAP value can be approximated as the expectation
|
48 |
+
of gradients that are computed for randomly generated `n_samples` input
|
49 |
+
samples after adding gaussian noise `n_samples` times to each input for
|
50 |
+
different baselines/references.
|
51 |
+
|
52 |
+
In some sense it can be viewed as an approximation of integrated gradients
|
53 |
+
by computing the expectations of gradients for different baselines.
|
54 |
+
|
55 |
+
Current implementation uses Smoothgrad from `NoiseTunnel` in order to
|
56 |
+
randomly draw samples from the distribution of baselines, add noise to input
|
57 |
+
samples and compute the expectation (smoothgrad).
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
forward_func: Callable,
|
63 |
+
layer: Module,
|
64 |
+
device_ids: Union[None, List[int]] = None,
|
65 |
+
multiply_by_inputs: bool = True,
|
66 |
+
) -> None:
|
67 |
+
r"""
|
68 |
+
Args:
|
69 |
+
|
70 |
+
forward_func (callable): The forward function of the model or any
|
71 |
+
modification of it
|
72 |
+
layer (torch.nn.Module): Layer for which attributions are computed.
|
73 |
+
Output size of attribute matches this layer's input or
|
74 |
+
output dimensions, depending on whether we attribute to
|
75 |
+
the inputs or outputs of the layer, corresponding to
|
76 |
+
attribution of each neuron in the input or output of
|
77 |
+
this layer.
|
78 |
+
device_ids (list(int)): Device ID list, necessary only if forward_func
|
79 |
+
applies a DataParallel model. This allows reconstruction of
|
80 |
+
intermediate outputs from batched results across devices.
|
81 |
+
If forward_func is given as the DataParallel model itself,
|
82 |
+
then it is not necessary to provide this argument.
|
83 |
+
multiply_by_inputs (bool, optional): Indicates whether to factor
|
84 |
+
model inputs' multiplier in the final attribution scores.
|
85 |
+
In the literature this is also known as local vs global
|
86 |
+
attribution. If inputs' multiplier isn't factored in,
|
87 |
+
then this type of attribution method is also called local
|
88 |
+
attribution. If it is, then that type of attribution
|
89 |
+
method is called global.
|
90 |
+
More detailed can be found here:
|
91 |
+
https://arxiv.org/abs/1711.06104
|
92 |
+
|
93 |
+
In case of layer gradient shap, if `multiply_by_inputs`
|
94 |
+
is set to True, the sensitivity scores for scaled inputs
|
95 |
+
are being multiplied by
|
96 |
+
layer activations for inputs - layer activations for baselines.
|
97 |
+
|
98 |
+
"""
|
99 |
+
LayerAttribution.__init__(self, forward_func, layer, device_ids)
|
100 |
+
GradientAttribution.__init__(self, forward_func)
|
101 |
+
self._multiply_by_inputs = multiply_by_inputs
|
102 |
+
|
103 |
+
@typing.overload
|
104 |
+
def attribute(
|
105 |
+
self,
|
106 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
107 |
+
baselines: Union[TensorOrTupleOfTensorsGeneric, Callable],
|
108 |
+
n_samples: int = 5,
|
109 |
+
stdevs: Union[float, Tuple[float, ...]] = 0.0,
|
110 |
+
target: TargetType = None,
|
111 |
+
additional_forward_args: Any = None,
|
112 |
+
*,
|
113 |
+
return_convergence_delta: Literal[True],
|
114 |
+
attribute_to_layer_input: bool = False,
|
115 |
+
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]:
|
116 |
+
...
|
117 |
+
|
118 |
+
@typing.overload
|
119 |
+
def attribute(
|
120 |
+
self,
|
121 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
122 |
+
baselines: Union[TensorOrTupleOfTensorsGeneric, Callable],
|
123 |
+
n_samples: int = 5,
|
124 |
+
stdevs: Union[float, Tuple[float, ...]] = 0.0,
|
125 |
+
target: TargetType = None,
|
126 |
+
additional_forward_args: Any = None,
|
127 |
+
return_convergence_delta: Literal[False] = False,
|
128 |
+
attribute_to_layer_input: bool = False,
|
129 |
+
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
130 |
+
...
|
131 |
+
|
132 |
+
@log_usage()
|
133 |
+
def attribute(
|
134 |
+
self,
|
135 |
+
inputs: TensorOrTupleOfTensorsGeneric,
|
136 |
+
baselines: Union[TensorOrTupleOfTensorsGeneric, Callable],
|
137 |
+
n_samples: int = 5,
|
138 |
+
stdevs: Union[float, Tuple[float, ...]] = 0.0,
|
139 |
+
target: TargetType = None,
|
140 |
+
additional_forward_args: Any = None,
|
141 |
+
return_convergence_delta: bool = False,
|
142 |
+
attribute_to_layer_input: bool = False,
|
143 |
+
) -> Union[
|
144 |
+
Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
|
145 |
+
]:
|
146 |
+
r"""
|
147 |
+
Args:
|
148 |
+
|
149 |
+
inputs (tensor or tuple of tensors): Input which are used to compute
|
150 |
+
SHAP attribution values for a given `layer`. If `forward_func`
|
151 |
+
takes a single tensor as input, a single input tensor should
|
152 |
+
be provided.
|
153 |
+
If `forward_func` takes multiple tensors as input, a tuple
|
154 |
+
of the input tensors should be provided. It is assumed
|
155 |
+
that for all given input tensors, dimension 0 corresponds
|
156 |
+
to the number of examples, and if multiple input tensors
|
157 |
+
are provided, the examples must be aligned appropriately.
|
158 |
+
baselines (tensor, tuple of tensors, callable):
|
159 |
+
Baselines define the starting point from which expectation
|
160 |
+
is computed and can be provided as:
|
161 |
+
|
162 |
+
- a single tensor, if inputs is a single tensor, with
|
163 |
+
the first dimension equal to the number of examples
|
164 |
+
in the baselines' distribution. The remaining dimensions
|
165 |
+
must match with input tensor's dimension starting from
|
166 |
+
the second dimension.
|
167 |
+
|
168 |
+
- a tuple of tensors, if inputs is a tuple of tensors,
|
169 |
+
with the first dimension of any tensor inside the tuple
|
170 |
+
equal to the number of examples in the baseline's
|
171 |
+
distribution. The remaining dimensions must match
|
172 |
+
the dimensions of the corresponding input tensor
|
173 |
+
starting from the second dimension.
|
174 |
+
|
175 |
+
- callable function, optionally takes `inputs` as an
|
176 |
+
argument and either returns a single tensor
|
177 |
+
or a tuple of those.
|
178 |
+
|
179 |
+
It is recommended that the number of samples in the baselines'
|
180 |
+
tensors is larger than one.
|
181 |
+
n_samples (int, optional): The number of randomly generated examples
|
182 |
+
per sample in the input batch. Random examples are
|
183 |
+
generated by adding gaussian random noise to each sample.
|
184 |
+
Default: `5` if `n_samples` is not provided.
|
185 |
+
stdevs (float, or a tuple of floats optional): The standard deviation
|
186 |
+
of gaussian noise with zero mean that is added to each
|
187 |
+
input in the batch. If `stdevs` is a single float value
|
188 |
+
then that same value is used for all inputs. If it is
|
189 |
+
a tuple, then it must have the same length as the inputs
|
190 |
+
tuple. In this case, each stdev value in the stdevs tuple
|
191 |
+
corresponds to the input with the same index in the inputs
|
192 |
+
tuple.
|
193 |
+
Default: 0.0
|
194 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
195 |
+
which gradients are computed (for classification cases,
|
196 |
+
this is usually the target class).
|
197 |
+
If the network returns a scalar value per example,
|
198 |
+
no target index is necessary.
|
199 |
+
For general 2D outputs, targets can be either:
|
200 |
+
|
201 |
+
- a single integer or a tensor containing a single
|
202 |
+
integer, which is applied to all input examples
|
203 |
+
|
204 |
+
- a list of integers or a 1D tensor, with length matching
|
205 |
+
the number of examples in inputs (dim 0). Each integer
|
206 |
+
is applied as the target for the corresponding example.
|
207 |
+
|
208 |
+
For outputs with > 2 dimensions, targets can be either:
|
209 |
+
|
210 |
+
- A single tuple, which contains #output_dims - 1
|
211 |
+
elements. This target index is applied to all examples.
|
212 |
+
|
213 |
+
- A list of tuples with length equal to the number of
|
214 |
+
examples in inputs (dim 0), and each tuple containing
|
215 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
216 |
+
target for the corresponding example.
|
217 |
+
|
218 |
+
Default: None
|
219 |
+
additional_forward_args (any, optional): If the forward function
|
220 |
+
requires additional arguments other than the inputs for
|
221 |
+
which attributions should not be computed, this argument
|
222 |
+
can be provided. It can contain a tuple of ND tensors or
|
223 |
+
any arbitrary python type of any shape.
|
224 |
+
In case of the ND tensor the first dimension of the
|
225 |
+
tensor must correspond to the batch size. It will be
|
226 |
+
repeated for each `n_steps` for each randomly generated
|
227 |
+
input sample.
|
228 |
+
Note that the attributions are not computed with respect
|
229 |
+
to these arguments.
|
230 |
+
Default: None
|
231 |
+
return_convergence_delta (bool, optional): Indicates whether to return
|
232 |
+
convergence delta or not. If `return_convergence_delta`
|
233 |
+
is set to True convergence delta will be returned in
|
234 |
+
a tuple following attributions.
|
235 |
+
Default: False
|
236 |
+
attribute_to_layer_input (bool, optional): Indicates whether to
|
237 |
+
compute the attribution with respect to the layer input
|
238 |
+
or output. If `attribute_to_layer_input` is set to True
|
239 |
+
then the attributions will be computed with respect to
|
240 |
+
layer input, otherwise it will be computed with respect
|
241 |
+
to layer output.
|
242 |
+
Note that currently it is assumed that either the input
|
243 |
+
or the output of internal layer, depending on whether we
|
244 |
+
attribute to the input or output, is a single tensor.
|
245 |
+
Support for multiple tensors will be added later.
|
246 |
+
Default: False
|
247 |
+
Returns:
|
248 |
+
**attributions** or 2-element tuple of **attributions**, **delta**:
|
249 |
+
- **attributions** (*tensor* or tuple of *tensors*):
|
250 |
+
Attribution score computed based on GradientSHAP with
|
251 |
+
respect to layer's input or output. Attributions will always
|
252 |
+
be the same size as the provided layer's inputs or outputs,
|
253 |
+
depending on whether we attribute to the inputs or outputs
|
254 |
+
of the layer.
|
255 |
+
Attributions are returned in a tuple if
|
256 |
+
the layer inputs / outputs contain multiple tensors,
|
257 |
+
otherwise a single tensor is returned.
|
258 |
+
- **delta** (*tensor*, returned if return_convergence_delta=True):
|
259 |
+
This is computed using the property that the total
|
260 |
+
sum of forward_func(inputs) - forward_func(baselines)
|
261 |
+
must be very close to the total sum of the attributions
|
262 |
+
based on layer gradient SHAP.
|
263 |
+
Delta is calculated for each example in the input after adding
|
264 |
+
`n_samples` times gaussian noise to each of them. Therefore,
|
265 |
+
the dimensionality of the deltas tensor is equal to the
|
266 |
+
`number of examples in the input` * `n_samples`
|
267 |
+
The deltas are ordered by each input example and `n_samples`
|
268 |
+
noisy samples generated for it.
|
269 |
+
|
270 |
+
Examples::
|
271 |
+
|
272 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
273 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
274 |
+
>>> net = ImageClassifier()
|
275 |
+
>>> layer_grad_shap = LayerGradientShap(net, net.linear1)
|
276 |
+
>>> input = torch.randn(3, 3, 32, 32, requires_grad=True)
|
277 |
+
>>> # choosing baselines randomly
|
278 |
+
>>> baselines = torch.randn(20, 3, 32, 32)
|
279 |
+
>>> # Computes gradient SHAP of output layer when target is equal
|
280 |
+
>>> # to 0 with respect to the layer linear1.
|
281 |
+
>>> # Attribution size matches to the size of the linear1 layer
|
282 |
+
>>> attribution = layer_grad_shap.attribute(input, baselines,
|
283 |
+
target=5)
|
284 |
+
|
285 |
+
"""
|
286 |
+
# since `baselines` is a distribution, we can generate it using a function
|
287 |
+
# rather than passing it as an input argument
|
288 |
+
baselines = _format_callable_baseline(baselines, inputs)
|
289 |
+
assert isinstance(baselines[0], torch.Tensor), (
|
290 |
+
"Baselines distribution has to be provided in a form "
|
291 |
+
"of a torch.Tensor {}.".format(baselines[0])
|
292 |
+
)
|
293 |
+
|
294 |
+
input_min_baseline_x_grad = LayerInputBaselineXGradient(
|
295 |
+
self.forward_func,
|
296 |
+
self.layer,
|
297 |
+
device_ids=self.device_ids,
|
298 |
+
multiply_by_inputs=self.multiplies_by_inputs,
|
299 |
+
)
|
300 |
+
|
301 |
+
nt = NoiseTunnel(input_min_baseline_x_grad)
|
302 |
+
|
303 |
+
attributions = nt.attribute.__wrapped__(
|
304 |
+
nt, # self
|
305 |
+
inputs,
|
306 |
+
nt_type="smoothgrad",
|
307 |
+
nt_samples=n_samples,
|
308 |
+
stdevs=stdevs,
|
309 |
+
draw_baseline_from_distrib=True,
|
310 |
+
baselines=baselines,
|
311 |
+
target=target,
|
312 |
+
additional_forward_args=additional_forward_args,
|
313 |
+
return_convergence_delta=return_convergence_delta,
|
314 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
315 |
+
)
|
316 |
+
|
317 |
+
return attributions
|
318 |
+
|
319 |
+
def has_convergence_delta(self) -> bool:
|
320 |
+
return True
|
321 |
+
|
322 |
+
@property
|
323 |
+
def multiplies_by_inputs(self):
|
324 |
+
return self._multiply_by_inputs
|
325 |
+
|
326 |
+
|
327 |
+
class LayerInputBaselineXGradient(LayerAttribution, GradientAttribution):
|
328 |
+
def __init__(
|
329 |
+
self,
|
330 |
+
forward_func: Callable,
|
331 |
+
layer: Module,
|
332 |
+
device_ids: Union[None, List[int]] = None,
|
333 |
+
multiply_by_inputs: bool = True,
|
334 |
+
) -> None:
|
335 |
+
r"""
|
336 |
+
Args:
|
337 |
+
|
338 |
+
forward_func (callable): The forward function of the model or any
|
339 |
+
modification of it
|
340 |
+
layer (torch.nn.Module): Layer for which attributions are computed.
|
341 |
+
Output size of attribute matches this layer's input or
|
342 |
+
output dimensions, depending on whether we attribute to
|
343 |
+
the inputs or outputs of the layer, corresponding to
|
344 |
+
attribution of each neuron in the input or output of
|
345 |
+
this layer.
|
346 |
+
device_ids (list(int)): Device ID list, necessary only if forward_func
|
347 |
+
applies a DataParallel model. This allows reconstruction of
|
348 |
+
intermediate outputs from batched results across devices.
|
349 |
+
If forward_func is given as the DataParallel model itself,
|
350 |
+
then it is not necessary to provide this argument.
|
351 |
+
multiply_by_inputs (bool, optional): Indicates whether to factor
|
352 |
+
model inputs' multiplier in the final attribution scores.
|
353 |
+
In the literature this is also known as local vs global
|
354 |
+
attribution. If inputs' multiplier isn't factored in,
|
355 |
+
then this type of attribution method is also called local
|
356 |
+
attribution. If it is, then that type of attribution
|
357 |
+
method is called global.
|
358 |
+
More detailed can be found here:
|
359 |
+
https://arxiv.org/abs/1711.06104
|
360 |
+
|
361 |
+
In case of layer input minus baseline x gradient,
|
362 |
+
if `multiply_by_inputs` is set to True, the sensitivity scores
|
363 |
+
for scaled inputs are being multiplied by
|
364 |
+
layer activations for inputs - layer activations for baselines.
|
365 |
+
|
366 |
+
"""
|
367 |
+
LayerAttribution.__init__(self, forward_func, layer, device_ids)
|
368 |
+
GradientAttribution.__init__(self, forward_func)
|
369 |
+
self._multiply_by_inputs = multiply_by_inputs
|
370 |
+
|
371 |
+
@typing.overload
|
372 |
+
def attribute(
|
373 |
+
self,
|
374 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
375 |
+
baselines: Union[Tensor, Tuple[Tensor, ...]],
|
376 |
+
target: TargetType = None,
|
377 |
+
additional_forward_args: Any = None,
|
378 |
+
return_convergence_delta: Literal[False] = False,
|
379 |
+
attribute_to_layer_input: bool = False,
|
380 |
+
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
381 |
+
...
|
382 |
+
|
383 |
+
@typing.overload
|
384 |
+
def attribute(
|
385 |
+
self,
|
386 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
387 |
+
baselines: Union[Tensor, Tuple[Tensor, ...]],
|
388 |
+
target: TargetType = None,
|
389 |
+
additional_forward_args: Any = None,
|
390 |
+
*,
|
391 |
+
return_convergence_delta: Literal[True],
|
392 |
+
attribute_to_layer_input: bool = False,
|
393 |
+
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]:
|
394 |
+
...
|
395 |
+
|
396 |
+
@log_usage()
|
397 |
+
def attribute( # type: ignore
|
398 |
+
self,
|
399 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
400 |
+
baselines: Union[Tensor, Tuple[Tensor, ...]],
|
401 |
+
target: TargetType = None,
|
402 |
+
additional_forward_args: Any = None,
|
403 |
+
return_convergence_delta: bool = False,
|
404 |
+
attribute_to_layer_input: bool = False,
|
405 |
+
) -> Union[
|
406 |
+
Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
|
407 |
+
]:
|
408 |
+
inputs, baselines = _format_input_baseline(inputs, baselines)
|
409 |
+
rand_coefficient = torch.tensor(
|
410 |
+
np.random.uniform(0.0, 1.0, inputs[0].shape[0]),
|
411 |
+
device=inputs[0].device,
|
412 |
+
dtype=inputs[0].dtype,
|
413 |
+
)
|
414 |
+
|
415 |
+
input_baseline_scaled = tuple(
|
416 |
+
_scale_input(input, baseline, rand_coefficient)
|
417 |
+
for input, baseline in zip(inputs, baselines)
|
418 |
+
)
|
419 |
+
grads, _ = compute_layer_gradients_and_eval(
|
420 |
+
self.forward_func,
|
421 |
+
self.layer,
|
422 |
+
input_baseline_scaled,
|
423 |
+
target,
|
424 |
+
additional_forward_args,
|
425 |
+
device_ids=self.device_ids,
|
426 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
427 |
+
)
|
428 |
+
|
429 |
+
attr_baselines = _forward_layer_eval(
|
430 |
+
self.forward_func,
|
431 |
+
baselines,
|
432 |
+
self.layer,
|
433 |
+
additional_forward_args=additional_forward_args,
|
434 |
+
device_ids=self.device_ids,
|
435 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
436 |
+
)
|
437 |
+
|
438 |
+
attr_inputs = _forward_layer_eval(
|
439 |
+
self.forward_func,
|
440 |
+
inputs,
|
441 |
+
self.layer,
|
442 |
+
additional_forward_args=additional_forward_args,
|
443 |
+
device_ids=self.device_ids,
|
444 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
445 |
+
)
|
446 |
+
|
447 |
+
if self.multiplies_by_inputs:
|
448 |
+
input_baseline_diffs = tuple(
|
449 |
+
input - baseline for input, baseline in zip(attr_inputs, attr_baselines)
|
450 |
+
)
|
451 |
+
attributions = tuple(
|
452 |
+
input_baseline_diff * grad
|
453 |
+
for input_baseline_diff, grad in zip(input_baseline_diffs, grads)
|
454 |
+
)
|
455 |
+
else:
|
456 |
+
attributions = grads
|
457 |
+
|
458 |
+
return _compute_conv_delta_and_format_attrs(
|
459 |
+
self,
|
460 |
+
return_convergence_delta,
|
461 |
+
attributions,
|
462 |
+
baselines,
|
463 |
+
inputs,
|
464 |
+
additional_forward_args,
|
465 |
+
target,
|
466 |
+
cast(Union[Literal[True], Literal[False]], len(attributions) > 1),
|
467 |
+
)
|
468 |
+
|
469 |
+
def has_convergence_delta(self) -> bool:
|
470 |
+
return True
|
471 |
+
|
472 |
+
@property
|
473 |
+
def multiplies_by_inputs(self):
|
474 |
+
return self._multiply_by_inputs
|
captum/attr/_core/layer/layer_gradient_x_activation.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from typing import Any, Callable, List, Tuple, Union
|
3 |
+
|
4 |
+
from captum._utils.common import (
|
5 |
+
_format_additional_forward_args,
|
6 |
+
_format_output,
|
7 |
+
_format_tensor_into_tuples,
|
8 |
+
)
|
9 |
+
from captum._utils.gradient import compute_layer_gradients_and_eval
|
10 |
+
from captum._utils.typing import ModuleOrModuleList, TargetType
|
11 |
+
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
|
12 |
+
from captum.log import log_usage
|
13 |
+
from torch import Tensor
|
14 |
+
from torch.nn import Module
|
15 |
+
|
16 |
+
|
17 |
+
class LayerGradientXActivation(LayerAttribution, GradientAttribution):
|
18 |
+
r"""
|
19 |
+
Computes element-wise product of gradient and activation for selected
|
20 |
+
layer on given inputs.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
forward_func: Callable,
|
26 |
+
layer: ModuleOrModuleList,
|
27 |
+
device_ids: Union[None, List[int]] = None,
|
28 |
+
multiply_by_inputs: bool = True,
|
29 |
+
) -> None:
|
30 |
+
r"""
|
31 |
+
Args:
|
32 |
+
|
33 |
+
forward_func (callable): The forward function of the model or any
|
34 |
+
modification of it
|
35 |
+
layer (torch.nn.Module or list(torch.nn.Module)): Layer or layers
|
36 |
+
for which attributions are computed.
|
37 |
+
Output size of attribute matches this layer's input or
|
38 |
+
output dimensions, depending on whether we attribute to
|
39 |
+
the inputs or outputs of the layer, corresponding to
|
40 |
+
attribution of each neuron in the input or output of
|
41 |
+
this layer. If multiple layers are provided, attributions
|
42 |
+
are returned as a list, each element corresponding to the
|
43 |
+
attributions of the corresponding layer.
|
44 |
+
device_ids (list(int)): Device ID list, necessary only if forward_func
|
45 |
+
applies a DataParallel model. This allows reconstruction of
|
46 |
+
intermediate outputs from batched results across devices.
|
47 |
+
If forward_func is given as the DataParallel model itself,
|
48 |
+
then it is not necessary to provide this argument.
|
49 |
+
multiply_by_inputs (bool, optional): Indicates whether to factor
|
50 |
+
model inputs' multiplier in the final attribution scores.
|
51 |
+
In the literature this is also known as local vs global
|
52 |
+
attribution. If inputs' multiplier isn't factored in,
|
53 |
+
then this type of attribution method is also called local
|
54 |
+
attribution. If it is, then that type of attribution
|
55 |
+
method is called global.
|
56 |
+
More detailed can be found here:
|
57 |
+
https://arxiv.org/abs/1711.06104
|
58 |
+
|
59 |
+
In case of layer gradient x activation, if `multiply_by_inputs`
|
60 |
+
is set to True, final sensitivity scores are being multiplied by
|
61 |
+
layer activations for inputs.
|
62 |
+
|
63 |
+
"""
|
64 |
+
LayerAttribution.__init__(self, forward_func, layer, device_ids)
|
65 |
+
GradientAttribution.__init__(self, forward_func)
|
66 |
+
self._multiply_by_inputs = multiply_by_inputs
|
67 |
+
|
68 |
+
@property
|
69 |
+
def multiplies_by_inputs(self):
|
70 |
+
return self._multiply_by_inputs
|
71 |
+
|
72 |
+
@log_usage()
|
73 |
+
def attribute(
|
74 |
+
self,
|
75 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
76 |
+
target: TargetType = None,
|
77 |
+
additional_forward_args: Any = None,
|
78 |
+
attribute_to_layer_input: bool = False,
|
79 |
+
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
|
80 |
+
r"""
|
81 |
+
Args:
|
82 |
+
|
83 |
+
inputs (tensor or tuple of tensors): Input for which attributions
|
84 |
+
are computed. If forward_func takes a single
|
85 |
+
tensor as input, a single input tensor should be provided.
|
86 |
+
If forward_func takes multiple tensors as input, a tuple
|
87 |
+
of the input tensors should be provided. It is assumed
|
88 |
+
that for all given input tensors, dimension 0 corresponds
|
89 |
+
to the number of examples, and if multiple input tensors
|
90 |
+
are provided, the examples must be aligned appropriately.
|
91 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
92 |
+
which gradients are computed (for classification cases,
|
93 |
+
this is usually the target class).
|
94 |
+
If the network returns a scalar value per example,
|
95 |
+
no target index is necessary.
|
96 |
+
For general 2D outputs, targets can be either:
|
97 |
+
|
98 |
+
- a single integer or a tensor containing a single
|
99 |
+
integer, which is applied to all input examples
|
100 |
+
|
101 |
+
- a list of integers or a 1D tensor, with length matching
|
102 |
+
the number of examples in inputs (dim 0). Each integer
|
103 |
+
is applied as the target for the corresponding example.
|
104 |
+
|
105 |
+
For outputs with > 2 dimensions, targets can be either:
|
106 |
+
|
107 |
+
- A single tuple, which contains #output_dims - 1
|
108 |
+
elements. This target index is applied to all examples.
|
109 |
+
|
110 |
+
- A list of tuples with length equal to the number of
|
111 |
+
examples in inputs (dim 0), and each tuple containing
|
112 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
113 |
+
target for the corresponding example.
|
114 |
+
|
115 |
+
Default: None
|
116 |
+
additional_forward_args (any, optional): If the forward function
|
117 |
+
requires additional arguments other than the inputs for
|
118 |
+
which attributions should not be computed, this argument
|
119 |
+
can be provided. It must be either a single additional
|
120 |
+
argument of a Tensor or arbitrary (non-tuple) type or a
|
121 |
+
tuple containing multiple additional arguments including
|
122 |
+
tensors or any arbitrary python types. These arguments
|
123 |
+
are provided to forward_func in order following the
|
124 |
+
arguments in inputs.
|
125 |
+
Note that attributions are not computed with respect
|
126 |
+
to these arguments.
|
127 |
+
Default: None
|
128 |
+
attribute_to_layer_input (bool, optional): Indicates whether to
|
129 |
+
compute the attribution with respect to the layer input
|
130 |
+
or output. If `attribute_to_layer_input` is set to True
|
131 |
+
then the attributions will be computed with respect to
|
132 |
+
layer input, otherwise it will be computed with respect
|
133 |
+
to layer output.
|
134 |
+
Default: False
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
*tensor* or tuple of *tensors* or *list* of **attributions**:
|
138 |
+
- **attributions** (*tensor* or tuple of *tensors* or *list*):
|
139 |
+
Product of gradient and activation for each
|
140 |
+
neuron in given layer output.
|
141 |
+
Attributions will always be the same size as the
|
142 |
+
output of the given layer.
|
143 |
+
Attributions are returned in a tuple if
|
144 |
+
the layer inputs / outputs contain multiple tensors,
|
145 |
+
otherwise a single tensor is returned.
|
146 |
+
If multiple layers are provided, attributions
|
147 |
+
are returned as a list, each element corresponding to the
|
148 |
+
activations of the corresponding layer.
|
149 |
+
|
150 |
+
|
151 |
+
Examples::
|
152 |
+
|
153 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
154 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
155 |
+
>>> # It contains an attribute conv1, which is an instance of nn.conv2d,
|
156 |
+
>>> # and the output of this layer has dimensions Nx12x32x32.
|
157 |
+
>>> net = ImageClassifier()
|
158 |
+
>>> layer_ga = LayerGradientXActivation(net, net.conv1)
|
159 |
+
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
|
160 |
+
>>> # Computes layer activation x gradient for class 3.
|
161 |
+
>>> # attribution size matches layer output, Nx12x32x32
|
162 |
+
>>> attribution = layer_ga.attribute(input, 3)
|
163 |
+
"""
|
164 |
+
inputs = _format_tensor_into_tuples(inputs)
|
165 |
+
additional_forward_args = _format_additional_forward_args(
|
166 |
+
additional_forward_args
|
167 |
+
)
|
168 |
+
# Returns gradient of output with respect to
|
169 |
+
# hidden layer and hidden layer evaluated at each input.
|
170 |
+
layer_gradients, layer_evals = compute_layer_gradients_and_eval(
|
171 |
+
self.forward_func,
|
172 |
+
self.layer,
|
173 |
+
inputs,
|
174 |
+
target,
|
175 |
+
additional_forward_args,
|
176 |
+
device_ids=self.device_ids,
|
177 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
178 |
+
)
|
179 |
+
if isinstance(self.layer, Module):
|
180 |
+
return _format_output(
|
181 |
+
len(layer_evals) > 1,
|
182 |
+
self.multiply_gradient_acts(layer_gradients, layer_evals),
|
183 |
+
)
|
184 |
+
else:
|
185 |
+
return [
|
186 |
+
_format_output(
|
187 |
+
len(layer_evals[i]) > 1,
|
188 |
+
self.multiply_gradient_acts(layer_gradients[i], layer_evals[i]),
|
189 |
+
)
|
190 |
+
for i in range(len(self.layer))
|
191 |
+
]
|
192 |
+
|
193 |
+
def multiply_gradient_acts(
|
194 |
+
self, gradients: Tuple[Tensor, ...], evals: Tuple[Tensor, ...]
|
195 |
+
) -> Tuple[Tensor, ...]:
|
196 |
+
return tuple(
|
197 |
+
single_gradient * single_eval
|
198 |
+
if self.multiplies_by_inputs
|
199 |
+
else single_gradient
|
200 |
+
for single_gradient, single_eval in zip(gradients, evals)
|
201 |
+
)
|
captum/attr/_core/layer/layer_integrated_gradients.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import functools
|
3 |
+
import warnings
|
4 |
+
from typing import Any, Callable, List, overload, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from captum._utils.common import (
|
8 |
+
_extract_device,
|
9 |
+
_format_additional_forward_args,
|
10 |
+
_format_outputs,
|
11 |
+
)
|
12 |
+
from captum._utils.gradient import _forward_layer_eval, _run_forward
|
13 |
+
from captum._utils.typing import BaselineType, Literal, ModuleOrModuleList, TargetType
|
14 |
+
from captum.attr._core.integrated_gradients import IntegratedGradients
|
15 |
+
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
|
16 |
+
from captum.attr._utils.common import (
|
17 |
+
_format_input_baseline,
|
18 |
+
_tensorize_baseline,
|
19 |
+
_validate_input,
|
20 |
+
)
|
21 |
+
from captum.log import log_usage
|
22 |
+
from torch import Tensor
|
23 |
+
from torch.nn.parallel.scatter_gather import scatter
|
24 |
+
|
25 |
+
|
26 |
+
class LayerIntegratedGradients(LayerAttribution, GradientAttribution):
|
27 |
+
r"""
|
28 |
+
Layer Integrated Gradients is a variant of Integrated Gradients that assigns
|
29 |
+
an importance score to layer inputs or outputs, depending on whether we
|
30 |
+
attribute to the former or to the latter one.
|
31 |
+
|
32 |
+
Integrated Gradients is an axiomatic model interpretability algorithm that
|
33 |
+
attributes / assigns an importance score to each input feature by approximating
|
34 |
+
the integral of gradients of the model's output with respect to the inputs
|
35 |
+
along the path (straight line) from given baselines / references to inputs.
|
36 |
+
|
37 |
+
Baselines can be provided as input arguments to attribute method.
|
38 |
+
To approximate the integral we can choose to use either a variant of
|
39 |
+
Riemann sum or Gauss-Legendre quadrature rule.
|
40 |
+
|
41 |
+
More details regarding the integrated gradients method can be found in the
|
42 |
+
original paper:
|
43 |
+
https://arxiv.org/abs/1703.01365
|
44 |
+
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
forward_func: Callable,
|
50 |
+
layer: ModuleOrModuleList,
|
51 |
+
device_ids: Union[None, List[int]] = None,
|
52 |
+
multiply_by_inputs: bool = True,
|
53 |
+
) -> None:
|
54 |
+
r"""
|
55 |
+
Args:
|
56 |
+
forward_func (callable): The forward function of the model or any
|
57 |
+
modification of it
|
58 |
+
layer (ModuleOrModuleList):
|
59 |
+
Layer or list of layers for which attributions are computed.
|
60 |
+
For each layer the output size of the attribute matches
|
61 |
+
this layer's input or output dimensions, depending on
|
62 |
+
whether we attribute to the inputs or outputs of the
|
63 |
+
layer, corresponding to the attribution of each neuron
|
64 |
+
in the input or output of this layer.
|
65 |
+
|
66 |
+
Please note that layers to attribute on cannot be
|
67 |
+
dependent on each other. That is, a subset of layers in
|
68 |
+
`layer` cannot produce the inputs for another layer.
|
69 |
+
|
70 |
+
For example, if your model is of a simple linked-list
|
71 |
+
based graph structure (think nn.Sequence), e.g. x -> l1
|
72 |
+
-> l2 -> l3 -> output. If you pass in any one of those
|
73 |
+
layers, you cannot pass in another due to the
|
74 |
+
dependence, e.g. if you pass in l2 you cannot pass in
|
75 |
+
l1 or l3.
|
76 |
+
|
77 |
+
device_ids (list(int)): Device ID list, necessary only if forward_func
|
78 |
+
applies a DataParallel model. This allows reconstruction of
|
79 |
+
intermediate outputs from batched results across devices.
|
80 |
+
If forward_func is given as the DataParallel model itself,
|
81 |
+
then it is not necessary to provide this argument.
|
82 |
+
multiply_by_inputs (bool, optional): Indicates whether to factor
|
83 |
+
model inputs' multiplier in the final attribution scores.
|
84 |
+
In the literature this is also known as local vs global
|
85 |
+
attribution. If inputs' multiplier isn't factored in,
|
86 |
+
then this type of attribution method is also called local
|
87 |
+
attribution. If it is, then that type of attribution
|
88 |
+
method is called global.
|
89 |
+
More detailed can be found here:
|
90 |
+
https://arxiv.org/abs/1711.06104
|
91 |
+
|
92 |
+
In case of layer integrated gradients, if `multiply_by_inputs`
|
93 |
+
is set to True, final sensitivity scores are being multiplied by
|
94 |
+
layer activations for inputs - layer activations for baselines.
|
95 |
+
|
96 |
+
"""
|
97 |
+
LayerAttribution.__init__(self, forward_func, layer, device_ids=device_ids)
|
98 |
+
GradientAttribution.__init__(self, forward_func)
|
99 |
+
self.ig = IntegratedGradients(forward_func, multiply_by_inputs)
|
100 |
+
|
101 |
+
if isinstance(layer, list) and len(layer) > 1:
|
102 |
+
warnings.warn(
|
103 |
+
"Multiple layers provided. Please ensure that each layer is"
|
104 |
+
"**not** solely solely dependent on the outputs of"
|
105 |
+
"another layer. Please refer to the documentation for more"
|
106 |
+
"detail."
|
107 |
+
)
|
108 |
+
|
109 |
+
@overload
|
110 |
+
def attribute(
|
111 |
+
self,
|
112 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
113 |
+
baselines: BaselineType,
|
114 |
+
target: TargetType,
|
115 |
+
additional_forward_args: Any,
|
116 |
+
n_steps: int,
|
117 |
+
method: str,
|
118 |
+
internal_batch_size: Union[None, int],
|
119 |
+
return_convergence_delta: Literal[False],
|
120 |
+
attribute_to_layer_input: bool,
|
121 |
+
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
|
122 |
+
...
|
123 |
+
|
124 |
+
@overload
|
125 |
+
def attribute(
|
126 |
+
self,
|
127 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
128 |
+
baselines: BaselineType,
|
129 |
+
target: TargetType,
|
130 |
+
additional_forward_args: Any,
|
131 |
+
n_steps: int,
|
132 |
+
method: str,
|
133 |
+
internal_batch_size: Union[None, int],
|
134 |
+
return_convergence_delta: Literal[True],
|
135 |
+
attribute_to_layer_input: bool,
|
136 |
+
) -> Tuple[
|
137 |
+
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
|
138 |
+
Tensor,
|
139 |
+
]:
|
140 |
+
...
|
141 |
+
|
142 |
+
@overload
|
143 |
+
def attribute(
|
144 |
+
self,
|
145 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
146 |
+
baselines: BaselineType = None,
|
147 |
+
target: TargetType = None,
|
148 |
+
additional_forward_args: Any = None,
|
149 |
+
n_steps: int = 50,
|
150 |
+
method: str = "gausslegendre",
|
151 |
+
internal_batch_size: Union[None, int] = None,
|
152 |
+
return_convergence_delta: bool = False,
|
153 |
+
attribute_to_layer_input: bool = False,
|
154 |
+
) -> Union[
|
155 |
+
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
|
156 |
+
Tuple[
|
157 |
+
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
|
158 |
+
Tensor,
|
159 |
+
],
|
160 |
+
]:
|
161 |
+
...
|
162 |
+
|
163 |
+
@log_usage()
|
164 |
+
def attribute(
|
165 |
+
self,
|
166 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
167 |
+
baselines: BaselineType = None,
|
168 |
+
target: TargetType = None,
|
169 |
+
additional_forward_args: Any = None,
|
170 |
+
n_steps: int = 50,
|
171 |
+
method: str = "gausslegendre",
|
172 |
+
internal_batch_size: Union[None, int] = None,
|
173 |
+
return_convergence_delta: bool = False,
|
174 |
+
attribute_to_layer_input: bool = False,
|
175 |
+
) -> Union[
|
176 |
+
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
|
177 |
+
Tuple[
|
178 |
+
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
|
179 |
+
Tensor,
|
180 |
+
],
|
181 |
+
]:
|
182 |
+
r"""
|
183 |
+
This method attributes the output of the model with given target index
|
184 |
+
(in case it is provided, otherwise it assumes that output is a
|
185 |
+
scalar) to layer inputs or outputs of the model, depending on whether
|
186 |
+
`attribute_to_layer_input` is set to True or False, using the approach
|
187 |
+
described above.
|
188 |
+
|
189 |
+
In addition to that it also returns, if `return_convergence_delta` is
|
190 |
+
set to True, integral approximation delta based on the completeness
|
191 |
+
property of integrated gradients.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
|
195 |
+
inputs (tensor or tuple of tensors): Input for which layer integrated
|
196 |
+
gradients are computed. If forward_func takes a single
|
197 |
+
tensor as input, a single input tensor should be provided.
|
198 |
+
If forward_func takes multiple tensors as input, a tuple
|
199 |
+
of the input tensors should be provided. It is assumed
|
200 |
+
that for all given input tensors, dimension 0 corresponds
|
201 |
+
to the number of examples, and if multiple input tensors
|
202 |
+
are provided, the examples must be aligned appropriately.
|
203 |
+
baselines (scalar, tensor, tuple of scalars or tensors, optional):
|
204 |
+
Baselines define the starting point from which integral
|
205 |
+
is computed and can be provided as:
|
206 |
+
|
207 |
+
- a single tensor, if inputs is a single tensor, with
|
208 |
+
exactly the same dimensions as inputs or the first
|
209 |
+
dimension is one and the remaining dimensions match
|
210 |
+
with inputs.
|
211 |
+
|
212 |
+
- a single scalar, if inputs is a single tensor, which will
|
213 |
+
be broadcasted for each input value in input tensor.
|
214 |
+
|
215 |
+
- a tuple of tensors or scalars, the baseline corresponding
|
216 |
+
to each tensor in the inputs' tuple can be:
|
217 |
+
- either a tensor with matching dimensions to
|
218 |
+
corresponding tensor in the inputs' tuple
|
219 |
+
or the first dimension is one and the remaining
|
220 |
+
dimensions match with the corresponding
|
221 |
+
input tensor.
|
222 |
+
- or a scalar, corresponding to a tensor in the
|
223 |
+
inputs' tuple. This scalar value is broadcasted
|
224 |
+
for corresponding input tensor.
|
225 |
+
|
226 |
+
In the cases when `baselines` is not provided, we internally
|
227 |
+
use zero scalar corresponding to each input tensor.
|
228 |
+
|
229 |
+
Default: None
|
230 |
+
target (int, tuple, tensor or list, optional): Output indices for
|
231 |
+
which gradients are computed (for classification cases,
|
232 |
+
this is usually the target class).
|
233 |
+
If the network returns a scalar value per example,
|
234 |
+
no target index is necessary.
|
235 |
+
For general 2D outputs, targets can be either:
|
236 |
+
|
237 |
+
- a single integer or a tensor containing a single
|
238 |
+
integer, which is applied to all input examples
|
239 |
+
|
240 |
+
- a list of integers or a 1D tensor, with length matching
|
241 |
+
the number of examples in inputs (dim 0). Each integer
|
242 |
+
is applied as the target for the corresponding example.
|
243 |
+
|
244 |
+
For outputs with > 2 dimensions, targets can be either:
|
245 |
+
|
246 |
+
- A single tuple, which contains #output_dims - 1
|
247 |
+
elements. This target index is applied to all examples.
|
248 |
+
|
249 |
+
- A list of tuples with length equal to the number of
|
250 |
+
examples in inputs (dim 0), and each tuple containing
|
251 |
+
#output_dims - 1 elements. Each tuple is applied as the
|
252 |
+
target for the corresponding example.
|
253 |
+
|
254 |
+
Default: None
|
255 |
+
additional_forward_args (any, optional): If the forward function
|
256 |
+
requires additional arguments other than the inputs for
|
257 |
+
which attributions should not be computed, this argument
|
258 |
+
can be provided. It must be either a single additional
|
259 |
+
argument of a Tensor or arbitrary (non-tuple) type or a
|
260 |
+
tuple containing multiple additional arguments including
|
261 |
+
tensors or any arbitrary python types. These arguments
|
262 |
+
are provided to forward_func in order following the
|
263 |
+
arguments in inputs.
|
264 |
+
For a tensor, the first dimension of the tensor must
|
265 |
+
correspond to the number of examples. It will be
|
266 |
+
repeated for each of `n_steps` along the integrated
|
267 |
+
path. For all other types, the given argument is used
|
268 |
+
for all forward evaluations.
|
269 |
+
Note that attributions are not computed with respect
|
270 |
+
to these arguments.
|
271 |
+
Default: None
|
272 |
+
n_steps (int, optional): The number of steps used by the approximation
|
273 |
+
method. Default: 50.
|
274 |
+
method (string, optional): Method for approximating the integral,
|
275 |
+
one of `riemann_right`, `riemann_left`, `riemann_middle`,
|
276 |
+
`riemann_trapezoid` or `gausslegendre`.
|
277 |
+
Default: `gausslegendre` if no method is provided.
|
278 |
+
internal_batch_size (int, optional): Divides total #steps * #examples
|
279 |
+
data points into chunks of size at most internal_batch_size,
|
280 |
+
which are computed (forward / backward passes)
|
281 |
+
sequentially. internal_batch_size must be at least equal to
|
282 |
+
#examples.
|
283 |
+
For DataParallel models, each batch is split among the
|
284 |
+
available devices, so evaluations on each available
|
285 |
+
device contain internal_batch_size / num_devices examples.
|
286 |
+
If internal_batch_size is None, then all evaluations are
|
287 |
+
processed in one batch.
|
288 |
+
Default: None
|
289 |
+
return_convergence_delta (bool, optional): Indicates whether to return
|
290 |
+
convergence delta or not. If `return_convergence_delta`
|
291 |
+
is set to True convergence delta will be returned in
|
292 |
+
a tuple following attributions.
|
293 |
+
Default: False
|
294 |
+
attribute_to_layer_input (bool, optional): Indicates whether to
|
295 |
+
compute the attribution with respect to the layer input
|
296 |
+
or output. If `attribute_to_layer_input` is set to True
|
297 |
+
then the attributions will be computed with respect to
|
298 |
+
layer input, otherwise it will be computed with respect
|
299 |
+
to layer output.
|
300 |
+
Note that currently it is assumed that either the input
|
301 |
+
or the output of internal layer, depending on whether we
|
302 |
+
attribute to the input or output, is a single tensor.
|
303 |
+
Support for multiple tensors will be added later.
|
304 |
+
Default: False
|
305 |
+
Returns:
|
306 |
+
**attributions** or 2-element tuple of **attributions**, **delta**:
|
307 |
+
- **attributions** (*tensor*, tuple of *tensors* or tuple of *tensors*):
|
308 |
+
Integrated gradients with respect to `layer`'s inputs or
|
309 |
+
outputs. Attributions will always be the same size and
|
310 |
+
dimensionality as the input or output of the given layer,
|
311 |
+
depending on whether we attribute to the inputs or outputs
|
312 |
+
of the layer which is decided by the input flag
|
313 |
+
`attribute_to_layer_input`.
|
314 |
+
|
315 |
+
For a single layer, attributions are returned in a tuple if
|
316 |
+
the layer inputs / outputs contain multiple tensors,
|
317 |
+
otherwise a single tensor is returned.
|
318 |
+
|
319 |
+
For multiple layers, attributions will always be
|
320 |
+
returned as a list. Each element in this list will be
|
321 |
+
equivalent to that of a single layer output, i.e. in the
|
322 |
+
case that one layer, in the given layers, inputs / outputs
|
323 |
+
multiple tensors: the corresponding output element will be
|
324 |
+
a tuple of tensors. The ordering of the outputs will be
|
325 |
+
the same order as the layers given in the constructor.
|
326 |
+
- **delta** (*tensor*, returned if return_convergence_delta=True):
|
327 |
+
The difference between the total approximated and true
|
328 |
+
integrated gradients. This is computed using the property
|
329 |
+
that the total sum of forward_func(inputs) -
|
330 |
+
forward_func(baselines) must equal the total sum of the
|
331 |
+
integrated gradient.
|
332 |
+
Delta is calculated per example, meaning that the number of
|
333 |
+
elements in returned delta tensor is equal to the number of
|
334 |
+
of examples in inputs.
|
335 |
+
|
336 |
+
Examples::
|
337 |
+
|
338 |
+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
|
339 |
+
>>> # and returns an Nx10 tensor of class probabilities.
|
340 |
+
>>> # It contains an attribute conv1, which is an instance of nn.conv2d,
|
341 |
+
>>> # and the output of this layer has dimensions Nx12x32x32.
|
342 |
+
>>> net = ImageClassifier()
|
343 |
+
>>> lig = LayerIntegratedGradients(net, net.conv1)
|
344 |
+
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
|
345 |
+
>>> # Computes layer integrated gradients for class 3.
|
346 |
+
>>> # attribution size matches layer output, Nx12x32x32
|
347 |
+
>>> attribution = lig.attribute(input, target=3)
|
348 |
+
"""
|
349 |
+
inps, baselines = _format_input_baseline(inputs, baselines)
|
350 |
+
_validate_input(inps, baselines, n_steps, method)
|
351 |
+
|
352 |
+
baselines = _tensorize_baseline(inps, baselines)
|
353 |
+
additional_forward_args = _format_additional_forward_args(
|
354 |
+
additional_forward_args
|
355 |
+
)
|
356 |
+
|
357 |
+
def flatten_tuple(tup):
|
358 |
+
return tuple(
|
359 |
+
sum((list(x) if isinstance(x, (tuple, list)) else [x] for x in tup), [])
|
360 |
+
)
|
361 |
+
|
362 |
+
if self.device_ids is None:
|
363 |
+
self.device_ids = getattr(self.forward_func, "device_ids", None)
|
364 |
+
|
365 |
+
inputs_layer = _forward_layer_eval(
|
366 |
+
self.forward_func,
|
367 |
+
inps,
|
368 |
+
self.layer,
|
369 |
+
device_ids=self.device_ids,
|
370 |
+
additional_forward_args=additional_forward_args,
|
371 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
372 |
+
)
|
373 |
+
|
374 |
+
# if we have one output
|
375 |
+
if not isinstance(self.layer, list):
|
376 |
+
inputs_layer = (inputs_layer,)
|
377 |
+
|
378 |
+
num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in inputs_layer]
|
379 |
+
num_outputs_cumsum = torch.cumsum(
|
380 |
+
torch.IntTensor([0] + num_outputs), dim=0 # type: ignore
|
381 |
+
)
|
382 |
+
inputs_layer = flatten_tuple(inputs_layer)
|
383 |
+
|
384 |
+
baselines_layer = _forward_layer_eval(
|
385 |
+
self.forward_func,
|
386 |
+
baselines,
|
387 |
+
self.layer,
|
388 |
+
device_ids=self.device_ids,
|
389 |
+
additional_forward_args=additional_forward_args,
|
390 |
+
attribute_to_layer_input=attribute_to_layer_input,
|
391 |
+
)
|
392 |
+
baselines_layer = flatten_tuple(baselines_layer)
|
393 |
+
|
394 |
+
# inputs -> these inputs are scaled
|
395 |
+
def gradient_func(
|
396 |
+
forward_fn: Callable,
|
397 |
+
inputs: Union[Tensor, Tuple[Tensor, ...]],
|
398 |
+
target_ind: TargetType = None,
|
399 |
+
additional_forward_args: Any = None,
|
400 |
+
) -> Tuple[Tensor, ...]:
|
401 |
+
if self.device_ids is None or len(self.device_ids) == 0:
|
402 |
+
scattered_inputs = (inputs,)
|
403 |
+
else:
|
404 |
+
# scatter method does not have a precise enough return type in its
|
405 |
+
# stub, so suppress the type warning.
|
406 |
+
scattered_inputs = scatter( # type:ignore
|
407 |
+
inputs, target_gpus=self.device_ids
|
408 |
+
)
|
409 |
+
|
410 |
+
scattered_inputs_dict = {
|
411 |
+
scattered_input[0].device: scattered_input
|
412 |
+
for scattered_input in scattered_inputs
|
413 |
+
}
|
414 |
+
|
415 |
+
with torch.autograd.set_grad_enabled(True):
|
416 |
+
|
417 |
+
def layer_forward_hook(
|
418 |
+
module, hook_inputs, hook_outputs=None, layer_idx=0
|
419 |
+
):
|
420 |
+
device = _extract_device(module, hook_inputs, hook_outputs)
|
421 |
+
is_layer_tuple = (
|
422 |
+
isinstance(hook_outputs, tuple)
|
423 |
+
# hook_outputs is None if attribute_to_layer_input == True
|
424 |
+
if hook_outputs is not None
|
425 |
+
else isinstance(hook_inputs, tuple)
|
426 |
+
)
|
427 |
+
|
428 |
+
if is_layer_tuple:
|
429 |
+
return scattered_inputs_dict[device][
|
430 |
+
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
|
431 |
+
layer_idx + 1
|
432 |
+
]
|
433 |
+
]
|
434 |
+
|
435 |
+
return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]]
|
436 |
+
|
437 |
+
hooks = []
|
438 |
+
try:
|
439 |
+
|
440 |
+
layers = self.layer
|
441 |
+
if not isinstance(layers, list):
|
442 |
+
layers = [self.layer]
|
443 |
+
|
444 |
+
for layer_idx, layer in enumerate(layers):
|
445 |
+
hook = None
|
446 |
+
# TODO:
|
447 |
+
# Allow multiple attribute_to_layer_input flags for
|
448 |
+
# each layer, i.e. attribute_to_layer_input[layer_idx]
|
449 |
+
if attribute_to_layer_input:
|
450 |
+
hook = layer.register_forward_pre_hook(
|
451 |
+
functools.partial(
|
452 |
+
layer_forward_hook, layer_idx=layer_idx
|
453 |
+
)
|
454 |
+
)
|
455 |
+
else:
|
456 |
+
hook = layer.register_forward_hook(
|
457 |
+
functools.partial(
|
458 |
+
layer_forward_hook, layer_idx=layer_idx
|
459 |
+
)
|
460 |
+
)
|
461 |
+
|
462 |
+
hooks.append(hook)
|
463 |
+
|
464 |
+
output = _run_forward(
|
465 |
+
self.forward_func, tuple(), target_ind, additional_forward_args
|
466 |
+
)
|
467 |
+
finally:
|
468 |
+
for hook in hooks:
|
469 |
+
if hook is not None:
|
470 |
+
hook.remove()
|
471 |
+
|
472 |
+
assert output[0].numel() == 1, (
|
473 |
+
"Target not provided when necessary, cannot"
|
474 |
+
" take gradient with respect to multiple outputs."
|
475 |
+
)
|
476 |
+
# torch.unbind(forward_out) is a list of scalar tensor tuples and
|
477 |
+
# contains batch_size * #steps elements
|
478 |
+
grads = torch.autograd.grad(torch.unbind(output), inputs)
|
479 |
+
return grads
|
480 |
+
|
481 |
+
self.ig.gradient_func = gradient_func
|
482 |
+
all_inputs = (
|
483 |
+
(inps + additional_forward_args)
|
484 |
+
if additional_forward_args is not None
|
485 |
+
else inps
|
486 |
+
)
|
487 |
+
|
488 |
+
attributions = self.ig.attribute.__wrapped__( # type: ignore
|
489 |
+
self.ig, # self
|
490 |
+
inputs_layer,
|
491 |
+
baselines=baselines_layer,
|
492 |
+
target=target,
|
493 |
+
additional_forward_args=all_inputs,
|
494 |
+
n_steps=n_steps,
|
495 |
+
method=method,
|
496 |
+
internal_batch_size=internal_batch_size,
|
497 |
+
return_convergence_delta=False,
|
498 |
+
)
|
499 |
+
|
500 |
+
# handle multiple outputs
|
501 |
+
output: List[Tuple[Tensor, ...]] = [
|
502 |
+
tuple(
|
503 |
+
attributions[
|
504 |
+
int(num_outputs_cumsum[i]) : int(num_outputs_cumsum[i + 1])
|
505 |
+
]
|
506 |
+
)
|
507 |
+
for i in range(len(num_outputs))
|
508 |
+
]
|
509 |
+
|
510 |
+
if return_convergence_delta:
|
511 |
+
start_point, end_point = baselines, inps
|
512 |
+
# computes approximation error based on the completeness axiom
|
513 |
+
delta = self.compute_convergence_delta(
|
514 |
+
attributions,
|
515 |
+
start_point,
|
516 |
+
end_point,
|
517 |
+
additional_forward_args=additional_forward_args,
|
518 |
+
target=target,
|
519 |
+
)
|
520 |
+
return _format_outputs(isinstance(self.layer, list), output), delta
|
521 |
+
return _format_outputs(isinstance(self.layer, list), output)
|
522 |
+
|
523 |
+
def has_convergence_delta(self) -> bool:
|
524 |
+
return True
|
525 |
+
|
526 |
+
@property
|
527 |
+
def multiplies_by_inputs(self):
|
528 |
+
return self.ig.multiplies_by_inputs
|