diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..5aba95a3b936cfe51dc58f7ba8ee7338139ef6bc --- /dev/null +++ b/.gitignore @@ -0,0 +1,224 @@ +**/saved_models/* +**/data_lmdb_release/* +**/image_release/* +**/vitstr_base_patch* +**/result/* +**/results/* +**/oldData/ +*.mdb +*.xlsx +*.pth +*.json +*.pkl +*.tar +*.ipynb +*.zip +*.eps +*.pdf +**/grcnn_straug/* +**/augmentation/results/* +**/tmp/* +*.sh +**/__pycache__ +workdir/ +.remote-sync.json +*.png +pretrained/ +attributionImgs/ +attributionImgsOld/ +attrSelectivityOld/ + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### OSX ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +### Python Patch ### +.venv/ + +### Python.VirtualEnv Stack ### +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +[Bb]in +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +[Ss]cripts +pyvenv.cfg +pip-selfcheck.json + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk diff --git a/attribution_ops.py b/attribution_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..1cd903c293f261b0b8e95c8be926e271ed88e241 --- /dev/null +++ b/attribution_ops.py @@ -0,0 +1,87 @@ +import os +import pickle +from captum_improve_vitstr import rankedAttributionsBySegm +import matplotlib.pyplot as plt +from skimage.color import gray2rgb +from captum.attr._utils.visualization import visualize_image_attr +import torch +import numpy as np + +def attr_one_dataset(): + modelName = "vitstr" + datasetName = "IIIT5k_3000" + + rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/" + attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/" + if not os.path.exists(attrOutputImgs): + os.makedirs(attrOutputImgs) + + minNumber = 1000000 + maxNumber = 0 + # From a folder containing saved attribution pickle files, convert them into attribution images + for path, subdirs, files in os.walk(rootDir): + for name in files: + fullfilename = os.path.join(rootDir, name) # Value + # fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl + partfilename = fullfilename[fullfilename.rfind('/')+1:] + print("fullfilename: ", fullfilename) + imgNum = int(partfilename.split('_')[0]) + attrImgName = partfilename.replace('.pkl', '.png') + minNumber = min(minNumber, imgNum) + maxNumber = max(maxNumber, imgNum) + with open(fullfilename, 'rb') as f: + pklData = pickle.load(f) + attributions = pklData['attribution'] + segmDataNP = pklData['segmData'] + origImgNP = pklData['origImg'] + if np.isnan(attributions).any(): + continue + attributions = torch.from_numpy(attributions) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(attrOutputImgs + attrImgName) + mplotfig.clear() + plt.close(mplotfig) + +def attr_all_dataset(): + modelName = "vitstr" + + datasetNameList = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] + + for datasetName in datasetNameList: + rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/" + attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/" + if not os.path.exists(attrOutputImgs): + os.makedirs(attrOutputImgs) + + minNumber = 1000000 + maxNumber = 0 + # From a folder containing saved attribution pickle files, convert them into attribution images + for path, subdirs, files in os.walk(rootDir): + for name in files: + fullfilename = os.path.join(rootDir, name) # Value + # fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl + partfilename = fullfilename[fullfilename.rfind('/')+1:] + imgNum = int(partfilename.split('_')[0]) + attrImgName = partfilename.replace('.pkl', '.png') + minNumber = min(minNumber, imgNum) + maxNumber = max(maxNumber, imgNum) + with open(fullfilename, 'rb') as f: + pklData = pickle.load(f) + attributions = pklData['attribution'] + segmDataNP = pklData['segmData'] + origImgNP = pklData['origImg'] + attributions = torch.from_numpy(attributions) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(attrOutputImgs + attrImgName) + mplotfig.clear() + plt.close(mplotfig) + +if __name__ == '__main__': + attr_one_dataset() + # attr_all_dataset() diff --git a/augmentation/blur.py b/augmentation/blur.py new file mode 100644 index 0000000000000000000000000000000000000000..e09b1387169a9752a74d34e2e9e74ed00ba2c052 --- /dev/null +++ b/augmentation/blur.py @@ -0,0 +1,189 @@ + +import cv2 +import numpy as np +from PIL import Image, ImageOps +import torchvision.transforms as transforms +from wand.image import Image as WandImage +from scipy.ndimage import zoom as scizoom +from skimage.filters import gaussian +from wand.api import library as wandlibrary +from io import BytesIO + +#from skimage import color +from .ops import MotionImage, clipped_zoom, disk, plasma_fractal +''' + PIL resize (W,H) +''' +class GaussianBlur: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + #kernel = [(31,31)] prev 1 level only + kernel = (31, 31) + sigmas = [.5, 1, 2] + if mag<0 or mag>=len(kernel): + index = np.random.randint(0, len(sigmas)) + else: + index = mag + + sigma = sigmas[index] + return transforms.GaussianBlur(kernel_size=kernel, sigma=sigma)(img) + + +class DefocusBlur: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + n_channels = len(img.getbands()) + isgray = n_channels == 1 + #c = [(3, 0.1), (4, 0.5), (6, 0.5), (8, 0.5), (10, 0.5)] + c = [(2, 0.1), (3, 0.1), (4, 0.1)] #, (6, 0.5)] #prev 2 levels only + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + c = c[index] + + img = np.array(img) / 255. + if isgray: + img = np.expand_dims(img, axis=2) + img = np.repeat(img, 3, axis=2) + n_channels = 3 + kernel = disk(radius=c[0], alias_blur=c[1]) + + channels = [] + for d in range(n_channels): + channels.append(cv2.filter2D(img[:, :, d], -1, kernel)) + channels = np.array(channels).transpose((1, 2, 0)) # 3x224x224 -> 224x224x3 + + #if isgray: + # img = img[:,:,0] + # img = np.squeeze(img) + + img = np.clip(channels, 0, 1) * 255 + img = Image.fromarray(img.astype(np.uint8)) + if isgray: + img = ImageOps.grayscale(img) + + return img + + +class MotionBlur: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + n_channels = len(img.getbands()) + isgray = n_channels == 1 + #c = [(10, 3), (15, 5), (15, 8), (15, 12), (20, 15)] + c = [(10, 3), (12, 4), (14, 5)] + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + c = c[index] + + output = BytesIO() + img.save(output, format='PNG') + img = MotionImage(blob=output.getvalue()) + + img.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45)) + img = cv2.imdecode(np.fromstring(img.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED) + if len(img.shape) > 2: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = Image.fromarray(img.astype(np.uint8)) + + if isgray: + img = ImageOps.grayscale(img) + + return img + +class GlassBlur: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + #c = [(0.7, 1, 2), (0.9, 2, 1), (1, 2, 3), (1.1, 3, 2), (1.5, 4, 2)][severity - 1] + c = [(0.7, 1, 2), (0.75, 1, 2), (0.8, 1, 2)] #, (1, 2, 3)] #prev 2 levels only + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + + c = c[index] + + img = np.uint8(gaussian(np.array(img) / 255., sigma=c[0], multichannel=True) * 255) + + # locally shuffle pixels + for i in range(c[2]): + for h in range(H - c[1], c[1], -1): + for w in range(W - c[1], c[1], -1): + dx, dy = np.random.randint(-c[1], c[1], size=(2,)) + h_prime, w_prime = h + dy, w + dx + # swap + img[h, w], img[h_prime, w_prime] = img[h_prime, w_prime], img[h, w] + + img = np.clip(gaussian(img / 255., sigma=c[0], multichannel=True), 0, 1) * 255 + return Image.fromarray(img.astype(np.uint8)) + + +class ZoomBlur: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + c = [np.arange(1, 1.11, .01), + np.arange(1, 1.16, .01), + np.arange(1, 1.21, .02)] + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + + c = c[index] + + n_channels = len(img.getbands()) + isgray = n_channels == 1 + + uint8_img = img + img = (np.array(img) / 255.).astype(np.float32) + + out = np.zeros_like(img) + for zoom_factor in c: + ZW = int(W*zoom_factor) + ZH = int(H*zoom_factor) + zoom_img = uint8_img.resize((ZW, ZH), Image.BICUBIC) + x1 = (ZW - W) // 2 + y1 = (ZH - H) // 2 + x2 = x1 + W + y2 = y1 + H + zoom_img = zoom_img.crop((x1,y1,x2,y2)) + out += (np.array(zoom_img) / 255.).astype(np.float32) + + img = (img + out) / (len(c) + 1) + + img = np.clip(img, 0, 1) * 255 + img = Image.fromarray(img.astype(np.uint8)) + + return img diff --git a/augmentation/camera.py b/augmentation/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..0adc925f4a38bbf8d59eeafb99c2e4530316582e --- /dev/null +++ b/augmentation/camera.py @@ -0,0 +1,120 @@ + +import cv2 +import numpy as np +import skimage as sk +from PIL import Image, ImageOps +from io import BytesIO + +from skimage import color +''' + PIL resize (W,H) + cv2 image is BGR + PIL image is RGB +''' +class Contrast: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + #c = [0.4, .3, .2, .1, .05] + c = [0.4, .3, .2] + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + c = c[index] + img = np.array(img) / 255. + means = np.mean(img, axis=(0, 1), keepdims=True) + img = np.clip((img - means) * c + means, 0, 1) * 255 + + return Image.fromarray(img.astype(np.uint8)) + + +class Brightness: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + #W, H = img.size + #c = [.1, .2, .3, .4, .5] + c = [.1, .2, .3] + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + c = c[index] + + n_channels = len(img.getbands()) + isgray = n_channels == 1 + + img = np.array(img) / 255. + if isgray: + img = np.expand_dims(img, axis=2) + img = np.repeat(img, 3, axis=2) + + img = sk.color.rgb2hsv(img) + img[:, :, 2] = np.clip(img[:, :, 2] + c, 0, 1) + img = sk.color.hsv2rgb(img) + + #if isgray: + # img = img[:,:,0] + # img = np.squeeze(img) + + img = np.clip(img, 0, 1) * 255 + img = Image.fromarray(img.astype(np.uint8)) + if isgray: + img = ImageOps.grayscale(img) + + return img + #if isgray: + #if isgray: + # img = color.rgb2gray(img) + + #return Image.fromarray(img.astype(np.uint8)) + + +class JpegCompression: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + #c = [25, 18, 15, 10, 7] + c = [25, 18, 15] + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + c = c[index] + output = BytesIO() + img.save(output, 'JPEG', quality=c) + return Image.open(output) + + +class Pixelate: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + #c = [0.6, 0.5, 0.4, 0.3, 0.25] + c = [0.6, 0.5, 0.4] + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + c = c[index] + img = img.resize((int(W* c), int(H * c)), Image.BOX) + return img.resize((W, H), Image.BOX) + diff --git a/augmentation/frost/frost4.jpg b/augmentation/frost/frost4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f8b0c413176d70150b593e029d84b4a88c21dd4b Binary files /dev/null and b/augmentation/frost/frost4.jpg differ diff --git a/augmentation/frost/frost5.jpg b/augmentation/frost/frost5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..95dc9056926d8201df760535f9bb9112f012e862 Binary files /dev/null and b/augmentation/frost/frost5.jpg differ diff --git a/augmentation/frost/frost6.jpg b/augmentation/frost/frost6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..14e5d58e762a5d0808df9fa6494fd6d78ee4409b Binary files /dev/null and b/augmentation/frost/frost6.jpg differ diff --git a/augmentation/geometry.py b/augmentation/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..50ef811186e0c92165cf6f55b1da723c5dedc9e0 --- /dev/null +++ b/augmentation/geometry.py @@ -0,0 +1,233 @@ + +import cv2 +import numpy as np +from PIL import Image, ImageOps + +''' + PIL resize (W,H) + Torch resize is (H,W) +''' +class Shrink: + def __init__(self): + self.tps = cv2.createThinPlateSplineShapeTransformer() + self.translateXAbs = TranslateXAbs() + self.translateYAbs = TranslateYAbs() + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + img = np.array(img) + srcpt = list() + dstpt = list() + + W_33 = 0.33 * W + W_50 = 0.50 * W + W_66 = 0.66 * W + + H_50 = 0.50 * H + + P = 0 + + #frac = 0.4 + + b = [.2, .3, .4] + if mag<0 or mag>=len(b): + index = 0 + else: + index = mag + frac = b[index] + + # left-most + srcpt.append([P, P]) + srcpt.append([P, H-P]) + x = np.random.uniform(frac-.1, frac)*W_33 + y = np.random.uniform(frac-.1, frac)*H_50 + dstpt.append([P+x, P+y]) + dstpt.append([P+x, H-P-y]) + + # 2nd left-most + srcpt.append([P+W_33, P]) + srcpt.append([P+W_33, H-P]) + dstpt.append([P+W_33, P+y]) + dstpt.append([P+W_33, H-P-y]) + + # 3rd left-most + srcpt.append([P+W_66, P]) + srcpt.append([P+W_66, H-P]) + dstpt.append([P+W_66, P+y]) + dstpt.append([P+W_66, H-P-y]) + + # right-most + srcpt.append([W-P, P]) + srcpt.append([W-P, H-P]) + dstpt.append([W-P-x, P+y]) + dstpt.append([W-P-x, H-P-y]) + + N = len(dstpt) + matches = [cv2.DMatch(i, i, 0) for i in range(N)] + dst_shape = np.array(dstpt).reshape((-1, N, 2)) + src_shape = np.array(srcpt).reshape((-1, N, 2)) + self.tps.estimateTransformation(dst_shape, src_shape, matches) + img = self.tps.warpImage(img) + img = Image.fromarray(img) + + if np.random.uniform(0, 1) < 0.5: + img = self.translateXAbs(img, val=x) + else: + img = self.translateYAbs(img, val=y) + + return img + + +class Rotate: + def __init__(self, square_side=224): + self.side = square_side + + def __call__(self, img, iscurve=False, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + + if H!=self.side or W!=self.side: + img = img.resize((self.side, self.side), Image.BICUBIC) + + b = [20., 40, 60] + if mag<0 or mag>=len(b): + index = 1 + else: + index = mag + rotate_angle = b[index] + + angle = np.random.uniform(rotate_angle-20, rotate_angle) + if np.random.uniform(0, 1) < 0.5: + angle = -angle + + #angle = np.random.normal(loc=0., scale=rotate_angle) + #angle = min(angle, 2*rotate_angle) + #angle = max(angle, -2*rotate_angle) + + expand = False if iscurve else True + img = img.rotate(angle=angle, resample=Image.BICUBIC, expand=expand) + img = img.resize((W, H), Image.BICUBIC) + + return img + +class Perspective: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + + # upper-left, upper-right, lower-left, lower-right + src = np.float32([[0, 0], [W, 0], [0, H], [W, H]]) + #low = 0.3 + + b = [.1, .2, .3] + if mag<0 or mag>=len(b): + index = 2 + else: + index = mag + low = b[index] + + high = 1 - low + if np.random.uniform(0, 1) > 0.5: + toprightY = np.random.uniform(low, low+.1)*H + bottomrightY = np.random.uniform(high-.1, high)*H + dest = np.float32([[0, 0], [W, toprightY], [0, H], [W, bottomrightY]]) + else: + topleftY = np.random.uniform(low, low+.1)*H + bottomleftY = np.random.uniform(high-.1, high)*H + dest = np.float32([[0, topleftY], [W, 0], [0, bottomleftY], [W, H]]) + M = cv2.getPerspectiveTransform(src, dest) + img = np.array(img) + img = cv2.warpPerspective(img, M, (W, H) ) + img = Image.fromarray(img) + + return img + + +class TranslateX: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + b = [.03, .06, .09] + if mag<0 or mag>=len(b): + index = 2 + else: + index = mag + v = b[index] + v = np.random.uniform(v-0.03, v) + + v = v * img.size[0] + if np.random.uniform(0,1) > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) + + +class TranslateY: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + b = [.07, .14, .21] + if mag<0 or mag>=len(b): + index = 2 + else: + index = mag + v = b[index] + v = np.random.uniform(v-0.07, v) + + v = v * img.size[1] + if np.random.uniform(0,1) > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) + + +class TranslateXAbs: + def __init__(self): + pass + + def __call__(self, img, val=0, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + v = np.random.uniform(0, val) + + if np.random.uniform(0,1) > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) + + +class TranslateYAbs: + def __init__(self): + pass + + def __call__(self, img, val=0, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + v = np.random.uniform(0, val) + + if np.random.uniform(0,1) > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) + + + + + + diff --git a/augmentation/noise.py b/augmentation/noise.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a83f3fc324fd8157b50e7ff171ac5a9a6dcd27 --- /dev/null +++ b/augmentation/noise.py @@ -0,0 +1,94 @@ + +import numpy as np +import skimage as sk +from PIL import Image + +''' + PIL resize (W,H) +''' +class GaussianNoise: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + #c = np.random.uniform(.08, .38) + b = [.08, 0.1, 0.12] + if mag<0 or mag>=len(b): + index = 0 + else: + index = mag + a = b[index] + c = np.random.uniform(a, a+0.03) + img = np.array(img) / 255. + img = np.clip(img + np.random.normal(size=img.shape, scale=c), 0, 1) * 255 + return Image.fromarray(img.astype(np.uint8)) + + +class ShotNoise: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + #c = np.random.uniform(3, 60) + b = [13, 8, 3] + if mag<0 or mag>=len(b): + index = 2 + else: + index = mag + a = b[index] + c = np.random.uniform(a, a+7) + img = np.array(img) / 255. + img = np.clip(np.random.poisson(img * c) / float(c), 0, 1) * 255 + return Image.fromarray(img.astype(np.uint8)) + + +class ImpulseNoise: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + #c = np.random.uniform(.03, .27) + b = [.03, .07, .11] + if mag<0 or mag>=len(b): + index = 0 + else: + index = mag + a = b[index] + c = np.random.uniform(a, a+.04) + img = sk.util.random_noise(np.array(img) / 255., mode='s&p', amount=c) * 255 + return Image.fromarray(img.astype(np.uint8)) + + +class SpeckleNoise: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + # c = np.random.uniform(.15, .6) + b = [.15, .2, .25] + if mag<0 or mag>=len(b): + index = 0 + else: + index = mag + a = b[index] + c = np.random.uniform(a, a+.05) + img = np.array(img) / 255. + img = np.clip(img + img * np.random.normal(size=img.shape, scale=c), 0, 1) * 255 + return Image.fromarray(img.astype(np.uint8)) + diff --git a/augmentation/ops.py b/augmentation/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..276e58cd494a91285662b147cfd63b5a152ca662 --- /dev/null +++ b/augmentation/ops.py @@ -0,0 +1,87 @@ + +import cv2 +import numpy as np +from wand.image import Image as WandImage +from scipy.ndimage import zoom as scizoom +from wand.api import library as wandlibrary + +class MotionImage(WandImage): + def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0): + wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle) + +def clipped_zoom(img, zoom_factor): + h = img.shape[1] + # ceil crop height(= crop width) + ch = int(np.ceil(h / float(zoom_factor))) + + top = (h - ch) // 2 + img = scizoom(img[top:top + ch, top:top + ch], (zoom_factor, zoom_factor, 1), order=1) + # trim off any extra pixels + trim_top = (img.shape[0] - h) // 2 + + return img[trim_top:trim_top + h, trim_top:trim_top + h] + +def disk(radius, alias_blur=0.1, dtype=np.float32): + if radius <= 8: + L = np.arange(-8, 8 + 1) + ksize = (3, 3) + else: + L = np.arange(-radius, radius + 1) + ksize = (5, 5) + X, Y = np.meshgrid(L, L) + aliased_disk = np.array((X ** 2 + Y ** 2) <= radius ** 2, dtype=dtype) + aliased_disk /= np.sum(aliased_disk) + + # supersample disk to antialias + return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur) + +# modification of https://github.com/FLHerne/mapgen/blob/master/diamondsquare.py +def plasma_fractal(mapsize=256, wibbledecay=3): + """ + Generate a heightmap using diamond-square algorithm. + Return square 2d array, side length 'mapsize', of floats in range 0-255. + 'mapsize' must be a power of two. + """ + assert (mapsize & (mapsize - 1) == 0) + maparray = np.empty((mapsize, mapsize), dtype=np.float_) + maparray[0, 0] = 0 + stepsize = mapsize + wibble = 100 + + def wibbledmean(array): + return array / 4 + wibble * np.random.uniform(-wibble, wibble, array.shape) + + def fillsquares(): + """For each square of points stepsize apart, + calculate middle value as mean of points + wibble""" + cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize] + squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0) + squareaccum += np.roll(squareaccum, shift=-1, axis=1) + maparray[stepsize // 2:mapsize:stepsize, + stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum) + + def filldiamonds(): + """For each diamond of points stepsize apart, + calculate middle value as mean of points + wibble""" + mapsize = maparray.shape[0] + drgrid = maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize] + ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize] + ldrsum = drgrid + np.roll(drgrid, 1, axis=0) + lulsum = ulgrid + np.roll(ulgrid, -1, axis=1) + ltsum = ldrsum + lulsum + maparray[0:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum) + tdrsum = drgrid + np.roll(drgrid, 1, axis=1) + tulsum = ulgrid + np.roll(ulgrid, -1, axis=0) + ttsum = tdrsum + tulsum + maparray[stepsize // 2:mapsize:stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum) + + while stepsize >= 2: + fillsquares() + filldiamonds() + stepsize //= 2 + wibble /= wibbledecay + + maparray -= maparray.min() + return maparray / maparray.max() + + diff --git a/augmentation/pattern.py b/augmentation/pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..3bbf8e2782f67b5ad59245b8a4a6ad88b8dc49e0 --- /dev/null +++ b/augmentation/pattern.py @@ -0,0 +1,115 @@ + +import cv2 +import numpy as np +from PIL import Image, ImageOps, ImageDraw + +''' + PIL resize (W,H) + Torch resize is (H,W) +''' +class VGrid: + def __init__(self): + pass + + def __call__(self, img, copy=True, max_width=4, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + if copy: + img = img.copy() + W, H = img.size + + if mag<0 or mag>max_width: + line_width = np.random.randint(1, max_width) + image_stripe = np.random.randint(1, max_width) + else: + line_width = 1 + image_stripe = 3 - mag + + n_lines = W // (line_width + image_stripe) + 1 + draw = ImageDraw.Draw(img) + for i in range(1, n_lines): + x = image_stripe*i + line_width*(i-1) + draw.line([(x,0), (x,H)], width=line_width, fill='black') + + return img + +class HGrid: + def __init__(self): + pass + + def __call__(self, img, copy=True, max_width=4, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + if copy: + img = img.copy() + W, H = img.size + if mag<0 or mag>max_width: + line_width = np.random.randint(1, max_width) + image_stripe = np.random.randint(1, max_width) + else: + line_width = 1 + image_stripe = 3 - mag + + n_lines = H // (line_width + image_stripe) + 1 + draw = ImageDraw.Draw(img) + for i in range(1, n_lines): + y = image_stripe*i + line_width*(i-1) + draw.line([(0,y), (W, y)], width=line_width, fill='black') + + return img + +class Grid: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + img = VGrid()(img, copy=True, mag=mag) + img = HGrid()(img, copy=False, mag=mag) + return img + +class RectGrid: + def __init__(self): + pass + + def __call__(self, img, isellipse=False, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + img = img.copy() + W, H = img.size + line_width = 1 + image_stripe = 3 - mag #np.random.randint(2, 6) + offset = 4 if isellipse else 1 + n_lines = ((H//2) // (line_width + image_stripe)) + offset + draw = ImageDraw.Draw(img) + x_center = W // 2 + y_center = H // 2 + for i in range(1, n_lines): + dx = image_stripe*i + line_width*(i-1) + dy = image_stripe*i + line_width*(i-1) + x1 = x_center - (dx * W//H) + y1 = y_center - dy + x2 = x_center + (dx * W/H) + y2 = y_center + dy + if isellipse: + draw.ellipse([(x1,y1), (x2, y2)], width=line_width, outline='black') + else: + draw.rectangle([(x1,y1), (x2, y2)], width=line_width, outline='black') + + return img + +class EllipseGrid: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + img = RectGrid()(img, isellipse=True, mag=mag, prob=prob) + return img diff --git a/augmentation/process.py b/augmentation/process.py new file mode 100644 index 0000000000000000000000000000000000000000..370b0483520eb92b15afe5cf9ffc5f4ad0d4871c --- /dev/null +++ b/augmentation/process.py @@ -0,0 +1,123 @@ + +from PIL import Image +import PIL.ImageOps, PIL.ImageEnhance +import numpy as np + +class Posterize: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + c = [1, 3, 6] + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + c = c[index] + bit = np.random.randint(c, c+2) + img = PIL.ImageOps.posterize(img, bit) + + return img + + +class Solarize: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + c = [64, 128, 192] + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + c = c[index] + thresh = np.random.randint(c, c+64) + img = PIL.ImageOps.solarize(img, thresh) + + return img + +class Invert: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + img = PIL.ImageOps.invert(img) + + return img + + +class Equalize: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + mg = PIL.ImageOps.equalize(img) + + return img + + +class AutoContrast: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + mg = PIL.ImageOps.autocontrast(img) + + return img + + +class Sharpness: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + c = [.1, .7, 1.3] + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + c = c[index] + magnitude = np.random.uniform(c, c+.6) + img = PIL.ImageEnhance.Sharpness(img).enhance(magnitude) + + return img + + +class Color: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + c = [.1, .7, 1.3] + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + c = c[index] + magnitude = np.random.uniform(c, c+.6) + img = PIL.ImageEnhance.Color(img).enhance(magnitude) + + return img + + diff --git a/augmentation/test.py b/augmentation/test.py new file mode 100644 index 0000000000000000000000000000000000000000..0ad1ded2df8fa873e815e2ba6f5f80fa63384057 --- /dev/null +++ b/augmentation/test.py @@ -0,0 +1,43 @@ + +import os +import cv2 +from warp import Curve, Distort, Stretch +from geometry import Rotate, Perspective, Shrink, TranslateX, TranslateY +from pattern import VGrid, HGrid, Grid, RectGrid, EllipseGrid +from noise import GaussianNoise, ShotNoise, ImpulseNoise, SpeckleNoise +from blur import GaussianBlur, DefocusBlur, MotionBlur, GlassBlur, ZoomBlur +from camera import Contrast, Brightness, JpegCompression, Pixelate +from weather import Fog, Snow, Frost, Rain, Shadow +from process import Posterize, Solarize, Invert, Equalize, AutoContrast, Sharpness, Color + +from PIL import Image +import PIL.ImageOps +import numpy as np +import argparse + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--image', default="images/delivery.png", help='Load image file') + parser.add_argument('--results', default="results", help='Load image file') + parser.add_argument('--gray', action='store_true', help='Convert to grayscale 1st') + opt = parser.parse_args() + os.makedirs(opt.results, exist_ok=True) + + img = Image.open(opt.image) + img = img.resize( (100,32) ) + ops = [Curve(), Rotate(), Perspective(), Distort(), Stretch(), Shrink(), TranslateX(), TranslateY(), VGrid(), HGrid(), Grid(), RectGrid(), EllipseGrid()] + ops.extend([GaussianNoise(), ShotNoise(), ImpulseNoise(), SpeckleNoise()]) + ops.extend([GaussianBlur(), DefocusBlur(), MotionBlur(), GlassBlur(), ZoomBlur()]) + ops.extend([Contrast(), Brightness(), JpegCompression(), Pixelate()]) + ops.extend([Fog(), Snow(), Frost(), Rain(), Shadow()]) + ops.extend([Posterize(), Solarize(), Invert(), Equalize(), AutoContrast(), Sharpness(), Color()]) + for op in ops: + for mag in range(3): + filename = type(op).__name__ + "-" + str(mag) + ".png" + out_img = op(img, mag=mag) + if opt.gray: + out_img = PIL.ImageOps.grayscale(out_img) + out_img.save(os.path.join(opt.results, filename)) + + diff --git a/augmentation/warp.py b/augmentation/warp.py new file mode 100644 index 0000000000000000000000000000000000000000..e185cf9f0ee700b2ae4defeeb427e136cd210911 --- /dev/null +++ b/augmentation/warp.py @@ -0,0 +1,241 @@ + +import cv2 +import numpy as np +from PIL import Image, ImageOps + +''' + PIL resize (W,H) + Torch resize is (H,W) +''' +class Stretch: + def __init__(self): + self.tps = cv2.createThinPlateSplineShapeTransformer() + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + img = np.array(img) + srcpt = list() + dstpt = list() + + W_33 = 0.33 * W + W_50 = 0.50 * W + W_66 = 0.66 * W + + H_50 = 0.50 * H + + P = 0 + #frac = 0.4 + + b = [.2, .3, .4] + if mag<0 or mag>=len(b): + index = len(b)-1 + else: + index = mag + frac = b[index] + + # left-most + srcpt.append([P, P]) + srcpt.append([P, H-P]) + srcpt.append([P, H_50]) + x = np.random.uniform(0, frac)*W_33 #if np.random.uniform(0,1) > 0.5 else 0 + dstpt.append([P+x, P]) + dstpt.append([P+x, H-P]) + dstpt.append([P+x, H_50]) + + # 2nd left-most + srcpt.append([P+W_33, P]) + srcpt.append([P+W_33, H-P]) + x = np.random.uniform(-frac, frac)*W_33 + dstpt.append([P+W_33+x, P]) + dstpt.append([P+W_33+x, H-P]) + + # 3rd left-most + srcpt.append([P+W_66, P]) + srcpt.append([P+W_66, H-P]) + x = np.random.uniform(-frac, frac)*W_33 + dstpt.append([P+W_66+x, P]) + dstpt.append([P+W_66+x, H-P]) + + # right-most + srcpt.append([W-P, P]) + srcpt.append([W-P, H-P]) + srcpt.append([W-P, H_50]) + x = np.random.uniform(-frac, 0)*W_33 #if np.random.uniform(0,1) > 0.5 else 0 + dstpt.append([W-P+x, P]) + dstpt.append([W-P+x, H-P]) + dstpt.append([W-P+x, H_50]) + + N = len(dstpt) + matches = [cv2.DMatch(i, i, 0) for i in range(N)] + dst_shape = np.array(dstpt).reshape((-1, N, 2)) + src_shape = np.array(srcpt).reshape((-1, N, 2)) + self.tps.estimateTransformation(dst_shape, src_shape, matches) + img = self.tps.warpImage(img) + img = Image.fromarray(img) + + return img + + +class Distort: + def __init__(self): + self.tps = cv2.createThinPlateSplineShapeTransformer() + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + img = np.array(img) + srcpt = list() + dstpt = list() + + W_33 = 0.33 * W + W_50 = 0.50 * W + W_66 = 0.66 * W + + H_50 = 0.50 * H + + P = 0 + #frac = 0.4 + + b = [.2, .3, .4] + if mag<0 or mag>=len(b): + index = len(b)-1 + else: + index = mag + frac = b[index] + + # top pts + srcpt.append([P, P]) + x = np.random.uniform(0, frac)*W_33 + y = np.random.uniform(0, frac)*H_50 + dstpt.append([P+x, P+y]) + + srcpt.append([P+W_33, P]) + x = np.random.uniform(-frac, frac)*W_33 + y = np.random.uniform(0, frac)*H_50 + dstpt.append([P+W_33+x, P+y]) + + srcpt.append([P+W_66, P]) + x = np.random.uniform(-frac, frac)*W_33 + y = np.random.uniform(0, frac)*H_50 + dstpt.append([P+W_66+x, P+y]) + + srcpt.append([W-P, P]) + x = np.random.uniform(-frac, 0)*W_33 + y = np.random.uniform(0, frac)*H_50 + dstpt.append([W-P+x, P+y]) + + # bottom pts + srcpt.append([P, H-P]) + x = np.random.uniform(0, frac)*W_33 + y = np.random.uniform(-frac, 0)*H_50 + dstpt.append([P+x, H-P+y]) + + srcpt.append([P+W_33, H-P]) + x = np.random.uniform(-frac, frac)*W_33 + y = np.random.uniform(-frac, 0)*H_50 + dstpt.append([P+W_33+x, H-P+y]) + + srcpt.append([P+W_66, H-P]) + x = np.random.uniform(-frac, frac)*W_33 + y = np.random.uniform(-frac, 0)*H_50 + dstpt.append([P+W_66+x, H-P+y]) + + srcpt.append([W-P, H-P]) + x = np.random.uniform(-frac, 0)*W_33 + y = np.random.uniform(-frac, 0)*H_50 + dstpt.append([W-P+x, H-P+y]) + + N = len(dstpt) + matches = [cv2.DMatch(i, i, 0) for i in range(N)] + dst_shape = np.array(dstpt).reshape((-1, N, 2)) + src_shape = np.array(srcpt).reshape((-1, N, 2)) + self.tps.estimateTransformation(dst_shape, src_shape, matches) + img = self.tps.warpImage(img) + img = Image.fromarray(img) + + return img + + +class Curve: + def __init__(self, square_side=224): + self.tps = cv2.createThinPlateSplineShapeTransformer() + self.side = square_side + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + + if H!=self.side or W!=self.side: + img = img.resize((self.side, self.side), Image.BICUBIC) + + isflip = np.random.uniform(0,1) > 0.5 + if isflip: + img = ImageOps.flip(img) + #img = TF.vflip(img) + + img = np.array(img) + w = self.side + h = self.side + w_25 = 0.25 * w + w_50 = 0.50 * w + w_75 = 0.75 * w + + b = [1.1, .95, .8] + if mag<0 or mag>=len(b): + index = 0 + else: + index = mag + rmin = b[index] + + r = np.random.uniform(rmin, rmin+.1)*h + x1 = (r**2 - w_50**2)**0.5 + h1 = r - x1 + + t = np.random.uniform(0.4, 0.5)*h + + w2 = w_50*t/r + hi = x1*t/r + h2 = h1 + hi + + sinb_2 = ((1 - x1/r)/2)**0.5 + cosb_2 = ((1 + x1/r)/2)**0.5 + w3 = w_50 - r*sinb_2 + h3 = r - r*cosb_2 + + w4 = w_50 - (r-t)*sinb_2 + h4 = r - (r-t)*cosb_2 + + w5 = 0.5*w2 + h5 = h1 + 0.5*hi + h_50 = 0.50*h + + srcpt = [(0,0 ), (w,0 ), (w_50,0), (0,h ), (w,h ), (w_25,0), (w_75,0 ), (w_50,h), (w_25,h), (w_75,h ), (0,h_50), (w,h_50 )] + dstpt = [(0,h1), (w,h1), (w_50,0), (w2,h2), (w-w2,h2), (w3, h3), (w-w3,h3), (w_50,t), (w4,h4 ), (w-w4,h4), (w5,h5 ), (w-w5,h5)] + + N = len(dstpt) + matches = [cv2.DMatch(i, i, 0) for i in range(N)] + dst_shape = np.array(dstpt).reshape((-1, N, 2)) + src_shape = np.array(srcpt).reshape((-1, N, 2)) + self.tps.estimateTransformation(dst_shape, src_shape, matches) + img = self.tps.warpImage(img) + img = Image.fromarray(img) + + if isflip: + #img = TF.vflip(img) + img = ImageOps.flip(img) + rect = (0, self.side//2, self.side, self.side) + else: + rect = (0, 0, self.side, self.side//2) + + img = img.crop(rect) + img = img.resize((W, H), Image.BICUBIC) + return img + + diff --git a/augmentation/weather.py b/augmentation/weather.py new file mode 100644 index 0000000000000000000000000000000000000000..6fbc88718c12e7dee1cceed2c6dd819ecf152f1e --- /dev/null +++ b/augmentation/weather.py @@ -0,0 +1,231 @@ + +import cv2 +import numpy as np +import math +from PIL import Image, ImageOps, ImageDraw +from skimage import color +from pkg_resources import resource_filename +from io import BytesIO +from .ops import plasma_fractal, clipped_zoom, MotionImage + +''' + PIL resize (W,H) +''' +class Fog: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + c = [(1.5, 2), (2., 2), (2.5, 1.7)] + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + c = c[index] + + n_channels = len(img.getbands()) + isgray = n_channels == 1 + + img = np.array(img) / 255. + max_val = img.max() + fog = c[0] * plasma_fractal(wibbledecay=c[1])[:H, :W][..., np.newaxis] + #x += c[0] * plasma_fractal(wibbledecay=c[1])[:224, :224][..., np.newaxis] + #return np.clip(x * max_val / (max_val + c[0]), 0, 1) * 255 + if isgray: + fog = np.squeeze(fog) + else: + fog = np.repeat(fog, 3, axis=2) + + img += fog + img = np.clip(img * max_val / (max_val + c[0]), 0, 1) * 255 + return Image.fromarray(img.astype(np.uint8)) + + +class Frost: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + c = [(1, 0.4), (0.8, 0.6), (0.7, 0.7)] + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + c = c[index] + + filename = [resource_filename(__name__, 'frost/frost1.png'), + resource_filename(__name__, 'frost/frost2.png'), + resource_filename(__name__, 'frost/frost3.png'), + resource_filename(__name__, 'frost/frost4.jpg'), + resource_filename(__name__, 'frost/frost5.jpg'), + resource_filename(__name__, 'frost/frost6.jpg')] + index = np.random.randint(0, len(filename)) + filename = filename[index] + frost = cv2.imread(filename) + #randomly crop and convert to rgb + x_start, y_start = np.random.randint(0, frost.shape[0] - H), np.random.randint(0, frost.shape[1] - W) + frost = frost[x_start:x_start + H, y_start:y_start + W][..., [2, 1, 0]] + + n_channels = len(img.getbands()) + isgray = n_channels == 1 + + img = np.array(img) + + if isgray: + img = np.expand_dims(img, axis=2) + img = np.repeat(img, 3, axis=2) + + img = img * c[0] + frost = frost * c[1] + img = np.clip(c[0] * img + c[1] * frost, 0, 255) + img = Image.fromarray(img.astype(np.uint8)) + if isgray: + img = ImageOps.grayscale(img) + + return img + +class Snow: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + W, H = img.size + c = [(0.1, 0.3, 3, 0.5, 10, 4, 0.8), + (0.2, 0.3, 2, 0.5, 12, 4, 0.7), + (0.55, 0.3, 4, 0.9, 12, 8, 0.7)] + if mag<0 or mag>=len(c): + index = np.random.randint(0, len(c)) + else: + index = mag + c = c[index] + + n_channels = len(img.getbands()) + isgray = n_channels == 1 + + img = np.array(img, dtype=np.float32) / 255. + if isgray: + img = np.expand_dims(img, axis=2) + img = np.repeat(img, 3, axis=2) + + snow_layer = np.random.normal(size=img.shape[:2], loc=c[0], scale=c[1]) # [:2] for monochrome + + #snow_layer = clipped_zoom(snow_layer[..., np.newaxis], c[2]) + snow_layer[snow_layer < c[3]] = 0 + + snow_layer = Image.fromarray((np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode='L') + output = BytesIO() + snow_layer.save(output, format='PNG') + snow_layer = MotionImage(blob=output.getvalue()) + + snow_layer.motion_blur(radius=c[4], sigma=c[5], angle=np.random.uniform(-135, -45)) + + snow_layer = cv2.imdecode(np.fromstring(snow_layer.make_blob(), np.uint8), + cv2.IMREAD_UNCHANGED) / 255. + + #snow_layer = cv2.cvtColor(snow_layer, cv2.COLOR_BGR2RGB) + + snow_layer = snow_layer[..., np.newaxis] + + img = c[6] * img + gray_img = (1 - c[6]) * np.maximum(img, cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).reshape(H, W, 1) * 1.5 + 0.5) + img += gray_img + img = np.clip(img + snow_layer + np.rot90(snow_layer, k=2), 0, 1) * 255 + img = Image.fromarray(img.astype(np.uint8)) + if isgray: + img = ImageOps.grayscale(img) + + return img + +class Rain: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + img = img.copy() + W, H = img.size + n_channels = len(img.getbands()) + isgray = n_channels == 1 + line_width = np.random.randint(1, 2) + + c =[50, 70, 90] + if mag<0 or mag>=len(c): + index = 0 + else: + index = mag + c = c[index] + + n_rains = np.random.randint(c, c+20) + slant = np.random.randint(-60, 60) + fillcolor = 200 if isgray else (200,200,200) + + draw = ImageDraw.Draw(img) + for i in range(1, n_rains): + length = np.random.randint(5, 10) + x1 = np.random.randint(0, W-length) + y1 = np.random.randint(0, H-length) + x2 = x1 + length*math.sin(slant*math.pi/180.) + y2 = y1 + length*math.cos(slant*math.pi/180.) + x2 = int(x2) + y2 = int(y2) + draw.line([(x1,y1), (x2,y2)], width=line_width, fill=fillcolor) + + return img + +class Shadow: + def __init__(self): + pass + + def __call__(self, img, mag=-1, prob=1.): + if np.random.uniform(0,1) > prob: + return img + + #img = img.copy() + W, H = img.size + n_channels = len(img.getbands()) + isgray = n_channels == 1 + + c =[64, 96, 128] + if mag<0 or mag>=len(c): + index = 0 + else: + index = mag + c = c[index] + + img = img.convert('RGBA') + overlay = Image.new('RGBA', img.size, (255,255,255,0)) + draw = ImageDraw.Draw(overlay) + transparency = np.random.randint(c, c+32) + x1 = np.random.randint(0, W//2) + y1 = 0 + + x2 = np.random.randint(W//2, W) + y2 = 0 + + x3 = np.random.randint(W//2, W) + y3 = H - 1 + + x4 = np.random.randint(0, W//2) + y4 = H - 1 + + draw.polygon([(x1,y1), (x2,y2), (x3,y3), (x4,y4)], fill=(0,0,0,transparency)) + + img = Image.alpha_composite(img, overlay) + img = img.convert("RGB") + if isgray: + img = ImageOps.grayscale(img) + + return img diff --git a/callbacks.py b/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..22da46b4b8d58c4596ac6accd96fc92e6a343c87 --- /dev/null +++ b/callbacks.py @@ -0,0 +1,360 @@ +import logging +import shutil +import time + +import editdistance as ed +import torchvision.utils as vutils +from fastai.callbacks.tensorboard import (LearnerTensorboardWriter, + SummaryWriter, TBWriteRequest, + asyncTBWriter) +from fastai.vision import * +from torch.nn.parallel import DistributedDataParallel +from torchvision import transforms + +import dataset_abinet +from utils_abinet import CharsetMapper, Timer, blend_mask + + +class IterationCallback(LearnerTensorboardWriter): + "A `TrackerCallback` that monitor in each iteration." + def __init__(self, learn:Learner, name:str='model', checpoint_keep_num=5, + show_iters:int=50, eval_iters:int=1000, save_iters:int=20000, + start_iters:int=0, stats_iters=20000): + #if self.learn.rank is not None: time.sleep(self.learn.rank) # keep all event files + super().__init__(learn, base_dir='.', name=learn.path, loss_iters=show_iters, + stats_iters=stats_iters, hist_iters=stats_iters) + self.name, self.bestname = Path(name).name, f'best-{Path(name).name}' + self.show_iters = show_iters + self.eval_iters = eval_iters + self.save_iters = save_iters + self.start_iters = start_iters + self.checpoint_keep_num = checpoint_keep_num + self.metrics_root = 'metrics/' # rewrite + self.timer = Timer() + self.host = self.learn.rank is None or self.learn.rank == 0 + + def _write_metrics(self, iteration:int, names:List[str], last_metrics:MetricsList)->None: + "Writes training metrics to Tensorboard." + for i, name in enumerate(names): + if last_metrics is None or len(last_metrics) < i+1: return + scalar_value = last_metrics[i] + self._write_scalar(name=name, scalar_value=scalar_value, iteration=iteration) + + def _write_sub_loss(self, iteration:int, last_losses:dict)->None: + "Writes sub loss to Tensorboard." + for name, loss in last_losses.items(): + scalar_value = to_np(loss) + tag = self.metrics_root + name + self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration) + + def _save(self, name): + if isinstance(self.learn.model, DistributedDataParallel): + tmp = self.learn.model + self.learn.model = self.learn.model.module + self.learn.save(name) + self.learn.model = tmp + else: self.learn.save(name) + + def _validate(self, dl=None, callbacks=None, metrics=None, keeped_items=False): + "Validate on `dl` with potential `callbacks` and `metrics`." + dl = ifnone(dl, self.learn.data.valid_dl) + metrics = ifnone(metrics, self.learn.metrics) + cb_handler = CallbackHandler(ifnone(callbacks, []), metrics) + cb_handler.on_train_begin(1, None, metrics); cb_handler.on_epoch_begin() + if keeped_items: cb_handler.state_dict.update(dict(keeped_items=[])) + val_metrics = validate(self.learn.model, dl, self.loss_func, cb_handler) + cb_handler.on_epoch_end(val_metrics) + if keeped_items: return cb_handler.state_dict['keeped_items'] + else: return cb_handler.state_dict['last_metrics'] + + def jump_to_epoch_iter(self, epoch:int, iteration:int)->None: + try: + self.learn.load(f'{self.name}_{epoch}_{iteration}', purge=False) + logging.info(f'Loaded {self.name}_{epoch}_{iteration}') + except: logging.info(f'Model {self.name}_{epoch}_{iteration} not found.') + + def on_train_begin(self, n_epochs, **kwargs): + # TODO: can not write graph here + # super().on_train_begin(**kwargs) + self.best = -float('inf') + self.timer.tic() + if self.host: + checkpoint_path = self.learn.path/'checkpoint.yaml' + if checkpoint_path.exists(): + os.remove(checkpoint_path) + open(checkpoint_path, 'w').close() + return {'skip_validate': True, 'iteration':self.start_iters} # disable default validate + + def on_batch_begin(self, **kwargs:Any)->None: + self.timer.toc_data() + super().on_batch_begin(**kwargs) + + def on_batch_end(self, iteration, epoch, last_loss, smooth_loss, train, **kwargs): + super().on_batch_end(last_loss, iteration, train, **kwargs) + if iteration == 0: return + + if iteration % self.loss_iters == 0: + last_losses = self.learn.loss_func.last_losses + self._write_sub_loss(iteration=iteration, last_losses=last_losses) + self.tbwriter.add_scalar(tag=self.metrics_root + 'lr', + scalar_value=self.opt.lr, global_step=iteration) + + if iteration % self.show_iters == 0: + log_str = f'epoch {epoch} iter {iteration}: loss = {last_loss:6.4f}, ' \ + f'smooth loss = {smooth_loss:6.4f}' + logging.info(log_str) + # log_str = f'data time = {self.timer.data_diff:.4f}s, runing time = {self.timer.running_diff:.4f}s' + # logging.info(log_str) + + if iteration % self.eval_iters == 0: + # TODO: or remove time to on_epoch_end + # 1. Record time + log_str = f'average data time = {self.timer.average_data_time():.4f}s, ' \ + f'average running time = {self.timer.average_running_time():.4f}s' + logging.info(log_str) + + # 2. Call validate + last_metrics = self._validate() + self.learn.model.train() + log_str = f'epoch {epoch} iter {iteration}: eval loss = {last_metrics[0]:6.4f}, ' \ + f'ccr = {last_metrics[1]:6.4f}, cwr = {last_metrics[2]:6.4f}, ' \ + f'ted = {last_metrics[3]:6.4f}, ned = {last_metrics[4]:6.4f}, ' \ + f'ted/w = {last_metrics[5]:6.4f}, ' + logging.info(log_str) + names = ['eval_loss', 'ccr', 'cwr', 'ted', 'ned', 'ted/w'] + self._write_metrics(iteration, names, last_metrics) + + # 3. Save best model + current = last_metrics[2] + if current is not None and current > self.best: + logging.info(f'Better model found at epoch {epoch}, '\ + f'iter {iteration} with accuracy value: {current:6.4f}.') + self.best = current + self._save(f'{self.bestname}') + + if iteration % self.save_iters == 0 and self.host: + logging.info(f'Save model {self.name}_{epoch}_{iteration}') + filename = f'{self.name}_{epoch}_{iteration}' + self._save(filename) + + checkpoint_path = self.learn.path/'checkpoint.yaml' + if not checkpoint_path.exists(): + open(checkpoint_path, 'w').close() + with open(checkpoint_path, 'r') as file: + checkpoints = yaml.load(file, Loader=yaml.FullLoader) or dict() + checkpoints['all_checkpoints'] = ( + checkpoints.get('all_checkpoints') or list()) + checkpoints['all_checkpoints'].insert(0, filename) + if len(checkpoints['all_checkpoints']) > self.checpoint_keep_num: + removed_checkpoint = checkpoints['all_checkpoints'].pop() + removed_checkpoint = self.learn.path/self.learn.model_dir/f'{removed_checkpoint}.pth' + os.remove(removed_checkpoint) + checkpoints['current_checkpoint'] = filename + with open(checkpoint_path, 'w') as file: + yaml.dump(checkpoints, file) + + + self.timer.toc_running() + + def on_train_end(self, **kwargs): + #self.learn.load(f'{self.bestname}', purge=False) + pass + + def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs)->None: + self._write_embedding(iteration=iteration) + + +class TextAccuracy(Callback): + _names = ['ccr', 'cwr', 'ted', 'ned', 'ted/w'] + def __init__(self, charset_path, max_length, case_sensitive, model_eval): + self.charset_path = charset_path + self.max_length = max_length + self.case_sensitive = case_sensitive + self.charset = CharsetMapper(charset_path, self.max_length) + self.names = self._names + + self.model_eval = model_eval or 'alignment' + assert self.model_eval in ['vision', 'language', 'alignment'] + + def on_epoch_begin(self, **kwargs): + self.total_num_char = 0. + self.total_num_word = 0. + self.correct_num_char = 0. + self.correct_num_word = 0. + self.total_ed = 0. + self.total_ned = 0. + + def _get_output(self, last_output): + if isinstance(last_output, (tuple, list)): + for res in last_output: + if res['name'] == self.model_eval: output = res + else: output = last_output + return output + + def _update_output(self, last_output, items): + if isinstance(last_output, (tuple, list)): + for res in last_output: + if res['name'] == self.model_eval: res.update(items) + else: last_output.update(items) + return last_output + + def on_batch_end(self, last_output, last_target, **kwargs): + output = self._get_output(last_output) + logits, pt_lengths = output['logits'], output['pt_lengths'] + pt_text, pt_scores, pt_lengths_ = self.decode(logits) + assert (pt_lengths == pt_lengths_).all(), f'{pt_lengths} != {pt_lengths_} for {pt_text}' + last_output = self._update_output(last_output, {'pt_text':pt_text, 'pt_scores':pt_scores}) + + pt_text = [self.charset.trim(t) for t in pt_text] + label = last_target[0] + if label.dim() == 3: label = label.argmax(dim=-1) # one-hot label + gt_text = [self.charset.get_text(l, trim=True) for l in label] + + for i in range(len(gt_text)): + if not self.case_sensitive: + gt_text[i], pt_text[i] = gt_text[i].lower(), pt_text[i].lower() + distance = ed.eval(gt_text[i], pt_text[i]) + self.total_ed += distance + self.total_ned += float(distance) / max(len(gt_text[i]), 1) + + if gt_text[i] == pt_text[i]: + self.correct_num_word += 1 + self.total_num_word += 1 + + for j in range(min(len(gt_text[i]), len(pt_text[i]))): + if gt_text[i][j] == pt_text[i][j]: + self.correct_num_char += 1 + self.total_num_char += len(gt_text[i]) + + return {'last_output': last_output} + + def on_epoch_end(self, last_metrics, **kwargs): + mets = [self.correct_num_char / self.total_num_char, + self.correct_num_word / self.total_num_word, + self.total_ed, + self.total_ned, + self.total_ed / self.total_num_word] + return add_metrics(last_metrics, mets) + + def decode(self, logit): + """ Greed decode """ + # TODO: test running time and decode on GPU + out = F.softmax(logit, dim=2) + pt_text, pt_scores, pt_lengths = [], [], [] + for o in out: + text = self.charset.get_text(o.argmax(dim=1), padding=False, trim=False) + text = text.split(self.charset.null_char)[0] # end at end-token + pt_text.append(text) + pt_scores.append(o.max(dim=1)[0]) + pt_lengths.append(min(len(text) + 1, self.max_length)) # one for end-token + pt_scores = torch.stack(pt_scores) + pt_lengths = pt_scores.new_tensor(pt_lengths, dtype=torch.long) + return pt_text, pt_scores, pt_lengths + + +class TopKTextAccuracy(TextAccuracy): + _names = ['ccr', 'cwr'] + def __init__(self, k, charset_path, max_length, case_sensitive, model_eval): + self.k = k + self.charset_path = charset_path + self.max_length = max_length + self.case_sensitive = case_sensitive + self.charset = CharsetMapper(charset_path, self.max_length) + self.names = self._names + + def on_epoch_begin(self, **kwargs): + self.total_num_char = 0. + self.total_num_word = 0. + self.correct_num_char = 0. + self.correct_num_word = 0. + + def on_batch_end(self, last_output, last_target, **kwargs): + logits, pt_lengths = last_output['logits'], last_output['pt_lengths'] + gt_labels, gt_lengths = last_target[:] + + for logit, pt_length, label, length in zip(logits, pt_lengths, gt_labels, gt_lengths): + word_flag = True + for i in range(length): + char_logit = logit[i].topk(self.k)[1] + char_label = label[i].argmax(-1) + if char_label in char_logit: self.correct_num_char += 1 + else: word_flag = False + self.total_num_char += 1 + if pt_length == length and word_flag: + self.correct_num_word += 1 + self.total_num_word += 1 + + def on_epoch_end(self, last_metrics, **kwargs): + mets = [self.correct_num_char / self.total_num_char, + self.correct_num_word / self.total_num_word, + 0., 0., 0.] + return add_metrics(last_metrics, mets) + + +class DumpPrediction(LearnerCallback): + + def __init__(self, learn, dataset, charset_path, model_eval, image_only=False, debug=False): + super().__init__(learn=learn) + self.debug = debug + self.model_eval = model_eval or 'alignment' + self.image_only = image_only + assert self.model_eval in ['vision', 'language', 'alignment'] + + self.dataset, self.root = dataset, Path(self.learn.path)/f'{dataset}-{self.model_eval}' + self.attn_root = self.root/'attn' + self.charset = CharsetMapper(charset_path) + if self.root.exists(): shutil.rmtree(self.root) + self.root.mkdir(), self.attn_root.mkdir() + + self.pil = transforms.ToPILImage() + self.tensor = transforms.ToTensor() + size = self.learn.data.img_h, self.learn.data.img_w + self.resize = transforms.Resize(size=size, interpolation=0) + self.c = 0 + + def on_batch_end(self, last_input, last_output, last_target, **kwargs): + if isinstance(last_output, (tuple, list)): + for res in last_output: + if res['name'] == self.model_eval: pt_text = res['pt_text'] + if res['name'] == 'vision': attn_scores = res['attn_scores'].detach().cpu() + if res['name'] == self.model_eval: logits = res['logits'] + else: + pt_text = last_output['pt_text'] + attn_scores = last_output['attn_scores'].detach().cpu() + logits = last_output['logits'] + + images = last_input[0] if isinstance(last_input, (tuple, list)) else last_input + images = images.detach().cpu() + pt_text = [self.charset.trim(t) for t in pt_text] + gt_label = last_target[0] + if gt_label.dim() == 3: gt_label = gt_label.argmax(dim=-1) # one-hot label + gt_text = [self.charset.get_text(l, trim=True) for l in gt_label] + + prediction, false_prediction = [], [] + for gt, pt, image, attn, logit in zip(gt_text, pt_text, images, attn_scores, logits): + prediction.append(f'{gt}\t{pt}\n') + if gt != pt: + if self.debug: + scores = torch.softmax(logit, dim=-1)[:max(len(pt), len(gt)) + 1] + logging.info(f'{self.c} gt {gt}, pt {pt}, logit {logit.shape}, scores {scores.topk(5, dim=-1)}') + false_prediction.append(f'{gt}\t{pt}\n') + + image = self.learn.data.denorm(image) + if not self.image_only: + image_np = np.array(self.pil(image)) + attn_pil = [self.pil(a) for a in attn[:, None, :, :]] + attn = [self.tensor(self.resize(a)).repeat(3, 1, 1) for a in attn_pil] + attn_sum = np.array([np.array(a) for a in attn_pil[:len(pt)]]).sum(axis=0) + blended_sum = self.tensor(blend_mask(image_np, attn_sum)) + blended = [self.tensor(blend_mask(image_np, np.array(a))) for a in attn_pil] + save_image = torch.stack([image] + attn + [blended_sum] + blended) + save_image = save_image.view(2, -1, *save_image.shape[1:]) + save_image = save_image.permute(1, 0, 2, 3, 4).flatten(0, 1) + vutils.save_image(save_image, self.attn_root/f'{self.c}_{gt}_{pt}.jpg', + nrow=2, normalize=True, scale_each=True) + else: + self.pil(image).save(self.attn_root/f'{self.c}_{gt}_{pt}.jpg') + self.c += 1 + + with open(self.root/f'{self.model_eval}.txt', 'a') as f: f.writelines(prediction) + with open(self.root/f'{self.model_eval}-false.txt', 'a') as f: f.writelines(false_prediction) diff --git a/captum/__init__.py b/captum/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..24b3fae72787c76ec1e628b24663fe5d8749b8c7 --- /dev/null +++ b/captum/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 + +__version__ = "0.5.0" diff --git a/captum/_utils/__init__.py b/captum/_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/_utils/av.py b/captum/_utils/av.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b235dd8d02619b68523d0d187a16cc90219920 --- /dev/null +++ b/captum/_utils/av.py @@ -0,0 +1,499 @@ +#!/usr/bin/env python3 + +import glob +import os +import re +import warnings +from typing import Any, List, Optional, Tuple, Union + +import captum._utils.common as common +import torch +from captum.attr import LayerActivation +from torch import Tensor +from torch.nn import Module +from torch.utils.data import DataLoader, Dataset + + +class AV: + r""" + This class provides functionality to store and load activation vectors + generated for pre-defined neural network layers. + It also provides functionality to check if activation vectors already + exist in the manifold and other auxiliary functions. + + This class also defines a torch `Dataset`, representing Activation Vectors, + which enables lazy access to activation vectors and layer stored in the manifold. + + """ + + r""" + The name of the subfolder in the manifold where the activation vectors + are stored. + """ + + class AVDataset(Dataset): + r""" + This dataset enables access to activation vectors for a given `model` stored + under a pre-defined path. + The iterator of this dataset returns a batch of data tensors. + Additionally, subsets of the model activations can be loaded based on layer + or identifier or num_id (representing batch number in source dataset). + """ + + def __init__( + self, + path: str, + model_id: str, + identifier: Optional[str] = None, + layer: Optional[str] = None, + num_id: Optional[str] = None, + ): + r""" + Loads into memory the list of all activation file paths associated + with the input `model_id`. + + Args: + path (str): The path where the activation vectors + for the `layer` are stored. + model_id (str): The name/version of the model for which layer + activations are being computed and stored. + identifier (str or None): An optional identifier for the layer + activations. Can be used to distinguish between activations for + different training batches. + layer (str or None): The layer for which the activation vectors + are computed. + num_id (str): An optional string representing the batch number for + which the activation vectors are computed + """ + + self.av_filesearch = AV._construct_file_search( + path, model_id, identifier, layer, num_id + ) + + files = glob.glob(self.av_filesearch) + + self.files = AV.sort_files(files) + + def __getitem__(self, idx: int) -> Union[Tensor, Tuple[Tensor, ...]]: + assert idx < len(self.files), "Layer index is out of bounds!" + fl = self.files[idx] + av = torch.load(fl) + return av + + def __len__(self): + return len(self.files) + + AV_DIR_NAME: str = "av" + + def __init__(self) -> None: + pass + + @staticmethod + def _assemble_model_dir(path: str, model_id: str) -> str: + r""" + Returns a directory path for the given source path `path` and `model_id.` + This path is suffixed with the '/' delimiter. + """ + return "/".join([path, AV.AV_DIR_NAME, model_id, ""]) + + @staticmethod + def _assemble_file_path(source_dir: str, identifier: str, layer: str) -> str: + r""" + Returns a full filepath given a source directory, layer, and required + identifier. The source dir is not required to end with a "/" delimiter. + """ + if not source_dir.endswith("/"): + source_dir += "/" + + filepath = os.path.join(source_dir, identifier) + + filepath = os.path.join(filepath, layer) + + return filepath + + @staticmethod + def _construct_file_search( + source_dir: str, + model_id: str, + identifier: Optional[str] = None, + layer: Optional[str] = None, + num_id: Optional[str] = None, + ) -> str: + r""" + Returns a search string that can be used by glob to search `source_dir/model_id` + for the desired layer/identifier pair. Leaving `layer` as None will search ids + over all layers, and leaving `identifier` as none will search layers over all + ids. Leaving both as none will return a path to glob for every activation. + Assumes identifier is always specified when saving activations, so that + activations live at source_dir/model_id/identifier/layer + (and never source_dir/model_id/layer) + """ + + av_filesearch = AV._assemble_model_dir(source_dir, model_id) + + av_filesearch = os.path.join( + av_filesearch, "*" if identifier is None else identifier + ) + + av_filesearch = os.path.join(av_filesearch, "*" if layer is None else layer) + + av_filesearch = os.path.join( + av_filesearch, "*.pt" if num_id is None else "%s.pt" % num_id + ) + + return av_filesearch + + @staticmethod + def exists( + path: str, + model_id: str, + identifier: Optional[str] = None, + layer: Optional[str] = None, + num_id: Optional[str] = None, + ) -> bool: + r""" + Verifies whether the model + layer activations exist + under the path. + + Args: + path (str): The path where the activation vectors + for the `model_id` are stored. + model_id (str): The name/version of the model for which layer activations + are being computed and stored. + identifier (str or None): An optional identifier for the layer activations. + Can be used to distinguish between activations for different + training batches. For example, the id could be a suffix composed of + a train/test label and numerical value, such as "-train-xxxxx". + The numerical id is often a monotonic sequence taken from datetime. + layer (str or None): The layer for which the activation vectors are + computed. + num_id (str): An optional string representing the batch number for which + the activation vectors are computed + + Returns: + exists (bool): Indicating whether the activation vectors for the `layer` + and `identifier` (if provided) and num_id (if provided) were stored + in the manifold. If no `identifier` is provided, will return `True` + if any layer activation exists, whether it has an identifier or + not, and vice-versa. + """ + av_dir = AV._assemble_model_dir(path, model_id) + av_filesearch = AV._construct_file_search( + path, model_id, identifier, layer, num_id + ) + return os.path.exists(av_dir) and len(glob.glob(av_filesearch)) > 0 + + @staticmethod + def save( + path: str, + model_id: str, + identifier: str, + layers: Union[str, List[str]], + act_tensors: Union[Tensor, List[Tensor]], + num_id: str, + ) -> None: + r""" + Saves the activation vectors `act_tensor` for the + `layer` under the manifold `path`. + + Args: + path (str): The path where the activation vectors + for the `layer` are stored. + model_id (str): The name/version of the model for which layer activations + are being computed and stored. + identifier (str or None): An optional identifier for the layer + activations. Can be used to distinguish between activations for + different training batches. For example, the identifier could be + a suffix composed of a train/test label and numerical value, such + as "-src-abc". + Additionally, (abc) could be a unique identifying number. For + example, it is automatically created in + AV.generate_dataset_activations from batch index. + It assumes identifier is same for all layers if a list of + `layers` is provided. + layers (str or List of str): The layer(s) for which the activation vectors + are computed. + act_tensors (Tensor or List of Tensor): A batch of activation vectors. + This must match the dimension of `layers`. + num_id (str): string representing the batch number for which the activation + vectors are computed + """ + if isinstance(layers, str): + layers = [layers] + if isinstance(act_tensors, Tensor): + act_tensors = [act_tensors] + + if len(layers) != len(act_tensors): + raise ValueError("The dimension of `layers` and `act_tensors` must match!") + + av_dir = AV._assemble_model_dir(path, model_id) + + for i, layer in enumerate(layers): + av_save_fl_path = os.path.join( + AV._assemble_file_path(av_dir, identifier, layer), "%s.pt" % num_id + ) + + layer_dir = os.path.dirname(av_save_fl_path) + if not os.path.exists(layer_dir): + os.makedirs(layer_dir) + torch.save(act_tensors[i], av_save_fl_path) + + @staticmethod + def load( + path: str, + model_id: str, + identifier: Optional[str] = None, + layer: Optional[str] = None, + num_id: Optional[str] = None, + ) -> AVDataset: + r""" + Loads lazily the activation vectors for given `model_id` and + `layer` saved under the `path`. + + Args: + path (str): The path where the activation vectors + for the `layer` are stored. + model_id (str): The name/version of the model for which layer activations + are being computed and stored. + identifier (str or None): An optional identifier for the layer + activations. Can be used to distinguish between activations for + different training batches. + layer (str or None): The layer for which the activation vectors + are computed. + num_id (str): An optional string representing the batch number for which + the activation vectors are computed + + Returns: + dataset (AV.AVDataset): AV.AVDataset that allows to iterate + over the activation vectors for given layer, identifier (if + provided), num_id (if provided). Returning an AV.AVDataset as + opposed to a DataLoader constructed from it offers more + flexibility. Raises RuntimeError if activation vectors are not + found. + """ + + av_save_dir = AV._assemble_model_dir(path, model_id) + + if os.path.exists(av_save_dir): + avdataset = AV.AVDataset(path, model_id, identifier, layer, num_id) + return avdataset + else: + raise RuntimeError( + f"Activation vectors for model {model_id} was not found at path {path}" + ) + + @staticmethod + def _manage_loading_layers( + path: str, + model_id: str, + layers: Union[str, List[str]], + load_from_disk: bool = True, + identifier: Optional[str] = None, + num_id: Optional[str] = None, + ) -> List[str]: + r""" + Returns unsaved layers, and deletes saved layers if load_from_disk is False. + + Args: + path (str): The path where the activation vectors + for the `layer` are stored. + model_id (str): The name/version of the model for which layer activations + are being computed and stored. + layers (str or List of str): The layer(s) for which the activation vectors + are computed. + identifier (str or None): An optional identifier for the layer + activations. Can be used to distinguish between activations for + different training batches. + num_id (str): An optional string representing the batch number for which the + activation vectors are computed + + Returns: + List of layer names for which activations should be generated + """ + + layers = [layers] if isinstance(layers, str) else layers + unsaved_layers = [] + + if load_from_disk: + for layer in layers: + if not AV.exists(path, model_id, identifier, layer, num_id): + unsaved_layers.append(layer) + else: + unsaved_layers = layers + warnings.warn( + "Overwriting activations: load_from_disk is set to False. Removing all " + f"activations matching specified parameters {{path: {path}, " + f"model_id: {model_id}, layers: {layers}, identifier: {identifier}}} " + "before generating new activations." + ) + for layer in layers: + files = glob.glob( + AV._construct_file_search(path, model_id, identifier, layer) + ) + for filename in files: + os.remove(filename) + + return unsaved_layers + + @staticmethod + def _compute_and_save_activations( + path: str, + model: Module, + model_id: str, + layers: Union[str, List[str]], + inputs: Union[Tensor, Tuple[Tensor, ...]], + identifier: str, + num_id: str, + additional_forward_args: Any = None, + load_from_disk: bool = True, + ) -> None: + r""" + Computes layer activations for the given inputs and specified `layers` + + Args: + path (str): The path where the activation vectors + for the `layer` are stored. + model (torch.nn.Module): An instance of pytorch model. This model should + define all of its layers as attributes of the model. + model_id (str): The name/version of the model for which layer activations + are being computed and stored. + layers (str or List of str): The layer(s) for which the activation vectors + are computed. + inputs (tensor or tuple of tensors): Batch of examples for + which influential instances are computed. They are passed to the + input `model`. The first dimension in `inputs` tensor or tuple of + tensors corresponds to the batch size. + identifier (str or None): An optional identifier for the layer + activations. Can be used to distinguish between activations for + different training batches. + num_id (str): An required string representing the batch number for which the + activation vectors are computed + additional_forward_args (optional): Additional arguments that will be + passed to `model` after inputs. + Default: None + load_from_disk (bool): Forces function to regenerate activations if False. + Default: True + """ + unsaved_layers = AV._manage_loading_layers( + path, + model_id, + layers, + load_from_disk, + identifier, + num_id, + ) + layer_modules = [ + common._get_module_from_name(model, layer) for layer in unsaved_layers + ] + if len(unsaved_layers) > 0: + layer_act = LayerActivation(model, layer_modules) + new_activations = layer_act.attribute.__wrapped__( # type: ignore + layer_act, inputs, additional_forward_args + ) + AV.save(path, model_id, identifier, unsaved_layers, new_activations, num_id) + + @staticmethod + def _unpack_data(data: Union[Any, Tuple[Any, Any]]) -> Any: + r""" + Helper to extract input from labels when getting items from a Dataset. Assumes + that data is either a single value, or a tuple containing two elements. + The input could itself be a Tuple containing multiple values. If your + dataset returns a Tuple with more than 2 elements, please reformat it such that + all inputs are formatted into a tuple stored at the first position. + """ + if isinstance(data, tuple) or isinstance(data, list): + data = data[0] + return data + + r"""TODO: + 1. Can propagate saving labels along with activations. + 2. Use of additional_forward_args when sourcing from dataset? + """ + + @staticmethod + def generate_dataset_activations( + path: str, + model: Module, + model_id: str, + layers: Union[str, List[str]], + dataloader: DataLoader, + identifier: str = "default", + load_from_disk: bool = True, + return_activations: bool = False, + ) -> Optional[Union[AVDataset, List[AVDataset]]]: + r""" + Computes layer activations for a source dataset and specified `layers`. Assumes + that the dataset returns a single value, or a tuple containing two elements + (see AV._unpack_data). + + Args: + path (str): The path where the activation vectors + for the `layer` are stored. + module (torch.nn.Module): An instance of pytorch model. This model should + define all of its layers as attributes of the model. + model_id (str): The name/version of the model for which layer activations + are being computed and stored. + layers (str or List of str): The layer(s) for which the activation vectors + are computed. + dataloader (torch.utils.data.DataLoader): DataLoader that yields Dataset + for which influential instances are computed. They are passed to + input `model`. + identifier (str or None): An identifier for the layer + activations. Can be used to distinguish between activations for + different training batches. + Default: "default" + load_from_disk (bool): Forces function to regenerate activations if False. + Default: True + return_activations (bool, optional): Whether to return the activations. + Default: False + Returns: If `return_activations == True`, returns a single `AVDataset` if + `layers` is a str, otherwise, a list of `AVDataset`s of the length + of `layers`, where each element corresponds to a layer. In either + case, `AVDataset`'s represent the activations for a single layer, + over the entire `dataloader`. If `return_activations == False`, + does not return anything. + + """ + + unsaved_layers = AV._manage_loading_layers( + path, + model_id, + layers, + load_from_disk, + identifier, + ) + if len(unsaved_layers) > 0: + for i, data in enumerate(dataloader): + AV._compute_and_save_activations( + path, + model, + model_id, + layers, + AV._unpack_data(data), + identifier, + str(i), + ) + + if not return_activations: + return None + if isinstance(layers, str): + return AV.load(path, model_id, identifier, layers) + else: + return [AV.load(path, model_id, identifier, layer) for layer in layers] + + @staticmethod + def sort_files(files: List[str]) -> List[str]: + r""" + Utility for sorting files based on natural sorting instead of the default + lexigraphical sort. + """ + + def split_alphanum(s): + r""" + Splits string into a list of strings and numbers + "z23a" -> ["z", 23, "a"] + """ + + return [int(x) if x.isdigit() else x for x in re.split("([0-9]+)", s)] + + return sorted(files, key=split_alphanum) diff --git a/captum/_utils/common.py b/captum/_utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..6db0727024e5a2e2d056c0665b7ab010b8c8b0da --- /dev/null +++ b/captum/_utils/common.py @@ -0,0 +1,679 @@ +#!/usr/bin/env python3 +import typing +from enum import Enum +from functools import reduce +from inspect import signature +from typing import Any, Callable, cast, Dict, List, overload, Tuple, Union + +import numpy as np +import torch +from captum._utils.typing import ( + BaselineType, + Literal, + TargetType, + TensorOrTupleOfTensorsGeneric, + TupleOrTensorOrBoolGeneric, +) +from torch import device, Tensor +from torch.nn import Module + + +class ExpansionTypes(Enum): + repeat = 1 + repeat_interleave = 2 + + +def safe_div( + numerator: Tensor, + denom: Union[Tensor, int, float], + default_denom: Union[Tensor, int, float] = 1.0, +) -> Tensor: + r""" + A simple utility function to perform `numerator / denom` + if the statement is undefined => result will be `numerator / default_denorm` + """ + if isinstance(denom, (int, float)): + return numerator / (denom if denom != 0 else default_denom) + + # convert default_denom to tensor if it is float + if not torch.is_tensor(default_denom): + default_denom = torch.tensor( + default_denom, dtype=denom.dtype, device=denom.device + ) + + return numerator / torch.where(denom != 0, denom, default_denom) + + +@typing.overload +def _is_tuple(inputs: Tensor) -> Literal[False]: + ... + + +@typing.overload +def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: + ... + + +def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool: + return isinstance(inputs, tuple) + + +def _validate_target(num_samples: int, target: TargetType) -> None: + if isinstance(target, list) or ( + isinstance(target, torch.Tensor) and torch.numel(target) > 1 + ): + assert num_samples == len(target), ( + "The number of samples provied in the" + "input {} does not match with the number of targets. {}".format( + num_samples, len(target) + ) + ) + + +def _validate_input( + inputs: Tuple[Tensor, ...], + baselines: Tuple[Union[Tensor, int, float], ...], + draw_baseline_from_distrib: bool = False, +) -> None: + assert len(inputs) == len(baselines), ( + "Input and baseline must have the same " + "dimensions, baseline has {} features whereas input has {}.".format( + len(baselines), len(inputs) + ) + ) + + for input, baseline in zip(inputs, baselines): + if draw_baseline_from_distrib: + assert ( + isinstance(baseline, (int, float)) + or input.shape[1:] == baseline.shape[1:] + ), ( + "The samples in input and baseline batches must have" + " the same shape or the baseline corresponding to the" + " input tensor must be a scalar." + " Found baseline: {} and input: {} ".format(baseline, input) + ) + else: + assert ( + isinstance(baseline, (int, float)) + or input.shape == baseline.shape + or baseline.shape[0] == 1 + ), ( + "Baseline can be provided as a tensor for just one input and" + " broadcasted to the batch or input and baseline must have the" + " same shape or the baseline corresponding to each input tensor" + " must be a scalar. Found baseline: {} and input: {}".format( + baseline, input + ) + ) + + +def _zeros(inputs: Tuple[Tensor, ...]) -> Tuple[int, ...]: + r""" + Takes a tuple of tensors as input and returns a tuple that has the same + length as `inputs` with each element as the integer 0. + """ + return tuple(0 if input.dtype is not torch.bool else False for input in inputs) + + +def _format_baseline( + baselines: BaselineType, inputs: Tuple[Tensor, ...] +) -> Tuple[Union[Tensor, int, float], ...]: + if baselines is None: + return _zeros(inputs) + + if not isinstance(baselines, tuple): + baselines = (baselines,) + + for baseline in baselines: + assert isinstance( + baseline, (torch.Tensor, int, float) + ), "baseline input argument must be either a torch.Tensor or a number \ + however {} detected".format( + type(baseline) + ) + + return baselines + + +@overload +def _format_tensor_into_tuples(inputs: None) -> None: + ... + + +@overload +def _format_tensor_into_tuples( + inputs: Union[Tensor, Tuple[Tensor, ...]] +) -> Tuple[Tensor, ...]: + ... + + +def _format_tensor_into_tuples( + inputs: Union[None, Tensor, Tuple[Tensor, ...]] +) -> Union[None, Tuple[Tensor, ...]]: + if inputs is None: + return None + if not isinstance(inputs, tuple): + assert isinstance( + inputs, torch.Tensor + ), "`inputs` must have type " "torch.Tensor but {} found: ".format(type(inputs)) + inputs = (inputs,) + return inputs + + +def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any: + return ( + inputs + if (isinstance(inputs, tuple) or isinstance(inputs, list)) and unpack_inputs + else (inputs,) + ) + + +def _format_float_or_tensor_into_tuples( + inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]] +) -> Tuple[Union[float, Tensor], ...]: + if not isinstance(inputs, tuple): + assert isinstance( + inputs, (torch.Tensor, float) + ), "`inputs` must have type float or torch.Tensor but {} found: ".format( + type(inputs) + ) + inputs = (inputs,) + return inputs + + +@overload +def _format_additional_forward_args(additional_forward_args: None) -> None: + ... + + +@overload +def _format_additional_forward_args( + additional_forward_args: Union[Tensor, Tuple] +) -> Tuple: + ... + + +@overload +def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, Tuple]: + ... + + +def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, Tuple]: + if additional_forward_args is not None and not isinstance( + additional_forward_args, tuple + ): + additional_forward_args = (additional_forward_args,) + return additional_forward_args + + +def _expand_additional_forward_args( + additional_forward_args: Any, + n_steps: int, + expansion_type: ExpansionTypes = ExpansionTypes.repeat, +) -> Union[None, Tuple]: + def _expand_tensor_forward_arg( + additional_forward_arg: Tensor, + n_steps: int, + expansion_type: ExpansionTypes = ExpansionTypes.repeat, + ) -> Tensor: + if len(additional_forward_arg.size()) == 0: + return additional_forward_arg + if expansion_type == ExpansionTypes.repeat: + return torch.cat([additional_forward_arg] * n_steps, dim=0) + elif expansion_type == ExpansionTypes.repeat_interleave: + return additional_forward_arg.repeat_interleave(n_steps, dim=0) + else: + raise NotImplementedError( + "Currently only `repeat` and `repeat_interleave`" + " expansion_types are supported" + ) + + if additional_forward_args is None: + return None + + return tuple( + _expand_tensor_forward_arg(additional_forward_arg, n_steps, expansion_type) + if isinstance(additional_forward_arg, torch.Tensor) + else additional_forward_arg + for additional_forward_arg in additional_forward_args + ) + + +def _expand_target( + target: TargetType, + n_steps: int, + expansion_type: ExpansionTypes = ExpansionTypes.repeat, +) -> TargetType: + if isinstance(target, list): + if expansion_type == ExpansionTypes.repeat: + return target * n_steps + elif expansion_type == ExpansionTypes.repeat_interleave: + expanded_target = [] + for i in target: + expanded_target.extend([i] * n_steps) + return cast(Union[List[Tuple[int, ...]], List[int]], expanded_target) + else: + raise NotImplementedError( + "Currently only `repeat` and `repeat_interleave`" + " expansion_types are supported" + ) + + elif isinstance(target, torch.Tensor) and torch.numel(target) > 1: + if expansion_type == ExpansionTypes.repeat: + return torch.cat([target] * n_steps, dim=0) + elif expansion_type == ExpansionTypes.repeat_interleave: + return target.repeat_interleave(n_steps, dim=0) + else: + raise NotImplementedError( + "Currently only `repeat` and `repeat_interleave`" + " expansion_types are supported" + ) + + return target + + +def _expand_feature_mask( + feature_mask: Union[Tensor, Tuple[Tensor, ...]], n_samples: int +): + is_feature_mask_tuple = _is_tuple(feature_mask) + feature_mask = _format_tensor_into_tuples(feature_mask) + feature_mask_new = tuple( + feature_mask_elem.repeat_interleave(n_samples, dim=0) + if feature_mask_elem.size(0) > 1 + else feature_mask_elem + for feature_mask_elem in feature_mask + ) + return _format_output(is_feature_mask_tuple, feature_mask_new) + + +def _expand_and_update_baselines( + inputs: Tuple[Tensor, ...], + n_samples: int, + kwargs: dict, + draw_baseline_from_distrib: bool = False, +): + def get_random_baseline_indices(bsz, baseline): + num_ref_samples = baseline.shape[0] + return np.random.choice(num_ref_samples, n_samples * bsz).tolist() + + # expand baselines to match the sizes of input + if "baselines" not in kwargs: + return + + baselines = kwargs["baselines"] + baselines = _format_baseline(baselines, inputs) + _validate_input( + inputs, baselines, draw_baseline_from_distrib=draw_baseline_from_distrib + ) + + if draw_baseline_from_distrib: + bsz = inputs[0].shape[0] + baselines = tuple( + baseline[get_random_baseline_indices(bsz, baseline)] + if isinstance(baseline, torch.Tensor) + else baseline + for baseline in baselines + ) + else: + baselines = tuple( + baseline.repeat_interleave(n_samples, dim=0) + if isinstance(baseline, torch.Tensor) + and baseline.shape[0] == input.shape[0] + and baseline.shape[0] > 1 + else baseline + for input, baseline in zip(inputs, baselines) + ) + # update kwargs with expanded baseline + kwargs["baselines"] = baselines + + +def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict): + if "additional_forward_args" not in kwargs: + return + additional_forward_args = kwargs["additional_forward_args"] + additional_forward_args = _format_additional_forward_args(additional_forward_args) + if additional_forward_args is None: + return + additional_forward_args = _expand_additional_forward_args( + additional_forward_args, + n_samples, + expansion_type=ExpansionTypes.repeat_interleave, + ) + # update kwargs with expanded baseline + kwargs["additional_forward_args"] = additional_forward_args + + +def _expand_and_update_target(n_samples: int, kwargs: dict): + if "target" not in kwargs: + return + target = kwargs["target"] + target = _expand_target( + target, n_samples, expansion_type=ExpansionTypes.repeat_interleave + ) + # update kwargs with expanded baseline + kwargs["target"] = target + + +def _expand_and_update_feature_mask(n_samples: int, kwargs: dict): + if "feature_mask" not in kwargs: + return + + feature_mask = kwargs["feature_mask"] + if feature_mask is None: + return + + feature_mask = _expand_feature_mask(feature_mask, n_samples) + kwargs["feature_mask"] = feature_mask + + +@typing.overload +def _format_output( + is_inputs_tuple: Literal[True], output: Tuple[Tensor, ...] +) -> Tuple[Tensor, ...]: + ... + + +@typing.overload +def _format_output( + is_inputs_tuple: Literal[False], output: Tuple[Tensor, ...] +) -> Tensor: + ... + + +@typing.overload +def _format_output( + is_inputs_tuple: bool, output: Tuple[Tensor, ...] +) -> Union[Tensor, Tuple[Tensor, ...]]: + ... + + +def _format_output( + is_inputs_tuple: bool, output: Tuple[Tensor, ...] +) -> Union[Tensor, Tuple[Tensor, ...]]: + r""" + In case input is a tensor and the output is returned in form of a + tuple we take the first element of the output's tuple to match the + same shape signatues of the inputs + """ + assert isinstance(output, tuple), "Output must be in shape of a tuple" + assert is_inputs_tuple or len(output) == 1, ( + "The input is a single tensor however the output isn't." + "The number of output tensors is: {}".format(len(output)) + ) + return output if is_inputs_tuple else output[0] + + +@typing.overload +def _format_outputs( + is_multiple_inputs: Literal[False], outputs: List[Tuple[Tensor, ...]] +) -> Union[Tensor, Tuple[Tensor, ...]]: + ... + + +@typing.overload +def _format_outputs( + is_multiple_inputs: Literal[True], outputs: List[Tuple[Tensor, ...]] +) -> List[Union[Tensor, Tuple[Tensor, ...]]]: + ... + + +@typing.overload +def _format_outputs( + is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]] +) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: + ... + + +def _format_outputs( + is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]] +) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: + assert isinstance(outputs, list), "Outputs must be a list" + assert is_multiple_inputs or len(outputs) == 1, ( + "outputs should contain multiple inputs or have a single output" + f"however the number of outputs is: {len(outputs)}" + ) + + return ( + [_format_output(len(output) > 1, output) for output in outputs] + if is_multiple_inputs + else _format_output(len(outputs[0]) > 1, outputs[0]) + ) + + +def _run_forward( + forward_func: Callable, + inputs: Any, + target: TargetType = None, + additional_forward_args: Any = None, +) -> Tensor: + forward_func_args = signature(forward_func).parameters + if len(forward_func_args) == 0: + output = forward_func() + return output if target is None else _select_targets(output, target) + + # make everything a tuple so that it is easy to unpack without + # using if-statements + inputs = _format_inputs(inputs) + additional_forward_args = _format_additional_forward_args(additional_forward_args) + + output = forward_func( + *(*inputs, *additional_forward_args) + if additional_forward_args is not None + else inputs + ) + return _select_targets(output, target) + + +def _select_targets(output: Tensor, target: TargetType) -> Tensor: + if target is None: + return output + + num_examples = output.shape[0] + dims = len(output.shape) + device = output.device + if isinstance(target, (int, tuple)): + return _verify_select_column(output, target) + elif isinstance(target, torch.Tensor): + if torch.numel(target) == 1 and isinstance(target.item(), int): + return _verify_select_column(output, cast(int, target.item())) + elif len(target.shape) == 1 and torch.numel(target) == num_examples: + assert dims == 2, "Output must be 2D to select tensor of targets." + return torch.gather(output, 1, target.reshape(len(output), 1)) + else: + raise AssertionError( + "Tensor target dimension %r is not valid. %r" + % (target.shape, output.shape) + ) + elif isinstance(target, list): + assert len(target) == num_examples, "Target list length does not match output!" + if isinstance(target[0], int): + assert dims == 2, "Output must be 2D to select tensor of targets." + return torch.gather( + output, 1, torch.tensor(target, device=device).reshape(len(output), 1) + ) + elif isinstance(target[0], tuple): + return torch.stack( + [ + output[(i,) + cast(Tuple, targ_elem)] + for i, targ_elem in enumerate(target) + ] + ) + else: + raise AssertionError("Target element type in list is not valid.") + else: + raise AssertionError("Target type %r is not valid." % target) + + +def _contains_slice(target: Union[int, Tuple[Union[int, slice], ...]]) -> bool: + if isinstance(target, tuple): + for index in target: + if isinstance(index, slice): + return True + return False + return isinstance(target, slice) + + +def _verify_select_column( + output: Tensor, target: Union[int, Tuple[Union[int, slice], ...]] +) -> Tensor: + target = (target,) if isinstance(target, int) else target + assert ( + len(target) <= len(output.shape) - 1 + ), "Cannot choose target column with output shape %r." % (output.shape,) + return output[(slice(None), *target)] + + +def _verify_select_neuron( + layer_output: Tuple[Tensor, ...], + selector: Union[int, Tuple[Union[int, slice], ...], Callable], +) -> Tensor: + if callable(selector): + return selector(layer_output if len(layer_output) > 1 else layer_output[0]) + + assert len(layer_output) == 1, ( + "Cannot select neuron index from layer with multiple tensors," + "consider providing a neuron selector function instead." + ) + + selected_neurons = _verify_select_column(layer_output[0], selector) + if _contains_slice(selector): + return selected_neurons.reshape(selected_neurons.shape[0], -1).sum(1) + return selected_neurons + + +def _extract_device( + module: Module, + hook_inputs: Union[None, Tensor, Tuple[Tensor, ...]], + hook_outputs: Union[None, Tensor, Tuple[Tensor, ...]], +) -> device: + params = list(module.parameters()) + if ( + (hook_inputs is None or len(hook_inputs) == 0) + and (hook_outputs is None or len(hook_outputs) == 0) + and len(params) == 0 + ): + raise RuntimeError( + """Unable to extract device information for the module + {}. Both inputs and outputs to the forward hook and + `module.parameters()` are empty. + The reason that the inputs to the forward hook are empty + could be due to the fact that the arguments to that + module {} are all named and are passed as named + variables to its forward function. + """.format( + module, module + ) + ) + if hook_inputs is not None and len(hook_inputs) > 0: + return hook_inputs[0].device + if hook_outputs is not None and len(hook_outputs) > 0: + return hook_outputs[0].device + + return params[0].device + + +def _reduce_list( + val_list: List[TupleOrTensorOrBoolGeneric], + red_func: Callable[[List], Any] = torch.cat, +) -> TupleOrTensorOrBoolGeneric: + """ + Applies reduction function to given list. If each element in the list is + a Tensor, applies reduction function to all elements of the list, and returns + the output Tensor / value. If each element is a boolean, apply any method (or). + If each element is a tuple, applies reduction + function to corresponding elements of each tuple in the list, and returns + tuple of reduction function outputs with length matching the length of tuple + val_list[0]. It is assumed that all tuples in the list have the same length + and red_func can be applied to all elements in each corresponding position. + """ + assert len(val_list) > 0, "Cannot reduce empty list!" + if isinstance(val_list[0], torch.Tensor): + first_device = val_list[0].device + return red_func([elem.to(first_device) for elem in val_list]) + elif isinstance(val_list[0], bool): + return any(val_list) + elif isinstance(val_list[0], tuple): + final_out = [] + for i in range(len(val_list[0])): + final_out.append( + _reduce_list([val_elem[i] for val_elem in val_list], red_func) + ) + else: + raise AssertionError( + "Elements to be reduced can only be" + "either Tensors or tuples containing Tensors." + ) + return tuple(final_out) + + +def _sort_key_list( + keys: List[device], device_ids: Union[None, List[int]] = None +) -> List[device]: + """ + Sorts list of torch devices (keys) by given index list, device_ids. If keys + contains only one device, then the list is returned unchanged. If keys + contains a device for which the id is not contained in device_ids, then + an error is returned. This method is used to identify the order of DataParallel + batched devices, given the device ID ordering. + """ + if len(keys) == 1: + return keys + id_dict: Dict[int, device] = {} + assert device_ids is not None, "Device IDs must be provided with multiple devices." + for key in keys: + if key.index in id_dict: + raise AssertionError("Duplicate CUDA Device ID identified in device list.") + id_dict[key.index] = key + + out_list = [ + id_dict[device_id] + for device_id in filter(lambda device_id: device_id in id_dict, device_ids) + ] + + assert len(out_list) == len(keys), "Given Device ID List does not match" + "devices with computed tensors." + + return out_list + + +def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor: + if isinstance(inp, Tensor): + return inp.flatten() + return torch.cat([single_inp.flatten() for single_inp in inp]) + + +def _get_module_from_name(model: Module, layer_name: str) -> Any: + r""" + Returns the module (layer) object, given its (string) name + in the model. + + Args: + name (str): Module or nested modules name string in self.model + + Returns: + The module (layer) in self.model. + """ + + return reduce(getattr, layer_name.split("."), model) + + +def _register_backward_hook( + module: Module, hook: Callable, attr_obj: Any +) -> torch.utils.hooks.RemovableHandle: + # Special case for supporting output attributions for neuron methods + # This can be removed after deprecation of neuron output attributions + # for NeuronDeepLift, NeuronDeconvolution, and NeuronGuidedBackprop + # in v0.6.0 + if ( + hasattr(attr_obj, "skip_new_hook_layer") + and attr_obj.skip_new_hook_layer == module + ): + return module.register_backward_hook(hook) + + if torch.__version__ >= "1.9": + # Only supported for torch >= 1.9 + return module.register_full_backward_hook(hook) + else: + # Fallback for previous versions of PyTorch + return module.register_backward_hook(hook) diff --git a/captum/_utils/gradient.py b/captum/_utils/gradient.py new file mode 100644 index 0000000000000000000000000000000000000000..a15157d8d7ba3fe47ce91dd51a26034b89e7b2d7 --- /dev/null +++ b/captum/_utils/gradient.py @@ -0,0 +1,865 @@ +#!/usr/bin/env python3 +import threading +import typing +import warnings +from collections import defaultdict +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union + +import torch +from captum._utils.common import ( + _reduce_list, + _run_forward, + _sort_key_list, + _verify_select_neuron, +) +from captum._utils.sample_gradient import SampleGradientWrapper +from captum._utils.typing import ( + Literal, + ModuleOrModuleList, + TargetType, + TensorOrTupleOfTensorsGeneric, +) +from torch import device, Tensor +from torch.nn import Module + + +def apply_gradient_requirements( + inputs: Tuple[Tensor, ...], warn: bool = True +) -> List[bool]: + """ + Iterates through tuple on input tensors and sets requires_grad to be true on + each Tensor, and ensures all grads are set to zero. To ensure that the input + is returned to its initial state, a list of flags representing whether or not + a tensor originally required grad is returned. + """ + assert isinstance( + inputs, tuple + ), "Inputs should be wrapped in a tuple prior to preparing for gradients" + grad_required = [] + for index, input in enumerate(inputs): + assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor" + grad_required.append(input.requires_grad) + inputs_dtype = input.dtype + # Note: torch 1.2 doesn't support is_complex for dtype that's why we check + # on the existance of is_complex method. + if not inputs_dtype.is_floating_point and not ( + hasattr(inputs_dtype, "is_complex") and inputs_dtype.is_complex + ): + if warn: + warnings.warn( + """Input Tensor %d has a dtype of %s. + Gradients cannot be activated + for these data types.""" + % (index, str(inputs_dtype)) + ) + elif not input.requires_grad: + if warn: + warnings.warn( + "Input Tensor %d did not already require gradients, " + "required_grads has been set automatically." % index + ) + input.requires_grad_() + return grad_required + + +def undo_gradient_requirements( + inputs: Tuple[Tensor, ...], grad_required: List[bool] +) -> None: + """ + Iterates through list of tensors, zeros each gradient, and sets required + grad to false if the corresponding index in grad_required is False. + This method is used to undo the effects of prepare_gradient_inputs, making + grads not required for any input tensor that did not initially require + gradients. + """ + + assert isinstance( + inputs, tuple + ), "Inputs should be wrapped in a tuple prior to preparing for gradients." + assert len(inputs) == len( + grad_required + ), "Input tuple length should match gradient mask." + for index, input in enumerate(inputs): + assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor" + if not grad_required[index]: + input.requires_grad_(False) + + +def compute_gradients( + forward_fn: Callable, + inputs: Union[Tensor, Tuple[Tensor, ...]], + target_ind: TargetType = None, + additional_forward_args: Any = None, +) -> Tuple[Tensor, ...]: + r""" + Computes gradients of the output with respect to inputs for an + arbitrary forward function. + + Args: + + forward_fn: forward function. This can be for example model's + forward function. + input: Input at which gradients are evaluated, + will be passed to forward_fn. + target_ind: Index of the target class for which gradients + must be computed (classification only). + additional_forward_args: Additional input arguments that forward + function requires. It takes an empty tuple (no additional + arguments) if no additional arguments are required + """ + with torch.autograd.set_grad_enabled(True): + # runs forward pass + outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args) + assert outputs[0].numel() == 1, ( + "Target not provided when necessary, cannot" + " take gradient with respect to multiple outputs." + ) + # torch.unbind(forward_out) is a list of scalar tensor tuples and + # contains batch_size * #steps elements + grads = torch.autograd.grad(torch.unbind(outputs), inputs) + return grads + + +def _neuron_gradients( + inputs: Union[Tensor, Tuple[Tensor, ...]], + saved_layer: Dict[device, Tuple[Tensor, ...]], + key_list: List[device], + gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], +) -> Tuple[Tensor, ...]: + with torch.autograd.set_grad_enabled(True): + gradient_tensors = [] + for key in key_list: + current_out_tensor = _verify_select_neuron( + saved_layer[key], gradient_neuron_selector + ) + gradient_tensors.append( + torch.autograd.grad( + torch.unbind(current_out_tensor) + if current_out_tensor.numel() > 1 + else current_out_tensor, + inputs, + ) + ) + _total_gradients = _reduce_list(gradient_tensors, sum) + return _total_gradients + + +@typing.overload +def _forward_layer_eval( + forward_fn: Callable, + inputs: Union[Tensor, Tuple[Tensor, ...]], + layer: Module, + additional_forward_args: Any = None, + device_ids: Union[None, List[int]] = None, + attribute_to_layer_input: bool = False, + grad_enabled: bool = False, +) -> Tuple[Tensor, ...]: + ... + + +@typing.overload +def _forward_layer_eval( + forward_fn: Callable, + inputs: Union[Tensor, Tuple[Tensor, ...]], + layer: List[Module], + additional_forward_args: Any = None, + device_ids: Union[None, List[int]] = None, + attribute_to_layer_input: bool = False, + grad_enabled: bool = False, +) -> List[Tuple[Tensor, ...]]: + ... + + +def _forward_layer_eval( + forward_fn: Callable, + inputs: Union[Tensor, Tuple[Tensor, ...]], + layer: ModuleOrModuleList, + additional_forward_args: Any = None, + device_ids: Union[None, List[int]] = None, + attribute_to_layer_input: bool = False, + grad_enabled: bool = False, +) -> Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]]: + return _forward_layer_eval_with_neuron_grads( + forward_fn, + inputs, + layer, + additional_forward_args=additional_forward_args, + gradient_neuron_selector=None, + grad_enabled=grad_enabled, + device_ids=device_ids, + attribute_to_layer_input=attribute_to_layer_input, + ) + + +@typing.overload +def _forward_layer_distributed_eval( + forward_fn: Callable, + inputs: Any, + layer: ModuleOrModuleList, + target_ind: TargetType = None, + additional_forward_args: Any = None, + attribute_to_layer_input: bool = False, + forward_hook_with_return: Literal[False] = False, + require_layer_grads: bool = False, +) -> Dict[Module, Dict[device, Tuple[Tensor, ...]]]: + ... + + +@typing.overload +def _forward_layer_distributed_eval( + forward_fn: Callable, + inputs: Any, + layer: ModuleOrModuleList, + target_ind: TargetType = None, + additional_forward_args: Any = None, + attribute_to_layer_input: bool = False, + *, + forward_hook_with_return: Literal[True], + require_layer_grads: bool = False, +) -> Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor]: + ... + + +def _forward_layer_distributed_eval( + forward_fn: Callable, + inputs: Any, + layer: ModuleOrModuleList, + target_ind: TargetType = None, + additional_forward_args: Any = None, + attribute_to_layer_input: bool = False, + forward_hook_with_return: bool = False, + require_layer_grads: bool = False, +) -> Union[ + Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor], + Dict[Module, Dict[device, Tuple[Tensor, ...]]], +]: + r""" + A helper function that allows to set a hook on model's `layer`, run the forward + pass and returns intermediate layer results, stored in a dictionary, + and optionally also the output of the forward function. The keys in the + dictionary are the device ids and the values are corresponding intermediate layer + results, either the inputs or the outputs of the layer depending on whether we set + `attribute_to_layer_input` to True or False. + This is especially useful when we execute forward pass in a distributed setting, + using `DataParallel`s for example. + """ + saved_layer: Dict[Module, Dict[device, Tuple[Tensor, ...]]] = defaultdict(dict) + lock = threading.Lock() + all_layers: List[Module] = [layer] if isinstance(layer, Module) else layer + + # Set a forward hook on specified module and run forward pass to + # get layer output tensor(s). + # For DataParallel models, each partition adds entry to dictionary + # with key as device and value as corresponding Tensor. + def hook_wrapper(original_module): + def forward_hook(module, inp, out=None): + eval_tsrs = inp if attribute_to_layer_input else out + is_eval_tuple = isinstance(eval_tsrs, tuple) + + if not is_eval_tuple: + eval_tsrs = (eval_tsrs,) + if require_layer_grads: + apply_gradient_requirements(eval_tsrs, warn=False) + with lock: + nonlocal saved_layer + # Note that cloning behaviour of `eval_tsr` is different + # when `forward_hook_with_return` is set to True. This is because + # otherwise `backward()` on the last output layer won't execute. + if forward_hook_with_return: + saved_layer[original_module][eval_tsrs[0].device] = eval_tsrs + eval_tsrs_to_return = tuple( + eval_tsr.clone() for eval_tsr in eval_tsrs + ) + if not is_eval_tuple: + eval_tsrs_to_return = eval_tsrs_to_return[0] + return eval_tsrs_to_return + else: + saved_layer[original_module][eval_tsrs[0].device] = tuple( + eval_tsr.clone() for eval_tsr in eval_tsrs + ) + + return forward_hook + + all_hooks = [] + try: + for single_layer in all_layers: + if attribute_to_layer_input: + all_hooks.append( + single_layer.register_forward_pre_hook(hook_wrapper(single_layer)) + ) + else: + all_hooks.append( + single_layer.register_forward_hook(hook_wrapper(single_layer)) + ) + output = _run_forward( + forward_fn, + inputs, + target=target_ind, + additional_forward_args=additional_forward_args, + ) + finally: + for hook in all_hooks: + hook.remove() + + if len(saved_layer) == 0: + raise AssertionError("Forward hook did not obtain any outputs for given layer") + + if forward_hook_with_return: + return saved_layer, output + return saved_layer + + +def _gather_distributed_tensors( + saved_layer: Dict[device, Tuple[Tensor, ...]], + device_ids: Union[None, List[int]] = None, + key_list: Union[None, List[device]] = None, +) -> Tuple[Tensor, ...]: + r""" + A helper function to concatenate intermediate layer results stored on + different devices in `saved_layer`. `saved_layer` is a dictionary that + contains `device_id` as a key and intermediate layer results (either + the input or the output of the layer) stored on the device corresponding to + the key. + `key_list` is a list of devices in appropriate ordering for concatenation + and if not provided, keys are sorted based on device ids. + + If only one key exists (standard model), key list simply has one element. + """ + if key_list is None: + key_list = _sort_key_list(list(saved_layer.keys()), device_ids) + return _reduce_list([saved_layer[device_id] for device_id in key_list]) + + +def _extract_device_ids( + forward_fn: Callable, + saved_layer: Dict[Module, Dict[device, Tuple[Tensor, ...]]], + device_ids: Union[None, List[int]], +) -> Union[None, List[int]]: + r""" + A helper function to extract device_ids from `forward_function` in case it is + provided as part of a `DataParallel` model or if is accessible from + `forward_fn`. + In case input device_ids is not None, this function returns that value. + """ + # Multiple devices / keys implies a DataParallel model, so we look for + # device IDs if given or available from forward function + # (DataParallel model object). + if ( + max(len(saved_layer[single_layer]) for single_layer in saved_layer) > 1 + and device_ids is None + ): + if ( + hasattr(forward_fn, "device_ids") + and cast(Any, forward_fn).device_ids is not None + ): + device_ids = cast(Any, forward_fn).device_ids + else: + raise AssertionError( + "Layer tensors are saved on multiple devices, however unable to access" + " device ID list from the `forward_fn`. Device ID list must be" + " accessible from `forward_fn`. For example, they can be retrieved" + " if `forward_fn` is a model of type `DataParallel`. It is used" + " for identifying device batch ordering." + ) + return device_ids + + +@typing.overload +def _forward_layer_eval_with_neuron_grads( + forward_fn: Callable, + inputs: Union[Tensor, Tuple[Tensor, ...]], + layer: Module, + additional_forward_args: Any = None, + *, + gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + grad_enabled: bool = False, + device_ids: Union[None, List[int]] = None, + attribute_to_layer_input: bool = False, +) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: + ... + + +@typing.overload +def _forward_layer_eval_with_neuron_grads( + forward_fn: Callable, + inputs: Union[Tensor, Tuple[Tensor, ...]], + layer: Module, + additional_forward_args: Any = None, + gradient_neuron_selector: None = None, + grad_enabled: bool = False, + device_ids: Union[None, List[int]] = None, + attribute_to_layer_input: bool = False, +) -> Tuple[Tensor, ...]: + ... + + +@typing.overload +def _forward_layer_eval_with_neuron_grads( + forward_fn: Callable, + inputs: Union[Tensor, Tuple[Tensor, ...]], + layer: List[Module], + additional_forward_args: Any = None, + gradient_neuron_selector: None = None, + grad_enabled: bool = False, + device_ids: Union[None, List[int]] = None, + attribute_to_layer_input: bool = False, +) -> List[Tuple[Tensor, ...]]: + ... + + +def _forward_layer_eval_with_neuron_grads( + forward_fn: Callable, + inputs: Union[Tensor, Tuple[Tensor, ...]], + layer: ModuleOrModuleList, + additional_forward_args: Any = None, + gradient_neuron_selector: Union[ + None, int, Tuple[Union[int, slice], ...], Callable + ] = None, + grad_enabled: bool = False, + device_ids: Union[None, List[int]] = None, + attribute_to_layer_input: bool = False, +) -> Union[ + Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]], + Tuple[Tensor, ...], + List[Tuple[Tensor, ...]], +]: + """ + This method computes forward evaluation for a particular layer using a + forward hook. If a gradient_neuron_selector is provided, then gradients with + respect to that neuron in the layer output are also returned. + + These functionalities are combined due to the behavior of DataParallel models + with hooks, in which hooks are executed once per device. We need to internally + combine the separated tensors from devices by concatenating based on device_ids. + Any necessary gradients must be taken with respect to each independent batched + tensor, so the gradients are computed and combined appropriately. + + More information regarding the behavior of forward hooks with DataParallel models + can be found in the PyTorch data parallel documentation. We maintain the separate + evals in a dictionary protected by a lock, analogous to the gather implementation + for the core PyTorch DataParallel implementation. + """ + grad_enabled = True if gradient_neuron_selector is not None else grad_enabled + + with torch.autograd.set_grad_enabled(grad_enabled): + saved_layer = _forward_layer_distributed_eval( + forward_fn, + inputs, + layer, + additional_forward_args=additional_forward_args, + attribute_to_layer_input=attribute_to_layer_input, + ) + device_ids = _extract_device_ids(forward_fn, saved_layer, device_ids) + # Identifies correct device ordering based on device ids. + # key_list is a list of devices in appropriate ordering for concatenation. + # If only one key exists (standard model), key list simply has one element. + key_list = _sort_key_list(list(next(iter(saved_layer.values())).keys()), device_ids) + if gradient_neuron_selector is not None: + assert isinstance( + layer, Module + ), "Cannot compute neuron gradients for multiple layers simultaneously!" + inp_grads = _neuron_gradients( + inputs, saved_layer[layer], key_list, gradient_neuron_selector + ) + return ( + _gather_distributed_tensors(saved_layer[layer], key_list=key_list), + inp_grads, + ) + else: + if isinstance(layer, Module): + return _gather_distributed_tensors(saved_layer[layer], key_list=key_list) + else: + return [ + _gather_distributed_tensors(saved_layer[curr_layer], key_list=key_list) + for curr_layer in layer + ] + + +@typing.overload +def compute_layer_gradients_and_eval( + forward_fn: Callable, + layer: Module, + inputs: Union[Tensor, Tuple[Tensor, ...]], + target_ind: TargetType = None, + additional_forward_args: Any = None, + *, + gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + device_ids: Union[None, List[int]] = None, + attribute_to_layer_input: bool = False, + output_fn: Union[None, Callable] = None, +) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]]: + ... + + +@typing.overload +def compute_layer_gradients_and_eval( + forward_fn: Callable, + layer: List[Module], + inputs: Union[Tensor, Tuple[Tensor, ...]], + target_ind: TargetType = None, + additional_forward_args: Any = None, + gradient_neuron_selector: None = None, + device_ids: Union[None, List[int]] = None, + attribute_to_layer_input: bool = False, + output_fn: Union[None, Callable] = None, +) -> Tuple[List[Tuple[Tensor, ...]], List[Tuple[Tensor, ...]]]: + ... + + +@typing.overload +def compute_layer_gradients_and_eval( + forward_fn: Callable, + layer: Module, + inputs: Union[Tensor, Tuple[Tensor, ...]], + target_ind: TargetType = None, + additional_forward_args: Any = None, + gradient_neuron_selector: None = None, + device_ids: Union[None, List[int]] = None, + attribute_to_layer_input: bool = False, + output_fn: Union[None, Callable] = None, +) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: + ... + + +def compute_layer_gradients_and_eval( + forward_fn: Callable, + layer: ModuleOrModuleList, + inputs: Union[Tensor, Tuple[Tensor, ...]], + target_ind: TargetType = None, + additional_forward_args: Any = None, + gradient_neuron_selector: Union[ + None, int, Tuple[Union[int, slice], ...], Callable + ] = None, + device_ids: Union[None, List[int]] = None, + attribute_to_layer_input: bool = False, + output_fn: Union[None, Callable] = None, +) -> Union[ + Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]], + Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]], + Tuple[List[Tuple[Tensor, ...]], List[Tuple[Tensor, ...]]], +]: + r""" + Computes gradients of the output with respect to a given layer as well + as the output evaluation of the layer for an arbitrary forward function + and given input. + + For data parallel models, hooks are executed once per device ,so we + need to internally combine the separated tensors from devices by + concatenating based on device_ids. Any necessary gradients must be taken + with respect to each independent batched tensor, so the gradients are + computed and combined appropriately. + + More information regarding the behavior of forward hooks with DataParallel + models can be found in the PyTorch data parallel documentation. We maintain + the separate inputs in a dictionary protected by a lock, analogous to the + gather implementation for the core PyTorch DataParallel implementation. + + NOTE: To properly handle inplace operations, a clone of the layer output + is stored. This structure inhibits execution of a backward hook on the last + module for the layer output when computing the gradient with respect to + the input, since we store an intermediate clone, as + opposed to the true module output. If backward module hooks are necessary + for the final module when computing input gradients, utilize + _forward_layer_eval_with_neuron_grads instead. + + Args: + + forward_fn: forward function. This can be for example model's + forward function. + layer: Layer for which gradients / output will be evaluated. + inputs: Input at which gradients are evaluated, + will be passed to forward_fn. + target_ind: Index of the target class for which gradients + must be computed (classification only). + output_fn: An optional function that is applied to the layer inputs or + outputs depending whether the `attribute_to_layer_input` is + set to `True` or `False` + args: Additional input arguments that forward function requires. + It takes an empty tuple (no additional arguments) if no + additional arguments are required + + + Returns: + 2-element tuple of **gradients**, **evals**: + - **gradients**: + Gradients of output with respect to target layer output. + - **evals**: + Target layer output for given input. + """ + with torch.autograd.set_grad_enabled(True): + # saved_layer is a dictionary mapping device to a tuple of + # layer evaluations on that device. + saved_layer, output = _forward_layer_distributed_eval( + forward_fn, + inputs, + layer, + target_ind=target_ind, + additional_forward_args=additional_forward_args, + attribute_to_layer_input=attribute_to_layer_input, + forward_hook_with_return=True, + require_layer_grads=True, + ) + assert output[0].numel() == 1, ( + "Target not provided when necessary, cannot" + " take gradient with respect to multiple outputs." + ) + + device_ids = _extract_device_ids(forward_fn, saved_layer, device_ids) + + # Identifies correct device ordering based on device ids. + # key_list is a list of devices in appropriate ordering for concatenation. + # If only one key exists (standard model), key list simply has one element. + key_list = _sort_key_list( + list(next(iter(saved_layer.values())).keys()), device_ids + ) + all_outputs: Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]] + if isinstance(layer, Module): + all_outputs = _reduce_list( + [ + saved_layer[layer][device_id] + if output_fn is None + else output_fn(saved_layer[layer][device_id]) + for device_id in key_list + ] + ) + else: + all_outputs = [ + _reduce_list( + [ + saved_layer[single_layer][device_id] + if output_fn is None + else output_fn(saved_layer[single_layer][device_id]) + for device_id in key_list + ] + ) + for single_layer in layer + ] + all_layers: List[Module] = [layer] if isinstance(layer, Module) else layer + grad_inputs = tuple( + layer_tensor + for single_layer in all_layers + for device_id in key_list + for layer_tensor in saved_layer[single_layer][device_id] + ) + saved_grads = torch.autograd.grad(torch.unbind(output), grad_inputs) + + offset = 0 + all_grads: List[Tuple[Tensor, ...]] = [] + for single_layer in all_layers: + num_tensors = len(next(iter(saved_layer[single_layer].values()))) + curr_saved_grads = [ + saved_grads[i : i + num_tensors] + for i in range( + offset, offset + len(key_list) * num_tensors, num_tensors + ) + ] + offset += len(key_list) * num_tensors + if output_fn is not None: + curr_saved_grads = [ + output_fn(curr_saved_grad) for curr_saved_grad in curr_saved_grads + ] + + all_grads.append(_reduce_list(curr_saved_grads)) + + layer_grads: Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]] + layer_grads = all_grads + if isinstance(layer, Module): + layer_grads = all_grads[0] + + if gradient_neuron_selector is not None: + assert isinstance( + layer, Module + ), "Cannot compute neuron gradients for multiple layers simultaneously!" + inp_grads = _neuron_gradients( + inputs, saved_layer[layer], key_list, gradient_neuron_selector + ) + return ( + cast(Tuple[Tensor, ...], layer_grads), + cast(Tuple[Tensor, ...], all_outputs), + inp_grads, + ) + return layer_grads, all_outputs # type: ignore + + +def construct_neuron_grad_fn( + layer: Module, + neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + device_ids: Union[None, List[int]] = None, + attribute_to_neuron_input: bool = False, +) -> Callable: + def grad_fn( + forward_fn: Callable, + inputs: TensorOrTupleOfTensorsGeneric, + target_ind: TargetType = None, + additional_forward_args: Any = None, + ) -> Tuple[Tensor, ...]: + _, grads = _forward_layer_eval_with_neuron_grads( + forward_fn, + inputs, + layer, + additional_forward_args, + gradient_neuron_selector=neuron_selector, + device_ids=device_ids, + attribute_to_layer_input=attribute_to_neuron_input, + ) + return grads + + return grad_fn + + +def _compute_jacobian_wrt_params( + model: Module, + inputs: Tuple[Any, ...], + labels: Optional[Tensor] = None, + loss_fn: Optional[Union[Module, Callable]] = None, +) -> Tuple[Tensor, ...]: + r""" + Computes the Jacobian of a batch of test examples given a model, and optional + loss function and target labels. This method is equivalent to calculating the + gradient for every individual example in the minibatch. + + Args: + model (torch.nn.Module): The trainable model providing the forward pass + inputs (tuple of Any): The minibatch for which the forward pass is computed. + It is unpacked before passing to `model`, so it must be a tuple. The + individual elements of `inputs` can be anything. + labels (Tensor or None): Labels for input if computing a loss function. + loss_fn (torch.nn.Module or Callable or None): The loss function. If a library + defined loss function is provided, it would be expected to be a + torch.nn.Module. If a custom loss is provided, it can be either type, + but must behave as a library loss function would if `reduction='none'`. + + Returns: + grads (Tuple of Tensor): Returns the Jacobian for the minibatch as a + tuple of gradients corresponding to the tuple of trainable parameters + returned by `model.parameters()`. Each object grads[i] references to the + gradients for the parameters in the i-th trainable layer of the model. + Each grads[i] object is a tensor with the gradients for the `inputs` + batch. For example, grads[i][j] would reference the gradients for the + parameters of the i-th layer, for the j-th member of the minibatch. + """ + with torch.autograd.set_grad_enabled(True): + out = model(*inputs) + assert out.dim() != 0, "Please ensure model output has at least one dimension." + + if labels is not None and loss_fn is not None: + loss = loss_fn(out, labels) + if hasattr(loss_fn, "reduction"): + msg0 = "Please ensure loss_fn.reduction is set to `none`" + assert loss_fn.reduction == "none", msg0 # type: ignore + else: + msg1 = ( + "Loss function is applying a reduction. Please ensure " + f"Output shape: {out.shape} and Loss shape: {loss.shape} " + "are matching." + ) + assert loss.dim() != 0, msg1 + assert out.shape[0] == loss.shape[0], msg1 + out = loss + + grads_list = [ + torch.autograd.grad( + outputs=out[i], + inputs=model.parameters(), # type: ignore + grad_outputs=torch.ones_like(out[i]), + retain_graph=True, + ) + for i in range(out.shape[0]) + ] + + grads = tuple([torch.stack(x) for x in zip(*grads_list)]) + + return tuple(grads) + + +def _compute_jacobian_wrt_params_with_sample_wise_trick( + model: Module, + inputs: Tuple[Any, ...], + labels: Optional[Tensor] = None, + loss_fn: Optional[Union[Module, Callable]] = None, + reduction_type: Optional[str] = "sum", +) -> Tuple[Any, ...]: + r""" + Computes the Jacobian of a batch of test examples given a model, and optional + loss function and target labels. This method uses sample-wise gradients per + batch trick to fully vectorize the Jacobian calculation. Currently, only + linear and conv2d layers are supported. + + User must `add_hooks(model)` before calling this function. + + Args: + model (torch.nn.Module): The trainable model providing the forward pass + inputs (tuple of Any): The minibatch for which the forward pass is computed. + It is unpacked before passing to `model`, so it must be a tuple. The + individual elements of `inputs` can be anything. + labels (Tensor or None): Labels for input if computing a loss function. + loss_fn (torch.nn.Module or Callable or None): The loss function. If a library + defined loss function is provided, it would be expected to be a + torch.nn.Module. If a custom loss is provided, it can be either type, + but must behave as a library loss function would if `reduction='sum'` or + `reduction='mean'`. + reduction_type (str): The type of reduction applied. If a loss_fn is passed, + this should match `loss_fn.reduction`. Else if gradients are being + computed on direct model outputs (scores), then 'sum' should be used. + Defaults to 'sum'. + + Returns: + grads (Tuple of Tensor): Returns the Jacobian for the minibatch as a + tuple of gradients corresponding to the tuple of trainable parameters + returned by `model.parameters()`. Each object grads[i] references to the + gradients for the parameters in the i-th trainable layer of the model. + Each grads[i] object is a tensor with the gradients for the `inputs` + batch. For example, grads[i][j] would reference the gradients for the + parameters of the i-th layer, for the j-th member of the minibatch. + """ + with torch.autograd.set_grad_enabled(True): + sample_grad_wrapper = SampleGradientWrapper(model) + try: + sample_grad_wrapper.add_hooks() + + out = model(*inputs) + assert ( + out.dim() != 0 + ), "Please ensure model output has at least one dimension." + + if labels is not None and loss_fn is not None: + loss = loss_fn(out, labels) + # TODO: allow loss_fn to be Callable + if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"): + msg0 = ( + "Please ensure that loss_fn.reduction is set to `sum` or `mean`" + ) + + assert loss_fn.reduction != "none", msg0 + msg1 = ( + f"loss_fn.reduction ({loss_fn.reduction}) does not match" + f"reduction type ({reduction_type}). Please ensure they are" + " matching." + ) + assert loss_fn.reduction == reduction_type, msg1 + msg2 = ( + "Please ensure custom loss function is applying either a " + "sum or mean reduction." + ) + assert out.shape != loss.shape, msg2 + + if reduction_type != "sum" and reduction_type != "mean": + raise ValueError( + f"{reduction_type} is not a valid value for reduction_type. " + "Must be either 'sum' or 'mean'." + ) + out = loss + + sample_grad_wrapper.compute_param_sample_gradients( + out, loss_mode=reduction_type + ) + + grads = tuple( + param.sample_grad # type: ignore + for param in model.parameters() + if hasattr(param, "sample_grad") + ) + finally: + sample_grad_wrapper.remove_hooks() + + return grads diff --git a/captum/_utils/models/__init__.py b/captum/_utils/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5ebcee2e47a19d79dcc73a897d56754f3b78d35d --- /dev/null +++ b/captum/_utils/models/__init__.py @@ -0,0 +1,25 @@ +from captum._utils.models.linear_model import ( + LinearModel, + SGDLasso, + SGDLinearModel, + SGDLinearRegression, + SGDRidge, + SkLearnLasso, + SkLearnLinearModel, + SkLearnLinearRegression, + SkLearnRidge, +) +from captum._utils.models.model import Model + +__all__ = [ + "Model", + "LinearModel", + "SGDLinearModel", + "SGDLasso", + "SGDRidge", + "SGDLinearRegression", + "SkLearnLinearModel", + "SkLearnLasso", + "SkLearnRidge", + "SkLearnLinearRegression", +] diff --git a/captum/_utils/models/linear_model/__init__.py b/captum/_utils/models/linear_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4f50d2146a9bbda32ff6f0779bf80b141823dec --- /dev/null +++ b/captum/_utils/models/linear_model/__init__.py @@ -0,0 +1,23 @@ +from captum._utils.models.linear_model.model import ( + LinearModel, + SGDLasso, + SGDLinearModel, + SGDLinearRegression, + SGDRidge, + SkLearnLasso, + SkLearnLinearModel, + SkLearnLinearRegression, + SkLearnRidge, +) + +__all__ = [ + "LinearModel", + "SGDLinearModel", + "SGDLasso", + "SGDRidge", + "SGDLinearRegression", + "SkLearnLinearModel", + "SkLearnLasso", + "SkLearnRidge", + "SkLearnLinearRegression", +] diff --git a/captum/_utils/models/linear_model/model.py b/captum/_utils/models/linear_model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..bfffdbf38a4c5193c54a476bd00f44e47dc164c9 --- /dev/null +++ b/captum/_utils/models/linear_model/model.py @@ -0,0 +1,341 @@ +from typing import Callable, cast, List, Optional + +import torch.nn as nn +from captum._utils.models.model import Model +from torch import Tensor +from torch.utils.data import DataLoader + + +class LinearModel(nn.Module, Model): + SUPPORTED_NORMS: List[Optional[str]] = [None, "batch_norm", "layer_norm"] + + def __init__(self, train_fn: Callable, **kwargs) -> None: + r""" + Constructs a linear model with a training function and additional + construction arguments that will be sent to + `self._construct_model_params` after a `self.fit` is called. Please note + that this assumes the `self.train_fn` will call + `self._construct_model_params`. + + Please note that this is an experimental feature. + + Args: + train_fn (callable) + The function to train with. See + `captum._utils.models.linear_model.train.sgd_train_linear_model` + and + `captum._utils.models.linear_model.train.sklearn_train_linear_model` + for examples + kwargs + Any additional keyword arguments to send to + `self._construct_model_params` once a `self.fit` is called. + """ + super().__init__() + + self.norm: Optional[nn.Module] = None + self.linear: Optional[nn.Linear] = None + self.train_fn = train_fn + self.construct_kwargs = kwargs + + def _construct_model_params( + self, + in_features: Optional[int] = None, + out_features: Optional[int] = None, + norm_type: Optional[str] = None, + affine_norm: bool = False, + bias: bool = True, + weight_values: Optional[Tensor] = None, + bias_value: Optional[Tensor] = None, + classes: Optional[Tensor] = None, + ): + r""" + Lazily initializes a linear model. This will be called for you in a + train method. + + Args: + in_features (int): + The number of input features + output_features (int): + The number of output features. + norm_type (str, optional): + The type of normalization that can occur. Please assign this + to one of `PyTorchLinearModel.SUPPORTED_NORMS`. + affine_norm (bool): + Whether or not to learn an affine transformation of the + normalization parameters used. + bias (bool): + Whether to add a bias term. Not needed if normalized input. + weight_values (tensor, optional): + The values to initialize the linear model with. This must be a + 1D or 2D tensor, and of the form `(num_outputs, num_features)` or + `(num_features,)`. Additionally, if this is provided you need not + to provide `in_features` or `out_features`. + bias_value (tensor, optional): + The bias value to initialize the model with. + classes (tensor, optional): + The list of prediction classes supported by the model in case it + performs classificaton. In case of regression it is set to None. + Default: None + """ + if norm_type not in LinearModel.SUPPORTED_NORMS: + raise ValueError( + f"{norm_type} not supported. Please use {LinearModel.SUPPORTED_NORMS}" + ) + + if weight_values is not None: + in_features = weight_values.shape[-1] + out_features = ( + 1 if len(weight_values.shape) == 1 else weight_values.shape[0] + ) + + if in_features is None or out_features is None: + raise ValueError( + "Please provide `in_features` and `out_features` or `weight_values`" + ) + + if norm_type == "batch_norm": + self.norm = nn.BatchNorm1d(in_features, eps=1e-8, affine=affine_norm) + elif norm_type == "layer_norm": + self.norm = nn.LayerNorm( + in_features, eps=1e-8, elementwise_affine=affine_norm + ) + else: + self.norm = None + + self.linear = nn.Linear(in_features, out_features, bias=bias) + + if weight_values is not None: + self.linear.weight.data = weight_values + + if bias_value is not None: + if not bias: + raise ValueError("`bias_value` is not None and bias is False") + + self.linear.bias.data = bias_value + + if classes is not None: + self.linear.classes = classes + + def fit(self, train_data: DataLoader, **kwargs): + r""" + Calls `self.train_fn` + """ + return self.train_fn( + self, + dataloader=train_data, + construct_kwargs=self.construct_kwargs, + **kwargs, + ) + + def forward(self, x: Tensor) -> Tensor: + assert self.linear is not None + if self.norm is not None: + x = self.norm(x) + return self.linear(x) + + def representation(self) -> Tensor: + r""" + Returns a tensor which describes the hyper-plane input space. This does + not include the bias. For bias/intercept, please use `self.bias` + """ + assert self.linear is not None + return self.linear.weight.detach() + + def bias(self) -> Optional[Tensor]: + r""" + Returns the bias of the linear model + """ + if self.linear is None or self.linear.bias is None: + return None + return self.linear.bias.detach() + + def classes(self) -> Optional[Tensor]: + if self.linear is None or self.linear.classes is None: + return None + return cast(Tensor, self.linear.classes).detach() + + +class SGDLinearModel(LinearModel): + def __init__(self, **kwargs) -> None: + r""" + Factory class. Construct a a `LinearModel` with the + `sgd_train_linear_model` as the train method + + Args: + kwargs + Arguments send to `self._construct_model_params` after + `self.fit` is called. Please refer to that method for parameter + documentation. + """ + # avoid cycles + from captum._utils.models.linear_model.train import sgd_train_linear_model + + super().__init__(train_fn=sgd_train_linear_model, **kwargs) + + +class SGDLasso(SGDLinearModel): + def __init__(self, **kwargs) -> None: + r""" + Factory class to train a `LinearModel` with SGD + (`sgd_train_linear_model`) whilst setting appropriate parameters to + optimize for ridge regression loss. This optimizes L2 loss + alpha * L1 + regularization. + + Please note that with SGD it is not guaranteed that weights will + converge to 0. + """ + super().__init__(**kwargs) + + def fit(self, train_data: DataLoader, **kwargs): + # avoid cycles + from captum._utils.models.linear_model.train import l2_loss + + return super().fit(train_data=train_data, loss_fn=l2_loss, reg_term=1, **kwargs) + + +class SGDRidge(SGDLinearModel): + def __init__(self, **kwargs) -> None: + r""" + Factory class to train a `LinearModel` with SGD + (`sgd_train_linear_model`) whilst setting appropriate parameters to + optimize for ridge regression loss. This optimizes L2 loss + alpha * + L2 regularization. + """ + super().__init__(**kwargs) + + def fit(self, train_data: DataLoader, **kwargs): + # avoid cycles + from captum._utils.models.linear_model.train import l2_loss + + return super().fit(train_data=train_data, loss_fn=l2_loss, reg_term=2, **kwargs) + + +class SGDLinearRegression(SGDLinearModel): + def __init__(self, **kwargs) -> None: + r""" + Factory class to train a `LinearModel` with SGD + (`sgd_train_linear_model`). For linear regression this assigns the loss + to L2 and no regularization. + """ + super().__init__(**kwargs) + + def fit(self, train_data: DataLoader, **kwargs): + # avoid cycles + from captum._utils.models.linear_model.train import l2_loss + + return super().fit( + train_data=train_data, loss_fn=l2_loss, reg_term=None, **kwargs + ) + + +class SkLearnLinearModel(LinearModel): + def __init__(self, sklearn_module: str, **kwargs) -> None: + r""" + Factory class to construct a `LinearModel` with sklearn training method. + + Please note that this assumes: + + 0. You have sklearn and numpy installed + 1. The dataset can fit into memory + + SkLearn support does introduce some slight overhead as we convert the + tensors to numpy and then convert the resulting trained model to a + `LinearModel` object. However, this conversion should be negligible. + + Args: + sklearn_module + The module under sklearn to construct and use for training, e.g. + use "svm.LinearSVC" for an SVM or "linear_model.Lasso" for Lasso. + + There are factory classes defined for you for common use cases, + such as `SkLearnLasso`. + kwargs + The kwargs to pass to the construction of the sklearn model + """ + # avoid cycles + from captum._utils.models.linear_model.train import sklearn_train_linear_model + + super().__init__(train_fn=sklearn_train_linear_model, **kwargs) + + self.sklearn_module = sklearn_module + + def fit(self, train_data: DataLoader, **kwargs): + r""" + Args: + train_data + Train data to use + kwargs + Arguments to feed to `.fit` method for sklearn + """ + return super().fit( + train_data=train_data, sklearn_trainer=self.sklearn_module, **kwargs + ) + + +class SkLearnLasso(SkLearnLinearModel): + def __init__(self, **kwargs) -> None: + r""" + Factory class. Trains a `LinearModel` model with + `sklearn.linear_model.Lasso`. You will need sklearn version >= 0.23 to + support sample weights. + """ + super().__init__(sklearn_module="linear_model.Lasso", **kwargs) + + def fit(self, train_data: DataLoader, **kwargs): + return super().fit(train_data=train_data, **kwargs) + + +class SkLearnRidge(SkLearnLinearModel): + def __init__(self, **kwargs) -> None: + r""" + Factory class. Trains a model with `sklearn.linear_model.Ridge`. + + Any arguments provided to the sklearn constructor can be provided + as kwargs here. + """ + super().__init__(sklearn_module="linear_model.Ridge", **kwargs) + + def fit(self, train_data: DataLoader, **kwargs): + return super().fit(train_data=train_data, **kwargs) + + +class SkLearnLinearRegression(SkLearnLinearModel): + def __init__(self, **kwargs) -> None: + r""" + Factory class. Trains a model with `sklearn.linear_model.LinearRegression`. + + Any arguments provided to the sklearn constructor can be provided + as kwargs here. + """ + super().__init__(sklearn_module="linear_model.LinearRegression", **kwargs) + + def fit(self, train_data: DataLoader, **kwargs): + return super().fit(train_data=train_data, **kwargs) + + +class SkLearnLogisticRegression(SkLearnLinearModel): + def __init__(self, **kwargs) -> None: + r""" + Factory class. Trains a model with `sklearn.linear_model.LogisticRegression`. + + Any arguments provided to the sklearn constructor can be provided + as kwargs here. + """ + super().__init__(sklearn_module="linear_model.LogisticRegression", **kwargs) + + def fit(self, train_data: DataLoader, **kwargs): + return super().fit(train_data=train_data, **kwargs) + + +class SkLearnSGDClassifier(SkLearnLinearModel): + def __init__(self, **kwargs) -> None: + r""" + Factory class. Trains a model with `sklearn.linear_model.SGDClassifier(`. + + Any arguments provided to the sklearn constructor can be provided + as kwargs here. + """ + super().__init__(sklearn_module="linear_model.SGDClassifier", **kwargs) + + def fit(self, train_data: DataLoader, **kwargs): + return super().fit(train_data=train_data, **kwargs) diff --git a/captum/_utils/models/linear_model/train.py b/captum/_utils/models/linear_model/train.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf8a2e4bf3fd0ee4493f465a7ab9e68ffc7cb66 --- /dev/null +++ b/captum/_utils/models/linear_model/train.py @@ -0,0 +1,364 @@ +import time +import warnings +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.nn as nn +from captum._utils.models.linear_model.model import LinearModel +from torch.utils.data import DataLoader + + +def l2_loss(x1, x2, weights=None): + if weights is None: + return torch.mean((x1 - x2) ** 2) / 2.0 + else: + return torch.sum((weights / weights.norm(p=1)) * ((x1 - x2) ** 2)) / 2.0 + + +def sgd_train_linear_model( + model: LinearModel, + dataloader: DataLoader, + construct_kwargs: Dict[str, Any], + max_epoch: int = 100, + reduce_lr: bool = True, + initial_lr: float = 0.01, + alpha: float = 1.0, + loss_fn: Callable = l2_loss, + reg_term: Optional[int] = 1, + patience: int = 10, + threshold: float = 1e-4, + running_loss_window: Optional[int] = None, + device: Optional[str] = None, + init_scheme: str = "zeros", + debug: bool = False, +) -> Dict[str, float]: + r""" + Trains a linear model with SGD. This will continue to iterate your + dataloader until we converged to a solution or alternatively until we have + exhausted `max_epoch`. + + Convergence is defined by the loss not changing by `threshold` amount for + `patience` number of iterations. + + Args: + model + The model to train + dataloader + The data to train it with. We will assume the dataloader produces + either pairs or triples of the form (x, y) or (x, y, w). Where x and + y are typical pairs for supervised learning and w is a weight + vector. + + We will call `model._construct_model_params` with construct_kwargs + and the input features set to `x.shape[1]` (`x.shape[0]` corresponds + to the batch size). We assume that `len(x.shape) == 2`, i.e. the + tensor is flat. The number of output features will be set to + y.shape[1] or 1 (if `len(y.shape) == 1`); we require `len(y.shape) + <= 2`. + max_epoch + The maximum number of epochs to exhaust + reduce_lr + Whether or not to reduce the learning rate as iterations progress. + Halves the learning rate when the training loss does not move. This + uses torch.optim.lr_scheduler.ReduceLROnPlateau and uses the + parameters `patience` and `threshold` + initial_lr + The initial learning rate to use. + alpha + A constant for the regularization term. + loss_fn + The loss to optimise for. This must accept three parameters: + x1 (predicted), x2 (labels) and a weight vector + reg_term + Regularization is defined by the `reg_term` norm of the weights. + Please use `None` if you do not wish to use regularization. + patience + Defines the number of iterations in a row the loss must remain + within `threshold` in order to be classified as converged. + threshold + Threshold for convergence detection. + running_loss_window + Used to report the training loss once we have finished training and + to determine when we have converged (along with reducing the + learning rate). + + The reported training loss will take the last `running_loss_window` + iterations and average them. + + If `None` we will approximate this to be the number of examples in + an epoch. + init_scheme + Initialization to use prior to training the linear model. + device + The device to send the model and data to. If None then no `.to` call + will be used. + debug + Whether to print the loss, learning rate per iteration + + Returns + This will return the final training loss (averaged with + `running_loss_window`) + """ + + loss_window: List[torch.Tensor] = [] + min_avg_loss = None + convergence_counter = 0 + converged = False + + def get_point(datapoint): + if len(datapoint) == 2: + x, y = datapoint + w = None + else: + x, y, w = datapoint + + if device is not None: + x = x.to(device) + y = y.to(device) + if w is not None: + w = w.to(device) + + return x, y, w + + # get a point and construct the model + data_iter = iter(dataloader) + x, y, w = get_point(next(data_iter)) + + model._construct_model_params( + in_features=x.shape[1], + out_features=y.shape[1] if len(y.shape) == 2 else 1, + **construct_kwargs, + ) + model.train() + + assert model.linear is not None + + if init_scheme is not None: + assert init_scheme in ["xavier", "zeros"] + + with torch.no_grad(): + if init_scheme == "xavier": + torch.nn.init.xavier_uniform_(model.linear.weight) + else: + model.linear.weight.zero_() + + if model.linear.bias is not None: + model.linear.bias.zero_() + + optim = torch.optim.SGD(model.parameters(), lr=initial_lr) + if reduce_lr: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optim, factor=0.5, patience=patience, threshold=threshold + ) + + t1 = time.time() + epoch = 0 + i = 0 + while epoch < max_epoch: + while True: # for x, y, w in dataloader + if running_loss_window is None: + running_loss_window = x.shape[0] * len(dataloader) + + y = y.view(x.shape[0], -1) + if w is not None: + w = w.view(x.shape[0], -1) + + i += 1 + + out = model(x) + + loss = loss_fn(y, out, w) + if reg_term is not None: + reg = torch.norm(model.linear.weight, p=reg_term) + loss += reg.sum() * alpha + + if len(loss_window) >= running_loss_window: + loss_window = loss_window[1:] + loss_window.append(loss.clone().detach()) + assert len(loss_window) <= running_loss_window + + average_loss = torch.mean(torch.stack(loss_window)) + if min_avg_loss is not None: + # if we haven't improved by at least `threshold` + if average_loss > min_avg_loss or torch.isclose( + min_avg_loss, average_loss, atol=threshold + ): + convergence_counter += 1 + if convergence_counter >= patience: + converged = True + break + else: + convergence_counter = 0 + if min_avg_loss is None or min_avg_loss >= average_loss: + min_avg_loss = average_loss.clone() + + if debug: + print( + f"lr={optim.param_groups[0]['lr']}, Loss={loss}," + + "Aloss={average_loss}, min_avg_loss={min_avg_loss}" + ) + + loss.backward() + + optim.step() + model.zero_grad() + if scheduler: + scheduler.step(average_loss) + + temp = next(data_iter, None) + if temp is None: + break + x, y, w = get_point(temp) + + if converged: + break + + epoch += 1 + data_iter = iter(dataloader) + x, y, w = get_point(next(data_iter)) + + t2 = time.time() + return { + "train_time": t2 - t1, + "train_loss": torch.mean(torch.stack(loss_window)).item(), + "train_iter": i, + "train_epoch": epoch, + } + + +class NormLayer(nn.Module): + def __init__(self, mean, std, n=None, eps=1e-8) -> None: + super().__init__() + self.mean = mean + self.std = std + self.eps = eps + + def forward(self, x): + return (x - self.mean) / (self.std + self.eps) + + +def sklearn_train_linear_model( + model: LinearModel, + dataloader: DataLoader, + construct_kwargs: Dict[str, Any], + sklearn_trainer: str = "Lasso", + norm_input: bool = False, + **fit_kwargs, +): + r""" + Alternative method to train with sklearn. This does introduce some slight + overhead as we convert the tensors to numpy and then convert the resulting + trained model to a `LinearModel` object. However, this conversion + should be negligible. + + Please note that this assumes: + + 0. You have sklearn and numpy installed + 1. The dataset can fit into memory + + Args + model + The model to train. + dataloader + The data to use. This will be exhausted and converted to numpy + arrays. Therefore please do not feed an infinite dataloader. + norm_input + Whether or not to normalize the input + sklearn_trainer + The sklearn model to use to train the model. Please refer to + sklearn.linear_model for a list of modules to use. + construct_kwargs + Additional arguments provided to the `sklearn_trainer` constructor + fit_kwargs + Other arguments to send to `sklearn_trainer`'s `.fit` method + """ + from functools import reduce + + try: + import numpy as np + except ImportError: + raise ValueError("numpy is not available. Please install numpy.") + + try: + import sklearn + import sklearn.linear_model + import sklearn.svm + except ImportError: + raise ValueError("sklearn is not available. Please install sklearn >= 0.23") + + if not sklearn.__version__ >= "0.23.0": + warnings.warn( + "Must have sklearn version 0.23.0 or higher to use " + "sample_weight in Lasso regression." + ) + + num_batches = 0 + xs, ys, ws = [], [], [] + for data in dataloader: + if len(data) == 3: + x, y, w = data + else: + assert len(data) == 2 + x, y = data + w = None + + xs.append(x.cpu().numpy()) + ys.append(y.cpu().numpy()) + if w is not None: + ws.append(w.cpu().numpy()) + num_batches += 1 + + x = np.concatenate(xs, axis=0) + y = np.concatenate(ys, axis=0) + if len(ws) > 0: + w = np.concatenate(ws, axis=0) + else: + w = None + + if norm_input: + mean, std = x.mean(0), x.std(0) + x -= mean + x /= std + + t1 = time.time() + sklearn_model = reduce( + lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") + )(**construct_kwargs) + try: + sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs) + except TypeError: + sklearn_model.fit(x, y, **fit_kwargs) + warnings.warn( + "Sample weight is not supported for the provided linear model!" + " Trained model without weighting inputs. For Lasso, please" + " upgrade sklearn to a version >= 0.23.0." + ) + + t2 = time.time() + + # Convert weights to pytorch + classes = ( + torch.IntTensor(sklearn_model.classes_) + if hasattr(sklearn_model, "classes_") + else None + ) + + # extract model device + device = model.device if hasattr(model, "device") else "cpu" + + num_outputs = sklearn_model.coef_.shape[0] if sklearn_model.coef_.ndim > 1 else 1 + weight_values = torch.FloatTensor(sklearn_model.coef_).to(device) # type: ignore + bias_values = torch.FloatTensor([sklearn_model.intercept_]).to( # type: ignore + device # type: ignore + ) # type: ignore + model._construct_model_params( + norm_type=None, + weight_values=weight_values.view(num_outputs, -1), + bias_value=bias_values.squeeze().unsqueeze(0), + classes=classes, + ) + + if norm_input: + model.norm = NormLayer(mean, std) + + return {"train_time": t2 - t1} diff --git a/captum/_utils/models/model.py b/captum/_utils/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9e8a98db0471952b8dd1f6feeeee53363594a0e6 --- /dev/null +++ b/captum/_utils/models/model.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 + +from abc import ABC, abstractmethod +from typing import Dict, Optional, Union + +from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from torch import Tensor +from torch.utils.data import DataLoader + + +class Model(ABC): + r""" + Abstract Class to describe the interface of a trainable model to be used + within the algorithms of captum. + + Please note that this is an experimental feature. + """ + + @abstractmethod + def fit( + self, train_data: DataLoader, **kwargs + ) -> Optional[Dict[str, Union[int, float, Tensor]]]: + r""" + Override this method to actually train your model. + + The specification of the dataloader will be supplied by the algorithm + you are using within captum. This will likely be a supervised learning + task, thus you should expect batched (x, y) pairs or (x, y, w) triples. + + Args: + train_data (DataLoader): + The data to train on + + Returns: + Optional statistics about training, e.g. iterations it took to + train, training loss, etc. + """ + pass + + @abstractmethod + def representation(self) -> Tensor: + r""" + Returns the underlying representation of the interpretable model. For a + linear model this is simply a tensor (the concatenation of weights + and bias). For something slightly more complicated, such as a decision + tree, this could be the nodes of a decision tree. + + Returns: + A Tensor describing the representation of the model. + """ + pass + + @abstractmethod + def __call__( + self, x: TensorOrTupleOfTensorsGeneric + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Predicts with the interpretable model. + + Args: + x (TensorOrTupleOfTensorsGeneric) + A batched input of tensor(s) to the model to predict + Returns: + The prediction of the input as a TensorOrTupleOfTensorsGeneric. + """ + pass diff --git a/captum/_utils/progress.py b/captum/_utils/progress.py new file mode 100644 index 0000000000000000000000000000000000000000..88cb07e83f749c95e9055ed7b1583d44cce8f9a8 --- /dev/null +++ b/captum/_utils/progress.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 + +import sys +import warnings +from time import time +from typing import cast, Iterable, Sized, TextIO + +try: + from tqdm import tqdm +except ImportError: + tqdm = None + + +class DisableErrorIOWrapper(object): + def __init__(self, wrapped: TextIO): + """ + The wrapper around a TextIO object to ignore write errors like tqdm + https://github.com/tqdm/tqdm/blob/bcce20f771a16cb8e4ac5cc5b2307374a2c0e535/tqdm/utils.py#L131 + """ + self._wrapped = wrapped + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + @staticmethod + def _wrapped_run(func, *args, **kwargs): + try: + return func(*args, **kwargs) + except OSError as e: + if e.errno != 5: + raise + except ValueError as e: + if "closed" not in str(e): + raise + + def write(self, *args, **kwargs): + return self._wrapped_run(self._wrapped.write, *args, **kwargs) + + def flush(self, *args, **kwargs): + return self._wrapped_run(self._wrapped.flush, *args, **kwargs) + + +class SimpleProgress: + def __init__( + self, + iterable: Iterable = None, + desc: str = None, + total: int = None, + file: TextIO = None, + mininterval: float = 0.5, + ): + """ + Simple progress output used when tqdm is unavailable. + Same as tqdm, output to stderr channel + """ + self.cur = 0 + + self.iterable = iterable + self.total = total + if total is None and hasattr(iterable, "__len__"): + self.total = len(cast(Sized, iterable)) + + self.desc = desc + + file = DisableErrorIOWrapper(file if file else sys.stderr) + cast(TextIO, file) + self.file = file + + self.mininterval = mininterval + self.last_print_t = 0.0 + self.closed = False + + def __iter__(self): + if self.closed or not self.iterable: + return + self._refresh() + for it in self.iterable: + yield it + self.update() + self.close() + + def _refresh(self): + progress_str = self.desc + ": " if self.desc else "" + if self.total: + # e.g., progress: 60% 3/5 + progress_str += f"{100 * self.cur // self.total}% {self.cur}/{self.total}" + else: + # e.g., progress: ..... + progress_str += "." * self.cur + + print("\r" + progress_str, end="", file=self.file) + + def update(self, amount: int = 1): + if self.closed: + return + self.cur += amount + + cur_t = time() + if cur_t - self.last_print_t >= self.mininterval: + self._refresh() + self.last_print_t = cur_t + + def close(self): + if not self.closed: + self._refresh() + print(file=self.file) # end with new line + self.closed = True + + +def progress( + iterable: Iterable = None, + desc: str = None, + total: int = None, + use_tqdm=True, + file: TextIO = None, + mininterval: float = 0.5, + **kwargs, +): + # Try to use tqdm is possible. Fall back to simple progress print + if tqdm and use_tqdm: + return tqdm( + iterable, + desc=desc, + total=total, + file=file, + mininterval=mininterval, + **kwargs, + ) + else: + if not tqdm and use_tqdm: + warnings.warn( + "Tried to show progress with tqdm " + "but tqdm is not installed. " + "Fall back to simply print out the progress." + ) + return SimpleProgress( + iterable, desc=desc, total=total, file=file, mininterval=mininterval + ) diff --git a/captum/_utils/sample_gradient.py b/captum/_utils/sample_gradient.py new file mode 100644 index 0000000000000000000000000000000000000000..694b2c012179e695f9cb0ae41fb5b0447ec4849e --- /dev/null +++ b/captum/_utils/sample_gradient.py @@ -0,0 +1,184 @@ +from collections import defaultdict +from enum import Enum +from typing import cast, Iterable, Tuple, Union + +import torch +from captum._utils.common import _format_tensor_into_tuples, _register_backward_hook +from torch import Tensor +from torch.nn import Module + + +def _reset_sample_grads(module: Module): + module.weight.sample_grad = 0 # type: ignore + if module.bias is not None: + module.bias.sample_grad = 0 # type: ignore + + +def linear_param_grads( + module: Module, activation: Tensor, gradient_out: Tensor, reset: bool = False +) -> None: + r""" + Computes parameter gradients per sample for nn.Linear module, given module + input activations and output gradients. + + Gradients are accumulated in the sample_grad attribute of each parameter + (weight and bias). If reset = True, any current sample_grad values are reset, + otherwise computed gradients are accumulated and added to the existing + stored gradients. + + Inputs with more than 2 dimensions are only supported with torch 1.8 or later + """ + if reset: + _reset_sample_grads(module) + + module.weight.sample_grad += torch.einsum( # type: ignore + "n...i,n...j->nij", gradient_out, activation + ) + if module.bias is not None: + module.bias.sample_grad += torch.einsum( # type: ignore + "n...i->ni", gradient_out + ) + + +def conv2d_param_grads( + module: Module, activation: Tensor, gradient_out: Tensor, reset: bool = False +) -> None: + r""" + Computes parameter gradients per sample for nn.Conv2d module, given module + input activations and output gradients. + + nn.Conv2d modules with padding set to a string option ('same' or 'valid') are + currently unsupported. + + Gradients are accumulated in the sample_grad attribute of each parameter + (weight and bias). If reset = True, any current sample_grad values are reset, + otherwise computed gradients are accumulated and added to the existing + stored gradients. + """ + if reset: + _reset_sample_grads(module) + + batch_size = cast(int, activation.shape[0]) + unfolded_act = torch.nn.functional.unfold( + activation, + cast(Union[int, Tuple[int, ...]], module.kernel_size), + dilation=cast(Union[int, Tuple[int, ...]], module.dilation), + padding=cast(Union[int, Tuple[int, ...]], module.padding), + stride=cast(Union[int, Tuple[int, ...]], module.stride), + ) + reshaped_grad = gradient_out.reshape(batch_size, -1, unfolded_act.shape[-1]) + grad1 = torch.einsum("ijk,ilk->ijl", reshaped_grad, unfolded_act) + shape = [batch_size] + list(cast(Iterable[int], module.weight.shape)) + module.weight.sample_grad += grad1.reshape(shape) # type: ignore + if module.bias is not None: + module.bias.sample_grad += torch.sum(reshaped_grad, dim=2) # type: ignore + + +SUPPORTED_MODULES = { + torch.nn.Conv2d: conv2d_param_grads, + torch.nn.Linear: linear_param_grads, +} + + +class LossMode(Enum): + SUM = 0 + MEAN = 1 + + +class SampleGradientWrapper: + r""" + Wrapper which allows computing sample-wise gradients in a single backward pass. + + This is accomplished by adding hooks to capture activations and output + gradients for supported modules, and using these activations and gradients + to compute the parameter gradients per-sample. + + Currently, only nn.Linear and nn.Conv2d modules are supported. + + Similar reference implementations of sample-based gradients include: + - https://github.com/cybertronai/autograd-hacks + - https://github.com/pytorch/opacus/tree/main/opacus/grad_sample + """ + + def __init__(self, model): + self.model = model + self.hooks_added = False + self.activation_dict = defaultdict(list) + self.gradient_dict = defaultdict(list) + self.forward_hooks = [] + self.backward_hooks = [] + + def add_hooks(self): + self.hooks_added = True + self.model.apply(self._register_module_hooks) + + def _register_module_hooks(self, module: torch.nn.Module): + if isinstance(module, tuple(SUPPORTED_MODULES.keys())): + self.forward_hooks.append( + module.register_forward_hook(self._forward_hook_fn) + ) + self.backward_hooks.append( + _register_backward_hook(module, self._backward_hook_fn, None) + ) + + def _forward_hook_fn( + self, + module: Module, + module_input: Union[Tensor, Tuple[Tensor, ...]], + module_output: Union[Tensor, Tuple[Tensor, ...]], + ): + inp_tuple = _format_tensor_into_tuples(module_input) + self.activation_dict[module].append(inp_tuple[0].clone().detach()) + + def _backward_hook_fn( + self, + module: Module, + grad_input: Union[Tensor, Tuple[Tensor, ...]], + grad_output: Union[Tensor, Tuple[Tensor, ...]], + ): + grad_output_tuple = _format_tensor_into_tuples(grad_output) + self.gradient_dict[module].append(grad_output_tuple[0].clone().detach()) + + def remove_hooks(self): + self.hooks_added = False + + for hook in self.forward_hooks: + hook.remove() + + for hook in self.backward_hooks: + hook.remove() + + self.forward_hooks = [] + self.backward_hooks = [] + + def _reset(self): + self.activation_dict = defaultdict(list) + self.gradient_dict = defaultdict(list) + + def compute_param_sample_gradients(self, loss_blob, loss_mode="mean"): + assert ( + loss_mode.upper() in LossMode.__members__ + ), f"Provided loss mode {loss_mode} is not valid" + mode = LossMode[loss_mode.upper()] + + self.model.zero_grad() + loss_blob.backward(gradient=torch.ones_like(loss_blob)) + + for module in self.gradient_dict: + sample_grad_fn = SUPPORTED_MODULES[type(module)] + activations = self.activation_dict[module] + gradients = self.gradient_dict[module] + assert len(activations) == len(gradients), ( + "Number of saved activations do not match number of saved gradients." + " This may occur if multiple forward passes are run without calling" + " reset or computing param gradients." + ) + # Reversing grads since when a module is used multiple times, + # the activations will be aligned with the reverse order of the gradients, + # since the order is reversed in backprop. + for i, (act, grad) in enumerate( + zip(activations, list(reversed(gradients))) + ): + mult = 1 if mode is LossMode.SUM else act.shape[0] + sample_grad_fn(module, act, grad * mult, reset=(i == 0)) + self._reset() diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..89ea6af0483281e3069aa58907ccf64e40fb00ce --- /dev/null +++ b/captum/_utils/typing.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 + +from typing import List, Tuple, TYPE_CHECKING, TypeVar, Union + +from torch import Tensor +from torch.nn import Module + +if TYPE_CHECKING: + import sys + + if sys.version_info >= (3, 8): + from typing import Literal # noqa: F401 + else: + from typing_extensions import Literal # noqa: F401 +else: + Literal = {True: bool, False: bool, (True, False): bool} + +TensorOrTupleOfTensorsGeneric = TypeVar( + "TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...] +) +TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool) +ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module]) +TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]] +BaselineType = Union[None, Tensor, int, float, Tuple[Union[Tensor, int, float], ...]] + +TensorLikeList1D = List[float] +TensorLikeList2D = List[TensorLikeList1D] +TensorLikeList3D = List[TensorLikeList2D] +TensorLikeList4D = List[TensorLikeList3D] +TensorLikeList5D = List[TensorLikeList4D] +TensorLikeList = Union[ + TensorLikeList1D, + TensorLikeList2D, + TensorLikeList3D, + TensorLikeList4D, + TensorLikeList5D, +] diff --git a/captum/attr/__init__.py b/captum/attr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b942230a1002c6bd723889ee33a6e87efd78e83 --- /dev/null +++ b/captum/attr/__init__.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +from captum.attr._core.deep_lift import DeepLift, DeepLiftShap # noqa +from captum.attr._core.feature_ablation import FeatureAblation # noqa +from captum.attr._core.feature_permutation import FeaturePermutation # noqa +from captum.attr._core.gradient_shap import GradientShap # noqa +from captum.attr._core.guided_backprop_deconvnet import ( # noqa + Deconvolution, + GuidedBackprop, +) +from captum.attr._core.guided_grad_cam import GuidedGradCam # noqa +from captum.attr._core.input_x_gradient import InputXGradient # noqa +from captum.attr._core.integrated_gradients import IntegratedGradients # noqa +from captum.attr._core.kernel_shap import KernelShap # noqa +from captum.attr._core.layer.grad_cam import LayerGradCam # noqa +from captum.attr._core.layer.internal_influence import InternalInfluence # noqa +from captum.attr._core.layer.layer_activation import LayerActivation # noqa +from captum.attr._core.layer.layer_conductance import LayerConductance # noqa +from captum.attr._core.layer.layer_deep_lift import ( # noqa + LayerDeepLift, + LayerDeepLiftShap, +) +from captum.attr._core.layer.layer_feature_ablation import LayerFeatureAblation # noqa +from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap # noqa +from captum.attr._core.layer.layer_gradient_x_activation import ( # noqa + LayerGradientXActivation, +) +from captum.attr._core.layer.layer_integrated_gradients import ( # noqa + LayerIntegratedGradients, +) +from captum.attr._core.layer.layer_lrp import LayerLRP # noqa +from captum.attr._core.lime import Lime, LimeBase # noqa +from captum.attr._core.lrp import LRP # noqa +from captum.attr._core.neuron.neuron_conductance import NeuronConductance # noqa +from captum.attr._core.neuron.neuron_deep_lift import ( # noqa + NeuronDeepLift, + NeuronDeepLiftShap, +) +from captum.attr._core.neuron.neuron_feature_ablation import ( # noqa + NeuronFeatureAblation, +) +from captum.attr._core.neuron.neuron_gradient import NeuronGradient # noqa +from captum.attr._core.neuron.neuron_gradient_shap import NeuronGradientShap # noqa +from captum.attr._core.neuron.neuron_guided_backprop_deconvnet import ( # noqa + NeuronDeconvolution, + NeuronGuidedBackprop, +) +from captum.attr._core.neuron.neuron_integrated_gradients import ( # noqa + NeuronIntegratedGradients, +) +from captum.attr._core.noise_tunnel import NoiseTunnel # noqa +from captum.attr._core.occlusion import Occlusion # noqa +from captum.attr._core.saliency import Saliency # noqa +from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling # noqa +from captum.attr._models.base import ( # noqa + configure_interpretable_embedding_layer, + InterpretableEmbeddingBase, + remove_interpretable_embedding_layer, + TokenReferenceBase, +) +from captum.attr._utils import visualization # noqa +from captum.attr._utils.attribution import ( # noqa # noqa # noqa # noqa # noqa + Attribution, + GradientAttribution, + LayerAttribution, + NeuronAttribution, + PerturbationAttribution, +) +from captum.attr._utils.class_summarizer import ClassSummarizer +from captum.attr._utils.stat import ( + CommonStats, + Count, + Max, + Mean, + Min, + MSE, + StdDev, + Sum, + Var, +) +from captum.attr._utils.summarizer import Summarizer + +__all__ = [ + "Attribution", + "GradientAttribution", + "PerturbationAttribution", + "NeuronAttribution", + "LayerAttribution", + "IntegratedGradients", + "DeepLift", + "DeepLiftShap", + "InputXGradient", + "Saliency", + "GuidedBackprop", + "Deconvolution", + "GuidedGradCam", + "FeatureAblation", + "FeaturePermutation", + "Occlusion", + "ShapleyValueSampling", + "ShapleyValues", + "LimeBase", + "Lime", + "LRP", + "KernelShap", + "LayerConductance", + "LayerGradientXActivation", + "LayerActivation", + "LayerFeatureAblation", + "InternalInfluence", + "LayerGradCam", + "LayerDeepLift", + "LayerDeepLiftShap", + "LayerGradientShap", + "LayerIntegratedGradients", + "LayerLRP", + "NeuronConductance", + "NeuronFeatureAblation", + "NeuronGradient", + "NeuronIntegratedGradients", + "NeuronDeepLift", + "NeuronDeepLiftShap", + "NeuronGradientShap", + "NeuronDeconvolution", + "NeuronGuidedBackprop", + "NoiseTunnel", + "GradientShap", + "InterpretableEmbeddingBase", + "TokenReferenceBase", + "visualization", + "configure_interpretable_embedding_layer", + "remove_interpretable_embedding_layer", + "Summarizer", + "CommonStats", + "ClassSummarizer", + "Mean", + "StdDev", + "MSE", + "Var", + "Min", + "Max", + "Sum", + "Count", +] diff --git a/captum/attr/_core/__init__.py b/captum/attr/_core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/attr/_core/deep_lift.py b/captum/attr/_core/deep_lift.py new file mode 100644 index 0000000000000000000000000000000000000000..251e68dc23a09dcad5d9edc6863a6346aa7c9e3e --- /dev/null +++ b/captum/attr/_core/deep_lift.py @@ -0,0 +1,1151 @@ +#!/usr/bin/env python3 +import typing +import warnings +from typing import Any, Callable, cast, List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from captum._utils.common import ( + _expand_additional_forward_args, + _expand_target, + _format_additional_forward_args, + _format_baseline, + _format_output, + _format_tensor_into_tuples, + _is_tuple, + _register_backward_hook, + _run_forward, + _select_targets, + ExpansionTypes, +) +from captum._utils.gradient import ( + apply_gradient_requirements, + undo_gradient_requirements, +) +from captum._utils.typing import ( + BaselineType, + Literal, + TargetType, + TensorOrTupleOfTensorsGeneric, +) +from captum.attr._utils.attribution import GradientAttribution +from captum.attr._utils.common import ( + _call_custom_attribution_func, + _compute_conv_delta_and_format_attrs, + _format_callable_baseline, + _tensorize_baseline, + _validate_input, +) +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module +from torch.utils.hooks import RemovableHandle + + +# Check if module backward hook can safely be used for the module that produced +# this inputs / outputs mapping +def _check_valid_module(inputs_grad_fn, outputs) -> bool: + def is_output_cloned(output_fn, input_grad_fn) -> bool: + """ + Checks if the output has been cloned. This happens especially in case of + layer deeplift. + """ + return ( + output_fn[0].next_functions is not None + and output_fn[0].next_functions[0][0] == input_grad_fn + ) + + curr_fn = outputs.grad_fn + first_next = curr_fn.next_functions[0] + try: + # if `inputs` in the input to the network then the grad_fn is None and + # for that input backward_hook isn't computed. That's the reason why we + # need to check on `inputs_grad_fns[first_next[1]]` being None. + return ( + inputs_grad_fn is None + or first_next[0] == inputs_grad_fn + or is_output_cloned(first_next, inputs_grad_fn) + ) + except IndexError: + return False + + +class DeepLift(GradientAttribution): + r""" + Implements DeepLIFT algorithm based on the following paper: + Learning Important Features Through Propagating Activation Differences, + Avanti Shrikumar, et. al. + https://arxiv.org/abs/1704.02685 + + and the gradient formulation proposed in: + Towards better understanding of gradient-based attribution methods for + deep neural networks, Marco Ancona, et.al. + https://openreview.net/pdf?id=Sy21R9JAW + + This implementation supports only Rescale rule. RevealCancel rule will + be supported in later releases. + In addition to that, in order to keep the implementation cleaner, DeepLIFT + for internal neurons and layers extends current implementation and is + implemented separately in LayerDeepLift and NeuronDeepLift. + Although DeepLIFT's(Rescale Rule) attribution quality is comparable with + Integrated Gradients, it runs significantly faster than Integrated + Gradients and is preferred for large datasets. + + Currently we only support a limited number of non-linear activations + but the plan is to expand the list in the future. + + Note: As we know, currently we cannot access the building blocks, + of PyTorch's built-in LSTM, RNNs and GRUs such as Tanh and Sigmoid. + Nonetheless, it is possible to build custom LSTMs, RNNS and GRUs + with performance similar to built-in ones using TorchScript. + More details on how to build custom RNNs can be found here: + https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/ + """ + + def __init__( + self, + model: Module, + multiply_by_inputs: bool = True, + eps: float = 1e-10, + ) -> None: + r""" + Args: + + model (nn.Module): The reference to PyTorch model instance. Model cannot + contain any in-place nonlinear submodules; these are not + supported by the register_full_backward_hook PyTorch API + starting from PyTorch v1.9. + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in + then that type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of DeepLift, if `multiply_by_inputs` + is set to True, final sensitivity scores + are being multiplied by (inputs - baselines). + This flag applies only if `custom_attribution_func` is + set to None. + + eps (float, optional): A value at which to consider output/input change + significant when computing the gradients for non-linear layers. + This is useful to adjust, depending on your model's bit depth, + to avoid numerical issues during the gradient computation. + Default: 1e-10 + """ + GradientAttribution.__init__(self, model) + self.model = model + self.eps = eps + self.forward_handles: List[RemovableHandle] = [] + self.backward_handles: List[RemovableHandle] = [] + self._multiply_by_inputs = multiply_by_inputs + + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: Literal[False] = False, + custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, + ) -> TensorOrTupleOfTensorsGeneric: + ... + + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + *, + return_convergence_delta: Literal[True], + custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, + ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: + ... + + @log_usage() + def attribute( # type: ignore + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: bool = False, + custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, + ) -> Union[ + TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] + ]: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which + attributions are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define reference samples that are compared with + the inputs. In order to assign attribution scores DeepLift + computes the differences between the inputs/outputs and + corresponding references. + Baselines can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a tuple + containing multiple additional arguments including tensors + or any arbitrary python types. These arguments are provided to + forward_func in order, following the arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + return_convergence_delta (bool, optional): Indicates whether to return + convergence delta or not. If `return_convergence_delta` + is set to True convergence delta will be returned in + a tuple following attributions. + Default: False + custom_attribution_func (callable, optional): A custom function for + computing final attribution scores. This function can take + at least one and at most three arguments with the + following signature: + + - custom_attribution_func(multipliers) + - custom_attribution_func(multipliers, inputs) + - custom_attribution_func(multipliers, inputs, baselines) + + In case this function is not provided, we use the default + logic defined as: multipliers * (inputs - baselines) + It is assumed that all input arguments, `multipliers`, + `inputs` and `baselines` are provided in tuples of same + length. `custom_attribution_func` returns a tuple of + attribution tensors that have the same length as the + `inputs`. + + Default: None + + Returns: + **attributions** or 2-element tuple of **attributions**, **delta**: + - **attributions** (*tensor* or tuple of *tensors*): + Attribution score computed based on DeepLift rescale rule with respect + to each input feature. Attributions will always be + the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + - **delta** (*tensor*, returned if return_convergence_delta=True): + This is computed using the property that + the total sum of forward_func(inputs) - forward_func(baselines) + must equal the total sum of the attributions computed + based on DeepLift's rescale rule. + Delta is calculated per example, meaning that the number of + elements in returned delta tensor is equal to the number of + of examples in input. + Note that the logic described for deltas is guaranteed when the + default logic for attribution computations is used, meaning that the + `custom_attribution_func=None`, otherwise it is not guaranteed and + depends on the specifics of the `custom_attribution_func`. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> dl = DeepLift(net) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes deeplift attribution scores for class 3. + >>> attribution = dl.attribute(input, target=3) + """ + + # Keeps track whether original input is a tuple or not before + # converting it into a tuple. + is_inputs_tuple = _is_tuple(inputs) + + inputs = _format_tensor_into_tuples(inputs) + baselines = _format_baseline(baselines, inputs) + + gradient_mask = apply_gradient_requirements(inputs) + + _validate_input(inputs, baselines) + + # set hooks for baselines + warnings.warn( + """Setting forward, backward hooks and attributes on non-linear + activations. The hooks and attributes will be removed + after the attribution is finished""" + ) + baselines = _tensorize_baseline(inputs, baselines) + main_model_hooks = [] + try: + main_model_hooks = self._hook_main_model() + + self.model.apply(self._register_hooks) + + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + + expanded_target = _expand_target( + target, 2, expansion_type=ExpansionTypes.repeat + ) + + wrapped_forward_func = self._construct_forward_func( + self.model, + (inputs, baselines), + expanded_target, + additional_forward_args, + ) + gradients = self.gradient_func(wrapped_forward_func, inputs) + if custom_attribution_func is None: + if self.multiplies_by_inputs: + attributions = tuple( + (input - baseline) * gradient + for input, baseline, gradient in zip( + inputs, baselines, gradients + ) + ) + else: + attributions = gradients + else: + attributions = _call_custom_attribution_func( + custom_attribution_func, gradients, inputs, baselines + ) + finally: + # Even if any error is raised, remove all hooks before raising + self._remove_hooks(main_model_hooks) + + undo_gradient_requirements(inputs, gradient_mask) + return _compute_conv_delta_and_format_attrs( + self, + return_convergence_delta, + attributions, + baselines, + inputs, + additional_forward_args, + target, + is_inputs_tuple, + ) + + def _construct_forward_func( + self, + forward_func: Callable, + inputs: Tuple, + target: TargetType = None, + additional_forward_args: Any = None, + ) -> Callable: + def forward_fn(): + model_out = _run_forward( + forward_func, inputs, None, additional_forward_args + ) + return _select_targets( + torch.cat((model_out[:, 0], model_out[:, 1])), target + ) + + if hasattr(forward_func, "device_ids"): + forward_fn.device_ids = forward_func.device_ids # type: ignore + return forward_fn + + def _is_non_linear(self, module: Module) -> bool: + return type(module) in SUPPORTED_NON_LINEAR.keys() + + def _forward_pre_hook_ref( + self, module: Module, inputs: Union[Tensor, Tuple[Tensor, ...]] + ) -> None: + inputs = _format_tensor_into_tuples(inputs) + module.input_ref = tuple( # type: ignore + input.clone().detach() for input in inputs + ) + + def _forward_pre_hook( + self, module: Module, inputs: Union[Tensor, Tuple[Tensor, ...]] + ) -> None: + """ + For the modules that perform in-place operations such as ReLUs, we cannot + use inputs from forward hooks. This is because in that case inputs + and outputs are the same. We need access the inputs in pre-hooks and + set necessary hooks on inputs there. + """ + inputs = _format_tensor_into_tuples(inputs) + module.input = inputs[0].clone().detach() + module.input_grad_fns = inputs[0].grad_fn # type: ignore + + def tensor_backward_hook(grad): + if module.saved_grad is None: + raise RuntimeError( + """Module {} was detected as not supporting correctly module + backward hook. You should modify your hook to ignore the given + grad_inputs (recompute them by hand if needed) and save the + newly computed grad_inputs in module.saved_grad. See MaxPool1d + as an example.""".format( + module + ) + ) + return module.saved_grad + + # the hook is set by default but it will be used only for + # failure cases and will be removed otherwise + handle = inputs[0].register_hook(tensor_backward_hook) + module.input_hook = handle + + def _forward_hook( + self, + module: Module, + inputs: Union[Tensor, Tuple[Tensor, ...]], + outputs: Union[Tensor, Tuple[Tensor, ...]], + ) -> None: + r""" + we need forward hook to access and detach the inputs and + outputs of a neuron + """ + outputs = _format_tensor_into_tuples(outputs) + module.output = outputs[0].clone().detach() + if not _check_valid_module(module.input_grad_fns, outputs[0]): + warnings.warn( + """An invalid module {} is detected. Saved gradients will + be used as the gradients of the module's input tensor. + See MaxPool1d as an example.""".format( + module + ) + ) + module.is_invalid = True # type: ignore + module.saved_grad = None # type: ignore + self.forward_handles.append(cast(RemovableHandle, module.input_hook)) + else: + module.is_invalid = False # type: ignore + # removing the hook if there is no failure case + cast(RemovableHandle, module.input_hook).remove() + del module.input_hook + del module.input_grad_fns + + def _backward_hook( + self, + module: Module, + grad_input: Union[Tensor, Tuple[Tensor, ...]], + grad_output: Union[Tensor, Tuple[Tensor, ...]], + ): + r""" + `grad_input` is the gradient of the neuron with respect to its input + `grad_output` is the gradient of the neuron with respect to its output + we can override `grad_input` according to chain rule with. + `grad_output` * delta_out / delta_in. + + """ + # before accessing the attributes from the module we want + # to ensure that the properties exist, if not, then it is + # likely that the module is being reused. + attr_criteria = self.satisfies_attribute_criteria(module) + if not attr_criteria: + raise RuntimeError( + "A Module {} was detected that does not contain some of " + "the input/output attributes that are required for DeepLift " + "computations. This can occur, for example, if " + "your module is being used more than once in the network." + "Please, ensure that module is being used only once in the " + "network.".format(module) + ) + multipliers = tuple( + SUPPORTED_NON_LINEAR[type(module)]( + module, + module.input, + module.output, + grad_input, + grad_output, + eps=self.eps, + ) + ) + # remove all the properies that we set for the inputs and output + del module.input + del module.output + + return multipliers + + def satisfies_attribute_criteria(self, module: Module) -> bool: + return hasattr(module, "input") and hasattr(module, "output") + + def _can_register_hook(self, module: Module) -> bool: + # TODO find a better way of checking if a module is a container or not + module_fullname = str(type(module)) + has_already_hooks = len(module._backward_hooks) > 0 # type: ignore + return not ( + "nn.modules.container" in module_fullname + or has_already_hooks + or not self._is_non_linear(module) + ) + + def _register_hooks( + self, module: Module, attribute_to_layer_input: bool = True + ) -> None: + if not self._can_register_hook(module) or ( + not attribute_to_layer_input and module is self.layer # type: ignore + ): + return + # adds forward hook to leaf nodes that are non-linear + forward_handle = module.register_forward_hook(self._forward_hook) + pre_forward_handle = module.register_forward_pre_hook(self._forward_pre_hook) + backward_handle = _register_backward_hook(module, self._backward_hook, self) + self.forward_handles.append(forward_handle) + self.forward_handles.append(pre_forward_handle) + self.backward_handles.append(backward_handle) + + def _remove_hooks(self, extra_hooks_to_remove: List[RemovableHandle]) -> None: + for handle in extra_hooks_to_remove: + handle.remove() + for forward_handle in self.forward_handles: + forward_handle.remove() + for backward_handle in self.backward_handles: + backward_handle.remove() + + def _hook_main_model(self) -> List[RemovableHandle]: + def pre_hook(module: Module, baseline_inputs_add_args: Tuple) -> Tuple: + inputs = baseline_inputs_add_args[0] + baselines = baseline_inputs_add_args[1] + additional_args = None + if len(baseline_inputs_add_args) > 2: + additional_args = baseline_inputs_add_args[2:] + + baseline_input_tsr = tuple( + torch.cat([input, baseline]) + for input, baseline in zip(inputs, baselines) + ) + if additional_args is not None: + expanded_additional_args = cast( + Tuple, + _expand_additional_forward_args( + additional_args, 2, ExpansionTypes.repeat + ), + ) + return (*baseline_input_tsr, *expanded_additional_args) + return baseline_input_tsr + + def forward_hook(module: Module, inputs: Tuple, outputs: Tensor): + return torch.stack(torch.chunk(outputs, 2), dim=1) + + if isinstance( + self.model, (nn.DataParallel, nn.parallel.DistributedDataParallel) + ): + return [ + self.model.module.register_forward_pre_hook(pre_hook), # type: ignore + self.model.module.register_forward_hook(forward_hook), + ] # type: ignore + else: + return [ + self.model.register_forward_pre_hook(pre_hook), # type: ignore + self.model.register_forward_hook(forward_hook), + ] # type: ignore + + def has_convergence_delta(self) -> bool: + return True + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs + + +class DeepLiftShap(DeepLift): + r""" + Extends DeepLift algorithm and approximates SHAP values using Deeplift. + For each input sample it computes DeepLift attribution with respect to + each baseline and averages resulting attributions. + More details about the algorithm can be found here: + + http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf + + Note that the explanation model: + 1. Assumes that input features are independent of one another + 2. Is linear, meaning that the explanations are modeled through + the additive composition of feature effects. + Although, it assumes a linear model for each explanation, the overall + model across multiple explanations can be complex and non-linear. + """ + + def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None: + r""" + Args: + + model (nn.Module): The reference to PyTorch model instance. Model cannot + contain any in-place nonlinear submodules; these are not + supported by the register_full_backward_hook PyTorch API. + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in + then that type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of DeepLiftShap, if `multiply_by_inputs` + is set to True, final sensitivity scores + are being multiplied by (inputs - baselines). + This flag applies only if `custom_attribution_func` is + set to None. + """ + DeepLift.__init__(self, model, multiply_by_inputs=multiply_by_inputs) + + # There's a mismatch between the signatures of DeepLift.attribute and + # DeepLiftShap.attribute, so we ignore typing here + @typing.overload # type: ignore + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: Union[ + TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] + ], + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: Literal[False] = False, + custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, + ) -> TensorOrTupleOfTensorsGeneric: + ... + + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: Union[ + TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] + ], + target: TargetType = None, + additional_forward_args: Any = None, + *, + return_convergence_delta: Literal[True], + custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, + ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: + ... + + @log_usage() + def attribute( # type: ignore + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: Union[ + TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] + ], + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: bool = False, + custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, + ) -> Union[ + TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] + ]: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which + attributions are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + baselines (tensor, tuple of tensors, callable): + Baselines define reference samples that are compared with + the inputs. In order to assign attribution scores DeepLift + computes the differences between the inputs/outputs and + corresponding references. Baselines can be provided as: + + - a single tensor, if inputs is a single tensor, with + the first dimension equal to the number of examples + in the baselines' distribution. The remaining dimensions + must match with input tensor's dimension starting from + the second dimension. + + - a tuple of tensors, if inputs is a tuple of tensors, + with the first dimension of any tensor inside the tuple + equal to the number of examples in the baseline's + distribution. The remaining dimensions must match + the dimensions of the corresponding input tensor + starting from the second dimension. + + - callable function, optionally takes `inputs` as an + argument and either returns a single tensor + or a tuple of those. + + It is recommended that the number of samples in the baselines' + tensors is larger than one. + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a tuple + containing multiple additional arguments including tensors + or any arbitrary python types. These arguments are provided to + forward_func in order, following the arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + return_convergence_delta (bool, optional): Indicates whether to return + convergence delta or not. If `return_convergence_delta` + is set to True convergence delta will be returned in + a tuple following attributions. + Default: False + custom_attribution_func (callable, optional): A custom function for + computing final attribution scores. This function can take + at least one and at most three arguments with the + following signature: + + - custom_attribution_func(multipliers) + - custom_attribution_func(multipliers, inputs) + - custom_attribution_func(multipliers, inputs, baselines) + + In case this function is not provided we use the default + logic defined as: multipliers * (inputs - baselines) + It is assumed that all input arguments, `multipliers`, + `inputs` and `baselines` are provided in tuples of same + length. `custom_attribution_func` returns a tuple of + attribution tensors that have the same length as the + `inputs`. + Default: None + + Returns: + **attributions** or 2-element tuple of **attributions**, **delta**: + - **attributions** (*tensor* or tuple of *tensors*): + Attribution score computed based on DeepLift rescale rule with + respect to each input feature. Attributions will always be + the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + - **delta** (*tensor*, returned if return_convergence_delta=True): + This is computed using the property that the + total sum of forward_func(inputs) - forward_func(baselines) + must be very close to the total sum of attributions + computed based on approximated SHAP values using + Deeplift's rescale rule. + Delta is calculated for each example input and baseline pair, + meaning that the number of elements in returned delta tensor + is equal to the + `number of examples in input` * `number of examples + in baseline`. The deltas are ordered in the first place by + input example, followed by the baseline. + Note that the logic described for deltas is guaranteed + when the default logic for attribution computations is used, + meaning that the `custom_attribution_func=None`, otherwise + it is not guaranteed and depends on the specifics of the + `custom_attribution_func`. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> dl = DeepLiftShap(net) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes shap values using deeplift for class 3. + >>> attribution = dl.attribute(input, target=3) + """ + baselines = _format_callable_baseline(baselines, inputs) + + assert isinstance(baselines[0], torch.Tensor) and baselines[0].shape[0] > 1, ( + "Baselines distribution has to be provided in form of a torch.Tensor" + " with more than one example but found: {}." + " If baselines are provided in shape of scalars or with a single" + " baseline example, `DeepLift`" + " approach can be used instead.".format(baselines[0]) + ) + + # Keeps track whether original input is a tuple or not before + # converting it into a tuple. + is_inputs_tuple = _is_tuple(inputs) + + inputs = _format_tensor_into_tuples(inputs) + + # batch sizes + inp_bsz = inputs[0].shape[0] + base_bsz = baselines[0].shape[0] + + ( + exp_inp, + exp_base, + exp_tgt, + exp_addit_args, + ) = self._expand_inputs_baselines_targets( + baselines, inputs, target, additional_forward_args + ) + attributions = super().attribute.__wrapped__( # type: ignore + self, + exp_inp, + exp_base, + target=exp_tgt, + additional_forward_args=exp_addit_args, + return_convergence_delta=cast( + Literal[True, False], return_convergence_delta + ), + custom_attribution_func=custom_attribution_func, + ) + if return_convergence_delta: + attributions, delta = cast(Tuple[Tuple[Tensor, ...], Tensor], attributions) + + attributions = tuple( + self._compute_mean_across_baselines( + inp_bsz, base_bsz, cast(Tensor, attribution) + ) + for attribution in attributions + ) + + if return_convergence_delta: + return _format_output(is_inputs_tuple, attributions), delta + else: + return _format_output(is_inputs_tuple, attributions) + + def _expand_inputs_baselines_targets( + self, + baselines: Tuple[Tensor, ...], + inputs: Tuple[Tensor, ...], + target: TargetType, + additional_forward_args: Any, + ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], TargetType, Any]: + inp_bsz = inputs[0].shape[0] + base_bsz = baselines[0].shape[0] + + expanded_inputs = tuple( + [ + input.repeat_interleave(base_bsz, dim=0).requires_grad_() + for input in inputs + ] + ) + expanded_baselines = tuple( + [ + baseline.repeat( + (inp_bsz,) + tuple([1] * (len(baseline.shape) - 1)) + ).requires_grad_() + for baseline in baselines + ] + ) + expanded_target = _expand_target( + target, base_bsz, expansion_type=ExpansionTypes.repeat_interleave + ) + input_additional_args = ( + _expand_additional_forward_args( + additional_forward_args, + base_bsz, + expansion_type=ExpansionTypes.repeat_interleave, + ) + if additional_forward_args is not None + else None + ) + return ( + expanded_inputs, + expanded_baselines, + expanded_target, + input_additional_args, + ) + + def _compute_mean_across_baselines( + self, inp_bsz: int, base_bsz: int, attribution: Tensor + ) -> Tensor: + # Average for multiple references + attr_shape: Tuple = (inp_bsz, base_bsz) + if len(attribution.shape) > 1: + attr_shape += attribution.shape[1:] + return torch.mean(attribution.view(attr_shape), dim=1, keepdim=False) + + +def nonlinear( + module: Module, + inputs: Tensor, + outputs: Tensor, + grad_input: Tensor, + grad_output: Tensor, + eps: float = 1e-10, +): + r""" + grad_input: (dLoss / dprev_layer_out, dLoss / wij, dLoss / bij) + grad_output: (dLoss / dlayer_out) + https://github.com/pytorch/pytorch/issues/12331 + """ + delta_in, delta_out = _compute_diffs(inputs, outputs) + + new_grad_inp = list(grad_input) + + # supported non-linear modules take only single tensor as input hence accessing + # only the first element in `grad_input` and `grad_output` + new_grad_inp[0] = torch.where( + abs(delta_in) < eps, new_grad_inp[0], grad_output[0] * delta_out / delta_in + ) + + # If the module is invalid, save the newly computed gradients + # The original_grad_input will be overridden later in the Tensor hook + if module.is_invalid: + module.saved_grad = new_grad_inp[0] + return new_grad_inp + + +def softmax( + module: Module, + inputs: Tensor, + outputs: Tensor, + grad_input: Tensor, + grad_output: Tensor, + eps: float = 1e-10, +): + delta_in, delta_out = _compute_diffs(inputs, outputs) + + new_grad_inp = list(grad_input) + grad_input_unnorm = torch.where( + abs(delta_in) < eps, new_grad_inp[0], grad_output[0] * delta_out / delta_in + ) + # normalizing + n = grad_input[0].numel() + + # updating only the first half + new_grad_inp[0] = grad_input_unnorm - grad_input_unnorm.sum() * 1 / n + return new_grad_inp + + +def maxpool1d( + module: Module, + inputs: Tensor, + outputs: Tensor, + grad_input: Tensor, + grad_output: Tensor, + eps: float = 1e-10, +): + return maxpool( + module, + F.max_pool1d, + F.max_unpool1d, + inputs, + outputs, + grad_input, + grad_output, + eps=eps, + ) + + +def maxpool2d( + module: Module, + inputs: Tensor, + outputs: Tensor, + grad_input: Tensor, + grad_output: Tensor, + eps: float = 1e-10, +): + return maxpool( + module, + F.max_pool2d, + F.max_unpool2d, + inputs, + outputs, + grad_input, + grad_output, + eps=eps, + ) + + +def maxpool3d( + module: Module, inputs, outputs, grad_input, grad_output, eps: float = 1e-10 +): + return maxpool( + module, + F.max_pool3d, + F.max_unpool3d, + inputs, + outputs, + grad_input, + grad_output, + eps=eps, + ) + + +def maxpool( + module: Module, + pool_func: Callable, + unpool_func: Callable, + inputs, + outputs, + grad_input, + grad_output, + eps: float = 1e-10, +): + with torch.no_grad(): + input, input_ref = inputs.chunk(2) + output, output_ref = outputs.chunk(2) + + delta_in = input - input_ref + delta_in = torch.cat(2 * [delta_in]) + # Extracts cross maximum between the outputs of maxpool for the + # actual inputs and its corresponding references. In case the delta outputs + # for the references are larger the method relies on the references and + # corresponding gradients to compute the multiplies and contributions. + delta_out_xmax = torch.max(output, output_ref) + delta_out = torch.cat([delta_out_xmax - output_ref, output - delta_out_xmax]) + + _, indices = pool_func( + module.input, + module.kernel_size, + module.stride, + module.padding, + module.dilation, + module.ceil_mode, + True, + ) + grad_output_updated = grad_output[0] + unpool_grad_out_delta, unpool_grad_out_ref_delta = torch.chunk( + unpool_func( + grad_output_updated * delta_out, + indices, + module.kernel_size, + module.stride, + module.padding, + list(cast(torch.Size, module.input.shape)), + ), + 2, + ) + + unpool_grad_out_delta = unpool_grad_out_delta + unpool_grad_out_ref_delta + unpool_grad_out_delta = torch.cat(2 * [unpool_grad_out_delta]) + + # If the module is invalid, we need to recompute the grad_input + if module.is_invalid: + original_grad_input = grad_input + grad_input = ( + unpool_func( + grad_output_updated, + indices, + module.kernel_size, + module.stride, + module.padding, + list(cast(torch.Size, module.input.shape)), + ), + ) + if grad_input[0].shape != inputs.shape: + raise AssertionError( + "A problem occurred during maxpool modul's backward pass. " + "The gradients with respect to inputs include only a " + "subset of inputs. More details about this issue can " + "be found here: " + "https://pytorch.org/docs/stable/" + "nn.html#torch.nn.Module.register_backward_hook " + "This can happen for example if you attribute to the outputs of a " + "MaxPool. As a workaround, please, attribute to the inputs of " + "the following layer." + ) + + new_grad_inp = torch.where( + abs(delta_in) < eps, grad_input[0], unpool_grad_out_delta / delta_in + ) + # If the module is invalid, save the newly computed gradients + # The original_grad_input will be overridden later in the Tensor hook + if module.is_invalid: + module.saved_grad = new_grad_inp + return original_grad_input + else: + return (new_grad_inp,) + + +def _compute_diffs(inputs: Tensor, outputs: Tensor) -> Tuple[Tensor, Tensor]: + input, input_ref = inputs.chunk(2) + # if the model is a single non-linear module and we apply Rescale rule on it + # we might not be able to perform chunk-ing because the output of the module is + # usually being replaced by model output. + output, output_ref = outputs.chunk(2) + delta_in = input - input_ref + delta_out = output - output_ref + + return torch.cat(2 * [delta_in]), torch.cat(2 * [delta_out]) + + +SUPPORTED_NON_LINEAR = { + nn.ReLU: nonlinear, + nn.ELU: nonlinear, + nn.LeakyReLU: nonlinear, + nn.Sigmoid: nonlinear, + nn.Tanh: nonlinear, + nn.Softplus: nonlinear, + nn.MaxPool1d: maxpool1d, + nn.MaxPool2d: maxpool2d, + nn.MaxPool3d: maxpool3d, + nn.Softmax: softmax, +} diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py new file mode 100644 index 0000000000000000000000000000000000000000..fd0007fc7563ecde64d578f8a216bf0e7ec66baa --- /dev/null +++ b/captum/attr/_core/feature_ablation.py @@ -0,0 +1,591 @@ +#!/usr/bin/env python3 + +import math +from typing import Any, Callable, cast, Tuple, Union + +import torch +from captum._utils.common import ( + _expand_additional_forward_args, + _expand_target, + _format_additional_forward_args, + _format_output, + _format_tensor_into_tuples, + _is_tuple, + _run_forward, +) +from captum._utils.progress import progress +from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr._utils.attribution import PerturbationAttribution +from captum.attr._utils.common import _format_input_baseline +from captum.log import log_usage +from torch import dtype, Tensor + + +class FeatureAblation(PerturbationAttribution): + r""" + A perturbation based approach to computing attribution, involving + replacing each input feature with a given baseline / reference, and + computing the difference in output. By default, each scalar value within + each input tensor is taken as a feature and replaced independently. Passing + a feature mask, allows grouping features to be ablated together. This can + be used in cases such as images, where an entire segment or region + can be ablated, measuring the importance of the segment (feature group). + Each input scalar in the group will be given the same attribution value + equal to the change in target as a result of ablating the entire feature + group. + + The forward function can either return a scalar per example or a tensor + of a fixed sized tensor (or scalar value) for the full batch, i.e. the + output does not grow as the batch size increase. If the output is fixed + we consider this model to be an "aggregation" of the inputs. In the fixed + sized output mode we require `perturbations_per_eval == 1` and the + `feature_mask` to be either `None` or for all of them to have 1 as their + first dimension (i.e. a feature mask requires to be applied to all inputs). + """ + + def __init__(self, forward_func: Callable) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or + any modification of it + """ + PerturbationAttribution.__init__(self, forward_func) + self.use_weights = False + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, + perturbations_per_eval: int = 1, + show_progress: bool = False, + **kwargs: Any, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which ablation + attributions are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define reference value which replaces each + feature when ablated. + Baselines can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or + broadcastable to match the dimensions of inputs + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. For all other types, + the given argument is used for all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + feature_mask (tensor or tuple of tensors, optional): + feature_mask defines a mask for the input, grouping + features which should be ablated together. feature_mask + should contain the same number of tensors as inputs. + Each tensor should + be the same size as the corresponding input or + broadcastable to match the input tensor. Each tensor + should contain integers in the range 0 to num_features + - 1, and indices corresponding to the same feature should + have the same value. + Note that features within each input tensor are ablated + independently (not across tensors). + If the forward function returns a single scalar per batch, + we enforce that the first dimension of each mask must be 1, + since attributions are returned batch-wise rather than per + example, so the attributions must correspond to the + same features (indices) in each input example. + If None, then a feature mask is constructed which assigns + each scalar within a tensor as a separate feature, which + is ablated independently. + Default: None + perturbations_per_eval (int, optional): Allows ablation of multiple + features to be processed simultaneously in one call to + forward_fn. + Each forward pass will contain a maximum of + perturbations_per_eval * #examples samples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain at most + (perturbations_per_eval * #examples) / num_devices + samples. + If the forward function's number of outputs does not + change as the batch size grows (e.g. if it outputs a + scalar value), you must set perturbations_per_eval to 1 + and use a single feature mask to describe the features + for all examples in the batch. + Default: 1 + show_progress (bool, optional): Displays the progress of computation. + It will try to use tqdm if available for advanced features + (e.g. time estimation). Otherwise, it will fallback to + a simple output of progress. + Default: False + **kwargs (Any, optional): Any additional arguments used by child + classes of FeatureAblation (such as Occlusion) to construct + ablations. These arguments are ignored when using + FeatureAblation directly. + Default: None + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + The attributions with respect to each input feature. + If the forward function returns + a scalar value per example, attributions will be + the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If the forward function returns a scalar per batch, then + attribution tensor(s) will have first dimension 1 and + the remaining dimensions will match the input. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple of tensors is provided for inputs, a + tuple of corresponding sized tensors is returned. + + + Examples:: + + >>> # SimpleClassifier takes a single input tensor of size Nx4x4, + >>> # and returns an Nx3 tensor of class probabilities. + >>> net = SimpleClassifier() + >>> # Generating random input with size 2 x 4 x 4 + >>> input = torch.randn(2, 4, 4) + >>> # Defining FeatureAblation interpreter + >>> ablator = FeatureAblation(net) + >>> # Computes ablation attribution, ablating each of the 16 + >>> # scalar input independently. + >>> attr = ablator.attribute(input, target=1) + + >>> # Alternatively, we may want to ablate features in groups, e.g. + >>> # grouping each 2x2 square of the inputs and ablating them together. + >>> # This can be done by creating a feature mask as follows, which + >>> # defines the feature groups, e.g.: + >>> # +---+---+---+---+ + >>> # | 0 | 0 | 1 | 1 | + >>> # +---+---+---+---+ + >>> # | 0 | 0 | 1 | 1 | + >>> # +---+---+---+---+ + >>> # | 2 | 2 | 3 | 3 | + >>> # +---+---+---+---+ + >>> # | 2 | 2 | 3 | 3 | + >>> # +---+---+---+---+ + >>> # With this mask, all inputs with the same value are ablated + >>> # simultaneously, and the attribution for each input in the same + >>> # group (0, 1, 2, and 3) per example are the same. + >>> # The attributions can be calculated as follows: + >>> # feature mask has dimensions 1 x 4 x 4 + >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1], + >>> [2,2,3,3],[2,2,3,3]]]) + >>> attr = ablator.attribute(input, target=1, feature_mask=feature_mask) + """ + # Keeps track whether original input is a tuple or not before + # converting it into a tuple. + is_inputs_tuple = _is_tuple(inputs) + inputs, baselines = _format_input_baseline(inputs, baselines) + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + num_examples = inputs[0].shape[0] + feature_mask = ( + _format_tensor_into_tuples(feature_mask) + if feature_mask is not None + else None + ) + assert ( + isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1 + ), "Perturbations per evaluation must be an integer and at least 1." + with torch.no_grad(): + if show_progress: + feature_counts = self._get_feature_counts( + inputs, feature_mask, **kwargs + ) + total_forwards = ( + sum( + math.ceil(count / perturbations_per_eval) + for count in feature_counts + ) + + 1 + ) # add 1 for the initial eval + attr_progress = progress( + desc=f"{self.get_name()} attribution", total=total_forwards + ) + attr_progress.update(0) + + # Computes initial evaluation with all features, which is compared + # to each ablated result. + initial_eval = _run_forward( + self.forward_func, inputs, target, additional_forward_args + ) + + if show_progress: + attr_progress.update() + + agg_output_mode = FeatureAblation._find_output_mode( + perturbations_per_eval, feature_mask + ) + + # get as a 2D tensor (if it is not a scalar) + if isinstance(initial_eval, torch.Tensor): + initial_eval = initial_eval.reshape(1, -1) + num_outputs = initial_eval.shape[1] + else: + num_outputs = 1 + + if not agg_output_mode: + assert ( + isinstance(initial_eval, torch.Tensor) + and num_outputs == num_examples + ), ( + "expected output of `forward_func` to have " + + "`batch_size` elements for perturbations_per_eval > 1 " + + "and all feature_mask.shape[0] > 1" + ) + + # Initialize attribution totals and counts + attrib_type = cast( + dtype, + initial_eval.dtype + if isinstance(initial_eval, Tensor) + else type(initial_eval), + ) + + total_attrib = [ + torch.zeros( + (num_outputs,) + input.shape[1:], + dtype=attrib_type, + device=input.device, + ) + for input in inputs + ] + + # Weights are used in cases where ablations may be overlapping. + if self.use_weights: + weights = [ + torch.zeros( + (num_outputs,) + input.shape[1:], device=input.device + ).float() + for input in inputs + ] + + # Iterate through each feature tensor for ablation + for i in range(len(inputs)): + # Skip any empty input tensors + if torch.numel(inputs[i]) == 0: + continue + + for ( + current_inputs, + current_add_args, + current_target, + current_mask, + ) in self._ith_input_ablation_generator( + i, + inputs, + additional_forward_args, + target, + baselines, + feature_mask, + perturbations_per_eval, + **kwargs, + ): + # modified_eval dimensions: 1D tensor with length + # equal to #num_examples * #features in batch + modified_eval = _run_forward( + self.forward_func, + current_inputs, + current_target, + current_add_args, + ) + + if show_progress: + attr_progress.update() + + # (contains 1 more dimension than inputs). This adds extra + # dimensions of 1 to make the tensor broadcastable with the inputs + # tensor. + if not isinstance(modified_eval, torch.Tensor): + eval_diff = initial_eval - modified_eval + else: + if not agg_output_mode: + assert ( + modified_eval.numel() == current_inputs[0].shape[0] + ), """expected output of forward_func to grow with + batch_size. If this is not the case for your model + please set perturbations_per_eval = 1""" + + eval_diff = ( + initial_eval - modified_eval.reshape((-1, num_outputs)) + ).reshape((-1, num_outputs) + (len(inputs[i].shape) - 1) * (1,)) + eval_diff = eval_diff.to(total_attrib[i].device) + if self.use_weights: + weights[i] += current_mask.float().sum(dim=0) + total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum( + dim=0 + ) + + if show_progress: + attr_progress.close() + + # Divide total attributions by counts and return formatted attributions + if self.use_weights: + attrib = tuple( + single_attrib.float() / weight + for single_attrib, weight in zip(total_attrib, weights) + ) + else: + attrib = tuple(total_attrib) + _result = _format_output(is_inputs_tuple, attrib) + return _result + + def _ith_input_ablation_generator( + self, + i, + inputs, + additional_args, + target, + baselines, + input_mask, + perturbations_per_eval, + **kwargs, + ): + """ + This method return an generator of ablation perturbations of the i-th input + + Returns: + ablation_iter (generator): yields each perturbation to be evaluated + as a tuple (inputs, additional_forward_args, targets, mask). + """ + extra_args = {} + for key, value in kwargs.items(): + # For any tuple argument in kwargs, we choose index i of the tuple. + if isinstance(value, tuple): + extra_args[key] = value[i] + else: + extra_args[key] = value + + input_mask = input_mask[i] if input_mask is not None else None + min_feature, num_features, input_mask = self._get_feature_range_and_mask( + inputs[i], input_mask, **extra_args + ) + num_examples = inputs[0].shape[0] + perturbations_per_eval = min(perturbations_per_eval, num_features) + baseline = baselines[i] if isinstance(baselines, tuple) else baselines + if isinstance(baseline, torch.Tensor): + baseline = baseline.reshape((1,) + baseline.shape) + + if perturbations_per_eval > 1: + # Repeat features and additional args for batch size. + all_features_repeated = [ + torch.cat([inputs[j]] * perturbations_per_eval, dim=0) + for j in range(len(inputs)) + ] + additional_args_repeated = ( + _expand_additional_forward_args(additional_args, perturbations_per_eval) + if additional_args is not None + else None + ) + target_repeated = _expand_target(target, perturbations_per_eval) + else: + all_features_repeated = list(inputs) + additional_args_repeated = additional_args + target_repeated = target + + num_features_processed = min_feature + while num_features_processed < num_features: + current_num_ablated_features = min( + perturbations_per_eval, num_features - num_features_processed + ) + + # Store appropriate inputs and additional args based on batch size. + if current_num_ablated_features != perturbations_per_eval: + current_features = [ + feature_repeated[0 : current_num_ablated_features * num_examples] + for feature_repeated in all_features_repeated + ] + current_additional_args = ( + _expand_additional_forward_args( + additional_args, current_num_ablated_features + ) + if additional_args is not None + else None + ) + current_target = _expand_target(target, current_num_ablated_features) + else: + current_features = all_features_repeated + current_additional_args = additional_args_repeated + current_target = target_repeated + + # Store existing tensor before modifying + original_tensor = current_features[i] + # Construct ablated batch for features in range num_features_processed + # to num_features_processed + current_num_ablated_features and return + # mask with same size as ablated batch. ablated_features has dimension + # (current_num_ablated_features, num_examples, inputs[i].shape[1:]) + # Note that in the case of sparse tensors, the second dimension + # may not necessarilly be num_examples and will match the first + # dimension of this tensor. + current_reshaped = current_features[i].reshape( + (current_num_ablated_features, -1) + current_features[i].shape[1:] + ) + + ablated_features, current_mask = self._construct_ablated_input( + current_reshaped, + input_mask, + baseline, + num_features_processed, + num_features_processed + current_num_ablated_features, + **extra_args, + ) + + # current_features[i] has dimension + # (current_num_ablated_features * num_examples, inputs[i].shape[1:]), + # which can be provided to the model as input. + current_features[i] = ablated_features.reshape( + (-1,) + ablated_features.shape[2:] + ) + yield tuple( + current_features + ), current_additional_args, current_target, current_mask + # Replace existing tensor at index i. + current_features[i] = original_tensor + num_features_processed += current_num_ablated_features + + def _construct_ablated_input( + self, expanded_input, input_mask, baseline, start_feature, end_feature, **kwargs + ): + r""" + Ablates given expanded_input tensor with given feature mask, feature range, + and baselines. expanded_input shape is (`num_features`, `num_examples`, ...) + with remaining dimensions corresponding to remaining original tensor + dimensions and `num_features` = `end_feature` - `start_feature`. + input_mask has same number of dimensions as original input tensor (one less + than `expanded_input`), and can have first dimension either 1, applying same + feature mask to all examples, or `num_examples`. baseline is expected to + be broadcastable to match `expanded_input`. + + This method returns the ablated input tensor, which has the same + dimensionality as `expanded_input` as well as the corresponding mask with + either the same dimensionality as `expanded_input` or second dimension + being 1. This mask contains 1s in locations which have been ablated (and + thus counted towards ablations for that feature) and 0s otherwise. + """ + current_mask = torch.stack( + [input_mask == j for j in range(start_feature, end_feature)], dim=0 + ).long() + ablated_tensor = ( + expanded_input * (1 - current_mask).to(expanded_input.dtype) + ) + (baseline * current_mask.to(expanded_input.dtype)) + return ablated_tensor, current_mask + + def _get_feature_range_and_mask(self, input, input_mask, **kwargs): + if input_mask is None: + # Obtain feature mask for selected input tensor, matches size of + # 1 input example, (1 x inputs[i].shape[1:]) + input_mask = torch.reshape( + torch.arange(torch.numel(input[0]), device=input.device), + input[0:1].shape, + ).long() + return ( + torch.min(input_mask).item(), + torch.max(input_mask).item() + 1, + input_mask, + ) + + def _get_feature_counts(self, inputs, feature_mask, **kwargs): + """return the numbers of input features""" + if not feature_mask: + return tuple(inp[0].numel() if inp.numel() else 0 for inp in inputs) + + return tuple( + (mask.max() - mask.min()).item() + 1 + if mask is not None + else (inp[0].numel() if inp.numel() else 0) + for inp, mask in zip(inputs, feature_mask) + ) + + @staticmethod + def _find_output_mode( + perturbations_per_eval: int, + feature_mask: Union[None, TensorOrTupleOfTensorsGeneric], + ) -> bool: + """ + Returns True if the output mode is "aggregation output mode" + + Aggregation output mode is defined as: when there is no 1:1 correspondence + with the `num_examples` (`batch_size`) and the amount of outputs your model + produces, i.e. the model output does not grow in size as the input becomes + larger. + + We assume this is the case if `perturbations_per_eval == 1` + and your feature mask is None or is associated to all + examples in a batch (fm.shape[0] == 1 for all fm in feature_mask). + """ + return perturbations_per_eval == 1 and ( + feature_mask is None + or all(len(sm.shape) == 0 or sm.shape[0] == 1 for sm in feature_mask) + ) diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py new file mode 100644 index 0000000000000000000000000000000000000000..544ff16ac611c757563f9e3bd73acac75fcf2e32 --- /dev/null +++ b/captum/attr/_core/feature_permutation.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, Tuple, Union + +import torch +from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr._core.feature_ablation import FeatureAblation +from captum.log import log_usage +from torch import Tensor + + +def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor: + n = x.size(0) + assert n > 1, "cannot permute features with batch_size = 1" + + perm = torch.randperm(n) + no_perm = torch.arange(n) + while (perm == no_perm).all(): + perm = torch.randperm(n) + + return (x[perm] * feature_mask.to(dtype=x.dtype)) + ( + x * feature_mask.bitwise_not().to(dtype=x.dtype) + ) + + +class FeaturePermutation(FeatureAblation): + r""" + A perturbation based approach to compute attribution, which + takes each input feature, permutes the feature values within a batch, + and computes the difference between original and shuffled outputs for + the given batch. This difference signifies the feature importance + for the permuted feature. + + Example pseudocode for the algorithm is as follows:: + + perm_feature_importance(batch): + importance = dict() + baseline_error = error_metric(model(batch), batch_labels) + for each feature: + permute this feature across the batch + error = error_metric(model(permuted_batch), batch_labels) + importance[feature] = baseline_error - error + "un-permute" the feature across the batch + + return importance + + It should be noted that the `error_metric` must be called in the + `forward_func`. You do not need to have an error metric, e.g. you + could simply return the logits (the model output), but this may or may + not provide a meaningful attribution. + + This method, unlike other attribution methods, requires a batch + of examples to compute attributions and cannot be performed on a single example. + + By default, each scalar value within + each input tensor is taken as a feature and shuffled independently. Passing + a feature mask, allows grouping features to be shuffled together. + Each input scalar in the group will be given the same attribution value + equal to the change in target as a result of shuffling the entire feature + group. + + The forward function can either return a scalar per example, or a single + scalar for the full batch. If a single scalar is returned for the batch, + `perturbations_per_eval` must be 1, and the returned attributions will have + first dimension 1, corresponding to feature importance across all + examples in the batch. + + More information can be found in the permutation feature + importance algorithm description here: + https://christophm.github.io/interpretable-ml-book/feature-importance.html + """ + + def __init__( + self, forward_func: Callable, perm_func: Callable = _permute_feature + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or + any modification of it + perm_func (callable, optional): A function that accepts a batch of + inputs and a feature mask, and "permutes" the feature using + feature mask across the batch. This defaults to a function + which applies a random permutation, this argument only needs + to be provided if a custom permutation behavior is desired. + Default: `_permute_feature` + """ + FeatureAblation.__init__(self, forward_func=forward_func) + self.perm_func = perm_func + + # suppressing error caused by the child class not having a matching + # signature to the parent + @log_usage() + def attribute( # type: ignore + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, + perturbations_per_eval: int = 1, + show_progress: bool = False, + **kwargs: Any, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + This function is almost equivalent to `FeatureAblation.attribute`. The + main difference is the way ablated examples are generated. Specifically + they are generated through the `perm_func`, as we set the baselines for + `FeatureAblation.attribute` to None. + + + Args: + inputs (tensor or tuple of tensors): Input for which + permutation attributions are computed. If + forward_func takes a single tensor as input, a + single input tensor should be provided. If + forward_func takes multiple tensors as input, a + tuple of the input tensors should be provided. It is + assumed that for all given input tensors, dimension + 0 corresponds to the number of examples (aka batch + size), and if multiple input tensors are provided, + the examples must be aligned appropriately. + target (int, tuple, tensor or list, optional): Output indices for + which difference is computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. For all other types, + the given argument is used for all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + feature_mask (tensor or tuple of tensors, optional): + feature_mask defines a mask for the input, grouping + features which should be ablated together. feature_mask + should contain the same number of tensors as inputs. + Each tensor should be the same size as the + corresponding input or broadcastable to match the + input tensor. Each tensor should contain integers in + the range 0 to num_features - 1, and indices + corresponding to the same feature should have the + same value. Note that features within each input + tensor are ablated independently (not across + tensors). + + The first dimension of each mask must be 1, as we require + to have the same group of features for each input sample. + + If None, then a feature mask is constructed which assigns + each scalar within a tensor as a separate feature, which + is permuted independently. + Default: None + perturbations_per_eval (int, optional): Allows permutations + of multiple features to be processed simultaneously + in one call to forward_fn. Each forward pass will + contain a maximum of perturbations_per_eval * #examples + samples. For DataParallel models, each batch is + split among the available devices, so evaluations on + each available device contain at most + (perturbations_per_eval * #examples) / num_devices + samples. + If the forward function returns a single scalar per batch, + perturbations_per_eval must be set to 1. + Default: 1 + show_progress (bool, optional): Displays the progress of computation. + It will try to use tqdm if available for advanced features + (e.g. time estimation). Otherwise, it will fallback to + a simple output of progress. + Default: False + **kwargs (Any, optional): Any additional arguments used by child + classes of FeatureAblation (such as Occlusion) to construct + ablations. These arguments are ignored when using + FeatureAblation directly. + Default: None + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + The attributions with respect to each input feature. + If the forward function returns + a scalar value per example, attributions will be + the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If the forward function returns a scalar per batch, then + attribution tensor(s) will have first dimension 1 and + the remaining dimensions will match the input. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple of tensors is provided for inputs, + a tuple of corresponding sized tensors is returned. + + + Examples:: + + >>> # SimpleClassifier takes a single input tensor of size Nx4x4, + >>> # and returns an Nx3 tensor of class probabilities. + >>> net = SimpleClassifier() + >>> # Generating random input with size 10 x 4 x 4 + >>> input = torch.randn(10, 4, 4) + >>> # Defining FeaturePermutation interpreter + >>> feature_perm = FeaturePermutation(net) + >>> # Computes permutation attribution, shuffling each of the 16 + >>> # scalar input independently. + >>> attr = feature_perm.attribute(input, target=1) + + >>> # Alternatively, we may want to permute features in groups, e.g. + >>> # grouping each 2x2 square of the inputs and shuffling them together. + >>> # This can be done by creating a feature mask as follows, which + >>> # defines the feature groups, e.g.: + >>> # +---+---+---+---+ + >>> # | 0 | 0 | 1 | 1 | + >>> # +---+---+---+---+ + >>> # | 0 | 0 | 1 | 1 | + >>> # +---+---+---+---+ + >>> # | 2 | 2 | 3 | 3 | + >>> # +---+---+---+---+ + >>> # | 2 | 2 | 3 | 3 | + >>> # +---+---+---+---+ + >>> # With this mask, all inputs with the same value are shuffled + >>> # simultaneously, and the attribution for each input in the same + >>> # group (0, 1, 2, and 3) per example are the same. + >>> # The attributions can be calculated as follows: + >>> # feature mask has dimensions 1 x 4 x 4 + >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1], + >>> [2,2,3,3],[2,2,3,3]]]) + >>> attr = feature_perm.attribute(input, target=1, + >>> feature_mask=feature_mask) + """ + return FeatureAblation.attribute.__wrapped__( + self, + inputs, + baselines=None, + target=target, + additional_forward_args=additional_forward_args, + feature_mask=feature_mask, + perturbations_per_eval=perturbations_per_eval, + show_progress=show_progress, + **kwargs, + ) + + def _construct_ablated_input( + self, + expanded_input: Tensor, + input_mask: Tensor, + baseline: Union[int, float, Tensor], + start_feature: int, + end_feature: int, + **kwargs: Any, + ) -> Tuple[Tensor, Tensor]: + r""" + This function permutes the features of `expanded_input` with a given + feature mask and feature range. Permutation occurs via calling + `self.perm_func` across each batch within `expanded_input`. As with + `FeatureAblation._construct_ablated_input`: + - `expanded_input.shape = (num_features, num_examples, ...)` + - `num_features = end_feature - start_feature` (i.e. start and end is a + half-closed interval) + - `input_mask` is a tensor of the same shape as one input, which + describes the locations of each feature via their "index" + + Since `baselines` is set to None for `FeatureAblation.attribute, this + will be the zero tensor, however, it is not used. + """ + assert input_mask.shape[0] == 1, ( + "input_mask.shape[0] != 1: pass in one mask in order to permute" + "the same features for each input" + ) + current_mask = torch.stack( + [input_mask == j for j in range(start_feature, end_feature)], dim=0 + ).bool() + + output = torch.stack( + [ + self.perm_func(x, mask.squeeze(0)) + for x, mask in zip(expanded_input, current_mask) + ] + ) + return output, current_mask diff --git a/captum/attr/_core/gradient_shap.py b/captum/attr/_core/gradient_shap.py new file mode 100644 index 0000000000000000000000000000000000000000..57d5e909af3b7d362759b6774b7db2cd6fda8b12 --- /dev/null +++ b/captum/attr/_core/gradient_shap.py @@ -0,0 +1,414 @@ +#!/usr/bin/env python3 +import typing +from typing import Any, Callable, Tuple, Union + +import numpy as np +import torch +from captum._utils.common import _is_tuple +from captum._utils.typing import ( + BaselineType, + Literal, + TargetType, + Tensor, + TensorOrTupleOfTensorsGeneric, +) +from captum.attr._core.noise_tunnel import NoiseTunnel +from captum.attr._utils.attribution import GradientAttribution +from captum.attr._utils.common import ( + _compute_conv_delta_and_format_attrs, + _format_callable_baseline, + _format_input_baseline, +) +from captum.log import log_usage + + +class GradientShap(GradientAttribution): + r""" + Implements gradient SHAP based on the implementation from SHAP's primary + author. For reference, please view the original + `implementation + `_ + and the paper: `A Unified Approach to Interpreting Model Predictions + `_ + + GradientShap approximates SHAP values by computing the expectations of + gradients by randomly sampling from the distribution of baselines/references. + It adds white noise to each input sample `n_samples` times, selects a + random baseline from baselines' distribution and a random point along the + path between the baseline and the input, and computes the gradient of outputs + with respect to those selected random points. The final SHAP values represent + the expected values of gradients * (inputs - baselines). + + GradientShap makes an assumption that the input features are independent + and that the explanation model is linear, meaning that the explanations + are modeled through the additive composition of feature effects. + Under those assumptions, SHAP value can be approximated as the expectation + of gradients that are computed for randomly generated `n_samples` input + samples after adding gaussian noise `n_samples` times to each input for + different baselines/references. + + In some sense it can be viewed as an approximation of integrated gradients + by computing the expectations of gradients for different baselines. + + Current implementation uses Smoothgrad from `NoiseTunnel` in order to + randomly draw samples from the distribution of baselines, add noise to input + samples and compute the expectation (smoothgrad). + """ + + def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> None: + r""" + Args: + + forward_func (function): The forward function of the model or + any modification of it. + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in + then this type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of gradient shap, if `multiply_by_inputs` + is set to True, the sensitivity scores of scaled inputs + are being multiplied by (inputs - baselines). + """ + GradientAttribution.__init__(self, forward_func) + self._multiply_by_inputs = multiply_by_inputs + + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: Union[ + TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] + ], + n_samples: int = 5, + stdevs: Union[float, Tuple[float, ...]] = 0.0, + target: TargetType = None, + additional_forward_args: Any = None, + *, + return_convergence_delta: Literal[True], + ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: + ... + + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: Union[ + TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] + ], + n_samples: int = 5, + stdevs: Union[float, Tuple[float, ...]] = 0.0, + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: Literal[False] = False, + ) -> TensorOrTupleOfTensorsGeneric: + ... + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: Union[ + TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] + ], + n_samples: int = 5, + stdevs: Union[float, Tuple[float, ...]] = 0.0, + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: bool = False, + ) -> Union[ + TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] + ]: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which SHAP attribution + values are computed. If `forward_func` takes a single + tensor as input, a single input tensor should be provided. + If `forward_func` takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + baselines (tensor, tuple of tensors, callable): + Baselines define the starting point from which expectation + is computed and can be provided as: + + - a single tensor, if inputs is a single tensor, with + the first dimension equal to the number of examples + in the baselines' distribution. The remaining dimensions + must match with input tensor's dimension starting from + the second dimension. + + - a tuple of tensors, if inputs is a tuple of tensors, + with the first dimension of any tensor inside the tuple + equal to the number of examples in the baseline's + distribution. The remaining dimensions must match + the dimensions of the corresponding input tensor + starting from the second dimension. + + - callable function, optionally takes `inputs` as an + argument and either returns a single tensor + or a tuple of those. + + It is recommended that the number of samples in the baselines' + tensors is larger than one. + n_samples (int, optional): The number of randomly generated examples + per sample in the input batch. Random examples are + generated by adding gaussian random noise to each sample. + Default: `5` if `n_samples` is not provided. + stdevs (float, or a tuple of floats optional): The standard deviation + of gaussian noise with zero mean that is added to each + input in the batch. If `stdevs` is a single float value + then that same value is used for all inputs. If it is + a tuple, then it must have the same length as the inputs + tuple. In this case, each stdev value in the stdevs tuple + corresponds to the input with the same index in the inputs + tuple. + Default: 0.0 + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It can contain a tuple of ND tensors or + any arbitrary python type of any shape. + In case of the ND tensor the first dimension of the + tensor must correspond to the batch size. It will be + repeated for each `n_steps` for each randomly generated + input sample. + Note that the gradients are not computed with respect + to these arguments. + Default: None + return_convergence_delta (bool, optional): Indicates whether to return + convergence delta or not. If `return_convergence_delta` + is set to True convergence delta will be returned in + a tuple following attributions. + Default: False + Returns: + **attributions** or 2-element tuple of **attributions**, **delta**: + - **attributions** (*tensor* or tuple of *tensors*): + Attribution score computed based on GradientSHAP with respect + to each input feature. Attributions will always be + the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + - **delta** (*tensor*, returned if return_convergence_delta=True): + This is computed using the property that the total + sum of forward_func(inputs) - forward_func(baselines) + must be very close to the total sum of the attributions + based on GradientSHAP. + Delta is calculated for each example in the input after adding + `n_samples` times gaussian noise to each of them. Therefore, + the dimensionality of the deltas tensor is equal to the + `number of examples in the input` * `n_samples` + The deltas are ordered by each input example and `n_samples` + noisy samples generated for it. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> gradient_shap = GradientShap(net) + >>> input = torch.randn(3, 3, 32, 32, requires_grad=True) + >>> # choosing baselines randomly + >>> baselines = torch.randn(20, 3, 32, 32) + >>> # Computes gradient shap for the input + >>> # Attribution size matches input size: 3x3x32x32 + >>> attribution = gradient_shap.attribute(input, baselines, + target=5) + + """ + # since `baselines` is a distribution, we can generate it using a function + # rather than passing it as an input argument + baselines = _format_callable_baseline(baselines, inputs) + assert isinstance(baselines[0], torch.Tensor), ( + "Baselines distribution has to be provided in a form " + "of a torch.Tensor {}.".format(baselines[0]) + ) + + input_min_baseline_x_grad = InputBaselineXGradient( + self.forward_func, self.multiplies_by_inputs + ) + input_min_baseline_x_grad.gradient_func = self.gradient_func + + nt = NoiseTunnel(input_min_baseline_x_grad) + + # NOTE: using attribute.__wrapped__ to not log + attributions = nt.attribute.__wrapped__( + nt, # self + inputs, + nt_type="smoothgrad", + nt_samples=n_samples, + stdevs=stdevs, + draw_baseline_from_distrib=True, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + return_convergence_delta=return_convergence_delta, + ) + + return attributions + + def has_convergence_delta(self) -> bool: + return True + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs + + +class InputBaselineXGradient(GradientAttribution): + def __init__(self, forward_func: Callable, multiply_by_inputs=True) -> None: + r""" + Args: + + forward_func (function): The forward function of the model or + any modification of it + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in + then this type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of gradient shap, if `multiply_by_inputs` + is set to True, the sensitivity scores of scaled inputs + are being multiplied by (inputs - baselines). + + """ + GradientAttribution.__init__(self, forward_func) + self._multiply_by_inputs = multiply_by_inputs + + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + *, + return_convergence_delta: Literal[True], + ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: + ... + + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: Literal[False] = False, + ) -> TensorOrTupleOfTensorsGeneric: + ... + + @log_usage() + def attribute( # type: ignore + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: bool = False, + ) -> Union[ + TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] + ]: + # Keeps track whether original input is a tuple or not before + # converting it into a tuple. + is_inputs_tuple = _is_tuple(inputs) + inputs, baselines = _format_input_baseline(inputs, baselines) + + rand_coefficient = torch.tensor( + np.random.uniform(0.0, 1.0, inputs[0].shape[0]), + device=inputs[0].device, + dtype=inputs[0].dtype, + ) + + input_baseline_scaled = tuple( + _scale_input(input, baseline, rand_coefficient) + for input, baseline in zip(inputs, baselines) + ) + grads = self.gradient_func( + self.forward_func, input_baseline_scaled, target, additional_forward_args + ) + + if self.multiplies_by_inputs: + input_baseline_diffs = tuple( + input - baseline for input, baseline in zip(inputs, baselines) + ) + attributions = tuple( + input_baseline_diff * grad + for input_baseline_diff, grad in zip(input_baseline_diffs, grads) + ) + else: + attributions = grads + + return _compute_conv_delta_and_format_attrs( + self, + return_convergence_delta, + attributions, + baselines, + inputs, + additional_forward_args, + target, + is_inputs_tuple, + ) + + def has_convergence_delta(self) -> bool: + return True + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs + + +def _scale_input( + input: Tensor, baseline: Union[Tensor, int, float], rand_coefficient: Tensor +) -> Tensor: + # batch size + bsz = input.shape[0] + inp_shape_wo_bsz = input.shape[1:] + inp_shape = (bsz,) + tuple([1] * len(inp_shape_wo_bsz)) + + # expand and reshape the indices + rand_coefficient = rand_coefficient.view(inp_shape) + + input_baseline_scaled = ( + rand_coefficient * input + (1.0 - rand_coefficient) * baseline + ).requires_grad_() + return input_baseline_scaled diff --git a/captum/attr/_core/guided_backprop_deconvnet.py b/captum/attr/_core/guided_backprop_deconvnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e1953ed5b94476e2ade87ce8418fb60dec9033a6 --- /dev/null +++ b/captum/attr/_core/guided_backprop_deconvnet.py @@ -0,0 +1,322 @@ +#!/usr/bin/env python3 +import warnings +from typing import Any, List, Tuple, Union + +import torch +import torch.nn.functional as F +from captum._utils.common import ( + _format_output, + _format_tensor_into_tuples, + _is_tuple, + _register_backward_hook, +) +from captum._utils.gradient import ( + apply_gradient_requirements, + undo_gradient_requirements, +) +from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr._utils.attribution import GradientAttribution +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module +from torch.utils.hooks import RemovableHandle + + +class ModifiedReluGradientAttribution(GradientAttribution): + def __init__(self, model: Module, use_relu_grad_output: bool = False) -> None: + r""" + Args: + + model (nn.Module): The reference to PyTorch model instance. + """ + GradientAttribution.__init__(self, model) + self.model = model + self.backward_hooks: List[RemovableHandle] = [] + self.use_relu_grad_output = use_relu_grad_output + assert isinstance(self.model, torch.nn.Module), ( + "Given model must be an instance of torch.nn.Module to properly hook" + " ReLU layers." + ) + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Computes attribution by overriding relu gradients. Based on constructor + flag use_relu_grad_output, performs either GuidedBackpropagation if False + and Deconvolution if True. This class is the parent class of both these + methods, more information on usage can be found in the docstrings for each + implementing class. + """ + + # Keeps track whether original input is a tuple or not before + # converting it into a tuple. + is_inputs_tuple = _is_tuple(inputs) + + inputs = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(inputs) + + # set hooks for overriding ReLU gradients + warnings.warn( + "Setting backward hooks on ReLU activations." + "The hooks will be removed after the attribution is finished" + ) + try: + self.model.apply(self._register_hooks) + + gradients = self.gradient_func( + self.forward_func, inputs, target, additional_forward_args + ) + finally: + self._remove_hooks() + + undo_gradient_requirements(inputs, gradient_mask) + return _format_output(is_inputs_tuple, gradients) + + def _register_hooks(self, module: Module): + if isinstance(module, torch.nn.ReLU): + hook = _register_backward_hook(module, self._backward_hook, self) + self.backward_hooks.append(hook) + + def _backward_hook( + self, + module: Module, + grad_input: Union[Tensor, Tuple[Tensor, ...]], + grad_output: Union[Tensor, Tuple[Tensor, ...]], + ): + to_override_grads = grad_output if self.use_relu_grad_output else grad_input + if isinstance(to_override_grads, tuple): + return tuple( + F.relu(to_override_grad) for to_override_grad in to_override_grads + ) + else: + return F.relu(to_override_grads) + + def _remove_hooks(self): + for hook in self.backward_hooks: + hook.remove() + + +class GuidedBackprop(ModifiedReluGradientAttribution): + r""" + Computes attribution using guided backpropagation. Guided backpropagation + computes the gradient of the target output with respect to the input, + but gradients of ReLU functions are overridden so that only + non-negative gradients are backpropagated. + + More details regarding the guided backpropagation algorithm can be found + in the original paper here: + https://arxiv.org/abs/1412.6806 + + Warning: Ensure that all ReLU operations in the forward function of the + given model are performed using a module (nn.module.ReLU). + If nn.functional.ReLU is used, gradients are not overridden appropriately. + """ + + def __init__(self, model: Module) -> None: + r""" + Args: + + model (nn.Module): The reference to PyTorch model instance. Model cannot + contain any in-place ReLU submodules; these are not + supported by the register_full_backward_hook PyTorch API. + """ + ModifiedReluGradientAttribution.__init__( + self, model, use_relu_grad_output=False + ) + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which + attributions are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a tuple + containing multiple additional arguments including tensors + or any arbitrary python types. These arguments are provided to + forward_func in order, following the arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + The guided backprop gradients with respect to each + input feature. Attributions will always + be the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> gbp = GuidedBackprop(net) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes Guided Backprop attribution scores for class 3. + >>> attribution = gbp.attribute(input, target=3) + """ + return super().attribute.__wrapped__( + self, inputs, target, additional_forward_args + ) + + +class Deconvolution(ModifiedReluGradientAttribution): + r""" + Computes attribution using deconvolution. Deconvolution + computes the gradient of the target output with respect to the input, + but gradients of ReLU functions are overridden so that the gradient + of the ReLU input is simply computed taking ReLU of the output gradient, + essentially only propagating non-negative gradients (without + dependence on the sign of the ReLU input). + + More details regarding the deconvolution algorithm can be found + in these papers: + https://arxiv.org/abs/1311.2901 + https://link.springer.com/chapter/10.1007/978-3-319-46466-4_8 + + Warning: Ensure that all ReLU operations in the forward function of the + given model are performed using a module (nn.module.ReLU). + If nn.functional.ReLU is used, gradients are not overridden appropriately. + """ + + def __init__(self, model: Module) -> None: + r""" + Args: + + model (nn.Module): The reference to PyTorch model instance. Model cannot + contain any in-place ReLU submodules; these are not + supported by the register_full_backward_hook PyTorch API. + """ + ModifiedReluGradientAttribution.__init__(self, model, use_relu_grad_output=True) + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which + attributions are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a tuple + containing multiple additional arguments including tensors + or any arbitrary python types. These arguments are provided to + forward_func in order, following the arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + The deconvolution attributions with respect to each + input feature. Attributions will always + be the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> deconv = Deconvolution(net) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes Deconvolution attribution scores for class 3. + >>> attribution = deconv.attribute(input, target=3) + """ + return super().attribute.__wrapped__( + self, inputs, target, additional_forward_args + ) diff --git a/captum/attr/_core/guided_grad_cam.py b/captum/attr/_core/guided_grad_cam.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e29c4b2941ddaf49545a320042d13f6f33ba9a --- /dev/null +++ b/captum/attr/_core/guided_grad_cam.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 +import warnings +from typing import Any, List, Union + +import torch +from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple +from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr._core.guided_backprop_deconvnet import GuidedBackprop +from captum.attr._core.layer.grad_cam import LayerGradCam +from captum.attr._utils.attribution import GradientAttribution, LayerAttribution +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module + + +class GuidedGradCam(GradientAttribution): + r""" + Computes element-wise product of guided backpropagation attributions + with upsampled (non-negative) GradCAM attributions. + GradCAM attributions are computed with respect to the layer + provided in the constructor, and attributions + are upsampled to match the input size. GradCAM is designed for + convolutional neural networks, and is usually applied to the last + convolutional layer. + + Note that if multiple input tensors are provided, attributions for + each input tensor are computed by upsampling the GradCAM + attributions to match that input's dimensions. If interpolation is + not possible for the input tensor dimensions and interpolation mode, + then an empty tensor is returned in the attributions for the + corresponding position of that input tensor. This can occur if the + input tensor does not have the same number of dimensions as the chosen + layer's output or is not either 3D, 4D or 5D. + + Note that attributions are only meaningful for input tensors + which are spatially alligned with the chosen layer, e.g. an input + image tensor for a convolutional layer. + + More details regarding GuidedGradCAM can be found in the original + GradCAM paper here: + https://arxiv.org/pdf/1610.02391.pdf + + Warning: Ensure that all ReLU operations in the forward function of the + given model are performed using a module (nn.module.ReLU). + If nn.functional.ReLU is used, gradients are not overridden appropriately. + """ + + def __init__( + self, model: Module, layer: Module, device_ids: Union[None, List[int]] = None + ) -> None: + r""" + Args: + + model (nn.Module): The reference to PyTorch model instance. Model cannot + contain any in-place ReLU submodules; these are not + supported by the register_full_backward_hook PyTorch API + starting from PyTorch v1.9. + layer (torch.nn.Module): Layer for which GradCAM attributions are computed. + Currently, only layers with a single tensor output are + supported. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + """ + GradientAttribution.__init__(self, model) + self.grad_cam = LayerGradCam(model, layer, device_ids) + self.guided_backprop = GuidedBackprop(model) + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + interpolate_mode: str = "nearest", + attribute_to_layer_input: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which attributions + are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + interpolate_mode (str, optional): Method for interpolation, which + must be a valid input interpolation mode for + torch.nn.functional. These methods are + "nearest", "area", "linear" (3D-only), "bilinear" + (4D-only), "bicubic" (4D-only), "trilinear" (5D-only) + based on the number of dimensions of the chosen layer + output (which must also match the number of + dimensions for the input tensor). Note that + the original GradCAM paper uses "bilinear" + interpolation, but we default to "nearest" for + applicability to any of 3D, 4D or 5D tensors. + Default: "nearest" + attribute_to_layer_input (bool, optional): Indicates whether to + compute the attribution with respect to the layer input + or output in `LayerGradCam`. + If `attribute_to_layer_input` is set to True + then the attributions will be computed with respect to + layer inputs, otherwise it will be computed with respect + to layer outputs. + Note that currently it is assumed that either the input + or the output of internal layer, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + + Returns: + *tensor* of **attributions**: + - **attributions** (*tensor*): + Element-wise product of (upsampled) GradCAM + and Guided Backprop attributions. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + Attributions will be the same size as the provided inputs, + with each value providing the attribution of the + corresponding input index. + If the GradCAM attributions cannot be upsampled to the shape + of a given input tensor, None is returned in the corresponding + index position. + + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> # It contains an attribute conv4, which is an instance of nn.conv2d, + >>> # and the output of this layer has dimensions Nx50x8x8. + >>> # It is the last convolution layer, which is the recommended + >>> # use case for GuidedGradCAM. + >>> net = ImageClassifier() + >>> guided_gc = GuidedGradCam(net, net.conv4) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes guided GradCAM attributions for class 3. + >>> # attribution size matches input size, Nx3x32x32 + >>> attribution = guided_gc.attribute(input, 3) + """ + is_inputs_tuple = _is_tuple(inputs) + inputs = _format_tensor_into_tuples(inputs) + grad_cam_attr = self.grad_cam.attribute.__wrapped__( + self.grad_cam, # self + inputs=inputs, + target=target, + additional_forward_args=additional_forward_args, + attribute_to_layer_input=attribute_to_layer_input, + relu_attributions=True, + ) + if isinstance(grad_cam_attr, tuple): + assert len(grad_cam_attr) == 1, ( + "GuidedGradCAM attributions for layer with multiple inputs / " + "outputs is not supported." + ) + grad_cam_attr = grad_cam_attr[0] + + guided_backprop_attr = self.guided_backprop.attribute.__wrapped__( + self.guided_backprop, # self + inputs=inputs, + target=target, + additional_forward_args=additional_forward_args, + ) + output_attr: List[Tensor] = [] + for i in range(len(inputs)): + try: + output_attr.append( + guided_backprop_attr[i] + * LayerAttribution.interpolate( + grad_cam_attr, + inputs[i].shape[2:], + interpolate_mode=interpolate_mode, + ) + ) + except Exception: + warnings.warn( + "Couldn't appropriately interpolate GradCAM attributions for some " + "input tensors, returning empty tensor for corresponding " + "attributions." + ) + output_attr.append(torch.empty(0)) + + return _format_output(is_inputs_tuple, tuple(output_attr)) diff --git a/captum/attr/_core/input_x_gradient.py b/captum/attr/_core/input_x_gradient.py new file mode 100644 index 0000000000000000000000000000000000000000..7817466013541a209a2c915a097eaa97db1db705 --- /dev/null +++ b/captum/attr/_core/input_x_gradient.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +from typing import Any, Callable + +from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple +from captum._utils.gradient import ( + apply_gradient_requirements, + undo_gradient_requirements, +) +from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr._utils.attribution import GradientAttribution +from captum.log import log_usage + + +class InputXGradient(GradientAttribution): + r""" + A baseline approach for computing the attribution. It multiplies input with + the gradient with respect to input. + https://arxiv.org/abs/1605.01713 + """ + + def __init__(self, forward_func: Callable) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + """ + GradientAttribution.__init__(self, forward_func) + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which + attributions are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a tuple + containing multiple additional arguments including tensors + or any arbitrary python types. These arguments are provided to + forward_func in order following the arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + The input x gradient with + respect to each input feature. Attributions will always be + the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> # Generating random input with size 2x3x3x32 + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Defining InputXGradient interpreter + >>> input_x_gradient = InputXGradient(net) + >>> # Computes inputXgradient for class 4. + >>> attribution = input_x_gradient.attribute(input, target=4) + """ + # Keeps track whether original input is a tuple or not before + # converting it into a tuple. + is_inputs_tuple = _is_tuple(inputs) + + inputs = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(inputs) + + gradients = self.gradient_func( + self.forward_func, inputs, target, additional_forward_args + ) + + attributions = tuple( + input * gradient for input, gradient in zip(inputs, gradients) + ) + + undo_gradient_requirements(inputs, gradient_mask) + return _format_output(is_inputs_tuple, attributions) + + @property + def multiplies_by_inputs(self): + return True diff --git a/captum/attr/_core/integrated_gradients.py b/captum/attr/_core/integrated_gradients.py new file mode 100644 index 0000000000000000000000000000000000000000..e96a826c325cdc6786c7069165ce6bd40cd112a3 --- /dev/null +++ b/captum/attr/_core/integrated_gradients.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +import typing +from typing import Any, Callable, List, Tuple, Union + +import torch +from captum._utils.common import ( + _expand_additional_forward_args, + _expand_target, + _format_additional_forward_args, + _format_output, + _is_tuple, +) +from captum._utils.typing import ( + BaselineType, + Literal, + TargetType, + TensorOrTupleOfTensorsGeneric, +) +from captum.attr._utils.approximation_methods import approximation_parameters +from captum.attr._utils.attribution import GradientAttribution +from captum.attr._utils.batching import _batch_attribution +from captum.attr._utils.common import ( + _format_input_baseline, + _reshape_and_sum, + _validate_input, +) +from captum.log import log_usage +from torch import Tensor + + +class IntegratedGradients(GradientAttribution): + r""" + Integrated Gradients is an axiomatic model interpretability algorithm that + assigns an importance score to each input feature by approximating the + integral of gradients of the model's output with respect to the inputs + along the path (straight line) from given baselines / references to inputs. + + Baselines can be provided as input arguments to attribute method. + To approximate the integral we can choose to use either a variant of + Riemann sum or Gauss-Legendre quadrature rule. + + More details regarding the integrated gradients method can be found in the + original paper: + https://arxiv.org/abs/1703.01365 + + """ + + def __init__( + self, + forward_func: Callable, + multiply_by_inputs: bool = True, + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in, + then that type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of integrated gradients, if `multiply_by_inputs` + is set to True, final sensitivity scores are being multiplied by + (inputs - baselines). + """ + GradientAttribution.__init__(self, forward_func) + self._multiply_by_inputs = multiply_by_inputs + + # The following overloaded method signatures correspond to the case where + # return_convergence_delta is False, then only attributions are returned, + # and when return_convergence_delta is True, the return type is + # a tuple with both attributions and deltas. + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + internal_batch_size: Union[None, int] = None, + return_convergence_delta: Literal[False] = False, + ) -> TensorOrTupleOfTensorsGeneric: + ... + + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + internal_batch_size: Union[None, int] = None, + *, + return_convergence_delta: Literal[True], + ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: + ... + + @log_usage() + def attribute( # type: ignore + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + internal_batch_size: Union[None, int] = None, + return_convergence_delta: bool = False, + ) -> Union[ + TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] + ]: + r""" + This method attributes the output of the model with given target index + (in case it is provided, otherwise it assumes that output is a + scalar) to the inputs of the model using the approach described above. + + In addition to that it also returns, if `return_convergence_delta` is + set to True, integral approximation delta based on the completeness + property of integrated gradients. + + Args: + + inputs (tensor or tuple of tensors): Input for which integrated + gradients are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define the starting point from which integral + is computed and can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. It will be + repeated for each of `n_steps` along the integrated + path. For all other types, the given argument is used + for all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + n_steps (int, optional): The number of steps used by the approximation + method. Default: 50. + method (string, optional): Method for approximating the integral, + one of `riemann_right`, `riemann_left`, `riemann_middle`, + `riemann_trapezoid` or `gausslegendre`. + Default: `gausslegendre` if no method is provided. + internal_batch_size (int, optional): Divides total #steps * #examples + data points into chunks of size at most internal_batch_size, + which are computed (forward / backward passes) + sequentially. internal_batch_size must be at least equal to + #examples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain internal_batch_size / num_devices examples. + If internal_batch_size is None, then all evaluations are + processed in one batch. + Default: None + return_convergence_delta (bool, optional): Indicates whether to return + convergence delta or not. If `return_convergence_delta` + is set to True convergence delta will be returned in + a tuple following attributions. + Default: False + Returns: + **attributions** or 2-element tuple of **attributions**, **delta**: + - **attributions** (*tensor* or tuple of *tensors*): + Integrated gradients with respect to each input feature. + attributions will always be the same size as the provided + inputs, with each value providing the attribution of the + corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + - **delta** (*tensor*, returned if return_convergence_delta=True): + The difference between the total approximated and true + integrated gradients. This is computed using the property + that the total sum of forward_func(inputs) - + forward_func(baselines) must equal the total sum of the + integrated gradient. + Delta is calculated per example, meaning that the number of + elements in returned delta tensor is equal to the number of + of examples in inputs. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> ig = IntegratedGradients(net) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes integrated gradients for class 3. + >>> attribution = ig.attribute(input, target=3) + """ + # Keeps track whether original input is a tuple or not before + # converting it into a tuple. + is_inputs_tuple = _is_tuple(inputs) + + inputs, baselines = _format_input_baseline(inputs, baselines) + + _validate_input(inputs, baselines, n_steps, method) + + if internal_batch_size is not None: + num_examples = inputs[0].shape[0] + attributions = _batch_attribution( + self, + num_examples, + internal_batch_size, + n_steps, + inputs=inputs, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + method=method, + ) + else: + attributions = self._attribute( + inputs=inputs, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + n_steps=n_steps, + method=method, + ) + + if return_convergence_delta: + start_point, end_point = baselines, inputs + # computes approximation error based on the completeness axiom + delta = self.compute_convergence_delta( + attributions, + start_point, + end_point, + additional_forward_args=additional_forward_args, + target=target, + ) + return _format_output(is_inputs_tuple, attributions), delta + return _format_output(is_inputs_tuple, attributions) + + def _attribute( + self, + inputs: Tuple[Tensor, ...], + baselines: Tuple[Union[Tensor, int, float], ...], + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None, + ) -> Tuple[Tensor, ...]: + if step_sizes_and_alphas is None: + # retrieve step size and scaling factor for specified + # approximation method + step_sizes_func, alphas_func = approximation_parameters(method) + step_sizes, alphas = step_sizes_func(n_steps), alphas_func(n_steps) + else: + step_sizes, alphas = step_sizes_and_alphas + + # scale features and compute gradients. (batch size is abbreviated as bsz) + # scaled_features' dim -> (bsz * #steps x inputs[0].shape[1:], ...) + scaled_features_tpl = tuple( + torch.cat( + [baseline + alpha * (input - baseline) for alpha in alphas], dim=0 + ).requires_grad_() + for input, baseline in zip(inputs, baselines) + ) + + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + # apply number of steps to additional forward args + # currently, number of steps is applied only to additional forward arguments + # that are nd-tensors. It is assumed that the first dimension is + # the number of batches. + # dim -> (bsz * #steps x additional_forward_args[0].shape[1:], ...) + input_additional_args = ( + _expand_additional_forward_args(additional_forward_args, n_steps) + if additional_forward_args is not None + else None + ) + expanded_target = _expand_target(target, n_steps) + + # grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...) + grads = self.gradient_func( + forward_fn=self.forward_func, + inputs=scaled_features_tpl, + target_ind=expanded_target, + additional_forward_args=input_additional_args, + ) + + # flattening grads so that we can multilpy it with step-size + # calling contiguous to avoid `memory whole` problems + scaled_grads = [ + grad.contiguous().view(n_steps, -1) + * torch.tensor(step_sizes).view(n_steps, 1).to(grad.device) + for grad in grads + ] + + # aggregates across all steps for each tensor in the input tuple + # total_grads has the same dimensionality as inputs + total_grads = tuple( + _reshape_and_sum( + scaled_grad, n_steps, grad.shape[0] // n_steps, grad.shape[1:] + ) + for (scaled_grad, grad) in zip(scaled_grads, grads) + ) + + # computes attribution for each tensor in input tuple + # attributions has the same dimensionality as inputs + if not self.multiplies_by_inputs: + attributions = total_grads + else: + attributions = tuple( + total_grad * (input - baseline) + for total_grad, input, baseline in zip(total_grads, inputs, baselines) + ) + return attributions + + def has_convergence_delta(self) -> bool: + return True + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs diff --git a/captum/attr/_core/kernel_shap.py b/captum/attr/_core/kernel_shap.py new file mode 100644 index 0000000000000000000000000000000000000000..2826b30dfe70f3f72b34ed00ddeed6c7c027141a --- /dev/null +++ b/captum/attr/_core/kernel_shap.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python3 + +from typing import Any, Callable, Generator, Tuple, Union + +import torch +from captum._utils.models.linear_model import SkLearnLinearRegression +from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr._core.lime import construct_feature_mask, Lime +from captum.attr._utils.common import _format_input_baseline +from captum.log import log_usage +from torch import Tensor +from torch.distributions.categorical import Categorical + + +class KernelShap(Lime): + r""" + Kernel SHAP is a method that uses the LIME framework to compute + Shapley Values. Setting the loss function, weighting kernel and + regularization terms appropriately in the LIME framework allows + theoretically obtaining Shapley Values more efficiently than + directly computing Shapley Values. + + More information regarding this method and proof of equivalence + can be found in the original paper here: + https://arxiv.org/abs/1705.07874 + """ + + def __init__(self, forward_func: Callable) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or + any modification of it + """ + Lime.__init__( + self, + forward_func, + interpretable_model=SkLearnLinearRegression(), + similarity_func=self.kernel_shap_similarity_kernel, + perturb_func=self.kernel_shap_perturb_generator, + ) + self.inf_weight = 1000000.0 + + @log_usage() + def attribute( # type: ignore + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, + n_samples: int = 25, + perturbations_per_eval: int = 1, + return_input_shape: bool = True, + show_progress: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + This method attributes the output of the model with given target index + (in case it is provided, otherwise it assumes that output is a + scalar) to the inputs of the model using the approach described above, + training an interpretable model based on KernelSHAP and returning a + representation of the interpretable model. + + It is recommended to only provide a single example as input (tensors + with first dimension or batch size = 1). This is because LIME / KernelShap + is generally used for sample-based interpretability, training a separate + interpretable model to explain a model's prediction on each individual example. + + A batch of inputs can also be provided as inputs, similar to + other perturbation-based attribution methods. In this case, if forward_fn + returns a scalar per example, attributions will be computed for each + example independently, with a separate interpretable model trained for each + example. Note that provided similarity and perturbation functions will be + provided each example separately (first dimension = 1) in this case. + If forward_fn returns a scalar per batch (e.g. loss), attributions will + still be computed using a single interpretable model for the full batch. + In this case, similarity and perturbation functions will be provided the + same original input containing the full batch. + + The number of interpretable features is determined from the provided + feature mask, or if none is provided, from the default feature mask, + which considers each scalar input as a separate feature. It is + generally recommended to provide a feature mask which groups features + into a small number of interpretable features / components (e.g. + superpixels in images). + + Args: + + inputs (tensor or tuple of tensors): Input for which KernelShap + is computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define the reference value which replaces each + feature when the corresponding interpretable feature + is set to 0. + Baselines can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which surrogate model is trained + (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. It will be + repeated for each of `n_steps` along the integrated + path. For all other types, the given argument is used + for all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + feature_mask (tensor or tuple of tensors, optional): + feature_mask defines a mask for the input, grouping + features which correspond to the same + interpretable feature. feature_mask + should contain the same number of tensors as inputs. + Each tensor should + be the same size as the corresponding input or + broadcastable to match the input tensor. Values across + all tensors should be integers in the range 0 to + num_interp_features - 1, and indices corresponding to the + same feature should have the same value. + Note that features are grouped across tensors + (unlike feature ablation and occlusion), so + if the same index is used in different tensors, those + features are still grouped and added simultaneously. + If None, then a feature mask is constructed which assigns + each scalar within a tensor as a separate feature. + Default: None + n_samples (int, optional): The number of samples of the original + model used to train the surrogate interpretable model. + Default: `50` if `n_samples` is not provided. + perturbations_per_eval (int, optional): Allows multiple samples + to be processed simultaneously in one call to forward_fn. + Each forward pass will contain a maximum of + perturbations_per_eval * #examples samples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain at most + (perturbations_per_eval * #examples) / num_devices + samples. + If the forward function returns a single scalar per batch, + perturbations_per_eval must be set to 1. + Default: 1 + return_input_shape (bool, optional): Determines whether the returned + tensor(s) only contain the coefficients for each interp- + retable feature from the trained surrogate model, or + whether the returned attributions match the input shape. + When return_input_shape is True, the return type of attribute + matches the input shape, with each element containing the + coefficient of the corresponding interpretable feature. + All elements with the same value in the feature mask + will contain the same coefficient in the returned + attributions. If return_input_shape is False, a 1D + tensor is returned, containing only the coefficients + of the trained interpretable model, with length + num_interp_features. + show_progress (bool, optional): Displays the progress of computation. + It will try to use tqdm if available for advanced features + (e.g. time estimation). Otherwise, it will fallback to + a simple output of progress. + Default: False + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + The attributions with respect to each input feature. + If return_input_shape = True, attributions will be + the same size as the provided inputs, with each value + providing the coefficient of the corresponding + interpretale feature. + If return_input_shape is False, a 1D + tensor is returned, containing only the coefficients + of the trained interpreatable models, with length + num_interp_features. + Examples:: + >>> # SimpleClassifier takes a single input tensor of size Nx4x4, + >>> # and returns an Nx3 tensor of class probabilities. + >>> net = SimpleClassifier() + + >>> # Generating random input with size 1 x 4 x 4 + >>> input = torch.randn(1, 4, 4) + + >>> # Defining KernelShap interpreter + >>> ks = KernelShap(net) + >>> # Computes attribution, with each of the 4 x 4 = 16 + >>> # features as a separate interpretable feature + >>> attr = ks.attribute(input, target=1, n_samples=200) + + >>> # Alternatively, we can group each 2x2 square of the inputs + >>> # as one 'interpretable' feature and perturb them together. + >>> # This can be done by creating a feature mask as follows, which + >>> # defines the feature groups, e.g.: + >>> # +---+---+---+---+ + >>> # | 0 | 0 | 1 | 1 | + >>> # +---+---+---+---+ + >>> # | 0 | 0 | 1 | 1 | + >>> # +---+---+---+---+ + >>> # | 2 | 2 | 3 | 3 | + >>> # +---+---+---+---+ + >>> # | 2 | 2 | 3 | 3 | + >>> # +---+---+---+---+ + >>> # With this mask, all inputs with the same value are set to their + >>> # baseline value, when the corresponding binary interpretable + >>> # feature is set to 0. + >>> # The attributions can be calculated as follows: + >>> # feature mask has dimensions 1 x 4 x 4 + >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1], + >>> [2,2,3,3],[2,2,3,3]]]) + + >>> # Computes KernelSHAP attributions with feature mask. + >>> attr = ks.attribute(input, target=1, feature_mask=feature_mask) + """ + formatted_inputs, baselines = _format_input_baseline(inputs, baselines) + feature_mask, num_interp_features = construct_feature_mask( + feature_mask, formatted_inputs + ) + num_features_list = torch.arange(num_interp_features, dtype=torch.float) + denom = num_features_list * (num_interp_features - num_features_list) + probs = (num_interp_features - 1) / denom + probs[0] = 0.0 + return self._attribute_kwargs( + inputs=inputs, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + feature_mask=feature_mask, + n_samples=n_samples, + perturbations_per_eval=perturbations_per_eval, + return_input_shape=return_input_shape, + num_select_distribution=Categorical(probs), + show_progress=show_progress, + ) + + def kernel_shap_similarity_kernel( + self, _, __, interpretable_sample: Tensor, **kwargs + ) -> Tensor: + assert ( + "num_interp_features" in kwargs + ), "Must provide num_interp_features to use default similarity kernel" + num_selected_features = int(interpretable_sample.sum(dim=1).item()) + num_features = kwargs["num_interp_features"] + if num_selected_features == 0 or num_selected_features == num_features: + # weight should be theoretically infinite when + # num_selected_features = 0 or num_features + # enforcing that trained linear model must satisfy + # end-point criteria. In practice, it is sufficient to + # make this weight substantially larger so setting this + # weight to 1000000 (all other weights are 1). + similarities = self.inf_weight + else: + similarities = 1.0 + return torch.tensor([similarities]) + + def kernel_shap_perturb_generator( + self, original_inp: Union[Tensor, Tuple[Tensor, ...]], **kwargs + ) -> Generator[Tensor, None, None]: + r""" + Perturbations are sampled by the following process: + - Choose k (number of selected features), based on the distribution + p(k) = (M - 1) / (k * (M - k)) + where M is the total number of features in the interpretable space + - Randomly select a binary vector with k ones, each sample is equally + likely. This is done by generating a random vector of normal + values and thresholding based on the top k elements. + + Since there are M choose k vectors with k ones, this weighted sampling + is equivalent to applying the Shapley kernel for the sample weight, + defined as: + k(M, k) = (M - 1) / (k * (M - k) * (M choose k)) + """ + assert ( + "num_select_distribution" in kwargs and "num_interp_features" in kwargs + ), ( + "num_select_distribution and num_interp_features are necessary" + " to use kernel_shap_perturb_func" + ) + if isinstance(original_inp, Tensor): + device = original_inp.device + else: + device = original_inp[0].device + num_features = kwargs["num_interp_features"] + yield torch.ones(1, num_features, device=device, dtype=torch.long) + yield torch.zeros(1, num_features, device=device, dtype=torch.long) + while True: + num_selected_features = kwargs["num_select_distribution"].sample() + rand_vals = torch.randn(1, num_features) + threshold = torch.kthvalue( + rand_vals, num_features - num_selected_features + ).values.item() + yield (rand_vals > threshold).to(device=device).long() diff --git a/captum/attr/_core/layer/__init__.py b/captum/attr/_core/layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/attr/_core/layer/grad_cam.py b/captum/attr/_core/layer/grad_cam.py new file mode 100644 index 0000000000000000000000000000000000000000..c6504091492d801e49fda56d27888bf6549e1eda --- /dev/null +++ b/captum/attr/_core/layer/grad_cam.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, List, Tuple, Union + +import torch +import torch.nn.functional as F +from captum._utils.common import ( + _format_additional_forward_args, + _format_output, + _format_tensor_into_tuples, +) +from captum._utils.gradient import compute_layer_gradients_and_eval +from captum._utils.typing import TargetType +from captum.attr._utils.attribution import GradientAttribution, LayerAttribution +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module + + +class LayerGradCam(LayerAttribution, GradientAttribution): + r""" + Computes GradCAM attribution for chosen layer. GradCAM is designed for + convolutional neural networks, and is usually applied to the last + convolutional layer. + + GradCAM computes the gradients of the target output with respect to + the given layer, averages for each output channel (dimension 2 of + output), and multiplies the average gradient for each channel by the + layer activations. The results are summed over all channels. + + Note that in the original GradCAM algorithm described in the paper, + ReLU is applied to the output, returning only non-negative attributions. + For providing more flexibility to the user, we choose to not perform the + ReLU internally by default and return the sign information. To match the + original GradCAM algorithm, it is necessary to pass the parameter + relu_attributions=True to apply ReLU on the final + attributions or alternatively only visualize the positive attributions. + + Note: this procedure sums over the second dimension (# of channels), + so the output of GradCAM attributions will have a second + dimension of 1, but all other dimensions will match that of the layer + output. + + GradCAM attributions are generally upsampled and can be viewed as a + mask to the input, since a convolutional layer output generally + matches the input image spatially. This upsampling can be performed + using LayerAttribution.interpolate, as shown in the example below. + + More details regarding the GradCAM method can be found in the + original paper here: + https://arxiv.org/pdf/1610.02391.pdf + """ + + def __init__( + self, + forward_func: Callable, + layer: Module, + device_ids: Union[None, List[int]] = None, + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + layer (torch.nn.Module): Layer for which attributions are computed. + Output size of attribute matches this layer's output + dimensions, except for dimension 2, which will be 1, + since GradCAM sums over channels. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + """ + LayerAttribution.__init__(self, forward_func, layer, device_ids) + GradientAttribution.__init__(self, forward_func) + + @log_usage() + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + target: TargetType = None, + additional_forward_args: Any = None, + attribute_to_layer_input: bool = False, + relu_attributions: bool = False, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which attributions + are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + attribute_to_layer_input (bool, optional): Indicates whether to + compute the attributions with respect to the layer input + or output. If `attribute_to_layer_input` is set to True + then the attributions will be computed with respect to the + layer input, otherwise it will be computed with respect + to layer output. + Note that currently it is assumed that either the input + or the outputs of internal layers, depending on whether we + attribute to the input or output, are single tensors. + Support for multiple tensors will be added later. + Default: False + relu_attributions (bool, optional): Indicates whether to + apply a ReLU operation on the final attribution, + returning only non-negative attributions. Setting this + flag to True matches the original GradCAM algorithm, + otherwise, by default, both positive and negative + attributions are returned. + Default: False + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + Attributions based on GradCAM method. + Attributions will be the same size as the + output of the given layer, except for dimension 2, + which will be 1 due to summing over channels. + Attributions are returned in a tuple if + the layer inputs / outputs contain multiple tensors, + otherwise a single tensor is returned. + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> # It contains a layer conv4, which is an instance of nn.conv2d, + >>> # and the output of this layer has dimensions Nx50x8x8. + >>> # It is the last convolution layer, which is the recommended + >>> # use case for GradCAM. + >>> net = ImageClassifier() + >>> layer_gc = LayerGradCam(net, net.conv4) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes layer GradCAM for class 3. + >>> # attribution size matches layer output except for dimension + >>> # 1, so dimensions of attr would be Nx1x8x8. + >>> attr = layer_gc.attribute(input, 3) + >>> # GradCAM attributions are often upsampled and viewed as a + >>> # mask to the input, since the convolutional layer output + >>> # spatially matches the original input image. + >>> # This can be done with LayerAttribution's interpolate method. + >>> upsampled_attr = LayerAttribution.interpolate(attr, (32, 32)) + """ + inputs = _format_tensor_into_tuples(inputs) + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + # Returns gradient of output with respect to + # hidden layer and hidden layer evaluated at each input. + layer_gradients, layer_evals = compute_layer_gradients_and_eval( + self.forward_func, + self.layer, + inputs, + target, + additional_forward_args, + device_ids=self.device_ids, + attribute_to_layer_input=attribute_to_layer_input, + ) + + summed_grads = tuple( + torch.mean( + layer_grad, + dim=tuple(x for x in range(2, len(layer_grad.shape))), + keepdim=True, + ) + if len(layer_grad.shape) > 2 + else layer_grad + for layer_grad in layer_gradients + ) + + scaled_acts = tuple( + torch.sum(summed_grad * layer_eval, dim=1, keepdim=True) + for summed_grad, layer_eval in zip(summed_grads, layer_evals) + ) + if relu_attributions: + scaled_acts = tuple(F.relu(scaled_act) for scaled_act in scaled_acts) + return _format_output(len(scaled_acts) > 1, scaled_acts) diff --git a/captum/attr/_core/layer/internal_influence.py b/captum/attr/_core/layer/internal_influence.py new file mode 100644 index 0000000000000000000000000000000000000000..8976fe7344982e1a7bd46d981f39b5da507e04d6 --- /dev/null +++ b/captum/attr/_core/layer/internal_influence.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, List, Tuple, Union + +import torch +from captum._utils.common import ( + _expand_additional_forward_args, + _expand_target, + _format_additional_forward_args, + _format_output, +) +from captum._utils.gradient import compute_layer_gradients_and_eval +from captum._utils.typing import BaselineType, TargetType +from captum.attr._utils.approximation_methods import approximation_parameters +from captum.attr._utils.attribution import GradientAttribution, LayerAttribution +from captum.attr._utils.batching import _batch_attribution +from captum.attr._utils.common import ( + _format_input_baseline, + _reshape_and_sum, + _validate_input, +) +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module + + +class InternalInfluence(LayerAttribution, GradientAttribution): + r""" + Computes internal influence by approximating the integral of gradients + for a particular layer along the path from a baseline input to the + given input. + If no baseline is provided, the default baseline is the zero tensor. + More details on this approach can be found here: + https://arxiv.org/pdf/1802.03788.pdf + + Note that this method is similar to applying integrated gradients and + taking the layer as input, integrating the gradient of the layer with + respect to the output. + """ + + def __init__( + self, + forward_func: Callable, + layer: Module, + device_ids: Union[None, List[int]] = None, + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + layer (torch.nn.Module): Layer for which attributions are computed. + Output size of attribute matches this layer's input or + output dimensions, depending on whether we attribute to + the inputs or outputs of the layer, corresponding to + attribution of each neuron in the input or output of + this layer. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + """ + LayerAttribution.__init__(self, forward_func, layer, device_ids) + GradientAttribution.__init__(self, forward_func) + + @log_usage() + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + internal_batch_size: Union[None, int] = None, + attribute_to_layer_input: bool = False, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which internal + influence is computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + baselines scalar, tensor, tuple of scalars or tensors, optional): + Baselines define a starting point from which integral + is computed and can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. It will be + repeated for each of `n_steps` along the integrated + path. For all other types, the given argument is used + for all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + n_steps (int, optional): The number of steps used by the approximation + method. Default: 50. + method (string, optional): Method for approximating the integral, + one of `riemann_right`, `riemann_left`, `riemann_middle`, + `riemann_trapezoid` or `gausslegendre`. + Default: `gausslegendre` if no method is provided. + internal_batch_size (int, optional): Divides total #steps * #examples + data points into chunks of size at most internal_batch_size, + which are computed (forward / backward passes) + sequentially. internal_batch_size must be at least equal to + #examples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain internal_batch_size / num_devices examples. + If internal_batch_size is None, then all evaluations + are processed in one batch. + Default: None + attribute_to_layer_input (bool, optional): Indicates whether to + compute the attribution with respect to the layer input + or output. If `attribute_to_layer_input` is set to True + then the attributions will be computed with respect to + layer inputs, otherwise it will be computed with respect + to layer outputs. + Note that currently it is assumed that either the input + or the output of internal layer, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + Internal influence of each neuron in given + layer output. Attributions will always be the same size + as the output or input of the given layer depending on + whether `attribute_to_layer_input` is set to `False` or + `True`respectively. + Attributions are returned in a tuple if + the layer inputs / outputs contain multiple tensors, + otherwise a single tensor is returned. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> # It contains an attribute conv1, which is an instance of nn.conv2d, + >>> # and the output of this layer has dimensions Nx12x32x32. + >>> net = ImageClassifier() + >>> layer_int_inf = InternalInfluence(net, net.conv1) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes layer internal influence. + >>> # attribution size matches layer output, Nx12x32x32 + >>> attribution = layer_int_inf.attribute(input) + """ + inputs, baselines = _format_input_baseline(inputs, baselines) + _validate_input(inputs, baselines, n_steps, method) + if internal_batch_size is not None: + num_examples = inputs[0].shape[0] + attrs = _batch_attribution( + self, + num_examples, + internal_batch_size, + n_steps, + inputs=inputs, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + method=method, + attribute_to_layer_input=attribute_to_layer_input, + ) + else: + attrs = self._attribute( + inputs=inputs, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + n_steps=n_steps, + method=method, + attribute_to_layer_input=attribute_to_layer_input, + ) + + return attrs + + def _attribute( + self, + inputs: Tuple[Tensor, ...], + baselines: Tuple[Union[Tensor, int, float], ...], + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + attribute_to_layer_input: bool = False, + step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + if step_sizes_and_alphas is None: + # retrieve step size and scaling factor for specified approximation method + step_sizes_func, alphas_func = approximation_parameters(method) + step_sizes, alphas = step_sizes_func(n_steps), alphas_func(n_steps) + else: + step_sizes, alphas = step_sizes_and_alphas + + # Compute scaled inputs from baseline to final input. + scaled_features_tpl = tuple( + torch.cat( + [baseline + alpha * (input - baseline) for alpha in alphas], dim=0 + ).requires_grad_() + for input, baseline in zip(inputs, baselines) + ) + + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + # apply number of steps to additional forward args + # currently, number of steps is applied only to additional forward arguments + # that are nd-tensors. It is assumed that the first dimension is + # the number of batches. + # dim -> (bsz * #steps x additional_forward_args[0].shape[1:], ...) + input_additional_args = ( + _expand_additional_forward_args(additional_forward_args, n_steps) + if additional_forward_args is not None + else None + ) + expanded_target = _expand_target(target, n_steps) + + # Returns gradient of output with respect to hidden layer. + layer_gradients, _ = compute_layer_gradients_and_eval( + forward_fn=self.forward_func, + layer=self.layer, + inputs=scaled_features_tpl, + target_ind=expanded_target, + additional_forward_args=input_additional_args, + device_ids=self.device_ids, + attribute_to_layer_input=attribute_to_layer_input, + ) + # flattening grads so that we can multiply it with step-size + # calling contiguous to avoid `memory whole` problems + scaled_grads = tuple( + layer_grad.contiguous().view(n_steps, -1) + * torch.tensor(step_sizes).view(n_steps, 1).to(layer_grad.device) + for layer_grad in layer_gradients + ) + + # aggregates across all steps for each tensor in the input tuple + attrs = tuple( + _reshape_and_sum( + scaled_grad, n_steps, inputs[0].shape[0], layer_grad.shape[1:] + ) + for scaled_grad, layer_grad in zip(scaled_grads, layer_gradients) + ) + return _format_output(len(attrs) > 1, attrs) diff --git a/captum/attr/_core/layer/layer_activation.py b/captum/attr/_core/layer/layer_activation.py new file mode 100644 index 0000000000000000000000000000000000000000..86c511706b520600437878e8601f3876eaf7f3d4 --- /dev/null +++ b/captum/attr/_core/layer/layer_activation.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, List, Tuple, Union + +import torch +from captum._utils.common import _format_output +from captum._utils.gradient import _forward_layer_eval +from captum._utils.typing import ModuleOrModuleList +from captum.attr._utils.attribution import LayerAttribution +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module + + +class LayerActivation(LayerAttribution): + r""" + Computes activation of selected layer for given input. + """ + + def __init__( + self, + forward_func: Callable, + layer: ModuleOrModuleList, + device_ids: Union[None, List[int]] = None, + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + layer (torch.nn.Module or list(torch.nn.Module)): Layer or layers + for which attributions are computed. + Output size of attribute matches this layer's input or + output dimensions, depending on whether we attribute to + the inputs or outputs of the layer, corresponding to + attribution of each neuron in the input or output of + this layer. If multiple layers are provided, attributions + are returned as a list, each element corresponding to the + activations of the corresponding layer. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + """ + LayerAttribution.__init__(self, forward_func, layer, device_ids) + + @log_usage() + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + additional_forward_args: Any = None, + attribute_to_layer_input: bool = False, + ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which layer + activation is computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + attribute_to_layer_input (bool, optional): Indicates whether to + compute the attribution with respect to the layer input + or output. If `attribute_to_layer_input` is set to True + then the attributions will be computed with respect to + layer input, otherwise it will be computed with respect + to layer output. + Note that currently it is assumed that either the input + or the output of internal layer, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + + Returns: + *tensor* or tuple of *tensors* or *list* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors* or *list*): + Activation of each neuron in given layer output. + Attributions will always be the same size as the + output of the given layer. + Attributions are returned in a tuple if + the layer inputs / outputs contain multiple tensors, + otherwise a single tensor is returned. + If multiple layers are provided, attributions + are returned as a list, each element corresponding to the + activations of the corresponding layer. + + + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> # It contains an attribute conv1, which is an instance of nn.conv2d, + >>> # and the output of this layer has dimensions Nx12x32x32. + >>> net = ImageClassifier() + >>> layer_act = LayerActivation(net, net.conv1) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes layer activation. + >>> # attribution is layer output, with size Nx12x32x32 + >>> attribution = layer_cond.attribute(input) + """ + with torch.no_grad(): + layer_eval = _forward_layer_eval( + self.forward_func, + inputs, + self.layer, + additional_forward_args, + device_ids=self.device_ids, + attribute_to_layer_input=attribute_to_layer_input, + ) + if isinstance(self.layer, Module): + return _format_output(len(layer_eval) > 1, layer_eval) + else: + return [ + _format_output(len(single_layer_eval) > 1, single_layer_eval) + for single_layer_eval in layer_eval + ] + + @property + def multiplies_by_inputs(self): + return True diff --git a/captum/attr/_core/layer/layer_conductance.py b/captum/attr/_core/layer/layer_conductance.py new file mode 100644 index 0000000000000000000000000000000000000000..3d76569c10fd71ab97ae21052f87b77ee81b0ef4 --- /dev/null +++ b/captum/attr/_core/layer/layer_conductance.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 +import typing +from typing import Any, Callable, List, Tuple, Union + +import torch +from captum._utils.common import ( + _expand_additional_forward_args, + _expand_target, + _format_additional_forward_args, + _format_output, +) +from captum._utils.gradient import compute_layer_gradients_and_eval +from captum._utils.typing import BaselineType, Literal, TargetType +from captum.attr._utils.approximation_methods import approximation_parameters +from captum.attr._utils.attribution import GradientAttribution, LayerAttribution +from captum.attr._utils.batching import _batch_attribution +from captum.attr._utils.common import ( + _format_input_baseline, + _reshape_and_sum, + _validate_input, +) +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module + + +class LayerConductance(LayerAttribution, GradientAttribution): + r""" + Computes conductance with respect to the given layer. The + returned output is in the shape of the layer's output, showing the total + conductance of each hidden layer neuron. + + The details of the approach can be found here: + https://arxiv.org/abs/1805.12233 + https://arxiv.org/pdf/1807.09946.pdf + + Note that this provides the total conductance of each neuron in the + layer's output. To obtain the breakdown of a neuron's conductance by input + features, utilize NeuronConductance instead, and provide the target + neuron index. + """ + + def __init__( + self, + forward_func: Callable, + layer: Module, + device_ids: Union[None, List[int]] = None, + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + layer (torch.nn.Module): Layer for which attributions are computed. + Output size of attribute matches this layer's input or + output dimensions, depending on whether we attribute to + the inputs or outputs of the layer, corresponding to + attribution of each neuron in the input or output of + this layer. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + """ + LayerAttribution.__init__(self, forward_func, layer, device_ids) + GradientAttribution.__init__(self, forward_func) + + def has_convergence_delta(self) -> bool: + return True + + @typing.overload + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + internal_batch_size: Union[None, int] = None, + *, + return_convergence_delta: Literal[True], + attribute_to_layer_input: bool = False, + ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: + ... + + @typing.overload + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + internal_batch_size: Union[None, int] = None, + return_convergence_delta: Literal[False] = False, + attribute_to_layer_input: bool = False, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + ... + + @log_usage() + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: Union[ + None, int, float, Tensor, Tuple[Union[int, float, Tensor], ...] + ] = None, + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + internal_batch_size: Union[None, int] = None, + return_convergence_delta: bool = False, + attribute_to_layer_input: bool = False, + ) -> Union[ + Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor] + ]: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which layer + conductance is computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define the starting point from which integral + is computed and can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. It will be repeated + for each of `n_steps` along the integrated path. + For all other types, the given argument is used for + all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + n_steps (int, optional): The number of steps used by the approximation + method. Default: 50. + method (string, optional): Method for approximating the integral, + one of `riemann_right`, `riemann_left`, `riemann_middle`, + `riemann_trapezoid` or `gausslegendre`. + Default: `gausslegendre` if no method is provided. + internal_batch_size (int, optional): Divides total #steps * #examples + data points into chunks of size at most internal_batch_size, + which are computed (forward / backward passes) + sequentially. internal_batch_size must be at least equal to + 2 * #examples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain internal_batch_size / num_devices examples. + If internal_batch_size is None, then all evaluations are + processed in one batch. + Default: None + return_convergence_delta (bool, optional): Indicates whether to return + convergence delta or not. If `return_convergence_delta` + is set to True convergence delta will be returned in + a tuple following attributions. + Default: False + attribute_to_layer_input (bool, optional): Indicates whether to + compute the attribution with respect to the layer input + or output. If `attribute_to_layer_input` is set to True + then the attributions will be computed with respect to + layer inputs, otherwise it will be computed with respect + to layer outputs. + Note that currently it is assumed that either the input + or the output of internal layer, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + + Returns: + **attributions** or 2-element tuple of **attributions**, **delta**: + - **attributions** (*tensor* or tuple of *tensors*): + Conductance of each neuron in given layer input or + output. Attributions will always be the same size as + the input or output of the given layer, depending on + whether we attribute to the inputs or outputs + of the layer which is decided by the input flag + `attribute_to_layer_input`. + Attributions are returned in a tuple if + the layer inputs / outputs contain multiple tensors, + otherwise a single tensor is returned. + - **delta** (*tensor*, returned if return_convergence_delta=True): + The difference between the total + approximated and true conductance. + This is computed using the property that the total sum of + forward_func(inputs) - forward_func(baselines) must equal + the total sum of the attributions. + Delta is calculated per example, meaning that the number of + elements in returned delta tensor is equal to the number of + of examples in inputs. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> # It contains an attribute conv1, which is an instance of nn.conv2d, + >>> # and the output of this layer has dimensions Nx12x32x32. + >>> net = ImageClassifier() + >>> layer_cond = LayerConductance(net, net.conv1) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes layer conductance for class 3. + >>> # attribution size matches layer output, Nx12x32x32 + >>> attribution = layer_cond.attribute(input, target=3) + """ + inputs, baselines = _format_input_baseline(inputs, baselines) + _validate_input(inputs, baselines, n_steps, method) + + num_examples = inputs[0].shape[0] + if internal_batch_size is not None: + num_examples = inputs[0].shape[0] + attrs = _batch_attribution( + self, + num_examples, + internal_batch_size, + n_steps + 1, + include_endpoint=True, + inputs=inputs, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + method=method, + attribute_to_layer_input=attribute_to_layer_input, + ) + + else: + attrs = self._attribute( + inputs=inputs, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + n_steps=n_steps, + method=method, + attribute_to_layer_input=attribute_to_layer_input, + ) + + is_layer_tuple = isinstance(attrs, tuple) + attributions = attrs if is_layer_tuple else (attrs,) + + if return_convergence_delta: + start_point, end_point = baselines, inputs + delta = self.compute_convergence_delta( + attributions, + start_point, + end_point, + target=target, + additional_forward_args=additional_forward_args, + ) + return _format_output(is_layer_tuple, attributions), delta + return _format_output(is_layer_tuple, attributions) + + def _attribute( + self, + inputs: Tuple[Tensor, ...], + baselines: Tuple[Union[Tensor, int, float], ...], + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + attribute_to_layer_input: bool = False, + step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + num_examples = inputs[0].shape[0] + if step_sizes_and_alphas is None: + # Retrieve scaling factors for specified approximation method + step_sizes_func, alphas_func = approximation_parameters(method) + alphas = alphas_func(n_steps + 1) + else: + _, alphas = step_sizes_and_alphas + # Compute scaled inputs from baseline to final input. + scaled_features_tpl = tuple( + torch.cat( + [baseline + alpha * (input - baseline) for alpha in alphas], dim=0 + ).requires_grad_() + for input, baseline in zip(inputs, baselines) + ) + + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + # apply number of steps to additional forward args + # currently, number of steps is applied only to additional forward arguments + # that are nd-tensors. It is assumed that the first dimension is + # the number of batches. + # dim -> (#examples * #steps x additional_forward_args[0].shape[1:], ...) + input_additional_args = ( + _expand_additional_forward_args(additional_forward_args, n_steps + 1) + if additional_forward_args is not None + else None + ) + expanded_target = _expand_target(target, n_steps + 1) + + # Conductance Gradients - Returns gradient of output with respect to + # hidden layer and hidden layer evaluated at each input. + (layer_gradients, layer_evals,) = compute_layer_gradients_and_eval( + forward_fn=self.forward_func, + layer=self.layer, + inputs=scaled_features_tpl, + additional_forward_args=input_additional_args, + target_ind=expanded_target, + device_ids=self.device_ids, + attribute_to_layer_input=attribute_to_layer_input, + ) + + # Compute differences between consecutive evaluations of layer_eval. + # This approximates the total input gradient of each step multiplied + # by the step size. + grad_diffs = tuple( + layer_eval[num_examples:] - layer_eval[:-num_examples] + for layer_eval in layer_evals + ) + + # Element-wise multiply gradient of output with respect to hidden layer + # and summed gradients with respect to input (chain rule) and sum + # across stepped inputs. + attributions = tuple( + _reshape_and_sum( + grad_diff * layer_gradient[:-num_examples], + n_steps, + num_examples, + layer_eval.shape[1:], + ) + for layer_gradient, layer_eval, grad_diff in zip( + layer_gradients, layer_evals, grad_diffs + ) + ) + return _format_output(len(attributions) > 1, attributions) + + @property + def multiplies_by_inputs(self): + return True diff --git a/captum/attr/_core/layer/layer_deep_lift.py b/captum/attr/_core/layer/layer_deep_lift.py new file mode 100644 index 0000000000000000000000000000000000000000..71a8e9eb299bca59fbb2d485513a5d6acc1e9a47 --- /dev/null +++ b/captum/attr/_core/layer/layer_deep_lift.py @@ -0,0 +1,682 @@ +#!/usr/bin/env python3 +import typing +from typing import Any, Callable, cast, Sequence, Tuple, Union + +import torch +from captum._utils.common import ( + _expand_target, + _format_additional_forward_args, + _format_baseline, + _format_tensor_into_tuples, + ExpansionTypes, +) +from captum._utils.gradient import compute_layer_gradients_and_eval +from captum._utils.typing import ( + BaselineType, + Literal, + TargetType, + TensorOrTupleOfTensorsGeneric, +) +from captum.attr._core.deep_lift import DeepLift, DeepLiftShap +from captum.attr._utils.attribution import LayerAttribution +from captum.attr._utils.common import ( + _call_custom_attribution_func, + _compute_conv_delta_and_format_attrs, + _format_callable_baseline, + _tensorize_baseline, + _validate_input, +) +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module + + +class LayerDeepLift(LayerAttribution, DeepLift): + r""" + Implements DeepLIFT algorithm for the layer based on the following paper: + Learning Important Features Through Propagating Activation Differences, + Avanti Shrikumar, et. al. + https://arxiv.org/abs/1704.02685 + + and the gradient formulation proposed in: + Towards better understanding of gradient-based attribution methods for + deep neural networks, Marco Ancona, et.al. + https://openreview.net/pdf?id=Sy21R9JAW + + This implementation supports only Rescale rule. RevealCancel rule will + be supported in later releases. + Although DeepLIFT's(Rescale Rule) attribution quality is comparable with + Integrated Gradients, it runs significantly faster than Integrated + Gradients and is preferred for large datasets. + + Currently we only support a limited number of non-linear activations + but the plan is to expand the list in the future. + + Note: As we know, currently we cannot access the building blocks, + of PyTorch's built-in LSTM, RNNs and GRUs such as Tanh and Sigmoid. + Nonetheless, it is possible to build custom LSTMs, RNNS and GRUs + with performance similar to built-in ones using TorchScript. + More details on how to build custom RNNs can be found here: + https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/ + """ + + def __init__( + self, + model: Module, + layer: Module, + multiply_by_inputs: bool = True, + ) -> None: + r""" + Args: + + model (nn.Module): The reference to PyTorch model instance. Model cannot + contain any in-place nonlinear submodules; these are not + supported by the register_full_backward_hook PyTorch API + starting from PyTorch v1.9. + layer (torch.nn.Module): Layer for which attributions are computed. + The size and dimensionality of the attributions + corresponds to the size and dimensionality of the layer's + input or output depending on whether we attribute to the + inputs or outputs of the layer. + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in + then that type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of Layer DeepLift, if `multiply_by_inputs` + is set to True, final sensitivity scores + are being multiplied by + layer activations for inputs - layer activations for baselines. + This flag applies only if `custom_attribution_func` is + set to None. + """ + LayerAttribution.__init__(self, model, layer) + DeepLift.__init__(self, model) + self.model = model + self._multiply_by_inputs = multiply_by_inputs + + # Ignoring mypy error for inconsistent signature with DeepLift + @typing.overload # type: ignore + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: Literal[False] = False, + attribute_to_layer_input: bool = False, + custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + ... + + @typing.overload + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + *, + return_convergence_delta: Literal[True], + attribute_to_layer_input: bool = False, + custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, + ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: + ... + + @log_usage() + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: bool = False, + attribute_to_layer_input: bool = False, + custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, + ) -> Union[ + Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor] + ]: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which layer + attributions are computed. If forward_func takes a + single tensor as input, a single input tensor should be + provided. If forward_func takes multiple tensors as input, + a tuple of the input tensors should be provided. It is + assumed that for all given input tensors, dimension 0 + corresponds to the number of examples (aka batch size), + and if multiple input tensors are provided, the examples + must be aligned appropriately. + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define reference samples that are compared with + the inputs. In order to assign attribution scores DeepLift + computes the differences between the inputs/outputs and + corresponding references. + Baselines can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a tuple + containing multiple additional arguments including tensors + or any arbitrary python types. These arguments are provided to + forward_func in order, following the arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + return_convergence_delta (bool, optional): Indicates whether to return + convergence delta or not. If `return_convergence_delta` + is set to True convergence delta will be returned in + a tuple following attributions. + Default: False + attribute_to_layer_input (bool, optional): Indicates whether to + compute the attribution with respect to the layer input + or output. If `attribute_to_layer_input` is set to True + then the attributions will be computed with respect to + layer input, otherwise it will be computed with respect + to layer output. + Note that currently it is assumed that either the input + or the output of internal layer, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + custom_attribution_func (callable, optional): A custom function for + computing final attribution scores. This function can take + at least one and at most three arguments with the + following signature: + + - custom_attribution_func(multipliers) + - custom_attribution_func(multipliers, inputs) + - custom_attribution_func(multipliers, inputs, baselines) + + In case this function is not provided, we use the default + logic defined as: multipliers * (inputs - baselines) + It is assumed that all input arguments, `multipliers`, + `inputs` and `baselines` are provided in tuples of same length. + `custom_attribution_func` returns a tuple of attribution + tensors that have the same length as the `inputs`. + Default: None + + Returns: + **attributions** or 2-element tuple of **attributions**, **delta**: + - **attributions** (*tensor* or tuple of *tensors*): + Attribution score computed based on DeepLift's rescale rule with + respect to layer's inputs or outputs. Attributions will always be the + same size as the provided layer's inputs or outputs, depending on + whether we attribute to the inputs or outputs of the layer. + If the layer input / output is a single tensor, then + just a tensor is returned; if the layer input / output + has multiple tensors, then a corresponding tuple + of tensors is returned. + - **delta** (*tensor*, returned if return_convergence_delta=True): + This is computed using the property that the total sum of + forward_func(inputs) - forward_func(baselines) must equal the + total sum of the attributions computed based on DeepLift's + rescale rule. + Delta is calculated per example, meaning that the number of + elements in returned delta tensor is equal to the number of + of examples in input. + Note that the logic described for deltas is guaranteed + when the default logic for attribution computations is used, + meaning that the `custom_attribution_func=None`, otherwise + it is not guaranteed and depends on the specifics of the + `custom_attribution_func`. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> # creates an instance of LayerDeepLift to interpret target + >>> # class 1 with respect to conv4 layer. + >>> dl = LayerDeepLift(net, net.conv4) + >>> input = torch.randn(1, 3, 32, 32, requires_grad=True) + >>> # Computes deeplift attribution scores for conv4 layer and class 3. + >>> attribution = dl.attribute(input, target=1) + """ + inputs = _format_tensor_into_tuples(inputs) + baselines = _format_baseline(baselines, inputs) + _validate_input(inputs, baselines) + + baselines = _tensorize_baseline(inputs, baselines) + + main_model_hooks = [] + try: + main_model_hooks = self._hook_main_model() + + self.model.apply( + lambda mod: self._register_hooks( + mod, attribute_to_layer_input=attribute_to_layer_input + ) + ) + + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + expanded_target = _expand_target( + target, 2, expansion_type=ExpansionTypes.repeat + ) + wrapped_forward_func = self._construct_forward_func( + self.model, + (inputs, baselines), + expanded_target, + additional_forward_args, + ) + + def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence: + if isinstance(out, Tensor): + return out.chunk(2) + return tuple(out_sub.chunk(2) for out_sub in out) + + gradients, attrs = compute_layer_gradients_and_eval( + wrapped_forward_func, + self.layer, + inputs, + attribute_to_layer_input=attribute_to_layer_input, + output_fn=lambda out: chunk_output_fn(out), + ) + + attr_inputs = tuple(map(lambda attr: attr[0], attrs)) + attr_baselines = tuple(map(lambda attr: attr[1], attrs)) + gradients = tuple(map(lambda grad: grad[0], gradients)) + + if custom_attribution_func is None: + if self.multiplies_by_inputs: + attributions = tuple( + (input - baseline) * gradient + for input, baseline, gradient in zip( + attr_inputs, attr_baselines, gradients + ) + ) + else: + attributions = gradients + else: + attributions = _call_custom_attribution_func( + custom_attribution_func, gradients, attr_inputs, attr_baselines + ) + finally: + # remove hooks from all activations + self._remove_hooks(main_model_hooks) + + return _compute_conv_delta_and_format_attrs( + self, + return_convergence_delta, + attributions, + baselines, + inputs, + additional_forward_args, + target, + cast(Union[Literal[True], Literal[False]], len(attributions) > 1), + ) + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs + + +class LayerDeepLiftShap(LayerDeepLift, DeepLiftShap): + r""" + Extends LayerDeepLift and DeepLiftShap algorithms and approximates SHAP + values for given input `layer`. + For each input sample - baseline pair it computes DeepLift attributions + with respect to inputs or outputs of given `layer` averages + resulting attributions across baselines. Whether to compute the attributions + with respect to the inputs or outputs of the layer is defined by the + input flag `attribute_to_layer_input`. + More details about the algorithm can be found here: + + http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf + + Note that the explanation model: + 1. Assumes that input features are independent of one another + 2. Is linear, meaning that the explanations are modeled through + the additive composition of feature effects. + Although, it assumes a linear model for each explanation, the overall + model across multiple explanations can be complex and non-linear. + """ + + def __init__( + self, + model: Module, + layer: Module, + multiply_by_inputs: bool = True, + ) -> None: + r""" + Args: + + model (nn.Module): The reference to PyTorch model instance. Model cannot + contain any in-place nonlinear submodules; these are not + supported by the register_full_backward_hook PyTorch API + starting from PyTorch v1.9. + layer (torch.nn.Module): Layer for which attributions are computed. + The size and dimensionality of the attributions + corresponds to the size and dimensionality of the layer's + input or output depending on whether we attribute to the + inputs or outputs of the layer. + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in + then that type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of LayerDeepLiftShap, if `multiply_by_inputs` + is set to True, final sensitivity scores are being + multiplied by + layer activations for inputs - layer activations for baselines + This flag applies only if `custom_attribution_func` is + set to None. + """ + LayerDeepLift.__init__(self, model, layer) + DeepLiftShap.__init__(self, model, multiply_by_inputs) + + # Ignoring mypy error for inconsistent signature with DeepLiftShap + @typing.overload # type: ignore + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: Union[ + Tensor, Tuple[Tensor, ...], Callable[..., Union[Tensor, Tuple[Tensor, ...]]] + ], + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: Literal[False] = False, + attribute_to_layer_input: bool = False, + custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + ... + + @typing.overload + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: Union[ + Tensor, Tuple[Tensor, ...], Callable[..., Union[Tensor, Tuple[Tensor, ...]]] + ], + target: TargetType = None, + additional_forward_args: Any = None, + *, + return_convergence_delta: Literal[True], + attribute_to_layer_input: bool = False, + custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, + ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: + ... + + @log_usage() + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: Union[ + Tensor, Tuple[Tensor, ...], Callable[..., Union[Tensor, Tuple[Tensor, ...]]] + ], + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: bool = False, + attribute_to_layer_input: bool = False, + custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, + ) -> Union[ + Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor] + ]: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which layer + attributions are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + baselines (tensor, tuple of tensors, callable): + Baselines define reference samples that are compared with + the inputs. In order to assign attribution scores DeepLift + computes the differences between the inputs/outputs and + corresponding references. Baselines can be provided as: + + - a single tensor, if inputs is a single tensor, with + the first dimension equal to the number of examples + in the baselines' distribution. The remaining dimensions + must match with input tensor's dimension starting from + the second dimension. + + - a tuple of tensors, if inputs is a tuple of tensors, + with the first dimension of any tensor inside the tuple + equal to the number of examples in the baseline's + distribution. The remaining dimensions must match + the dimensions of the corresponding input tensor + starting from the second dimension. + + - callable function, optionally takes `inputs` as an + argument and either returns a single tensor + or a tuple of those. + + It is recommended that the number of samples in the baselines' + tensors is larger than one. + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a tuple + containing multiple additional arguments including tensors + or any arbitrary python types. These arguments are provided to + forward_func in order, following the arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + return_convergence_delta (bool, optional): Indicates whether to return + convergence delta or not. If `return_convergence_delta` + is set to True convergence delta will be returned in + a tuple following attributions. + Default: False + attribute_to_layer_input (bool, optional): Indicates whether to + compute the attributions with respect to the layer input + or output. If `attribute_to_layer_input` is set to True + then the attributions will be computed with respect to + layer inputs, otherwise it will be computed with respect + to layer outputs. + Note that currently it assumes that both the inputs and + outputs of internal layers are single tensors. + Support for multiple tensors will be added later. + Default: False + custom_attribution_func (callable, optional): A custom function for + computing final attribution scores. This function can take + at least one and at most three arguments with the + following signature: + + - custom_attribution_func(multipliers) + - custom_attribution_func(multipliers, inputs) + - custom_attribution_func(multipliers, inputs, baselines) + + In case this function is not provided, we use the default + logic defined as: multipliers * (inputs - baselines) + It is assumed that all input arguments, `multipliers`, + `inputs` and `baselines` are provided in tuples of same + length. `custom_attribution_func` returns a tuple of + attribution tensors that have the same length as the + `inputs`. + Default: None + + Returns: + **attributions** or 2-element tuple of **attributions**, **delta**: + - **attributions** (*tensor* or tuple of *tensors*): + Attribution score computed based on DeepLift's rescale rule + with respect to layer's inputs or outputs. Attributions + will always be the same size as the provided layer's inputs + or outputs, depending on whether we attribute to the inputs + or outputs of the layer. + Attributions are returned in a tuple based on whether + the layer inputs / outputs are contained in a tuple + from a forward hook. For standard modules, inputs of + a single tensor are usually wrapped in a tuple, while + outputs of a single tensor are not. + - **delta** (*tensor*, returned if return_convergence_delta=True): + This is computed using the property that the + total sum of forward_func(inputs) - forward_func(baselines) + must be very close to the total sum of attributions + computed based on approximated SHAP values using + DeepLift's rescale rule. + Delta is calculated for each example input and baseline pair, + meaning that the number of elements in returned delta tensor + is equal to the + `number of examples in input` * `number of examples + in baseline`. The deltas are ordered in the first place by + input example, followed by the baseline. + Note that the logic described for deltas is guaranteed + when the default logic for attribution computations is used, + meaning that the `custom_attribution_func=None`, otherwise + it is not guaranteed and depends on the specifics of the + `custom_attribution_func`. + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> # creates an instance of LayerDeepLift to interpret target + >>> # class 1 with respect to conv4 layer. + >>> dl = LayerDeepLiftShap(net, net.conv4) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes shap values using deeplift for class 3. + >>> attribution = dl.attribute(input, target=3) + """ + inputs = _format_tensor_into_tuples(inputs) + baselines = _format_callable_baseline(baselines, inputs) + + assert isinstance(baselines[0], torch.Tensor) and baselines[0].shape[0] > 1, ( + "Baselines distribution has to be provided in form of a torch.Tensor" + " with more than one example but found: {}." + " If baselines are provided in shape of scalars or with a single" + " baseline example, `LayerDeepLift`" + " approach can be used instead.".format(baselines[0]) + ) + + # batch sizes + inp_bsz = inputs[0].shape[0] + base_bsz = baselines[0].shape[0] + + ( + exp_inp, + exp_base, + exp_target, + exp_addit_args, + ) = DeepLiftShap._expand_inputs_baselines_targets( + self, baselines, inputs, target, additional_forward_args + ) + attributions = LayerDeepLift.attribute.__wrapped__( # type: ignore + self, + exp_inp, + exp_base, + target=exp_target, + additional_forward_args=exp_addit_args, + return_convergence_delta=cast( + Literal[True, False], return_convergence_delta + ), + attribute_to_layer_input=attribute_to_layer_input, + custom_attribution_func=custom_attribution_func, + ) + if return_convergence_delta: + attributions, delta = attributions + if isinstance(attributions, tuple): + attributions = tuple( + DeepLiftShap._compute_mean_across_baselines( + self, inp_bsz, base_bsz, cast(Tensor, attrib) + ) + for attrib in attributions + ) + else: + attributions = DeepLiftShap._compute_mean_across_baselines( + self, inp_bsz, base_bsz, attributions + ) + if return_convergence_delta: + return attributions, delta + else: + return attributions + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs diff --git a/captum/attr/_core/layer/layer_feature_ablation.py b/captum/attr/_core/layer/layer_feature_ablation.py new file mode 100644 index 0000000000000000000000000000000000000000..75ac885eac3e5e1e923283ce081edacbebda741d --- /dev/null +++ b/captum/attr/_core/layer/layer_feature_ablation.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, List, Tuple, Union + +import torch +from captum._utils.common import ( + _extract_device, + _format_additional_forward_args, + _format_output, + _format_tensor_into_tuples, + _run_forward, +) +from captum._utils.gradient import _forward_layer_eval +from captum._utils.typing import BaselineType, TargetType +from captum.attr._core.feature_ablation import FeatureAblation +from captum.attr._utils.attribution import LayerAttribution, PerturbationAttribution +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module +from torch.nn.parallel.scatter_gather import scatter + + +class LayerFeatureAblation(LayerAttribution, PerturbationAttribution): + r""" + A perturbation based approach to computing layer attribution, involving + replacing values in the input / output of a layer with a given baseline / + reference, and computing the difference in output. By default, each + neuron (scalar input / output value) within the layer is replaced + independently. + Passing a layer mask allows grouping neurons to be + ablated together. + Each neuron in the group will be given the same attribution value + equal to the change in target as a result of ablating the entire neuron + group. + """ + + def __init__( + self, + forward_func: Callable, + layer: Module, + device_ids: Union[None, List[int]] = None, + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + layer (torch.nn.Module): Layer for which attributions are computed. + Output size of attribute matches this layer's input or + output dimensions, depending on whether we attribute to + the inputs or outputs of the layer, corresponding to + attribution of each neuron in the input or output of + this layer. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself + (or otherwise has a device_ids attribute with the device + ID list), then it is not necessary to provide this + argument. + """ + LayerAttribution.__init__(self, forward_func, layer, device_ids) + PerturbationAttribution.__init__(self, forward_func) + + @log_usage() + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + layer_baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + layer_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, + attribute_to_layer_input: bool = False, + perturbations_per_eval: int = 1, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which layer + attributions are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + layer_baselines (scalar, tensor, tuple of scalars or tensors, optional): + Layer baselines define reference values which replace each + layer input / output value when ablated. + Layer baselines should be a single tensor with dimensions + matching the input / output of the target layer (or + broadcastable to match it), based + on whether we are attributing to the input or output + of the target layer. + In the cases when `baselines` is not provided, we internally + use zero as the baseline for each neuron. + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + layer_mask (tensor or tuple of tensors, optional): + layer_mask defines a mask for the layer, grouping + elements of the layer input / output which should be + ablated together. + layer_mask should be a single tensor with dimensions + matching the input / output of the target layer (or + broadcastable to match it), based + on whether we are attributing to the input or output + of the target layer. layer_mask + should contain integers in the range 0 to num_groups + - 1, and all elements with the same value are + considered to be in the same group. + If None, then a layer mask is constructed which assigns + each neuron within the layer as a separate group, which + is ablated independently. + Default: None + attribute_to_layer_input (bool, optional): Indicates whether to + compute the attributions with respect to the layer input + or output. If `attribute_to_layer_input` is set to True + then the attributions will be computed with respect to + layer's inputs, otherwise it will be computed with respect + to layer's outputs. + Note that currently it is assumed that either the input + or the output of the layer, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + perturbations_per_eval (int, optional): Allows ablation of multiple + neuron (groups) to be processed simultaneously in one + call to forward_fn. + Each forward pass will contain a maximum of + perturbations_per_eval * #examples samples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain at most + (perturbations_per_eval * #examples) / num_devices + samples. + Default: 1 + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + Attribution of each neuron in given layer input or + output. Attributions will always be the same size as + the input or output of the given layer, depending on + whether we attribute to the inputs or outputs + of the layer which is decided by the input flag + `attribute_to_layer_input` + Attributions are returned in a tuple if + the layer inputs / outputs contain multiple tensors, + otherwise a single tensor is returned. + + + Examples:: + + >>> # SimpleClassifier takes a single input tensor of size Nx4x4, + >>> # and returns an Nx3 tensor of class probabilities. + >>> # It contains an attribute conv1, which is an instance of nn.conv2d, + >>> # and the output of this layer has dimensions Nx12x3x3. + >>> net = SimpleClassifier() + >>> # Generating random input with size 2 x 4 x 4 + >>> input = torch.randn(2, 4, 4) + >>> # Defining LayerFeatureAblation interpreter + >>> ablator = LayerFeatureAblation(net, net.conv1) + >>> # Computes ablation attribution, ablating each of the 108 + >>> # neurons independently. + >>> attr = ablator.attribute(input, target=1) + + >>> # Alternatively, we may want to ablate neurons in groups, e.g. + >>> # grouping all the layer outputs in the same row. + >>> # This can be done by creating a layer mask as follows, which + >>> # defines the groups of layer inputs / outouts, e.g.: + >>> # +---+---+---+ + >>> # | 0 | 0 | 0 | + >>> # +---+---+---+ + >>> # | 1 | 1 | 1 | + >>> # +---+---+---+ + >>> # | 2 | 2 | 2 | + >>> # +---+---+---+ + >>> # With this mask, all the 36 neurons in a row / channel are ablated + >>> # simultaneously, and the attribution for each neuron in the same + >>> # group (0 - 2) per example are the same. + >>> # The attributions can be calculated as follows: + >>> # layer mask has dimensions 1 x 3 x 3 + >>> layer_mask = torch.tensor([[[0,0,0],[1,1,1], + >>> [2,2,2]]]) + >>> attr = ablator.attribute(input, target=1, + >>> layer_mask=layer_mask) + """ + + def layer_forward_func(*args): + layer_length = args[-1] + layer_input = args[:layer_length] + original_inputs = args[layer_length:-1] + + device_ids = self.device_ids + if device_ids is None: + device_ids = getattr(self.forward_func, "device_ids", None) + + all_layer_inputs = {} + if device_ids is not None: + scattered_layer_input = scatter(layer_input, target_gpus=device_ids) + for device_tensors in scattered_layer_input: + all_layer_inputs[device_tensors[0].device] = device_tensors + else: + all_layer_inputs[layer_input[0].device] = layer_input + + def forward_hook(module, inp, out=None): + device = _extract_device(module, inp, out) + is_layer_tuple = ( + isinstance(out, tuple) + if out is not None + else isinstance(inp, tuple) + ) + if device not in all_layer_inputs: + raise AssertionError( + "Layer input not placed on appropriate " + "device. If using a DataParallel model, either provide the " + "DataParallel model as forward_func or provide device ids" + " to the constructor." + ) + if not is_layer_tuple: + return all_layer_inputs[device][0] + return all_layer_inputs[device] + + hook = None + try: + if attribute_to_layer_input: + hook = self.layer.register_forward_pre_hook(forward_hook) + else: + hook = self.layer.register_forward_hook(forward_hook) + eval = _run_forward(self.forward_func, original_inputs, target=target) + finally: + if hook is not None: + hook.remove() + return eval + + with torch.no_grad(): + inputs = _format_tensor_into_tuples(inputs) + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + layer_eval = _forward_layer_eval( + self.forward_func, + inputs, + self.layer, + additional_forward_args, + device_ids=self.device_ids, + attribute_to_layer_input=attribute_to_layer_input, + ) + layer_eval_len = (len(layer_eval),) + all_inputs = ( + (inputs + additional_forward_args + layer_eval_len) + if additional_forward_args is not None + else inputs + layer_eval_len + ) + + ablator = FeatureAblation(layer_forward_func) + + layer_attribs = ablator.attribute.__wrapped__( + ablator, # self + layer_eval, + baselines=layer_baselines, + additional_forward_args=all_inputs, + feature_mask=layer_mask, + perturbations_per_eval=perturbations_per_eval, + ) + _attr = _format_output(len(layer_attribs) > 1, layer_attribs) + return _attr diff --git a/captum/attr/_core/layer/layer_gradient_shap.py b/captum/attr/_core/layer/layer_gradient_shap.py new file mode 100644 index 0000000000000000000000000000000000000000..9473475cdf5dbaf260c0dd82a70a35224da60558 --- /dev/null +++ b/captum/attr/_core/layer/layer_gradient_shap.py @@ -0,0 +1,474 @@ +#!/usr/bin/env python3 + +import typing +from typing import Any, Callable, cast, List, Tuple, Union + +import numpy as np +import torch +from captum._utils.gradient import _forward_layer_eval, compute_layer_gradients_and_eval +from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr._core.gradient_shap import _scale_input +from captum.attr._core.noise_tunnel import NoiseTunnel +from captum.attr._utils.attribution import GradientAttribution, LayerAttribution +from captum.attr._utils.common import ( + _compute_conv_delta_and_format_attrs, + _format_callable_baseline, + _format_input_baseline, +) +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module + + +class LayerGradientShap(LayerAttribution, GradientAttribution): + r""" + Implements gradient SHAP for layer based on the implementation from SHAP's + primary author. For reference, please, view: + + https://github.com/slundberg/shap\ + #deep-learning-example-with-gradientexplainer-tensorflowkeraspytorch-models + + A Unified Approach to Interpreting Model Predictions + http://papers.nips.cc/paper\ + 7062-a-unified-approach-to-interpreting-model-predictions + + GradientShap approximates SHAP values by computing the expectations of + gradients by randomly sampling from the distribution of baselines/references. + It adds white noise to each input sample `n_samples` times, selects a + random baseline from baselines' distribution and a random point along the + path between the baseline and the input, and computes the gradient of + outputs with respect to selected random points in chosen `layer`. + The final SHAP values represent the expected values of + `gradients * (layer_attr_inputs - layer_attr_baselines)`. + + GradientShap makes an assumption that the input features are independent + and that the explanation model is linear, meaning that the explanations + are modeled through the additive composition of feature effects. + Under those assumptions, SHAP value can be approximated as the expectation + of gradients that are computed for randomly generated `n_samples` input + samples after adding gaussian noise `n_samples` times to each input for + different baselines/references. + + In some sense it can be viewed as an approximation of integrated gradients + by computing the expectations of gradients for different baselines. + + Current implementation uses Smoothgrad from `NoiseTunnel` in order to + randomly draw samples from the distribution of baselines, add noise to input + samples and compute the expectation (smoothgrad). + """ + + def __init__( + self, + forward_func: Callable, + layer: Module, + device_ids: Union[None, List[int]] = None, + multiply_by_inputs: bool = True, + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + layer (torch.nn.Module): Layer for which attributions are computed. + Output size of attribute matches this layer's input or + output dimensions, depending on whether we attribute to + the inputs or outputs of the layer, corresponding to + attribution of each neuron in the input or output of + this layer. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in, + then this type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of layer gradient shap, if `multiply_by_inputs` + is set to True, the sensitivity scores for scaled inputs + are being multiplied by + layer activations for inputs - layer activations for baselines. + + """ + LayerAttribution.__init__(self, forward_func, layer, device_ids) + GradientAttribution.__init__(self, forward_func) + self._multiply_by_inputs = multiply_by_inputs + + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: Union[TensorOrTupleOfTensorsGeneric, Callable], + n_samples: int = 5, + stdevs: Union[float, Tuple[float, ...]] = 0.0, + target: TargetType = None, + additional_forward_args: Any = None, + *, + return_convergence_delta: Literal[True], + attribute_to_layer_input: bool = False, + ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: + ... + + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: Union[TensorOrTupleOfTensorsGeneric, Callable], + n_samples: int = 5, + stdevs: Union[float, Tuple[float, ...]] = 0.0, + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: Literal[False] = False, + attribute_to_layer_input: bool = False, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + ... + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: Union[TensorOrTupleOfTensorsGeneric, Callable], + n_samples: int = 5, + stdevs: Union[float, Tuple[float, ...]] = 0.0, + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: bool = False, + attribute_to_layer_input: bool = False, + ) -> Union[ + Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor] + ]: + r""" + Args: + + inputs (tensor or tuple of tensors): Input which are used to compute + SHAP attribution values for a given `layer`. If `forward_func` + takes a single tensor as input, a single input tensor should + be provided. + If `forward_func` takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + baselines (tensor, tuple of tensors, callable): + Baselines define the starting point from which expectation + is computed and can be provided as: + + - a single tensor, if inputs is a single tensor, with + the first dimension equal to the number of examples + in the baselines' distribution. The remaining dimensions + must match with input tensor's dimension starting from + the second dimension. + + - a tuple of tensors, if inputs is a tuple of tensors, + with the first dimension of any tensor inside the tuple + equal to the number of examples in the baseline's + distribution. The remaining dimensions must match + the dimensions of the corresponding input tensor + starting from the second dimension. + + - callable function, optionally takes `inputs` as an + argument and either returns a single tensor + or a tuple of those. + + It is recommended that the number of samples in the baselines' + tensors is larger than one. + n_samples (int, optional): The number of randomly generated examples + per sample in the input batch. Random examples are + generated by adding gaussian random noise to each sample. + Default: `5` if `n_samples` is not provided. + stdevs (float, or a tuple of floats optional): The standard deviation + of gaussian noise with zero mean that is added to each + input in the batch. If `stdevs` is a single float value + then that same value is used for all inputs. If it is + a tuple, then it must have the same length as the inputs + tuple. In this case, each stdev value in the stdevs tuple + corresponds to the input with the same index in the inputs + tuple. + Default: 0.0 + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It can contain a tuple of ND tensors or + any arbitrary python type of any shape. + In case of the ND tensor the first dimension of the + tensor must correspond to the batch size. It will be + repeated for each `n_steps` for each randomly generated + input sample. + Note that the attributions are not computed with respect + to these arguments. + Default: None + return_convergence_delta (bool, optional): Indicates whether to return + convergence delta or not. If `return_convergence_delta` + is set to True convergence delta will be returned in + a tuple following attributions. + Default: False + attribute_to_layer_input (bool, optional): Indicates whether to + compute the attribution with respect to the layer input + or output. If `attribute_to_layer_input` is set to True + then the attributions will be computed with respect to + layer input, otherwise it will be computed with respect + to layer output. + Note that currently it is assumed that either the input + or the output of internal layer, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + Returns: + **attributions** or 2-element tuple of **attributions**, **delta**: + - **attributions** (*tensor* or tuple of *tensors*): + Attribution score computed based on GradientSHAP with + respect to layer's input or output. Attributions will always + be the same size as the provided layer's inputs or outputs, + depending on whether we attribute to the inputs or outputs + of the layer. + Attributions are returned in a tuple if + the layer inputs / outputs contain multiple tensors, + otherwise a single tensor is returned. + - **delta** (*tensor*, returned if return_convergence_delta=True): + This is computed using the property that the total + sum of forward_func(inputs) - forward_func(baselines) + must be very close to the total sum of the attributions + based on layer gradient SHAP. + Delta is calculated for each example in the input after adding + `n_samples` times gaussian noise to each of them. Therefore, + the dimensionality of the deltas tensor is equal to the + `number of examples in the input` * `n_samples` + The deltas are ordered by each input example and `n_samples` + noisy samples generated for it. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> layer_grad_shap = LayerGradientShap(net, net.linear1) + >>> input = torch.randn(3, 3, 32, 32, requires_grad=True) + >>> # choosing baselines randomly + >>> baselines = torch.randn(20, 3, 32, 32) + >>> # Computes gradient SHAP of output layer when target is equal + >>> # to 0 with respect to the layer linear1. + >>> # Attribution size matches to the size of the linear1 layer + >>> attribution = layer_grad_shap.attribute(input, baselines, + target=5) + + """ + # since `baselines` is a distribution, we can generate it using a function + # rather than passing it as an input argument + baselines = _format_callable_baseline(baselines, inputs) + assert isinstance(baselines[0], torch.Tensor), ( + "Baselines distribution has to be provided in a form " + "of a torch.Tensor {}.".format(baselines[0]) + ) + + input_min_baseline_x_grad = LayerInputBaselineXGradient( + self.forward_func, + self.layer, + device_ids=self.device_ids, + multiply_by_inputs=self.multiplies_by_inputs, + ) + + nt = NoiseTunnel(input_min_baseline_x_grad) + + attributions = nt.attribute.__wrapped__( + nt, # self + inputs, + nt_type="smoothgrad", + nt_samples=n_samples, + stdevs=stdevs, + draw_baseline_from_distrib=True, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + return_convergence_delta=return_convergence_delta, + attribute_to_layer_input=attribute_to_layer_input, + ) + + return attributions + + def has_convergence_delta(self) -> bool: + return True + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs + + +class LayerInputBaselineXGradient(LayerAttribution, GradientAttribution): + def __init__( + self, + forward_func: Callable, + layer: Module, + device_ids: Union[None, List[int]] = None, + multiply_by_inputs: bool = True, + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + layer (torch.nn.Module): Layer for which attributions are computed. + Output size of attribute matches this layer's input or + output dimensions, depending on whether we attribute to + the inputs or outputs of the layer, corresponding to + attribution of each neuron in the input or output of + this layer. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in, + then this type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of layer input minus baseline x gradient, + if `multiply_by_inputs` is set to True, the sensitivity scores + for scaled inputs are being multiplied by + layer activations for inputs - layer activations for baselines. + + """ + LayerAttribution.__init__(self, forward_func, layer, device_ids) + GradientAttribution.__init__(self, forward_func) + self._multiply_by_inputs = multiply_by_inputs + + @typing.overload + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: Union[Tensor, Tuple[Tensor, ...]], + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: Literal[False] = False, + attribute_to_layer_input: bool = False, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + ... + + @typing.overload + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: Union[Tensor, Tuple[Tensor, ...]], + target: TargetType = None, + additional_forward_args: Any = None, + *, + return_convergence_delta: Literal[True], + attribute_to_layer_input: bool = False, + ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: + ... + + @log_usage() + def attribute( # type: ignore + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: Union[Tensor, Tuple[Tensor, ...]], + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: bool = False, + attribute_to_layer_input: bool = False, + ) -> Union[ + Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor] + ]: + inputs, baselines = _format_input_baseline(inputs, baselines) + rand_coefficient = torch.tensor( + np.random.uniform(0.0, 1.0, inputs[0].shape[0]), + device=inputs[0].device, + dtype=inputs[0].dtype, + ) + + input_baseline_scaled = tuple( + _scale_input(input, baseline, rand_coefficient) + for input, baseline in zip(inputs, baselines) + ) + grads, _ = compute_layer_gradients_and_eval( + self.forward_func, + self.layer, + input_baseline_scaled, + target, + additional_forward_args, + device_ids=self.device_ids, + attribute_to_layer_input=attribute_to_layer_input, + ) + + attr_baselines = _forward_layer_eval( + self.forward_func, + baselines, + self.layer, + additional_forward_args=additional_forward_args, + device_ids=self.device_ids, + attribute_to_layer_input=attribute_to_layer_input, + ) + + attr_inputs = _forward_layer_eval( + self.forward_func, + inputs, + self.layer, + additional_forward_args=additional_forward_args, + device_ids=self.device_ids, + attribute_to_layer_input=attribute_to_layer_input, + ) + + if self.multiplies_by_inputs: + input_baseline_diffs = tuple( + input - baseline for input, baseline in zip(attr_inputs, attr_baselines) + ) + attributions = tuple( + input_baseline_diff * grad + for input_baseline_diff, grad in zip(input_baseline_diffs, grads) + ) + else: + attributions = grads + + return _compute_conv_delta_and_format_attrs( + self, + return_convergence_delta, + attributions, + baselines, + inputs, + additional_forward_args, + target, + cast(Union[Literal[True], Literal[False]], len(attributions) > 1), + ) + + def has_convergence_delta(self) -> bool: + return True + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs diff --git a/captum/attr/_core/layer/layer_gradient_x_activation.py b/captum/attr/_core/layer/layer_gradient_x_activation.py new file mode 100644 index 0000000000000000000000000000000000000000..a63a5d7abe90d01f58e3baab9997e50c71c3caef --- /dev/null +++ b/captum/attr/_core/layer/layer_gradient_x_activation.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, List, Tuple, Union + +from captum._utils.common import ( + _format_additional_forward_args, + _format_output, + _format_tensor_into_tuples, +) +from captum._utils.gradient import compute_layer_gradients_and_eval +from captum._utils.typing import ModuleOrModuleList, TargetType +from captum.attr._utils.attribution import GradientAttribution, LayerAttribution +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module + + +class LayerGradientXActivation(LayerAttribution, GradientAttribution): + r""" + Computes element-wise product of gradient and activation for selected + layer on given inputs. + """ + + def __init__( + self, + forward_func: Callable, + layer: ModuleOrModuleList, + device_ids: Union[None, List[int]] = None, + multiply_by_inputs: bool = True, + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + layer (torch.nn.Module or list(torch.nn.Module)): Layer or layers + for which attributions are computed. + Output size of attribute matches this layer's input or + output dimensions, depending on whether we attribute to + the inputs or outputs of the layer, corresponding to + attribution of each neuron in the input or output of + this layer. If multiple layers are provided, attributions + are returned as a list, each element corresponding to the + attributions of the corresponding layer. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in, + then this type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of layer gradient x activation, if `multiply_by_inputs` + is set to True, final sensitivity scores are being multiplied by + layer activations for inputs. + + """ + LayerAttribution.__init__(self, forward_func, layer, device_ids) + GradientAttribution.__init__(self, forward_func) + self._multiply_by_inputs = multiply_by_inputs + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs + + @log_usage() + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + target: TargetType = None, + additional_forward_args: Any = None, + attribute_to_layer_input: bool = False, + ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which attributions + are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + attribute_to_layer_input (bool, optional): Indicates whether to + compute the attribution with respect to the layer input + or output. If `attribute_to_layer_input` is set to True + then the attributions will be computed with respect to + layer input, otherwise it will be computed with respect + to layer output. + Default: False + + Returns: + *tensor* or tuple of *tensors* or *list* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors* or *list*): + Product of gradient and activation for each + neuron in given layer output. + Attributions will always be the same size as the + output of the given layer. + Attributions are returned in a tuple if + the layer inputs / outputs contain multiple tensors, + otherwise a single tensor is returned. + If multiple layers are provided, attributions + are returned as a list, each element corresponding to the + activations of the corresponding layer. + + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> # It contains an attribute conv1, which is an instance of nn.conv2d, + >>> # and the output of this layer has dimensions Nx12x32x32. + >>> net = ImageClassifier() + >>> layer_ga = LayerGradientXActivation(net, net.conv1) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes layer activation x gradient for class 3. + >>> # attribution size matches layer output, Nx12x32x32 + >>> attribution = layer_ga.attribute(input, 3) + """ + inputs = _format_tensor_into_tuples(inputs) + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + # Returns gradient of output with respect to + # hidden layer and hidden layer evaluated at each input. + layer_gradients, layer_evals = compute_layer_gradients_and_eval( + self.forward_func, + self.layer, + inputs, + target, + additional_forward_args, + device_ids=self.device_ids, + attribute_to_layer_input=attribute_to_layer_input, + ) + if isinstance(self.layer, Module): + return _format_output( + len(layer_evals) > 1, + self.multiply_gradient_acts(layer_gradients, layer_evals), + ) + else: + return [ + _format_output( + len(layer_evals[i]) > 1, + self.multiply_gradient_acts(layer_gradients[i], layer_evals[i]), + ) + for i in range(len(self.layer)) + ] + + def multiply_gradient_acts( + self, gradients: Tuple[Tensor, ...], evals: Tuple[Tensor, ...] + ) -> Tuple[Tensor, ...]: + return tuple( + single_gradient * single_eval + if self.multiplies_by_inputs + else single_gradient + for single_gradient, single_eval in zip(gradients, evals) + ) diff --git a/captum/attr/_core/layer/layer_integrated_gradients.py b/captum/attr/_core/layer/layer_integrated_gradients.py new file mode 100644 index 0000000000000000000000000000000000000000..2e769a56586d5a14c3e3009e0e65d41fdd302611 --- /dev/null +++ b/captum/attr/_core/layer/layer_integrated_gradients.py @@ -0,0 +1,528 @@ +#!/usr/bin/env python3 +import functools +import warnings +from typing import Any, Callable, List, overload, Tuple, Union + +import torch +from captum._utils.common import ( + _extract_device, + _format_additional_forward_args, + _format_outputs, +) +from captum._utils.gradient import _forward_layer_eval, _run_forward +from captum._utils.typing import BaselineType, Literal, ModuleOrModuleList, TargetType +from captum.attr._core.integrated_gradients import IntegratedGradients +from captum.attr._utils.attribution import GradientAttribution, LayerAttribution +from captum.attr._utils.common import ( + _format_input_baseline, + _tensorize_baseline, + _validate_input, +) +from captum.log import log_usage +from torch import Tensor +from torch.nn.parallel.scatter_gather import scatter + + +class LayerIntegratedGradients(LayerAttribution, GradientAttribution): + r""" + Layer Integrated Gradients is a variant of Integrated Gradients that assigns + an importance score to layer inputs or outputs, depending on whether we + attribute to the former or to the latter one. + + Integrated Gradients is an axiomatic model interpretability algorithm that + attributes / assigns an importance score to each input feature by approximating + the integral of gradients of the model's output with respect to the inputs + along the path (straight line) from given baselines / references to inputs. + + Baselines can be provided as input arguments to attribute method. + To approximate the integral we can choose to use either a variant of + Riemann sum or Gauss-Legendre quadrature rule. + + More details regarding the integrated gradients method can be found in the + original paper: + https://arxiv.org/abs/1703.01365 + + """ + + def __init__( + self, + forward_func: Callable, + layer: ModuleOrModuleList, + device_ids: Union[None, List[int]] = None, + multiply_by_inputs: bool = True, + ) -> None: + r""" + Args: + forward_func (callable): The forward function of the model or any + modification of it + layer (ModuleOrModuleList): + Layer or list of layers for which attributions are computed. + For each layer the output size of the attribute matches + this layer's input or output dimensions, depending on + whether we attribute to the inputs or outputs of the + layer, corresponding to the attribution of each neuron + in the input or output of this layer. + + Please note that layers to attribute on cannot be + dependent on each other. That is, a subset of layers in + `layer` cannot produce the inputs for another layer. + + For example, if your model is of a simple linked-list + based graph structure (think nn.Sequence), e.g. x -> l1 + -> l2 -> l3 -> output. If you pass in any one of those + layers, you cannot pass in another due to the + dependence, e.g. if you pass in l2 you cannot pass in + l1 or l3. + + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in, + then this type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of layer integrated gradients, if `multiply_by_inputs` + is set to True, final sensitivity scores are being multiplied by + layer activations for inputs - layer activations for baselines. + + """ + LayerAttribution.__init__(self, forward_func, layer, device_ids=device_ids) + GradientAttribution.__init__(self, forward_func) + self.ig = IntegratedGradients(forward_func, multiply_by_inputs) + + if isinstance(layer, list) and len(layer) > 1: + warnings.warn( + "Multiple layers provided. Please ensure that each layer is" + "**not** solely solely dependent on the outputs of" + "another layer. Please refer to the documentation for more" + "detail." + ) + + @overload + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: BaselineType, + target: TargetType, + additional_forward_args: Any, + n_steps: int, + method: str, + internal_batch_size: Union[None, int], + return_convergence_delta: Literal[False], + attribute_to_layer_input: bool, + ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: + ... + + @overload + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: BaselineType, + target: TargetType, + additional_forward_args: Any, + n_steps: int, + method: str, + internal_batch_size: Union[None, int], + return_convergence_delta: Literal[True], + attribute_to_layer_input: bool, + ) -> Tuple[ + Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]], + Tensor, + ]: + ... + + @overload + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + internal_batch_size: Union[None, int] = None, + return_convergence_delta: bool = False, + attribute_to_layer_input: bool = False, + ) -> Union[ + Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]], + Tuple[ + Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]], + Tensor, + ], + ]: + ... + + @log_usage() + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + internal_batch_size: Union[None, int] = None, + return_convergence_delta: bool = False, + attribute_to_layer_input: bool = False, + ) -> Union[ + Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]], + Tuple[ + Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]], + Tensor, + ], + ]: + r""" + This method attributes the output of the model with given target index + (in case it is provided, otherwise it assumes that output is a + scalar) to layer inputs or outputs of the model, depending on whether + `attribute_to_layer_input` is set to True or False, using the approach + described above. + + In addition to that it also returns, if `return_convergence_delta` is + set to True, integral approximation delta based on the completeness + property of integrated gradients. + + Args: + + inputs (tensor or tuple of tensors): Input for which layer integrated + gradients are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define the starting point from which integral + is computed and can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. It will be + repeated for each of `n_steps` along the integrated + path. For all other types, the given argument is used + for all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + n_steps (int, optional): The number of steps used by the approximation + method. Default: 50. + method (string, optional): Method for approximating the integral, + one of `riemann_right`, `riemann_left`, `riemann_middle`, + `riemann_trapezoid` or `gausslegendre`. + Default: `gausslegendre` if no method is provided. + internal_batch_size (int, optional): Divides total #steps * #examples + data points into chunks of size at most internal_batch_size, + which are computed (forward / backward passes) + sequentially. internal_batch_size must be at least equal to + #examples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain internal_batch_size / num_devices examples. + If internal_batch_size is None, then all evaluations are + processed in one batch. + Default: None + return_convergence_delta (bool, optional): Indicates whether to return + convergence delta or not. If `return_convergence_delta` + is set to True convergence delta will be returned in + a tuple following attributions. + Default: False + attribute_to_layer_input (bool, optional): Indicates whether to + compute the attribution with respect to the layer input + or output. If `attribute_to_layer_input` is set to True + then the attributions will be computed with respect to + layer input, otherwise it will be computed with respect + to layer output. + Note that currently it is assumed that either the input + or the output of internal layer, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + Returns: + **attributions** or 2-element tuple of **attributions**, **delta**: + - **attributions** (*tensor*, tuple of *tensors* or tuple of *tensors*): + Integrated gradients with respect to `layer`'s inputs or + outputs. Attributions will always be the same size and + dimensionality as the input or output of the given layer, + depending on whether we attribute to the inputs or outputs + of the layer which is decided by the input flag + `attribute_to_layer_input`. + + For a single layer, attributions are returned in a tuple if + the layer inputs / outputs contain multiple tensors, + otherwise a single tensor is returned. + + For multiple layers, attributions will always be + returned as a list. Each element in this list will be + equivalent to that of a single layer output, i.e. in the + case that one layer, in the given layers, inputs / outputs + multiple tensors: the corresponding output element will be + a tuple of tensors. The ordering of the outputs will be + the same order as the layers given in the constructor. + - **delta** (*tensor*, returned if return_convergence_delta=True): + The difference between the total approximated and true + integrated gradients. This is computed using the property + that the total sum of forward_func(inputs) - + forward_func(baselines) must equal the total sum of the + integrated gradient. + Delta is calculated per example, meaning that the number of + elements in returned delta tensor is equal to the number of + of examples in inputs. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> # It contains an attribute conv1, which is an instance of nn.conv2d, + >>> # and the output of this layer has dimensions Nx12x32x32. + >>> net = ImageClassifier() + >>> lig = LayerIntegratedGradients(net, net.conv1) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes layer integrated gradients for class 3. + >>> # attribution size matches layer output, Nx12x32x32 + >>> attribution = lig.attribute(input, target=3) + """ + inps, baselines = _format_input_baseline(inputs, baselines) + _validate_input(inps, baselines, n_steps, method) + + baselines = _tensorize_baseline(inps, baselines) + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + + def flatten_tuple(tup): + return tuple( + sum((list(x) if isinstance(x, (tuple, list)) else [x] for x in tup), []) + ) + + if self.device_ids is None: + self.device_ids = getattr(self.forward_func, "device_ids", None) + + inputs_layer = _forward_layer_eval( + self.forward_func, + inps, + self.layer, + device_ids=self.device_ids, + additional_forward_args=additional_forward_args, + attribute_to_layer_input=attribute_to_layer_input, + ) + + # if we have one output + if not isinstance(self.layer, list): + inputs_layer = (inputs_layer,) + + num_outputs = [1 if isinstance(x, Tensor) else len(x) for x in inputs_layer] + num_outputs_cumsum = torch.cumsum( + torch.IntTensor([0] + num_outputs), dim=0 # type: ignore + ) + inputs_layer = flatten_tuple(inputs_layer) + + baselines_layer = _forward_layer_eval( + self.forward_func, + baselines, + self.layer, + device_ids=self.device_ids, + additional_forward_args=additional_forward_args, + attribute_to_layer_input=attribute_to_layer_input, + ) + baselines_layer = flatten_tuple(baselines_layer) + + # inputs -> these inputs are scaled + def gradient_func( + forward_fn: Callable, + inputs: Union[Tensor, Tuple[Tensor, ...]], + target_ind: TargetType = None, + additional_forward_args: Any = None, + ) -> Tuple[Tensor, ...]: + if self.device_ids is None or len(self.device_ids) == 0: + scattered_inputs = (inputs,) + else: + # scatter method does not have a precise enough return type in its + # stub, so suppress the type warning. + scattered_inputs = scatter( # type:ignore + inputs, target_gpus=self.device_ids + ) + + scattered_inputs_dict = { + scattered_input[0].device: scattered_input + for scattered_input in scattered_inputs + } + + with torch.autograd.set_grad_enabled(True): + + def layer_forward_hook( + module, hook_inputs, hook_outputs=None, layer_idx=0 + ): + device = _extract_device(module, hook_inputs, hook_outputs) + is_layer_tuple = ( + isinstance(hook_outputs, tuple) + # hook_outputs is None if attribute_to_layer_input == True + if hook_outputs is not None + else isinstance(hook_inputs, tuple) + ) + + if is_layer_tuple: + return scattered_inputs_dict[device][ + num_outputs_cumsum[layer_idx] : num_outputs_cumsum[ + layer_idx + 1 + ] + ] + + return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]] + + hooks = [] + try: + + layers = self.layer + if not isinstance(layers, list): + layers = [self.layer] + + for layer_idx, layer in enumerate(layers): + hook = None + # TODO: + # Allow multiple attribute_to_layer_input flags for + # each layer, i.e. attribute_to_layer_input[layer_idx] + if attribute_to_layer_input: + hook = layer.register_forward_pre_hook( + functools.partial( + layer_forward_hook, layer_idx=layer_idx + ) + ) + else: + hook = layer.register_forward_hook( + functools.partial( + layer_forward_hook, layer_idx=layer_idx + ) + ) + + hooks.append(hook) + + output = _run_forward( + self.forward_func, tuple(), target_ind, additional_forward_args + ) + finally: + for hook in hooks: + if hook is not None: + hook.remove() + + assert output[0].numel() == 1, ( + "Target not provided when necessary, cannot" + " take gradient with respect to multiple outputs." + ) + # torch.unbind(forward_out) is a list of scalar tensor tuples and + # contains batch_size * #steps elements + grads = torch.autograd.grad(torch.unbind(output), inputs) + return grads + + self.ig.gradient_func = gradient_func + all_inputs = ( + (inps + additional_forward_args) + if additional_forward_args is not None + else inps + ) + + attributions = self.ig.attribute.__wrapped__( # type: ignore + self.ig, # self + inputs_layer, + baselines=baselines_layer, + target=target, + additional_forward_args=all_inputs, + n_steps=n_steps, + method=method, + internal_batch_size=internal_batch_size, + return_convergence_delta=False, + ) + + # handle multiple outputs + output: List[Tuple[Tensor, ...]] = [ + tuple( + attributions[ + int(num_outputs_cumsum[i]) : int(num_outputs_cumsum[i + 1]) + ] + ) + for i in range(len(num_outputs)) + ] + + if return_convergence_delta: + start_point, end_point = baselines, inps + # computes approximation error based on the completeness axiom + delta = self.compute_convergence_delta( + attributions, + start_point, + end_point, + additional_forward_args=additional_forward_args, + target=target, + ) + return _format_outputs(isinstance(self.layer, list), output), delta + return _format_outputs(isinstance(self.layer, list), output) + + def has_convergence_delta(self) -> bool: + return True + + @property + def multiplies_by_inputs(self): + return self.ig.multiplies_by_inputs diff --git a/captum/attr/_core/layer/layer_lrp.py b/captum/attr/_core/layer/layer_lrp.py new file mode 100644 index 0000000000000000000000000000000000000000..e72bbbaddc8f3bcd65bf1ed482551ea12ed50a27 --- /dev/null +++ b/captum/attr/_core/layer/layer_lrp.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +import typing +from typing import Any, cast, List, Tuple, Union + +from captum._utils.common import ( + _format_tensor_into_tuples, + _reduce_list, + _sort_key_list, +) +from captum._utils.gradient import ( + apply_gradient_requirements, + compute_gradients, + undo_gradient_requirements, +) +from captum._utils.typing import ( + Literal, + ModuleOrModuleList, + TargetType, + TensorOrTupleOfTensorsGeneric, +) +from captum.attr._core.lrp import LRP +from captum.attr._utils.attribution import LayerAttribution +from torch import Tensor +from torch.nn import Module + + +class LayerLRP(LRP, LayerAttribution): + r""" + Layer-wise relevance propagation is based on a backward propagation + mechanism applied sequentially to all layers of the model. Here, the + model output score represents the initial relevance which is decomposed + into values for each neuron of the underlying layers. The decomposition + is defined by rules that are chosen for each layer, involving its weights + and activations. Details on the model can be found in the original paper + [https://doi.org/10.1371/journal.pone.0130140]. The implementation is + inspired by the tutorial of the same group + [https://doi.org/10.1016/j.dsp.2017.10.011] and the publication by + Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW]. + """ + + def __init__(self, model: Module, layer: ModuleOrModuleList) -> None: + """ + Args: + + model (module): The forward function of the model or + any modification of it. Custom rules for a given layer need to + be defined as attribute + `module.rule` and need to be of type PropagationRule. + Model cannot contain any in-place nonlinear submodules; + these are not supported by the register_full_backward_hook + PyTorch API starting from PyTorch v1.9. + + + layer (torch.nn.Module or list(torch.nn.Module)): Layer or layers + for which attributions are computed. + The size and dimensionality of the attributions + corresponds to the size and dimensionality of the layer's + input or output depending on whether we attribute to the + inputs or outputs of the layer. If value is None, the + relevance for all layers is returned in attribution. + """ + LayerAttribution.__init__(self, model, layer) + LRP.__init__(self, model) + if hasattr(self.model, "device_ids"): + self.device_ids = cast(List[int], self.model.device_ids) + + @typing.overload # type: ignore + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: Literal[False] = False, + attribute_to_layer_input: bool = False, + verbose: bool = False, + ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: + ... + + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + *, + return_convergence_delta: Literal[True], + attribute_to_layer_input: bool = False, + verbose: bool = False, + ) -> Tuple[ + Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]], + Union[Tensor, List[Tensor]], + ]: + ... + + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: bool = False, + attribute_to_layer_input: bool = False, + verbose: bool = False, + ) -> Union[ + Tensor, + Tuple[Tensor, ...], + List[Union[Tensor, Tuple[Tensor, ...]]], + Tuple[ + Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]], + Union[Tensor, List[Tensor]], + ], + ]: + r""" + + Args: + inputs (tensor or tuple of tensors): Input for which relevance is + propagated. + If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (tuple, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a tuple + containing multiple additional arguments including tensors + or any arbitrary python types. These arguments are provided to + forward_func in order, following the arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + + return_convergence_delta (bool, optional): Indicates whether to return + convergence delta or not. If `return_convergence_delta` + is set to True convergence delta will be returned in + a tuple following attributions. + Default: False + + attribute_to_layer_input (bool, optional): Indicates whether to + compute the attribution with respect to the layer input + or output. If `attribute_to_layer_input` is set to True + then the attributions will be computed with respect to + layer input, otherwise it will be computed with respect + to layer output. + + verbose (bool, optional): Indicates whether information on application + of rules is printed during propagation. + Default: False + + Returns: + *tensor* or tuple of *tensors* of **attributions** or 2-element tuple of + **attributions**, **delta** or lists of **attributions** and **delta**: + - **attributions** (*tensor* or tuple of *tensors*): + The propagated relevance values with respect to each + input feature. Attributions will always + be the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. The sum of attributions + is one and not corresponding to the prediction score as in other + implementations. If attributions for all layers are returned + (layer=None) a list of tensors or tuples of tensors is returned + with entries for each layer. + - **delta** (*tensor* or list of *tensors* + returned if return_convergence_delta=True): + Delta is calculated per example, meaning that the number of + elements in returned delta tensor is equal to the number of + of examples in input. + If attributions for all layers are returned (layer=None) a list + of tensors is returned with entries for + each layer. + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. It has one + >>> # Conv2D and a ReLU layer. + >>> net = ImageClassifier() + >>> lrp = LRP(net, net.conv1) + >>> input = torch.randn(3, 3, 32, 32) + >>> # Attribution size matches input size: 3x3x32x32 + >>> attribution = lrp.attribute(input, target=5) + + """ + self.verbose = verbose + self._original_state_dict = self.model.state_dict() + self.layers = [] + self._get_layers(self.model) + self._check_and_attach_rules() + self.attribute_to_layer_input = attribute_to_layer_input + self.backward_handles = [] + self.forward_handles = [] + + inputs = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(inputs) + + try: + # 1. Forward pass + output = self._compute_output_and_change_weights( + inputs, target, additional_forward_args + ) + self._register_forward_hooks() + # 2. Forward pass + backward pass + _ = compute_gradients( + self._forward_fn_wrapper, inputs, target, additional_forward_args + ) + relevances = self._get_output_relevance(output) + finally: + self._restore_model() + undo_gradient_requirements(inputs, gradient_mask) + + if return_convergence_delta: + delta: Union[Tensor, List[Tensor]] + if isinstance(self.layer, list): + delta = [] + for relevance_layer in relevances: + delta.append( + self.compute_convergence_delta(relevance_layer, output) + ) + else: + delta = self.compute_convergence_delta( + cast(Tuple[Tensor, ...], relevances), output + ) + return relevances, delta # type: ignore + else: + return relevances # type: ignore + + def _get_single_output_relevance(self, layer, output): + if self.attribute_to_layer_input: + normalized_relevances = layer.rule.relevance_input + else: + normalized_relevances = layer.rule.relevance_output + key_list = _sort_key_list(list(normalized_relevances.keys()), self.device_ids) + normalized_relevances = _reduce_list( + [normalized_relevances[device_id] for device_id in key_list] + ) + + if isinstance(normalized_relevances, tuple): + return tuple( + normalized_relevance + * output.reshape((-1,) + (1,) * (normalized_relevance.dim() - 1)) + for normalized_relevance in normalized_relevances + ) + else: + return normalized_relevances * output.reshape( + (-1,) + (1,) * (normalized_relevances.dim() - 1) + ) + + def _get_output_relevance(self, output): + if isinstance(self.layer, list): + relevances = [] + for layer in self.layer: + relevances.append(self._get_single_output_relevance(layer, output)) + return relevances + else: + return self._get_single_output_relevance(self.layer, output) + + @staticmethod + def _convert_list_to_tuple( + relevances: Union[List[Any], Tuple[Any, ...]] + ) -> Tuple[Any, ...]: + if isinstance(relevances, list): + return tuple(relevances) + else: + return relevances diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py new file mode 100644 index 0000000000000000000000000000000000000000..479fc1502271f96fd742c8a824de60ca540e6cac --- /dev/null +++ b/captum/attr/_core/lime.py @@ -0,0 +1,1242 @@ +#!/usr/bin/env python3 +import inspect +import math +import typing +import warnings +from typing import Any, Callable, cast, List, Optional, Tuple, Union + +import torch +from captum._utils.common import ( + _expand_additional_forward_args, + _expand_target, + _flatten_tensor_or_tuple, + _format_output, + _format_tensor_into_tuples, + _is_tuple, + _reduce_list, + _run_forward, +) +from captum._utils.models.linear_model import SkLearnLasso +from captum._utils.models.model import Model +from captum._utils.progress import progress +from captum._utils.typing import ( + BaselineType, + Literal, + TargetType, + TensorOrTupleOfTensorsGeneric, +) +from captum.attr._utils.attribution import PerturbationAttribution +from captum.attr._utils.batching import _batch_example_iterator +from captum.attr._utils.common import ( + _construct_default_feature_mask, + _format_input_baseline, +) +from captum.log import log_usage +from torch import Tensor +from torch.nn import CosineSimilarity +from torch.utils.data import DataLoader, TensorDataset + + +class LimeBase(PerturbationAttribution): + r""" + Lime is an interpretability method that trains an interpretable surrogate model + by sampling points around a specified input example and using model evaluations + at these points to train a simpler interpretable 'surrogate' model, such as a + linear model. + + LimeBase provides a generic framework to train a surrogate interpretable model. + This differs from most other attribution methods, since the method returns a + representation of the interpretable model (e.g. coefficients of the linear model). + For a similar interface to other perturbation-based attribution methods, please use + the Lime child class, which defines specific transformations for the interpretable + model. + + LimeBase allows sampling points in either the interpretable space or the original + input space to train the surrogate model. The interpretable space is a feature + vector used to train the surrogate interpretable model; this feature space is often + of smaller dimensionality than the original feature space in order for the surrogate + model to be more interpretable. + + If sampling in the interpretable space, a transformation function must be provided + to define how a vector sampled in the interpretable space can be transformed into + an example in the original input space. If sampling in the original input space, a + transformation function must be provided to define how the input can be transformed + into its interpretable vector representation. + + More details regarding LIME can be found in the original paper: + https://arxiv.org/abs/1602.04938 + """ + + def __init__( + self, + forward_func: Callable, + interpretable_model: Model, + similarity_func: Callable, + perturb_func: Callable, + perturb_interpretable_space: bool, + from_interp_rep_transform: Optional[Callable], + to_interp_rep_transform: Optional[Callable], + ) -> None: + r""" + + Args: + + + forward_func (callable): The forward function of the model or any + modification of it. If a batch is provided as input for + attribution, it is expected that forward_func returns a scalar + representing the entire batch. + interpretable_model (Model): Model object to train interpretable model. + A Model object provides a `fit` method to train the model, + given a dataloader, with batches containing three tensors: + + - interpretable_inputs: Tensor + [2D num_samples x num_interp_features], + - expected_outputs: Tensor [1D num_samples], + - weights: Tensor [1D num_samples] + + The model object must also provide a `representation` method to + access the appropriate coefficients or representation of the + interpretable model after fitting. + Some predefined interpretable linear models are provided in + captum._utils.models.linear_model including wrappers around + SkLearn linear models as well as SGD-based PyTorch linear + models. + + Note that calling fit multiple times should retrain the + interpretable model, each attribution call reuses + the same given interpretable model object. + similarity_func (callable): Function which takes a single sample + along with its corresponding interpretable representation + and returns the weight of the interpretable sample for + training interpretable model. Weight is generally + determined based on similarity to the original input. + The original paper refers to this as a similarity kernel. + + The expected signature of this callable is: + + >>> similarity_func( + >>> original_input: Tensor or tuple of Tensors, + >>> perturbed_input: Tensor or tuple of Tensors, + >>> perturbed_interpretable_input: + >>> Tensor [2D 1 x num_interp_features], + >>> **kwargs: Any + >>> ) -> float or Tensor containing float scalar + + perturbed_input and original_input will be the same type and + contain tensors of the same shape (regardless of whether or not + the sampling function returns inputs in the interpretable + space). original_input is the same as the input provided + when calling attribute. + + All kwargs passed to the attribute method are + provided as keyword arguments (kwargs) to this callable. + perturb_func (callable): Function which returns a single + sampled input, generally a perturbation of the original + input, which is used to train the interpretable surrogate + model. Function can return samples in either + the original input space (matching type and tensor shapes + of original input) or in the interpretable input space, + which is a vector containing the intepretable features. + Alternatively, this function can return a generator + yielding samples to train the interpretable surrogate + model, and n_samples perturbations will be sampled + from this generator. + + The expected signature of this callable is: + + >>> perturb_func( + >>> original_input: Tensor or tuple of Tensors, + >>> **kwargs: Any + >>> ) -> Tensor or tuple of Tensors or + >>> generator yielding tensor or tuple of Tensors + + All kwargs passed to the attribute method are + provided as keyword arguments (kwargs) to this callable. + + Returned sampled input should match the input type (Tensor + or Tuple of Tensor and corresponding shapes) if + perturb_interpretable_space = False. If + perturb_interpretable_space = True, the return type should + be a single tensor of shape 1 x num_interp_features, + corresponding to the representation of the + sample to train the interpretable model. + + All kwargs passed to the attribute method are + provided as keyword arguments (kwargs) to this callable. + perturb_interpretable_space (bool): Indicates whether + perturb_func returns a sample in the interpretable space + (tensor of shape 1 x num_interp_features) or a sample + in the original space, matching the format of the original + input. Once sampled, inputs can be converted to / from + the interpretable representation with either + to_interp_rep_transform or from_interp_rep_transform. + from_interp_rep_transform (callable): Function which takes a + single sampled interpretable representation (tensor + of shape 1 x num_interp_features) and returns + the corresponding representation in the input space + (matching shapes of original input to attribute). + + This argument is necessary if perturb_interpretable_space + is True, otherwise None can be provided for this argument. + + The expected signature of this callable is: + + >>> from_interp_rep_transform( + >>> curr_sample: Tensor [2D 1 x num_interp_features] + >>> original_input: Tensor or Tuple of Tensors, + >>> **kwargs: Any + >>> ) -> Tensor or tuple of Tensors + + Returned sampled input should match the type of original_input + and corresponding tensor shapes. + + All kwargs passed to the attribute method are + provided as keyword arguments (kwargs) to this callable. + + to_interp_rep_transform (callable): Function which takes a + sample in the original input space and converts to + its interpretable representation (tensor + of shape 1 x num_interp_features). + + This argument is necessary if perturb_interpretable_space + is False, otherwise None can be provided for this argument. + + The expected signature of this callable is: + + >>> to_interp_rep_transform( + >>> curr_sample: Tensor or Tuple of Tensors, + >>> original_input: Tensor or Tuple of Tensors, + >>> **kwargs: Any + >>> ) -> Tensor [2D 1 x num_interp_features] + + curr_sample will match the type of original_input + and corresponding tensor shapes. + + All kwargs passed to the attribute method are + provided as keyword arguments (kwargs) to this callable. + """ + PerturbationAttribution.__init__(self, forward_func) + self.interpretable_model = interpretable_model + self.similarity_func = similarity_func + self.perturb_func = perturb_func + self.perturb_interpretable_space = perturb_interpretable_space + self.from_interp_rep_transform = from_interp_rep_transform + self.to_interp_rep_transform = to_interp_rep_transform + + if self.perturb_interpretable_space: + assert ( + self.from_interp_rep_transform is not None + ), "Must provide transform from interpretable space to original input space" + " when sampling from interpretable space." + else: + assert ( + self.to_interp_rep_transform is not None + ), "Must provide transform from original input space to interpretable space" + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + n_samples: int = 50, + perturbations_per_eval: int = 1, + show_progress: bool = False, + **kwargs, + ) -> Tensor: + r""" + This method attributes the output of the model with given target index + (in case it is provided, otherwise it assumes that output is a + scalar) to the inputs of the model using the approach described above. + It trains an interpretable model and returns a representation of the + interpretable model. + + It is recommended to only provide a single example as input (tensors + with first dimension or batch size = 1). This is because LIME is generally + used for sample-based interpretability, training a separate interpretable + model to explain a model's prediction on each individual example. + + A batch of inputs can be provided as inputs only if forward_func + returns a single value per batch (e.g. loss). + The interpretable feature representation should still have shape + 1 x num_interp_features, corresponding to the interpretable + representation for the full batch, and perturbations_per_eval + must be set to 1. + + Args: + + inputs (tensor or tuple of tensors): Input for which LIME + is computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + target (int, tuple, tensor or list, optional): Output indices for + which surrogate model is trained + (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. For all other types, + the given argument is used for all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + n_samples (int, optional): The number of samples of the original + model used to train the surrogate interpretable model. + Default: `50` if `n_samples` is not provided. + perturbations_per_eval (int, optional): Allows multiple samples + to be processed simultaneously in one call to forward_fn. + Each forward pass will contain a maximum of + perturbations_per_eval * #examples samples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain at most + (perturbations_per_eval * #examples) / num_devices + samples. + If the forward function returns a single scalar per batch, + perturbations_per_eval must be set to 1. + Default: 1 + show_progress (bool, optional): Displays the progress of computation. + It will try to use tqdm if available for advanced features + (e.g. time estimation). Otherwise, it will fallback to + a simple output of progress. + Default: False + **kwargs (Any, optional): Any additional arguments necessary for + sampling and transformation functions (provided to + constructor). + Default: None + + Returns: + **interpretable model representation**: + - **interpretable model representation* (*Any*): + A representation of the interpretable model trained. The return + type matches the return type of train_interpretable_model_func. + For example, this could contain coefficients of a + linear surrogate model. + + Examples:: + + >>> # SimpleClassifier takes a single input tensor of + >>> # float features with size N x 5, + >>> # and returns an Nx3 tensor of class probabilities. + >>> net = SimpleClassifier() + >>> + >>> # We will train an interpretable model with the same + >>> # features by simply sampling with added Gaussian noise + >>> # to the inputs and training a model to predict the + >>> # score of the target class. + >>> + >>> # For interpretable model training, we will use sklearn + >>> # linear model in this example. We have provided wrappers + >>> # around sklearn linear models to fit the Model interface. + >>> # Any arguments provided to the sklearn constructor can also + >>> # be provided to the wrapper, e.g.: + >>> # SkLearnLinearModel("linear_model.Ridge", alpha=2.0) + >>> from captum._utils.models.linear_model import SkLearnLinearModel + >>> + >>> + >>> # Define similarity kernel (exponential kernel based on L2 norm) + >>> def similarity_kernel( + >>> original_input: Tensor, + >>> perturbed_input: Tensor, + >>> perturbed_interpretable_input: Tensor, + >>> **kwargs)->Tensor: + >>> # kernel_width will be provided to attribute as a kwarg + >>> kernel_width = kwargs["kernel_width"] + >>> l2_dist = torch.norm(original_input - perturbed_input) + >>> return torch.exp(- (l2_dist**2) / (kernel_width**2)) + >>> + >>> + >>> # Define sampling function + >>> # This function samples in original input space + >>> def perturb_func( + >>> original_input: Tensor, + >>> **kwargs)->Tensor: + >>> return original_input + torch.randn_like(original_input) + >>> + >>> # For this example, we are setting the interpretable input to + >>> # match the model input, so the to_interp_rep_transform + >>> # function simply returns the input. In most cases, the interpretable + >>> # input will be different and may have a smaller feature set, so + >>> # an appropriate transformation function should be provided. + >>> + >>> def to_interp_transform(curr_sample, original_inp, + >>> **kwargs): + >>> return curr_sample + >>> + >>> # Generating random input with size 1 x 5 + >>> input = torch.randn(1, 5) + >>> # Defining LimeBase interpreter + >>> lime_attr = LimeBase(net, + SkLearnLinearModel("linear_model.Ridge"), + similarity_func=similarity_kernel, + perturb_func=perturb_func, + perturb_interpretable_space=False, + from_interp_rep_transform=None, + to_interp_rep_transform=to_interp_transform) + >>> # Computes interpretable model, returning coefficients of linear + >>> # model. + >>> attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1) + """ + with torch.no_grad(): + inp_tensor = ( + cast(Tensor, inputs) if isinstance(inputs, Tensor) else inputs[0] + ) + device = inp_tensor.device + + interpretable_inps = [] + similarities = [] + outputs = [] + + curr_model_inputs = [] + expanded_additional_args = None + expanded_target = None + perturb_generator = None + if inspect.isgeneratorfunction(self.perturb_func): + perturb_generator = self.perturb_func(inputs, **kwargs) + + if show_progress: + attr_progress = progress( + total=math.ceil(n_samples / perturbations_per_eval), + desc=f"{self.get_name()} attribution", + ) + attr_progress.update(0) + + batch_count = 0 + for _ in range(n_samples): + if perturb_generator: + try: + curr_sample = next(perturb_generator) + except StopIteration: + warnings.warn( + "Generator completed prior to given n_samples iterations!" + ) + break + else: + curr_sample = self.perturb_func(inputs, **kwargs) + batch_count += 1 + if self.perturb_interpretable_space: + interpretable_inps.append(curr_sample) + curr_model_inputs.append( + self.from_interp_rep_transform( # type: ignore + curr_sample, inputs, **kwargs + ) + ) + else: + curr_model_inputs.append(curr_sample) + interpretable_inps.append( + self.to_interp_rep_transform( # type: ignore + curr_sample, inputs, **kwargs + ) + ) + curr_sim = self.similarity_func( + inputs, curr_model_inputs[-1], interpretable_inps[-1], **kwargs + ) + similarities.append( + curr_sim.flatten() + if isinstance(curr_sim, Tensor) + else torch.tensor([curr_sim], device=device) + ) + + if len(curr_model_inputs) == perturbations_per_eval: + if expanded_additional_args is None: + expanded_additional_args = _expand_additional_forward_args( + additional_forward_args, len(curr_model_inputs) + ) + if expanded_target is None: + expanded_target = _expand_target(target, len(curr_model_inputs)) + + model_out = self._evaluate_batch( + curr_model_inputs, + expanded_target, + expanded_additional_args, + device, + ) + + if show_progress: + attr_progress.update() + + outputs.append(model_out) + + curr_model_inputs = [] + + if len(curr_model_inputs) > 0: + expanded_additional_args = _expand_additional_forward_args( + additional_forward_args, len(curr_model_inputs) + ) + expanded_target = _expand_target(target, len(curr_model_inputs)) + model_out = self._evaluate_batch( + curr_model_inputs, + expanded_target, + expanded_additional_args, + device, + ) + if show_progress: + attr_progress.update() + outputs.append(model_out) + + if show_progress: + attr_progress.close() + + combined_interp_inps = torch.cat(interpretable_inps).double() + combined_outputs = ( + torch.cat(outputs) + if len(outputs[0].shape) > 0 + else torch.stack(outputs) + ).double() + combined_sim = ( + torch.cat(similarities) + if len(similarities[0].shape) > 0 + else torch.stack(similarities) + ).double() + dataset = TensorDataset( + combined_interp_inps, combined_outputs, combined_sim + ) + self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count)) + return self.interpretable_model.representation() + + def _evaluate_batch( + self, + curr_model_inputs: List[TensorOrTupleOfTensorsGeneric], + expanded_target: TargetType, + expanded_additional_args: Any, + device: torch.device, + ): + model_out = _run_forward( + self.forward_func, + _reduce_list(curr_model_inputs), + expanded_target, + expanded_additional_args, + ) + if isinstance(model_out, Tensor): + assert model_out.numel() == len(curr_model_inputs), ( + "Number of outputs is not appropriate, must return " + "one output per perturbed input" + ) + if isinstance(model_out, Tensor): + return model_out.flatten() + return torch.tensor([model_out], device=device) + + def has_convergence_delta(self) -> bool: + return False + + @property + def multiplies_by_inputs(self): + return False + + +# Default transformations and methods +# for Lime child implementation. + + +def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs): + assert ( + "feature_mask" in kwargs + ), "Must provide feature_mask to use default interpretable representation transform" + assert ( + "baselines" in kwargs + ), "Must provide baselines to use default interpretable representation transfrom" + feature_mask = kwargs["feature_mask"] + if isinstance(feature_mask, Tensor): + binary_mask = curr_sample[0][feature_mask].bool() + return ( + binary_mask.to(original_inputs.dtype) * original_inputs + + (~binary_mask).to(original_inputs.dtype) * kwargs["baselines"] + ) + else: + binary_mask = tuple( + curr_sample[0][feature_mask[j]].bool() for j in range(len(feature_mask)) + ) + return tuple( + binary_mask[j].to(original_inputs[j].dtype) * original_inputs[j] + + (~binary_mask[j]).to(original_inputs[j].dtype) * kwargs["baselines"][j] + for j in range(len(feature_mask)) + ) + + +def get_exp_kernel_similarity_function( + distance_mode: str = "cosine", kernel_width: float = 1.0 +) -> Callable: + r""" + This method constructs an appropriate similarity function to compute + weights for perturbed sample in LIME. Distance between the original + and perturbed inputs is computed based on the provided distance mode, + and the distance is passed through an exponential kernel with given + kernel width to convert to a range between 0 and 1. + + The callable returned can be provided as the similarity_fn for + Lime or LimeBase. + + Args: + + distance_mode (str, optional): Distance mode can be either "cosine" or + "euclidean" corresponding to either cosine distance + or Euclidean distance respectively. Distance is computed + by flattening the original inputs and perturbed inputs + (concatenating tuples of inputs if necessary) and computing + distances between the resulting vectors. + Default: "cosine" + kernel_width (float, optional): + Kernel width for exponential kernel applied to distance. + Default: 1.0 + + Returns: + + *Callable*: + - **similarity_fn** (*Callable*): + Similarity function. This callable can be provided as the + similarity_fn for Lime or LimeBase. + """ + + def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs): + flattened_original_inp = _flatten_tensor_or_tuple(original_inp).float() + flattened_perturbed_inp = _flatten_tensor_or_tuple(perturbed_inp).float() + if distance_mode == "cosine": + cos_sim = CosineSimilarity(dim=0) + distance = 1 - cos_sim(flattened_original_inp, flattened_perturbed_inp) + elif distance_mode == "euclidean": + distance = torch.norm(flattened_original_inp - flattened_perturbed_inp) + else: + raise ValueError("distance_mode must be either cosine or euclidean.") + return math.exp(-1 * (distance ** 2) / (2 * (kernel_width ** 2))) + + return default_exp_kernel + + +def default_perturb_func(original_inp, **kwargs): + assert ( + "num_interp_features" in kwargs + ), "Must provide num_interp_features to use default interpretable sampling function" + if isinstance(original_inp, Tensor): + device = original_inp.device + else: + device = original_inp[0].device + + probs = torch.ones(1, kwargs["num_interp_features"]) * 0.5 + return torch.bernoulli(probs).to(device=device).long() + + +def construct_feature_mask(feature_mask, formatted_inputs): + if feature_mask is None: + feature_mask, num_interp_features = _construct_default_feature_mask( + formatted_inputs + ) + else: + feature_mask = _format_tensor_into_tuples(feature_mask) + min_interp_features = int( + min( + torch.min(single_mask).item() + for single_mask in feature_mask + if single_mask.numel() + ) + ) + if min_interp_features != 0: + warnings.warn( + "Minimum element in feature mask is not 0, shifting indices to" + " start at 0." + ) + feature_mask = tuple( + single_mask - min_interp_features for single_mask in feature_mask + ) + + num_interp_features = int( + max( + torch.max(single_mask).item() + for single_mask in feature_mask + if single_mask.numel() + ) + + 1 + ) + return feature_mask, num_interp_features + + +class Lime(LimeBase): + r""" + Lime is an interpretability method that trains an interpretable surrogate model + by sampling points around a specified input example and using model evaluations + at these points to train a simpler interpretable 'surrogate' model, such as a + linear model. + + Lime provides a more specific implementation than LimeBase in order to expose + a consistent API with other perturbation-based algorithms. For more general + use of the LIME framework, consider using the LimeBase class directly and + defining custom sampling and transformation to / from interpretable + representation functions. + + Lime assumes that the interpretable representation is a binary vector, + corresponding to some elements in the input being set to their baseline value + if the corresponding binary interpretable feature value is 0 or being set + to the original input value if the corresponding binary interpretable + feature value is 1. Input values can be grouped to correspond to the same + binary interpretable feature using a feature mask provided when calling + attribute, similar to other perturbation-based attribution methods. + + One example of this setting is when applying Lime to an image classifier. + Pixels in an image can be grouped into super-pixels or segments, which + correspond to interpretable features, provided as a feature_mask when + calling attribute. Sampled binary vectors convey whether a super-pixel + is on (retains the original input values) or off (set to the corresponding + baseline value, e.g. black image). An interpretable linear model is trained + with input being the binary vectors and outputs as the corresponding scores + of the image classifier with the appropriate super-pixels masked based on the + binary vector. Coefficients of the trained surrogate + linear model convey the importance of each super-pixel. + + More details regarding LIME can be found in the original paper: + https://arxiv.org/abs/1602.04938 + """ + + def __init__( + self, + forward_func: Callable, + interpretable_model: Optional[Model] = None, + similarity_func: Optional[Callable] = None, + perturb_func: Optional[Callable] = None, + ) -> None: + r""" + + Args: + + + forward_func (callable): The forward function of the model or any + modification of it + interpretable_model (optional, Model): Model object to train + interpretable model. + + This argument is optional and defaults to SkLearnLasso(alpha=0.01), + which is a wrapper around the Lasso linear model in SkLearn. + This requires having sklearn version >= 0.23 available. + + Other predefined interpretable linear models are provided in + captum._utils.models.linear_model. + + Alternatively, a custom model object must provide a `fit` method to + train the model, given a dataloader, with batches containing + three tensors: + + - interpretable_inputs: Tensor + [2D num_samples x num_interp_features], + - expected_outputs: Tensor [1D num_samples], + - weights: Tensor [1D num_samples] + + The model object must also provide a `representation` method to + access the appropriate coefficients or representation of the + interpretable model after fitting. + + Note that calling fit multiple times should retrain the + interpretable model, each attribution call reuses + the same given interpretable model object. + similarity_func (optional, callable): Function which takes a single sample + along with its corresponding interpretable representation + and returns the weight of the interpretable sample for + training the interpretable model. + This is often referred to as a similarity kernel. + + This argument is optional and defaults to a function which + applies an exponential kernel to the consine distance between + the original input and perturbed input, with a kernel width + of 1.0. + + A similarity function applying an exponential + kernel to cosine / euclidean distances can be constructed + using the provided get_exp_kernel_similarity_function in + captum.attr._core.lime. + + Alternately, a custom callable can also be provided. + The expected signature of this callable is: + + >>> def similarity_func( + >>> original_input: Tensor or tuple of Tensors, + >>> perturbed_input: Tensor or tuple of Tensors, + >>> perturbed_interpretable_input: + >>> Tensor [2D 1 x num_interp_features], + >>> **kwargs: Any + >>> ) -> float or Tensor containing float scalar + + perturbed_input and original_input will be the same type and + contain tensors of the same shape, with original_input + being the same as the input provided when calling attribute. + + kwargs includes baselines, feature_mask, num_interp_features + (integer, determined from feature mask). + perturb_func (optional, callable): Function which returns a single + sampled input, which is a binary vector of length + num_interp_features, or a generator of such tensors. + + This function is optional, the default function returns + a binary vector where each element is selected + independently and uniformly at random. Custom + logic for selecting sampled binary vectors can + be implemented by providing a function with the + following expected signature: + + >>> perturb_func( + >>> original_input: Tensor or tuple of Tensors, + >>> **kwargs: Any + >>> ) -> Tensor [Binary 2D Tensor 1 x num_interp_features] + >>> or generator yielding such tensors + + kwargs includes baselines, feature_mask, num_interp_features + (integer, determined from feature mask). + + """ + if interpretable_model is None: + interpretable_model = SkLearnLasso(alpha=0.01) + + if similarity_func is None: + similarity_func = get_exp_kernel_similarity_function() + + if perturb_func is None: + perturb_func = default_perturb_func + + LimeBase.__init__( + self, + forward_func, + interpretable_model, + similarity_func, + perturb_func, + True, + default_from_interp_rep_transform, + None, + ) + + @log_usage() + def attribute( # type: ignore + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, + n_samples: int = 50, + perturbations_per_eval: int = 1, + return_input_shape: bool = True, + show_progress: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + This method attributes the output of the model with given target index + (in case it is provided, otherwise it assumes that output is a + scalar) to the inputs of the model using the approach described above, + training an interpretable model and returning a representation of the + interpretable model. + + It is recommended to only provide a single example as input (tensors + with first dimension or batch size = 1). This is because LIME is generally + used for sample-based interpretability, training a separate interpretable + model to explain a model's prediction on each individual example. + + A batch of inputs can also be provided as inputs, similar to + other perturbation-based attribution methods. In this case, if forward_fn + returns a scalar per example, attributions will be computed for each + example independently, with a separate interpretable model trained for each + example. Note that provided similarity and perturbation functions will be + provided each example separately (first dimension = 1) in this case. + If forward_fn returns a scalar per batch (e.g. loss), attributions will + still be computed using a single interpretable model for the full batch. + In this case, similarity and perturbation functions will be provided the + same original input containing the full batch. + + The number of interpretable features is determined from the provided + feature mask, or if none is provided, from the default feature mask, + which considers each scalar input as a separate feature. It is + generally recommended to provide a feature mask which groups features + into a small number of interpretable features / components (e.g. + superpixels in images). + + Args: + + inputs (tensor or tuple of tensors): Input for which LIME + is computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define reference value which replaces each + feature when the corresponding interpretable feature + is set to 0. + Baselines can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which surrogate model is trained + (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. It will be + repeated for each of `n_steps` along the integrated + path. For all other types, the given argument is used + for all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + feature_mask (tensor or tuple of tensors, optional): + feature_mask defines a mask for the input, grouping + features which correspond to the same + interpretable feature. feature_mask + should contain the same number of tensors as inputs. + Each tensor should + be the same size as the corresponding input or + broadcastable to match the input tensor. Values across + all tensors should be integers in the range 0 to + num_interp_features - 1, and indices corresponding to the + same feature should have the same value. + Note that features are grouped across tensors + (unlike feature ablation and occlusion), so + if the same index is used in different tensors, those + features are still grouped and added simultaneously. + If None, then a feature mask is constructed which assigns + each scalar within a tensor as a separate feature. + Default: None + n_samples (int, optional): The number of samples of the original + model used to train the surrogate interpretable model. + Default: `50` if `n_samples` is not provided. + perturbations_per_eval (int, optional): Allows multiple samples + to be processed simultaneously in one call to forward_fn. + Each forward pass will contain a maximum of + perturbations_per_eval * #examples samples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain at most + (perturbations_per_eval * #examples) / num_devices + samples. + If the forward function returns a single scalar per batch, + perturbations_per_eval must be set to 1. + Default: 1 + return_input_shape (bool, optional): Determines whether the returned + tensor(s) only contain the coefficients for each interp- + retable feature from the trained surrogate model, or + whether the returned attributions match the input shape. + When return_input_shape is True, the return type of attribute + matches the input shape, with each element containing the + coefficient of the corresponding interpretale feature. + All elements with the same value in the feature mask + will contain the same coefficient in the returned + attributions. If return_input_shape is False, a 1D + tensor is returned, containing only the coefficients + of the trained interpreatable models, with length + num_interp_features. + show_progress (bool, optional): Displays the progress of computation. + It will try to use tqdm if available for advanced features + (e.g. time estimation). Otherwise, it will fallback to + a simple output of progress. + Default: False + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + The attributions with respect to each input feature. + If return_input_shape = True, attributions will be + the same size as the provided inputs, with each value + providing the coefficient of the corresponding + interpretale feature. + If return_input_shape is False, a 1D + tensor is returned, containing only the coefficients + of the trained interpreatable models, with length + num_interp_features. + Examples:: + + >>> # SimpleClassifier takes a single input tensor of size Nx4x4, + >>> # and returns an Nx3 tensor of class probabilities. + >>> net = SimpleClassifier() + + >>> # Generating random input with size 1 x 4 x 4 + >>> input = torch.randn(1, 4, 4) + + >>> # Defining Lime interpreter + >>> lime = Lime(net) + >>> # Computes attribution, with each of the 4 x 4 = 16 + >>> # features as a separate interpretable feature + >>> attr = lime.attribute(input, target=1, n_samples=200) + + >>> # Alternatively, we can group each 2x2 square of the inputs + >>> # as one 'interpretable' feature and perturb them together. + >>> # This can be done by creating a feature mask as follows, which + >>> # defines the feature groups, e.g.: + >>> # +---+---+---+---+ + >>> # | 0 | 0 | 1 | 1 | + >>> # +---+---+---+---+ + >>> # | 0 | 0 | 1 | 1 | + >>> # +---+---+---+---+ + >>> # | 2 | 2 | 3 | 3 | + >>> # +---+---+---+---+ + >>> # | 2 | 2 | 3 | 3 | + >>> # +---+---+---+---+ + >>> # With this mask, all inputs with the same value are set to their + >>> # baseline value, when the corresponding binary interpretable + >>> # feature is set to 0. + >>> # The attributions can be calculated as follows: + >>> # feature mask has dimensions 1 x 4 x 4 + >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1], + >>> [2,2,3,3],[2,2,3,3]]]) + + >>> # Computes interpretable model and returning attributions + >>> # matching input shape. + >>> attr = lime.attribute(input, target=1, feature_mask=feature_mask) + """ + return self._attribute_kwargs( + inputs=inputs, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + feature_mask=feature_mask, + n_samples=n_samples, + perturbations_per_eval=perturbations_per_eval, + return_input_shape=return_input_shape, + show_progress=show_progress, + ) + + def _attribute_kwargs( # type: ignore + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, + n_samples: int = 25, + perturbations_per_eval: int = 1, + return_input_shape: bool = True, + show_progress: bool = False, + **kwargs, + ) -> TensorOrTupleOfTensorsGeneric: + is_inputs_tuple = _is_tuple(inputs) + formatted_inputs, baselines = _format_input_baseline(inputs, baselines) + bsz = formatted_inputs[0].shape[0] + + feature_mask, num_interp_features = construct_feature_mask( + feature_mask, formatted_inputs + ) + + if num_interp_features > 10000: + warnings.warn( + "Attempting to construct interpretable model with > 10000 features." + "This can be very slow or lead to OOM issues. Please provide a feature" + "mask which groups input features to reduce the number of interpretable" + "features. " + ) + + coefs: Tensor + if bsz > 1: + test_output = _run_forward( + self.forward_func, inputs, target, additional_forward_args + ) + if isinstance(test_output, Tensor) and torch.numel(test_output) > 1: + if torch.numel(test_output) == bsz: + warnings.warn( + "You are providing multiple inputs for Lime / Kernel SHAP " + "attributions. This trains a separate interpretable model " + "for each example, which can be time consuming. It is " + "recommended to compute attributions for one example at a time." + ) + output_list = [] + for ( + curr_inps, + curr_target, + curr_additional_args, + curr_baselines, + curr_feature_mask, + ) in _batch_example_iterator( + bsz, + formatted_inputs, + target, + additional_forward_args, + baselines, + feature_mask, + ): + coefs = super().attribute.__wrapped__( + self, + inputs=curr_inps if is_inputs_tuple else curr_inps[0], + target=curr_target, + additional_forward_args=curr_additional_args, + n_samples=n_samples, + perturbations_per_eval=perturbations_per_eval, + baselines=curr_baselines + if is_inputs_tuple + else curr_baselines[0], + feature_mask=curr_feature_mask + if is_inputs_tuple + else curr_feature_mask[0], + num_interp_features=num_interp_features, + show_progress=show_progress, + **kwargs, + ) + if return_input_shape: + output_list.append( + self._convert_output_shape( + curr_inps, + curr_feature_mask, + coefs, + num_interp_features, + is_inputs_tuple, + ) + ) + else: + output_list.append(coefs.reshape(1, -1)) # type: ignore + + return _reduce_list(output_list) + else: + raise AssertionError( + "Invalid number of outputs, forward function should return a" + "scalar per example or a scalar per input batch." + ) + else: + assert perturbations_per_eval == 1, ( + "Perturbations per eval must be 1 when forward function" + "returns single value per batch!" + ) + + coefs = super().attribute.__wrapped__( + self, + inputs=inputs, + target=target, + additional_forward_args=additional_forward_args, + n_samples=n_samples, + perturbations_per_eval=perturbations_per_eval, + baselines=baselines if is_inputs_tuple else baselines[0], + feature_mask=feature_mask if is_inputs_tuple else feature_mask[0], + num_interp_features=num_interp_features, + show_progress=show_progress, + **kwargs, + ) + if return_input_shape: + return self._convert_output_shape( + formatted_inputs, + feature_mask, + coefs, + num_interp_features, + is_inputs_tuple, + ) + else: + return coefs + + @typing.overload + def _convert_output_shape( + self, + formatted_inp: Tuple[Tensor, ...], + feature_mask: Tuple[Tensor, ...], + coefs: Tensor, + num_interp_features: int, + is_inputs_tuple: Literal[True], + ) -> Tuple[Tensor, ...]: + ... + + @typing.overload + def _convert_output_shape( + self, + formatted_inp: Tuple[Tensor, ...], + feature_mask: Tuple[Tensor, ...], + coefs: Tensor, + num_interp_features: int, + is_inputs_tuple: Literal[False], + ) -> Tensor: + ... + + def _convert_output_shape( + self, + formatted_inp: Tuple[Tensor, ...], + feature_mask: Tuple[Tensor, ...], + coefs: Tensor, + num_interp_features: int, + is_inputs_tuple: bool, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + coefs = coefs.flatten() + attr = [ + torch.zeros_like(single_inp, dtype=torch.float) + for single_inp in formatted_inp + ] + for tensor_ind in range(len(formatted_inp)): + for single_feature in range(num_interp_features): + attr[tensor_ind] += ( + coefs[single_feature].item() + * (feature_mask[tensor_ind] == single_feature).float() + ) + return _format_output(is_inputs_tuple, tuple(attr)) diff --git a/captum/attr/_core/lrp.py b/captum/attr/_core/lrp.py new file mode 100644 index 0000000000000000000000000000000000000000..e11d0b85445dd52310f0d131cf2738fa8a447c1f --- /dev/null +++ b/captum/attr/_core/lrp.py @@ -0,0 +1,422 @@ +#!/usr/bin/env python3 + +import typing +from collections import defaultdict +from typing import Any, cast, List, Tuple, Union + +import torch.nn as nn +from captum._utils.common import ( + _format_output, + _format_tensor_into_tuples, + _is_tuple, + _register_backward_hook, + _run_forward, +) +from captum._utils.gradient import ( + apply_gradient_requirements, + undo_gradient_requirements, +) +from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr._utils.attribution import GradientAttribution +from captum.attr._utils.common import _sum_rows +from captum.attr._utils.custom_modules import Addition_Module +from captum.attr._utils.lrp_rules import EpsilonRule, PropagationRule +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module +from torch.utils.hooks import RemovableHandle + + +class LRP(GradientAttribution): + r""" + Layer-wise relevance propagation is based on a backward propagation + mechanism applied sequentially to all layers of the model. Here, the + model output score represents the initial relevance which is decomposed + into values for each neuron of the underlying layers. The decomposition + is defined by rules that are chosen for each layer, involving its weights + and activations. Details on the model can be found in the original paper + [https://doi.org/10.1371/journal.pone.0130140]. The implementation is + inspired by the tutorial of the same group + [https://doi.org/10.1016/j.dsp.2017.10.011] and the publication by + Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW]. + """ + + def __init__(self, model: Module) -> None: + r""" + Args: + + model (module): The forward function of the model or any modification of + it. Custom rules for a given layer need to be defined as attribute + `module.rule` and need to be of type PropagationRule. If no rule is + specified for a layer, a pre-defined default rule for the module type + is used. Model cannot contain any in-place nonlinear submodules; + these are not supported by the register_full_backward_hook + PyTorch API starting from PyTorch v1.9. + + """ + GradientAttribution.__init__(self, model) + self.model = model + self._check_rules() + + @property + def multiplies_by_inputs(self) -> bool: + return True + + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: Literal[False] = False, + verbose: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + ... + + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + *, + return_convergence_delta: Literal[True], + verbose: bool = False, + ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: + ... + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + additional_forward_args: Any = None, + return_convergence_delta: bool = False, + verbose: bool = False, + ) -> Union[ + TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] + ]: + r""" + Args: + inputs (tensor or tuple of tensors): Input for which relevance is + propagated. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (tuple, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a tuple + containing multiple additional arguments including tensors + or any arbitrary python types. These arguments are provided to + forward_func in order, following the arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + + return_convergence_delta (bool, optional): Indicates whether to return + convergence delta or not. If `return_convergence_delta` + is set to True convergence delta will be returned in + a tuple following attributions. + Default: False + + verbose (bool, optional): Indicates whether information on application + of rules is printed during propagation. + + Returns: + *tensor* or tuple of *tensors* of **attributions** + or 2-element tuple of **attributions**, **delta**:: + - **attributions** (*tensor* or tuple of *tensors*): + The propagated relevance values with respect to each + input feature. The values are normalized by the output score + value (sum(relevance)=1). To obtain values comparable to other + methods or implementations these values need to be multiplied + by the output score. Attributions will always + be the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. The sum of attributions + is one and not corresponding to the prediction score as in other + implementations. + - **delta** (*tensor*, returned if return_convergence_delta=True): + Delta is calculated per example, meaning that the number of + elements in returned delta tensor is equal to the number of + of examples in the inputs. + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. It has one + >>> # Conv2D and a ReLU layer. + >>> net = ImageClassifier() + >>> lrp = LRP(net) + >>> input = torch.randn(3, 3, 32, 32) + >>> # Attribution size matches input size: 3x3x32x32 + >>> attribution = lrp.attribute(input, target=5) + + """ + self.verbose = verbose + self._original_state_dict = self.model.state_dict() + self.layers: List[Module] = [] + self._get_layers(self.model) + self._check_and_attach_rules() + self.backward_handles: List[RemovableHandle] = [] + self.forward_handles: List[RemovableHandle] = [] + + is_inputs_tuple = _is_tuple(inputs) + inputs = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(inputs) + + try: + # 1. Forward pass: Change weights of layers according to selected rules. + output = self._compute_output_and_change_weights( + inputs, target, additional_forward_args + ) + # 2. Forward pass + backward pass: Register hooks to configure relevance + # propagation and execute back-propagation. + self._register_forward_hooks() + normalized_relevances = self.gradient_func( + self._forward_fn_wrapper, inputs, target, additional_forward_args + ) + relevances = tuple( + normalized_relevance + * output.reshape((-1,) + (1,) * (normalized_relevance.dim() - 1)) + for normalized_relevance in normalized_relevances + ) + finally: + self._restore_model() + + undo_gradient_requirements(inputs, gradient_mask) + + if return_convergence_delta: + return ( + _format_output(is_inputs_tuple, relevances), + self.compute_convergence_delta(relevances, output), + ) + else: + return _format_output(is_inputs_tuple, relevances) # type: ignore + + def has_convergence_delta(self) -> bool: + return True + + def compute_convergence_delta( + self, attributions: Union[Tensor, Tuple[Tensor, ...]], output: Tensor + ) -> Tensor: + """ + Here, we use the completeness property of LRP: The relevance is conserved + during the propagation through the models' layers. Therefore, the difference + between the sum of attribution (relevance) values and model output is taken as + the convergence delta. It should be zero for functional attribution. However, + when rules with an epsilon value are used for stability reasons, relevance is + absorbed during propagation and the convergence delta is non-zero. + + Args: + + attributions (tensor or tuple of tensors): Attribution scores that + are precomputed by an attribution algorithm. + Attributions can be provided in form of a single tensor + or a tuple of those. It is assumed that attribution + tensor's dimension 0 corresponds to the number of + examples, and if multiple input tensors are provided, + the examples must be aligned appropriately. + + output (tensor with single element): The output value with respect to which + the attribution values are computed. This value corresponds to + the target score of a classification model. + + Returns: + *tensor*: + - **delta** Difference of relevance in output layer and input layer. + """ + if isinstance(attributions, tuple): + for attr in attributions: + summed_attr = cast( + Tensor, sum(_sum_rows(attr) for attr in attributions) + ) + else: + summed_attr = _sum_rows(attributions) + return output.flatten() - summed_attr.flatten() + + def _get_layers(self, model: Module) -> None: + for layer in model.children(): + if len(list(layer.children())) == 0: + self.layers.append(layer) + else: + self._get_layers(layer) + + def _check_and_attach_rules(self) -> None: + for layer in self.layers: + if hasattr(layer, "rule"): + layer.activations = {} # type: ignore + layer.rule.relevance_input = defaultdict(list) # type: ignore + layer.rule.relevance_output = {} # type: ignore + pass + elif type(layer) in SUPPORTED_LAYERS_WITH_RULES.keys(): + layer.activations = {} # type: ignore + layer.rule = SUPPORTED_LAYERS_WITH_RULES[type(layer)]() # type: ignore + layer.rule.relevance_input = defaultdict(list) # type: ignore + layer.rule.relevance_output = {} # type: ignore + elif type(layer) in SUPPORTED_NON_LINEAR_LAYERS: + layer.rule = None # type: ignore + else: + raise TypeError( + ( + f"Module of type {type(layer)} has no rule defined and no" + "default rule exists for this module type. Please, set a rule" + "explicitly for this module and assure that it is appropriate" + "for this type of layer." + ) + ) + + def _check_rules(self) -> None: + for module in self.model.modules(): + if hasattr(module, "rule"): + if ( + not isinstance(module.rule, PropagationRule) + and module.rule is not None + ): + raise TypeError( + ( + f"Please select propagation rules inherited from class " + f"PropagationRule for module: {module}" + ) + ) + + def _register_forward_hooks(self) -> None: + for layer in self.layers: + if type(layer) in SUPPORTED_NON_LINEAR_LAYERS: + backward_handle = _register_backward_hook( + layer, PropagationRule.backward_hook_activation, self + ) + self.backward_handles.append(backward_handle) + else: + forward_handle = layer.register_forward_hook( + layer.rule.forward_hook # type: ignore + ) + self.forward_handles.append(forward_handle) + if self.verbose: + print(f"Applied {layer.rule} on layer {layer}") + + def _register_weight_hooks(self) -> None: + for layer in self.layers: + if layer.rule is not None: + forward_handle = layer.register_forward_hook( + layer.rule.forward_hook_weights # type: ignore + ) + self.forward_handles.append(forward_handle) + + def _register_pre_hooks(self) -> None: + for layer in self.layers: + if layer.rule is not None: + forward_handle = layer.register_forward_pre_hook( + layer.rule.forward_pre_hook_activations # type: ignore + ) + self.forward_handles.append(forward_handle) + + def _compute_output_and_change_weights( + self, + inputs: Tuple[Tensor, ...], + target: TargetType, + additional_forward_args: Any, + ) -> Tensor: + try: + self._register_weight_hooks() + output = _run_forward(self.model, inputs, target, additional_forward_args) + finally: + self._remove_forward_hooks() + # Register pre_hooks that pass the initial activations from before weight + # adjustments as inputs to the layers with adjusted weights. This procedure + # is important for graph generation in the 2nd forward pass. + self._register_pre_hooks() + return output + + def _remove_forward_hooks(self) -> None: + for forward_handle in self.forward_handles: + forward_handle.remove() + + def _remove_backward_hooks(self) -> None: + for backward_handle in self.backward_handles: + backward_handle.remove() + for layer in self.layers: + if hasattr(layer.rule, "_handle_input_hooks"): + for handle in layer.rule._handle_input_hooks: # type: ignore + handle.remove() + if hasattr(layer.rule, "_handle_output_hook"): + layer.rule._handle_output_hook.remove() # type: ignore + + def _remove_rules(self) -> None: + for layer in self.layers: + if hasattr(layer, "rule"): + del layer.rule + + def _clear_properties(self) -> None: + for layer in self.layers: + if hasattr(layer, "activation"): + del layer.activation + + def _restore_state(self) -> None: + self.model.load_state_dict(self._original_state_dict) # type: ignore + + def _restore_model(self) -> None: + self._restore_state() + self._remove_backward_hooks() + self._remove_forward_hooks() + self._remove_rules() + self._clear_properties() + + def _forward_fn_wrapper(self, *inputs: Tensor) -> Tensor: + """ + Wraps a forward function with addition of zero as a workaround to + https://github.com/pytorch/pytorch/issues/35802 discussed in + https://github.com/pytorch/captum/issues/143#issuecomment-611750044 + + #TODO: Remove when bugs are fixed + """ + adjusted_inputs = tuple( + input + 0 if input is not None else input for input in inputs + ) + return self.model(*adjusted_inputs) + + +SUPPORTED_LAYERS_WITH_RULES = { + nn.MaxPool1d: EpsilonRule, + nn.MaxPool2d: EpsilonRule, + nn.MaxPool3d: EpsilonRule, + nn.Conv2d: EpsilonRule, + nn.AvgPool2d: EpsilonRule, + nn.AdaptiveAvgPool2d: EpsilonRule, + nn.Linear: EpsilonRule, + nn.BatchNorm2d: EpsilonRule, + Addition_Module: EpsilonRule, +} + +SUPPORTED_NON_LINEAR_LAYERS = [nn.ReLU, nn.Dropout, nn.Tanh] diff --git a/captum/attr/_core/neuron/__init__.py b/captum/attr/_core/neuron/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/attr/_core/neuron/neuron_conductance.py b/captum/attr/_core/neuron/neuron_conductance.py new file mode 100644 index 0000000000000000000000000000000000000000..dec6b39b018a6a1ab1e95afd50882456a56a4636 --- /dev/null +++ b/captum/attr/_core/neuron/neuron_conductance.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python3 +import warnings +from typing import Any, Callable, List, Tuple, Union + +import torch +from captum._utils.common import ( + _expand_additional_forward_args, + _expand_target, + _format_additional_forward_args, + _format_output, + _is_tuple, + _verify_select_neuron, +) +from captum._utils.gradient import compute_layer_gradients_and_eval +from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr._utils.approximation_methods import approximation_parameters +from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution +from captum.attr._utils.batching import _batch_attribution +from captum.attr._utils.common import ( + _format_input_baseline, + _reshape_and_sum, + _validate_input, +) +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module + + +class NeuronConductance(NeuronAttribution, GradientAttribution): + r""" + Computes conductance with respect to particular hidden neuron. The + returned output is in the shape of the input, showing the attribution + / conductance of each input feature to the selected hidden layer neuron. + The details of the approach can be found here: + https://arxiv.org/abs/1805.12233 + """ + + def __init__( + self, + forward_func: Callable, + layer: Module, + device_ids: Union[None, List[int]] = None, + multiply_by_inputs: bool = True, + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + layer (torch.nn.Module): Layer for which neuron attributions are computed. + Attributions for a particular neuron in the input or output + of this layer are computed using the argument neuron_selector + in the attribute method. + Currently, only layers with a single tensor input or output + are supported. + layer (torch.nn.Module): Layer for which attributions are computed. + Output size of attribute matches this layer's input or + output dimensions, depending on whether we attribute to + the inputs or outputs of the layer, corresponding to + attribution of each neuron in the input or output of + this layer. + Currently, it is assumed that the inputs or the outputs + of the layer, depending on which one is used for + attribution, can only be a single tensor. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in + then that type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of Neuron Conductance, + if `multiply_by_inputs` is set to True, final + sensitivity scores are being multiplied + by (inputs - baselines). + + """ + NeuronAttribution.__init__(self, forward_func, layer, device_ids) + GradientAttribution.__init__(self, forward_func) + self._multiply_by_inputs = multiply_by_inputs + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + neuron_selector: Union[int, Tuple[int, ...], Callable], + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "riemann_trapezoid", + internal_batch_size: Union[None, int] = None, + attribute_to_neuron_input: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which neuron + conductance is computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + neuron_selector (int, callable, or tuple of ints or slices): + Selector for neuron + in given layer for which attribution is desired. + Neuron selector can be provided as: + + - a single integer, if the layer output is 2D. This integer + selects the appropriate neuron column in the layer input + or output + + - a tuple of integers. Length of this + tuple must be one less than the number of dimensions + in the input / output of the given layer (since + dimension 0 corresponds to number of examples). + This can be used as long as the layer input / output + is a single tensor. + + - a callable, which should + take the target layer as input (single tensor or tuple + if multiple tensors are in layer) and return a selected + neuron - output shape should be 1D with length equal to + batch_size (one scalar per input example) + + NOTE: Callables applicable for neuron conductance are + less general than those of other methods and should + NOT aggregate values of the layer, only return a specific + output. This option should only be used in cases where the + layer input / output is a tuple of tensors, where the other + options would not suffice. This limitation is necessary since + neuron conductance, unlike other neuron methods, also utilizes + the gradient of output with respect to the intermedite neuron, + which cannot be computed for aggregations of multiple + intemediate neurons. + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define the starting point from which integral + is computed and can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. It will be + repeated for each of `n_steps` along the integrated + path. For all other types, the given argument is used + for all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + n_steps (int, optional): The number of steps used by the approximation + method. Default: 50. + method (string, optional): Method for approximating the integral, + one of `riemann_right`, `riemann_left`, `riemann_middle`, + `riemann_trapezoid` or `gausslegendre`. + Default: `gausslegendre` if no method is provided. + internal_batch_size (int, optional): Divides total #steps * #examples + data points into chunks of size at most internal_batch_size, + which are computed (forward / backward passes) + sequentially. internal_batch_size must be at least equal to + #examples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain internal_batch_size / num_devices examples. + If internal_batch_size is None, then all evaluations are + processed in one batch. + Default: None + attribute_to_neuron_input (bool, optional): Indicates whether to + compute the attributions with respect to the neuron input + or output. If `attribute_to_neuron_input` is set to True + then the attributions will be computed with respect to + neuron's inputs, otherwise it will be computed with respect + to neuron's outputs. + Note that currently it is assumed that either the input + or the output of internal neuron, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + Conductance for + particular neuron with respect to each input feature. + Attributions will always be the same size as the provided + inputs, with each value providing the attribution of the + corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> # It contains an attribute conv1, which is an instance of nn.conv2d, + >>> # and the output of this layer has dimensions Nx12x32x32. + >>> net = ImageClassifier() + >>> neuron_cond = NeuronConductance(net, net.conv1) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # To compute neuron attribution, we need to provide the neuron + >>> # index for which attribution is desired. Since the layer output + >>> # is Nx12x32x32, we need a tuple in the form (0..11,0..31,0..31) + >>> # which indexes a particular neuron in the layer output. + >>> # Computes neuron conductance for neuron with + >>> # index (4,1,2). + >>> attribution = neuron_cond.attribute(input, (4,1,2)) + """ + if callable(neuron_selector): + warnings.warn( + "The neuron_selector provided is a callable. Please ensure that this" + " function only selects neurons from the given layer; aggregating" + " or performing other operations on the tensor may lead to inaccurate" + " results." + ) + is_inputs_tuple = _is_tuple(inputs) + + inputs, baselines = _format_input_baseline(inputs, baselines) + _validate_input(inputs, baselines, n_steps, method) + + num_examples = inputs[0].shape[0] + + if internal_batch_size is not None: + num_examples = inputs[0].shape[0] + attrs = _batch_attribution( + self, + num_examples, + internal_batch_size, + n_steps, + inputs=inputs, + baselines=baselines, + neuron_selector=neuron_selector, + target=target, + additional_forward_args=additional_forward_args, + method=method, + attribute_to_neuron_input=attribute_to_neuron_input, + ) + else: + attrs = self._attribute( + inputs=inputs, + neuron_selector=neuron_selector, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + n_steps=n_steps, + method=method, + attribute_to_neuron_input=attribute_to_neuron_input, + ) + return _format_output(is_inputs_tuple, attrs) + + def _attribute( + self, + inputs: Tuple[Tensor, ...], + neuron_selector: Union[int, Tuple[int, ...], Callable], + baselines: Tuple[Union[Tensor, int, float], ...], + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "riemann_trapezoid", + attribute_to_neuron_input: bool = False, + step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None, + ) -> Tuple[Tensor, ...]: + + num_examples = inputs[0].shape[0] + total_batch = num_examples * n_steps + + if step_sizes_and_alphas is None: + # retrieve step size and scaling factor for specified approximation method + step_sizes_func, alphas_func = approximation_parameters(method) + step_sizes, alphas = step_sizes_func(n_steps), alphas_func(n_steps) + else: + step_sizes, alphas = step_sizes_and_alphas + + # Compute scaled inputs from baseline to final input. + scaled_features_tpl = tuple( + torch.cat( + [baseline + alpha * (input - baseline) for alpha in alphas], dim=0 + ).requires_grad_() + for input, baseline in zip(inputs, baselines) + ) + + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + # apply number of steps to additional forward args + # currently, number of steps is applied only to additional forward arguments + # that are nd-tensors. It is assumed that the first dimension is + # the number of batches. + # dim -> (#examples * #steps x additional_forward_args[0].shape[1:], ...) + input_additional_args = ( + _expand_additional_forward_args(additional_forward_args, n_steps) + if additional_forward_args is not None + else None + ) + expanded_target = _expand_target(target, n_steps) + + # Conductance Gradients - Returns gradient of output with respect to + # hidden layer and hidden layer evaluated at each input. + layer_gradients, layer_eval, input_grads = compute_layer_gradients_and_eval( + forward_fn=self.forward_func, + layer=self.layer, + inputs=scaled_features_tpl, + target_ind=expanded_target, + additional_forward_args=input_additional_args, + gradient_neuron_selector=neuron_selector, + device_ids=self.device_ids, + attribute_to_layer_input=attribute_to_neuron_input, + ) + + mid_grads = _verify_select_neuron(layer_gradients, neuron_selector) + scaled_input_gradients = tuple( + input_grad + * mid_grads.reshape((total_batch,) + (1,) * (len(input_grad.shape) - 1)) + for input_grad in input_grads + ) + + # Mutliplies by appropriate step size. + scaled_grads = tuple( + scaled_input_gradient.contiguous().view(n_steps, -1) + * torch.tensor(step_sizes).view(n_steps, 1).to(scaled_input_gradient.device) + for scaled_input_gradient in scaled_input_gradients + ) + + # Aggregates across all steps for each tensor in the input tuple + total_grads = tuple( + _reshape_and_sum(scaled_grad, n_steps, num_examples, input_grad.shape[1:]) + for (scaled_grad, input_grad) in zip(scaled_grads, input_grads) + ) + + if self.multiplies_by_inputs: + # computes attribution for each tensor in input tuple + # attributions has the same dimensionality as inputs + attributions = tuple( + total_grad * (input - baseline) + for total_grad, input, baseline in zip(total_grads, inputs, baselines) + ) + else: + attributions = total_grads + + return attributions + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs diff --git a/captum/attr/_core/neuron/neuron_deep_lift.py b/captum/attr/_core/neuron/neuron_deep_lift.py new file mode 100644 index 0000000000000000000000000000000000000000..aff216d37abc4e8f98ff334d62831eaae01650cc --- /dev/null +++ b/captum/attr/_core/neuron/neuron_deep_lift.py @@ -0,0 +1,499 @@ +#!/usr/bin/env python3 +import warnings +from typing import Any, Callable, cast, Tuple, Union + +from captum._utils.gradient import construct_neuron_grad_fn +from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric +from captum.attr._core.deep_lift import DeepLift, DeepLiftShap +from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module + + +class NeuronDeepLift(NeuronAttribution, GradientAttribution): + r""" + Implements DeepLIFT algorithm for the neuron based on the following paper: + Learning Important Features Through Propagating Activation Differences, + Avanti Shrikumar, et. al. + https://arxiv.org/abs/1704.02685 + + and the gradient formulation proposed in: + Towards better understanding of gradient-based attribution methods for + deep neural networks, Marco Ancona, et.al. + https://openreview.net/pdf?id=Sy21R9JAW + + This implementation supports only Rescale rule. RevealCancel rule will + be supported in later releases. + Although DeepLIFT's(Rescale Rule) attribution quality is comparable with + Integrated Gradients, it runs significantly faster than Integrated + Gradients and is preferred for large datasets. + + Currently we only support a limited number of non-linear activations + but the plan is to expand the list in the future. + + Note: As we know, currently we cannot access the building blocks, + of PyTorch's built-in LSTM, RNNs and GRUs such as Tanh and Sigmoid. + Nonetheless, it is possible to build custom LSTMs, RNNS and GRUs + with performance similar to built-in ones using TorchScript. + More details on how to build custom RNNs can be found here: + https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/ + """ + + def __init__( + self, model: Module, layer: Module, multiply_by_inputs: bool = True + ) -> None: + r""" + Args: + + model (nn.Module): The reference to PyTorch model instance. Model cannot + contain any in-place nonlinear submodules; these are not + supported by the register_full_backward_hook PyTorch API + starting from PyTorch v1.9. + layer (torch.nn.Module): Layer for which neuron attributions are computed. + Attributions for a particular neuron for the input or output + of this layer are computed using the argument neuron_selector + in the attribute method. + Currently, it is assumed that the inputs or the outputs + of the layer, depending on which one is used for + attribution, can only be a single tensor. + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in + then that type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of Neuron DeepLift, if `multiply_by_inputs` + is set to True, final sensitivity scores + are being multiplied by (inputs - baselines). + This flag applies only if `custom_attribution_func` is + set to None. + """ + NeuronAttribution.__init__(self, model, layer) + GradientAttribution.__init__(self, model) + self._multiply_by_inputs = multiply_by_inputs + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + baselines: BaselineType = None, + additional_forward_args: Any = None, + attribute_to_neuron_input: bool = False, + custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which layer + attributions are computed. If forward_func takes a + single tensor as input, a single input tensor should be + provided. If forward_func takes multiple tensors as input, + a tuple of the input tensors should be provided. It is + assumed that for all given input tensors, dimension 0 + corresponds to the number of examples (aka batch size), + and if multiple input tensors are provided, the examples + must be aligned appropriately. + neuron_selector (int, callable, or tuple of ints or slices): + Selector for neuron + in given layer for which attribution is desired. + Neuron selector can be provided as: + + - a single integer, if the layer output is 2D. This integer + selects the appropriate neuron column in the layer input + or output + + - a tuple of integers or slice objects. Length of this + tuple must be one less than the number of dimensions + in the input / output of the given layer (since + dimension 0 corresponds to number of examples). + The elements of the tuple can be either integers or + slice objects (slice object allows indexing a + range of neurons rather individual ones). + + If any of the tuple elements is a slice object, the + indexed output tensor is used for attribution. Note + that specifying a slice of a tensor would amount to + computing the attribution of the sum of the specified + neurons, and not the individual neurons independantly. + + - a callable, which should + take the target layer as input (single tensor or tuple + if multiple tensors are in layer) and return a neuron or + aggregate of the layer's neurons for attribution. + For example, this function could return the + sum of the neurons in the layer or sum of neurons with + activations in a particular range. It is expected that + this function returns either a tensor with one element + or a 1D tensor with length equal to batch_size (one scalar + per input example) + + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define reference samples that are compared with + the inputs. In order to assign attribution scores DeepLift + computes the differences between the inputs/outputs and + corresponding references. + Baselines can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a tuple + containing multiple additional arguments including tensors + or any arbitrary python types. These arguments are provided + to forward_func in order, following the arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + attribute_to_neuron_input (bool, optional): Indicates whether to + compute the attributions with respect to the neuron input + or output. If `attribute_to_neuron_input` is set to True + then the attributions will be computed with respect to + neuron's inputs, otherwise it will be computed with respect + to neuron's outputs. + Note that currently it is assumed that either the input + or the output of internal neuron, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + custom_attribution_func (callable, optional): A custom function for + computing final attribution scores. This function can take + at least one and at most three arguments with the + following signature: + + - custom_attribution_func(multipliers) + - custom_attribution_func(multipliers, inputs) + - custom_attribution_func(multipliers, inputs, baselines) + + In case this function is not provided, we use the default + logic defined as: multipliers * (inputs - baselines) + It is assumed that all input arguments, `multipliers`, + `inputs` and `baselines` are provided in tuples of same + length. `custom_attribution_func` returns a tuple of + attribution tensors that have the same length as the + `inputs`. + Default: None + + Returns: + **attributions** or 2-element tuple of **attributions**, **delta**: + - **attributions** (*tensor* or tuple of *tensors*): + Computes attributions using Deeplift's rescale rule for + particular neuron with respect to each input feature. + Attributions will always be the same size as the provided + inputs, with each value providing the attribution of the + corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> # creates an instance of LayerDeepLift to interpret target + >>> # class 1 with respect to conv4 layer. + >>> dl = NeuronDeepLift(net, net.conv4) + >>> input = torch.randn(1, 3, 32, 32, requires_grad=True) + >>> # Computes deeplift attribution scores for conv4 layer and neuron + >>> # index (4,1,2). + >>> attribution = dl.attribute(input, (4,1,2)) + """ + dl = DeepLift(cast(Module, self.forward_func), self.multiplies_by_inputs) + if not attribute_to_neuron_input: + warnings.warn( + "Attribution to neuron output is no longer supported for" + " NeuronDeepLift and will be deprecated in Captum" + " 0.6.0 due to changes in PyTorch's full backward hook" + " behavior. To obtain attributions for a neuron's" + " output, please attribute with respect to the next layer's input" + ) + dl.skip_new_hook_layer = self.layer # type: ignore + else: + dl.skip_new_hook_layer = None # type: ignore + dl.gradient_func = construct_neuron_grad_fn( + self.layer, + neuron_selector, + attribute_to_neuron_input=attribute_to_neuron_input, + ) + + # NOTE: using __wrapped__ to not log + return dl.attribute.__wrapped__( # type: ignore + dl, # self + inputs, + baselines, + additional_forward_args=additional_forward_args, + custom_attribution_func=custom_attribution_func, + ) + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs + + +class NeuronDeepLiftShap(NeuronAttribution, GradientAttribution): + r""" + Extends NeuronAttribution and uses LayerDeepLiftShap algorithms and + approximates SHAP values for given input `layer` and `neuron_selector`. + For each input sample - baseline pair it computes DeepLift attributions + with respect to inputs or outputs of given `layer` and `neuron_selector` + averages resulting attributions across baselines. Whether to compute the + attributions with respect to the inputs or outputs of the layer is defined + by the input flag `attribute_to_layer_input`. + More details about the algorithm can be found here: + + http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf + + Note that the explanation model: + 1. Assumes that input features are independent of one another + 2. Is linear, meaning that the explanations are modeled through + the additive composition of feature effects. + Although, it assumes a linear model for each explanation, the overall + model across multiple explanations can be complex and non-linear. + """ + + def __init__( + self, model: Module, layer: Module, multiply_by_inputs: bool = True + ) -> None: + r""" + Args: + + model (nn.Module): The reference to PyTorch model instance. Model cannot + contain any in-place nonlinear submodules; these are not + supported by the register_full_backward_hook PyTorch API + starting from PyTorch v1.9. + layer (torch.nn.Module): Layer for which neuron attributions are computed. + Attributions for a particular neuron for the input or output + of this layer are computed using the argument neuron_selector + in the attribute method. + Currently, only layers with a single tensor input and output + are supported. + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in + then that type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of Neuron DeepLift Shap, if `multiply_by_inputs` + is set to True, final sensitivity scores + are being multiplied by (inputs - baselines). + This flag applies only if `custom_attribution_func` is + set to None. + """ + NeuronAttribution.__init__(self, model, layer) + GradientAttribution.__init__(self, model) + self._multiply_by_inputs = multiply_by_inputs + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + baselines: Union[ + TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] + ], + additional_forward_args: Any = None, + attribute_to_neuron_input: bool = False, + custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which layer + attributions are computed. If forward_func takes a + single tensor as input, a single input tensor should be + provided. If forward_func takes multiple tensors as input, + a tuple of the input tensors should be provided. It is + assumed that for all given input tensors, dimension 0 + corresponds to the number of examples (aka batch size), + and if multiple input tensors are provided, the examples + must be aligned appropriately. + neuron_selector (int, callable, or tuple of ints or slices): + Selector for neuron + in given layer for which attribution is desired. + Neuron selector can be provided as: + + - a single integer, if the layer output is 2D. This integer + selects the appropriate neuron column in the layer input + or output + + - a tuple of integers or slice objects. Length of this + tuple must be one less than the number of dimensions + in the input / output of the given layer (since + dimension 0 corresponds to number of examples). + The elements of the tuple can be either integers or + slice objects (slice object allows indexing a + range of neurons rather individual ones). + + If any of the tuple elements is a slice object, the + indexed output tensor is used for attribution. Note + that specifying a slice of a tensor would amount to + computing the attribution of the sum of the specified + neurons, and not the individual neurons independantly. + + - a callable, which should + take the target layer as input (single tensor or tuple + if multiple tensors are in layer) and return a neuron or + aggregate of the layer's neurons for attribution. + For example, this function could return the + sum of the neurons in the layer or sum of neurons with + activations in a particular range. It is expected that + this function returns either a tensor with one element + or a 1D tensor with length equal to batch_size (one scalar + per input example) + baselines (tensor, tuple of tensors, callable): + Baselines define reference samples that are compared with + the inputs. In order to assign attribution scores DeepLift + computes the differences between the inputs/outputs and + corresponding references. Baselines can be provided as: + + - a single tensor, if inputs is a single tensor, with + the first dimension equal to the number of examples + in the baselines' distribution. The remaining dimensions + must match with input tensor's dimension starting from + the second dimension. + + - a tuple of tensors, if inputs is a tuple of tensors, + with the first dimension of any tensor inside the tuple + equal to the number of examples in the baseline's + distribution. The remaining dimensions must match + the dimensions of the corresponding input tensor + starting from the second dimension. + + - callable function, optionally takes `inputs` as an + argument and either returns a single tensor + or a tuple of those. + + It is recommended that the number of samples in the baselines' + tensors is larger than one. + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a tuple + containing multiple additional arguments including tensors + or any arbitrary python types. These arguments are provided + to forward_func in order, following the arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + attribute_to_neuron_input (bool, optional): Indicates whether to + compute the attributions with respect to the neuron input + or output. If `attribute_to_neuron_input` is set to True + then the attributions will be computed with respect to + neuron's inputs, otherwise it will be computed with respect + to neuron's outputs. + Note that currently it is assumed that either the input + or the output of internal neuron, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + custom_attribution_func (callable, optional): A custom function for + computing final attribution scores. This function can take + at least one and at most three arguments with the + following signature: + + - custom_attribution_func(multipliers) + - custom_attribution_func(multipliers, inputs) + - custom_attribution_func(multipliers, inputs, baselines) + + In case this function is not provided, we use the default + logic defined as: multipliers * (inputs - baselines) + It is assumed that all input arguments, `multipliers`, + `inputs` and `baselines` are provided in tuples of same + length. `custom_attribution_func` returns a tuple of + attribution tensors that have the same length as the + `inputs`. + Default: None + + Returns: + **attributions** or 2-element tuple of **attributions**, **delta**: + - **attributions** (*tensor* or tuple of *tensors*): + Computes attributions using Deeplift's rescale rule for + particular neuron with respect to each input feature. + Attributions will always be the same size as the provided + inputs, with each value providing the attribution of the + corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> # creates an instance of LayerDeepLift to interpret target + >>> # class 1 with respect to conv4 layer. + >>> dl = NeuronDeepLiftShap(net, net.conv4) + >>> input = torch.randn(1, 3, 32, 32, requires_grad=True) + >>> # Computes deeplift attribution scores for conv4 layer and neuron + >>> # index (4,1,2). + >>> attribution = dl.attribute(input, (4,1,2)) + """ + + dl = DeepLiftShap(cast(Module, self.forward_func), self.multiplies_by_inputs) + if not attribute_to_neuron_input: + warnings.warn( + "Attribution to neuron output is no longer supported for" + " NeuronDeepLiftShap and will be deprecated in Captum" + " 0.6.0 due to changes in PyTorch's full backward hook" + " behavior. To obtain attributions for a neuron's" + " output, please attribute with respect to the next layer's input" + ) + dl.skip_new_hook_layer = self.layer # type: ignore + else: + dl.skip_new_hook_layer = None # type: ignore + dl.gradient_func = construct_neuron_grad_fn( + self.layer, + neuron_selector, + attribute_to_neuron_input=attribute_to_neuron_input, + ) + + # NOTE: using __wrapped__ to not log + return dl.attribute.__wrapped__( # type: ignore + dl, # self + inputs, + baselines, + additional_forward_args=additional_forward_args, + custom_attribution_func=custom_attribution_func, + ) + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs diff --git a/captum/attr/_core/neuron/neuron_feature_ablation.py b/captum/attr/_core/neuron/neuron_feature_ablation.py new file mode 100644 index 0000000000000000000000000000000000000000..d706f71cb469fac47614079c66cdcb8441aa8f5d --- /dev/null +++ b/captum/attr/_core/neuron/neuron_feature_ablation.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, List, Tuple, Union + +import torch +from captum._utils.common import _verify_select_neuron +from captum._utils.gradient import _forward_layer_eval +from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric +from captum.attr._core.feature_ablation import FeatureAblation +from captum.attr._utils.attribution import NeuronAttribution, PerturbationAttribution +from captum.log import log_usage +from torch.nn import Module + + +class NeuronFeatureAblation(NeuronAttribution, PerturbationAttribution): + r""" + A perturbation based approach to computing neuron attribution, + involving replacing each input feature with a given baseline / + reference, and computing the difference in the neuron's input / output. + By default, each scalar value within + each input tensor is taken as a feature and replaced independently. Passing + a feature mask, allows grouping features to be ablated together. This can + be used in cases such as images, where an entire segment or region + can be ablated, measuring the importance of the segment (feature group). + Each input scalar in the group will be given the same attribution value + equal to the change in target as a result of ablating the entire feature + group. + """ + + def __init__( + self, + forward_func: Callable, + layer: Module, + device_ids: Union[None, List[int]] = None, + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + layer (torch.nn.Module): Layer for which attributions are computed. + Attributions for a particular neuron in the input or output + of this layer are computed using the argument neuron_selector + in the attribute method. + Currently, it is assumed that the inputs or the outputs + of the layer, depending on which one is used for + attribution, can only be a single tensor. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + """ + NeuronAttribution.__init__(self, forward_func, layer, device_ids) + PerturbationAttribution.__init__(self, forward_func) + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + baselines: BaselineType = None, + additional_forward_args: Any = None, + feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, + attribute_to_neuron_input: bool = False, + perturbations_per_eval: int = 1, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which neuron + attributions are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + neuron_selector (int, callable, or tuple of ints or slices): + Selector for neuron + in given layer for which attribution is desired. + Neuron selector can be provided as: + + - a single integer, if the layer output is 2D. This integer + selects the appropriate neuron column in the layer input + or output + + - a tuple of integers or slice objects. Length of this + tuple must be one less than the number of dimensions + in the input / output of the given layer (since + dimension 0 corresponds to number of examples). + The elements of the tuple can be either integers or + slice objects (slice object allows indexing a + range of neurons rather individual ones). + + If any of the tuple elements is a slice object, the + indexed output tensor is used for attribution. Note + that specifying a slice of a tensor would amount to + computing the attribution of the sum of the specified + neurons, and not the individual neurons independantly. + + - a callable, which should + take the target layer as input (single tensor or tuple + if multiple tensors are in layer) and return a neuron or + aggregate of the layer's neurons for attribution. + For example, this function could return the + sum of the neurons in the layer or sum of neurons with + activations in a particular range. It is expected that + this function returns either a tensor with one element + or a 1D tensor with length equal to batch_size (one scalar + per input example) + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define reference value which replaces each + feature when ablated. + Baselines can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or + broadcastable to match the dimensions of inputs + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + feature_mask (tensor or tuple of tensors, optional): + feature_mask defines a mask for the input, grouping + features which should be ablated together. feature_mask + should contain the same number of tensors as inputs. + Each tensor should + be the same size as the corresponding input or + broadcastable to match the input tensor. Each tensor + should contain integers in the range 0 to num_features + - 1, and indices corresponding to the same feature should + have the same value. + Note that features within each input tensor are ablated + independently (not across tensors). + If None, then a feature mask is constructed which assigns + each scalar within a tensor as a separate feature, which + is ablated independently. + Default: None + attribute_to_neuron_input (bool, optional): Indicates whether to + compute the attributions with respect to the neuron input + or output. If `attribute_to_neuron_input` is set to True + then the attributions will be computed with respect to + neuron's inputs, otherwise it will be computed with respect + to neuron's outputs. + Note that currently it is assumed that either the input + or the output of internal neurons, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + perturbations_per_eval (int, optional): Allows ablation of multiple + features to be processed simultaneously in one call to + forward_fn. + Each forward pass will contain a maximum of + perturbations_per_eval * #examples samples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain at most + (perturbations_per_eval * #examples) / num_devices + samples. + Default: 1 + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + Attributions of particular neuron with respect to each input + feature. Attributions will always be the same size as the + provided inputs, with each value providing the attribution + of the corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + Examples:: + + >>> # SimpleClassifier takes a single input tensor of size Nx4x4, + >>> # and returns an Nx3 tensor of class probabilities. + >>> # It contains an attribute conv1, which is an instance of nn.conv2d, + >>> # and the output of this layer has dimensions Nx12x3x3. + >>> net = SimpleClassifier() + >>> # Generating random input with size 2 x 4 x 4 + >>> input = torch.randn(2, 4, 4) + >>> # Defining NeuronFeatureAblation interpreter + >>> ablator = NeuronFeatureAblation(net, net.conv1) + >>> # To compute neuron attribution, we need to provide the neuron + >>> # index for which attribution is desired. Since the layer output + >>> # is Nx12x3x3, we need a tuple in the form (0..11,0..2,0..2) + >>> # which indexes a particular neuron in the layer output. + >>> # For this example, we choose the index (4,1,2). + >>> # Computes neuron gradient for neuron with + >>> # index (4,1,2). + >>> # Computes ablation attribution, ablating each of the 16 + >>> # scalar inputs independently. + >>> attr = ablator.attribute(input, neuron_selector=(4,1,2)) + + >>> # Alternatively, we may want to ablate features in groups, e.g. + >>> # grouping each 2x2 square of the inputs and ablating them together. + >>> # This can be done by creating a feature mask as follows, which + >>> # defines the feature groups, e.g.: + >>> # +---+---+---+---+ + >>> # | 0 | 0 | 1 | 1 | + >>> # +---+---+---+---+ + >>> # | 0 | 0 | 1 | 1 | + >>> # +---+---+---+---+ + >>> # | 2 | 2 | 3 | 3 | + >>> # +---+---+---+---+ + >>> # | 2 | 2 | 3 | 3 | + >>> # +---+---+---+---+ + >>> # With this mask, all inputs with the same value are ablated + >>> # simultaneously, and the attribution for each input in the same + >>> # group (0, 1, 2, and 3) per example are the same. + >>> # The attributions can be calculated as follows: + >>> # feature mask has dimensions 1 x 4 x 4 + >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1], + >>> [2,2,3,3],[2,2,3,3]]]) + >>> attr = ablator.attribute(input, neuron_selector=(4,1,2), + >>> feature_mask=feature_mask) + """ + + def neuron_forward_func(*args: Any): + with torch.no_grad(): + layer_eval = _forward_layer_eval( + self.forward_func, + args, + self.layer, + device_ids=self.device_ids, + attribute_to_layer_input=attribute_to_neuron_input, + ) + return _verify_select_neuron(layer_eval, neuron_selector) + + ablator = FeatureAblation(neuron_forward_func) + + # NOTE: using __wrapped__ to not log + return ablator.attribute.__wrapped__( + ablator, # self + inputs, + baselines=baselines, + additional_forward_args=additional_forward_args, + feature_mask=feature_mask, + perturbations_per_eval=perturbations_per_eval, + ) diff --git a/captum/attr/_core/neuron/neuron_gradient.py b/captum/attr/_core/neuron/neuron_gradient.py new file mode 100644 index 0000000000000000000000000000000000000000..5292990bbf017e32c8d51fbca673ee1d65d991d2 --- /dev/null +++ b/captum/attr/_core/neuron/neuron_gradient.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, List, Tuple, Union + +from captum._utils.common import ( + _format_additional_forward_args, + _format_output, + _format_tensor_into_tuples, + _is_tuple, +) +from captum._utils.gradient import ( + _forward_layer_eval_with_neuron_grads, + apply_gradient_requirements, + undo_gradient_requirements, +) +from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution +from captum.log import log_usage +from torch.nn import Module + + +class NeuronGradient(NeuronAttribution, GradientAttribution): + r""" + Computes the gradient of the output of a particular neuron with + respect to the inputs of the network. + """ + + def __init__( + self, + forward_func: Callable, + layer: Module, + device_ids: Union[None, List[int]] = None, + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + layer (torch.nn.Module): Layer for which attributions are computed. + Output size of attribute matches this layer's input or + output dimensions, depending on whether we attribute to + the inputs or outputs of the layer, corresponding to + attribution of each neuron in the input or output of + this layer. + Currently, it is assumed that the inputs or the outputs + of the layer, depending on which one is used for + attribution, can only be a single tensor. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + """ + NeuronAttribution.__init__(self, forward_func, layer, device_ids) + GradientAttribution.__init__(self, forward_func) + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + additional_forward_args: Any = None, + attribute_to_neuron_input: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which neuron + gradients are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + neuron_selector (int, callable, or tuple of ints or slices): + Selector for neuron + in given layer for which attribution is desired. + Neuron selector can be provided as: + + - a single integer, if the layer output is 2D. This integer + selects the appropriate neuron column in the layer input + or output + + - a tuple of integers or slice objects. Length of this + tuple must be one less than the number of dimensions + in the input / output of the given layer (since + dimension 0 corresponds to number of examples). + The elements of the tuple can be either integers or + slice objects (slice object allows indexing a + range of neurons rather individual ones). + + If any of the tuple elements is a slice object, the + indexed output tensor is used for attribution. Note + that specifying a slice of a tensor would amount to + computing the attribution of the sum of the specified + neurons, and not the individual neurons independantly. + + - a callable, which should + take the target layer as input (single tensor or tuple + if multiple tensors are in layer) and return a neuron or + aggregate of the layer's neurons for attribution. + For example, this function could return the + sum of the neurons in the layer or sum of neurons with + activations in a particular range. It is expected that + this function returns either a tensor with one element + or a 1D tensor with length equal to batch_size (one scalar + per input example) + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + attribute_to_neuron_input (bool, optional): Indicates whether to + compute the attributions with respect to the neuron input + or output. If `attribute_to_neuron_input` is set to True + then the attributions will be computed with respect to + neuron's inputs, otherwise it will be computed with respect + to neuron's outputs. + Note that currently it is assumed that either the input + or the output of internal neurons, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + Gradients of particular neuron with respect to each input + feature. Attributions will always be the same size as the + provided inputs, with each value providing the attribution + of the corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> # It contains an attribute conv1, which is an instance of nn.conv2d, + >>> # and the output of this layer has dimensions Nx12x32x32. + >>> net = ImageClassifier() + >>> neuron_ig = NeuronGradient(net, net.conv1) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # To compute neuron attribution, we need to provide the neuron + >>> # index for which attribution is desired. Since the layer output + >>> # is Nx12x32x32, we need a tuple in the form (0..11,0..31,0..31) + >>> # which indexes a particular neuron in the layer output. + >>> # For this example, we choose the index (4,1,2). + >>> # Computes neuron gradient for neuron with + >>> # index (4,1,2). + >>> attribution = neuron_ig.attribute(input, (4,1,2)) + """ + is_inputs_tuple = _is_tuple(inputs) + inputs = _format_tensor_into_tuples(inputs) + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + gradient_mask = apply_gradient_requirements(inputs) + + _, input_grads = _forward_layer_eval_with_neuron_grads( + self.forward_func, + inputs, + self.layer, + additional_forward_args, + gradient_neuron_selector=neuron_selector, + device_ids=self.device_ids, + attribute_to_layer_input=attribute_to_neuron_input, + ) + + undo_gradient_requirements(inputs, gradient_mask) + return _format_output(is_inputs_tuple, input_grads) diff --git a/captum/attr/_core/neuron/neuron_gradient_shap.py b/captum/attr/_core/neuron/neuron_gradient_shap.py new file mode 100644 index 0000000000000000000000000000000000000000..42a543b50ddd5b556ec6125be5ca098062e454a0 --- /dev/null +++ b/captum/attr/_core/neuron/neuron_gradient_shap.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, List, Tuple, Union + +from captum._utils.gradient import construct_neuron_grad_fn +from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from captum.attr._core.gradient_shap import GradientShap +from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution +from captum.log import log_usage +from torch.nn import Module + + +class NeuronGradientShap(NeuronAttribution, GradientAttribution): + r""" + Implements gradient SHAP for a neuron in a hidden layer based on the + implementation from SHAP's primary author. For reference, please, view: + + https://github.com/slundberg/shap\ + #deep-learning-example-with-gradientexplainer-tensorflowkeraspytorch-models + + A Unified Approach to Interpreting Model Predictions + http://papers.nips.cc/paper\ + 7062-a-unified-approach-to-interpreting-model-predictions + + GradientShap approximates SHAP values by computing the expectations of + gradients by randomly sampling from the distribution of baselines/references. + It adds white noise to each input sample `n_samples` times, selects a + random baseline from baselines' distribution and a random point along the + path between the baseline and the input, and computes the gradient of the + neuron with index `neuron_selector` with respect to those selected random + points. The final SHAP values represent the expected values of + `gradients * (inputs - baselines)`. + + GradientShap makes an assumption that the input features are independent + and that the explanation model is linear, meaning that the explanations + are modeled through the additive composition of feature effects. + Under those assumptions, SHAP value can be approximated as the expectation + of gradients that are computed for randomly generated `n_samples` input + samples after adding gaussian noise `n_samples` times to each input for + different baselines/references. + + In some sense it can be viewed as an approximation of integrated gradients + by computing the expectations of gradients for different baselines. + + Current implementation uses Smoothgrad from `NoiseTunnel` in order to + randomly draw samples from the distribution of baselines, add noise to input + samples and compute the expectation (smoothgrad). + """ + + def __init__( + self, + forward_func: Callable, + layer: Module, + device_ids: Union[None, List[int]] = None, + multiply_by_inputs: bool = True, + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + layer (torch.nn.Module): Layer for which neuron attributions are computed. + The output size of the attribute method matches the + dimensions of the inputs or ouputs of the neuron with + index `neuron_selector` in this layer, depending on whether + we attribute to the inputs or outputs of the neuron. + Currently, it is assumed that the inputs or the outputs + of the neurons in this layer, depending on which one is + used for attribution, can only be a single tensor. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in + then that type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of Neuron Gradient SHAP, + if `multiply_by_inputs` is set to True, the + sensitivity scores for scaled inputs are + being multiplied by (inputs - baselines). + """ + NeuronAttribution.__init__(self, forward_func, layer, device_ids) + GradientAttribution.__init__(self, forward_func) + self._multiply_by_inputs = multiply_by_inputs + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + baselines: Union[ + TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] + ], + n_samples: int = 5, + stdevs: float = 0.0, + additional_forward_args: Any = None, + attribute_to_neuron_input: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which SHAP attribution + values are computed. If `forward_func` takes a single + tensor as input, a single input tensor should be provided. + If `forward_func` takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + neuron_selector (int, callable, or tuple of ints or slices): + Selector for neuron + in given layer for which attribution is desired. + Neuron selector can be provided as: + + - a single integer, if the layer output is 2D. This integer + selects the appropriate neuron column in the layer input + or output + + - a tuple of integers or slice objects. Length of this + tuple must be one less than the number of dimensions + in the input / output of the given layer (since + dimension 0 corresponds to number of examples). + The elements of the tuple can be either integers or + slice objects (slice object allows indexing a + range of neurons rather individual ones). + + If any of the tuple elements is a slice object, the + indexed output tensor is used for attribution. Note + that specifying a slice of a tensor would amount to + computing the attribution of the sum of the specified + neurons, and not the individual neurons independantly. + + - a callable, which should + take the target layer as input (single tensor or tuple + if multiple tensors are in layer) and return a neuron or + aggregate of the layer's neurons for attribution. + For example, this function could return the + sum of the neurons in the layer or sum of neurons with + activations in a particular range. It is expected that + this function returns either a tensor with one element + or a 1D tensor with length equal to batch_size (one scalar + per input example) + baselines (tensor, tuple of tensors, callable): + Baselines define the starting point from which expectation + is computed and can be provided as: + + - a single tensor, if inputs is a single tensor, with + the first dimension equal to the number of examples + in the baselines' distribution. The remaining dimensions + must match with input tensor's dimension starting from + the second dimension. + + - a tuple of tensors, if inputs is a tuple of tensors, + with the first dimension of any tensor inside the tuple + equal to the number of examples in the baseline's + distribution. The remaining dimensions must match + the dimensions of the corresponding input tensor + starting from the second dimension. + + - callable function, optionally takes `inputs` as an + argument and either returns a single tensor + or a tuple of those. + + It is recommended that the number of samples in the baselines' + tensors is larger than one. + n_samples (int, optional): The number of randomly generated examples + per sample in the input batch. Random examples are + generated by adding gaussian random noise to each sample. + Default: `5` if `n_samples` is not provided. + stdevs (float, or a tuple of floats optional): The standard deviation + of gaussian noise with zero mean that is added to each + input in the batch. If `stdevs` is a single float value + then that same value is used for all inputs. If it is + a tuple, then it must have the same length as the inputs + tuple. In this case, each stdev value in the stdevs tuple + corresponds to the input with the same index in the inputs + tuple. + Default: 0.0 + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It can contain a tuple of ND tensors or + any arbitrary python type of any shape. + In case of the ND tensor the first dimension of the + tensor must correspond to the batch size. It will be + repeated for each `n_steps` for each randomly generated + input sample. + Note that the gradients are not computed with respect + to these arguments. + Default: None + attribute_to_neuron_input (bool, optional): Indicates whether to + compute the attributions with respect to the neuron input + or output. If `attribute_to_neuron_input` is set to True + then the attributions will be computed with respect to + neuron's inputs, otherwise it will be computed with respect + to neuron's outputs. + Note that currently it is assumed that either the input + or the output of internal neuron, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + + Returns: + **attributions** or 2-element tuple of **attributions**, **delta**: + - **attributions** (*tensor* or tuple of *tensors*): + Attribution score computed based on GradientSHAP with respect + to each input feature. Attributions will always be + the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> neuron_grad_shap = NeuronGradientShap(net, net.linear2) + >>> input = torch.randn(3, 3, 32, 32, requires_grad=True) + >>> # choosing baselines randomly + >>> baselines = torch.randn(20, 3, 32, 32) + >>> # Computes gradient SHAP of first neuron in linear2 layer + >>> # with respect to the input's of the network. + >>> # Attribution size matches input size: 3x3x32x32 + >>> attribution = neuron_grad_shap.attribute(input, neuron_ind=0 + baselines) + + """ + gs = GradientShap(self.forward_func, self.multiplies_by_inputs) + gs.gradient_func = construct_neuron_grad_fn( + self.layer, + neuron_selector, + self.device_ids, + attribute_to_neuron_input=attribute_to_neuron_input, + ) + + # NOTE: using __wrapped__ to not log + return gs.attribute.__wrapped__( # type: ignore + gs, # self + inputs, + baselines, + n_samples=n_samples, + stdevs=stdevs, + additional_forward_args=additional_forward_args, + ) + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs diff --git a/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py b/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py new file mode 100644 index 0000000000000000000000000000000000000000..7c69aed87ac42dd049b06d80a706d34c344969c4 --- /dev/null +++ b/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python3 +import warnings +from typing import Any, Callable, List, Tuple, Union + +from captum._utils.gradient import construct_neuron_grad_fn +from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from captum.attr._core.guided_backprop_deconvnet import Deconvolution, GuidedBackprop +from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution +from captum.log import log_usage +from torch.nn import Module + + +class NeuronDeconvolution(NeuronAttribution, GradientAttribution): + r""" + Computes attribution of the given neuron using deconvolution. + Deconvolution computes the gradient of the target output with + respect to the input, but gradients of ReLU functions are overridden so + that the gradient of the ReLU input is simply computed taking ReLU of + the output gradient, essentially only propagating non-negative gradients + (without dependence on the sign of the ReLU input). + + More details regarding the deconvolution algorithm can be found + in these papers: + https://arxiv.org/abs/1311.2901 + https://link.springer.com/chapter/10.1007/978-3-319-46466-4_8 + + Warning: Ensure that all ReLU operations in the forward function of the + given model are performed using a module (nn.module.ReLU). + If nn.functional.ReLU is used, gradients are not overridden appropriately. + """ + + def __init__( + self, model: Module, layer: Module, device_ids: Union[None, List[int]] = None + ) -> None: + r""" + Args: + + model (nn.Module): The reference to PyTorch model instance. Model cannot + contain any in-place ReLU submodules; these are not + supported by the register_full_backward_hook PyTorch API + starting from PyTorch v1.9. + layer (Module): Layer for which attributions are computed. + Output size of attribute matches this layer's input or + output dimensions, depending on whether we attribute to + the inputs or outputs of the layer, corresponding to + attribution of each neuron in the input or output of + this layer. + Currently, it is assumed that the inputs or the outputs + of the layer, depending on which one is used for + attribution, can only be a single tensor. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + """ + NeuronAttribution.__init__(self, model, layer, device_ids) + GradientAttribution.__init__(self, model) + self.deconv = Deconvolution(model) + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + additional_forward_args: Any = None, + attribute_to_neuron_input: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which + attributions are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + neuron_selector (int, callable, or tuple of ints or slices): + Selector for neuron + in given layer for which attribution is desired. + Neuron selector can be provided as: + + - a single integer, if the layer output is 2D. This integer + selects the appropriate neuron column in the layer input + or output + + - a tuple of integers or slice objects. Length of this + tuple must be one less than the number of dimensions + in the input / output of the given layer (since + dimension 0 corresponds to number of examples). + The elements of the tuple can be either integers or + slice objects (slice object allows indexing a + range of neurons rather individual ones). + + If any of the tuple elements is a slice object, the + indexed output tensor is used for attribution. Note + that specifying a slice of a tensor would amount to + computing the attribution of the sum of the specified + neurons, and not the individual neurons independantly. + + - a callable, which should + take the target layer as input (single tensor or tuple + if multiple tensors are in layer) and return a neuron or + aggregate of the layer's neurons for attribution. + For example, this function could return the + sum of the neurons in the layer or sum of neurons with + activations in a particular range. It is expected that + this function returns either a tensor with one element + or a 1D tensor with length equal to batch_size (one scalar + per input example) + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a tuple + containing multiple additional arguments including tensors + or any arbitrary python types. These arguments are provided to + forward_func in order, following the arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + attribute_to_neuron_input (bool, optional): Indicates whether to + compute the attributions with respect to the neuron input + or output. If `attribute_to_neuron_input` is set to True + then the attributions will be computed with respect to + neuron's inputs, otherwise it will be computed with respect + to neuron's outputs. + Note that currently it is assumed that either the input + or the output of internal neuron, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + Deconvolution attribution of + particular neuron with respect to each input feature. + Attributions will always be the same size as the provided + inputs, with each value providing the attribution of the + corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> # It contains an attribute conv1, which is an instance of nn.conv2d, + >>> # and the output of this layer has dimensions Nx12x32x32. + >>> net = ImageClassifier() + >>> neuron_deconv = NeuronDeconvolution(net, net.conv1) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # To compute neuron attribution, we need to provide the neuron + >>> # index for which attribution is desired. Since the layer output + >>> # is Nx12x32x32, we need a tuple in the form (0..11,0..31,0..31) + >>> # which indexes a particular neuron in the layer output. + >>> # For this example, we choose the index (4,1,2). + >>> # Computes neuron deconvolution for neuron with + >>> # index (4,1,2). + >>> attribution = neuron_deconv.attribute(input, (4,1,2)) + """ + if not attribute_to_neuron_input: + warnings.warn( + "Attribution to neuron output is no longer supported for" + " NeuronDeconvolution and will be deprecated in Captum" + " 0.6.0 due to changes in PyTorch's full backward hook" + " behavior. To obtain attributions for a neuron's" + " output, please attribute with respect to the next layer's input" + ) + self.deconv.skip_new_hook_layer = self.layer # type: ignore + else: + self.deconv.skip_new_hook_layer = None # type: ignore + + self.deconv.gradient_func = construct_neuron_grad_fn( + self.layer, neuron_selector, self.device_ids, attribute_to_neuron_input + ) + + # NOTE: using __wrapped__ to not log + return self.deconv.attribute.__wrapped__( + self.deconv, inputs, None, additional_forward_args + ) + + +class NeuronGuidedBackprop(NeuronAttribution, GradientAttribution): + r""" + Computes attribution of the given neuron using guided backpropagation. + Guided backpropagation computes the gradient of the target neuron + with respect to the input, but gradients of ReLU functions are overridden + so that only non-negative gradients are backpropagated. + + More details regarding the guided backpropagation algorithm can be found + in the original paper here: + https://arxiv.org/abs/1412.6806 + + Warning: Ensure that all ReLU operations in the forward function of the + given model are performed using a module (nn.module.ReLU). + If nn.functional.ReLU is used, gradients are not overridden appropriately. + """ + + def __init__( + self, model: Module, layer: Module, device_ids: Union[None, List[int]] = None + ) -> None: + r""" + Args: + + model (nn.Module): The reference to PyTorch model instance. Model cannot + contain any in-place ReLU submodules; these are not + supported by the register_full_backward_hook PyTorch API + starting from PyTorch v1.9. + layer (Module): Layer for which neuron attributions are computed. + Attributions for a particular neuron in the output of + this layer are computed using the argument neuron_selector + in the attribute method. + Currently, only layers with a single tensor output are + supported. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + """ + NeuronAttribution.__init__(self, model, layer, device_ids) + GradientAttribution.__init__(self, model) + self.guided_backprop = GuidedBackprop(model) + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + additional_forward_args: Any = None, + attribute_to_neuron_input: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which + attributions are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + neuron_selector (int, callable, or tuple of ints or slices): + Selector for neuron + in given layer for which attribution is desired. + Neuron selector can be provided as: + + - a single integer, if the layer output is 2D. This integer + selects the appropriate neuron column in the layer input + or output + + - a tuple of integers or slice objects. Length of this + tuple must be one less than the number of dimensions + in the input / output of the given layer (since + dimension 0 corresponds to number of examples). + The elements of the tuple can be either integers or + slice objects (slice object allows indexing a + range of neurons rather individual ones). + + If any of the tuple elements is a slice object, the + indexed output tensor is used for attribution. Note + that specifying a slice of a tensor would amount to + computing the attribution of the sum of the specified + neurons, and not the individual neurons independantly. + + - a callable, which should + take the target layer as input (single tensor or tuple + if multiple tensors are in layer) and return a neuron or + aggregate of the layer's neurons for attribution. + For example, this function could return the + sum of the neurons in the layer or sum of neurons with + activations in a particular range. It is expected that + this function returns either a tensor with one element + or a 1D tensor with length equal to batch_size (one scalar + per input example) + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a tuple + containing multiple additional arguments including tensors + or any arbitrary python types. These arguments are provided to + forward_func in order, following the arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + attribute_to_neuron_input (bool, optional): Indicates whether to + compute the attributions with respect to the neuron input + or output. If `attribute_to_neuron_input` is set to True + then the attributions will be computed with respect to + neuron's inputs, otherwise it will be computed with respect + to neuron's outputs. + Note that currently it is assumed that either the input + or the output of internal neurons, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + Guided backprop attribution of + particular neuron with respect to each input feature. + Attributions will always be the same size as the provided + inputs, with each value providing the attribution of the + corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> # It contains an attribute conv1, which is an instance of nn.conv2d, + >>> # and the output of this layer has dimensions Nx12x32x32. + >>> net = ImageClassifier() + >>> neuron_gb = NeuronGuidedBackprop(net, net.conv1) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # To compute neuron attribution, we need to provide the neuron + >>> # index for which attribution is desired. Since the layer output + >>> # is Nx12x32x32, we need a tuple in the form (0..11,0..31,0..31) + >>> # which indexes a particular neuron in the layer output. + >>> # For this example, we choose the index (4,1,2). + >>> # Computes neuron guided backpropagation for neuron with + >>> # index (4,1,2). + >>> attribution = neuron_gb.attribute(input, (4,1,2)) + """ + if not attribute_to_neuron_input: + warnings.warn( + "Attribution to neuron output is no longer supported for" + " NeuronGuidedBackprop and will be deprecated in Captum" + " 0.6.0 due to changes in PyTorch's full backward hook" + " behavior. To obtain attributions for a neuron's" + " output, please attribute with respect to the next layer's input" + ) + self.guided_backprop.skip_new_hook_layer = self.layer # type: ignore + else: + self.guided_backprop.skip_new_hook_layer = None # type: ignore + + self.guided_backprop.gradient_func = construct_neuron_grad_fn( + self.layer, neuron_selector, self.device_ids, attribute_to_neuron_input + ) + # NOTE: using __wrapped__ to not log + return self.guided_backprop.attribute.__wrapped__( + self.guided_backprop, inputs, None, additional_forward_args + ) diff --git a/captum/attr/_core/neuron/neuron_integrated_gradients.py b/captum/attr/_core/neuron/neuron_integrated_gradients.py new file mode 100644 index 0000000000000000000000000000000000000000..f67aec7e7eeea37bcd4de61241e23f31dfce7f19 --- /dev/null +++ b/captum/attr/_core/neuron/neuron_integrated_gradients.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, List, Tuple, Union + +from captum._utils.gradient import construct_neuron_grad_fn +from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from captum.attr._core.integrated_gradients import IntegratedGradients +from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module + + +class NeuronIntegratedGradients(NeuronAttribution, GradientAttribution): + r""" + Approximates the integral of gradients for a particular neuron + along the path from a baseline input to the given input. + If no baseline is provided, the default baseline is the zero tensor. + More details regarding the integrated gradient method can be found in the + original paper here: + https://arxiv.org/abs/1703.01365 + + Note that this method is equivalent to applying integrated gradients + where the output is the output of the identified neuron. + """ + + def __init__( + self, + forward_func: Callable, + layer: Module, + device_ids: Union[None, List[int]] = None, + multiply_by_inputs: bool = True, + ) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or any + modification of it + layer (torch.nn.Module): Layer for which attributions are computed. + Output size of attribute matches this layer's input or + output dimensions, depending on whether we attribute to + the inputs or outputs of the layer, corresponding to + attribution of each neuron in the input or output of + this layer. + Currently, it is assumed that the inputs or the outputs + of the layer, depending on which one is used for + attribution, can only be a single tensor. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model. This allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in + then that type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of Neuron Integrated Gradients, + if `multiply_by_inputs` is set to True, final + sensitivity scores are being multiplied + by (inputs - baselines). + + """ + NeuronAttribution.__init__(self, forward_func, layer, device_ids) + GradientAttribution.__init__(self, forward_func) + self._multiply_by_inputs = multiply_by_inputs + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + baselines: Union[None, Tensor, Tuple[Tensor, ...]] = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + internal_batch_size: Union[None, int] = None, + attribute_to_neuron_input: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which neuron integrated + gradients are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + neuron_selector (int, callable, or tuple of ints or slices): + Selector for neuron + in given layer for which attribution is desired. + Neuron selector can be provided as: + + - a single integer, if the layer output is 2D. This integer + selects the appropriate neuron column in the layer input + or output + + - a tuple of integers or slice objects. Length of this + tuple must be one less than the number of dimensions + in the input / output of the given layer (since + dimension 0 corresponds to number of examples). + The elements of the tuple can be either integers or + slice objects (slice object allows indexing a + range of neurons rather individual ones). + + If any of the tuple elements is a slice object, the + indexed output tensor is used for attribution. Note + that specifying a slice of a tensor would amount to + computing the attribution of the sum of the specified + neurons, and not the individual neurons independantly. + + - a callable, which should + take the target layer as input (single tensor or tuple + if multiple tensors are in layer) and return a neuron or + aggregate of the layer's neurons for attribution. + For example, this function could return the + sum of the neurons in the layer or sum of neurons with + activations in a particular range. It is expected that + this function returns either a tensor with one element + or a 1D tensor with length equal to batch_size (one scalar + per input example) + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define the starting point from which integral + is computed. + Baselines can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. It will be + repeated for each of `n_steps` along the integrated + path. For all other types, the given argument is used + for all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + n_steps (int, optional): The number of steps used by the approximation + method. Default: 50. + method (string, optional): Method for approximating the integral, + one of `riemann_right`, `riemann_left`, `riemann_middle`, + `riemann_trapezoid` or `gausslegendre`. + Default: `gausslegendre` if no method is provided. + internal_batch_size (int, optional): Divides total #steps * #examples + data points into chunks of size at most internal_batch_size, + which are computed (forward / backward passes) + sequentially. internal_batch_size must be at least equal to + #examples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain internal_batch_size / num_devices examples. + If internal_batch_size is None, then all evaluations are + processed in one batch. + Default: None + attribute_to_neuron_input (bool, optional): Indicates whether to + compute the attributions with respect to the neuron input + or output. If `attribute_to_neuron_input` is set to True + then the attributions will be computed with respect to + neuron's inputs, otherwise it will be computed with respect + to neuron's outputs. + Note that currently it is assumed that either the input + or the output of internal neuron, depending on whether we + attribute to the input or output, is a single tensor. + Support for multiple tensors will be added later. + Default: False + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + Integrated gradients for particular neuron with + respect to each input feature. + Attributions will always be the same size as the provided + inputs, with each value providing the attribution of the + corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> # It contains an attribute conv1, which is an instance of nn.conv2d, + >>> # and the output of this layer has dimensions Nx12x32x32. + >>> net = ImageClassifier() + >>> neuron_ig = NeuronIntegratedGradients(net, net.conv1) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # To compute neuron attribution, we need to provide the neuron + >>> # index for which attribution is desired. Since the layer output + >>> # is Nx12x32x32, we need a tuple in the form (0..11,0..31,0..31) + >>> # which indexes a particular neuron in the layer output. + >>> # For this example, we choose the index (4,1,2). + >>> # Computes neuron integrated gradients for neuron with + >>> # index (4,1,2). + >>> attribution = neuron_ig.attribute(input, (4,1,2)) + """ + ig = IntegratedGradients(self.forward_func, self.multiplies_by_inputs) + ig.gradient_func = construct_neuron_grad_fn( + self.layer, neuron_selector, self.device_ids, attribute_to_neuron_input + ) + # NOTE: using __wrapped__ to not log + # Return only attributions and not delta + return ig.attribute.__wrapped__( # type: ignore + ig, # self + inputs, + baselines, + additional_forward_args=additional_forward_args, + n_steps=n_steps, + method=method, + internal_batch_size=internal_batch_size, + ) + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs diff --git a/captum/attr/_core/noise_tunnel.py b/captum/attr/_core/noise_tunnel.py new file mode 100644 index 0000000000000000000000000000000000000000..2e87532b358625fdd611146555fbfe50c531d9b0 --- /dev/null +++ b/captum/attr/_core/noise_tunnel.py @@ -0,0 +1,450 @@ +#!/usr/bin/env python3 +from enum import Enum +from typing import Any, cast, List, Tuple, Union + +import torch +from captum._utils.common import ( + _expand_and_update_additional_forward_args, + _expand_and_update_baselines, + _expand_and_update_feature_mask, + _expand_and_update_target, + _format_output, + _format_tensor_into_tuples, + _is_tuple, +) +from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from captum.attr._utils.attribution import Attribution, GradientAttribution +from captum.attr._utils.common import _validate_noise_tunnel_type +from captum.log import log_usage +from torch import Tensor + + +class NoiseTunnelType(Enum): + smoothgrad = 1 + smoothgrad_sq = 2 + vargrad = 3 + + +SUPPORTED_NOISE_TUNNEL_TYPES = list(NoiseTunnelType.__members__.keys()) + + +class NoiseTunnel(Attribution): + r""" + Adds gaussian noise to each input in the batch `nt_samples` times + and applies the given attribution algorithm to each of the samples. + The attributions of the samples are combined based on the given noise + tunnel type (nt_type): + If nt_type is `smoothgrad`, the mean of the sampled attributions is + returned. This approximates smoothing the given attribution method + with a Gaussian Kernel. + If nt_type is `smoothgrad_sq`, the mean of the squared sample attributions + is returned. + If nt_type is `vargrad`, the variance of the sample attributions is + returned. + + More details about adding noise can be found in the following papers: + https://arxiv.org/abs/1810.03292 + https://arxiv.org/abs/1810.03307 + https://arxiv.org/abs/1706.03825 + https://arxiv.org/pdf/1806.10758 + This method currently also supports batches of multiple examples input, + however it can be computationally expensive depending on the model, + the dimensionality of the data and execution environment. + It is assumed that the batch size is the first dimension of input tensors. + """ + + def __init__(self, attribution_method: Attribution) -> None: + r""" + Args: + attribution_method (Attribution): An instance of any attribution algorithm + of type `Attribution`. E.g. Integrated Gradients, + Conductance or Saliency. + """ + self.attribution_method = attribution_method + self.is_delta_supported = self.attribution_method.has_convergence_delta() + self._multiply_by_inputs = self.attribution_method.multiplies_by_inputs + self.is_gradient_method = isinstance( + self.attribution_method, GradientAttribution + ) + Attribution.__init__(self, self.attribution_method.forward_func) + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs + + @log_usage() + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + nt_type: str = "smoothgrad", + nt_samples: int = 5, + nt_samples_batch_size: int = None, + stdevs: Union[float, Tuple[float, ...]] = 1.0, + draw_baseline_from_distrib: bool = False, + **kwargs: Any, + ) -> Union[ + Union[ + Tensor, + Tuple[Tensor, Tensor], + Tuple[Tensor, ...], + Tuple[Tuple[Tensor, ...], Tensor], + ] + ]: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which integrated + gradients are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + nt_type (string, optional): Smoothing type of the attributions. + `smoothgrad`, `smoothgrad_sq` or `vargrad` + Default: `smoothgrad` if `type` is not provided. + nt_samples (int, optional): The number of randomly generated examples + per sample in the input batch. Random examples are + generated by adding gaussian random noise to each sample. + Default: `5` if `nt_samples` is not provided. + nt_samples_batch_size (int, optional): The number of the `nt_samples` + that will be processed together. With the help + of this parameter we can avoid out of memory situation and + reduce the number of randomly generated examples per sample + in each batch. + Default: None if `nt_samples_batch_size` is not provided. In + this case all `nt_samples` will be processed together. + stdevs (float, or a tuple of floats optional): The standard deviation + of gaussian noise with zero mean that is added to each + input in the batch. If `stdevs` is a single float value + then that same value is used for all inputs. If it is + a tuple, then it must have the same length as the inputs + tuple. In this case, each stdev value in the stdevs tuple + corresponds to the input with the same index in the inputs + tuple. + Default: `1.0` if `stdevs` is not provided. + draw_baseline_from_distrib (bool, optional): Indicates whether to + randomly draw baseline samples from the `baselines` + distribution provided as an input tensor. + Default: False + **kwargs (Any, optional): Contains a list of arguments that are passed + to `attribution_method` attribution algorithm. + Any additional arguments that should be used for the + chosen attribution method should be included here. + For instance, such arguments include + `additional_forward_args` and `baselines`. + + Returns: + **attributions** or 2-element tuple of **attributions**, **delta**: + - **attributions** (*tensor* or tuple of *tensors*): + Attribution with + respect to each input feature. attributions will always be + the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + - **delta** (*float*, returned if return_convergence_delta=True): + Approximation error computed by the + attribution algorithm. Not all attribution algorithms + return delta value. It is computed only for some + algorithms, e.g. integrated gradients. + Delta is computed for each input in the batch + and represents the arithmetic mean + across all `nt_samples` perturbed tensors for that input. + + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> ig = IntegratedGradients(net) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Creates noise tunnel + >>> nt = NoiseTunnel(ig) + >>> # Generates 10 perturbed input tensors per image. + >>> # Computes integrated gradients for class 3 for each generated + >>> # input and averages attributions accros all 10 + >>> # perturbed inputs per image + >>> attribution = nt.attribute(input, nt_type='smoothgrad', + >>> nt_samples=10, target=3) + """ + + def add_noise_to_inputs(nt_samples_partition: int) -> Tuple[Tensor, ...]: + if isinstance(stdevs, tuple): + assert len(stdevs) == len(inputs), ( + "The number of input tensors " + "in {} must be equal to the number of stdevs values {}".format( + len(inputs), len(stdevs) + ) + ) + else: + assert isinstance( + stdevs, float + ), "stdevs must be type float. " "Given: {}".format(type(stdevs)) + stdevs_ = (stdevs,) * len(inputs) + return tuple( + add_noise_to_input(input, stdev, nt_samples_partition).requires_grad_() + if self.is_gradient_method + else add_noise_to_input(input, stdev, nt_samples_partition) + for (input, stdev) in zip(inputs, stdevs_) + ) + + def add_noise_to_input( + input: Tensor, stdev: float, nt_samples_partition: int + ) -> Tensor: + # batch size + bsz = input.shape[0] + + # expand input size by the number of drawn samples + input_expanded_size = (bsz * nt_samples_partition,) + input.shape[1:] + + # expand stdev for the shape of the input and number of drawn samples + stdev_expanded = torch.tensor(stdev, device=input.device).repeat( + input_expanded_size + ) + + # draws `np.prod(input_expanded_size)` samples from normal distribution + # with given input parametrization + # FIXME it look like it is very difficult to make torch.normal + # deterministic this needs an investigation + noise = torch.normal(0, stdev_expanded) + return input.repeat_interleave(nt_samples_partition, dim=0) + noise + + def update_sum_attribution_and_sq( + sum_attribution: List[Tensor], + sum_attribution_sq: List[Tensor], + attribution: Tensor, + i: int, + nt_samples_batch_size_inter: int, + ) -> None: + bsz = attribution.shape[0] // nt_samples_batch_size_inter + attribution_shape = cast( + Tuple[int, ...], (bsz, nt_samples_batch_size_inter) + ) + if len(attribution.shape) > 1: + attribution_shape += cast(Tuple[int, ...], tuple(attribution.shape[1:])) + + attribution = attribution.view(attribution_shape) + current_attribution_sum = attribution.sum(dim=1, keepdim=False) + current_attribution_sq = torch.sum(attribution ** 2, dim=1, keepdim=False) + + sum_attribution[i] = ( + current_attribution_sum + if not isinstance(sum_attribution[i], torch.Tensor) + else sum_attribution[i] + current_attribution_sum + ) + sum_attribution_sq[i] = ( + current_attribution_sq + if not isinstance(sum_attribution_sq[i], torch.Tensor) + else sum_attribution_sq[i] + current_attribution_sq + ) + + def compute_partial_attribution( + inputs_with_noise_partition: Tuple[Tensor, ...], kwargs_partition: Any + ) -> Tuple[Tuple[Tensor, ...], bool, Union[None, Tensor]]: + # smoothgrad_Attr(x) = 1 / n * sum(Attr(x + N(0, sigma^2)) + # NOTE: using __wrapped__ such that it does not log the inner logs + + attributions = attr_func.__wrapped__( # type: ignore + self.attribution_method, # self + inputs_with_noise_partition + if is_inputs_tuple + else inputs_with_noise_partition[0], + **kwargs_partition, + ) + delta = None + + if self.is_delta_supported and return_convergence_delta: + attributions, delta = attributions + + is_attrib_tuple = _is_tuple(attributions) + attributions = _format_tensor_into_tuples(attributions) + + return ( + cast(Tuple[Tensor, ...], attributions), + cast(bool, is_attrib_tuple), + delta, + ) + + def expand_partial(nt_samples_partition: int, kwargs_partial: dict) -> None: + # if the algorithm supports targets, baselines and/or + # additional_forward_args they will be expanded based + # on the nt_samples_partition and corresponding kwargs + # variables will be updated accordingly + _expand_and_update_additional_forward_args( + nt_samples_partition, kwargs_partial + ) + _expand_and_update_target(nt_samples_partition, kwargs_partial) + _expand_and_update_baselines( + cast(Tuple[Tensor, ...], inputs), + nt_samples_partition, + kwargs_partial, + draw_baseline_from_distrib=draw_baseline_from_distrib, + ) + _expand_and_update_feature_mask(nt_samples_partition, kwargs_partial) + + def compute_smoothing( + expected_attributions: Tuple[Union[Tensor], ...], + expected_attributions_sq: Tuple[Union[Tensor], ...], + ) -> Tuple[Tensor, ...]: + if NoiseTunnelType[nt_type] == NoiseTunnelType.smoothgrad: + return expected_attributions + + if NoiseTunnelType[nt_type] == NoiseTunnelType.smoothgrad_sq: + return expected_attributions_sq + + vargrad = tuple( + expected_attribution_sq - expected_attribution * expected_attribution + for expected_attribution, expected_attribution_sq in zip( + expected_attributions, expected_attributions_sq + ) + ) + + return cast(Tuple[Tensor, ...], vargrad) + + def update_partial_attribution_and_delta( + attributions_partial: Tuple[Tensor, ...], + delta_partial: Tensor, + sum_attributions: List[Tensor], + sum_attributions_sq: List[Tensor], + delta_partial_list: List[Tensor], + nt_samples_partial: int, + ) -> None: + for i, attribution_partial in enumerate(attributions_partial): + update_sum_attribution_and_sq( + sum_attributions, + sum_attributions_sq, + attribution_partial, + i, + nt_samples_partial, + ) + if self.is_delta_supported and return_convergence_delta: + delta_partial_list.append(delta_partial) + + return_convergence_delta: bool + return_convergence_delta = ( + "return_convergence_delta" in kwargs and kwargs["return_convergence_delta"] + ) + with torch.no_grad(): + nt_samples_batch_size = ( + nt_samples + if nt_samples_batch_size is None + else min(nt_samples, nt_samples_batch_size) + ) + + nt_samples_partition = nt_samples // nt_samples_batch_size + + # Keeps track whether original input is a tuple or not before + # converting it into a tuple. + is_inputs_tuple = isinstance(inputs, tuple) + + inputs = _format_tensor_into_tuples(inputs) # type: ignore + + _validate_noise_tunnel_type(nt_type, SUPPORTED_NOISE_TUNNEL_TYPES) + + kwargs_copy = kwargs.copy() + expand_partial(nt_samples_batch_size, kwargs_copy) + + attr_func = self.attribution_method.attribute + + sum_attributions: List[Union[None, Tensor]] = [] + sum_attributions_sq: List[Union[None, Tensor]] = [] + delta_partial_list: List[Tensor] = [] + + for _ in range(nt_samples_partition): + inputs_with_noise = add_noise_to_inputs(nt_samples_batch_size) + ( + attributions_partial, + is_attrib_tuple, + delta_partial, + ) = compute_partial_attribution(inputs_with_noise, kwargs_copy) + + if len(sum_attributions) == 0: + sum_attributions = [None] * len(attributions_partial) + sum_attributions_sq = [None] * len(attributions_partial) + + update_partial_attribution_and_delta( + cast(Tuple[Tensor, ...], attributions_partial), + cast(Tensor, delta_partial), + cast(List[Tensor], sum_attributions), + cast(List[Tensor], sum_attributions_sq), + delta_partial_list, + nt_samples_batch_size, + ) + + nt_samples_remaining = ( + nt_samples - nt_samples_partition * nt_samples_batch_size + ) + if nt_samples_remaining > 0: + inputs_with_noise = add_noise_to_inputs(nt_samples_remaining) + expand_partial(nt_samples_remaining, kwargs) + ( + attributions_partial, + is_attrib_tuple, + delta_partial, + ) = compute_partial_attribution(inputs_with_noise, kwargs) + + update_partial_attribution_and_delta( + cast(Tuple[Tensor, ...], attributions_partial), + cast(Tensor, delta_partial), + cast(List[Tensor], sum_attributions), + cast(List[Tensor], sum_attributions_sq), + delta_partial_list, + nt_samples_remaining, + ) + + expected_attributions = tuple( + [ + cast(Tensor, sum_attribution) * 1 / nt_samples + for sum_attribution in sum_attributions + ] + ) + expected_attributions_sq = tuple( + [ + cast(Tensor, sum_attribution_sq) * 1 / nt_samples + for sum_attribution_sq in sum_attributions_sq + ] + ) + attributions = compute_smoothing( + cast(Tuple[Tensor, ...], expected_attributions), + cast(Tuple[Tensor, ...], expected_attributions_sq), + ) + + delta = None + if self.is_delta_supported and return_convergence_delta: + delta = torch.cat(delta_partial_list, dim=0) + + return self._apply_checks_and_return_attributions( + attributions, is_attrib_tuple, return_convergence_delta, delta + ) + + def _apply_checks_and_return_attributions( + self, + attributions: Tuple[Tensor, ...], + is_attrib_tuple: bool, + return_convergence_delta: bool, + delta: Union[None, Tensor], + ) -> Union[ + TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] + ]: + attributions = _format_output(is_attrib_tuple, attributions) + + ret = ( + (attributions, cast(Tensor, delta)) + if self.is_delta_supported and return_convergence_delta + else attributions + ) + ret = cast( + Union[ + TensorOrTupleOfTensorsGeneric, + Tuple[TensorOrTupleOfTensorsGeneric, Tensor], + ], + ret, + ) + return ret + + def has_convergence_delta(self) -> bool: + return self.is_delta_supported diff --git a/captum/attr/_core/occlusion.py b/captum/attr/_core/occlusion.py new file mode 100644 index 0000000000000000000000000000000000000000..de148693fa13a7ca0927dd29ef91341ca5ae0958 --- /dev/null +++ b/captum/attr/_core/occlusion.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, Tuple, Union + +import numpy as np +import torch +from captum._utils.common import _format_tensor_into_tuples +from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr._core.feature_ablation import FeatureAblation +from captum.attr._utils.common import ( + _format_and_verify_sliding_window_shapes, + _format_and_verify_strides, +) +from captum.log import log_usage +from torch import Tensor + + +class Occlusion(FeatureAblation): + r""" + A perturbation based approach to compute attribution, involving + replacing each contiguous rectangular region with a given baseline / + reference, and computing the difference in output. For features located + in multiple regions (hyperrectangles), the corresponding output differences + are averaged to compute the attribution for that feature. + + The first patch is applied with the corner aligned with all indices 0, + and strides are applied until the entire dimension range is covered. Note + that this may cause the final patch applied in a direction to be cut-off + and thus smaller than the target occlusion shape. + + More details regarding the occlusion (or grey-box / sliding window) + method can be found in the original paper and in the DeepExplain + implementation. + https://arxiv.org/abs/1311.2901 + https://github.com/marcoancona/DeepExplain/blob/master/deepexplain\ + /tensorflow/methods.py#L401 + """ + + def __init__(self, forward_func: Callable) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or + any modification of it + """ + FeatureAblation.__init__(self, forward_func) + self.use_weights = True + + @log_usage() + def attribute( # type: ignore + self, + inputs: TensorOrTupleOfTensorsGeneric, + sliding_window_shapes: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...]], + strides: Union[ + None, int, Tuple[int, ...], Tuple[Union[int, Tuple[int, ...]], ...] + ] = None, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + perturbations_per_eval: int = 1, + show_progress: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which occlusion + attributions are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + sliding_window_shapes (tuple or tuple of tuples): Shape of patch + (hyperrectangle) to occlude each input. For a single + input tensor, this must be a tuple of length equal to the + number of dimensions of the input tensor - 1, defining + the dimensions of the patch. If the input tensor is 1-d, + this should be an empty tuple. For multiple input tensors, + this must be a tuple containing one tuple for each input + tensor defining the dimensions of the patch for that + input tensor, as described for the single tensor case. + strides (int or tuple or tuple of ints or tuple of tuples, optional): + This defines the step by which the occlusion hyperrectangle + should be shifted by in each direction for each iteration. + For a single tensor input, this can be either a single + integer, which is used as the step size in each direction, + or a tuple of integers matching the number of dimensions + in the occlusion shape, defining the step size in the + corresponding dimension. For multiple tensor inputs, this + can be either a tuple of integers, one for each input + tensor (used for all dimensions of the corresponding + tensor), or a tuple of tuples, providing the stride per + dimension for each tensor. + To ensure that all inputs are covered by at least one + sliding window, the stride for any dimension must be + <= the corresponding sliding window dimension if the + sliding window dimension is less than the input + dimension. + If None is provided, a stride of 1 is used for each + dimension of each input tensor. + Default: None + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define reference value which replaces each + feature when occluded. + Baselines can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or + broadcastable to match the dimensions of inputs + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which difference is computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. For all other types, + the given argument is used for all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + perturbations_per_eval (int, optional): Allows multiple occlusions + to be included in one batch (one call to forward_fn). + By default, perturbations_per_eval is 1, so each occlusion + is processed individually. + Each forward pass will contain a maximum of + perturbations_per_eval * #examples samples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain at most + (perturbations_per_eval * #examples) / num_devices + samples. + Default: 1 + show_progress (bool, optional): Displays the progress of computation. + It will try to use tqdm if available for advanced features + (e.g. time estimation). Otherwise, it will fallback to + a simple output of progress. + Default: False + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + The attributions with respect to each input feature. + Attributions will always be + the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + + Examples:: + + >>> # SimpleClassifier takes a single input tensor of size Nx4x4, + >>> # and returns an Nx3 tensor of class probabilities. + >>> net = SimpleClassifier() + >>> # Generating random input with size 2 x 4 x 4 + >>> input = torch.randn(2, 4, 4) + >>> # Defining Occlusion interpreter + >>> ablator = Occlusion(net) + >>> # Computes occlusion attribution, ablating each 3x3 patch, + >>> # shifting in each direction by the default of 1. + >>> attr = ablator.attribute(input, target=1, sliding_window_shapes=(3,3)) + """ + formatted_inputs = _format_tensor_into_tuples(inputs) + + # Formatting strides + strides = _format_and_verify_strides(strides, formatted_inputs) + + # Formatting sliding window shapes + sliding_window_shapes = _format_and_verify_sliding_window_shapes( + sliding_window_shapes, formatted_inputs + ) + + # Construct tensors from sliding window shapes + sliding_window_tensors = tuple( + torch.ones(window_shape, device=formatted_inputs[i].device) + for i, window_shape in enumerate(sliding_window_shapes) + ) + + # Construct counts, defining number of steps to make of occlusion block in + # each dimension. + shift_counts = [] + for i, inp in enumerate(formatted_inputs): + current_shape = np.subtract(inp.shape[1:], sliding_window_shapes[i]) + # Verify sliding window doesn't exceed input dimensions. + assert (np.array(current_shape) >= 0).all(), ( + "Sliding window dimensions {} cannot exceed input dimensions" "{}." + ).format(sliding_window_shapes[i], tuple(inp.shape[1:])) + # Stride cannot be larger than sliding window for any dimension where + # the sliding window doesn't cover the entire input. + assert np.logical_or( + np.array(current_shape) == 0, + np.array(strides[i]) <= sliding_window_shapes[i], + ).all(), ( + "Stride dimension {} cannot be larger than sliding window " + "shape dimension {}." + ).format( + strides[i], sliding_window_shapes[i] + ) + shift_counts.append( + tuple( + np.add(np.ceil(np.divide(current_shape, strides[i])).astype(int), 1) + ) + ) + + # Use ablation attribute method + return super().attribute.__wrapped__( + self, + inputs, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + perturbations_per_eval=perturbations_per_eval, + sliding_window_tensors=sliding_window_tensors, + shift_counts=tuple(shift_counts), + strides=strides, + show_progress=show_progress, + ) + + def _construct_ablated_input( + self, + expanded_input: Tensor, + input_mask: Union[None, Tensor], + baseline: Union[Tensor, int, float], + start_feature: int, + end_feature: int, + **kwargs: Any, + ) -> Tuple[Tensor, Tensor]: + r""" + Ablates given expanded_input tensor with given feature mask, feature range, + and baselines, and any additional arguments. + expanded_input shape is (num_features, num_examples, ...) + with remaining dimensions corresponding to remaining original tensor + dimensions and num_features = end_feature - start_feature. + + input_mask is None for occlusion, and the mask is constructed + using sliding_window_tensors, strides, and shift counts, which are provided in + kwargs. baseline is expected to + be broadcastable to match expanded_input. + + This method returns the ablated input tensor, which has the same + dimensionality as expanded_input as well as the corresponding mask with + either the same dimensionality as expanded_input or second dimension + being 1. This mask contains 1s in locations which have been ablated (and + thus counted towards ablations for that feature) and 0s otherwise. + """ + input_mask = torch.stack( + [ + self._occlusion_mask( + expanded_input, + j, + kwargs["sliding_window_tensors"], + kwargs["strides"], + kwargs["shift_counts"], + ) + for j in range(start_feature, end_feature) + ], + dim=0, + ).long() + ablated_tensor = ( + expanded_input + * ( + torch.ones(1, dtype=torch.long, device=expanded_input.device) + - input_mask + ).to(expanded_input.dtype) + ) + (baseline * input_mask.to(expanded_input.dtype)) + return ablated_tensor, input_mask + + def _occlusion_mask( + self, + expanded_input: Tensor, + ablated_feature_num: int, + sliding_window_tsr: Tensor, + strides: Union[int, Tuple[int, ...]], + shift_counts: Tuple[int, ...], + ) -> Tensor: + """ + This constructs the current occlusion mask, which is the appropriate + shift of the sliding window tensor based on the ablated feature number. + The feature number ranges between 0 and the product of the shift counts + (# of times the sliding window should be shifted in each dimension). + + First, the ablated feature number is converted to the number of steps in + each dimension from the origin, based on shift counts. This procedure + is similar to a base conversion, with the position values equal to shift + counts. The feature number is first taken modulo shift_counts[0] to + get the number of shifts in the first dimension (each shift + by shift_count[0]), and then divided by shift_count[0]. + The procedure is then continued for each element of shift_count. This + computes the total shift in each direction for the sliding window. + + We then need to compute the padding required after the window in each + dimension, which is equal to the total input dimension minus the sliding + window dimension minus the (left) shift amount. We construct the + array pad_values which contains the left and right pad values for each + dimension, in reverse order of dimensions, starting from the last one. + + Once these padding values are computed, we pad the sliding window tensor + of 1s with 0s appropriately, which is the corresponding mask, + and the result will match the input shape. + """ + remaining_total = ablated_feature_num + current_index = [] + for i, shift_count in enumerate(shift_counts): + stride = strides[i] if isinstance(strides, tuple) else strides + current_index.append((remaining_total % shift_count) * stride) + remaining_total = remaining_total // shift_count + + remaining_padding = np.subtract( + expanded_input.shape[2:], np.add(current_index, sliding_window_tsr.shape) + ) + pad_values = [ + val for pair in zip(remaining_padding, current_index) for val in pair + ] + pad_values.reverse() + padded_tensor = torch.nn.functional.pad( + sliding_window_tsr, tuple(pad_values) # type: ignore + ) + return padded_tensor.reshape((1,) + padded_tensor.shape) + + def _get_feature_range_and_mask( + self, input: Tensor, input_mask: Tensor, **kwargs: Any + ) -> Tuple[int, int, None]: + feature_max = np.prod(kwargs["shift_counts"]) + return 0, feature_max, None + + def _get_feature_counts(self, inputs, feature_mask, **kwargs): + """return the numbers of possible input features""" + return tuple(np.prod(counts).astype(int) for counts in kwargs["shift_counts"]) diff --git a/captum/attr/_core/saliency.py b/captum/attr/_core/saliency.py new file mode 100644 index 0000000000000000000000000000000000000000..7e2aeed5cdd44486c3f122ee7f73855084911623 --- /dev/null +++ b/captum/attr/_core/saliency.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 + +from typing import Any, Callable + +import torch +from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple +from captum._utils.gradient import ( + apply_gradient_requirements, + undo_gradient_requirements, +) +from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr._utils.attribution import GradientAttribution +from captum.log import log_usage + + +class Saliency(GradientAttribution): + r""" + A baseline approach for computing input attribution. It returns + the gradients with respect to inputs. If `abs` is set to True, which is + the default, the absolute value of the gradients is returned. + + More details about the approach can be found in the following paper: + https://arxiv.org/pdf/1312.6034.pdf + """ + + def __init__(self, forward_func: Callable) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or + any modification of it + """ + GradientAttribution.__init__(self, forward_func) + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + target: TargetType = None, + abs: bool = True, + additional_forward_args: Any = None, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + + inputs (tensor or tuple of tensors): Input for which integrated + gradients are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + abs (bool, optional): Returns absolute value of gradients if set + to True, otherwise returns the (signed) gradients if + False. + Default: True + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + Note that attributions are not computed with respect + to these arguments. + Default: None + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + The gradients with respect to each input feature. + Attributions will always be + the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> # Generating random input with size 2x3x3x32 + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Defining Saliency interpreter + >>> saliency = Saliency(net) + >>> # Computes saliency maps for class 3. + >>> attribution = saliency.attribute(input, target=3) + """ + # Keeps track whether original input is a tuple or not before + # converting it into a tuple. + is_inputs_tuple = _is_tuple(inputs) + + inputs = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(inputs) + + # No need to format additional_forward_args here. + # They are being formated in the `_run_forward` function in `common.py` + gradients = self.gradient_func( + self.forward_func, inputs, target, additional_forward_args + ) + if abs: + attributions = tuple(torch.abs(gradient) for gradient in gradients) + else: + attributions = gradients + undo_gradient_requirements(inputs, gradient_mask) + return _format_output(is_inputs_tuple, attributions) diff --git a/captum/attr/_core/shapley_value.py b/captum/attr/_core/shapley_value.py new file mode 100644 index 0000000000000000000000000000000000000000..a55d58a10a71caf0475a57140e694b1f0d68831a --- /dev/null +++ b/captum/attr/_core/shapley_value.py @@ -0,0 +1,767 @@ +#!/usr/bin/env python3 + +import itertools +import math +import warnings +from typing import Any, Callable, Iterable, Sequence, Tuple, Union +import sys + +import torch +from captum._utils.common import ( + _expand_additional_forward_args, + _expand_target, + _format_additional_forward_args, + _format_output, + _format_tensor_into_tuples, + _is_tuple, + _run_forward, +) +from captum._utils.progress import progress +from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr._utils.attribution import PerturbationAttribution +from captum.attr._utils.common import ( + _construct_default_feature_mask, + _find_output_mode_and_verify, + _format_input_baseline, + _tensorize_baseline, +) +from captum.log import log_usage +from torch import Tensor + + +def _all_perm_generator(num_features: int, num_samples: int) -> Iterable[Sequence[int]]: + for perm in itertools.permutations(range(num_features)): + yield perm + + +def _perm_generator(num_features: int, num_samples: int) -> Iterable[Sequence[int]]: + for _ in range(num_samples): + yield torch.randperm(num_features).tolist() + + +class ShapleyValueSampling(PerturbationAttribution): + """ + A perturbation based approach to compute attribution, based on the concept + of Shapley Values from cooperative game theory. This method involves taking + a random permutation of the input features and adding them one-by-one to the + given baseline. The output difference after adding each feature corresponds + to its attribution, and these difference are averaged when repeating this + process n_samples times, each time choosing a new random permutation of + the input features. + + By default, each scalar value within + the input tensors are taken as a feature and added independently. Passing + a feature mask, allows grouping features to be added together. This can + be used in cases such as images, where an entire segment or region + can be grouped together, measuring the importance of the segment + (feature group). Each input scalar in the group will be given the same + attribution value equal to the change in output as a result of adding back + the entire feature group. + + More details regarding Shapley Value sampling can be found in these papers: + https://www.sciencedirect.com/science/article/pii/S0305054808000804 + https://pdfs.semanticscholar.org/7715/bb1070691455d1fcfc6346ff458dbca77b2c.pdf + """ + + def __init__(self, forward_func: Callable) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or + any modification of it. The forward function can either + return a scalar per example, or a single scalar for the + full batch. If a single scalar is returned for the batch, + `perturbations_per_eval` must be 1, and the returned + attributions will have first dimension 1, corresponding to + feature importance across all examples in the batch. + """ + PerturbationAttribution.__init__(self, forward_func) + self.permutation_generator = _perm_generator + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, + n_samples: int = 25, + perturbations_per_eval: int = 1, + show_progress: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + NOTE: The feature_mask argument differs from other perturbation based + methods, since feature indices can overlap across tensors. See the + description of the feature_mask argument below for more details. + + Args: + + inputs (tensor or tuple of tensors): Input for which Shapley value + sampling attributions are computed. If forward_func takes + a single tensor as input, a single input tensor should + be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define reference value which replaces each + feature when ablated. + Baselines can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which difference is computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. For all other types, + the given argument is used for all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + feature_mask (tensor or tuple of tensors, optional): + feature_mask defines a mask for the input, grouping + features which should be added together. feature_mask + should contain the same number of tensors as inputs. + Each tensor should + be the same size as the corresponding input or + broadcastable to match the input tensor. Values across + all tensors should be integers in the range 0 to + num_features - 1, and indices corresponding to the same + feature should have the same value. + Note that features are grouped across tensors + (unlike feature ablation and occlusion), so + if the same index is used in different tensors, those + features are still grouped and added simultaneously. + If the forward function returns a single scalar per batch, + we enforce that the first dimension of each mask must be 1, + since attributions are returned batch-wise rather than per + example, so the attributions must correspond to the + same features (indices) in each input example. + If None, then a feature mask is constructed which assigns + each scalar within a tensor as a separate feature + Default: None + n_samples (int, optional): The number of feature permutations + tested. + Default: `25` if `n_samples` is not provided. + perturbations_per_eval (int, optional): Allows multiple ablations + to be processed simultaneously in one call to forward_fn. + Each forward pass will contain a maximum of + perturbations_per_eval * #examples samples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain at most + (perturbations_per_eval * #examples) / num_devices + samples. + If the forward function returns a single scalar per batch, + perturbations_per_eval must be set to 1. + Default: 1 + show_progress (bool, optional): Displays the progress of computation. + It will try to use tqdm if available for advanced features + (e.g. time estimation). Otherwise, it will fallback to + a simple output of progress. + Default: False + + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + The attributions with respect to each input feature. + If the forward function returns + a scalar value per example, attributions will be + the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If the forward function returns a scalar per batch, then + attribution tensor(s) will have first dimension 1 and + the remaining dimensions will match the input. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + + Examples:: + + >>> # SimpleClassifier takes a single input tensor of size Nx4x4, + >>> # and returns an Nx3 tensor of class probabilities. + >>> net = SimpleClassifier() + >>> # Generating random input with size 2 x 4 x 4 + >>> input = torch.randn(2, 4, 4) + >>> # Defining ShapleyValueSampling interpreter + >>> svs = ShapleyValueSampling(net) + >>> # Computes attribution, taking random orderings + >>> # of the 16 features and computing the output change when adding + >>> # each feature. We average over 200 trials (random permutations). + >>> attr = svs.attribute(input, target=1, n_samples=200) + + >>> # Alternatively, we may want to add features in groups, e.g. + >>> # grouping each 2x2 square of the inputs and adding them together. + >>> # This can be done by creating a feature mask as follows, which + >>> # defines the feature groups, e.g.: + >>> # +---+---+---+---+ + >>> # | 0 | 0 | 1 | 1 | + >>> # +---+---+---+---+ + >>> # | 0 | 0 | 1 | 1 | + >>> # +---+---+---+---+ + >>> # | 2 | 2 | 3 | 3 | + >>> # +---+---+---+---+ + >>> # | 2 | 2 | 3 | 3 | + >>> # +---+---+---+---+ + >>> # With this mask, all inputs with the same value are added + >>> # together, and the attribution for each input in the same + >>> # group (0, 1, 2, and 3) per example are the same. + >>> # The attributions can be calculated as follows: + >>> # feature mask has dimensions 1 x 4 x 4 + >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1], + >>> [2,2,3,3],[2,2,3,3]]]) + >>> attr = svs.attribute(input, target=1, feature_mask=feature_mask) + """ + # Keeps track whether original input is a tuple or not before + # converting it into a tuple. + is_inputs_tuple = _is_tuple(inputs) + inputs, baselines = _format_input_baseline(inputs, baselines) + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + feature_mask = ( + _format_tensor_into_tuples(feature_mask) + if feature_mask is not None + else None + ) + assert ( + isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1 + ), "Ablations per evaluation must be at least 1." + + with torch.no_grad(): + baselines = _tensorize_baseline(inputs, baselines) + num_examples = inputs[0].shape[0] + + if feature_mask is None: + feature_mask, total_features = _construct_default_feature_mask(inputs) + else: + total_features = int( + max(torch.max(single_mask).item() for single_mask in feature_mask) + + 1 + ) + + if show_progress: + attr_progress = progress( + desc=f"{self.get_name()} attribution", + total=self._get_n_evaluations( + total_features, n_samples, perturbations_per_eval + ) + + 1, # add 1 for the initial eval + ) + attr_progress.update(0) + + initial_eval = _run_forward( + self.forward_func, baselines, target, additional_forward_args + ) + + if show_progress: + attr_progress.update() + + agg_output_mode = _find_output_mode_and_verify( + initial_eval, num_examples, perturbations_per_eval, feature_mask + ) + # print("agg_output_mode: ", agg_output_mode) # Single boolean False + + # Initialize attribution totals and counts + total_attrib = [ + torch.zeros_like( + input[0:1] if agg_output_mode else input, dtype=torch.float + ) + for input in inputs + ] + # print("total_features: ", total_features) # Total unique instance segmentations + # print("total_attrib len: ", len(total_attrib)) 1 + # print("total_attrib shape: ", total_attrib[0].shape) # Same as input (1,1,224,224) + # print("total_attrib min: ", total_attrib[0].min()) 0 + # print("total_attrib max: ", total_attrib[0].max()) 0 + + iter_count = 0 + # Iterate for number of samples, generate a permutation of the features + # and evalute the incremental increase for each feature. + for feature_permutation in self.permutation_generator( + total_features, n_samples + ): + iter_count += 1 + prev_results = initial_eval + for ( + current_inputs, + current_add_args, + current_target, + current_masks, + ) in self._perturbation_generator( + inputs, + additional_forward_args, + target, + baselines, + feature_mask, + feature_permutation, + perturbations_per_eval, + ): + if sum(torch.sum(mask).item() for mask in current_masks) == 0: + warnings.warn( + "Feature mask is missing some integers between 0 and " + "num_features, for optimal performance, make sure each" + " consecutive integer corresponds to a feature." + ) + # modified_eval dimensions: 1D tensor with length + # equal to #num_examples * #features in batch + modified_eval = _run_forward( + self.forward_func, + current_inputs, + current_target, + current_add_args, + ) + if show_progress: + attr_progress.update() + # print("current_masks len: ", len(current_masks)) 1 + # print("current_masks[0] shape: ", current_masks[0].shape) # 1, 1, 1, 224, 224 + # print("current_masks unique: ", torch.unique(current_masks[0])) tensor([False, True] + # print("modified_eval shape: ", modified_eval.shape) # 1-dim (1) + # print("modified_eval: ", modified_eval) tensor([0.2161] + # print("num_examples: ", num_examples) 1 + # sys.exit() + + if agg_output_mode: + eval_diff = modified_eval - prev_results + prev_results = modified_eval + else: + all_eval = torch.cat((prev_results, modified_eval), dim=0) + # print("all_eval shape: ", all_eval.shape) 2 + eval_diff = all_eval[num_examples:] - all_eval[:-num_examples] + # print("all_eval: ", all_eval) + # print("eval_diff: ", eval_diff) # if 1-dim, modified_eval - prev_results (minus) + prev_results = all_eval[-num_examples:] + for j in range(len(total_attrib)): + current_eval_diff = eval_diff + if not agg_output_mode: + # current_eval_diff dimensions: + # (#features in batch, #num_examples, 1,.. 1) + # (contains 1 more dimension than inputs). This adds extra + # dimensions of 1 to make the tensor broadcastable with the + # inputs tensor. + current_eval_diff = current_eval_diff.reshape( + (-1, num_examples) + (len(inputs[j].shape) - 1) * (1,) + ) + total_attrib[j] += ( + current_eval_diff * current_masks[j].float() + ).sum(dim=0) # Sum of all masks(0,1) X eval diff + + if show_progress: + attr_progress.close() + + # Divide total attributions by number of random permutations and return + # formatted attributions. + attrib = tuple( + tensor_attrib_total / iter_count for tensor_attrib_total in total_attrib + ) + formatted_attr = _format_output(is_inputs_tuple, attrib) + return formatted_attr + + def _perturbation_generator( + self, + inputs: Tuple[Tensor, ...], + additional_args: Any, + target: TargetType, + baselines: Tuple[Tensor, ...], + input_masks: TensorOrTupleOfTensorsGeneric, + feature_permutation: Sequence[int], + perturbations_per_eval: int, + ) -> Iterable[Tuple[Tuple[Tensor, ...], Any, TargetType, Tuple[Tensor, ...]]]: + """ + This method is a generator which yields each perturbation to be evaluated + including inputs, additional_forward_args, targets, and mask. + """ + # current_tensors starts at baselines and includes each additional feature as + # added based on the permutation order. + current_tensors = baselines + current_tensors_list = [] + current_mask_list = [] + + # Compute repeated additional args and targets + additional_args_repeated = ( + _expand_additional_forward_args(additional_args, perturbations_per_eval) + if additional_args is not None + else None + ) + target_repeated = _expand_target(target, perturbations_per_eval) + for i in range(len(feature_permutation)): + current_tensors = tuple( + current * (~(mask == feature_permutation[i])).to(current.dtype) + + input * (mask == feature_permutation[i]).to(input.dtype) + for input, current, mask in zip(inputs, current_tensors, input_masks) + ) + current_tensors_list.append(current_tensors) + current_mask_list.append( + tuple(mask == feature_permutation[i] for mask in input_masks) + ) + if len(current_tensors_list) == perturbations_per_eval: + combined_inputs = tuple( + torch.cat(aligned_tensors, dim=0) + for aligned_tensors in zip(*current_tensors_list) + ) + combined_masks = tuple( + torch.stack(aligned_masks, dim=0) + for aligned_masks in zip(*current_mask_list) + ) + yield ( + combined_inputs, + additional_args_repeated, + target_repeated, + combined_masks, + ) + current_tensors_list = [] + current_mask_list = [] + + # Create batch with remaining evaluations, may not be a complete batch + # (= perturbations_per_eval) + if len(current_tensors_list) != 0: + additional_args_repeated = ( + _expand_additional_forward_args( + additional_args, len(current_tensors_list) + ) + if additional_args is not None + else None + ) + target_repeated = _expand_target(target, len(current_tensors_list)) + combined_inputs = tuple( + torch.cat(aligned_tensors, dim=0) + for aligned_tensors in zip(*current_tensors_list) + ) + combined_masks = tuple( + torch.stack(aligned_masks, dim=0) + for aligned_masks in zip(*current_mask_list) + ) + yield ( + combined_inputs, + additional_args_repeated, + target_repeated, + combined_masks, + ) + + def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval): + """return the total number of forward evaluations needed""" + return math.ceil(total_features / perturbations_per_eval) * n_samples + + +class ShapleyValues(ShapleyValueSampling): + """ + A perturbation based approach to compute attribution, based on the concept + of Shapley Values from cooperative game theory. This method involves taking + each permutation of the input features and adding them one-by-one to the + given baseline. The output difference after adding each feature corresponds + to its attribution, and these difference are averaged over all possible + random permutations of the input features. + + By default, each scalar value within + the input tensors are taken as a feature and added independently. Passing + a feature mask, allows grouping features to be added together. This can + be used in cases such as images, where an entire segment or region + can be grouped together, measuring the importance of the segment + (feature group). Each input scalar in the group will be given the same + attribution value equal to the change in output as a result of adding back + the entire feature group. + + More details regarding Shapley Values can be found in these papers: + https://apps.dtic.mil/dtic/tr/fulltext/u2/604084.pdf + https://www.sciencedirect.com/science/article/pii/S0305054808000804 + https://pdfs.semanticscholar.org/7715/bb1070691455d1fcfc6346ff458dbca77b2c.pdf + + NOTE: The method implemented here is very computationally intensive, and + should only be used with a very small number of features (e.g. < 7). + This implementation simply extends ShapleyValueSampling and + evaluates all permutations, leading to a total of n * n! evaluations for n + features. Shapley values can alternatively be computed with only 2^n + evaluations, and we plan to add this approach in the future. + """ + + def __init__(self, forward_func: Callable) -> None: + r""" + Args: + + forward_func (callable): The forward function of the model or + any modification of it. The forward function can either + return a scalar per example, or a single scalar for the + full batch. If a single scalar is returned for the batch, + `perturbations_per_eval` must be 1, and the returned + attributions will have first dimension 1, corresponding to + feature importance across all examples in the batch. + """ + ShapleyValueSampling.__init__(self, forward_func) + self.permutation_generator = _all_perm_generator + + @log_usage() + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, + perturbations_per_eval: int = 1, + show_progress: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + NOTE: The feature_mask argument differs from other perturbation based + methods, since feature indices can overlap across tensors. See the + description of the feature_mask argument below for more details. + + Args: + + inputs (tensor or tuple of tensors): Input for which Shapley value + sampling attributions are computed. If forward_func takes + a single tensor as input, a single input tensor should + be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define reference value which replaces each + feature when ablated. + Baselines can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which difference is computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. For all other types, + the given argument is used for all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + feature_mask (tensor or tuple of tensors, optional): + feature_mask defines a mask for the input, grouping + features which should be added together. feature_mask + should contain the same number of tensors as inputs. + Each tensor should + be the same size as the corresponding input or + broadcastable to match the input tensor. Values across + all tensors should be integers in the range 0 to + num_features - 1, and indices corresponding to the same + feature should have the same value. + Note that features are grouped across tensors + (unlike feature ablation and occlusion), so + if the same index is used in different tensors, those + features are still grouped and added simultaneously. + If the forward function returns a single scalar per batch, + we enforce that the first dimension of each mask must be 1, + since attributions are returned batch-wise rather than per + example, so the attributions must correspond to the + same features (indices) in each input example. + If None, then a feature mask is constructed which assigns + each scalar within a tensor as a separate feature + Default: None + perturbations_per_eval (int, optional): Allows multiple ablations + to be processed simultaneously in one call to forward_fn. + Each forward pass will contain a maximum of + perturbations_per_eval * #examples samples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain at most + (perturbations_per_eval * #examples) / num_devices + samples. + If the forward function returns a single scalar per batch, + perturbations_per_eval must be set to 1. + Default: 1 + show_progress (bool, optional): Displays the progress of computation. + It will try to use tqdm if available for advanced features + (e.g. time estimation). Otherwise, it will fallback to + a simple output of progress. + Default: False + Returns: + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + The attributions with respect to each input feature. + If the forward function returns + a scalar value per example, attributions will be + the same size as the provided inputs, with each value + providing the attribution of the corresponding input index. + If the forward function returns a scalar per batch, then + attribution tensor(s) will have first dimension 1 and + the remaining dimensions will match the input. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + + Examples:: + + >>> # SimpleClassifier takes a single input tensor of size Nx4x4, + >>> # and returns an Nx3 tensor of class probabilities. + >>> net = SimpleClassifier() + >>> # Generating random input with size 2 x 4 x 4 + >>> input = torch.randn(2, 4, 4) + + >>> # We may want to add features in groups, e.g. + >>> # grouping each 2x2 square of the inputs and adding them together. + >>> # This can be done by creating a feature mask as follows, which + >>> # defines the feature groups, e.g.: + >>> # +---+---+---+---+ + >>> # | 0 | 0 | 1 | 1 | + >>> # +---+---+---+---+ + >>> # | 0 | 0 | 1 | 1 | + >>> # +---+---+---+---+ + >>> # | 2 | 2 | 3 | 3 | + >>> # +---+---+---+---+ + >>> # | 2 | 2 | 3 | 3 | + >>> # +---+---+---+---+ + >>> # With this mask, all inputs with the same value are added + >>> # together, and the attribution for each input in the same + >>> # group (0, 1, 2, and 3) per example are the same. + >>> # The attributions can be calculated as follows: + >>> # feature mask has dimensions 1 x 4 x 4 + >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1], + >>> [2,2,3,3],[2,2,3,3]]]) + + >>> # With only 4 features, it is feasible to compute exact + >>> # Shapley Values. These can be computed as follows: + >>> sv = ShapleyValues(net) + >>> attr = sv.attribute(input, target=1, feature_mask=feature_mask) + """ + if feature_mask is None: + total_features = sum( + torch.numel(inp[0]) for inp in _format_tensor_into_tuples(inputs) + ) + else: + total_features = ( + int(max(torch.max(single_mask).item() for single_mask in feature_mask)) + + 1 + ) + + if total_features >= 10: + warnings.warn( + "You are attempting to compute Shapley Values with at least 10 " + "features, which will likely be very computationally expensive." + "Consider using Shapley Value Sampling instead." + ) + + return super().attribute.__wrapped__( + self, + inputs=inputs, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + feature_mask=feature_mask, + perturbations_per_eval=perturbations_per_eval, + show_progress=show_progress, + ) + + def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval): + """return the total number of forward evaluations needed""" + return math.ceil(total_features / perturbations_per_eval) * math.factorial( + total_features + ) diff --git a/captum/attr/_models/__init__.py b/captum/attr/_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/attr/_models/base.py b/captum/attr/_models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d57646c0da799a247f42931f127834c1b38882c7 --- /dev/null +++ b/captum/attr/_models/base.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python3 + +import warnings +from functools import reduce + +import torch +from torch.nn import Module + + +class InterpretableEmbeddingBase(Module): + r""" + Since some embedding vectors, e.g. word are created and assigned in + the embedding layers of Pytorch models we need a way to access + those layers, generate the embeddings and subtract the baseline. + To do so, we separate embedding layers from the model, compute the + embeddings separately and do all operations needed outside of the model. + The original embedding layer is being replaced by + `InterpretableEmbeddingBase` layer which passes already + precomputed embedding vectors to the layers below. + """ + + def __init__(self, embedding, full_name) -> None: + Module.__init__(self) + self.num_embeddings = getattr(embedding, "num_embeddings", None) + self.embedding_dim = getattr(embedding, "embedding_dim", None) + + self.embedding = embedding + self.full_name = full_name + + def forward(self, *inputs, **kwargs): + r""" + The forward function of a wrapper embedding layer that takes and returns + embedding layer. It allows embeddings to be created outside of the model + and passes them seamlessly to the preceding layers of the model. + + Args: + + *inputs (Any, optional): A sequence of inputs arguments that the + forward function takes. Since forward functions can take any + type and number of arguments, this will ensure that we can + execute the forward pass using interpretable embedding layer. + Note that if inputs are specified, it is assumed that the first + argument is the embedding tensor generated using the + `self.embedding` layer using all input arguments provided in + `inputs` and `kwargs`. + **kwargs (Any, optional): Similar to `inputs` we want to make sure + that our forward pass supports arbitrary number and type of + key-value arguments. If `inputs` is not provided, `kwargs` must + be provided and the first argument corresponds to the embedding + tensor generated using the `self.embedding`. Note that we make + here an assumption here that `kwargs` is an ordered dict which + is new in python 3.6 and is not guaranteed that it will + consistently remain that way in the newer versions. In case + current implementation doesn't work for special use cases, + it is encouraged to override `InterpretableEmbeddingBase` and + address those specifics in descendant classes. + + Returns: + + embedding_tensor (Tensor): + Returns a tensor which is the same as first argument passed + to the forward function. + It passes pre-computed embedding tensors to lower layers + without any modifications. + """ + assert len(inputs) > 0 or len(kwargs) > 0, ( + "No input arguments are provided to `InterpretableEmbeddingBase`." + "Input embedding tensor has to be provided as first argument to forward " + "function either through inputs argument or kwargs." + ) + return inputs[0] if len(inputs) > 0 else list(kwargs.values())[0] + + def indices_to_embeddings(self, *input, **kwargs): + r""" + Maps indices to corresponding embedding vectors. E.g. word embeddings + + Args: + + *input (Any, Optional): This can be a tensor(s) of input indices or any + other variable necessary to comput the embeddings. A typical + example of input indices are word or token indices. + **kwargs (Any, optional): Similar to `input` this can be any sequence + of key-value arguments necessary to compute final embedding + tensor. + Returns: + + tensor: + A tensor of word embeddings corresponding to the + indices specified in the input + """ + return self.embedding(*input, **kwargs) + + +class TokenReferenceBase: + r""" + A base class for creating reference (aka baseline) tensor for a sequence of + tokens. A typical example of such token is `PAD`. Users need to provide the + index of the reference token in the vocabulary as an argument to + `TokenReferenceBase` class. + """ + + def __init__(self, reference_token_idx=0) -> None: + self.reference_token_idx = reference_token_idx + + def generate_reference(self, sequence_length, device): + r""" + Generated reference tensor of given `sequence_length` using + `reference_token_idx`. + + Args: + sequence_length (int): The length of the reference sequence + device (torch.device): The device on which the reference tensor will + be created. + Returns: + + tensor: + A sequence of reference token with shape: + [sequence_length] + """ + return torch.tensor([self.reference_token_idx] * sequence_length, device=device) + + +def _get_deep_layer_name(obj, layer_names): + r""" + Traverses through the layer names that are separated by + dot in order to access the embedding layer. + """ + return reduce(getattr, layer_names.split("."), obj) + + +def _set_deep_layer_value(obj, layer_names, value): + r""" + Traverses through the layer names that are separated by + dot in order to access the embedding layer and update its value. + """ + layer_names = layer_names.split(".") + setattr(reduce(getattr, layer_names[:-1], obj), layer_names[-1], value) + + +def configure_interpretable_embedding_layer(model, embedding_layer_name="embedding"): + r""" + This method wraps model's embedding layer with an interpretable embedding + layer that allows us to access the embeddings through their indices. + + Args: + + model (torch.nn.Model): An instance of PyTorch model that contains embeddings. + embedding_layer_name (str, optional): The name of the embedding layer + in the `model` that we would like to make interpretable. + + Returns: + + interpretable_emb (tensor): An instance of `InterpretableEmbeddingBase` + embedding layer that wraps model's embedding layer that is being + accessed through `embedding_layer_name`. + + Examples:: + + >>> # Let's assume that we have a DocumentClassifier model that + >>> # has a word embedding layer named 'embedding'. + >>> # To make that layer interpretable we need to execute the + >>> # following command: + >>> net = DocumentClassifier() + >>> interpretable_emb = configure_interpretable_embedding_layer(net, + >>> 'embedding') + >>> # then we can use interpretable embedding to convert our + >>> # word indices into embeddings. + >>> # Let's assume that we have the following word indices + >>> input_indices = torch.tensor([1, 0, 2]) + >>> # we can access word embeddings for those indices with the command + >>> # line stated below. + >>> input_emb = interpretable_emb.indices_to_embeddings(input_indices) + >>> # Let's assume that we want to apply integrated gradients to + >>> # our model and that target attribution class is 3 + >>> ig = IntegratedGradients(net) + >>> attribution = ig.attribute(input_emb, target=3) + >>> # after we finish the interpretation we need to remove + >>> # interpretable embedding layer with the following command: + >>> remove_interpretable_embedding_layer(net, interpretable_emb) + + """ + embedding_layer = _get_deep_layer_name(model, embedding_layer_name) + assert ( + embedding_layer.__class__ is not InterpretableEmbeddingBase + ), "InterpretableEmbeddingBase has already been configured for layer {}".format( + embedding_layer_name + ) + warnings.warn( + "In order to make embedding layers more interpretable they will " + "be replaced with an interpretable embedding layer which wraps the " + "original embedding layer and takes word embedding vectors as inputs of " + "the forward function. This allows us to generate baselines for word " + "embeddings and compute attributions for each embedding dimension. " + "The original embedding layer must be set " + "back by calling `remove_interpretable_embedding_layer` function " + "after model interpretation is finished. " + ) + interpretable_emb = InterpretableEmbeddingBase( + embedding_layer, embedding_layer_name + ) + _set_deep_layer_value(model, embedding_layer_name, interpretable_emb) + return interpretable_emb + + +def remove_interpretable_embedding_layer(model, interpretable_emb): + r""" + Removes interpretable embedding layer and sets back original + embedding layer in the model. + + Args: + + model (torch.nn.Module): An instance of PyTorch model that contains embeddings + interpretable_emb (tensor): An instance of `InterpretableEmbeddingBase` + that was originally created in + `configure_interpretable_embedding_layer` function and has + to be removed after interpretation is finished. + + Examples:: + + >>> # Let's assume that we have a DocumentClassifier model that + >>> # has a word embedding layer named 'embedding'. + >>> # To make that layer interpretable we need to execute the + >>> # following command: + >>> net = DocumentClassifier() + >>> interpretable_emb = configure_interpretable_embedding_layer(net, + >>> 'embedding') + >>> # then we can use interpretable embedding to convert our + >>> # word indices into embeddings. + >>> # Let's assume that we have the following word indices + >>> input_indices = torch.tensor([1, 0, 2]) + >>> # we can access word embeddings for those indices with the command + >>> # line stated below. + >>> input_emb = interpretable_emb.indices_to_embeddings(input_indices) + >>> # Let's assume that we want to apply integrated gradients to + >>> # our model and that target attribution class is 3 + >>> ig = IntegratedGradients(net) + >>> attribution = ig.attribute(input_emb, target=3) + >>> # after we finish the interpretation we need to remove + >>> # interpretable embedding layer with the following command: + >>> remove_interpretable_embedding_layer(net, interpretable_emb) + + """ + _set_deep_layer_value( + model, interpretable_emb.full_name, interpretable_emb.embedding + ) diff --git a/captum/attr/_models/pytext.py b/captum/attr/_models/pytext.py new file mode 100644 index 0000000000000000000000000000000000000000..f5e6af3a0462286b36db520599226fa29250e113 --- /dev/null +++ b/captum/attr/_models/pytext.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +from collections import defaultdict + +import torch +from pytext.models.embeddings.dict_embedding import DictEmbedding +from pytext.models.embeddings.word_embedding import WordEmbedding +from pytext.models.model import EmbeddingBase, EmbeddingList + + +class PyTextInterpretableEmbedding(EmbeddingBase): + r""" + In PyText DocNN models we need a way to access word embedding layers, + generate the embeddings and subtract the baseline. + To do so, we separate embedding layers from the model, compute the embeddings + separately and do all operations needed outside of the model. + The original embedding layer is being replaced by `PyTextInterpretableEmbedding` + layer which passes precomputed embedding vectors to lower layers. + """ + + def __init__(self, embeddings) -> None: + self.embedding_dims = [embedding.embedding_dim for embedding in embeddings] + super().__init__(sum(self.embedding_dims)) + self.embeddings = embeddings + + def forward(self, input): + r""" + The forward pass of embedding layer. This can be for the text or any + type of embedding. + + Args + + input: Input embeddings tensor + + Return + + output: Output tensor is the same as input. It passes through + the embedding tensors to lower layers without any + modifications + """ + return input + + def get_attribution_map(self, attributions): + r""" + After attribution scores are computed for an input embedding vector + we need to split it up into attribution sub tensors for each + feature type: word, dict and other types + + TODO: we can potentally also output tuples of attributions. This might be + a better option. We'll work on this in a separate diff. + + Args + + attributions: A tensor that contains attribution values for each input + field. It usually has the same dimensions as the input + tensor + + Return + + attribution_map: A dictionary of feature_type and attribution values + + """ + begin = 0 + attribution_map = defaultdict() + for embedding, embedding_size in zip(self.embeddings, self.embedding_dims): + end = begin + embedding_size + if isinstance(embedding, WordEmbedding): + attribution_map["word"] = attributions[:, :, begin:end] + elif isinstance(embedding, DictEmbedding): + attribution_map["dict"] = attributions[:, :, begin:end] + else: + raise NotImplementedError( + "Currently only word and dict " "embeddings are supported" + ) + begin = end + + return attribution_map + + +class BaselineGenerator: + r""" + This is an example input baseline generator for DocNN model which uses + word and dict features. + """ + PAD = "" + + def __init__(self, model, data_handler, device) -> None: + self.model = model + self.data_handler = data_handler + if "dict_feat" in data_handler.features: + self.vocab_dict = data_handler.features["dict_feat"].vocab + if "word_feat" in data_handler.features: + self.vocab_word = data_handler.features["word_feat"].vocab + + self.baseline_single_word_feature = self._generate_baseline_single_word_feature( + device + ) + self.baseline_single_dict_feature = self._generate_baseline_single_dict_feature( + device + ) + + def generate_baseline(self, integ_grads_embeddings, seq_length): + r""" + Generates baseline for input word and dict features. In the future we + will extend it to support char and other features as well. + This baseline is entirely based on the `` token. + + Args + + integ_grads_embeddings: A reference to integrated gradients embedding + layer + seq_length: The length of each sequence which depends on batch size + + Return + baseline: A tuple of feature baselines + Each feature type has a corresponding baseline tensor + in the tuple. + Currently only Dict and Word feature types are supported + """ + baseline = [] + for embedding in integ_grads_embeddings.embeddings: + if isinstance(embedding, WordEmbedding): + baseline.append(self._generate_word_baseline(seq_length)) + elif isinstance(embedding, DictEmbedding): + baseline.append(self._generate_dict_baseline(seq_length)) + else: + raise NotImplementedError( + "Currently only word and dict " "embeddings are supported" + ) + return tuple(baseline) + + def _generate_baseline_single_word_feature(self, device): + return ( + torch.tensor( + [self.vocab_word.stoi[self.PAD] if hasattr(self, "vocab_word") else 0] + ) + .unsqueeze(0) + .to(device) + ) + + def _generate_baseline_single_dict_feature(self, device): + r"""Generate dict features based on Assistant's case study by using + sia_transformer: + fbcode/assistant/sia/transformer/sia_transformer.py + sia_transformer generates dict features in a special gazetter format + See `fbsource/fbcode/pytext/models/embeddings/dict_embedding.py` + + It generates word dict feature embeddings for each word token. + + The output of SIATransformer after running it on `` token + looks as following: + OutputRecord(tokens=['<', 'pad', '>'], + token_ranges=[(0, 1), (1, 4), (4, 5)], + gazetteer_feats=['', '', ''], + gazetteer_feat_lengths=[1, 1, 1], + gazetteer_feat_weights=[0.0, 0.0, 0.0], + characters=[['<', '', ''], + ['p', 'a', 'd'], ['>', '', '']], + pretrained_token_embedding=[ ], dense_feats=None) + """ + gazetteer_feats = [self.PAD, self.PAD, self.PAD] + gazetteer_feat_lengths = [1, 1, 1] + gazetteer_feat_weights = [0.0, 0.0, 0.0] + gazetteer_feat_id = ( + torch.tensor( + [ + self.vocab_dict.stoi[gazetteer_feat] + if hasattr(self, "vocab_dict") + else 0 + for gazetteer_feat in gazetteer_feats + ] + ) + .unsqueeze(0) + .to(device) + ) + gazetteer_feat_weights = ( + torch.tensor(gazetteer_feat_weights).unsqueeze(0).to(device) + ) + gazetteer_feat_lengths = ( + torch.tensor(gazetteer_feat_lengths).to(device).view(1, -1)[:, 1] + ) + + return (gazetteer_feat_id, gazetteer_feat_weights, gazetteer_feat_lengths) + + def _generate_word_baseline(self, seq_length): + return self.baseline_single_word_feature.repeat(1, seq_length) + + def _generate_dict_baseline(self, seq_length): + return ( + self.baseline_single_dict_feature[0].repeat(1, seq_length), + self.baseline_single_dict_feature[1].repeat(1, seq_length), + self.baseline_single_dict_feature[2].repeat(1, seq_length), + ) + + +def configure_task_integ_grads_embeddings(task): + r""" + Wraps Pytext's DocNN model embedding with `IntegratedGradientsEmbedding` for + a given input task. + IntegratedGradientsEmbedding allows to perform baseline related operations + + Args + + task: DocNN task reference + + Returns + + integrated_gradients_embedding_lst: The embedding layer which contains + IntegratedGradientsEmbedding as a wrapper over the original + embeddings of the model + + """ + integrated_gradients_embedding_lst = configure_model_integ_grads_embeddings( + task.model + ) + task.model.embedding = integrated_gradients_embedding_lst + return integrated_gradients_embedding_lst[0] + + +def configure_model_integ_grads_embeddings(model): + r""" + Wraps Pytext's DocNN model embedding with `IntegratedGradientsEmbedding` + IntegratedGradientsEmbedding allows to perform baseline related operations + + Args + + model: a reference to DocModel + + Returns + + integrated_gradients_embedding_lst: The embedding layer which contains + IntegratedGradientsEmbedding as a wrapper over the original + embeddings of the model + + """ + embeddings = model.embedding + integrated_gradients_embedding = PyTextInterpretableEmbedding(embeddings) + return EmbeddingList([integrated_gradients_embedding], False) + + +def reshape_word_features(word_features): + r""" + Creates one-sample batch for word features for sanity check purposes + + Args + + word_features: A tensor of diemnsions #words x #embeddings + + Return + + word_features: A tensor of dimensions 1 x #words x #embeddings + + """ + return word_features.unsqueeze(0) + + +def reshape_dict_features( + dict_feature_id_batch, dict_weight_batch, dict_seq_len_batch, seq_length, idx +): + r""" + Creates one-sample batch for dict features for sanity check purposes + It reads and reshapes id, weight and seq_length feature arrays for given + input index `idx` from the input batch + + Args + + dict_feature_id_batch: The batch tensor for ids + dict_weight_matrix: The batch tensor for weights + dict_seq_len_matrix: The batch tensor for sequence length + seq_length: The number of tokens per sequence + idx: The index of sample in the batch + + Return + + dict_feature_ids: A tensor of dimensions [ bsz x # dict feature embeddings] + dict_feature_weights: [ bsz x # dict feature embeddings] + dict_feature_lens: [ bsz * seq_length ] + + """ + dict_feature_ids = dict_feature_id_batch[idx].unsqueeze(0) + dict_feature_weights = dict_weight_batch[idx].unsqueeze(0) + dict_feature_lens = dict_seq_len_batch[idx].unsqueeze(0) + return (dict_feature_ids, dict_feature_weights, dict_feature_lens) diff --git a/captum/attr/_utils/__init__.py b/captum/attr/_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/attr/_utils/approximation_methods.py b/captum/attr/_utils/approximation_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..9d63e90c1acf7a17899cd252130078e8380728e3 --- /dev/null +++ b/captum/attr/_utils/approximation_methods.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +from enum import Enum +from typing import Callable, List, Tuple + +import torch + + +class Riemann(Enum): + left = 1 + right = 2 + middle = 3 + trapezoid = 4 + + +SUPPORTED_RIEMANN_METHODS = [ + "riemann_left", + "riemann_right", + "riemann_middle", + "riemann_trapezoid", +] + +SUPPORTED_METHODS = SUPPORTED_RIEMANN_METHODS + ["gausslegendre"] + + +def approximation_parameters( + method: str, +) -> Tuple[Callable[[int], List[float]], Callable[[int], List[float]]]: + r"""Retrieves parameters for the input approximation `method` + + Args: + method: The name of the approximation method. Currently only `riemann` + and gauss legendre are + """ + if method in SUPPORTED_RIEMANN_METHODS: + return riemann_builders(method=Riemann[method.split("_")[-1]]) + if method == "gausslegendre": + return gauss_legendre_builders() + raise ValueError("Invalid integral approximation method name: {}".format(method)) + + +def riemann_builders( + method: Riemann = Riemann.trapezoid, +) -> Tuple[Callable[[int], List[float]], Callable[[int], List[float]]]: + r"""Step sizes are identical and alphas are scaled in [0, 1] + + Args: + + n: The number of integration steps + method: `left`, `right`, `middle` and `trapezoid` riemann + + Returns: + 2-element tuple of **step_sizes**, **alphas**: + - **step_sizes** (*callable*): + `step_sizes` takes the number of steps as an + input argument and returns an array of steps sizes which + sum is smaller than or equal to one. + + - **alphas** (*callable*): + `alphas` takes the number of steps as an input argument + and returns the multipliers/coefficients for the inputs + of integrand in the range of [0, 1] + + """ + + def step_sizes(n: int) -> List[float]: + assert n > 1, "The number of steps has to be larger than one" + deltas = [1 / n] * n + if method == Riemann.trapezoid: + deltas[0] /= 2 + deltas[-1] /= 2 + return deltas + + def alphas(n: int) -> List[float]: + assert n > 1, "The number of steps has to be larger than one" + if method == Riemann.trapezoid: + return torch.linspace(0, 1, n).tolist() + elif method == Riemann.left: + return torch.linspace(0, 1 - 1 / n, n).tolist() + elif method == Riemann.middle: + return torch.linspace(1 / (2 * n), 1 - 1 / (2 * n), n).tolist() + elif method == Riemann.right: + return torch.linspace(1 / n, 1, n).tolist() + else: + raise AssertionError("Provided Reimann approximation method is not valid.") + # This is not a standard riemann method but in many cases it + # leades to faster approaximation. Test cases for small number of steps + # do not make sense but for larger number of steps the approximation is + # better therefore leaving this option available + # if method == 'riemann_include_endpoints': + # return [i / (n - 1) for i in range(n)] + + return step_sizes, alphas + + +def gauss_legendre_builders() -> Tuple[ + Callable[[int], List[float]], Callable[[int], List[float]] +]: + r"""Numpy's `np.polynomial.legendre` function helps to compute step sizes + and alpha coefficients using gauss-legendre quadrature rule. + Since numpy returns the integration parameters in different scales we need to + rescale them to adjust to the desired scale. + + Gauss Legendre quadrature rule for approximating the integrals was originally + proposed by [Xue Feng and her intern Hauroun Habeeb] + (https://research.fb.com/people/feng-xue/). + + Args: + + n (int): The number of integration steps + + Returns: + 2-element tuple of **step_sizes**, **alphas**: + - **step_sizes** (*callable*): + `step_sizes` takes the number of steps as an + input argument and returns an array of steps sizes which + sum is smaller than or equal to one. + + - **alphas** (*callable*): + `alphas` takes the number of steps as an input argument + and returns the multipliers/coefficients for the inputs + of integrand in the range of [0, 1] + + """ + + # allow using riemann even without np + import numpy as np + + def step_sizes(n: int) -> List[float]: + assert n > 0, "The number of steps has to be larger than zero" + # Scaling from 2 to 1 + return list(0.5 * np.polynomial.legendre.leggauss(n)[1]) + + def alphas(n: int) -> List[float]: + assert n > 0, "The number of steps has to be larger than zero" + # Scaling from [-1, 1] to [0, 1] + return list(0.5 * (1 + np.polynomial.legendre.leggauss(n)[0])) + + return step_sizes, alphas diff --git a/captum/attr/_utils/attribution.py b/captum/attr/_utils/attribution.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b6e9d35c25211bacb341679dcb4e701cc68606 --- /dev/null +++ b/captum/attr/_utils/attribution.py @@ -0,0 +1,477 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, cast, Generic, List, Tuple, Type, Union + +import torch +import torch.nn.functional as F +from captum._utils.common import ( + _format_additional_forward_args, + _format_tensor_into_tuples, + _run_forward, + _validate_target, +) +from captum._utils.gradient import compute_gradients +from captum._utils.typing import ModuleOrModuleList, TargetType +from captum.attr._utils.common import ( + _format_input_baseline, + _sum_rows, + _tensorize_baseline, + _validate_input, +) +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module + + +class Attribution: + r""" + All attribution algorithms extend this class. It enforces its child classes + to extend and override core `attribute` method. + """ + + def __init__(self, forward_func: Callable) -> None: + r""" + Args: + forward_func (callable or torch.nn.Module): This can either be an instance + of pytorch model or any modification of model's forward + function. + """ + self.forward_func = forward_func + + attribute: Callable + r""" + This method computes and returns the attribution values for each input tensor. + Deriving classes are responsible for implementing its logic accordingly. + + Specific attribution algorithms that extend this class take relevant + arguments. + + Args: + + inputs (tensor or tuple of tensors): Input for which attribution + is computed. It can be provided as a single tensor or + a tuple of multiple tensors. If multiple input tensors + are provided, the batch sizes must be aligned accross all + tensors. + + + Returns: + + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + Attribution values for each + input tensor. The `attributions` have the same shape and + dimensionality as the inputs. + If a single tensor is provided as inputs, a single tensor + is returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + + """ + + @property + def multiplies_by_inputs(self): + return False + + def has_convergence_delta(self) -> bool: + r""" + This method informs the user whether the attribution algorithm provides + a convergence delta (aka an approximation error) or not. Convergence + delta may serve as a proxy of correctness of attribution algorithm's + approximation. If deriving attribution class provides a + `compute_convergence_delta` method, it should + override both `compute_convergence_delta` and `has_convergence_delta` methods. + + Returns: + bool: + Returns whether the attribution algorithm + provides a convergence delta (aka approximation error) or not. + + """ + return False + + compute_convergence_delta: Callable + r""" + The attribution algorithms which derive `Attribution` class and provide + convergence delta (aka approximation error) should implement this method. + Convergence delta can be computed based on certain properties of the + attribution alogrithms. + + Args: + + attributions (tensor or tuple of tensors): Attribution scores that + are precomputed by an attribution algorithm. + Attributions can be provided in form of a single tensor + or a tuple of those. It is assumed that attribution + tensor's dimension 0 corresponds to the number of + examples, and if multiple input tensors are provided, + the examples must be aligned appropriately. + *args (optional): Additonal arguments that are used by the + sub-classes depending on the specific implementation + of `compute_convergence_delta`. + + Returns: + + *tensor* of **deltas**: + - **deltas** (*tensor*): + Depending on specific implementaion of + sub-classes, convergence delta can be returned per + sample in form of a tensor or it can be aggregated + across multuple samples and returned in form of a + single floating point tensor. + """ + + @classmethod + def get_name(cls: Type["Attribution"]) -> str: + r""" + Create readable class name by inserting a space before any capital + characters besides the very first. + + Returns: + str: a readable class name + Example: + for a class called IntegratedGradients, we return the string + 'Integrated Gradients' + """ + return "".join( + [ + char if char.islower() or idx == 0 else " " + char + for idx, char in enumerate(cls.__name__) + ] + ) + + +class GradientAttribution(Attribution): + r""" + All gradient based attribution algorithms extend this class. It requires a + forward function, which most commonly is the forward function of the model + that we want to interpret or the model itself. + """ + + def __init__(self, forward_func: Callable) -> None: + r""" + Args: + + forward_func (callable or torch.nn.Module): This can either be an instance + of pytorch model or any modification of model's forward + function. + """ + Attribution.__init__(self, forward_func) + self.gradient_func = compute_gradients + + @log_usage() + def compute_convergence_delta( + self, + attributions: Union[Tensor, Tuple[Tensor, ...]], + start_point: Union[ + None, int, float, Tensor, Tuple[Union[int, float, Tensor], ...] + ], + end_point: Union[Tensor, Tuple[Tensor, ...]], + target: TargetType = None, + additional_forward_args: Any = None, + ) -> Tensor: + r""" + Here we provide a specific implementation for `compute_convergence_delta` + which is based on a common property among gradient-based attribution algorithms. + In the literature sometimes it is also called completeness axiom. Completeness + axiom states that the sum of the attribution must be equal to the differences of + NN Models's function at its end and start points. In other words: + sum(attributions) - (F(end_point) - F(start_point)) is close to zero. + Returned delta of this method is defined as above stated difference. + + This implementation assumes that both the `start_point` and `end_point` have + the same shape and dimensionality. It also assumes that the target must have + the same number of examples as the `start_point` and the `end_point` in case + it is provided in form of a list or a non-singleton tensor. + + Args: + + attributions (tensor or tuple of tensors): Precomputed attribution + scores. The user can compute those using any attribution + algorithm. It is assumed the the shape and the + dimensionality of attributions must match the shape and + the dimensionality of `start_point` and `end_point`. + It also assumes that the attribution tensor's + dimension 0 corresponds to the number of + examples, and if multiple input tensors are provided, + the examples must be aligned appropriately. + start_point (tensor or tuple of tensors, optional): `start_point` + is passed as an input to model's forward function. It + is the starting point of attributions' approximation. + It is assumed that both `start_point` and `end_point` + have the same shape and dimensionality. + end_point (tensor or tuple of tensors): `end_point` + is passed as an input to model's forward function. It + is the end point of attributions' approximation. + It is assumed that both `start_point` and `end_point` + have the same shape and dimensionality. + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. + `additional_forward_args` is used both for `start_point` + and `end_point` when computing the forward pass. + Default: None + + Returns: + + *tensor* of **deltas**: + - **deltas** (*tensor*): + This implementation returns convergence delta per + sample. Deriving sub-classes may do any type of aggregation + of those values, if necessary. + """ + end_point, start_point = _format_input_baseline(end_point, start_point) + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + # tensorizing start_point in case it is a scalar or one example baseline + # If the batch size is large we could potentially also tensorize only one + # sample and expand the output to the rest of the elements in the batch + start_point = _tensorize_baseline(end_point, start_point) + + attributions = _format_tensor_into_tuples(attributions) + + # verify that the attributions and end_point match on 1st dimension + for attribution, end_point_tnsr in zip(attributions, end_point): + assert end_point_tnsr.shape[0] == attribution.shape[0], ( + "Attributions tensor and the end_point must match on the first" + " dimension but found attribution: {} and end_point: {}".format( + attribution.shape[0], end_point_tnsr.shape[0] + ) + ) + + num_samples = end_point[0].shape[0] + _validate_input(end_point, start_point) + _validate_target(num_samples, target) + + with torch.no_grad(): + start_out_sum = _sum_rows( + _run_forward( + self.forward_func, start_point, target, additional_forward_args + ) + ) + + end_out_sum = _sum_rows( + _run_forward( + self.forward_func, end_point, target, additional_forward_args + ) + ) + row_sums = [_sum_rows(attribution) for attribution in attributions] + attr_sum = torch.stack( + [cast(Tensor, sum(row_sum)) for row_sum in zip(*row_sums)] + ) + _delta = attr_sum - (end_out_sum - start_out_sum) + return _delta + + +class PerturbationAttribution(Attribution): + r""" + All perturbation based attribution algorithms extend this class. It requires a + forward function, which most commonly is the forward function of the model + that we want to interpret or the model itself. + """ + + def __init__(self, forward_func: Callable) -> None: + r""" + Args: + + forward_func (callable or torch.nn.Module): This can either be an instance + of pytorch model or any modification of model's forward + function. + """ + Attribution.__init__(self, forward_func) + + @property + def multiplies_by_inputs(self): + return True + + +class InternalAttribution(Attribution, Generic[ModuleOrModuleList]): + layer: ModuleOrModuleList + r""" + Shared base class for LayerAttrubution and NeuronAttribution, + attribution types that require a model and a particular layer. + """ + + def __init__( + self, + forward_func: Callable, + layer: ModuleOrModuleList, + device_ids: Union[None, List[int]] = None, + ) -> None: + r""" + Args: + + forward_func (callable or torch.nn.Module): This can either be an instance + of pytorch model or any modification of model's forward + function. + layer (torch.nn.Module): Layer for which output attributions are computed. + Output size of attribute matches that of layer output. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model, which allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + """ + Attribution.__init__(self, forward_func) + self.layer = layer + self.device_ids = device_ids + + +class LayerAttribution(InternalAttribution): + r""" + Layer attribution provides attribution values for the given layer, quanitfying + the importance of each neuron within the given layer's output. The output + attribution of calling attribute on a LayerAttribution object always matches + the size of the layer output. + """ + + def __init__( + self, + forward_func: Callable, + layer: ModuleOrModuleList, + device_ids: Union[None, List[int]] = None, + ) -> None: + r""" + Args: + + forward_func (callable or torch.nn.Module): This can either be an instance + of pytorch model or any modification of model's forward + function. + layer (torch.nn.Module): Layer for which output attributions are computed. + Output size of attribute matches that of layer output. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model, which allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + """ + InternalAttribution.__init__(self, forward_func, layer, device_ids) + + @staticmethod + def interpolate( + layer_attribution: Tensor, + interpolate_dims: Union[int, Tuple[int, ...]], + interpolate_mode: str = "nearest", + ) -> Tensor: + r""" + Interpolates given 3D, 4D or 5D layer attribution to given dimensions. + This is often utilized to upsample the attribution of a convolutional layer + to the size of an input, which allows visualizing in the input space. + + Args: + + layer_attribution (torch.Tensor): Tensor of given layer attributions. + interpolate_dims (int or tuple): Upsampled dimensions. The + number of elements must be the number of dimensions + of layer_attribution - 2, since the first dimension + corresponds to number of examples and the second is + assumed to correspond to the number of channels. + interpolate_mode (str): Method for interpolation, which + must be a valid input interpolation mode for + torch.nn.functional. These methods are + "nearest", "area", "linear" (3D-only), "bilinear" + (4D-only), "bicubic" (4D-only), "trilinear" (5D-only) + based on the number of dimensions of the given layer + attribution. + + Returns: + *tensor* of upsampled **attributions**: + - **attributions** (*tensor*): + Upsampled layer attributions with first 2 dimensions matching + slayer_attribution and remaining dimensions given by + interpolate_dims. + """ + return F.interpolate(layer_attribution, interpolate_dims, mode=interpolate_mode) + + +class NeuronAttribution(InternalAttribution): + r""" + Neuron attribution provides input attribution for a given neuron, quanitfying + the importance of each input feature in the activation of a particular neuron. + Calling attribute on a NeuronAttribution object requires also providing + the index of the neuron in the output of the given layer for which attributions + are required. + The output attribution of calling attribute on a NeuronAttribution object + always matches the size of the input. + """ + + def __init__( + self, + forward_func: Callable, + layer: Module, + device_ids: Union[None, List[int]] = None, + ) -> None: + r""" + Args: + + forward_func (callable or torch.nn.Module): This can either be an instance + of pytorch model or any modification of model's forward + function. + layer (torch.nn.Module): Layer for which output attributions are computed. + Output size of attribute matches that of layer output. + device_ids (list(int)): Device ID list, necessary only if forward_func + applies a DataParallel model, which allows reconstruction of + intermediate outputs from batched results across devices. + If forward_func is given as the DataParallel model itself, + then it is not necessary to provide this argument. + """ + InternalAttribution.__init__(self, forward_func, layer, device_ids) + + attribute: Callable + r""" + This method computes and returns the neuron attribution values for each + input tensor. Deriving classes are responsible for implementing + its logic accordingly. + + Specific attribution algorithms that extend this class take relevant + arguments. + + Args: + + inputs: A single high dimensional input tensor or a tuple of them. + neuron_selector (int or tuple): Tuple providing index of neuron in output + of given layer for which attribution is desired. Length of + this tuple must be one less than the number of + dimensions in the output of the given layer (since + dimension 0 corresponds to number of examples). + + Returns: + + *tensor* or tuple of *tensors* of **attributions**: + - **attributions** (*tensor* or tuple of *tensors*): + Attribution values for + each input vector. The `attributions` have the + dimensionality of inputs. + """ diff --git a/captum/attr/_utils/batching.py b/captum/attr/_utils/batching.py new file mode 100644 index 0000000000000000000000000000000000000000..611517b3f96280a36a522709e2182f812814095f --- /dev/null +++ b/captum/attr/_utils/batching.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +import typing +import warnings +from typing import Any, Callable, Iterator, Tuple, Union + +import torch +from captum._utils.common import ( + _format_additional_forward_args, + _format_output, + _format_tensor_into_tuples, + _reduce_list, +) +from captum._utils.typing import ( + TargetType, + TensorOrTupleOfTensorsGeneric, + TupleOrTensorOrBoolGeneric, +) +from captum.attr._utils.approximation_methods import approximation_parameters +from torch import Tensor + + +def _batch_attribution( + attr_method, + num_examples, + internal_batch_size, + n_steps, + include_endpoint=False, + **kwargs, +): + """ + This method applies internal batching to given attribution method, dividing + the total steps into batches and running each independently and sequentially, + adding each result to compute the total attribution. + + Step sizes and alphas are spliced for each batch and passed explicitly for each + call to _attribute. + + kwargs include all argument necessary to pass to each attribute call, except + for n_steps, which is computed based on the number of steps for the batch. + + include_endpoint ensures that one step overlaps between each batch, which + is necessary for some methods, particularly LayerConductance. + """ + if internal_batch_size < num_examples: + warnings.warn( + "Internal batch size cannot be less than the number of input examples. " + "Defaulting to internal batch size of %d equal to the number of examples." + % num_examples + ) + # Number of steps for each batch + step_count = max(1, internal_batch_size // num_examples) + if include_endpoint: + if step_count < 2: + step_count = 2 + warnings.warn( + "This method computes finite differences between evaluations at " + "consecutive steps, so internal batch size must be at least twice " + "the number of examples. Defaulting to internal batch size of %d" + " equal to twice the number of examples." % (2 * num_examples) + ) + + total_attr = None + cumulative_steps = 0 + step_sizes_func, alphas_func = approximation_parameters(kwargs["method"]) + full_step_sizes = step_sizes_func(n_steps) + full_alphas = alphas_func(n_steps) + + while cumulative_steps < n_steps: + start_step = cumulative_steps + end_step = min(start_step + step_count, n_steps) + batch_steps = end_step - start_step + + if include_endpoint: + batch_steps -= 1 + + step_sizes = full_step_sizes[start_step:end_step] + alphas = full_alphas[start_step:end_step] + current_attr = attr_method._attribute( + **kwargs, n_steps=batch_steps, step_sizes_and_alphas=(step_sizes, alphas) + ) + + if total_attr is None: + total_attr = current_attr + else: + if isinstance(total_attr, Tensor): + total_attr = total_attr + current_attr.detach() + else: + total_attr = tuple( + current.detach() + prev_total + for current, prev_total in zip(current_attr, total_attr) + ) + if include_endpoint and end_step < n_steps: + cumulative_steps = end_step - 1 + else: + cumulative_steps = end_step + return total_attr + + +@typing.overload +def _tuple_splice_range(inputs: None, start: int, end: int) -> None: + ... + + +@typing.overload +def _tuple_splice_range(inputs: Tuple, start: int, end: int) -> Tuple: + ... + + +def _tuple_splice_range( + inputs: Union[None, Tuple], start: int, end: int +) -> Union[None, Tuple]: + """ + Splices each tensor element of given tuple (inputs) from range start + (inclusive) to end (non-inclusive) on its first dimension. If element + is not a Tensor, it is left unchanged. It is assumed that all tensor elements + have the same first dimension (corresponding to number of examples). + The returned value is a tuple with the same length as inputs, with Tensors + spliced appropriately. + """ + assert start < end, "Start point must precede end point for batch splicing." + if inputs is None: + return None + return tuple( + inp[start:end] if isinstance(inp, torch.Tensor) else inp for inp in inputs + ) + + +def _batched_generator( + inputs: TensorOrTupleOfTensorsGeneric, + additional_forward_args: Any = None, + target_ind: TargetType = None, + internal_batch_size: Union[None, int] = None, +) -> Iterator[Tuple[Tuple[Tensor, ...], Any, TargetType]]: + """ + Returns a generator which returns corresponding chunks of size internal_batch_size + for both inputs and additional_forward_args. If batch size is None, + generator only includes original inputs and additional args. + """ + assert internal_batch_size is None or ( + isinstance(internal_batch_size, int) and internal_batch_size > 0 + ), "Batch size must be greater than 0." + inputs = _format_tensor_into_tuples(inputs) + additional_forward_args = _format_additional_forward_args(additional_forward_args) + num_examples = inputs[0].shape[0] + # TODO Reconsider this check if _batched_generator is used for non gradient-based + # attribution algorithms + if not (inputs[0] * 1).requires_grad: + warnings.warn( + """It looks like that the attribution for a gradient-based method is + computed in a `torch.no_grad` block or perhaps the inputs have no + requires_grad.""" + ) + if internal_batch_size is None: + yield inputs, additional_forward_args, target_ind + else: + for current_total in range(0, num_examples, internal_batch_size): + with torch.autograd.set_grad_enabled(True): + inputs_splice = _tuple_splice_range( + inputs, current_total, current_total + internal_batch_size + ) + yield inputs_splice, _tuple_splice_range( + additional_forward_args, + current_total, + current_total + internal_batch_size, + ), target_ind[ + current_total : current_total + internal_batch_size + ] if isinstance( + target_ind, list + ) or ( + isinstance(target_ind, torch.Tensor) and target_ind.numel() > 1 + ) else target_ind + + +def _batched_operator( + operator: Callable[..., TupleOrTensorOrBoolGeneric], + inputs: TensorOrTupleOfTensorsGeneric, + additional_forward_args: Any = None, + target_ind: TargetType = None, + internal_batch_size: Union[None, int] = None, + **kwargs: Any, +) -> TupleOrTensorOrBoolGeneric: + """ + Batches the operation of the given operator, applying the given batch size + to inputs and additional forward arguments, and returning the concatenation + of the results of each batch. + """ + all_outputs = [ + operator( + inputs=input, + additional_forward_args=additional, + target_ind=target, + **kwargs, + ) + for input, additional, target in _batched_generator( + inputs, additional_forward_args, target_ind, internal_batch_size + ) + ] + return _reduce_list(all_outputs) + + +def _select_example(curr_arg: Any, index: int, bsz: int) -> Any: + if curr_arg is None: + return None + is_tuple = isinstance(curr_arg, tuple) + if not is_tuple: + curr_arg = (curr_arg,) + selected_arg = [] + for i in range(len(curr_arg)): + if isinstance(curr_arg[i], (Tensor, list)) and len(curr_arg[i]) == bsz: + selected_arg.append(curr_arg[i][index : index + 1]) + else: + selected_arg.append(curr_arg[i]) + return _format_output(is_tuple, tuple(selected_arg)) + + +def _batch_example_iterator(bsz: int, *args) -> Iterator: + """ + Batches the provided argument. + """ + for i in range(bsz): + curr_args = [_select_example(args[j], i, bsz) for j in range(len(args))] + yield tuple(curr_args) diff --git a/captum/attr/_utils/class_summarizer.py b/captum/attr/_utils/class_summarizer.py new file mode 100644 index 0000000000000000000000000000000000000000..24857118665e140e672fc6cb4d7821d3dc69c694 --- /dev/null +++ b/captum/attr/_utils/class_summarizer.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +from collections import defaultdict +from typing import Any, Dict, List, Optional, Union + +from captum._utils.common import _format_tensor_into_tuples +from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr._utils.stat import Stat +from captum.attr._utils.summarizer import Summarizer +from captum.log import log_usage +from torch import Tensor + + +class ClassSummarizer(Summarizer): + r""" + Used to keep track of summaries for associated classes. The + classes/labels can be of any type that are supported by `dict`. + + This also keeps track of an aggregate of all class summaries. + """ + + @log_usage() + def __init__(self, stats: List[Stat]) -> None: + Summarizer.__init__.__wrapped__(self, stats) + self.summaries: Dict[Any, Summarizer] = defaultdict( + lambda: Summarizer(stats=stats) + ) + + def update( # type: ignore + self, + x: TensorOrTupleOfTensorsGeneric, + labels: TargetType = None, + ): + r""" + Updates the stats of the summarizer, optionally associated to classes. + + This accepts either a single tensor to summarise or a tuple of tensors. + + Args: + x (Tensor or Tuple[Tensor, ...]): + The input tensor to be summarised. The first + dimension of this input must be associated to + the batch size of the inputs. + labels (int, tuple, tensor or list, optional): + The associated labels for `x`. If Any, we + assume `labels` represents the label for all inputs in `x`. + + If this is None we simply aggregate the total summary. + """ + if labels is None: + super().update(x) + return + + x = _format_tensor_into_tuples(x) + + num_labels = 1 + + labels_typed: Union[List[Any], Tensor] + if isinstance(labels, list) or isinstance(labels, Tensor): + labels_typed = labels + num_labels = len(labels) # = labels.size(0) if tensor + else: + labels_typed = [labels] + + # mypy doesn't realise I have made the int a list + if len(labels_typed) > 1: + for x_i in x: + assert x_i.size(0) == num_labels, ( + "batch size does not equal amount of labels; " + "please ensure length of labels is equal to 1 " + "or to the `batch_size` corresponding to the " + "number of examples in the input(s)" + ) + + batch_size = x[0].size(0) + + for i in range(batch_size): + tensors_to_summarize = tuple(tensor[i] for tensor in x) + tensors_to_summarize_copy = tuple(tensor[i].clone() for tensor in x) + label = labels_typed[0] if len(labels_typed) == 1 else labels_typed[i] + + self.summaries[label].update(tensors_to_summarize) + super().update(tensors_to_summarize_copy) + + @property + def class_summaries( + self, + ) -> Dict[ + Any, Union[None, Dict[str, Optional[Tensor]], List[Dict[str, Optional[Tensor]]]] + ]: + r""" + Returns: + The summaries for each class. + """ + return {key: value.summary for key, value in self.summaries.items()} diff --git a/captum/attr/_utils/common.py b/captum/attr/_utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..34979764be1d1124ba3432361c2ec8063073444c --- /dev/null +++ b/captum/attr/_utils/common.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python3 +import typing +from inspect import signature +from typing import Any, Callable, List, Tuple, TYPE_CHECKING, Union + +import torch +from captum._utils.common import ( + _format_baseline, + _format_output, + _format_tensor_into_tuples, + _validate_input as _validate_input_basic, +) +from captum._utils.typing import ( + BaselineType, + Literal, + TargetType, + TensorOrTupleOfTensorsGeneric, +) +from captum.attr._utils.approximation_methods import SUPPORTED_METHODS +from torch import Tensor + +if TYPE_CHECKING: + from captum.attr._utils.attribution import GradientAttribution + + +def _sum_rows(input: Tensor) -> Tensor: + return input.reshape(input.shape[0], -1).sum(1) + + +def _validate_target(num_samples: int, target: TargetType) -> None: + if isinstance(target, list) or ( + isinstance(target, torch.Tensor) and torch.numel(target) > 1 + ): + assert num_samples == len(target), ( + "The number of samples provied in the" + "input {} does not match with the number of targets. {}".format( + num_samples, len(target) + ) + ) + + +def _validate_input( + inputs: Tuple[Tensor, ...], + baselines: Tuple[Union[Tensor, int, float], ...], + n_steps: int = 50, + method: str = "riemann_trapezoid", + draw_baseline_from_distrib: bool = False, +) -> None: + _validate_input_basic(inputs, baselines, draw_baseline_from_distrib) + assert ( + n_steps >= 0 + ), "The number of steps must be a positive integer. " "Given: {}".format(n_steps) + + assert ( + method in SUPPORTED_METHODS + ), "Approximation method must be one for the following {}. " "Given {}".format( + SUPPORTED_METHODS, method + ) + + +def _validate_noise_tunnel_type( + nt_type: str, supported_noise_tunnel_types: List[str] +) -> None: + assert nt_type in supported_noise_tunnel_types, ( + "Noise types must be either `smoothgrad`, `smoothgrad_sq` or `vargrad`. " + "Given {}".format(nt_type) + ) + + +@typing.overload +def _format_input_baseline( + inputs: Union[Tensor, Tuple[Tensor, ...]], + baselines: Union[Tensor, Tuple[Tensor, ...]], +) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: + ... + + +@typing.overload +def _format_input_baseline( + inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType +) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: + ... + + +def _format_input_baseline( + inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType +) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: + inputs = _format_tensor_into_tuples(inputs) + baselines = _format_baseline(baselines, inputs) + return inputs, baselines + + +# This function can potentially be merged with the `format_baseline` function +# however, since currently not all algorithms support baselines of type +# callable this will be kept in a separate function. +@typing.overload +def _format_callable_baseline( + baselines: Union[ + None, + Callable[..., Union[Tensor, Tuple[Tensor, ...]]], + Tensor, + Tuple[Tensor, ...], + ], + inputs: Union[Tensor, Tuple[Tensor, ...]], +) -> Tuple[Tensor, ...]: + ... + + +@typing.overload +def _format_callable_baseline( + baselines: Union[ + None, + Callable[..., Union[Tensor, Tuple[Tensor, ...]]], + Tensor, + int, + float, + Tuple[Union[Tensor, int, float], ...], + ], + inputs: Union[Tensor, Tuple[Tensor, ...]], +) -> Tuple[Union[Tensor, int, float], ...]: + ... + + +def _format_callable_baseline( + baselines: Union[ + None, + Callable[..., Union[Tensor, Tuple[Tensor, ...]]], + Tensor, + int, + float, + Tuple[Union[Tensor, int, float], ...], + ], + inputs: Union[Tensor, Tuple[Tensor, ...]], +) -> Tuple[Union[Tensor, int, float], ...]: + if callable(baselines): + # Note: this assumes that if baselines is a function and if it takes + # arguments, then the first argument is the `inputs`. + # This can be expanded in the future with better type checks + baseline_parameters = signature(baselines).parameters + if len(baseline_parameters) == 0: + baselines = baselines() + else: + baselines = baselines(inputs) + return _format_baseline(baselines, _format_tensor_into_tuples(inputs)) + + +def _format_and_verify_strides( + strides: Union[None, int, Tuple[int, ...], Tuple[Union[int, Tuple[int, ...]], ...]], + inputs: Tuple[Tensor, ...], +) -> Tuple[Union[int, Tuple[int, ...]], ...]: + # Formats strides, which are necessary for occlusion + # Assumes inputs are already formatted (in tuple) + if strides is None: + strides = tuple(1 for input in inputs) + if len(inputs) == 1 and not (isinstance(strides, tuple) and len(strides) == 1): + strides = (strides,) # type: ignore + assert isinstance(strides, tuple) and len(strides) == len( + inputs + ), "Strides must be provided for each input tensor." + for i in range(len(inputs)): + assert isinstance(strides[i], int) or ( + isinstance(strides[i], tuple) + and len(strides[i]) == len(inputs[i].shape) - 1 # type: ignore + ), ( + "Stride for input index {} is {}, which is invalid for input with " + "shape {}. It must be either an int or a tuple with length equal to " + "len(input_shape) - 1." + ).format( + i, strides[i], inputs[i].shape + ) + + return strides + + +def _format_and_verify_sliding_window_shapes( + sliding_window_shapes: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...]], + inputs: Tuple[Tensor, ...], +) -> Tuple[Tuple[int, ...], ...]: + # Formats shapes of sliding windows, which is necessary for occlusion + # Assumes inputs is already formatted (in tuple) + if isinstance(sliding_window_shapes[0], int): + sliding_window_shapes = (sliding_window_shapes,) # type: ignore + sliding_window_shapes: Tuple[Tuple[int, ...], ...] + assert len(sliding_window_shapes) == len( + inputs + ), "Must provide sliding window dimensions for each input tensor." + for i in range(len(inputs)): + assert ( + isinstance(sliding_window_shapes[i], tuple) + and len(sliding_window_shapes[i]) == len(inputs[i].shape) - 1 + ), ( + "Occlusion shape for input index {} is {} but should be a tuple with " + "{} dimensions." + ).format( + i, sliding_window_shapes[i], len(inputs[i].shape) - 1 + ) + return sliding_window_shapes + + +@typing.overload +def _compute_conv_delta_and_format_attrs( + attr_algo: "GradientAttribution", + return_convergence_delta: bool, + attributions: Tuple[Tensor, ...], + start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], + end_point: Union[Tensor, Tuple[Tensor, ...]], + additional_forward_args: Any, + target: TargetType, + is_inputs_tuple: Literal[False] = False, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + ... + + +@typing.overload +def _compute_conv_delta_and_format_attrs( + attr_algo: "GradientAttribution", + return_convergence_delta: bool, + attributions: Tuple[Tensor, ...], + start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], + end_point: Union[Tensor, Tuple[Tensor, ...]], + additional_forward_args: Any, + target: TargetType, + is_inputs_tuple: Literal[True], +) -> Union[Tuple[Tensor, ...], Tuple[Tuple[Tensor, ...], Tensor]]: + ... + + +# FIXME: GradientAttribution is provided as a string due to a circular import. +# This should be fixed when common is refactored into separate files. +def _compute_conv_delta_and_format_attrs( + attr_algo: "GradientAttribution", + return_convergence_delta: bool, + attributions: Tuple[Tensor, ...], + start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], + end_point: Union[Tensor, Tuple[Tensor, ...]], + additional_forward_args: Any, + target: TargetType, + is_inputs_tuple: bool = False, +) -> Union[ + Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor] +]: + if return_convergence_delta: + # computes convergence error + delta = attr_algo.compute_convergence_delta( + attributions, + start_point, + end_point, + additional_forward_args=additional_forward_args, + target=target, + ) + return _format_output(is_inputs_tuple, attributions), delta + else: + return _format_output(is_inputs_tuple, attributions) + + +def _tensorize_baseline( + inputs: Tuple[Tensor, ...], baselines: Tuple[Union[int, float, Tensor], ...] +) -> Tuple[Tensor, ...]: + def _tensorize_single_baseline(baseline, input): + if isinstance(baseline, (int, float)): + return torch.full_like(input, baseline) + if input.shape[0] > baseline.shape[0] and baseline.shape[0] == 1: + return torch.cat([baseline] * input.shape[0]) + return baseline + + assert isinstance(inputs, tuple) and isinstance(baselines, tuple), ( + "inputs and baselines must" + "have tuple type but found baselines: {} and inputs: {}".format( + type(baselines), type(inputs) + ) + ) + return tuple( + _tensorize_single_baseline(baseline, input) + for baseline, input in zip(baselines, inputs) + ) + + +def _reshape_and_sum( + tensor_input: Tensor, num_steps: int, num_examples: int, layer_size: Tuple[int, ...] +) -> Tensor: + # Used for attribution methods which perform integration + # Sums across integration steps by reshaping tensor to + # (num_steps, num_examples, (layer_size)) and summing over + # dimension 0. Returns a tensor of size (num_examples, (layer_size)) + return torch.sum( + tensor_input.reshape((num_steps, num_examples) + layer_size), dim=0 + ) + + +def _call_custom_attribution_func( + custom_attribution_func: Callable[..., Tuple[Tensor, ...]], + multipliers: Tuple[Tensor, ...], + inputs: Tuple[Tensor, ...], + baselines: Tuple[Tensor, ...], +) -> Tuple[Tensor, ...]: + assert callable(custom_attribution_func), ( + "`custom_attribution_func`" + " must be a callable function but {} provided".format( + type(custom_attribution_func) + ) + ) + custom_attr_func_params = signature(custom_attribution_func).parameters + + if len(custom_attr_func_params) == 1: + return custom_attribution_func(multipliers) + elif len(custom_attr_func_params) == 2: + return custom_attribution_func(multipliers, inputs) + elif len(custom_attr_func_params) == 3: + return custom_attribution_func(multipliers, inputs, baselines) + else: + raise AssertionError( + "`custom_attribution_func` must take at least one and at most 3 arguments." + ) + + +def _find_output_mode_and_verify( + initial_eval: Union[int, float, Tensor], + num_examples: int, + perturbations_per_eval: int, + feature_mask: Union[None, TensorOrTupleOfTensorsGeneric], +) -> bool: + """ + This method identifies whether the model outputs a single output for a batch + (agg_output_mode = True) or whether it outputs a single output per example + (agg_output_mode = False) and returns agg_output_mode. The method also + verifies that perturbations_per_eval is 1 in the case that agg_output_mode is True + and also verifies that the first dimension of each feature mask if the model + returns a single output for a batch. + """ + if isinstance(initial_eval, (int, float)) or ( + isinstance(initial_eval, torch.Tensor) + and ( + len(initial_eval.shape) == 0 + or (num_examples > 1 and initial_eval.numel() == 1) + ) + ): + agg_output_mode = True + assert ( + perturbations_per_eval == 1 + ), "Cannot have perturbations_per_eval > 1 when function returns scalar." + if feature_mask is not None: + for single_mask in feature_mask: + assert single_mask.shape[0] == 1, ( + "Cannot provide different masks for each example when function " + "returns a scalar." + ) + else: + agg_output_mode = False + assert ( + isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1 + ), "Target should identify a single element in the model output." + return agg_output_mode + + +def _construct_default_feature_mask( + inputs: Tuple[Tensor, ...] +) -> Tuple[Tuple[Tensor, ...], int]: + feature_mask = [] + current_num_features = 0 + for i in range(len(inputs)): + num_features = torch.numel(inputs[i][0]) + feature_mask.append( + current_num_features + + torch.reshape( + torch.arange(num_features, device=inputs[i].device), + inputs[i][0:1].shape, + ) + ) + current_num_features += num_features + total_features = current_num_features + feature_mask = tuple(feature_mask) + return feature_mask, total_features diff --git a/captum/attr/_utils/custom_modules.py b/captum/attr/_utils/custom_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..8dea72054f2b4a83536d7e5d8f74200d3cb0c805 --- /dev/null +++ b/captum/attr/_utils/custom_modules.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +import torch.nn as nn + + +class Addition_Module(nn.Module): + """Custom addition module that uses multiple inputs to assure correct relevance + propagation. Any addition in a forward function needs to be replaced with the + module before using LRP.""" + + def __init__(self) -> None: + super().__init__() + + def forward(self, x1, x2): + return x1 + x2 diff --git a/captum/attr/_utils/input_layer_wrapper.py b/captum/attr/_utils/input_layer_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..402319fb4380f27937d1a76eae387c0734d1b66d --- /dev/null +++ b/captum/attr/_utils/input_layer_wrapper.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 + +import inspect +from typing import Any + +import torch.nn as nn + + +class InputIdentity(nn.Module): + def __init__(self, input_name: str) -> None: + r""" + The identity operation + + Args: + input_name (str) + The name of the input this layer is associated to. For debugging + purposes. + """ + super().__init__() + self.input_name = input_name + + def forward(self, x): + return x + + +class ModelInputWrapper(nn.Module): + def __init__(self, module_to_wrap: nn.Module) -> None: + r""" + This is a convenience class. This wraps a model via first feeding the + model's inputs to separate layers (one for each input) and then feeding + the (unmodified) inputs to the underlying model (`module_to_wrap`). Each + input is fed through an `InputIdentity` layer/module. This class does + not change how you feed inputs to your model, so feel free to use your + model as you normally would. + + To access a wrapped input layer, simply access it via the `input_maps` + ModuleDict, e.g. to get the corresponding module for input "x", simply + provide/write `my_wrapped_module.input_maps["x"]` + + This is done such that one can use layer attribution methods on inputs. + Which should allow you to use mix layers with inputs with these + attribution methods. This is especially useful multimodal models which + input discrete features (mapped to embeddings, such as text) and regular + continuous feature vectors. + + Notes: + - Since inputs are mapped with the identity, attributing to the + input/feature can be done with either the input or output of the + layer, e.g. attributing to an input/feature doesn't depend on whether + attribute_to_layer_input is True or False for + LayerIntegratedGradients. + - Please refer to the multimodal tutorial or unit tests + (test/attr/test_layer_wrapper.py) for an example. + + Args: + module_to_wrap (nn.Module): + The model/module you want to wrap + """ + super().__init__() + self.module = module_to_wrap + + # ignore self + self.arg_name_list = inspect.getfullargspec(module_to_wrap.forward).args[1:] + self.input_maps = nn.ModuleDict( + {arg_name: InputIdentity(arg_name) for arg_name in self.arg_name_list} + ) + + def forward(self, *args, **kwargs) -> Any: + args = list(args) + for idx, (arg_name, arg) in enumerate(zip(self.arg_name_list, args)): + args[idx] = self.input_maps[arg_name](arg) + + for arg_name in kwargs.keys(): + kwargs[arg_name] = self.input_maps[arg_name](kwargs[arg_name]) + + return self.module(*tuple(args), **kwargs) diff --git a/captum/attr/_utils/lrp_rules.py b/captum/attr/_utils/lrp_rules.py new file mode 100644 index 0000000000000000000000000000000000000000..edacdef0044c53d2bc56dde29daf62c89c736de9 --- /dev/null +++ b/captum/attr/_utils/lrp_rules.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 + +from abc import ABC, abstractmethod + +import torch + +from ..._utils.common import _format_tensor_into_tuples + + +class PropagationRule(ABC): + """ + Base class for all propagation rule classes, also called Z-Rule. + STABILITY_FACTOR is used to assure that no zero divison occurs. + """ + + STABILITY_FACTOR = 1e-9 + + def forward_hook(self, module, inputs, outputs): + """Register backward hooks on input and output + tensors of linear layers in the model.""" + inputs = _format_tensor_into_tuples(inputs) + self._has_single_input = len(inputs) == 1 + self._handle_input_hooks = [] + for input in inputs: + if not hasattr(input, "hook_registered"): + input_hook = self._create_backward_hook_input(input.data) + self._handle_input_hooks.append(input.register_hook(input_hook)) + input.hook_registered = True + output_hook = self._create_backward_hook_output(outputs.data) + self._handle_output_hook = outputs.register_hook(output_hook) + return outputs.clone() + + @staticmethod + def backward_hook_activation(module, grad_input, grad_output): + """Backward hook to propagate relevance over non-linear activations.""" + if ( + isinstance(grad_input, tuple) + and isinstance(grad_output, tuple) + and len(grad_input) > len(grad_output) + ): + # Adds any additional elements of grad_input if applicable + # This occurs when registering a backward hook on nn.Dropout + # modules, which has an additional element of None in + # grad_input + return grad_output + grad_input[len(grad_output) :] + return grad_output + + def _create_backward_hook_input(self, inputs): + def _backward_hook_input(grad): + relevance = grad * inputs + device = grad.device + if self._has_single_input: + self.relevance_input[device] = relevance.data + else: + self.relevance_input[device].append(relevance.data) + return relevance + + return _backward_hook_input + + def _create_backward_hook_output(self, outputs): + def _backward_hook_output(grad): + sign = torch.sign(outputs) + sign[sign == 0] = 1 + relevance = grad / (outputs + sign * self.STABILITY_FACTOR) + self.relevance_output[grad.device] = grad.data + return relevance + + return _backward_hook_output + + def forward_hook_weights(self, module, inputs, outputs): + """Save initial activations a_j before modules are changed""" + device = inputs[0].device if isinstance(inputs, tuple) else inputs.device + if hasattr(module, "activations") and device in module.activations: + raise RuntimeError( + "Module {} is being used more than once in the network, which " + "is not supported by LRP. " + "Please ensure that module is being used only once in the " + "network.".format(module) + ) + module.activations[device] = tuple(input.data for input in inputs) + self._manipulate_weights(module, inputs, outputs) + + @abstractmethod + def _manipulate_weights(self, module, inputs, outputs): + raise NotImplementedError + + def forward_pre_hook_activations(self, module, inputs): + """Pass initial activations to graph generation pass""" + device = inputs[0].device if isinstance(inputs, tuple) else inputs.device + for input, activation in zip(inputs, module.activations[device]): + input.data = activation + return inputs + + +class EpsilonRule(PropagationRule): + """ + Rule for relevance propagation using a small value of epsilon + to avoid numerical instabilities and remove noise. + + Use for middle layers. + + Args: + epsilon (integer, float): Value by which is added to the + discriminator during propagation. + """ + + def __init__(self, epsilon=1e-9) -> None: + self.STABILITY_FACTOR = epsilon + + def _manipulate_weights(self, module, inputs, outputs): + pass + + +class GammaRule(PropagationRule): + """ + Gamma rule for relevance propagation, gives more importance to + positive relevance. + + Use for lower layers. + + Args: + gamma (float): The gamma parameter determines by how much + the positive relevance is increased. + """ + + def __init__(self, gamma=0.25, set_bias_to_zero=False) -> None: + self.gamma = gamma + self.set_bias_to_zero = set_bias_to_zero + + def _manipulate_weights(self, module, inputs, outputs): + if hasattr(module, "weight"): + module.weight.data = ( + module.weight.data + self.gamma * module.weight.data.clamp(min=0) + ) + if self.set_bias_to_zero and hasattr(module, "bias"): + if module.bias is not None: + module.bias.data = torch.zeros_like(module.bias.data) + + +class Alpha1_Beta0_Rule(PropagationRule): + """ + Alpha1_Beta0 rule for relevance backpropagation, also known + as Deep-Taylor. Only positive relevance is propagated, resulting + in stable results, therefore recommended as the initial choice. + + Warning: Does not work for BatchNorm modules because weight and bias + are defined differently. + + Use for lower layers. + """ + + def __init__(self, set_bias_to_zero=False) -> None: + self.set_bias_to_zero = set_bias_to_zero + + def _manipulate_weights(self, module, inputs, outputs): + if hasattr(module, "weight"): + module.weight.data = module.weight.data.clamp(min=0) + if self.set_bias_to_zero and hasattr(module, "bias"): + if module.bias is not None: + module.bias.data = torch.zeros_like(module.bias.data) + + +class IdentityRule(EpsilonRule): + """ + Identity rule for skipping layer manipulation and propagating the + relevance over a layer. Only valid for modules with same dimensions for + inputs and outputs. + + Can be used for BatchNorm2D. + """ + + def _create_backward_hook_input(self, inputs): + def _backward_hook_input(grad): + return self.relevance_output[grad.device] + + return _backward_hook_input diff --git a/captum/attr/_utils/stat.py b/captum/attr/_utils/stat.py new file mode 100644 index 0000000000000000000000000000000000000000..8677642b03dab8ede8955f54332f19cc954b0339 --- /dev/null +++ b/captum/attr/_utils/stat.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, List, Optional, TYPE_CHECKING + +import torch +from torch import Tensor + +if TYPE_CHECKING: + from captum.attr._utils.summarizer import SummarizerSingleTensor + + +class Stat: + """ + The Stat class represents a statistic that can be updated and retrieved + at any point in time. + + The basic functionality this class provides is: + 1. A update/get method to actually compute the statistic + 2. A statistic store/cache to retrieve dependent information + (e.g. other stat values that are required for computation) + 3. The name of the statistic that is used for the user to refer to + """ + + def __init__(self, name: Optional[str] = None, **kwargs: Any) -> None: + """ + Args: + name (str, optional): + The name of the statistic. If not provided, + the class name will be used alongside it's parameters + kwargs (Any): + Additional arguments used to construct the statistic + """ + self.params = kwargs + self._name = name + + self._other_stats: Optional[SummarizerSingleTensor] = None + + def init(self): + pass + + def _get_stat(self, stat: "Stat") -> Optional["Stat"]: + assert self._other_stats is not None + return self._other_stats.get(stat) + + def update(self, x: Tensor): + raise NotImplementedError() + + def get(self) -> Optional[Tensor]: + raise NotImplementedError() + + def __hash__(self): + return hash((self.__class__, frozenset(self.params.items()))) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Stat): + return self.__class__ == other.__class__ and frozenset( + self.params.items() + ) == frozenset(other.params.items()) + else: + return False + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + @property + def name(self): + """ + The name of the statistic. i.e. it is the key in a .summary + + This will be the class name or a custom name if provided. + + See Summarizer or SummarizerSingleTensor + """ + default_name = self.__class__.__name__.lower() + if len(self.params) > 0: + default_name += f"({self.params})" + + return default_name if self._name is None else self._name + + +class Count(Stat): + """ + Counts the number of elements, i.e. the + number of `update`'s called + """ + + def __init__(self, name: Optional[str] = None) -> None: + super().__init__(name=name) + self.n = None + + def get(self): + return self.n + + def update(self, x): + if self.n is None: + self.n = 0 + self.n += 1 + + +class Mean(Stat): + """ + Calculates the average of a tensor + """ + + def __init__(self, name: Optional[str] = None) -> None: + super().__init__(name=name) + self.rolling_mean: Optional[Tensor] = None + self.n: Optional[Count] = None + + def get(self) -> Optional[Tensor]: + return self.rolling_mean + + def init(self): + self.n = self._get_stat(Count()) + + def update(self, x): + n = self.n.get() + + if self.rolling_mean is None: + # Ensures rolling_mean is a float tensor + self.rolling_mean = x.clone() if x.is_floating_point() else x.double() + else: + delta = x - self.rolling_mean + self.rolling_mean += delta / n + + +class MSE(Stat): + """ + Calculates the mean squared error of a tensor + """ + + def __init__(self, name: Optional[str] = None) -> None: + super().__init__(name=name) + self.prev_mean = None + self.mse = None + + def init(self): + self.mean = self._get_stat(Mean()) + + def get(self) -> Optional[Tensor]: + if self.mse is None and self.prev_mean is not None: + return torch.zeros_like(self.prev_mean) + return self.mse + + def update(self, x: Tensor): + mean = self.mean.get() + + if mean is not None and self.prev_mean is not None: + rhs = (x - self.prev_mean) * (x - mean) + if self.mse is None: + self.mse = rhs + else: + self.mse += rhs + + # do not not clone + self.prev_mean = mean.clone() + + +class Var(Stat): + """ + Calculates the variance of a tensor, with an order. e.g. + if `order = 1` then it will calculate sample variance. + + This is equal to mse / (n - order) + """ + + def __init__(self, name: Optional[str] = None, order: int = 0) -> None: + if name is None: + if order == 0: + name = "variance" + elif order == 1: + name = "sample_variance" + else: + name = f"variance({order})" + + super().__init__(name=name, order=order) + self.order = order + + def init(self): + self.mse = self._get_stat(MSE()) + self.n = self._get_stat(Count()) + + def update(self, x: Tensor): + pass + + def get(self) -> Optional[Tensor]: + mse = self.mse.get() + n = self.n.get() + + if mse is None: + return None + + if n <= self.order: + return torch.zeros_like(mse) + + # NOTE: The following ensures mse is a float tensor. + # torch.true_divide is available in PyTorch 1.5 and later. + # This is for compatibility with 1.4. + return mse.to(torch.float64) / (n - self.order) + + +class StdDev(Stat): + """ + The standard deviation, with an associated order. + """ + + def __init__(self, name: Optional[str] = None, order: int = 0) -> None: + if name is None: + if order == 0: + name = "std_dev" + elif order == 1: + name = "sample_std_dev" + else: + name = f"std_dev{order})" + + super().__init__(name=name, order=order) + self.order = order + + def init(self): + self.var = self._get_stat(Var(order=self.order)) + + def update(self, x: Tensor): + pass + + def get(self) -> Optional[Tensor]: + var = self.var.get() + return var ** 0.5 if var is not None else None + + +class GeneralAccumFn(Stat): + """ + Performs update(x): result = fn(result, x) + where fn is a custom function + """ + + def __init__(self, fn: Callable, name: Optional[str] = None) -> None: + super().__init__(name=name) + self.result = None + self.fn = fn + + def get(self) -> Optional[Tensor]: + return self.result + + def update(self, x): + if self.result is None: + self.result = x + else: + self.result = self.fn(self.result, x) + + +class Min(GeneralAccumFn): + def __init__( + self, name: Optional[str] = None, min_fn: Callable = torch.min + ) -> None: + super().__init__(name=name, fn=min_fn) + + +class Max(GeneralAccumFn): + def __init__( + self, name: Optional[str] = None, max_fn: Callable = torch.max + ) -> None: + super().__init__(name=name, fn=max_fn) + + +class Sum(GeneralAccumFn): + def __init__( + self, name: Optional[str] = None, add_fn: Callable = torch.add + ) -> None: + super().__init__(name=name, fn=add_fn) + + +def CommonStats() -> List[Stat]: + r""" + Returns common summary statistics, specifically: + Mean, Sample Variance, Sample Std Dev, Min, Max + """ + return [Mean(), Var(order=1), StdDev(order=1), Min(), Max()] diff --git a/captum/attr/_utils/summarizer.py b/captum/attr/_utils/summarizer.py new file mode 100644 index 0000000000000000000000000000000000000000..874e5d263b5811ecfdc3055f1608bccfab1832bc --- /dev/null +++ b/captum/attr/_utils/summarizer.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 + +from typing import Dict, List, Optional, Tuple, Type, Union + +import torch +from captum.attr._utils.stat import Count, Max, Mean, Min, MSE, Stat, StdDev, Sum, Var +from captum.log import log_usage +from torch import Tensor + + +class Summarizer: + r""" + This class simply wraps over a given a set of SummarizerSingleTensor's in order + to summarise multiple input tensors. + + Basic usage: + + >>>from captum.attr.aggregator import Summarizer + >>>from captum.attr._utils.stats import Mean, StdDev + >>> + >>>attrib = torch.tensor([1, 2, 3, 4, 5]) + >>> + >>>summ = Summarizer([Mean(), StdDev(0]) + >>>summ.update(attrib) + >>> + >>>print(summ.summary['mean']) + """ + + @log_usage() + def __init__(self, stats: List[Stat]) -> None: + r""" + Args: + stats (List[Stat]): + The list of statistics you wish to track + """ + self._summarizers: List[SummarizerSingleTensor] = [] + self._is_inputs_tuple: Optional[bool] = None + self._stats, self._summary_stats_indicies = _reorder_stats(stats) + + def _copy_stats(self): + import copy + + return copy.deepcopy(self._stats) + + def update(self, x: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]): + r""" + Calls `update` on each `Stat` object within the summarizer + + Args: + x (Tensor or Tuple[Tensor, ...]): + The input(s) you wish to summarize + """ + if self._is_inputs_tuple is None: + self._is_inputs_tuple = isinstance(x, tuple) + else: + # we want input to be consistently a single input or a tuple + assert not (self._is_inputs_tuple ^ isinstance(x, tuple)) + + from captum._utils.common import _format_float_or_tensor_into_tuples + + x = _format_float_or_tensor_into_tuples(x) + + for i, inp in enumerate(x): + if i >= len(self._summarizers): + # _summarizers[i] is a new SummarizerSingleTensor, which + # aims to summarize input i (i.e. x[i]) + # + # Thus, we must copy our stats, as otherwise + # in the best case the statistics for each input will be mangled + # and in the worst case we will run into an error due to different + # dimensionality in the input tensors tensors (i.e. + # x[i].shape != x[j].shape for some pair i, j) + stats = self._copy_stats() + self._summarizers.append( + SummarizerSingleTensor( + stats=stats, summary_stats_indices=self._summary_stats_indicies + ) + ) + if not isinstance(inp, torch.Tensor): + inp = torch.tensor(inp, dtype=torch.float) + self._summarizers[i].update(inp) + + @property + def summary( + self, + ) -> Optional[ + Union[Dict[str, Optional[Tensor]], List[Dict[str, Optional[Tensor]]]] + ]: + r""" + Effectively calls `get` on each `Stat` object within this object for each input + + Returns: + A dict or list of dict: mapping from the Stat + object's `name` to the associated value of `get` + """ + if len(self._summarizers) == 0: + return None + + temp = [summ.summary for summ in self._summarizers] + return temp if self._is_inputs_tuple else temp[0] + + +def _reorder_stats(stats: List[Stat]) -> Tuple[List[Stat], List[int]]: + # We want to want to store two things: + # 1. A mapping from a Stat to Stat object (self._stat_to_stat): + # This is to retrieve an existing Stat object for dependency + # resolution, e.g. Mean needs the Count stat - we want to + # retrieve it in O(1) + # + # 2. All of the necessary stats, in the correct order, + # to perform an update for each Stat (self.stats) trivially + + # As a reference, the dependency graph for our stats is as follows: + # StdDev(x) -> Var(x) -> MSE -> Mean -> Count, for all valid x + # + # Step 1: + # Ensure we have all the necessary stats + # i.e. ensure we have the dependencies + # Step 2: + # Figure out the order to update them + dep_order = [StdDev, Var, MSE, Mean, Count] + + # remove dupe stats + stats = set(stats) + summary_stats = set(stats) + + from collections import defaultdict + + stats_by_module: Dict[Type, List[Stat]] = defaultdict(list) + for stat in stats: + stats_by_module[stat.__class__].append(stat) + + # StdDev is an odd case since it is parameterized, thus + # for each StdDev(order) we must ensure there is an associated Var(order) + for std_dev in stats_by_module[StdDev]: + stat_to_add = Var(order=std_dev.order) # type: ignore + stats.add(stat_to_add) + stats_by_module[stat_to_add.__class__].append(stat_to_add) + + # For the other modules (deps[1:n-1]): if i exists => + # we want to ensure i...n-1 exists + for i, dep in enumerate(dep_order[1:]): + if dep in stats_by_module: + stats.update([mod() for mod in dep_order[i + 1 :]]) + break + + # Step 2: get the correct order + # NOTE: we are sorting via a given topological order + sort_order = {mod: i for i, mod in enumerate(dep_order)} + sort_order[Min] = -1 + sort_order[Max] = -1 + sort_order[Sum] = -1 + + stats = list(stats) + stats.sort(key=lambda x: sort_order[x.__class__], reverse=True) + + # get the summary stat indices + summary_stat_indexs = [] + for i, stat in enumerate(stats): + if stat in summary_stats: + summary_stat_indexs.append(i) + return stats, summary_stat_indexs + + +class SummarizerSingleTensor: + r""" + A simple class that summarizes a single tensor. The basic functionality + of this class is two operations .update and .summary + + If possible use `Summarizer` instead. + """ + + def __init__(self, stats: List[Stat], summary_stats_indices: List[int]) -> None: + r""" + Args: + stats (list of Stat): A list of all the Stat objects that + need to be updated. This must be in the appropriate order for + updates (see `_reorder_stats`) + summary_stats (list of int): A list of indicies, referencing `stats`, + which are the stats you want to show in the .summary property. This + does not require any specific order. + """ + self._stats = stats + self._stat_to_stat = {stat: stat for stat in self._stats} + self._summary_stats = [stats[i] for i in summary_stats_indices] + + for stat in stats: + stat._other_stats = self + stat.init() + + def update(self, x: Tensor): + r""" + Updates the summary of a given tensor `x` + + Args: + x (Tensor): + The tensor to summarize + """ + for stat in self._stats: + stat.update(x) + + def get(self, stat: Stat) -> Optional[Stat]: + r""" + Retrieves `stat` from cache if this summarizer contains it. + + Note that `Stat` has it's hash/equality method overridden, such + that an object with the same class and parameters will have the + same hash. Thus, if you call `get` with a `Stat`, an associated + `Stat` with the same class and parameters belonging to this object + will be retrieved if it exists. + + If no such object is retrieved then `None` is returned. + + Args: + stat (Stat): + The stat to retrieve + Returns: + Stat + The cached stat object or `None` + """ + if stat not in self._stat_to_stat: + return None + + return self._stat_to_stat[stat] + + @property + def summary(self) -> Dict[str, Optional[Tensor]]: + """ + Returns: + Optional[Dict[str, Optional[Tensor]]] + The cached stat object + """ + return {stat.name: stat.get() for stat in self._summary_stats} diff --git a/captum/attr/_utils/visualization.py b/captum/attr/_utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..241759f653ed3b76ddbbd74f2955b27792040652 --- /dev/null +++ b/captum/attr/_utils/visualization.py @@ -0,0 +1,573 @@ +#!/usr/bin/env python3 +import warnings +from enum import Enum +from typing import Any, Iterable, List, Tuple, Union + +import numpy as np +from matplotlib import pyplot as plt +from matplotlib.colors import LinearSegmentedColormap +from matplotlib.figure import Figure +from matplotlib.pyplot import axis, figure +from mpl_toolkits.axes_grid1 import make_axes_locatable +from numpy import ndarray + +try: + from IPython.core.display import display, HTML + + HAS_IPYTHON = True +except ImportError: + HAS_IPYTHON = False + + +class ImageVisualizationMethod(Enum): + heat_map = 1 + blended_heat_map = 2 + original_image = 3 + masked_image = 4 + alpha_scaling = 5 + + +class VisualizeSign(Enum): + positive = 1 + absolute_value = 2 + negative = 3 + all = 4 + + +def _prepare_image(attr_visual: ndarray): + return np.clip(attr_visual.astype(int), 0, 255) + + +def _normalize_scale(attr: ndarray, scale_factor: float): + assert scale_factor != 0, "Cannot normalize by scale factor = 0" + if abs(scale_factor) < 1e-5: + warnings.warn( + "Attempting to normalize by value approximately 0, visualized results" + "may be misleading. This likely means that attribution values are all" + "close to 0." + ) + attr_norm = attr / scale_factor + return np.clip(attr_norm, -1, 1) + + +def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]): + # given values should be non-negative + assert percentile >= 0 and percentile <= 100, ( + "Percentile for thresholding must be " "between 0 and 100 inclusive." + ) + sorted_vals = np.sort(values.flatten()) + cum_sums = np.cumsum(sorted_vals) + threshold_id = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0] + return sorted_vals[threshold_id] + + +def _normalize_image_attr( + attr: ndarray, sign: str, outlier_perc: Union[int, float] = 2 +): + attr_combined = np.sum(attr, axis=2) + # Choose appropriate signed values and rescale, removing given outlier percentage. + if VisualizeSign[sign] == VisualizeSign.all: + threshold = _cumulative_sum_threshold(np.abs(attr_combined), 100 - outlier_perc) + elif VisualizeSign[sign] == VisualizeSign.positive: + attr_combined = (attr_combined > 0) * attr_combined + threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc) + elif VisualizeSign[sign] == VisualizeSign.negative: + attr_combined = (attr_combined < 0) * attr_combined + threshold = -1 * _cumulative_sum_threshold( + np.abs(attr_combined), 100 - outlier_perc + ) + elif VisualizeSign[sign] == VisualizeSign.absolute_value: + attr_combined = np.abs(attr_combined) + threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc) + else: + raise AssertionError("Visualize Sign type is not valid.") + return _normalize_scale(attr_combined, threshold) + + +def visualize_image_attr( + attr: ndarray, + original_image: Union[None, ndarray] = None, + method: str = "heat_map", + sign: str = "absolute_value", + plt_fig_axis: Union[None, Tuple[figure, axis]] = None, + outlier_perc: Union[int, float] = 2, + cmap: Union[None, str] = None, + alpha_overlay: float = 0.5, + show_colorbar: bool = False, + title: Union[None, str] = None, + fig_size: Tuple[int, int] = (6, 6), + use_pyplot: bool = True, +): + r""" + Visualizes attribution for a given image by normalizing attribution values + of the desired sign (positive, negative, absolute value, or all) and displaying + them using the desired mode in a matplotlib figure. + + Args: + + attr (numpy.array): Numpy array corresponding to attributions to be + visualized. Shape must be in the form (H, W, C), with + channels as last dimension. Shape must also match that of + the original image if provided. + original_image (numpy.array, optional): Numpy array corresponding to + original image. Shape must be in the form (H, W, C), with + channels as the last dimension. Image can be provided either + with float values in range 0-1 or int values between 0-255. + This is a necessary argument for any visualization method + which utilizes the original image. + Default: None + method (string, optional): Chosen method for visualizing attribution. + Supported options are: + + 1. `heat_map` - Display heat map of chosen attributions + + 2. `blended_heat_map` - Overlay heat map over greyscale + version of original image. Parameter alpha_overlay + corresponds to alpha of heat map. + + 3. `original_image` - Only display original image. + + 4. `masked_image` - Mask image (pixel-wise multiply) + by normalized attribution values. + + 5. `alpha_scaling` - Sets alpha channel of each pixel + to be equal to normalized attribution value. + Default: `heat_map` + sign (string, optional): Chosen sign of attributions to visualize. Supported + options are: + + 1. `positive` - Displays only positive pixel attributions. + + 2. `absolute_value` - Displays absolute value of + attributions. + + 3. `negative` - Displays only negative pixel attributions. + + 4. `all` - Displays both positive and negative attribution + values. This is not supported for `masked_image` or + `alpha_scaling` modes, since signed information cannot + be represented in these modes. + Default: `absolute_value` + plt_fig_axis (tuple, optional): Tuple of matplotlib.pyplot.figure and axis + on which to visualize. If None is provided, then a new figure + and axis are created. + Default: None + outlier_perc (float or int, optional): Top attribution values which + correspond to a total of outlier_perc percentage of the + total attribution are set to 1 and scaling is performed + using the minimum of these values. For sign=`all`, outliers + and scale value are computed using absolute value of + attributions. + Default: 2 + cmap (string, optional): String corresponding to desired colormap for + heatmap visualization. This defaults to "Reds" for negative + sign, "Blues" for absolute value, "Greens" for positive sign, + and a spectrum from red to green for all. Note that this + argument is only used for visualizations displaying heatmaps. + Default: None + alpha_overlay (float, optional): Alpha to set for heatmap when using + `blended_heat_map` visualization mode, which overlays the + heat map over the greyscaled original image. + Default: 0.5 + show_colorbar (boolean, optional): Displays colorbar for heatmap below + the visualization. If given method does not use a heatmap, + then a colormap axis is created and hidden. This is + necessary for appropriate alignment when visualizing + multiple plots, some with colorbars and some without. + Default: False + title (string, optional): Title string for plot. If None, no title is + set. + Default: None + fig_size (tuple, optional): Size of figure created. + Default: (6,6) + use_pyplot (boolean, optional): If true, uses pyplot to create and show + figure and displays the figure after creating. If False, + uses Matplotlib object oriented API and simply returns a + figure object without showing. + Default: True. + + Returns: + 2-element tuple of **figure**, **axis**: + - **figure** (*matplotlib.pyplot.figure*): + Figure object on which visualization + is created. If plt_fig_axis argument is given, this is the + same figure provided. + - **axis** (*matplotlib.pyplot.axis*): + Axis object on which visualization + is created. If plt_fig_axis argument is given, this is the + same axis provided. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> ig = IntegratedGradients(net) + >>> # Computes integrated gradients for class 3 for a given image . + >>> attribution, delta = ig.attribute(orig_image, target=3) + >>> # Displays blended heat map visualization of computed attributions. + >>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map") + """ + # Create plot if figure, axis not provided + if plt_fig_axis is not None: + plt_fig, plt_axis = plt_fig_axis + else: + if use_pyplot: + plt_fig, plt_axis = plt.subplots(figsize=fig_size) + else: + plt_fig = Figure(figsize=fig_size) + plt_axis = plt_fig.subplots() + + if original_image is not None: + if np.max(original_image) <= 1.0: + original_image = _prepare_image(original_image * 255) + else: + assert ( + ImageVisualizationMethod[method] == ImageVisualizationMethod.heat_map + ), "Original Image must be provided for any visualization other than heatmap." + + # Remove ticks and tick labels from plot. + plt_axis.xaxis.set_ticks_position("none") + plt_axis.yaxis.set_ticks_position("none") + plt_axis.set_yticklabels([]) + plt_axis.set_xticklabels([]) + plt_axis.grid(b=False) + + heat_map = None + # Show original image + if ImageVisualizationMethod[method] == ImageVisualizationMethod.original_image: + if len(original_image.shape) > 2 and original_image.shape[2] == 1: + original_image = np.squeeze(original_image, axis=2) + plt_axis.imshow(original_image) + else: + # Choose appropriate signed attributions and normalize. + norm_attr = _normalize_image_attr(attr, sign, outlier_perc) + + # Set default colormap and bounds based on sign. + if VisualizeSign[sign] == VisualizeSign.all: + default_cmap = LinearSegmentedColormap.from_list( + "RdWhGn", ["red", "white", "green"] + ) + vmin, vmax = -1, 1 + elif VisualizeSign[sign] == VisualizeSign.positive: + default_cmap = "Greens" + vmin, vmax = 0, 1 + elif VisualizeSign[sign] == VisualizeSign.negative: + default_cmap = "Reds" + vmin, vmax = 0, 1 + elif VisualizeSign[sign] == VisualizeSign.absolute_value: + default_cmap = "Blues" + vmin, vmax = 0, 1 + else: + raise AssertionError("Visualize Sign type is not valid.") + cmap = cmap if cmap is not None else default_cmap + + # Show appropriate image visualization. + if ImageVisualizationMethod[method] == ImageVisualizationMethod.heat_map: + heat_map = plt_axis.imshow(norm_attr, cmap=cmap, vmin=vmin, vmax=vmax) + elif ( + ImageVisualizationMethod[method] + == ImageVisualizationMethod.blended_heat_map + ): + plt_axis.imshow(np.mean(original_image, axis=2), cmap="gray") + heat_map = plt_axis.imshow( + norm_attr, cmap=cmap, vmin=vmin, vmax=vmax, alpha=alpha_overlay + ) + elif ImageVisualizationMethod[method] == ImageVisualizationMethod.masked_image: + assert VisualizeSign[sign] != VisualizeSign.all, ( + "Cannot display masked image with both positive and negative " + "attributions, choose a different sign option." + ) + plt_axis.imshow( + _prepare_image(original_image * np.expand_dims(norm_attr, 2)) + ) + elif ImageVisualizationMethod[method] == ImageVisualizationMethod.alpha_scaling: + assert VisualizeSign[sign] != VisualizeSign.all, ( + "Cannot display alpha scaling with both positive and negative " + "attributions, choose a different sign option." + ) + plt_axis.imshow( + np.concatenate( + [ + original_image, + _prepare_image(np.expand_dims(norm_attr, 2) * 255), + ], + axis=2, + ) + ) + else: + raise AssertionError("Visualize Method type is not valid.") + + # Add colorbar. If given method is not a heatmap and no colormap is relevant, + # then a colormap axis is created and hidden. This is necessary for appropriate + # alignment when visualizing multiple plots, some with heatmaps and some + # without. + if show_colorbar: + axis_separator = make_axes_locatable(plt_axis) + colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.1) + if heat_map: + plt_fig.colorbar(heat_map, orientation="horizontal", cax=colorbar_axis) + else: + colorbar_axis.axis("off") + if title: + plt_axis.set_title(title) + + if use_pyplot: + plt.show() + + return plt_fig, plt_axis + + +def visualize_image_attr_multiple( + attr: ndarray, + original_image: Union[None, ndarray], + methods: List[str], + signs: List[str], + titles: Union[None, List[str]] = None, + fig_size: Tuple[int, int] = (8, 6), + use_pyplot: bool = True, + **kwargs: Any, +): + r""" + Visualizes attribution using multiple visualization methods displayed + in a 1 x k grid, where k is the number of desired visualizations. + + Args: + + attr (numpy.array): Numpy array corresponding to attributions to be + visualized. Shape must be in the form (H, W, C), with + channels as last dimension. Shape must also match that of + the original image if provided. + original_image (numpy.array, optional): Numpy array corresponding to + original image. Shape must be in the form (H, W, C), with + channels as the last dimension. Image can be provided either + with values in range 0-1 or 0-255. This is a necessary + argument for any visualization method which utilizes + the original image. + methods (list of strings): List of strings of length k, defining method + for each visualization. Each method must be a valid + string argument for method to visualize_image_attr. + signs (list of strings): List of strings of length k, defining signs for + each visualization. Each sign must be a valid + string argument for sign to visualize_image_attr. + titles (list of strings, optional): List of strings of length k, providing + a title string for each plot. If None is provided, no titles + are added to subplots. + Default: None + fig_size (tuple, optional): Size of figure created. + Default: (8, 6) + use_pyplot (boolean, optional): If true, uses pyplot to create and show + figure and displays the figure after creating. If False, + uses Matplotlib object oriented API and simply returns a + figure object without showing. + Default: True. + **kwargs (Any, optional): Any additional arguments which will be passed + to every individual visualization. Such arguments include + `show_colorbar`, `alpha_overlay`, `cmap`, etc. + + + Returns: + 2-element tuple of **figure**, **axis**: + - **figure** (*matplotlib.pyplot.figure*): + Figure object on which visualization + is created. If plt_fig_axis argument is given, this is the + same figure provided. + - **axis** (*matplotlib.pyplot.axis*): + Axis object on which visualization + is created. If plt_fig_axis argument is given, this is the + same axis provided. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> ig = IntegratedGradients(net) + >>> # Computes integrated gradients for class 3 for a given image . + >>> attribution, delta = ig.attribute(orig_image, target=3) + >>> # Displays original image and heat map visualization of + >>> # computed attributions side by side. + >>> _ = visualize_image_attr_multiple(attribution, orig_image, + >>> ["original_image", "heat_map"], ["all", "positive"]) + """ + assert len(methods) == len(signs), "Methods and signs array lengths must match." + if titles is not None: + assert len(methods) == len(titles), ( + "If titles list is given, length must " "match that of methods list." + ) + if use_pyplot: + plt_fig = plt.figure(figsize=fig_size) + else: + plt_fig = Figure(figsize=fig_size) + plt_axis = plt_fig.subplots(1, len(methods)) + + # When visualizing one + if len(methods) == 1: + plt_axis = [plt_axis] + + for i in range(len(methods)): + visualize_image_attr( + attr, + original_image=original_image, + method=methods[i], + sign=signs[i], + plt_fig_axis=(plt_fig, plt_axis[i]), + use_pyplot=False, + title=titles[i] if titles else None, + **kwargs, + ) + plt_fig.tight_layout() + if use_pyplot: + plt.show() + return plt_fig, plt_axis + + +# These visualization methods are for text and are partially copied from +# experiments conducted by Davide Testuggine at Facebook. + + +class VisualizationDataRecord: + r""" + A data record for storing attribution relevant information + """ + __slots__ = [ + "word_attributions", + "pred_prob", + "pred_class", + "true_class", + "attr_class", + "attr_score", + "raw_input_ids", + "convergence_score", + ] + + def __init__( + self, + word_attributions, + pred_prob, + pred_class, + true_class, + attr_class, + attr_score, + raw_input_ids, + convergence_score, + ) -> None: + self.word_attributions = word_attributions + self.pred_prob = pred_prob + self.pred_class = pred_class + self.true_class = true_class + self.attr_class = attr_class + self.attr_score = attr_score + self.raw_input_ids = raw_input_ids + self.convergence_score = convergence_score + + +def _get_color(attr): + # clip values to prevent CSS errors (Values should be from [-1,1]) + attr = max(-1, min(1, attr)) + if attr > 0: + hue = 120 + sat = 75 + lig = 100 - int(50 * attr) + else: + hue = 0 + sat = 75 + lig = 100 - int(-40 * attr) + return "hsl({}, {}%, {}%)".format(hue, sat, lig) + + +def format_classname(classname): + return '{}'.format(classname) + + +def format_special_tokens(token): + if token.startswith("<") and token.endswith(">"): + return "#" + token.strip("<>") + return token + + +def format_tooltip(item, text): + return '
{item}\ + {text}\ +
'.format( + item=item, text=text + ) + + +def format_word_importances(words, importances): + if importances is None or len(importances) == 0: + return "" + assert len(words) <= len(importances) + tags = [""] + for word, importance in zip(words, importances[: len(words)]): + word = format_special_tokens(word) + color = _get_color(importance) + unwrapped_tag = ' {word}\ + '.format( + color=color, word=word + ) + tags.append(unwrapped_tag) + tags.append("") + return "".join(tags) + + +def visualize_text( + datarecords: Iterable[VisualizationDataRecord], legend: bool = True +) -> "HTML": # In quotes because this type doesn't exist in standalone mode + assert HAS_IPYTHON, ( + "IPython must be available to visualize text. " + "Please run 'pip install ipython'." + ) + dom = [""] + rows = [ + "" + "" + "" + "" + "" + ] + for datarecord in datarecords: + rows.append( + "".join( + [ + "", + format_classname(datarecord.true_class), + format_classname( + "{0} ({1:.2f})".format( + datarecord.pred_class, datarecord.pred_prob + ) + ), + format_classname(datarecord.attr_class), + format_classname("{0:.2f}".format(datarecord.attr_score)), + format_word_importances( + datarecord.raw_input_ids, datarecord.word_attributions + ), + "", + ] + ) + ) + + if legend: + dom.append( + '
' + ) + dom.append("Legend: ") + + for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]): + dom.append( + ' {label} '.format( + value=_get_color(value), label=label + ) + ) + dom.append("
") + + dom.append("".join(rows)) + dom.append("
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
") + html = HTML("".join(dom)) + display(html) + + return html diff --git a/captum/concept/__init__.py b/captum/concept/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a1eee9e117c876c1d0a72754e394d59216f6618 --- /dev/null +++ b/captum/concept/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +from captum.concept._core.cav import CAV # noqa +from captum.concept._core.concept import Concept, ConceptInterpreter # noqa +from captum.concept._core.tcav import TCAV # noqa +from captum.concept._utils.classifier import Classifier, DefaultClassifier # noqa diff --git a/captum/concept/_core/__init__.py b/captum/concept/_core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/concept/_core/cav.py b/captum/concept/_core/cav.py new file mode 100644 index 0000000000000000000000000000000000000000..39aa9fba85655cdb3934c79f35eb48d07061677a --- /dev/null +++ b/captum/concept/_core/cav.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 + +import os +from typing import Any, Dict, List + +import torch +from captum.concept._core.concept import Concept +from captum.concept._utils.common import concepts_to_str + + +class CAV: + r""" + Concept Activation Vector (CAV) is a vector orthogonal to the decision + boundary of a classifier which distinguishes between activation + vectors produced by different concepts. + More details can be found in the paper: + https://arxiv.org/pdf/1711.11279.pdf + """ + + def __init__( + self, + concepts: List[Concept], + layer: str, + stats: Dict[str, Any] = None, + save_path: str = "./cav/", + model_id: str = "default_model_id", + ) -> None: + r""" + This class encapsulates the instances of CAVs objects, saves them in + and loads them from the disk (storage). + + Args: + concepts (list[Concept]): a List of Concept objects. Only their + names will be saved and loaded. + layer (str): The layer where concept activation vectors are + computed using a predefined classifier. + stats (dict, optional): a dictionary that retains information about + the CAV classifier such as CAV weights and accuracies. + Ex.: stats = {"weights": weights, "classes": classes, + "accs": accs}, where "weights" are learned + model parameters, "classes" are a list of classes used + by the model to generate the "weights" and "accs" + the classifier training or validation accuracy. + save_path (str, optional): The path where the CAV objects are stored. + model_id (str, optional): A unique model identifier associated with + this CAV instance. + """ + + self.concepts = concepts + self.layer = layer + self.stats = stats + self.save_path = save_path + self.model_id = model_id + + @staticmethod + def assemble_save_path( + path: str, model_id: str, concepts: List[Concept], layer: str + ) -> str: + r""" + A utility method for assembling filename and its path, from + a concept list and a layer name. + + Args: + path (str): A path to be concatenated with the concepts key and + layer name. + model_id (str): A unique model identifier associated with input + `layer` and `concepts` + concepts (list(Concept)): A list of concepts that are concatenated + together and used as a concept key using their ids. These + concept ids are retrieved from TCAV s`Concept` objects. + layer (str): The name of the layer for which the activations are + computed. + + Returns: + cav_path(str): A string containing the path where the computed CAVs + will be stored. + For example, given: + concept_ids = [0, 1, 2] + concept_names = ["striped", "random_0", "random_1"] + layer = "inception4c" + path = "/cavs", + the resulting save path will be: + "/cavs/default_model_id/0-1-2-inception4c.pkl" + + """ + + file_name = concepts_to_str(concepts) + "-" + layer + ".pkl" + return os.path.join(path, model_id, file_name) + + def save(self): + r""" + Saves a dictionary of the CAV computed values into a pickle file in the + location returned by the "assemble_save_path" static methods. The + dictionary contains the concept names list, the layer name for which + the activations are computed for, the stats dictionary which contains + information about the classifier train/eval statistics such as the + weights and training accuracies. Ex.: + + save_dict = { + "concept_ids": [0, 1, 2], + "concept_names": ["striped", "random_0", "random_1"], + "layer": "inception4c", + "stats": {"weights": weights, "classes": classes, "accs": accs} + } + + """ + + save_dict = { + "concept_ids": [c.id for c in self.concepts], + "concept_names": [c.name for c in self.concepts], + "layer": self.layer, + "stats": self.stats, + } + + cavs_path = CAV.assemble_save_path( + self.save_path, self.model_id, self.concepts, self.layer + ) + torch.save(save_dict, cavs_path) + + @staticmethod + def create_cav_dir_if_missing(save_path: str, model_id: str) -> None: + r""" + A utility function for creating the directories where the CAVs will + be stored. CAVs are saved in a folder under named by `model_id` + under `save_path`. + Args: + save_path (str): A root path where the CAVs will be stored + model_id (str): A unique model identifier associated with the + CAVs. A folder named `model_id` is created under + `save_path`. The CAVs are later stored there. + """ + cav_model_id_path = os.path.join(save_path, model_id) + if not os.path.exists(cav_model_id_path): + os.makedirs(cav_model_id_path) + + @staticmethod + def load(cavs_path: str, model_id: str, concepts: List[Concept], layer: str): + r""" + Loads CAV dictionary from a pickle file for given input + `layer` and `concepts`. + + Args: + cavs_path (str): The root path where the cavs are stored + in the storage (on the disk). + Ex.: "/cavs" + model_id (str): A unique model identifier associated with the + CAVs. There exist a folder named `model_id` under + `cavs_path` path. The CAVs are loaded from this folder. + concepts (list[Concept]): A List of concepts for which + we would like to load the cavs. + layer (str): The layer name. Ex.: "inception4c". In case of nested + layers we use dots to specify the depth / hierarchy. + Ex.: "layer.sublayer.subsublayer" + + Returns: + cav(CAV): An instance of a CAV class, containing the respective CAV + score per concept and layer. An example of a path where the + cavs are loaded from is: + "/cavs/default_model_id/0-1-2-inception4c.pkl" + """ + + cavs_path = CAV.assemble_save_path(cavs_path, model_id, concepts, layer) + + if os.path.exists(cavs_path): + save_dict = torch.load(cavs_path) + + concept_names = save_dict["concept_names"] + concept_ids = save_dict["concept_ids"] + concepts = [ + Concept(concept_id, concept_name, None) + for concept_id, concept_name in zip(concept_ids, concept_names) + ] + cav = CAV(concepts, save_dict["layer"], save_dict["stats"]) + + return cav + + return None diff --git a/captum/concept/_core/concept.py b/captum/concept/_core/concept.py new file mode 100644 index 0000000000000000000000000000000000000000..a550ab8a9db87ff4de57a199aae7ab7d47883219 --- /dev/null +++ b/captum/concept/_core/concept.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +from typing import Callable, Union + +import torch +from torch.nn import Module + + +class Concept: + + r""" + Concepts are human-friendly abstract representations that can be + numerically encoded into torch tensors. They can be illustrated as + images, text or any other form of representation. In case of images, + for example, "stripes" concept can be represented through a number + of example images resembling "stripes" in various different + contexts. In case of Natural Language Processing, the concept of + "happy", for instance, can be illustrated through a number of + adjectives and words that convey happiness. + """ + + def __init__( + self, id: int, name: str, data_iter: Union[None, torch.utils.data.DataLoader] + ) -> None: + + r""" + Args: + id (int): The unique identifier of the concept. + name (str): A unique name of the concept. + data_iter (DataLoader): A pytorch DataLoader object that combines a dataset + and a sampler, and provides an iterable over a given + dataset. Only the input batches are provided by `data_iter`. + Concept ids can be used as labels if necessary. + For more information, please check: + https://pytorch.org/docs/stable/data.html + + Example:: + >>> # Creates a Concept object named "striped", with a data_iter + >>> # object to iterate over all files in "./concepts/striped" + >>> concept_name = "striped" + >>> concept_path = os.path.join("./concepts", concept_name) + "/" + >>> concept_iter = dataset_to_dataloader( + >>> get_tensor_from_filename, concepts_path=concept_path) + >>> concept_object = Concept( + id=0, name=concept_name, data_iter=concept_iter) + """ + + self.id = id + self.name = name + self.data_iter = data_iter + + @property + def identifier(self) -> str: + return "%s-%s" % (self.name, self.id) + + def __repr__(self) -> str: + return "Concept(%r, %r)" % (self.id, self.name) + + +class ConceptInterpreter: + r""" + An abstract class that exposes an abstract interpret method + that has to be implemented by a specific algorithm for + concept-based model interpretability. + """ + + def __init__(self, model: Module) -> None: + r""" + Args: + model (torch.nn.Module): An instance of pytorch model. + """ + self.model = model + + interpret: Callable + r""" + An abstract interpret method that performs concept-based model interpretability + and returns the interpretation results in form of tensors, dictionaries or other + data structures. + + Args: + + inputs (tensor or tuple of tensors): Inputs for which concept-based + interpretation scores are computed. It can be provided as + a single tensor or a tuple of multiple tensors. If multiple + input tensors are provided, the batch size (the first + dimension of the tensors) must be aligned across all tensors. + """ diff --git a/captum/concept/_core/tcav.py b/captum/concept/_core/tcav.py new file mode 100644 index 0000000000000000000000000000000000000000..6d79ba06ae7e7ef21becb0bf354ef44eaf4341ca --- /dev/null +++ b/captum/concept/_core/tcav.py @@ -0,0 +1,790 @@ +#!/usr/bin/env python3 + +from collections import defaultdict +from typing import Any, cast, Dict, List, Set, Tuple, Union + +import numpy as np +import torch +import torch.multiprocessing as multiprocessing +from captum._utils.av import AV +from captum._utils.common import _format_tensor_into_tuples, _get_module_from_name +from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric +from captum.attr import LayerActivation, LayerAttribution, LayerGradientXActivation +from captum.concept._core.cav import CAV +from captum.concept._core.concept import Concept, ConceptInterpreter +from captum.concept._utils.classifier import Classifier, DefaultClassifier +from captum.concept._utils.common import concepts_to_str +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module +from torch.utils.data import DataLoader, Dataset + + +class LabelledDataset(Dataset): + """ + A torch Dataset whose __getitem__ returns both a batch of activation vectors, + as well as a batch of labels associated with those activation vectors. + It is used to train a classifier in train_tcav + """ + + def __init__(self, datasets: List[AV.AVDataset], labels: List[int]): + """ + Creates the LabelledDataset given a list of K Datasets, and a length K + list of integer labels representing K different concepts. + The assumption is that the k-th Dataset of datasets is associated with + the k-th element of labels. + The LabelledDataset is the concatenation of the K Datasets in datasets. + However, __get_item__ not only returns a batch of activation vectors, + but also a batch of labels indicating which concept that batch of + activation vectors is associated with. + Args: + datasets (list[Dataset]): The k-th element of datasets is a Dataset + representing activation vectors associated with the k-th + concept + labels (list[Int]): The k-th element of labels is the integer label + associated with the k-th concept + """ + assert len(datasets) == len( + labels + ), "number of datasets does not match the number of concepts" + + from itertools import accumulate + + offsets = [0] + list(accumulate(map(len, datasets), (lambda x, y: x + y))) + self.length = offsets[-1] + self.datasets = datasets + self.labels = labels + self.lowers = offsets[:-1] + self.uppers = offsets[1:] + + def _i_to_k(self, i): + + left, right = 0, len(self.uppers) + while left < right: + mid = (left + right) // 2 + if self.lowers[mid] <= i and i < self.uppers[mid]: + return mid + if i >= self.uppers[mid]: + left = mid + else: + right = mid + + def __getitem__(self, i): + """ + Returns a batch of activation vectors, as well as a batch of labels + indicating which concept the batch of activation vectors is associated + with. + + args: + i (int): which (activation vector, label) batch in the dataset to + return + returns: + inputs (Tensor): i-th batch in Dataset (representing activation + vectors) + labels (Tensor): labels of i-th batch in Dataset + """ + assert i < self.length + k = self._i_to_k(i) + inputs = self.datasets[k][i - self.lowers[k]] + assert len(inputs.shape) == 2 + + labels = torch.tensor([self.labels[k]] * inputs.size(0), device=inputs.device) + return inputs, labels + + def __len__(self): + """ + returns the total number of batches in the labelled_dataset + """ + return self.length + + +def train_cav( + model_id, + concepts: List[Concept], + layers: Union[str, List[str]], + classifier: Classifier, + save_path: str, + classifier_kwargs: Dict, +) -> Dict[str, Dict[str, CAV]]: + r""" + A helper function for parallel CAV computations that can be called + from a python process. + + Please see the TCAV class documentation for further information. + + Args: + model_id (str): A unique identifier for the PyTorch model for which + we would like to load the layer activations and train a + model in order to compute CAVs. + concepts (list[Concept]): A list of Concept objects that are used + to train a classifier and learn decision boundaries between + those concepts for each layer defined in the `layers` + argument. + layers (str, list[str]): A list of layer names or a single layer + name that is used to compute the activations of all concept + examples per concept and train a classifier using those + activations. + classifier (Classifier): A custom classifier class, such as the + Sklearn "linear_model" that allows us to train a model + using the activation vectors extracted for a layer per concept. + It also allows us to access trained weights of the classifier + and the list of prediction classes. + save_path (str): The path for storing Concept Activation + Vectors (CAVs) and Activation Vectors (AVs). + classifier_kwargs (dict): Additional named arguments that are passed to + concept classifier's `train_and_eval` method. + + Returns: + cavs (dict): A dictionary of CAV objects indexed by concept ids and + layer names. It gives access to the weights of each concept + in a given layer and model statistics such as accuracies + that resulted in trained concept weights. + """ + + concepts_key = concepts_to_str(concepts) + cavs: Dict[str, Dict[str, CAV]] = defaultdict() + cavs[concepts_key] = defaultdict() + layers = [layers] if isinstance(layers, str) else layers + for layer in layers: + + # Create data loader to initialize the trainer. + datasets = [ + AV.load(save_path, model_id, concept.identifier, layer) + for concept in concepts + ] + + labels = [concept.id for concept in concepts] + + labelled_dataset = LabelledDataset(cast(List[AV.AVDataset], datasets), labels) + + def batch_collate(batch): + inputs, labels = zip(*batch) + return torch.cat(inputs), torch.cat(labels) + + dataloader = DataLoader(labelled_dataset, collate_fn=batch_collate) + + classifier_stats_dict = classifier.train_and_eval( + dataloader, **classifier_kwargs + ) + classifier_stats_dict = ( + {} if classifier_stats_dict is None else classifier_stats_dict + ) + + weights = classifier.weights() + assert ( + weights is not None and len(weights) > 0 + ), "Model weights connot be None or empty" + + classes = classifier.classes() + assert ( + classes is not None and len(classes) > 0 + ), "Classes cannot be None or empty" + + classes = ( + cast(torch.Tensor, classes).detach().numpy() + if isinstance(classes, torch.Tensor) + else classes + ) + cavs[concepts_key][layer] = CAV( + concepts, + layer, + {"weights": weights, "classes": classes, **classifier_stats_dict}, + save_path, + model_id, + ) + # Saving cavs on the disk + cavs[concepts_key][layer].save() + + return cavs + + +class TCAV(ConceptInterpreter): + r""" + This class implements ConceptInterpreter abstract class using an + approach called Testing with Concept Activation Vectors (TCAVs), + as described in the paper: + https://arxiv.org/pdf/1711.11279.pdf + + TCAV scores for a given layer, a list of concepts and input example + are computed using the dot product between prediction's layer + sensitivities for given input examples and Concept Activation Vectors + (CAVs) in that same layer. + + CAVs are defined as vectors that are orthogonal to the classification boundary + hyperplane that separate given concepts in a given layer from each other. + For a given layer, CAVs are computed by training a classifier that uses the + layer activation vectors for a set of concept examples as input examples and + concept ids as corresponding input labels. Trained weights of + that classifier represent CAVs. + + CAVs are represented as a learned weight matrix with the dimensionality + C X F, where: + F represents the number of input features in the classifier. + C is the number of concepts used for the classification. Concept + ids are used as labels for concept examples during the training. + + We can use any layer attribution algorithm to compute layer sensitivities + of a model prediction. + For example, the gradients of an output prediction w.r.t. the outputs of + the layer. + The CAVs and the Sensitivities (SENS) are used to compute the TCAV score: + + 0. TCAV = CAV • SENS, a dot product between those two vectors + + The final TCAV score can be computed by aggregating the TCAV scores + for each input concept based on the sign or magnitude of the tcav scores. + + 1. sign_count_score = | TCAV > 0 | / | TCAV | + 2. magnitude_score = SUM(ABS(TCAV * (TCAV > 0))) / SUM(ABS(TCAV)) + """ + + def __init__( + self, + model: Module, + layers: Union[str, List[str]], + model_id: str = "default_model_id", + classifier: Classifier = None, + layer_attr_method: LayerAttribution = None, + attribute_to_layer_input=False, + save_path: str = "./cav/", + **classifier_kwargs: Any, + ) -> None: + r""" + Args: + model (Module): An instance of pytorch model that is used to compute + layer activations and attributions. + layers (str, list[str]): A list of layer name(s) that are + used for computing concept activations (cavs) and layer + attributions. + model_id (str, optional): A unique identifier for the PyTorch `model` + passed as first argument to the constructor of TCAV class. It + is used to store and load activations for given input `model` + and associated `layers`. + classifier (Classifier, optional): A custom classifier class, such as the + Sklearn "linear_model" that allows us to train a model + using the activation vectors extracted for a layer per concept. + It also allows us to access trained weights of the model + and the list of prediction classes. + layer_attr_method (LayerAttribution, optional): An instance of a layer + attribution algorithm that helps us to compute model prediction + sensitivity scores. + + Default: None + If `layer_attr_method` is None, we default it to gradients + for the layers using `LayerGradientXActivation` layer + attribution algorithm. + save_path (str, optional): The path for storing CAVs and + Activation Vectors (AVs). + classifier_kwargs (any, optional): Additional arguments such as + `test_split_ratio` that are passed to concept `classifier`. + + Examples:: + >>> + >>> # TCAV use example: + >>> + >>> # Define the concepts + >>> stripes = Concept(0, "stripes", striped_data_iter) + >>> random = Concept(1, "random", random_data_iter) + >>> + >>> + >>> mytcav = TCAV(model=imagenet, + >>> layers=['inception4c', 'inception4d']) + >>> + >>> scores = mytcav.interpret(inputs, [[stripes, random]], target = 0) + >>> + For more thorough examples, please check out TCAV tutorial and test cases. + """ + ConceptInterpreter.__init__(self, model) + self.layers = [layers] if isinstance(layers, str) else layers + self.model_id = model_id + self.concepts: Set[Concept] = set() + self.classifier = classifier + self.classifier_kwargs = classifier_kwargs + self.cavs: Dict[str, Dict[str, CAV]] = defaultdict(lambda: defaultdict()) + if self.classifier is None: + self.classifier = DefaultClassifier() + if layer_attr_method is None: + self.layer_attr_method = cast( + LayerAttribution, + LayerGradientXActivation( # type: ignore + model, None, multiply_by_inputs=False + ), + ) + else: + self.layer_attr_method = layer_attr_method + + assert model_id, ( + "`model_id` cannot be None or empty. Consider giving `model_id` " + "a meaningful name or leave it unspecified. If model_id is unspecified we " + "will use `default_model_id` as its default value." + ) + + self.attribute_to_layer_input = attribute_to_layer_input + self.save_path = save_path + + # Creates CAV save directory if it doesn't exist. It is created once in the + # constructor before generating the CAVs. + # It is assumed that `model_id` can be used as a valid directory name + # otherwise `create_cav_dir_if_missing` will raise an error + CAV.create_cav_dir_if_missing(self.save_path, model_id) + + def generate_all_activations(self) -> None: + r""" + Computes layer activations for all concepts and layers that are + defined in `self.layers` and `self.concepts` instance variables. + """ + for concept in self.concepts: + self.generate_activation(self.layers, concept) + + def generate_activation(self, layers: Union[str, List], concept: Concept) -> None: + r""" + Computes layer activations for the specified `concept` and + the list of layer(s) `layers`. + + Args: + layers (str, list[str]): A list of layer names or a layer name + that is used to compute layer activations for the + specific `concept`. + concept (Concept): A single Concept object that provides access + to concept examples using a data iterator. + """ + layers = [layers] if isinstance(layers, str) else layers + layer_modules = [_get_module_from_name(self.model, layer) for layer in layers] + + layer_act = LayerActivation(self.model, layer_modules) + assert concept.data_iter is not None, ( + "Data iterator for concept id:", + "{} must be specified".format(concept.id), + ) + for i, examples in enumerate(concept.data_iter): + activations = layer_act.attribute.__wrapped__( # type: ignore + layer_act, + examples, + attribute_to_layer_input=self.attribute_to_layer_input, + ) + for activation, layer_name in zip(activations, layers): + activation = torch.reshape(activation, (activation.shape[0], -1)) + AV.save( + self.save_path, + self.model_id, + concept.identifier, + layer_name, + activation.detach(), + str(i), + ) + + def generate_activations(self, concept_layers: Dict[Concept, List[str]]) -> None: + r""" + Computes layer activations for the concepts and layers specified in + `concept_layers` dictionary. + + Args: + concept_layers (dict[Concept, list[str]]): Dictionay that maps + Concept objects to a list of layer names to generate + the activations. Ex.: concept_layers = + {"striped": ['inception4c', 'inception4d']} + """ + for concept in concept_layers: + self.generate_activation(concept_layers[concept], concept) + + def load_cavs( + self, concepts: List[Concept] + ) -> Tuple[List[str], Dict[Concept, List[str]]]: + r""" + This function load CAVs as a dictionary of concept ids and + layers. CAVs are stored in a directory located under + `self.save_path` path, in .pkl files with the format: + /-.pkl. Ex.: + "/cavs/0-1-2-inception4c.pkl", where 0, 1 and 2 are concept ids. + + It returns a list of layers and a dictionary of concept-layers mapping + for the concepts and layer that require CAV computation through training. + This can happen if the CAVs aren't already pre-computed for a given list + of concepts and layer. + + Args: + concepts (list[Concept]): A list of Concept objects for which we want + to load the CAV. + + Returns: + layers (list[layer]): A list of layers for which some CAVs still need + to be computed. + concept_layers (dict[concept, layer]): A dictionay of concept-layers + mapping for which we need to perform CAV computation through + training. + """ + + concepts_key = concepts_to_str(concepts) + + layers = [] + concept_layers = defaultdict(list) + + for layer in self.layers: + self.cavs[concepts_key][layer] = CAV.load( + self.save_path, self.model_id, concepts, layer + ) + + # If CAV aren't loaded + if ( + concepts_key not in self.cavs + or layer not in self.cavs[concepts_key] + or not self.cavs[concepts_key][layer] + ): + + layers.append(layer) + # For all concepts in this experimental_set + for concept in concepts: + # Collect not activated layers for this concept + if not AV.exists( + self.save_path, self.model_id, layer, concept.identifier + ): + concept_layers[concept].append(layer) + return layers, concept_layers + + def compute_cavs( + self, + experimental_sets: List[List[Concept]], + force_train: bool = False, + processes: int = None, + ): + r""" + This method computes CAVs for given `experiments_sets` and layers + specified in `self.layers` instance variable. Internally, it + trains a classifier and creates an instance of CAV class using the + weights of the trained classifier for each experimental set. + + It also allows to compute the CAVs in parallel using python's + multiprocessing API and the number of processes specified in + the argument. + + Args: + experimental_sets (list[list[Concept]]): A list of lists of concept + instances for which the cavs will be computed. + force_train (bool, optional): A flag that indicates whether to + train the CAVs regardless of whether they are saved or not. + Default: False + processes (int, optional): The number of processes to be created + when running in multi-processing mode. If processes > 0 then + CAV computation will be performed in parallel using + multi-processing, otherwise it will be performed sequentially + in a single process. + Default: None + Returns: + cavs (dict) : A mapping of concept ids and layers to CAV objects. + If CAVs for the concept_ids-layer pairs are present in the + data storage they will be loaded into the memory, otherwise + they will be computed using a training process and stored + in the data storage that can be configured using `save_path` + input argument. + """ + + # Update self.concepts with concepts + for concepts in experimental_sets: + self.concepts.update(concepts) + + concept_ids = [] + for concept in self.concepts: + assert concept.id not in concept_ids, ( + "There is more than one instance " + "of a concept with id {} defined in experimental sets. Please, " + "make sure to reuse the same instance of concept".format( + str(concept.id) + ) + ) + concept_ids.append(concept.id) + + if force_train: + self.generate_all_activations() + + # List of layers per concept key (experimental_set item) to be trained + concept_key_to_layers = defaultdict(list) + + for concepts in experimental_sets: + + concepts_key = concepts_to_str(concepts) + + # If not 'force_train', try to load a saved CAV + if not force_train: + layers, concept_layers = self.load_cavs(concepts) + concept_key_to_layers[concepts_key] = layers + # Generate activations for missing (concept, layers) + self.generate_activations(concept_layers) + else: + concept_key_to_layers[concepts_key] = self.layers + if processes is not None and processes > 1: + pool = multiprocessing.Pool(processes) + cavs_list = pool.starmap( + train_cav, + [ + ( + self.model_id, + concepts, + concept_key_to_layers[concepts_to_str(concepts)], + self.classifier, + self.save_path, + self.classifier_kwargs, + ) + for concepts in experimental_sets + ], + ) + + pool.close() + pool.join() + + else: + cavs_list = [] + for concepts in experimental_sets: + cavs_list.append( + train_cav( + self.model_id, + concepts, + concept_key_to_layers[concepts_to_str(concepts)], + cast(Classifier, self.classifier), + self.save_path, + self.classifier_kwargs, + ) + ) + + # list[Dict[concept, Dict[layer, list]]] => Dict[concept, Dict[layer, list]] + for cavs in cavs_list: + for c_key in cavs: + self.cavs[c_key].update(cavs[c_key]) + + return self.cavs + + @log_usage() + def interpret( + self, + inputs: TensorOrTupleOfTensorsGeneric, + experimental_sets: List[List[Concept]], + target: TargetType = None, + additional_forward_args: Any = None, + processes: int = None, + **kwargs: Any, + ) -> Dict[str, Dict[str, Dict[str, Tensor]]]: + r""" + This method computes magnitude and sign-based TCAV scores for each + experimental sets in `experimental_sets` list. + TCAV scores are computed using a dot product between layer attribution + scores for specific predictions and CAV vectors. + + Args: + inputs (tensor or tuple of tensors): Inputs for which predictions + are performed and attributions are computed. + If model takes a single tensor as + input, a single input tensor should be provided. + If model takes multiple tensors as + input, a tuple of the input tensors should be provided. + It is assumed that for all given input tensors, + dimension 0 corresponds to the number of examples + (aka batch size), and if multiple input tensors are + provided, the examples must be aligned appropriately. + experimental_sets (list[list[Concept]]): A list of list of Concept + instances. + target (int, tuple, tensor or list, optional): Output indices for + which attributions are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + additional_forward_args (Any, optional): Extra arguments that are passed to + model when computing the attributions for `inputs` + w.r.t. layer output. + Default: None + processes (int, optional): The number of processes to be created. if + processes is larger than one then CAV computations will be + performed in parallel using the number of processes equal to + `processes`. Otherwise, CAV computations will be performed + sequential. + Default:None + **kwargs (Any, optional): A list of arguments that are passed to layer + attribution algorithm's attribute method. This could be for + example `n_steps` in case of integrated gradients. + Default: None + Returns: + results (dict): A dictionary of sign and magnitude -based tcav scores + for each concept set per layer. + The order of TCAV scores in the resulting tensor for each + experimental set follows the order in which concepts + are passed in `experimental_sets` input argument. + + results example:: + >>> # + >>> # scores = + >>> # {'0-1': + >>> # {'inception4c': + >>> # {'sign_count': tensor([0.5800, 0.4200]), + >>> # 'magnitude': tensor([0.6613, 0.3387])}, + >>> # 'inception4d': + >>> # {'sign_count': tensor([0.6200, 0.3800]), + >>> # 'magnitude': tensor([0.7707, 0.2293])}}), + >>> # '0-2': + >>> # {'inception4c': + >>> # {'sign_count': tensor([0.6200, 0.3800]), + >>> # 'magnitude': tensor([0.6806, 0.3194])}, + >>> # 'inception4d': + >>> # {'sign_count': tensor([0.6400, 0.3600]), + >>> # 'magnitude': tensor([0.6563, 0.3437])}})}) + >>> # + + """ + assert "attribute_to_layer_input" not in kwargs, ( + "Please, set `attribute_to_layer_input` flag as a constructor " + "argument to TCAV class. In that case it will be applied " + "consistently to both layer activation and layer attribution methods." + ) + self.compute_cavs(experimental_sets, processes=processes) + + scores: Dict[str, Dict[str, Dict[str, Tensor]]] = defaultdict( + lambda: defaultdict() + ) + + # Retrieves the lengths of the experimental sets so that we can sort + # them by the length and compute TCAV scores in batches. + exp_set_lens = np.array( + list(map(lambda exp_set: len(exp_set), experimental_sets)), dtype=object + ) + exp_set_lens_arg_sort = np.argsort(exp_set_lens) + + # compute offsets using sorted lengths using their indices + exp_set_lens_sort = exp_set_lens[exp_set_lens_arg_sort] + exp_set_offsets_bool = [False] + list( + exp_set_lens_sort[:-1] == exp_set_lens_sort[1:] + ) + exp_set_offsets = [] + for i, offset in enumerate(exp_set_offsets_bool): + if not offset: + exp_set_offsets.append(i) + + exp_set_offsets.append(len(exp_set_lens)) + + # sort experimental sets using the length of the concepts in each set + experimental_sets_sorted = np.array(experimental_sets, dtype=object)[ + exp_set_lens_arg_sort + ] + + for layer in self.layers: + layer_module = _get_module_from_name(self.model, layer) + self.layer_attr_method.layer = layer_module + attribs = self.layer_attr_method.attribute.__wrapped__( # type: ignore + self.layer_attr_method, # self + inputs, + target=target, + additional_forward_args=additional_forward_args, + attribute_to_layer_input=self.attribute_to_layer_input, + **kwargs, + ) + + attribs = _format_tensor_into_tuples(attribs) + # n_inputs x n_features + attribs = torch.cat( + [torch.reshape(attrib, (attrib.shape[0], -1)) for attrib in attribs], + dim=1, + ) + + # n_experiments x n_concepts x n_features + cavs = [] + classes = [] + for concepts in experimental_sets: + concepts_key = concepts_to_str(concepts) + cavs_stats = cast(Dict[str, Any], self.cavs[concepts_key][layer].stats) + cavs.append(cavs_stats["weights"].float().detach().tolist()) + classes.append(cavs_stats["classes"]) + + # sort cavs and classes using the length of the concepts in each set + cavs_sorted = np.array(cavs, dtype=object)[exp_set_lens_arg_sort] + classes_sorted = np.array(classes, dtype=object)[exp_set_lens_arg_sort] + i = 0 + while i < len(exp_set_offsets) - 1: + cav_subset = np.array( + cavs_sorted[exp_set_offsets[i] : exp_set_offsets[i + 1]], + dtype=object, + ).tolist() + classes_subset = classes_sorted[ + exp_set_offsets[i] : exp_set_offsets[i + 1] + ].tolist() + + # n_experiments x n_concepts x n_features + cav_subset = torch.tensor(cav_subset) + cav_subset = cav_subset.to(attribs.device) + assert len(cav_subset.shape) == 3, ( + "cav should have 3 dimensions: n_experiments x " + "n_concepts x n_features." + ) + + experimental_subset_sorted = experimental_sets_sorted[ + exp_set_offsets[i] : exp_set_offsets[i + 1] + ] + self._tcav_sub_computation( + scores, + layer, + attribs, + cav_subset, + classes_subset, + experimental_subset_sorted, + ) + i += 1 + + return scores + + def _tcav_sub_computation( + self, + scores: Dict[str, Dict[str, Dict[str, Tensor]]], + layer: str, + attribs: Tensor, + cavs: Tensor, + classes: List[List[int]], + experimental_sets: List[List[Concept]], + ) -> None: + # n_inputs x n_concepts + tcav_score = torch.matmul(attribs.float(), torch.transpose(cavs, 1, 2)) + assert len(tcav_score.shape) == 3, ( + "tcav_score should have 3 dimensions: n_experiments x " + "n_inputs x n_concepts." + ) + + assert attribs.shape[0] == tcav_score.shape[1], ( + "attrib and tcav_score should have the same 1st and " + "2nd dimensions respectively (n_inputs)." + ) + # n_experiments x n_concepts + sign_count_score = torch.mean((tcav_score > 0.0).float(), dim=1) + + magnitude_score = torch.mean(tcav_score, dim=1) + + for i, (cls_set, concepts) in enumerate(zip(classes, experimental_sets)): + concepts_key = concepts_to_str(concepts) + + # sort classes / concepts in the order specified in concept_keys + concept_ord = [concept.id for concept in concepts] + class_ord = {cls_: idx for idx, cls_ in enumerate(cls_set)} + + new_ord = torch.tensor( + [class_ord[cncpt] for cncpt in concept_ord], device=tcav_score.device + ) + + # sort based on classes + scores[concepts_key][layer] = { + "sign_count": torch.index_select( + sign_count_score[i, :], dim=0, index=new_ord + ), + "magnitude": torch.index_select( + magnitude_score[i, :], dim=0, index=new_ord + ), + } diff --git a/captum/concept/_utils/__init__.py b/captum/concept/_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/concept/_utils/classifier.py b/captum/concept/_utils/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..b9b21f809dc56ef81d3c1fa9b4a8fabc08445a9b --- /dev/null +++ b/captum/concept/_utils/classifier.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 + +import random +import warnings +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple, Union + +import torch +from captum._utils.models.linear_model import model +from torch import Tensor +from torch.utils.data import DataLoader, TensorDataset + + +class Classifier(ABC): + r""" + An abstract class definition of any classifier that allows to train a model + and access trained weights of that model. + + More specifically the classifier can, for instance, be trained on the + activations of a particular layer. Below we can see an example a sklearn + linear classifier wrapped by the `CustomClassifier` which extends `Classifier` + abstract class. + + Example:: + + >>> from sklearn import linear_model + >>> + >>> class CustomClassifier(Classifier): + >>> + >>> def __init__(self): + >>> + >>> self.lm = linear_model.SGDClassifier(alpha=0.01, max_iter=1000, + >>> tol=1e-3) + >>> + >>> def train_and_eval(self, dataloader): + >>> + >>> x_train, x_test, y_train, y_test = train_test_split(inputs, labels) + >>> self.lm.fit(x_train.detach().numpy(), y_train.detach().numpy()) + >>> + >>> preds = torch.tensor(self.lm.predict(x_test.detach().numpy())) + >>> return {'accs': (preds == y_test).float().mean()} + >>> + >>> + >>> def weights(self): + >>> + >>> if len(self.lm.coef_) == 1: + >>> # if there are two concepts, there is only one label. + >>> # We split it in two. + >>> return torch.tensor([-1 * self.lm.coef_[0], self.lm.coef_[0]]) + >>> else: + >>> return torch.tensor(self.lm.coef_) + >>> + >>> + >>> def classes(self): + >>> return self.lm.classes_ + >>> + >>> + + """ + + @abstractmethod + def __init__(self) -> None: + pass + + @abstractmethod + def train_and_eval( + self, dataloader: DataLoader, **kwargs: Any + ) -> Union[Dict, None]: + r""" + This method is responsible for training a classifier using the data + provided through `dataloader` input arguments. Based on the specific + implementation, it may or may not return a statistics about model + training and evaluation. + + Args: + dataloader (dataloader): A dataloader that enables batch-wise access to + the inputs and corresponding labels. Dataloader allows us to + iterate over the dataset by loading the batches in lazy manner. + kwargs (dict): Named arguments that are used for training and evaluating + concept classifier. + Default: None + Returns: + stats (dict): a dictionary of statistics about the performance of the model. + For example the accuracy of the model on the test and/or + train dataset(s). The user may decide to return None or an + empty dictionary if she/he decides to not return any performance + statistics. + """ + pass + + @abstractmethod + def weights(self) -> Tensor: + r""" + This function returns a C x F tensor weights, where + C is the number of classes and F is the number of features. + + Returns: + weights (tensor): A torch Tensor with the weights resulting from + the model training. + """ + pass + + @abstractmethod + def classes(self) -> List[int]: + r""" + This function returns the list of all classes that are used by the + classifier to train the model in the `train_and_eval` method. + The order of returned classes has to match the same order used in + the weights matrix returned by the `weights` method. + + Returns: + classes (list): The list of classes used by the classifier to train + the model in the `train_and_eval` method. + """ + pass + + +class DefaultClassifier(Classifier): + r""" + A default Linear Classifier based on sklearn's SGDClassifier for + learning decision boundaries between concepts. + Note that default implementation slices input dataset into train and test + splits and keeps them in memory. + In case concept datasets are large, this can lead to out of memory and we + recommend to provide a custom Classier that extends `Classifier` abstract + class and handles large concept datasets accordingly. + """ + + def __init__(self): + warnings.warn( + "Using default classifier for TCAV which keeps input" + " both train and test datasets in the memory. Consider defining" + " your own classifier that doesn't rely heavily on memory, for" + " large number of concepts, by extending" + " `Classifer` abstract class" + ) + self.lm = model.SkLearnSGDClassifier(alpha=0.01, max_iter=1000, tol=1e-3) + + def train_and_eval( + self, dataloader: DataLoader, test_split_ratio: float = 0.33, **kwargs: Any + ) -> Union[Dict, None]: + r""" + Implements Classifier::train_and_eval abstract method for small concept + datsets provided by `dataloader`. + It is assumed that when iterating over `dataloader` we can still + retain the entire dataset in the memory. + This method shuffles all examples randomly provided, splits them + into train and test partitions and trains an SGDClassifier using sklearn + library. Ultimately, it measures and returns model accuracy using test + split of the dataset. + + Args: + dataloader (dataloader): A dataloader that enables batch-wise access to + the inputs and corresponding labels. Dataloader allows us to + iterate over the dataset by loading the batches in lazy manner. + test_split_ratio (float): The ratio of test split in the entire dataset + served by input data loader `dataloader`. + + Default: 0.33 + Returns: + stats (dict): a dictionary of statistics about the performance of the model. + In this case stats represents a dictionary of model accuracy + measured on the test split of the dataset. + + """ + inputs = [] + labels = [] + for input, label in dataloader: + inputs.append(input) + labels.append(label) + + device = "cpu" if input is None else input.device + x_train, x_test, y_train, y_test = _train_test_split( + torch.cat(inputs), torch.cat(labels), test_split=test_split_ratio + ) + self.lm.device = device + self.lm.fit(DataLoader(TensorDataset(x_train, y_train))) + + predict = self.lm(x_test) + + predict = self.lm.classes()[torch.argmax(predict, dim=1)] + score = predict.long() == y_test.long().cpu() + + accs = score.float().mean() + + return {"accs": accs} + + def weights(self) -> Tensor: + r""" + This function returns a C x F tensor weights, where + C is the number of classes and F is the number of features. + In case of binary classification, C = 2 othewise it is > 2. + + Returns: + weights (tensor): A torch Tensor with the weights resulting from + the model training. + """ + assert self.lm.linear is not None, ( + "The weights cannot be obtained because no model was trained." + "In order to train the model call `train_and_eval` method first." + ) + weights = self.lm.representation() + if weights.shape[0] == 1: + # if there are two concepts, there is only one label. We split it in two. + return torch.stack([-1 * weights[0], weights[0]]) + else: + return weights + + def classes(self) -> List[int]: + r""" + This function returns the list of all classes that are used by the + classifier to train the model in the `train_and_eval` method. + The order of returned classes has to match the same order used in + the weights matrix returned by the `weights` method. + + Returns: + classes (list): The list of classes used by the classifier to train + the model in the `train_and_eval` method. + """ + return self.lm.classes().detach().numpy() + + +def _train_test_split( + x_list: Tensor, y_list: Tensor, test_split: float = 0.33 +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + # Shuffle + z_list = list(zip(x_list, y_list)) + random.shuffle(z_list) + # Split + test_size = int(test_split * len(z_list)) + z_test, z_train = z_list[:test_size], z_list[test_size:] + x_test, y_test = zip(*z_test) + x_train, y_train = zip(*z_train) + return ( + torch.stack(x_train), + torch.stack(x_test), + torch.stack(y_train), + torch.stack(y_test), + ) diff --git a/captum/concept/_utils/common.py b/captum/concept/_utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..61617365099a619346023b3e88cc74b23d11972b --- /dev/null +++ b/captum/concept/_utils/common.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 + +from typing import List + +from captum.concept._core.concept import Concept + + +def concepts_to_str(concepts: List[Concept]) -> str: + r""" + Returns a string of hyphen("-") concatenated concept names. + Example output: "striped-random_0-random_1" + + Args: + concepts (list[Concept]): a List of concept names to be + concatenated and used as a concepts key. These concept + names are respective to the Concept objects used for + the classifier train. + Returns: + names_str (str): A string of hyphen("-") concatenated + concept names. Ex.: "striped-random_0-random_1" + """ + + return "-".join([str(c.id) for c in concepts]) diff --git a/captum/concept/_utils/data_iterator.py b/captum/concept/_utils/data_iterator.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8a48f197a424e5d8791955788537255ce4099f --- /dev/null +++ b/captum/concept/_utils/data_iterator.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 + +import glob +import os +from typing import Callable, Iterator + +from torch import Tensor +from torch.utils.data import DataLoader, Dataset, IterableDataset + + +class CustomIterableDataset(IterableDataset): + r""" + An auxiliary class for iterating through a dataset. + """ + + def __init__(self, transform_filename_to_tensor: Callable, path: str) -> None: + r""" + Args: + transform_filename_to_tensor (callable): Function to read a data + file from path and return a tensor from that file. + path (str): Path to dataset files. This can be either a path to a + directory or a file where input examples are stored. + """ + self.file_itr = None + self.path = path + + if os.path.isdir(self.path): + self.file_itr = glob.glob(self.path + "*") + + self.transform_filename_to_tensor = transform_filename_to_tensor + + def __iter__(self) -> Iterator[Tensor]: + r""" + Returns: + iter (Iterator[Tensor]): A map from a function that + processes a list of file path(s) to a list of Tensors. + """ + if self.file_itr is not None: + return map(self.transform_filename_to_tensor, self.file_itr) + else: + return self.transform_filename_to_tensor(self.path) + + +def dataset_to_dataloader(dataset: Dataset, batch_size: int = 64) -> DataLoader: + r""" + An auxiliary function that creates torch DataLoader from torch Dataset + using input `batch_size`. + + Args: + dataset (Dataset): A torch dataset that allows to iterate over + the batches of examples. + batch_size (int, optional): Batch size of for each tensor in the + iteration. + + Returns: + dataloader_iter (DataLoader): a DataLoader for data iteration. + """ + + return DataLoader(dataset, batch_size=batch_size) diff --git a/captum/influence/__init__.py b/captum/influence/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac2c40a618dae492064b873d481285bb58e1fd84 --- /dev/null +++ b/captum/influence/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 + +from captum.influence._core.influence import DataInfluence # noqa +from captum.influence._core.similarity_influence import SimilarityInfluence # noqa +from captum.influence._core.tracincp import TracInCP, TracInCPBase # noqa +from captum.influence._core.tracincp_fast_rand_proj import ( + TracInCPFast, + TracInCPFastRandProj, +) # noqa + +__all__ = [ + "DataInfluence", + "SimilarityInfluence", + "TracInCPBase", + "TracInCP", + "TracInCPFast", + "TracInCPFastRandProj", +] diff --git a/captum/influence/_core/__init__.py b/captum/influence/_core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/influence/_core/influence.py b/captum/influence/_core/influence.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ef1eb882165f32e01c543eee2d73320e5e24e7 --- /dev/null +++ b/captum/influence/_core/influence.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +from abc import ABC, abstractmethod +from typing import Any + +from torch.nn import Module +from torch.utils.data import Dataset + + +class DataInfluence(ABC): + r""" + An abstract class to define model data influence skeleton. + """ + + def __init_( + self, model: Module, influence_src_dataset: Dataset, **kwargs: Any + ) -> None: + r""" + Args: + model (torch.nn.Module): An instance of pytorch model. + influence_src_dataset (torch.utils.data.Dataset): PyTorch Dataset that is + used to create a PyTorch Dataloader to iterate over the dataset and + its labels. This is the dataset for which we will be seeking for + influential instances. In most cases this is the training dataset. + **kwargs: Additional key-value arguments that are necessary for specific + implementation of `DataInfluence` abstract class. + """ + self.model = model + self.influence_src_dataset = influence_src_dataset + + @abstractmethod + def influence(self, inputs: Any = None, **kwargs: Any) -> Any: + r""" + Args: + inputs (Any): Batch of examples for which influential + instances are computed. They are passed to the forward_func. If + `inputs` if a tensor or tuple of tensors, the first dimension + of a tensor corresponds to the batch dimension. + **kwargs: Additional key-value arguments that are necessary for specific + implementation of `DataInfluence` abstract class. + + Returns: + influences (Any): We do not add restrictions on the return type for now, + though this may change in the future. + """ + pass diff --git a/captum/influence/_core/similarity_influence.py b/captum/influence/_core/similarity_influence.py new file mode 100644 index 0000000000000000000000000000000000000000..f781079a484ce9ea44ed25dd7bbe813d7588cd31 --- /dev/null +++ b/captum/influence/_core/similarity_influence.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python3 + +import warnings +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import captum._utils.common as common +import torch +from captum._utils.av import AV +from captum.attr import LayerActivation +from captum.influence._core.influence import DataInfluence +from torch import Tensor +from torch.nn import Module +from torch.utils.data import DataLoader, Dataset + +r""" +Additional helper functions to calculate similarity metrics. +""" + + +def euclidean_distance(test, train) -> Tensor: + r""" + Calculates the pairwise euclidean distance for batches of feature vectors. + Tensors test and train have shape (batch_size_1, *), and (batch_size_2, *). + Returns pairwise euclidean distance Tensor of shape (batch_size_1, batch_size_2). + """ + similarity = torch.cdist( + test.view(test.shape[0], -1).unsqueeze(0), + train.view(train.shape[0], -1).unsqueeze(0), + ).squeeze(0) + return similarity + + +def cosine_similarity(test, train, replace_nan=0) -> Tensor: + r""" + Calculates the pairwise cosine similarity for batches of feature vectors. + Tensors test and train have shape (batch_size_1, *), and (batch_size_2, *). + Returns pairwise cosine similarity Tensor of shape (batch_size_1, batch_size_2). + """ + test = test.view(test.shape[0], -1) + train = train.view(train.shape[0], -1) + + if torch.__version__ <= "1.6.0": + test_norm = torch.norm(test, p=None, dim=1, keepdim=True) + train_norm = torch.norm(train, p=None, dim=1, keepdim=True) + else: + test_norm = torch.linalg.norm(test, ord=2, dim=1, keepdim=True) + train_norm = torch.linalg.norm(train, ord=2, dim=1, keepdim=True) + + test = torch.where(test_norm != 0.0, test / test_norm, Tensor([replace_nan])) + train = torch.where(train_norm != 0.0, train / train_norm, Tensor([replace_nan])).T + + similarity = torch.mm(test, train) + return similarity + + +r""" +Implements abstract DataInfluence class and provides implementation details for +similarity metric-based influence computation. Similarity metrics can be used to compare +intermediate or final activation vectors of a model for different sets of input. Then, +these can be used to draw conclusions about influential instances. + +Some standard similarity metrics such as dot product similarity or euclidean distance +are provided, but the user can provide any custom similarity metric as well. +""" + + +class SimilarityInfluence(DataInfluence): + def __init__( + self, + module: Module, + layers: Union[str, List[str]], + influence_src_dataset: Dataset, + activation_dir: str, + model_id: str = "", + similarity_metric: Callable = cosine_similarity, + similarity_direction: str = "max", + batch_size: int = 1, + **kwargs: Any, + ): + r""" + Args: + module (torch.nn.Module): An instance of pytorch model. This model should + define all of its layers as attributes of the model. + layers (str or List of str): The fully qualified layer(s) for which the + activation vectors are computed. + influence_src_dataset (torch.utils.data.Dataset): PyTorch Dataset that is + used to create a PyTorch Dataloader to iterate over the dataset and + its labels. This is the dataset for which we will be seeking for + influential instances. In most cases this is the training dataset. + activation_dir (str): The directory of the path to store + and retrieve activation computations. Best practice would be to use + an absolute path. + model_id (str): The name/version of the model for which layer + activations are being computed. Activations will be stored and + loaded under the subdirectory with this name if provided. + similarity_metric (Callable): This is a callable function that computes a + similarity metric between two representations. For example, the + representations pair could be from the training and test sets. + + This function must adhere to certain standards. The inputs should be + torch Tensors with shape (batch_size_i/j, feature dimensions). The + output Tensor should have shape (batch_size_i, batch_size_j) with + scalar values corresponding to the similarity metric used for each + pairwise combination from the two batches. + + For example, suppose we use `batch_size_1 = 16` for iterating + through `influence_src_dataset`, and for the `inputs` argument + we pass in a Tensor with 3 examples, i.e. batch_size_2 = 3. Also, + suppose that our inputs and intermediate activations throughout the + model will have dimension (N, C, H, W). Then, the feature dimensions + should be flattened within this function. For example:: + + >>> av_test.shape + torch.Size([3, N, C, H, W]) + >>> av_src.shape + torch.Size([16, N, C, H, W]) + >>> av_test = torch.view(av_test.shape[0], -1) + >>> av_test.shape + torch.Size([3, N x C x H x W]) + + and similarly for av_src. The similarity_metric should then use + these flattened tensors to return the pairwise similarity matrix. + For example, `similarity_metric(av_test, av_src)` should return a + tensor of shape (3, 16). + + batch_size (int): Batch size for iterating through `influence_src_dataset`. + **kwargs: Additional key-value arguments that are necessary for specific + implementation of `DataInfluence` abstract class. + """ + self.module = module + self.layers = [layers] if isinstance(layers, str) else layers + self.influence_src_dataset = influence_src_dataset + self.activation_dir = activation_dir + self.model_id = model_id + self.batch_size = batch_size + + if similarity_direction == "max" or similarity_direction == "min": + self.similarity_direction = similarity_direction + else: + raise ValueError( + f"{similarity_direction} is not a valid value. " + "Must be either 'max' or 'min'" + ) + + if similarity_metric is cosine_similarity: + if "replace_nan" in kwargs: + self.replace_nan = kwargs["replace_nan"] + else: + self.replace_nan = -2 if self.similarity_direction == "max" else 2 + similarity_metric = partial(cosine_similarity, replace_nan=self.replace_nan) + + self.similarity_metric = similarity_metric + + self.influence_src_dataloader = DataLoader( + influence_src_dataset, batch_size, shuffle=False + ) + + def influence( # type: ignore[override] + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + top_k: int = 1, + additional_forward_args: Optional[Any] = None, + load_src_from_disk: bool = True, + **kwargs: Any, + ) -> Dict: + r""" + Args: + inputs (tensor or tuple of tensors): Batch of examples for which influential + instances are computed. They are passed to the forward_func. The + first dimension in `inputs` tensor or tuple of tensors corresponds + to the batch size. A tuple of tensors is only passed in if this + is the input form that `module` accepts. + top_k (int): The number of top-matching activations to return + additional_forward_args (optional): Additional arguments that will be + passed to forward_func after inputs. + load_src_from_disk (bool): Loads activations for `influence_src_dataset` + where possible. Setting to False would force regeneration of + activations. + load_input_from_disk (bool): Regenerates activations for inputs by default + and removes previous `inputs` activations that are flagged with + `inputs_id`. Setting to True will load prior matching inputs + activations. Note that this could lead to unexpected behavior if + `inputs_id` is not configured properly and activations are loaded + for a different, prior `inputs`. + inputs_id (str): Used to identify inputs for loading activations. + + **kwargs: Additional key-value arguments that are necessary for specific + implementation of `DataInfluence` abstract class. + + Returns: + + influences (dict): Returns the influential instances retrieved from + `influence_src_dataset` for each test example represented through a + tensor or a tuple of tensor in `inputs`. Returned influential + examples are represented as dict, with keys corresponding to + the layer names passed in `layers`. Each value in the dict is a + tuple containing the indices and values for the top k similarities + from `influence_src_dataset` by the chosen metric. The first value + in the tuple corresponds to the indices corresponding to the top k + most similar examples, and the second value is the similarity score. + The batch dimension corresponds to the batch dimension of `inputs`. + If inputs.shape[0] == 5, then dict[`layer_name`][0].shape[0] == 5. + These tensors will be of shape (inputs.shape[0], top_k). + """ + inputs_batch_size = ( + inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0] + ) + + influences: Dict[str, Any] = {} + + layer_AVDatasets = AV.generate_dataset_activations( + self.activation_dir, + self.module, + self.model_id, + self.layers, + DataLoader(self.influence_src_dataset, self.batch_size, shuffle=False), + identifier="src", + load_from_disk=load_src_from_disk, + return_activations=True, + ) + + assert layer_AVDatasets is not None and not isinstance( + layer_AVDatasets, AV.AVDataset + ) + + layer_modules = [ + common._get_module_from_name(self.module, layer) for layer in self.layers + ] + test_activations = LayerActivation(self.module, layer_modules).attribute( + inputs, additional_forward_args + ) + + minmax = self.similarity_direction == "max" + + # av_inputs shape: (inputs_batch_size, *) e.g. (inputs_batch_size, N, C, H, W) + # av_src shape: (self.batch_size, *) e.g. (self.batch_size, N, C, H, W) + test_activations = ( + test_activations if len(self.layers) > 1 else [test_activations] + ) + for i, (layer, layer_AVDataset) in enumerate( + zip(self.layers, layer_AVDatasets) + ): + topk_val, topk_idx = torch.Tensor(), torch.Tensor().long() + zero_acts = torch.Tensor().long() + + av_inputs = test_activations[i] + src_loader = DataLoader(layer_AVDataset) + for j, av_src in enumerate(src_loader): + av_src = av_src.squeeze(0) + + similarity = self.similarity_metric(av_inputs, av_src) + msg = ( + "Output of custom similarity does not meet required dimensions. " + f"Your output has shape {similarity.shape}.\nPlease ensure the " + "output shape matches (inputs_batch_size, src_dataset_batch_size), " + f"which should be {(inputs_batch_size, self.batch_size)}." + ) + assert similarity.shape == (inputs_batch_size, av_src.shape[0]), msg + if hasattr(self, "replace_nan"): + idx = (similarity == self.replace_nan).nonzero() + zero_acts = torch.cat((zero_acts, idx)) + + r""" + TODO: For models that can have tuples as activations, we should + allow similarity metrics to accept tuples, support topk selection. + """ + + topk_batch = min(top_k, self.batch_size) + values, indices = torch.topk( + similarity, topk_batch, dim=1, largest=minmax + ) + indices += int(j * self.batch_size) + + topk_val = torch.cat((topk_val, values), dim=1) + topk_idx = torch.cat((topk_idx, indices), dim=1) + + # can modify how often to sort for efficiency? minor + sort_idx = torch.argsort(topk_val, dim=1, descending=minmax) + topk_val = torch.gather(topk_val, 1, sort_idx[:, :top_k]) + topk_idx = torch.gather(topk_idx, 1, sort_idx[:, :top_k]) + + influences[layer] = (topk_idx, topk_val) + + if torch.numel(zero_acts != 0): + zero_warning = ( + f"Layer {layer} has zero-vector activations for some inputs. This " + "may cause undefined behavior for cosine similarity. The indices " + "for the offending inputs will be included under the key " + f"'zero_acts-{layer}' in the output dictionary. Indices are " + "returned as a tensor with [inputs_idx, src_dataset_idx] pairs " + "which may have corrupted similarity scores." + ) + warnings.warn(zero_warning, RuntimeWarning) + key = "-".join(["zero_acts", layer]) + influences[key] = zero_acts + + return influences diff --git a/captum/influence/_core/tracincp.py b/captum/influence/_core/tracincp.py new file mode 100644 index 0000000000000000000000000000000000000000..d3671767ce79c5f1a1ef0c5f4e25502449b62df0 --- /dev/null +++ b/captum/influence/_core/tracincp.py @@ -0,0 +1,1007 @@ +#!/usr/bin/env python3 + +import glob +import warnings +from abc import abstractmethod +from os.path import join +from typing import ( + Any, + Callable, + Iterator, + List, + Optional, + Union, + Tuple, + NamedTuple, + Type, +) + +import torch +from captum._utils.av import AV +from captum._utils.common import _format_inputs +from captum._utils.gradient import ( + _compute_jacobian_wrt_params, + _compute_jacobian_wrt_params_with_sample_wise_trick, +) +from captum._utils.progress import progress +from captum.influence._core.influence import DataInfluence +from captum.influence._utils.common import ( + _get_k_most_influential_helper, + _gradient_dot_product, + _load_flexible_state_dict, +) +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module +from torch.utils.data import DataLoader, Dataset + + +r""" + +Note: methods starting with "_" are protected, not private, and can be overridden in +child classes. They are not part of the API. + +Implements abstract DataInfluence class and provides implementation details for +influence computation based on the logic provided in TracIn paper +(https://arxiv.org/pdf/2002.08484.pdf). + +The TracIn paper proposes an idealized notion of influence which can be represented by +the total amount a training example reduces loss for a test example via a training +process such as stochastic gradient descent. As this idealized notion of influence is +impractical to compute, the TracIn paper proposes instead to compute an influence +score, which uses a first-order approximation for the change in loss for a test example +by a training example, which is accumulated across saved model checkpoints. This +influence score is accumulated via a summed dot-product of gradient vectors for the +scores/loss of a test and training example. +""" + +""" +TODO: Support for checkpoint type. Currently only supports model parameters as saved +checkpoints. Can use enum or string. + +Potential implementation from design doc: +checkpoint_type (Enum = [Parameters | Loss_Grad]): For performance, + saved / loaded checkpoints can be either model parameters, or + gradient of the loss function on an input w.r.t parameters. +""" + + +class KMostInfluentialResults(NamedTuple): + """ + This namedtuple stores the results of using the `influence` method. This method + is implemented by all subclasses of `TracInCPBase` to calculate + proponents / opponents. The `indices` field stores the indices of the + proponents / opponents for each example in the test batch. For example, if finding + opponents, `indices[i][j]` stores the index in the training data of the example + with the `j`-th highest influence score on the `i`-th example in the test batch. + Similarly, the `influence_scores` field stores the actual influence scores, so that + `influence_scores[i][j]` is the influence score of example `indices[i][j]` in the + training data on example `i` of the test batch. Please see `TracInCPBase.influence` + for more details. + """ + + indices: Tensor + influence_scores: Tensor + + +class TracInCPBase(DataInfluence): + """ + To implement the `influence` method, classes inheriting from `TracInCPBase` will + separately implement the private `_self_influence`, `_get_k_most_influential`, + and `_influence` methods. The public `influence` method is a wrapper for these + private methods. + """ + + def __init__( + self, + model: Module, + influence_src_dataset: Union[Dataset, DataLoader], + checkpoints: Union[str, List[str], Iterator], + checkpoints_load_func: Callable = _load_flexible_state_dict, + loss_fn: Optional[Union[Module, Callable]] = None, + batch_size: Union[int, None] = 1, + ) -> None: + r""" + Args: + model (torch.nn.Module): An instance of pytorch model. This model should + define all of its layers as attributes of the model. + influence_src_dataset (torch.utils.data.Dataset or torch.utils.DataLoader): + In the `influence` method, we either compute the influence score of + training examples on examples in a test batch, or self influence + scores for those training examples, depending on which mode is used. + This argument represents the training dataset containing those + training examples. In order to compute those influence scores, we + will create a Pytorch DataLoader yielding batches of training + examples that is then used for processing. If this argument is + already a Pytorch Dataloader, that DataLoader can be directly + used for processing. If it is instead a Pytorch Dataset, we will + create a DataLoader using it, with batch size specified by + `batch_size`. For efficiency purposes, the batch size of the + DataLoader used for processing should be as large as possible, but + not too large, so that certain intermediate quantities created + from a batch still fit in memory. Therefore, if + `influence_src_dataset` is a Dataset, `batch_size` should be large. + If `influence_src_dataset` was already a DataLoader to begin with, + it should have been constructed to have a large batch size. + checkpoints (str or List of str or Iterator): Either the directory of the + path to store and retrieve model checkpoints, a list of + filepaths with checkpoints from which to load, or an iterator which + returns objects from which to load checkpoints. + checkpoints_load_func (Callable, optional): The function to load a saved + checkpoint into a model to update its parameters, and get the + learning rate if it is saved. By default uses a utility to load a + model saved as a state dict. + Default: _load_flexible_state_dict + layers (List of str or None, optional): A list of layer names for which + gradients should be computed. If `layers` is None, gradients will + be computed for all layers. Otherwise, they will only be computed + for the layers specified in `layers`. + Default: None + loss_fn (Callable, optional): The loss function applied to model. + Default: None + batch_size (int or None, optional): Batch size of the DataLoader created to + iterate through `influence_src_dataset`, if it is a Dataset. + `batch_size` should be chosen as large as possible so that certain + intermediate quantities created from a batch still fit in memory. + Specific implementations of `TracInCPBase` will detail the size of + the intermediate quantities. `batch_size` must be an int if + `influence_src_dataset` is a Dataset. If `influence_src_dataset` + is a DataLoader, then `batch_size` is ignored as an argument. + Default: 1 + """ + + self.model = model + + if isinstance(checkpoints, str): + self.checkpoints = AV.sort_files(glob.glob(join(checkpoints, "*"))) + elif isinstance(checkpoints, List) and isinstance(checkpoints[0], str): + self.checkpoints = AV.sort_files(checkpoints) + else: + self.checkpoints = list(checkpoints) # cast to avoid mypy error + if isinstance(self.checkpoints, List): + assert len(self.checkpoints) > 0, "No checkpoints saved!" + + self.checkpoints_load_func = checkpoints_load_func + self.loss_fn = loss_fn + self.batch_size = batch_size + + if not isinstance(influence_src_dataset, DataLoader): + assert isinstance(batch_size, int), ( + "since the `influence_src_dataset` argument was a `Dataset`, " + "`batch_size` must be an int." + ) + self.influence_src_dataloader = DataLoader( + influence_src_dataset, batch_size, shuffle=False + ) + else: + self.influence_src_dataloader = influence_src_dataset + + self.influence_src_dataloader_len: Optional[int] = None + try: + + # since we will calculate the number of batches in + # `self.influence_src_dataloader` whenever we use progress bar, calculate + # it once in initialization, for re-use. + self.influence_src_dataloader_len = len(self.influence_src_dataloader) + except AttributeError: + pass + + @abstractmethod + def _self_influence(self, show_progress: bool = False): + """ + Returns: + self influence scores (tensor): 1D tensor containing self influence + scores for all examples in training dataset + `influence_src_dataset`. + show_progress (bool, optional): To compute the self influence scores for + all examples in training dataset `influence_src_dataset`, we + compute the self influence scores for each batch. If + `show_progress`is true, the progress of this computation will be + displayed. In particular, the number of batches for which self + influence scores have been computed will be displayed. It will + try to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + """ + pass + + @abstractmethod + def _get_k_most_influential( + self, + inputs: Tuple[Any, ...], + targets: Optional[Tensor] = None, + k: int = 5, + proponents: bool = True, + show_progress: bool = False, + ) -> KMostInfluentialResults: + r""" + Args: + inputs (Tuple of Any): A tuple that represents a batch of examples. It does + not represent labels, which are passed as `targets`. + targets (tensor, optional): If computing influence scores on a loss + function, these are the labels corresponding to the batch `inputs`. + Default: None + k (int, optional): The number of proponents or opponents to return per test + example. + Default: 5 + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`) + Default: True + show_progress (bool, optional): To compute the proponents (or opponents) + for the batch of examples, we perform computation for each batch in + training dataset `influence_src_dataset`, If `show_progress`is + true, the progress of this computation will be displayed. In + particular, the number of batches for which the computation has + been performed will be displayed. It will try to use tqdm if + available for advanced features (e.g. time estimation). Otherwise, + it will fallback to a simple output of progress. + Default: False + + Returns: + (indices, influence_scores) (namedtuple): `indices` is a torch.long Tensor + that contains the indices of the proponents (or opponents) for each + test example. Its dimension is `(inputs_batch_size, k)`, where + `inputs_batch_size` is the number of examples in `inputs`. For + example, if `proponents==True`, `indices[i][j]` is the index of the + example in training dataset `influence_src_dataset` with the + k-th highest influence score for the j-th example in `inputs`. + `indices` is a `torch.long` tensor so that it can directly be used + to index other tensors. Each row of `influence_scores` contains the + influence scores for a different test example, in sorted order. In + particular, `influence_scores[i][j]` is the influence score of + example `indices[i][j]` in training dataset `influence_src_dataset` + on example `i` in the test batch represented by `inputs` and + `targets`. + """ + pass + + @abstractmethod + def _influence( + self, + inputs: Tuple[Any, ...], + targets: Optional[Tensor] = None, + show_progress: bool = False, + ) -> Tensor: + r""" + Args: + inputs (Tuple of Any): A batch of examples. Does not represent labels, + which are passed as `targets`. The assumption is that + `self.model(*inputs)` produces the predictions for the batch. + targets (tensor, optional): If computing influence scores on a loss + function, these are the labels corresponding to the batch + `inputs`. + Default: None + + Returns: + influence_scores (tensor): Influence scores over the entire + training dataset `influence_src_dataset`. Dimensionality is + (inputs_batch_size, src_dataset_size). For example: + influence_scores[i][j] = the influence score for the j-th training + example to the i-th input example. + show_progress (bool, optional): To compute the influence of examples in + training dataset `influence_src_dataset`, we compute the influence + of each batch. If `show_progress`is true, the progress of this + computation will be displayed. In particular, the number of batches + for which influence has been computed will be displayed. It will + try to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + """ + pass + + @abstractmethod + def influence( # type: ignore[override] + self, + inputs: Any = None, + targets: Optional[Tensor] = None, + k: Optional[int] = None, + proponents: bool = True, + unpack_inputs: bool = True, + show_progress: bool = False, + ) -> Union[Tensor, KMostInfluentialResults]: + r""" + This is the key method of this class, and can be run in 3 different modes, + where the mode that is run depends on the arguments passed to this method: + + - self influence mode: This mode is used if `inputs` is None. This mode + computes the self influence scores for every example in + the training dataset `influence_src_dataset`. + - influence score mode: This mode is used if `inputs` is not None, and `k` is + None. This mode computes the influence score of every example in + training dataset `influence_src_dataset` on every example in the test + batch represented by `inputs` and `targets`. + - k-most influential mode: This mode is used if `inputs` is not None, and + `k` is not None, and an int. This mode computes the proponents or + opponents of every example in the test batch represented by `inputs` + and `targets`. In particular, for each test example in the test batch, + this mode computes its proponents (resp. opponents), which are the + indices in the training dataset `influence_src_dataset` of the training + examples with the `k` highest (resp. lowest) influence scores on the + test example. Proponents are computed if `proponents` is True. + Otherwise, opponents are computed. For each test example, this method + also returns the actual influence score of each proponent (resp. + opponent) on the test example. + + Args: + inputs (Any, optional): If not provided or `None`, the self influence mode + will be run. Otherwise, `inputs` is the test batch that will be + used when running in either influence score or k-most influential + mode. If the argument `unpack_inputs` is False, the + assumption is that `self.model(inputs)` produces the predictions + for a batch, and `inputs` can be of any type. Otherwise if the + argument `unpack_inputs` is True, the assumption is that + `self.model(*inputs)` produces the predictions for a batch, and + `inputs` will need to be a tuple. In other words, `inputs` will be + unpacked as an argument when passing to `self.model`. + Default: None + targets (tensor, optional): If computing influence scores on a loss + function, these are the labels corresponding to the batch `inputs`. + Default: None + k (int, optional): If not provided or `None`, the influence score mode will + be run. Otherwise, the k-most influential mode will be run, + and `k` is the number of proponents / opponents to return per + example in the test batch. + Default: None + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`), if running in k-most influential + mode. + Default: True + unpack_inputs (bool, optional): Whether to unpack the `inputs` argument to + when passing it to `model`, if `inputs` is a tuple (no unpacking + done otherwise). + Default: True + show_progress (bool, optional): For all modes, computation of results + requires "training dataset computations": computations for each + batch in the training dataset `influence_src_dataset`, which may + take a long time. If `show_progress`is true, the progress of + "training dataset computations" will be displayed. In particular, + the number of batches for which computations have been performed + will be displayed. It will try to use tqdm if available for + advanced features (e.g. time estimation). Otherwise, it will + fallback to a simple output of progress. + Default: False + + Returns: + The return value of this method depends on which mode is run. + + - self influence mode: if this mode is run (`inputs` is None), returns a 1D + tensor of self influence scores over training dataset + `influence_src_dataset`. The length of this tensor is the number of + examples in `influence_src_dataset`, regardless of whether it is a + Dataset or DataLoader. + - influence score mode: if this mode is run (`inputs is not None, `k` is + None), returns a 2D tensor `influence_scores` of shape + `(input_size, influence_src_dataset_size)`, where `input_size` is + the number of examples in the test batch, and + `influence_src_dataset_size` is the number of examples in + training dataset `influence_src_dataset`. In other words, + `influence_scores[i][j]` is the influence score of the `j`-th + example in `influence_src_dataset` on the `i`-th example in the + test batch. + - k-most influential mode: if this mode is run (`inputs` is not None, + `k` is an int), returns a namedtuple `(indices, influence_scores)`. + `indices` is a 2D tensor of shape `(input_size, k)`, where + `input_size` is the number of examples in the test batch. If + computing proponents (resp. opponents), `indices[i][j]` is the + index in training dataset `influence_src_dataset` of the example + with the `j`-th highest (resp. lowest) influence score (out of the + examples in `influence_src_dataset`) on the `i`-th example in the + test batch. `influence_scores` contains the corresponding influence + scores. In particular, `influence_scores[i][j]` is the influence + score of example `indices[i][j]` in `influence_src_dataset` on + example `i` in the test batch represented by `inputs` and + `targets`. + """ + pass + + @classmethod + def get_name(cls: Type["TracInCPBase"]) -> str: + r""" + Create readable class name. Due to the nature of the names of `TracInCPBase` + subclasses, simplies returns the class name. For example, for a class called + TracInCP, we return the string TracInCP. + + Returns: + name (str): a readable class name + """ + return cls.__name__ + + +def _influence_route_to_helpers( + influence_instance: TracInCPBase, + inputs: Any = None, + targets: Optional[Tensor] = None, + k: Optional[int] = None, + proponents: bool = True, + unpack_inputs: bool = True, + show_progress: bool = False, +) -> Union[Tensor, KMostInfluentialResults]: + """ + This is a helper function called by `TracInCP.influence` and + `TracInCPFast.influence`. Those methods share a common logic in that they assume + an instance of their respective classes implement 3 private methods + (`_self_influence`, `_influence`, `_get_k_most_influential`), and the logic of + which private method to call is common, as described in the documentation of the + `influence` method. The arguments and return values of this function are the exact + same as the `influence` method. Note that `influence_instance` refers to the + instance for which the `influence` method was called. + """ + _inputs = _format_inputs(inputs, unpack_inputs) + + if inputs is None: + return influence_instance._self_influence(show_progress) + elif k is None: + return influence_instance._influence(_inputs, targets, show_progress) + else: + return influence_instance._get_k_most_influential( + _inputs, targets, k, proponents, show_progress + ) + + +class TracInCP(TracInCPBase): + def __init__( + self, + model: Module, + influence_src_dataset: Union[Dataset, DataLoader], + checkpoints: Union[str, List[str], Iterator], + checkpoints_load_func: Callable = _load_flexible_state_dict, + layers: Optional[List[str]] = None, + loss_fn: Optional[Union[Module, Callable]] = None, + batch_size: Union[int, None] = 1, + sample_wise_grads_per_batch: bool = False, + ) -> None: + r""" + Args: + model (torch.nn.Module): An instance of pytorch model. This model should + define all of its layers as attributes of the model. + influence_src_dataset (torch.utils.data.Dataset or torch.utils.DataLoader): + In the `influence` method, we either compute the influence score of + training examples on examples in a test batch, or self influence + scores for those training examples, depending on which mode is used. + This argument represents the training dataset containing those + training examples. In order to compute those influence scores, we + will create a Pytorch DataLoader yielding batches of training + examples that is then used for processing. If this argument is + already a Pytorch Dataloader, that DataLoader can be directly + used for processing. If it is instead a Pytorch Dataset, we will + create a DataLoader using it, with batch size specified by + `batch_size`. For efficiency purposes, the batch size of the + DataLoader used for processing should be as large as possible, but + not too large, so that certain intermediate quantities created + from a batch still fit in memory. Therefore, if + `influence_src_dataset` is a Dataset, `batch_size` should be large. + If `influence_src_dataset` was already a DataLoader to begin with, + it should have been constructed to have a large batch size. + checkpoints (str or List of str or Iterator): Either the directory of the + path to store and retrieve model checkpoints, a list of + filepaths with checkpoints from which to load, or an iterator which + returns objects from which to load checkpoints. + checkpoints_load_func (Callable, optional): The function to load a saved + checkpoint into a model to update its parameters, and get the + learning rate if it is saved. By default uses a utility to load a + model saved as a state dict. + Default: _load_flexible_state_dict + layers (List of str or None, optional): A list of layer names for which + gradients should be computed. If `layers` is None, gradients will + be computed for all layers. Otherwise, they will only be computed + for the layers specified in `layers`. + Default: None + loss_fn (Callable, optional): The loss function applied to model. There + are two options for the return type of `loss_fn`. First, `loss_fn` + can be a "per-example" loss function - returns a 1D Tensor of + losses for each example in a batch. `nn.BCELoss(reduction="none")` + would be an "per-example" loss function. Second, `loss_fn` can be + a "reduction" loss function that reduces the per-example losses, + in a batch, and returns a single scalar Tensor. For this option, + the reduction must be the *sum* or the *mean* of the per-example + losses. For instance, `nn.BCELoss(reduction="sum")` is acceptable. + Note for the first option, the `sample_wise_grads_per_batch` + argument must be False, and for the second option, + `sample_wise_grads_per_batch` must be True. Also note that for + the second option, if `loss_fn` has no "reduction" attribute, + the implementation assumes that the reduction is the *sum* of the + per-example losses. If this is not the case, i.e. the reduction + is the *mean*, please set the "reduction" attribute of `loss_fn` + to "mean", i.e. `loss_fn.reduction = "mean"`. + Default: None + batch_size (int or None, optional): Batch size of the DataLoader created to + iterate through `influence_src_dataset`, if it is a Dataset. + `batch_size` should be chosen as large as possible so that certain + intermediate quantities created from a batch still fit in memory. + Specific implementations of `TracInCPBase` will detail the size of + the intermediate quantities. `batch_size` must be an int if + `influence_src_dataset` is a Dataset. If `influence_src_dataset` + is a DataLoader, then `batch_size` is ignored as an argument. + Default: 1 + sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient + computations w.r.t. model parameters aggregates the results for a + batch and does not allow to access sample-wise gradients w.r.t. + model parameters. This forces us to iterate over each sample in + the batch if we want sample-wise gradients which is computationally + inefficient. We offer an implementation of batch-wise gradient + computations w.r.t. to model parameters which is computationally + more efficient. This implementation can be enabled by setting the + `sample_wise_grad_per_batch` argument to `True`, and should be + enabled if and only if the `loss_fn` argument is a "reduction" loss + function. For example, `nn.BCELoss(reduction="sum")` would be a + valid `loss_fn` if this implementation is enabled (see + documentation for `loss_fn` for more details). Note that our + current implementation enables batch-wise gradient computations + only for a limited number of PyTorch nn.Modules: Conv2D and Linear. + This list will be expanded in the near future. Therefore, please + do not enable this implementation if gradients will be computed + for other kinds of layers. + Default: False + """ + + TracInCPBase.__init__( + self, + model, + influence_src_dataset, + checkpoints, + checkpoints_load_func, + loss_fn, + batch_size, + ) + + self.sample_wise_grads_per_batch = sample_wise_grads_per_batch + + # If we are able to access the reduction used by `loss_fn`, we check whether + # the reduction is compatible with `sample_wise_grads_per_batch` + if isinstance(loss_fn, Module) and hasattr( + loss_fn, "reduction" + ): # TODO: allow loss_fn to be Callable + if self.sample_wise_grads_per_batch: + assert loss_fn.reduction in ["sum", "mean"], ( + 'reduction for `loss_fn` must be "sum" or "mean" when ' + "`sample_wise_grads_per_batch` is True" + ) + self.reduction_type = str(loss_fn.reduction) + else: + assert loss_fn.reduction == "none", ( + 'reduction for `loss_fn` must be "none" when ' + "`sample_wise_grads_per_batch` is False" + ) + else: + # if we are unable to access the reduction used by `loss_fn`, we warn + # the user about the assumptions we are making regarding the reduction + # used by `loss_fn` + if self.sample_wise_grads_per_batch: + warnings.warn( + 'Since `loss_fn` has no "reduction" attribute, and ' + "`sample_wise_grads_per_batch` is True, the implementation assumes " + 'that `loss_fn` is a "reduction" loss function that reduces the ' + "per-example losses by taking their *sum*. If `loss_fn` " + "instead reduces the per-example losses by taking their mean, " + 'please set the reduction attribute of `loss_fn` to "mean", i.e. ' + '`loss_fn.reduction = "mean"`. Note that if ' + "`sample_wise_grads_per_batch` is True, the implementation " + "assumes the reduction is either a sum or mean reduction." + ) + self.reduction_type = "sum" + else: + warnings.warn( + 'Since `loss_fn` has no "reduction" attribute, and ' + "`sample_wise_grads_per_batch` is False, the implementation " + 'assumes that `loss_fn` is a "per-example" loss function (see ' + "documentation for `loss_fn` for details). Please ensure that " + "this is the case." + ) + + r""" + TODO: Either restore model state after done (would have to place functionality + within influence to restore after every influence call)? or make a copy so that + changes to grad_requires aren't persistent after using TracIn. + """ + if layers is not None: + assert isinstance(layers, List), "`layers` should be a list!" + assert len(layers) > 0, "`layers` cannot be empty!" + assert isinstance( + layers[0], str + ), "`layers` should contain str layer names." + layerstr = " ".join(layers) + gradset = False + for layer in layers: + for name, param in model.named_parameters(): + param.requires_grad = False + if name in layerstr or layer in name: + param.requires_grad = True + gradset = True + assert gradset, "At least one parameter of network must require gradient." + + @log_usage() + def influence( # type: ignore[override] + self, + inputs: Any = None, + targets: Optional[Tensor] = None, + k: Optional[int] = None, + proponents: bool = True, + unpack_inputs: bool = True, + show_progress: bool = False, + ) -> Union[Tensor, KMostInfluentialResults]: + r""" + This is the key method of this class, and can be run in 3 different modes, + where the mode that is run depends on the arguments passed to this method: + + - self influence mode: This mode is used if `inputs` is None. This mode + computes the self influence scores for every example in + the training dataset `influence_src_dataset`. + - influence score mode: This mode is used if `inputs` is not None, and `k` is + None. This mode computes the influence score of every example in + training dataset `influence_src_dataset` on every example in the test + batch represented by `inputs` and `targets`. + - k-most influential mode: This mode is used if `inputs` is not None, and + `k` is not None, and an int. This mode computes the proponents or + opponents of every example in the test batch represented by `inputs` + and `targets`. In particular, for each test example in the test batch, + this mode computes its proponents (resp. opponents), which are the + indices in the training dataset `influence_src_dataset` of the training + examples with the `k` highest (resp. lowest) influence scores on the + test example. Proponents are computed if `proponents` is True. + Otherwise, opponents are computed. For each test example, this method + also returns the actual influence score of each proponent (resp. + opponent) on the test example. + + Args: + inputs (Any, optional): If not provided or `None`, the self influence mode + will be run. Otherwise, `inputs` is the test batch that will be + used when running in either influence score or k-most influential + mode. If the argument `unpack_inputs` is False, the + assumption is that `self.model(inputs)` produces the predictions + for a batch, and `inputs` can be of any type. Otherwise if the + argument `unpack_inputs` is True, the assumption is that + `self.model(*inputs)` produces the predictions for a batch, and + `inputs` will need to be a tuple. In other words, `inputs` will be + unpacked as an argument when passing to `self.model`. + Default: None + targets (tensor, optional): If computing influence scores on a loss + function, these are the labels corresponding to the batch `inputs`. + Default: None + k (int, optional): If not provided or `None`, the influence score mode will + be run. Otherwise, the k-most influential mode will be run, + and `k` is the number of proponents / opponents to return per + example in the test batch. + Default: None + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`), if running in k-most influential + mode. + Default: True + unpack_inputs (bool, optional): Whether to unpack the `inputs` argument to + when passing it to `model`, if `inputs` is a tuple (no unpacking + done otherwise). + Default: True + show_progress (bool, optional): For all modes, computation of results + requires "training dataset computations": computations for each + batch in the training dataset `influence_src_dataset`, which may + take a long time. If `show_progress`is true, the progress of + "training dataset computations" will be displayed. In particular, + the number of batches for which computations have been performed + will be displayed. It will try to use tqdm if available for + advanced features (e.g. time estimation). Otherwise, it will + fallback to a simple output of progress. + Default: False + + Returns: + The return value of this method depends on which mode is run. + + - self influence mode: if this mode is run (`inputs` is None), returns a 1D + tensor of self influence scores over training dataset + `influence_src_dataset`. The length of this tensor is the number of + examples in `influence_src_dataset`, regardless of whether it is a + Dataset or DataLoader. + - influence score mode: if this mode is run (`inputs is not None, `k` is + None), returns a 2D tensor `influence_scores` of shape + `(input_size, influence_src_dataset_size)`, where `input_size` is + the number of examples in the test batch, and + `influence_src_dataset_size` is the number of examples in + training dataset `influence_src_dataset`. In other words, + `influence_scores[i][j]` is the influence score of the `j`-th + example in `influence_src_dataset` on the `i`-th example in the + test batch. + - k-most influential mode: if this mode is run (`inputs` is not None, + `k` is an int), returns a namedtuple `(indices, influence_scores)`. + `indices` is a 2D tensor of shape `(input_size, k)`, where + `input_size` is the number of examples in the test batch. If + computing proponents (resp. opponents), `indices[i][j]` is the + index in training dataset `influence_src_dataset` of the example + with the `j`-th highest (resp. lowest) influence score (out of the + examples in `influence_src_dataset`) on the `i`-th example in the + test batch. `influence_scores` contains the corresponding influence + scores. In particular, `influence_scores[i][j]` is the influence + score of example `indices[i][j]` in `influence_src_dataset` on + example `i` in the test batch represented by `inputs` and + `targets`. + """ + return _influence_route_to_helpers( + self, + inputs, + targets, + k, + proponents, + unpack_inputs, + show_progress, + ) + + def _influence_batch_tracincp( + self, + inputs: Tuple[Any, ...], + targets: Optional[Tensor], + batch: Tuple[Any, ...], + ): + """ + computes influence scores for a single training batch + """ + + def get_checkpoint_contribution(checkpoint): + + assert ( + checkpoint is not None + ), "None returned from `checkpoints`, cannot load." + + learning_rate = self.checkpoints_load_func(self.model, checkpoint) + + input_jacobians = self._basic_computation_tracincp( + inputs, + targets, + ) + + return ( + _gradient_dot_product( + input_jacobians, + self._basic_computation_tracincp(batch[0:-1], batch[-1]), + ) + * learning_rate + ) + + batch_tracin_scores = get_checkpoint_contribution(self.checkpoints[0]) + + for checkpoint in self.checkpoints[1:]: + batch_tracin_scores += get_checkpoint_contribution(checkpoint) + + return batch_tracin_scores + + def _influence( + self, + inputs: Tuple[Any, ...], + targets: Optional[Tensor] = None, + show_progress: bool = False, + ) -> Tensor: + r""" + Computes the influence of examples in training dataset `influence_src_dataset` + on the examples in the test batch represented by `inputs` and `targets`. + This implementation does not require knowing the number of training examples + in advance. Instead, the number of training examples is inferred from the + output of `self._basic_computation_tracincp`. + + Args: + inputs (Tuple of Any): A test batch of examples. Does not represent labels, + which are passed as `targets`. The assumption is that + `self.model(*inputs)` produces the predictions for the batch. + targets (tensor, optional): If computing influence scores on a loss + function, these are the labels corresponding to the batch `inputs`. + Default: None + show_progress (bool, optional): To compute the influence of examples in + training dataset `influence_src_dataset`, we compute the influence + of each batch. If `show_progress`is true, the progress of this + computation will be displayed. In particular, the number of batches + for which influence has been computed will be displayed. It will + try to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + + Returns: + influence_scores (tensor): Influence scores from the TracInCP method. + Its shape is `(input_size, influence_src_dataset_size)`, where `input_size` + is the number of examples in the test batch, and + `influence_src_dataset_size` is the number of examples in + training dataset `influence_src_dataset`. For example: + `influence_scores[i][j]` is the influence score for the j-th training + example to the i-th input example. + """ + influence_src_dataloader = self.influence_src_dataloader + + if show_progress: + influence_src_dataloader = progress( + influence_src_dataloader, + desc=( + f"Using {self.get_name()} to compute " + "influence for training batches" + ), + total=self.influence_src_dataloader_len, + ) + + return torch.cat( + [ + self._influence_batch_tracincp(inputs, targets, batch) + for batch in influence_src_dataloader + ], + dim=1, + ) + + def _get_k_most_influential( + self, + inputs: Tuple[Any, ...], + targets: Optional[Tensor] = None, + k: int = 5, + proponents: bool = True, + show_progress: bool = False, + ) -> KMostInfluentialResults: + r""" + Args: + inputs (Tuple of Any): A tuple that represents a batch of examples. It does + not represent labels, which are passed as `targets`. + targets (Tensor, optional): If computing influence scores on a loss + function, these are the labels corresponding to the batch `inputs`. + Default: None + k (int, optional): The number of proponents or opponents to return per test + example. + Default: 5 + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`) + Default: True + show_progress (bool, optional): To compute the proponents (or opponents) + for the batch of examples, we perform computation for each batch in + training dataset `influence_src_dataset`, If `show_progress`is + true, the progress of this computation will be displayed. In + particular, the number of batches for which the computation has + been performed will be displayed. It will try to use tqdm if + available for advanced features (e.g. time estimation). Otherwise, + it will fallback to a simple output of progress. + Default: False + + Returns: + (indices, influence_scores) (namedtuple): `indices` is a torch.long Tensor + that contains the indices of the proponents (or opponents) for each + test example. Its dimension is `(inputs_batch_size, k)`, where + `inputs_batch_size` is the number of examples in `inputs`. For + example, if `proponents==True`, `indices[i][j]` is the index of the + example in training dataset `influence_src_dataset` with the + k-th highest influence score for the j-th example in `inputs`. + `indices` is a `torch.long` tensor so that it can directly be used + to index other tensors. Each row of `influence_scores` contains the + influence scores for a different test example, in sorted order. In + particular, `influence_scores[i][j]` is the influence score of + example `indices[i][j]` in training dataset `influence_src_dataset` + on example `i` in the test batch represented by `inputs` and + `targets`. + """ + desc = ( + None + if not show_progress + else ( + ( + f"Using {self.get_name()} to perform computation for " + f'getting {"proponents" if proponents else "opponents"}. ' + "Processing training batches: 100%" + ) + ) + ) + return KMostInfluentialResults( + *_get_k_most_influential_helper( + self.influence_src_dataloader, + self._influence_batch_tracincp, + inputs, + targets, + k, + proponents, + show_progress, + desc, + ) + ) + + def _self_influence_batch_tracincp(self, batch: Tuple[Any, ...]): + """ + Computes self influence scores for a single batch + """ + + def get_checkpoint_contribution(checkpoint): + + assert ( + checkpoint is not None + ), "None returned from `checkpoints`, cannot load." + + learning_rate = self.checkpoints_load_func(self.model, checkpoint) + + layer_jacobians = self._basic_computation_tracincp(batch[0:-1], batch[-1]) + + # note that all variables in this function are for an entire batch. + # each `layer_jacobian` in `layer_jacobians` corresponds to a different + # layer. `layer_jacobian` is the jacobian w.r.t to a given layer's + # parameters. if the given layer's parameters are of shape *, then + # `layer_jacobian` is of shape (batch_size, *). for each layer, we need + # the squared jacobian for each example. so we square the jacobian and + # sum over all dimensions except the 0-th (the batch dimension). We then + # sum the contribution over all layers. + return ( + torch.sum( + torch.stack( + [ + torch.sum(layer_jacobian.flatten(start_dim=1) ** 2, dim=1) + for layer_jacobian in layer_jacobians + ], + dim=0, + ), + dim=0, + ) + * learning_rate + ) + + batch_self_tracin_scores = get_checkpoint_contribution(self.checkpoints[0]) + + for checkpoint in self.checkpoints[1:]: + batch_self_tracin_scores += get_checkpoint_contribution(checkpoint) + + return batch_self_tracin_scores + + def _self_influence(self, show_progress: bool = False): + """ + Returns: + self influence scores (tensor): 1D tensor containing self influence + scores for all examples in training dataset + `influence_src_dataset`. + show_progress (bool, optional): To compute the self influence scores for + all examples in training dataset `influence_src_dataset`, we + compute the self influence scores for each batch. If + `show_progress`is true, the progress of this computation will be + displayed. In particular, the number of batches for which self + influence scores have been computed will be displayed. It will + try to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + """ + influence_src_dataloader = self.influence_src_dataloader + + if show_progress: + influence_src_dataloader = progress( + influence_src_dataloader, + desc=( + f"Using {self.get_name()} to compute self " + "influence for training batches" + ), + total=self.influence_src_dataloader_len, + ) + + return torch.cat( + [ + self._self_influence_batch_tracincp(batch) + for batch in influence_src_dataloader + ], + dim=0, + ) + + def _basic_computation_tracincp( + self, + inputs: Tuple[Any, ...], + targets: Optional[Tensor] = None, + ) -> Tuple[Tensor, ...]: + """ + For instances of TracInCP, computation of influence scores or self influence + scores repeatedly calls this function for different checkpoints + and batches. + + Args: + inputs (Tuple of Any): A batch of examples, which could be a training batch + or test batch, depending which method is the caller. Does not + represent labels, which are passed as `targets`. The assumption is + that `self.model(*inputs)` produces the predictions for the batch. + targets (tensor or None): If computing influence scores on a loss function, + these are the labels corresponding to the batch `inputs`. + """ + if self.sample_wise_grads_per_batch: + return _compute_jacobian_wrt_params_with_sample_wise_trick( + self.model, + inputs, + targets, + self.loss_fn, + self.reduction_type, + ) + return _compute_jacobian_wrt_params( + self.model, + inputs, + targets, + self.loss_fn, + ) diff --git a/captum/influence/_core/tracincp_fast_rand_proj.py b/captum/influence/_core/tracincp_fast_rand_proj.py new file mode 100644 index 0000000000000000000000000000000000000000..930b0c5e752294a8968163c6267ec687aa1d001b --- /dev/null +++ b/captum/influence/_core/tracincp_fast_rand_proj.py @@ -0,0 +1,1188 @@ +#!/usr/bin/env python3 + +import warnings +from typing import Any, Callable, Iterator, List, Optional, Union, Tuple + +import torch +from captum._utils.common import _get_module_from_name, _format_inputs +from captum._utils.progress import progress +from captum.influence._core.tracincp import ( + TracInCPBase, + KMostInfluentialResults, + _influence_route_to_helpers, +) +from captum.influence._utils.common import ( + _jacobian_loss_wrt_inputs, + _load_flexible_state_dict, + _tensor_batch_dot, + _get_k_most_influential_helper, + _DatasetFromList, +) +from captum.influence._utils.nearest_neighbors import ( + NearestNeighbors, + AnnoyNearestNeighbors, +) +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module +from torch.utils.data import DataLoader, Dataset + +layer_inputs = [] + + +def _capture_inputs(layer: Module, input: Tensor, output: Tensor) -> None: + r"""Save activations into layer.activations in forward pass""" + + layer_inputs.append(input[0].detach()) + + +r""" +Implements abstract DataInfluence class and also provides implementation details for +influence computation based on the logic provided in TracIn paper +(https://arxiv.org/pdf/2002.08484.pdf). + +The TracIn paper proposes an idealized notion of influence which can be represented by +the total amount a training example reduces loss for a test example via a training +process such as stochastic gradient descent. As this idealized notion of influence is +impractical to compute, the TracIn paper proposes instead to compute an influence +score, which uses a first-order approximation for the change in loss for a test example +by a training example, which is accumulated across saved model checkpoints. This +influence score is accumulated via a summed dot-product of gradient vectors for the +scores/loss of a test and training example. +""" + +""" +TODO: Support for checkpoint type. Currently only supports model parameters as saved +checkpoints. Can use enum or string. + +Potential implementation from design doc: +checkpoint_type (Enum = [Parameters | Loss_Grad]): For performance, + saved / loaded checkpoints can be either model parameters, or + gradient of the loss function on an input w.r.t parameters. +""" + + +class TracInCPFast(TracInCPBase): + r""" + In Appendix F, Page 14 of the TracIn paper, they show that the calculation + of the influence score of between a test example x' and a training example x, + can be computed much more quickly than naive back-propagation in the special + case when considering only gradients in the last fully-connected layer. This class + computes influence scores for that special case. Note that the computed + influence scores are exactly the same as when naive back-propagation is used - + there is no loss in accuracy. + """ + + def __init__( + self, + model: Module, + final_fc_layer: Union[Module, str], + influence_src_dataset: Union[Dataset, DataLoader], + checkpoints: Union[str, List[str], Iterator], + checkpoints_load_func: Callable = _load_flexible_state_dict, + loss_fn: Optional[Union[Module, Callable]] = None, + batch_size: Union[int, None] = 1, + vectorize: bool = False, + ) -> None: + r""" + Args: + model (torch.nn.Module): An instance of pytorch model. This model should + define all of its layers as attributes of the model. + final_fc_layer (torch.nn.Module or str): The last fully connected layer in + the network for which gradients will be approximated via fast random + projection method. Can be either the layer module itself, or the + fully qualified name of the layer if it is a defined attribute of + the passed `model`. + influence_src_dataset (torch.utils.data.Dataset or torch.utils.DataLoader): + In the `influence` method, we either compute the influence score of + training examples on examples in a test batch, or self influence + scores for those training examples, depending on which mode is used. + This argument represents the training dataset containing those + training examples. In order to compute those influence scores, we + will create a Pytorch DataLoader yielding batches of training + examples that is then used for processing. If this argument is + already a Pytorch Dataloader, that DataLoader can be directly + used for processing. If it is instead a Pytorch Dataset, we will + create a DataLoader using it, with batch size specified by + `batch_size`. For efficiency purposes, the batch size of the + DataLoader used for processing should be as large as possible, but + not too large, so that certain intermediate quantities created + from a batch still fit in memory. Therefore, if + `influence_src_dataset` is a Dataset, `batch_size` should be large. + If `influence_src_dataset` was already a DataLoader to begin with, + it should have been constructed to have a large batch size. + checkpoints (str or List of str or Iterator): Either the directory of the + path to store and retrieve model checkpoints, a list of + filepaths with checkpoints from which to load, or an iterator which + returns objects from which to load checkpoints. + checkpoints_load_func (Callable, optional): The function to load a saved + checkpoint into a model to update its parameters, and get the + learning rate if it is saved. By default uses a utility to load a + model saved as a state dict. + Default: _load_flexible_state_dict + loss_fn (Callable, optional): The loss function applied to model. `loss_fn` + must be a "reduction" loss function that reduces the per-example + losses in a batch, and returns a single scalar Tensor. Furthermore, + the reduction must be the *sum* or the *mean* of the per-example + losses. For instance, `nn.BCELoss(reduction="sum")` is acceptable. + Also note that if `loss_fn` has no "reduction" attribute, + the implementation assumes that the reduction is the *sum* of the + per-example losses. If this is not the case, i.e. the reduction + is the *mean*, please set the "reduction" attribute of `loss_fn` + to "mean", i.e. `loss_fn.reduction = "mean"`. + Default: None + batch_size (int or None, optional): Batch size of the DataLoader created to + iterate through `influence_src_dataset`, if it is a Dataset. + `batch_size` should be chosen as large as possible so that certain + intermediate quantities created from a batch still fit in memory. + Specific implementations of `TracInCPBase` will detail the size of + the intermediate quantities. `batch_size` must be an int if + `influence_src_dataset` is a Dataset. If `influence_src_dataset` + is a DataLoader, then `batch_size` is ignored as an argument. + Default: 1 + vectorize (bool, optional): Flag to use experimental vectorize functionality + for `torch.autograd.functional.jacobian`. + Default: False + """ + TracInCPBase.__init__( + self, + model, + influence_src_dataset, + checkpoints, + checkpoints_load_func, + loss_fn, + batch_size, + ) + + self.vectorize = vectorize + + # TODO: restore prior state + self.final_fc_layer = final_fc_layer + if isinstance(self.final_fc_layer, str): + self.final_fc_layer = _get_module_from_name(model, self.final_fc_layer) + assert isinstance(self.final_fc_layer, Module) + for param in self.final_fc_layer.parameters(): + param.requires_grad = True + + assert loss_fn is not None, "loss function must not be none" + + # If we are able to access the reduction used by `loss_fn`, we check whether + # the reduction is either 'sum' or 'mean', as required + if isinstance(loss_fn, Module) and hasattr( + loss_fn, "reduction" + ): # TODO: allow loss_fn to be Callable + assert loss_fn.reduction in [ + "sum", + "mean", + ], 'reduction for `loss_fn` must be "sum" or "mean"' + self.reduction_type = str(loss_fn.reduction) + else: + # if we are unable to access the reduction used by `loss_fn`, we warn + # the user about the assumptions we are making regarding the reduction + # used by `loss_fn` + warnings.warn( + 'Since `loss_fn` has no "reduction" attribute, the implementation ' + 'assumes that `loss_fn` is a "reduction" loss function that ' + "reduces the per-example losses by taking their *sum*. If " + "`loss_fn` instead reduces the per-example losses by taking their " + 'mean, please set the reduction attribute of `loss_fn` to "mean", ' + 'i.e. `loss_fn.reduction = "mean"`.' + ) + self.reduction_type = "sum" + + @log_usage() + def influence( # type: ignore[override] + self, + inputs: Any = None, + targets: Optional[Tensor] = None, + k: Optional[int] = None, + proponents: bool = True, + unpack_inputs: bool = True, + show_progress: bool = False, + ) -> Union[Tensor, KMostInfluentialResults]: + r""" + This is the key method of this class, and can be run in 3 different modes, + where the mode that is run depends on the arguments passed to this method: + + - self influence mode: This mode is used if `inputs` is None. This mode + computes the self influence scores for every example in + the training dataset `influence_src_dataset`. + - influence score mode: This mode is used if `inputs` is not None, and `k` is + None. This mode computes the influence score of every example in + training dataset `influence_src_dataset` on every example in the test + batch represented by `inputs` and `targets`. + - k-most influential mode: This mode is used if `inputs` is not None, and + `k` is not None, and an int. This mode computes the proponents or + opponents of every example in the test batch represented by `inputs` + and `targets`. In particular, for each test example in the test batch, + this mode computes its proponents (resp. opponents), which are the + indices in the training dataset `influence_src_dataset` of the training + examples with the `k` highest (resp. lowest) influence scores on the + test example. Proponents are computed if `proponents` is True. + Otherwise, opponents are computed. For each test example, this method + also returns the actual influence score of each proponent (resp. + opponent) on the test example. + + Args: + inputs (Any, optional): If not provided or `None`, the self influence mode + will be run. Otherwise, `inputs` is the test batch that will be + used when running in either influence score or k-most influential + mode. If the argument `unpack_inputs` is False, the + assumption is that `self.model(inputs)` produces the predictions + for a batch, and `inputs` can be of any type. Otherwise if the + argument `unpack_inputs` is True, the assumption is that + `self.model(*inputs)` produces the predictions for a batch, and + `inputs` will need to be a tuple. In other words, `inputs` will be + unpacked as an argument when passing to `self.model`. + Default: None + targets (tensor, optional): The labels corresponding to the batch `inputs`. + This method is designed to be applied for a loss function, so + `targets` is required, unless running in "self influence" mode. + Default: None + k (int, optional): If not provided or `None`, the influence score mode will + be run. Otherwise, the k-most influential mode will be run, + and `k` is the number of proponents / opponents to return per + example in the test batch. + Default: None + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`), if running in k-most influential + mode. + Default: True + unpack_inputs (bool, optional): Whether to unpack the `inputs` argument to + when passing it to `model`, if `inputs` is a tuple (no unpacking + done otherwise). + Default: True + show_progress (bool, optional): For all modes, computation of results + requires "training dataset computations": computations for each + batch in the training dataset `influence_src_dataset`, which may + take a long time. If `show_progress`is true, the progress of + "training dataset computations" will be displayed. In particular, + the number of batches for which computations have been performed + will be displayed. It will try to use tqdm if available for + advanced features (e.g. time estimation). Otherwise, it will + fallback to a simple output of progress. + Default: False + + Returns: + The return value of this method depends on which mode is run. + + - self influence mode: if this mode is run (`inputs` is None), returns a 1D + tensor of self influence scores over training dataset + `influence_src_dataset`. The length of this tensor is the number of + examples in `influence_src_dataset`, regardless of whether it is a + Dataset or DataLoader. + - influence score mode: if this mode is run (`inputs is not None, `k` is + None), returns a 2D tensor `influence_scores` of shape + `(input_size, influence_src_dataset_size)`, where `input_size` is + the number of examples in the test batch, and + `influence_src_dataset_size` is the number of examples in + training dataset `influence_src_dataset`. In other words, + `influence_scores[i][j]` is the influence score of the `j`-th + example in `influence_src_dataset` on the `i`-th example in the + test batch. + - k-most influential mode: if this mode is run (`inputs` is not None, + `k` is an int), returns a namedtuple `(indices, influence_scores)`. + `indices` is a 2D tensor of shape `(input_size, k)`, where + `input_size` is the number of examples in the test batch. If + computing proponents (resp. opponents), `indices[i][j]` is the + index in training dataset `influence_src_dataset` of the example + with the `j`-th highest (resp. lowest) influence score (out of the + examples in `influence_src_dataset`) on the `i`-th example in the + test batch. `influence_scores` contains the corresponding influence + scores. In particular, `influence_scores[i][j]` is the influence + score of example `indices[i][j]` in `influence_src_dataset` on + example `i` in the test batch represented by `inputs` and + `targets`. + """ + return _influence_route_to_helpers( + self, + inputs, + targets, + k, + proponents, + unpack_inputs, + show_progress, + ) + + def _influence_batch_tracincp_fast( + self, + inputs: Tuple[Any, ...], + targets: Tensor, + batch: Tuple[Any, ...], + ): + """ + computes influence scores for a single training batch + """ + + def get_checkpoint_contribution(checkpoint): + + assert ( + checkpoint is not None + ), "None returned from `checkpoints`, cannot load." + + learning_rate = self.checkpoints_load_func(self.model, checkpoint) + + input_jacobians, input_layer_inputs = _basic_computation_tracincp_fast( + self, + inputs, + targets, + ) + + src_jacobian, src_layer_input = _basic_computation_tracincp_fast( + self, batch[0:-1], batch[-1] + ) + return ( + _tensor_batch_dot(input_jacobians, src_jacobian) + * _tensor_batch_dot(input_layer_inputs, src_layer_input) + * learning_rate + ) + + batch_tracin_scores = get_checkpoint_contribution(self.checkpoints[0]) + + for checkpoint in self.checkpoints[1:]: + batch_tracin_scores += get_checkpoint_contribution(checkpoint) + + return batch_tracin_scores + + def _influence( # type: ignore[override] + self, + inputs: Tuple[Any, ...], + targets: Tensor, + show_progress: bool = False, + ) -> Tensor: + r""" + Computes the influence of examples in training dataset `influence_src_dataset` + on the examples in the test batch represented by `inputs` and `targets`. + This implementation does not require knowing the number of training examples + in advance. Instead, the number of training examples is inferred from the + output of `_basic_computation_tracincp_fast`. + + Args: + inputs (Tuple of Any): A batch of examples. Does not represent labels, + which are passed as `targets`. The assumption is that + `self.model(*inputs)` produces the predictions for the batch. + targets (tensor): The labels corresponding to the batch `inputs`. This + method is designed to be applied for a loss function, so labels + are required. + show_progress (bool, optional): To compute the influence of examples in + training dataset `influence_src_dataset`, we compute the influence + of each batch. If `show_progress`is true, the progress of this + computation will be displayed. In particular, the number of batches + for which influence has been computed will be displayed. It will + try to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + + Returns: + influence_scores (tensor): Influence scores from the TracInCPFast method. + Its shape is `(input_size, influence_src_dataset_size)`, where `input_size` + is the number of examples in the test batch, and + `influence_src_dataset_size` is the number of examples in + training dataset `influence_src_dataset`. For example: + `influence_scores[i][j]` is the influence score for the j-th training + example to the i-th input example. + """ + assert targets is not None + + influence_src_dataloader = self.influence_src_dataloader + + if show_progress: + influence_src_dataloader = progress( + influence_src_dataloader, + desc=( + f"Using {self.get_name()} to compute " + "influence for training batches" + ), + total=self.influence_src_dataloader_len, + ) + + return torch.cat( + [ + self._influence_batch_tracincp_fast(inputs, targets, batch) + for batch in influence_src_dataloader + ], + dim=1, + ) + + def _get_k_most_influential( # type: ignore[override] + self, + inputs: Tuple[Any, ...], + targets: Tensor, + k: int = 5, + proponents: bool = True, + show_progress: bool = False, + ) -> KMostInfluentialResults: + r""" + Args: + inputs (Tuple of Any): A tuple that represents a batch of examples. It does + not represent labels, which are passed as `targets`. + targets (tensor): The labels corresponding to the batch `inputs`. This + method is designed to be applied for a loss function, so labels + are required. + k (int, optional): The number of proponents or opponents to return per test + example. + Default: 5 + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`) + Default: True + show_progress (bool, optional): To compute the proponents (or opponents) + for the batch of examples, we perform computation for each batch in + training dataset `influence_src_dataset`, If `show_progress`is + true, the progress of this computation will be displayed. In + particular, the number of batches for which the computation has + been performed will be displayed. It will try to use tqdm if + available for advanced features (e.g. time estimation). Otherwise, + it will fallback to a simple output of progress. + Default: False + + Returns: + (indices, influence_scores) (namedtuple): `indices` is a torch.long Tensor + that contains the indices of the proponents (or opponents) for each + test example. Its dimension is `(inputs_batch_size, k)`, where + `inputs_batch_size` is the number of examples in `inputs`. For + example, if `proponents==True`, `indices[i][j]` is the index of the + example in training dataset `influence_src_dataset` with the + k-th highest influence score for the j-th example in `inputs`. + `indices` is a `torch.long` tensor so that it can directly be used + to index other tensors. Each row of `influence_scores` contains the + influence scores for a different test example, in sorted order. In + particular, `influence_scores[i][j]` is the influence score of + example `indices[i][j]` in training dataset `influence_src_dataset` + on example `i` in the test batch represented by `inputs` and + `targets`. + """ + desc = ( + None + if not show_progress + else ( + ( + f"Using {self.get_name()} to perform computation for " + f'getting {"proponents" if proponents else "opponents"}. ' + "Processing training batches: 100%" + ) + ) + ) + return KMostInfluentialResults( + *_get_k_most_influential_helper( + self.influence_src_dataloader, + self._influence_batch_tracincp_fast, + inputs, + targets, + k, + proponents, + show_progress, + desc, + ) + ) + + def _self_influence_batch_tracincp_fast(self, batch: Tuple[Any, ...]): + """ + Computes self influence scores for a single batch + """ + + def get_checkpoint_contribution(checkpoint): + + assert ( + checkpoint is not None + ), "None returned from `checkpoints`, cannot load." + + learning_rate = self.checkpoints_load_func(self.model, checkpoint) + + batch_jacobian, batch_layer_input = _basic_computation_tracincp_fast( + self, batch[0:-1], batch[-1] + ) + + return ( + torch.sum(batch_jacobian ** 2, dim=1) + * torch.sum(batch_layer_input ** 2, dim=1) + * learning_rate + ) + + batch_self_tracin_scores = get_checkpoint_contribution(self.checkpoints[0]) + + for checkpoint in self.checkpoints[1:]: + batch_self_tracin_scores += get_checkpoint_contribution(checkpoint) + + return batch_self_tracin_scores + + def _self_influence(self, show_progress: bool = False): + """ + Returns: + self influence scores (tensor): 1D tensor containing self influence + scores for all examples in training dataset + `influence_src_dataset`. + show_progress (bool, optional): To compute the self influence scores for + all examples in training dataset `influence_src_dataset`, we + compute the self influence scores for each batch. If + `show_progress`is true, the progress of this computation will be + displayed. In particular, the number of batches for which self + influence scores have been computed will be displayed. It will + try to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + """ + influence_src_dataloader = self.influence_src_dataloader + + if show_progress: + influence_src_dataloader = progress( + influence_src_dataloader, + desc=( + f"Using {self.get_name()} to compute self " + "influence for training batches" + ), + total=self.influence_src_dataloader_len, + ) + + return torch.cat( + [ + self._self_influence_batch_tracincp_fast(batch) + for batch in influence_src_dataloader + ], + dim=0, + ) + + +def _basic_computation_tracincp_fast( + influence_instance: TracInCPFast, + inputs: Tuple[Any, ...], + targets: Tensor, +): + """ + For instances of TracInCPFast and children classes, computation of influence scores + or self influence scores repeatedly calls this function for different checkpoints + and batches. + + Args: + influence_instance (TracInCPFast): A instance of TracInCPFast or its children. + We assume `influence_instance` has a `loss_fn` attribute, i.e. the loss + function applied to the output of the last fully-connected layer, as + well as a `reduction_type` attribute, which indicates whether `loss_fn` + reduces the per-example losses by using their mean or sum. The + `reduction_type` attribute must either be "mean" or "sum". + inputs (Tuple of Any): A batch of examples, which could be a training batch + or test batch, depending which method is the caller. Does not + represent labels, which are passed as `targets`. The assumption is + that `self.model(*inputs)` produces the predictions for the batch. + targets (tensor): If computing influence scores on a loss function, + these are the labels corresponding to the batch `inputs`. + """ + global layer_inputs + layer_inputs = [] + assert isinstance(influence_instance.final_fc_layer, Module) + handle = influence_instance.final_fc_layer.register_forward_hook(_capture_inputs) + out = influence_instance.model(*inputs) + + assert influence_instance.loss_fn is not None, "loss function is required" + assert influence_instance.reduction_type in [ + "sum", + "mean", + ], 'reduction_type must be either "mean" or "sum"' + input_jacobians = _jacobian_loss_wrt_inputs( + influence_instance.loss_fn, + out, + targets, + influence_instance.vectorize, + influence_instance.reduction_type, + ) + handle.remove() + _layer_inputs = layer_inputs[0] + + assert len(input_jacobians.shape) == 2 + + return input_jacobians, _layer_inputs + + +class TracInCPFastRandProj(TracInCPFast): + def __init__( + self, + model: Module, + final_fc_layer: Union[Module, str], + influence_src_dataset: Union[Dataset, DataLoader], + checkpoints: Union[str, List[str], Iterator], + checkpoints_load_func: Callable = _load_flexible_state_dict, + loss_fn: Optional[Union[Module, Callable]] = None, + batch_size: Union[int, None] = 1, + vectorize: bool = False, + nearest_neighbors: Optional[NearestNeighbors] = None, + projection_dim: int = None, + seed: int = 0, + ) -> None: + r""" + A version of TracInCPFast which is optimized for "interactive" calls to + `influence` for the purpose of calculating proponents / opponents, or + influence scores. "Interactive" means there will be multiple calls to + `influence`, with each call for a different batch of test examples, and + subsequent calls rely on the results of previous calls. The implementation in + this class has been optimized so that each call to `influence` is fast, so that + it can be used for interactive analysis. This class should only be used for + interactive use cases. It should not be used if `influence` will only be + called once, because to enable fast calls to `influence`, time and memory + intensive preprocessing is required in `__init__`. Furthermore, it should not + be used to calculate self influencs scores - `TracInCPFast` should be used + instead for that purpose. To enable interactive analysis, this implementation + saves pre-computed vectors for all training examples in + `influence_src_dataset`. Crucially, the influence score of a training + example on a test example is simply the dot-product of their corresponding + vectors, and proponents / opponents can be found by first storing vectors for + training examples in a nearest-neighbor data structure, and then finding the + nearest-neighbors for a test example in terms of dot-product (see appendix F + of the TracIn paper). This class should only be used if calls to `influence` + to obtain proponents / opponents or influence scores will be made in an + "interactive" manner, and there is sufficient memory to store vectors for the + entire `influence_src_dataset`. This is because in order to enable interactive + analysis, this implementation incures overhead in ``__init__` to setup the + nearest-neighbors data structure, which is both time and memory intensive, as + vectors corresponding to all training examples needed to be stored. To reduce + memory usage, this implementation enables random projections of those vectors. + Note that the influence scores computed with random projections are less + accurate, though correct in expectation. + + Args: + model (torch.nn.Module): An instance of pytorch model. This model should + define all of its layers as attributes of the model. + final_fc_layer (torch.nn.Module or str): The last fully connected layer in + the network for which gradients will be approximated via fast random + projection method. Can be either the layer module itself, or the + fully qualified name of the layer if it is a defined attribute of + the passed `model`. + influence_src_dataset (torch.utils.data.Dataset or torch.utils.DataLoader): + In the `influence` method, we either compute the influence score of + training examples on examples in a test batch, or self influence + scores for those training examples, depending on which mode is used. + This argument represents the training dataset containing those + training examples. In order to compute those influence scores, we + will create a Pytorch DataLoader yielding batches of training + examples that is then used for processing. If this argument is + already a Pytorch Dataloader, that DataLoader can be directly + used for processing. If it is instead a Pytorch Dataset, we will + create a DataLoader using it, with batch size specified by + `batch_size`. For efficiency purposes, the batch size of the + DataLoader used for processing should be as large as possible, but + not too large, so that certain intermediate quantities created + from a batch still fit in memory. Therefore, if + `influence_src_dataset` is a Dataset, `batch_size` should be large. + If `influence_src_dataset` was already a DataLoader to begin with, + it should have been constructed to have a large batch size. + checkpoints (str or List of str or Iterator): Either the directory of the + path to store and retrieve model checkpoints, a list of + filepaths with checkpoints from which to load, or an iterator which + returns objects from which to load checkpoints. + checkpoints_load_func (Callable, optional): The function to load a saved + checkpoint into a model to update its parameters, and get the + learning rate if it is saved. By default uses a utility to load a + model saved as a state dict. + Default: _load_flexible_state_dict + loss_fn (Callable, optional): The loss function applied to model. `loss_fn` + must be a "reduction" loss function that reduces the per-example + losses in a batch, and returns a single scalar Tensor. Furthermore, + the reduction must be the *sum* of the per-example losses. For + instance, `nn.BCELoss(reduction="sum")` is acceptable, but + `nn.BCELoss(reduction="mean")` is *not* acceptable. + Default: None + batch_size (int or None, optional): Batch size of the DataLoader created to + iterate through `influence_src_dataset`, if it is a Dataset. + `batch_size` should be chosen as large as possible so that certain + intermediate quantities created from a batch still fit in memory. + Specific implementations of `TracInCPBase` will detail the size of + the intermediate quantities. `batch_size` must be an int if + `influence_src_dataset` is a Dataset. If `influence_src_dataset` + is a DataLoader, then `batch_size` is ignored as an argument. + Default: 1 + vectorize (bool): Flag to use experimental vectorize functionality + for `torch.autograd.functional.jacobian`. + Default: False + nearest_neighbors (NearestNeighbors, optional): The NearestNeighbors + instance for finding nearest neighbors. If None, defaults to + `AnnoyNearestNeighbors(n_trees=10)`. + Default: None + projection_dim (int, optional): Each example will be represented in + the nearest neighbors data structure with a vector. This vector + is the concatenation of several "checkpoint vectors", each of which + is computed using a different checkpoint in the `checkpoints` + argument. If `projection_dim` is an int, it represents the + dimension we will project each "checkpoint vector" to, so that the + vector for each example will be of dimension at most + `projection_dim` * C, where C is the number of checkpoints. + Regarding the dimension of each vector, D: Let I be the dimension + of the output of the last fully-connected layer times the dimension + of the input of the last fully-connected layer. If `projection_dim` + is not `None`, then D = min(I * C, `projection_dim` * C). + Otherwise, D = I * C. In summary, if `projection_dim` is None, the + dimension of this vector will be determined by the size of the + input and output of the last fully-connected layer of `model`, and + the number of checkpoints. Otherwise, `projection_dim` must be an + int, and random projection will be performed to ensure that the + vector is of dimension no more than `projection_dim` * C. + `projection_dim` corresponds to the variable d in the top of page + 15 of the TracIn paper: https://arxiv.org/pdf/2002.08484.pdf. + Default: None + seed (int, optional): Because this implementation chooses a random + projection, its output is random. Setting this seed specifies the + random seed when choosing the random projection. + Default: 0 + """ + + TracInCPFast.__init__( + self, + model, + final_fc_layer, + influence_src_dataset, + checkpoints, + checkpoints_load_func, + loss_fn, + batch_size, + vectorize, + ) + + warnings.warn( + ( + "WARNING: Using this implementation stores quantities related to the " + "entire `influence_src_dataset` in memory, and may results in running " + "out of memory. If this happens, consider using %s instead, for which " + "each call to `influence` to compute influence scores or proponents " + "will be slower, but may avoid running out of memory." + ) + % "`TracInCPFast`" + ) + + self.nearest_neighbors = ( + AnnoyNearestNeighbors() if nearest_neighbors is None else nearest_neighbors + ) + + self.projection_dim = projection_dim + + torch.manual_seed(seed) # for reproducibility + self.projection_quantities = self._set_projections_tracincp_fast_rand_proj( + self.influence_src_dataloader, + ) + + self.src_intermediate_quantities = ( + self._get_intermediate_quantities_tracincp_fast_rand_proj( + self.influence_src_dataloader, + self.projection_quantities, + ) + ) + + self._process_src_intermediate_quantities_tracincp_fast_rand_proj( + self.src_intermediate_quantities, + ) + + def _influence( # type: ignore[override] + self, + inputs: Tuple[Any, ...], + targets: Tensor, + ) -> Tensor: + r""" + Args: + inputs (tuple of Any): A batch of examples. Does not represent labels, + which are passed as `targets`. The assumption is that + `self.model(*inputs)` produces the predictions for the batch. + targets (tensor): The labels corresponding to the batch `inputs`. This + method is designed to be applied for a loss function, so labels + are required. + + Returns: + influence_scores (tensor): Influence scores from the + TracInCPFastRandProj method. Its shape is + `(input_size, influence_src_dataset_size)`, where `input_size` is the + number of examples in the test batch, and `influence_src_dataset_size` is + the number of examples in training dataset `influence_src_dataset`. For + example, `influence_scores[i][j]` is the influence score for the j-th + training example to the i-th input example. + """ + inputs_batch = (*inputs, targets) + input_projections = self._get_intermediate_quantities_tracincp_fast_rand_proj( + DataLoader( + _DatasetFromList([inputs_batch]), shuffle=False, batch_size=None + ), + self.projection_quantities, + ) + + src_projections = self.src_intermediate_quantities + + return torch.matmul(input_projections, src_projections.T) + + def _get_k_most_influential( # type: ignore[override] + self, + inputs: Tuple[Any, ...], + targets: Tensor, + k: int = 5, + proponents: bool = True, + ) -> KMostInfluentialResults: + r""" + Args: + inputs (Tuple of Any): A tuple that represents a batch of examples. It does + not represent labels, which are passed as `targets`. + targets (tensor): The labels corresponding to the batch `inputs`. This + method is designed to be applied for a loss function, so labels + are required. + k (int, optional): The number of proponents or opponents to return per test + example. + Default: 5 + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`) + Default: True + + Returns: + (indices, influence_scores) (namedtuple): `indices` is a torch.long Tensor + that contains the indices of the proponents (or opponents) for each + test example. Its dimension is `(inputs_batch_size, k)`, where + `inputs_batch_size` is the number of examples in `inputs`. For + example, if `proponents==True`, `indices[i][j]` is the index of the + example in training dataset `influence_src_dataset` with the + k-th highest influence score for the j-th example in `inputs`. + `indices` is a `torch.long` tensor so that it can directly be used + to index other tensors. Each row of `influence_scores` contains the + influence scores for a different test example, in sorted order. In + particular, `influence_scores[i][j]` is the influence score of + example `indices[i][j]` in training dataset `influence_src_dataset` + on example `i` in the test batch represented by `inputs` and + `targets`. + """ + inputs_batch = (*inputs, targets) + input_projections = self._get_intermediate_quantities_tracincp_fast_rand_proj( + DataLoader( + _DatasetFromList([inputs_batch]), shuffle=False, batch_size=None + ), + self.projection_quantities, + ) + multiplier = 1 if proponents else -1 + + input_projections *= multiplier + + indices, distances = self.nearest_neighbors.get_nearest_neighbors( + input_projections, k + ) + + distances *= multiplier + + return KMostInfluentialResults(indices, distances) + + def _self_influence(self): + """ + NOT IMPLEMENTED - no need to implement `TracInCPFastRandProj._self_influence`, + as `TracInCPFast._self_influence` is sufficient - the latter does not benefit + from random projections, since no quantities associated with a training + example are stored (other than its self influence score) + + Returns: + self influence scores (Tensor): 1-d Tensor containing self influence + scores for all examples in training dataset + `influence_src_dataset`. + """ + warnings.warn( + ( + "WARNING: If calculating self influence scores, when only considering " + "gradients with respect to the last fully-connected layer, " + "`TracInCPFastRandProj` should not be used. Instead, please use " + "`TracInCPFast`. This is because when calculating self influence " + "scores, no quantities associated with a training example are stored " + "so that memory-saving benefit of the random projections used by " + "`TracInCPFastRandProj`needed. Further considering the fact that " + "random projections results only in approximate self influence " + "scores, there is no reason to use `TracInCPFastRandProj` when " + "calculating self-influence scores." + ) + ) + raise NotImplementedError + + @log_usage() + def influence( # type: ignore[override] + self, + inputs: Any, + targets: Tensor, + k: int = 5, + proponents: bool = True, + unpack_inputs: bool = True, + ) -> Union[Tensor, KMostInfluentialResults]: + r""" + This is the key method of this class, and can be run in 2 different modes, + where the mode that is run depends on the arguments passed to this method + + - influence score mode: This mode is used if `inputs` is not None, and `k` is + None. This mode computes the influence score of every example in + training dataset `influence_src_dataset` on every example in the test + batch represented by `inputs` and `targets`. + + - k-most influential mode: This mode is used if `inputs` is not None, and + `k` is not None, and an int. This mode computes the proponents or + opponents of every example in the test batch represented by `inputs` + and `targets`. In particular, for each test example in the test batch, + this mode computes its proponents (resp. opponents), which are the + indices in the training dataset `influence_src_dataset` of the training + examples with the `k` highest (resp. lowest) influence scores on the + test example. Proponents are computed if `proponents` is True. + Otherwise, opponents are computed. For each test example, this method + also returns the actual influence score of each proponent (resp. + opponent) on the test example. + + Note that unlike `TracInCPFast`, this class should *not* be run in self + influence mode. To compute self influence scores when only considering + gradients in the last fully-connected layer, please use `TracInCPFast` instead. + + Args: + inputs (Any, optional): If not provided or `None`, the self influence mode + will be run. Otherwise, `inputs` is the test batch that will be + used when running in either influence score or k-most influential + mode. If the argument `unpack_inputs` is False, the + assumption is that `self.model(inputs)` produces the predictions + for a batch, and `inputs` can be of any type. Otherwise if the + argument `unpack_inputs` is True, the assumption is that + `self.model(*inputs)` produces the predictions for a batch, and + `inputs` will need to be a tuple. In other words, `inputs` will be + unpacked as an argument when passing to `self.model`. + Default: None + targets (tensor): The labels corresponding to the batch `inputs`. This + method is designed to be applied for a loss function, so `targets` + is required. + k (int, optional): If not provided or `None`, the influence score mode will + be run. Otherwise, the k-most influential mode will be run, + and `k` is the number of proponents / opponents to return per + example in the test batch. + Default: None + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`), if running in k-most influential + mode. + Default: True + unpack_inputs (bool, optional): Whether to unpack the `inputs` argument to + when passing it to `model`, if `inputs` is a tuple (no unpacking + done otherwise). + Default: True + + Returns: + + The return value of this method depends on which mode is run. + + - influence score mode: if this mode is run (`inputs is not None, `k` is + None), returns a 2D tensor `influence_scores` of shape + `(input_size, influence_src_dataset_size)`, where `input_size` is + the number of examples in the test batch, and + `influence_src_dataset_size` is the number of examples in + training dataset `influence_src_dataset`. In other words, + `influence_scores[i][j]` is the influence score of the `j`-th + example in `influence_src_dataset` on the `i`-th example in the + test batch. + - k-most influential mode: if this mode is run (`inputs` is not None, + `k` is an int), returns a namedtuple `(indices, influence_scores)`. + `indices` is a 2D tensor of shape `(input_size, k)`, where + `input_size` is the number of examples in the test batch. If + computing proponents (resp. opponents), `indices[i][j]` is the + index in training dataset `influence_src_dataset` of the example + with the `j`-th highest (resp. lowest) influence score (out of the + examples in `influence_src_dataset`) on the `i`-th example in the + test batch. `influence_scores` contains the corresponding influence + scores. In particular, `influence_scores[i][j]` is the influence + score of example `indices[i][j]` in `influence_src_dataset` on + example `i` in the test batch represented by `inputs` and + `targets`. + """ + msg = ( + "Since `inputs` is None, this suggests `TracInCPFastRandProj` is being " + "used in self influence mode. However, `TracInCPFastRandProj` should not " + "be used to compute self influence scores. If desiring self influence " + "scores which only consider gradients in the last fully-connected layer, " + "please use `TracInCPFast` instead." + ) + assert inputs is not None, msg + + _inputs = _format_inputs(inputs, unpack_inputs) + + if inputs is None: + return self._self_influence() + elif k is None: + return self._influence(_inputs, targets) + else: + return self._get_k_most_influential(_inputs, targets, k, proponents) + + def _set_projections_tracincp_fast_rand_proj( + self, + dataloader: DataLoader, + ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """ + returns the variables `jacobian_projection` and `layer_input_projection` + if needed, based on `self.projection_dim`. The two variables are + used by `self._get_intermediate_quantities_fast_rand_proj`. They are both None + if projection is not needed, due to the intermediate quantities (see the + `_get_intermediate_quantities_fast_rand_proj` method for details) being no + greater than `self.projection_dim` * C even without projection, where C is the + number of checkpoints in the `checkpoints` argument to + `TracInCPFastRandProj.__init__`. + + Args: + dataloader (DataLoader): determining the projection requires knowing the + dimensionality of the last layer's parameters (`jacobian_dim` + below) and its input (`layer_input_dim` below). These are + determined by passing a batch to `self.model`. `dataloader` + provides that batch. + + Returns: + jacobian_projection (tensor or None): Projection matrix to apply to + Jacobian of last layer to reduce its dimension, if needed. + None otherwise. + input_projection (tensor or None): Projection matrix to apply to input of + last layer to reduce its dimension, if needed. None otherwise. + """ + # figure out projection dimensions, if needed + + projection_dim = self.projection_dim + + projection_quantities = None + + if not (projection_dim is None): + + # figure out original dimensions by looking at data, passing through network + self.checkpoints_load_func(self.model, next(iter(self.checkpoints))) + + batch = next(iter(dataloader)) + batch_jacobians, batch_layer_inputs = _basic_computation_tracincp_fast( + self, + batch[0:-1], + batch[-1], + ) + + jacobian_dim = batch_jacobians.shape[ + 1 + ] # this is the dimension of the output of the last fully-connected layer + layer_input_dim = batch_layer_inputs.shape[ + 1 + ] # this is the dimension of the input of the last fully-connected layer + + # choose projection if needed + # without projection, the dimension of the intermediate quantities returned + # by `_get_intermediate_quantities_fast_rand_proj` will be + # `jacobian_dim` * `layer_input_dim` * number of checkpoints + # this is because for each checkpoint, we compute a "partial" intermediate + # quantity, and the intermediate quantity is the concatenation of the + # "partial" intermediate quantities, and the dimension of each "partial" + # intermediate quantity, without projection, is `jacobian_dim` * + # `layer_input_dim`. However, `projection_dim` refers to the maximum + # allowable dimension of the "partial" intermediate quantity. Therefore, + # we only project if `jacobian_dim` * `layer_input_dim` > `projection_dim`. + # `projection_dim` corresponds to the variable d in the top of page 15 of + # the TracIn paper: https://arxiv.org/pdf/2002.08484.pdf. + if jacobian_dim * layer_input_dim > projection_dim: + jacobian_projection_dim = min(int(projection_dim ** 0.5), jacobian_dim) + layer_input_projection_dim = min( + int(projection_dim ** 0.5), layer_input_dim + ) + jacobian_projection = torch.normal( + torch.zeros(jacobian_dim, jacobian_projection_dim), + 1.0 / jacobian_projection_dim ** 0.5, + ) + layer_input_projection = torch.normal( + torch.zeros(layer_input_dim, layer_input_projection_dim), + 1.0 / layer_input_projection_dim ** 0.5, + ) + + projection_quantities = jacobian_projection, layer_input_projection + + return projection_quantities + + def _process_src_intermediate_quantities_tracincp_fast_rand_proj( + self, + src_intermediate_quantities: torch.Tensor, + ): + """ + Assumes `self._get_intermediate_quantities_tracin_fast_rand_proj` returns + vector representations for each example, and that influence between a + training and test example is obtained by taking the dot product of their + vector representations. In this case, given a test example, its proponents + can be found by storing the vector representations for training examples + into a data structure enablng fast largest-dot-product computation. This + method creates that data structure. This method has side effects. + + Args: + src_intermediate_quantities (tensor): the output of the + `_get_intermediate_quantities_tracin_fast_rand_proj` function when + applied to training dataset `influence_src_dataset`. This + output is the vector representation of all training examples. + The dot product between the representation of a training example + and the representation of a test example gives the influence score + of the training example on the test example. + """ + self.nearest_neighbors.setup(src_intermediate_quantities) + + def _get_intermediate_quantities_tracincp_fast_rand_proj( + self, + dataloader: DataLoader, + projection_quantities: Optional[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: + r""" + This method computes vectors that can be used to compute influence. (see + Appendix F, page 15). Crucially, the influence score between a test example + and a training example is simply the dot product of their respective + vectors. This means that the training example with the largest influence score + on a given test example can be found using a nearest-neighbor (more + specifically, largest dot-product) data structure. + + Args: + dataloader (DataLoader): DataLoader for which the intermediate quantities + are computed. + projection_quantities (tuple or None): Is either the two tensors defining + the randomized projections to apply, or None, which means no + projection is to be applied. + + Returns: + checkpoint_projections (tensor): A tensor of dimension + (N, D * C), where N is total number of examples in `dataloader`, C + is the number of checkpoints passed as the `checkpoints` argument + of `TracInCPFastRandProj.__init__`, and each row represents the + vector for an example. Regarding D: Let I be the dimension of the + output of the last fully-connected layer times the dimension of the + input of the last fully-connected layer. If `self.projection_dim` + is specified in initialization, + D = min(I * C, `self.projection_dim` * C). Otherwise, D = I * C. + In summary, if `self.projection_dim` is None, the dimension of each + vector will be determined by the size of the input and output of + the last fully-connected layer of `model`. Otherwise, + `self.projection_dim` must be an int, and random projection will be + performed to ensure that the vector is of dimension no more than + `self.projection_dim` * C. `self.projection_dim` corresponds to + the variable d in the top of page 15 of the TracIn paper: + https://arxiv.org/pdf/2002.08484.pdf. + """ + checkpoint_projections: List[Any] = [[] for _ in self.checkpoints] + + if projection_quantities is None: + project = False + else: + project = True + jacobian_projection, layer_input_projection = projection_quantities + + for (j, checkpoint) in enumerate(self.checkpoints): + assert ( + checkpoint is not None + ), "None returned from `checkpoints`, cannot load." + + learning_rate = self.checkpoints_load_func(self.model, checkpoint) + learning_rate_root = learning_rate ** 0.5 + + for batch in dataloader: + + batch_jacobians, batch_layer_inputs = _basic_computation_tracincp_fast( + self, + batch[0:-1], + batch[-1], + ) + + if project: + + batch_jacobians = torch.matmul(batch_jacobians, jacobian_projection) + + batch_layer_inputs = torch.matmul( + batch_layer_inputs, layer_input_projection + ) + + checkpoint_projections[j].append( + torch.matmul( + torch.unsqueeze(batch_jacobians, 2), + torch.unsqueeze(batch_layer_inputs, 1), + ).flatten(start_dim=1) + * learning_rate_root + ) + + checkpoint_projections[j] = torch.cat(checkpoint_projections[j], dim=0) + + return torch.cat(checkpoint_projections, dim=1) diff --git a/captum/influence/_utils/__init__.py b/captum/influence/_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..28c76ebbc3b1c49dbad299029872d369f8931a70 --- /dev/null +++ b/captum/influence/_utils/common.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python3 + +from typing import Callable, Optional, Tuple, Union, Any, List + +import torch +import torch.nn as nn +from captum._utils.progress import progress +from torch import Tensor +from torch.nn import Module +from torch.utils.data import DataLoader, Dataset + + +def _tensor_batch_dot(t1: Tensor, t2: Tensor) -> Tensor: + r""" + Computes pairwise dot product between two tensors + + Args: + Tensors t1 and t2 are feature vectors with dimension (batch_size_1, *) and + (batch_size_2, *). The * dimensions must match in total number of elements. + + Returns: + Tensor with shape (batch_size_1, batch_size_2) containing the pairwise dot + products. For example, Tensor[i][j] would be the dot product between + t1[i] and t2[j]. + """ + + msg = ( + "Please ensure each batch member has the same feature dimension. " + f"First input has {torch.numel(t1) / t1.shape[0]} features, and " + f"second input has {torch.numel(t2) / t2.shape[0]} features." + ) + assert torch.numel(t1) / t1.shape[0] == torch.numel(t2) / t2.shape[0], msg + + return torch.mm( + t1.view(t1.shape[0], -1), + t2.view(t2.shape[0], -1).T, + ) + + +def _gradient_dot_product( + input_grads: Tuple[Tensor], src_grads: Tuple[Tensor] +) -> Tensor: + r""" + Computes the dot product between the gradient vector for a model on an input batch + and src batch, for each pairwise batch member. Gradients are passed in as a tuple + corresponding to the trainable parameters returned by model.parameters(). Output + corresponds to a tensor of size (inputs_batch_size, src_batch_size) with all + pairwise dot products. + """ + + assert len(input_grads) == len(src_grads), "Mismatching gradient parameters." + + iterator = zip(input_grads, src_grads) + total = _tensor_batch_dot(*next(iterator)) + for input_grad, src_grad in iterator: + total += _tensor_batch_dot(input_grad, src_grad) + total = torch.Tensor(total) + + return total + + +def _jacobian_loss_wrt_inputs( + loss_fn: Union[Module, Callable], + out: Tensor, + targets: Tensor, + vectorize: bool, + reduction_type: str, +) -> Tensor: + r""" + Often, we have a loss function that computes a per-sample loss given a 1D tensor + input, and we want to calculate the jacobian of the loss w.r.t. that input. For + example, the input could be a length K tensor specifying the probability a given + sample belongs to each of K possible classes, and the loss function could be + cross-entropy loss. This function performs that calculation, but does so for a + *batch* of inputs. We create this helper function for two reasons: 1) to handle + differences between Pytorch versiosn for vectorized jacobian calculations, and + 2) this function does not accept the aforementioned per-sample loss function. + Instead, it accepts a "reduction" loss function that *reduces* the per-sample loss + for a batch into a single loss. Using a "reduction" loss improves speed. + We will allow this reduction to either be the mean or sum of the per-sample losses, + and this function provides an uniform way to handle different possible reductions, + and also check if the reduction used is valid. Regardless of the reduction used, + this function returns the jacobian for the per-sample loss (for each sample in the + batch). + + Args: + loss_fn (torch.nn.Module or Callable or None): The loss function. If a library + defined loss function is provided, it would be expected to be a + torch.nn.Module. If a custom loss is provided, it can be either type, + but must behave as a library loss function would if `reduction='sum'` + or `reduction='mean'`. + out (tensor): This is a tensor that represents the batch of inputs to + `loss_fn`. In practice, this will be the output of a model; this is + why this argument is named `out`. `out` is a 2D tensor of shape + (batch size, model output dimensionality). We will call `loss_fn` via + `loss_fn(out, targets)`. + targets (tensor): The labels for the batch of inputs. + vectorize (bool): Flag to use experimental vectorize functionality for + `torch.autograd.functional.jacobian`. + reduction_type (str): The type of reduction used by `loss_fn`. If `loss_fn` + has the "reduction" attribute, we will check that they match. Can + only be "mean" or "sum". + + Returns: + jacobians (tensor): Returns the jacobian of the per-sample loss (implicitly + defined by `loss_fn` and `reduction_type`) w.r.t each sample + in the batch represented by `out`. This is a 2D tensor, where the + first dimension is the batch dimension. + """ + # TODO: allow loss_fn to be Callable + if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"): + msg0 = "Please ensure that loss_fn.reduction is set to `sum` or `mean`" + + assert loss_fn.reduction != "none", msg0 + msg1 = ( + f"loss_fn.reduction ({loss_fn.reduction}) does not match" + f"reduction type ({reduction_type}). Please ensure they are" + " matching." + ) + assert loss_fn.reduction == reduction_type, msg1 + + if reduction_type != "sum" and reduction_type != "mean": + raise ValueError( + f"{reduction_type} is not a valid value for reduction_type. " + "Must be either 'sum' or 'mean'." + ) + + if torch.__version__ >= "1.8": + input_jacobians = torch.autograd.functional.jacobian( + lambda out: loss_fn(out, targets), out, vectorize=vectorize + ) + else: + input_jacobians = torch.autograd.functional.jacobian( + lambda out: loss_fn(out, targets), out + ) + + if reduction_type == "mean": + input_jacobians = input_jacobians * len(input_jacobians) + + return input_jacobians + + +def _load_flexible_state_dict( + model: Module, path: str, device_ids: str = "cpu", keyname: Optional[str] = None +) -> int: + r""" + Helper to load pytorch models. This function attempts to find compatibility for + loading models that were trained on different devices / with DataParallel but are + being loaded in a different environment. + + Assumes that the model has been saved as a state_dict in some capacity. This can + either be a single state dict, or a nesting dictionary which contains the model + state_dict and other information. + + Args: + model: The model for which to load a checkpoint + path: The filepath to the checkpoint + keyname: The key under which the model state_dict is stored, if any. + + The module state_dict is modified in-place, and the learning rate is returned. + """ + + device = device_ids + + checkpoint = torch.load(path, map_location=device) + + learning_rate = checkpoint.get("learning_rate", 1) + # can get learning rate from optimizer state_dict? + + if keyname is not None: + checkpoint = checkpoint[keyname] + + if "module." in next(iter(checkpoint)): + if isinstance(model, nn.DataParallel): + model.load_state_dict(checkpoint) + else: + model = nn.DataParallel(model) + model.load_state_dict(checkpoint) + model = model.module + else: + if isinstance(model, nn.DataParallel): + model = model.module + model.load_state_dict(checkpoint) + model = nn.DataParallel(model) + else: + model.load_state_dict(checkpoint) + + return learning_rate + + +def _get_k_most_influential_helper( + influence_src_dataloader: DataLoader, + influence_batch_fn: Callable, + inputs: Tuple[Any, ...], + targets: Optional[Tensor], + k: int = 5, + proponents: bool = True, + show_progress: bool = False, + desc: Optional[str] = None, +) -> Tuple[Tensor, Tensor]: + r""" + Helper function that computes the quantities returned by + `TracInCPBase._get_k_most_influential`, using a specific implementation that is + constant memory. + + Args: + influence_src_dataloader (DataLoader): The DataLoader, representing training + data, for which we want to compute proponents / opponents. + influence_batch_fn (Callable): A callable that will be called via + `influence_batch_fn(inputs, targets, batch)`, where `batch` is a batch + in the `influence_src_dataloader` argument. + inputs (Tuple of Any): A batch of examples. Does not represent labels, + which are passed as `targets`. + targets (Tensor, optional): If computing TracIn scores on a loss function, + these are the labels corresponding to the batch `inputs`. + Default: None + k (int, optional): The number of proponents or opponents to return per test + instance. + Default: 5 + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`) + Default: True + show_progress (bool, optional): To compute the proponents (or opponents) + for the batch of examples, we perform computation for each batch in + training dataset `influence_src_dataloader`, If `show_progress`is + true, the progress of this computation will be displayed. In + particular, the number of batches for which the computation has + been performed will be displayed. It will try to use tqdm if + available for advanced features (e.g. time estimation). Otherwise, + it will fallback to a simple output of progress. + Default: False + desc (str, optional): If `show_progress` is true, this is the description to + show when displaying progress. If `desc` is none, no description is + shown. + Default: None + + Returns: + (indices, influence_scores): `indices` is a torch.long Tensor that contains the + indices of the proponents (or opponents) for each test example. Its + dimension is `(inputs_batch_size, k)`, where `inputs_batch_size` is the + number of examples in `inputs`. For example, if `proponents==True`, + `indices[i][j]` is the index of the example in training dataset + `influence_src_dataloader` with the k-th highest influence score for + the j-th example in `inputs`. `indices` is a `torch.long` tensor so that + it can directly be used to index other tensors. Each row of + `influence_scores` contains the influence scores for a different test + example, in sorted order. In particular, `influence_scores[i][j]` is + the influence score of example `indices[i][j]` in training dataset + `influence_src_dataloader` on example `i` in the test batch represented + by `inputs` and `targets`. + """ + # For each test instance, maintain the best indices and corresponding distances + # initially, these will be empty + topk_indices = torch.Tensor().long() + topk_tracin_scores = torch.Tensor() + + multiplier = 1.0 if proponents else -1.0 + + # needed to map from relative index in a batch fo index within entire `dataloader` + num_instances_processed = 0 + + # if show_progress, create progress bar + total: Optional[int] = None + if show_progress: + try: + total = len(influence_src_dataloader) + except AttributeError: + pass + influence_src_dataloader = progress( + influence_src_dataloader, + desc=desc, + total=total, + ) + + for batch in influence_src_dataloader: + + # calculate tracin_scores for the batch + batch_tracin_scores = influence_batch_fn(inputs, targets, batch) + batch_tracin_scores *= multiplier + + # get the top-k indices and tracin_scores for the batch + batch_size = batch_tracin_scores.shape[1] + batch_topk_tracin_scores, batch_topk_indices = torch.topk( + batch_tracin_scores, min(batch_size, k), dim=1 + ) + batch_topk_indices = batch_topk_indices + num_instances_processed + num_instances_processed += batch_size + + # combine the top-k for the batch with those for previously seen batches + topk_indices = torch.cat([topk_indices, batch_topk_indices], dim=1) + topk_tracin_scores = torch.cat( + [topk_tracin_scores, batch_topk_tracin_scores], dim=1 + ) + + # retain only the top-k in terms of tracin_scores + topk_tracin_scores, topk_argsort = torch.topk( + topk_tracin_scores, min(k, topk_indices.shape[1]), dim=1 + ) + topk_indices = torch.gather(topk_indices, dim=1, index=topk_argsort) + + # if seeking opponents, we were actually keeping track of negative tracin_scores + topk_tracin_scores *= multiplier + + return topk_indices, topk_tracin_scores + + +class _DatasetFromList(Dataset): + def __init__(self, _l: List[Any]): + self._l = _l + + def __getitem__(self, i: int) -> Any: + return self._l[i] + + def __len__(self) -> int: + return len(self._l) diff --git a/captum/influence/_utils/nearest_neighbors.py b/captum/influence/_utils/nearest_neighbors.py new file mode 100644 index 0000000000000000000000000000000000000000..3c26d1d44866534aa8c58c64584cb3f8a493237f --- /dev/null +++ b/captum/influence/_utils/nearest_neighbors.py @@ -0,0 +1,187 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch +from torch import Tensor + + +class NearestNeighbors(ABC): + r""" + An abstract class to define a nearest neighbors data structure. Classes + implementing this interface are intended for computing proponents / opponents in + certain implementations of `TracInCPBase`. In particular, it is for use in + implementations which compute proponents / opponents of a test instance by + 1) storing representations of training instances within a nearest neighbors data + structure, and 2) finding within that structure the nearest neighbor of the + representation of a test instance. The assumption is that the data structure + stores the tensors passed to the `setup` method, which we refer to as the "stored + tensors". If this class is used to find proponents / opponents, the nearest + neighbors of a tensor should be the stored tensors that have the largest + dot-product with the query. + """ + + @abstractmethod + def get_nearest_neighbors( + self, query: torch.Tensor, k: int + ) -> Tuple[Tensor, Tensor]: + r""" + Given a `query`, a tensor of shape (N, *), returns the nearest neighbors in the + "stored tensors" (see above). `query` represents a batch of N tensors, each + of common but arbitrary shape *. We always assume the 0-th dimension indexes + the batch. In use cases of this class for computing proponents / opponents, + the nearest neighbors of a tensor should be the stored tensors with the largest + dot-product with the tensor, and the tensors in `query` will all be 1D, + so that `query` is 2D. + + Args: + query (tensor): tensor representing the batch of tensors for which k-nearest + neighbors are desired. `query` is of shape (N, *), where N is the + size of the batch, i.e. the 0-th dimension of `query` indexes the + batch. * denotes an arbitrary shape, so that each tensor in the + batch can be of a common, but arbitrary shape. + k (int): The number of nearest neighbors to return. + + Returns: + results (tuple): A tuple of `(indices, distances)` is returned. `indices` + is a 2D tensor where `indices[i,j]` is the index (within the + "stored tensors" passed to the `setup` method) of the `j`-th + nearest neighbor of the `i`-th instance in query, and + `distances[i,j]` is the corresponding distance. `indices` should + be of dtype `torch.long` so that it can be used to index torch + tensors. + """ + pass + + @abstractmethod + def setup(self, data: torch.Tensor) -> None: + r""" + `data` denotes the "stored tensors". These are the tensors within which we + want to find the nearest neighbors to each tensor in a batch of tensors, via a + call to the`get_nearest_neighbors` method. Before we can call it, however, + we need to first store the stored tensors, by doing processing that indexes + the stored tensors in a form that enables nearest-neighbors computation. + This method does that preprocessing, and is assumed to be called before any + call to `get_nearest_neighbors`. For example, this method might put the + stored tensors in a K-d tree. The tensors in the "stored tensors" can be of a + common, but arbitrary shape, denoted *, so that `data` is of shape (N, *), + where N is the number of tensors in the stored tensors. Therefore, the 0-th + dimension indexes the tensors in the stored tensors. + + Args: + data (tensor): A tensor of shape (N, *) representing the stored tensors. + The 0-th dimension indexes the tensors in the stored tensors, + so that `data[i]` is the tensor with index `i`. The nearest + neighbors of a query will be referred to by their index. + """ + pass + + +class AnnoyNearestNeighbors(NearestNeighbors): + """ + This is an implementation of `NearestNeighbors` that uses the Annoy module. At a + high level, Annoy finds nearest neighbors by constructing binary trees in which + vectors reside at leaf nodes. Vectors near each other will tend to be in the same + leaf node. See https://tinyurl.com/2p89sb2h and https://github.com/spotify/annoy + for more details. Annoy has 1 key parameter: the number of trees to construct. + Increasing the number of trees leads to more accurate results, but longer time to + create the trees and memory usage. As mentioned in the `NearestNeighbors` + documentation, for the use case of computing proponents / opponents, the nearest + neighbors returned should be those with the largest dot product with the query + vector. The term "vector" is used here because Annoy stores 1D vectors. However + in our wrapper around Annoy, we will allow the stored tensors to be of a common + but arbitrary shape *, and flatten them before storing in the Annoy data structure. + """ + + def __init__(self, num_trees: int = 10): + """ + Args: + num_trees (int): The number of trees to use. Increasing this number gives + more accurate computation of nearest neighbors, but requires longer + setup time to create the trees, as well as memory. + """ + try: + import annoy # noqa + except ImportError: + raise ValueError( + ( + "Using `AnnoyNearestNeighbors` requires installing the annoy " + "module. If pip is installed, this can be done with " + "`pip install --user annoy`." + ) + ) + + self.num_trees = num_trees + + def setup(self, data: torch.Tensor) -> None: + """ + `data` denotes the "stored tensors". These are the tensors within which we + want to find the nearest neighbors to a query tensor, via a call to the + `get_nearest_neighbors` method. Before we can call `get_nearest_neighbors`, + we need to first store the stored tensors, by doing processing that indexes + the stored tensors in a form that enables nearest-neighbors computation. + This method does that preprocessing, and is assumed to be called before any + call to `get_nearest_neighbors`. In particular, it creates the trees used to + index the stored tensors. This index is built to enable computation of + vectors that have the largest dot-product with the query tensors. The tensors + in the "stored tensors" can be of a common, but arbitrary shape, denoted *, so + that `data` is of shape (N, *), where N is the number of tensors in the stored + tensors. Therefore, the 0-th dimension indexes the tensors in the stored + tensors. + + Args: + data (tensor): A tensor of shape (N, *) representing the stored tensors. + The 0-th dimension indexes the tensors in the stored tensors, + so that `data[i]` is the tensor with index `i`. The nearest + neighbors of a query will be referred to by their index. + """ + import annoy + + data = data.view((len(data), -1)) + projection_dim = data.shape[1] + self.knn_index = annoy.AnnoyIndex(projection_dim, "dot") + for (i, projection) in enumerate(data): + self.knn_index.add_item(i, projection) + self.knn_index.build(self.num_trees) + + def get_nearest_neighbors( + self, query: torch.Tensor, k: int + ) -> Tuple[Tensor, Tensor]: + r""" + Given a `query`, a tensor of shape (N, *), returns the nearest neighbors in the + "stored tensors" (see above). `query` represents a batch of N tensors, each + of common but arbitrary shape *. We always assume the 0-th dimension indexes + the batch. In use cases of this class for computing proponents / opponents, + the nearest neighbors of a tensor should be the stored tensors with the largest + dot-product with the tensor, and the tensors in `query` will all be 1D, + so that `query` is 2D. This implementation returns the stored tensors + that have the largest dot-product with the query tensor, and does not constrain + the tensors in `query` or in the stored tensors to be 1D. If tensors are of + dimension greater than 1D, their dot-product will be defined to be the + dot-product of the flattened version of tensors. + + Args: + query (tensor): tensor representing the batch of tensors for which k-nearest + neighbors are desired. `query` is of shape (N, *), where N is the + size of the batch, i.e. the 0-th dimension of `query` indexes the + batch. * denotes an arbitrary shape, so that each tensor in the + batch can be of a common, but arbitrary shape. + k (int): The number of nearest neighbors to return. + + Returns: + results (tuple): A tuple of `(indices, distances)` is returned. `indices` + is a 2D tensor where `indices[i,j]` is the index (within the + "stored tensors" passed to the `setup` method) of the `j`-th + nearest neighbor of the `i`-th instance in query, and + `distances[i,j]` is the corresponding distance. `indices` should + be of dtype `torch.long` so that it can be used to index torch + tensors. + """ + query = query.view((len(query), -1)) + indices_and_distances = [ + self.knn_index.get_nns_by_vector(instance, k, include_distances=True) + for instance in query + ] + indices, distances = zip(*indices_and_distances) + indices = torch.Tensor(indices).type(torch.long) + distances = torch.Tensor(distances) + return indices, distances diff --git a/captum/insights/__init__.py b/captum/insights/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..48ba6fdfa0f21e7392cdf7bb4827356644596c3b --- /dev/null +++ b/captum/insights/__init__.py @@ -0,0 +1 @@ +from captum.insights.attr_vis import AttributionVisualizer, Batch # noqa diff --git a/captum/insights/_utils/__init__.py b/captum/insights/_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/insights/attr_vis/__init__.py b/captum/insights/attr_vis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a5d0102ff6237c6be870c6fe94645129c712b5a4 --- /dev/null +++ b/captum/insights/attr_vis/__init__.py @@ -0,0 +1 @@ +from captum.insights.attr_vis.app import AttributionVisualizer, Batch # noqa diff --git a/captum/insights/attr_vis/_utils/__init__.py b/captum/insights/attr_vis/_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/insights/attr_vis/_utils/transforms.py b/captum/insights/attr_vis/_utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..fb376b7c3b5ec824ab506bd4057002be7f787fdd --- /dev/null +++ b/captum/insights/attr_vis/_utils/transforms.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 + +from typing import Callable, List, Optional, Union + + +def format_transforms( + transforms: Optional[Union[Callable, List[Callable]]] +) -> List[Callable]: + if transforms is None: + return [] + if callable(transforms): + return [transforms] + return transforms diff --git a/captum/insights/attr_vis/app.py b/captum/insights/attr_vis/app.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0433090bcea2a62de80020884fdee230d37bee --- /dev/null +++ b/captum/insights/attr_vis/app.py @@ -0,0 +1,493 @@ +#!/usr/bin/env python3 +from collections import namedtuple +from itertools import cycle +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +import torch +from captum.attr import IntegratedGradients +from captum.attr._utils.batching import _batched_generator +from captum.insights.attr_vis.attribution_calculation import ( + AttributionCalculation, + OutputScore, +) +from captum.insights.attr_vis.config import ( + ATTRIBUTION_METHOD_CONFIG, + ATTRIBUTION_NAMES_TO_METHODS, +) +from captum.insights.attr_vis.features import BaseFeature +from captum.insights.attr_vis.server import namedtuple_to_dict +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module + +_CONTEXT_COLAB = "_CONTEXT_COLAB" +_CONTEXT_IPYTHON = "_CONTEXT_IPYTHON" +_CONTEXT_NONE = "_CONTEXT_NONE" + + +def _get_context(): + """Determine the most specific context that we're in. + Implementation from TensorBoard: https://git.io/JvObD. + + Returns: + _CONTEXT_COLAB: If in Colab with an IPython notebook context. + _CONTEXT_IPYTHON: If not in Colab, but we are in an IPython notebook + context (e.g., from running `jupyter notebook` at the command + line). + _CONTEXT_NONE: Otherwise (e.g., by running a Python script at the + command-line or using the `ipython` interactive shell). + """ + # In Colab, the `google.colab` module is available, but the shell + # returned by `IPython.get_ipython` does not have a `get_trait` + # method. + try: + import google.colab # noqa: F401 + import IPython + except ImportError: + pass + else: + if IPython.get_ipython() is not None: + # We'll assume that we're in a Colab notebook context. + return _CONTEXT_COLAB + + # In an IPython command line shell or Jupyter notebook, we can + # directly query whether we're in a notebook context. + try: + import IPython + except ImportError: + pass + else: + ipython = IPython.get_ipython() + if ipython is not None and ipython.has_trait("kernel"): + return _CONTEXT_IPYTHON + + # Otherwise, we're not in a known notebook context. + return _CONTEXT_NONE + + +VisualizationOutput = namedtuple( + "VisualizationOutput", "feature_outputs actual predicted active_index model_index" +) +Contribution = namedtuple("Contribution", "name percent") +SampleCache = namedtuple("SampleCache", "inputs additional_forward_args label") + + +class FilterConfig(NamedTuple): + attribution_method: str = IntegratedGradients.get_name() + # issue with mypy github.com/python/mypy/issues/8376 + attribution_arguments: Dict[str, Any] = { + arg: config.value # type: ignore + for arg, config in ATTRIBUTION_METHOD_CONFIG[ + IntegratedGradients.get_name() + ].params.items() + } + prediction: str = "all" + classes: List[str] = [] + num_examples: int = 4 + + +class Batch: + def __init__( + self, + inputs: Union[Tensor, Tuple[Tensor, ...]], + labels: Optional[Tensor], + additional_args=None, + ) -> None: + r""" + Constructs batch of inputs to be attributed and visualized. + + Args: + + inputs (tensor or tuple of tensors): Batch of inputs for a model. + These may be either a Tensor or tuple of tensors. Each tensor + must correspond to a feature for AttributionVisualizer, and + the corresponding input transform function of the feature + is applied to each input tensor prior to passing it to the + model. It is assumed that the first dimension of each + input tensor corresponds to the number of examples + (batch size) and is aligned for all input tensors. + labels (tensor): Tensor containing correct labels for input examples. + This must be a 1D tensor with length matching the first + dimension of each input tensor. + additional_args (tuple, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to ``forward_func`` in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. + """ + self.inputs = inputs + self.labels = labels + self.additional_args = additional_args + + +class AttributionVisualizer: + def __init__( + self, + models: Union[List[Module], Module], + classes: List[str], + features: Union[List[BaseFeature], BaseFeature], + dataset: Iterable[Batch], + score_func: Optional[Callable] = None, + use_label_for_attr: bool = True, + ) -> None: + r""" + Args: + + models (torch.nn.module): One or more PyTorch modules (models) for + attribution visualization. + classes (list of string): List of strings corresponding to the names of + classes for classification. + features (list of BaseFeature): List of BaseFeatures, which correspond + to input arguments to the model. Each feature object defines + relevant transformations for converting to model input, + constructing baselines, and visualizing. The length of the + features list should exactly match the number of (tensor) + arguments expected by the given model. + For instance, an image classifier should only provide + a single BaseFeature, while a multimodal classifier may + provide a list of features, each corresponding to a different + tensor input and potentially different modalities. + dataset (iterable of Batch): Defines the dataset to visualize attributions + for. This must be an iterable of batch objects, each of which + may contain multiple input examples. + score_func (callable, optional): This function is applied to the model + output to obtain the score for each class. For instance, + this function could be the softmax or final non-linearity + of the network, applied to the model output. The indices + of the second dimension of the output should correspond + to the class names provided. If None, the model outputs + are taken directly and assumed to correspond to the + class scores. + Default: None + use_label_for_attr (boolean, optional): If true, the class index is passed + to the relevant attribution method. This is necessary in most + cases where there is an output neuron corresponding to each + class. When the model output is a scalar and class index + (e.g. positive, negative) is inferred from the output value, + this argument should be False. + Default: True + """ + if not isinstance(models, List): + models = [models] + + if not isinstance(features, List): + features = [features] + + self.classes = classes + self.features = features + self.dataset = dataset + self.models = models + self.attribution_calculation = AttributionCalculation( + models, classes, features, score_func, use_label_for_attr + ) + self._outputs: List[VisualizationOutput] = [] + self._config = FilterConfig(prediction="all", classes=[], num_examples=4) + self._dataset_iter = iter(dataset) + self._dataset_cache: List[Batch] = [] + + def _calculate_attribution_from_cache( + self, input_index: int, model_index: int, target: Optional[Tensor] + ) -> Optional[VisualizationOutput]: + c = self._outputs[input_index][1] + result = self._calculate_vis_output( + c.inputs, + c.additional_forward_args, + c.label, + torch.tensor(target), + model_index, + ) + + if not result: + return None + return result[0] + + def _update_config(self, settings): + self._config = FilterConfig( + attribution_method=settings["attribution_method"], + attribution_arguments=settings["arguments"], + prediction=settings["prediction"], + classes=settings["classes"], + num_examples=4, + ) + + @log_usage() + def render(self, debug=True): + from captum.insights.attr_vis.widget import CaptumInsights + from IPython.display import display + + widget = CaptumInsights(visualizer=self) + display(widget) + if debug: + display(widget.out) + + @log_usage() + def serve(self, blocking=False, debug=False, port=None, bind_all=False): + context = _get_context() + if context == _CONTEXT_COLAB: + return self._serve_colab(blocking=blocking, debug=debug, port=port) + else: + return self._serve( + blocking=blocking, debug=debug, port=port, bind_all=bind_all + ) + + def _serve(self, blocking=False, debug=False, port=None, bind_all=False): + from captum.insights.attr_vis.server import start_server + + return start_server( + self, blocking=blocking, debug=debug, _port=port, bind_all=bind_all + ) + + def _serve_colab(self, blocking=False, debug=False, port=None): + import ipywidgets as widgets + from captum.insights.attr_vis.server import start_server + from IPython.display import display, HTML + + # TODO: Output widget only captures beginning of server logs. It seems + # the context manager isn't respected when the web server is run on a + # separate thread. We should fix to display entirety of the logs + out = widgets.Output() + with out: + port = start_server(self, blocking=blocking, debug=debug, _port=port) + shell = """ +
+ + """.replace( + "%PORT%", str(port) + ) + html = HTML(shell) + display(html) + display(out) + + def _predictions_matches_labels( + self, predicted_scores: List[OutputScore], labels: Union[str, List[str]] + ) -> bool: + if len(predicted_scores) == 0: + return False + + predicted_label = predicted_scores[0].label + + if isinstance(labels, List): + return predicted_label in labels + + return labels == predicted_label + + def _should_keep_prediction( + self, predicted_scores: List[OutputScore], actual_label: Optional[OutputScore] + ) -> bool: + # filter by class + if len(self._config.classes) != 0: + if not self._predictions_matches_labels( + predicted_scores, self._config.classes + ): + return False + + if not actual_label: + return True + + # filter by accuracy + label_name = actual_label.label + if self._config.prediction == "all": + pass + elif self._config.prediction == "correct": + if not self._predictions_matches_labels(predicted_scores, label_name): + return False + elif self._config.prediction == "incorrect": + if self._predictions_matches_labels(predicted_scores, label_name): + return False + else: + raise Exception(f"Invalid prediction config: {self._config.prediction}") + + return True + + def _calculate_vis_output( + self, + inputs, + additional_forward_args, + label, + target=None, + single_model_index=None, + ) -> Optional[List[VisualizationOutput]]: + # Use all models, unless the user wants to render data for a particular one + models_used = ( + [self.models[single_model_index]] + if single_model_index is not None + else self.models + ) + results = [] + for model_index, model in enumerate(models_used): + # Get list of model visualizations for each input + actual_label_output = None + if label is not None and len(label) > 0: + label_index = int(label[0]) + actual_label_output = OutputScore( + score=100, index=label_index, label=self.classes[label_index] + ) + + ( + predicted_scores, + baselines, + transformed_inputs, + ) = self.attribution_calculation.calculate_predicted_scores( + inputs, additional_forward_args, model + ) + + # Filter based on UI configuration + if actual_label_output is None or not self._should_keep_prediction( + predicted_scores, actual_label_output + ): + continue + + if target is None: + target = ( + predicted_scores[0].index if len(predicted_scores) > 0 else None + ) + + # attributions are given per input* + # inputs given to the model are described via `self.features` + # + # *an input contains multiple features that represent it + # e.g. all the pixels that describe an image is an input + + attrs_per_feature = self.attribution_calculation.calculate_attribution( + baselines, + transformed_inputs, + additional_forward_args, + target, + self._config.attribution_method, + self._config.attribution_arguments, + model, + ) + + net_contrib = self.attribution_calculation.calculate_net_contrib( + attrs_per_feature + ) + + # the features per input given + features_per_input = [ + feature.visualize(attr, data, contrib) + for feature, attr, data, contrib in zip( + self.features, attrs_per_feature, inputs, net_contrib + ) + ] + + results.append( + VisualizationOutput( + feature_outputs=features_per_input, + actual=actual_label_output, + predicted=predicted_scores, + active_index=target + if target is not None + else actual_label_output.index, + # Even if we only iterated over one model, the index should be fixed + # to show the index the model would have had in the list + model_index=single_model_index + if single_model_index is not None + else model_index, + ) + ) + + return results if results else None + + def _get_outputs(self) -> List[Tuple[List[VisualizationOutput], SampleCache]]: + # If we run out of new batches, then we need to + # display data which was already shown before. + # However, since the dataset given to us is a generator, + # we can't reset it to return to the beginning. + # Because of this, we store a small cache of stale + # data, and iterate on it after the main generator + # stops returning new batches. + try: + batch_data = next(self._dataset_iter) + self._dataset_cache.append(batch_data) + if len(self._dataset_cache) > self._config.num_examples: + self._dataset_cache.pop(0) + except StopIteration: + self._dataset_iter = cycle(self._dataset_cache) + batch_data = next(self._dataset_iter) + + vis_outputs = [] + + # Type ignore for issue with passing union to function taking generic + # https://github.com/python/mypy/issues/1533 + for ( + inputs, + additional_forward_args, + label, + ) in _batched_generator( # type: ignore + inputs=batch_data.inputs, + additional_forward_args=batch_data.additional_args, + target_ind=batch_data.labels, + internal_batch_size=1, # should be 1 until we have batch label support + ): + output = self._calculate_vis_output(inputs, additional_forward_args, label) + if output is not None: + cache = SampleCache(inputs, additional_forward_args, label) + vis_outputs.append((output, cache)) + + return vis_outputs + + @log_usage() + def visualize(self): + self._outputs = [] + while len(self._outputs) < self._config.num_examples: + self._outputs.extend(self._get_outputs()) + return [o[0] for o in self._outputs] + + def get_insights_config(self): + return { + "classes": self.classes, + "methods": list(ATTRIBUTION_NAMES_TO_METHODS.keys()), + "method_arguments": namedtuple_to_dict( + {k: v.params for (k, v) in ATTRIBUTION_METHOD_CONFIG.items()} + ), + "selected_method": self._config.attribution_method, + } diff --git a/captum/insights/attr_vis/attribution_calculation.py b/captum/insights/attr_vis/attribution_calculation.py new file mode 100644 index 0000000000000000000000000000000000000000..3f695b1807b6329d55839d578ec1b4c9509f4ff2 --- /dev/null +++ b/captum/insights/attr_vis/attribution_calculation.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +import inspect +from collections import namedtuple +from typing import ( + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + Union, +) + +import torch +from captum._utils.common import _run_forward, safe_div +from captum.insights.attr_vis.config import ( + ATTRIBUTION_METHOD_CONFIG, + ATTRIBUTION_NAMES_TO_METHODS, +) +from captum.insights.attr_vis.features import BaseFeature +from torch import Tensor +from torch.nn import Module + +OutputScore = namedtuple("OutputScore", "score index label") + + +class AttributionCalculation: + def __init__( + self, + models: Sequence[Module], + classes: Sequence[str], + features: List[BaseFeature], + score_func: Optional[Callable] = None, + use_label_for_attr: bool = True, + ) -> None: + self.models = models + self.classes = classes + self.features = features + self.score_func = score_func + self.use_label_for_attr = use_label_for_attr + self.baseline_cache: dict = {} + self.transformed_input_cache: dict = {} + + def calculate_predicted_scores( + self, inputs, additional_forward_args, model + ) -> Tuple[ + List[OutputScore], Optional[List[Tuple[Tensor, ...]]], Tuple[Tensor, ...] + ]: + # Check if inputs have cached baselines and transformed inputs + hashable_inputs = tuple(inputs) + if hashable_inputs in self.baseline_cache: + baselines_group = self.baseline_cache[hashable_inputs] + transformed_inputs = self.transformed_input_cache[hashable_inputs] + else: + # Initialize baselines + baseline_transforms_len = 1 # todo support multiple baselines + baselines: List[List[Optional[Tensor]]] = [ + [None] * len(self.features) for _ in range(baseline_transforms_len) + ] + transformed_inputs = list(inputs) + for feature_i, feature in enumerate(self.features): + transformed_inputs[feature_i] = self._transform( + feature.input_transforms, transformed_inputs[feature_i], True + ) + for baseline_i in range(baseline_transforms_len): + if baseline_i > len(feature.baseline_transforms) - 1: + baselines[baseline_i][feature_i] = torch.zeros_like( + transformed_inputs[feature_i] + ) + else: + baselines[baseline_i][feature_i] = self._transform( + [feature.baseline_transforms[baseline_i]], + transformed_inputs[feature_i], + True, + ) + + baselines = cast(List[List[Optional[Tensor]]], baselines) + baselines_group = [tuple(b) for b in baselines] + self.baseline_cache[hashable_inputs] = baselines_group + self.transformed_input_cache[hashable_inputs] = transformed_inputs + + outputs = _run_forward( + model, + tuple(transformed_inputs), + additional_forward_args=additional_forward_args, + ) + + if self.score_func is not None: + outputs = self.score_func(outputs) + + if outputs.nelement() == 1: + scores = outputs + predicted = scores.round().to(torch.int) + else: + scores, predicted = outputs.topk(min(4, outputs.shape[-1])) + + scores = scores.cpu().squeeze(0) + predicted = predicted.cpu().squeeze(0) + + predicted_scores = self._get_labels_from_scores(scores, predicted) + + return predicted_scores, baselines_group, tuple(transformed_inputs) + + def calculate_attribution( + self, + baselines: Optional[Sequence[Tuple[Tensor, ...]]], + data: Tuple[Tensor, ...], + additional_forward_args: Optional[Tuple[Tensor, ...]], + label: Optional[Union[Tensor]], + attribution_method_name: str, + attribution_arguments: Dict, + model: Module, + ) -> Tuple[Tensor, ...]: + attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[attribution_method_name] + attribution_method = attribution_cls(model) + if attribution_method_name in ATTRIBUTION_METHOD_CONFIG: + param_config = ATTRIBUTION_METHOD_CONFIG[attribution_method_name] + if param_config.post_process: + for k, v in attribution_arguments.items(): + if k in param_config.post_process: + attribution_arguments[k] = param_config.post_process[k](v) + + # TODO support multiple baselines + baseline = baselines[0] if baselines and len(baselines) > 0 else None + label = ( + None + if not self.use_label_for_attr or label is None or label.nelement() == 0 + else label + ) + if "baselines" in inspect.signature(attribution_method.attribute).parameters: + attribution_arguments["baselines"] = baseline + attr = attribution_method.attribute.__wrapped__( + attribution_method, # self + data, + additional_forward_args=additional_forward_args, + target=label, + **attribution_arguments, + ) + + return attr + + def calculate_net_contrib( + self, attrs_per_input_feature: Tuple[Tensor, ...] + ) -> List[float]: + # get the net contribution per feature (input) + net_contrib = torch.stack( + [attrib.flatten().sum() for attrib in attrs_per_input_feature] + ) + + # normalise the contribution, s.t. sum(abs(x_i)) = 1 + norm = torch.norm(net_contrib, p=1) + # if norm is 0, all net_contrib elements are 0 + net_contrib = safe_div(net_contrib, norm) + + return net_contrib.tolist() + + def _transform( + self, transforms: Iterable[Callable], inputs: Tensor, batch: bool = False + ) -> Tensor: + transformed_inputs = inputs + # TODO support batch size > 1 + if batch: + transformed_inputs = inputs.squeeze(0) + + for t in transforms: + transformed_inputs = t(transformed_inputs) + + if batch: + transformed_inputs = transformed_inputs.unsqueeze(0) + + return transformed_inputs + + def _get_labels_from_scores( + self, scores: Tensor, indices: Tensor + ) -> List[OutputScore]: + pred_scores: List[OutputScore] = [] + if indices.nelement() < 2: + return pred_scores + for i in range(len(indices)): + score = scores[i] + pred_scores.append( + OutputScore(score, indices[i], self.classes[int(indices[i])]) + ) + return pred_scores diff --git a/captum/insights/attr_vis/config.py b/captum/insights/attr_vis/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b5d88cc92207e62e1a19a25c9283c58f1e14a6f0 --- /dev/null +++ b/captum/insights/attr_vis/config.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union + +from captum.attr import ( + Deconvolution, + DeepLift, + FeatureAblation, + GuidedBackprop, + InputXGradient, + IntegratedGradients, + Occlusion, + Saliency, +) +from captum.attr._utils.approximation_methods import SUPPORTED_METHODS + + +class NumberConfig(NamedTuple): + value: int = 1 + limit: Tuple[Optional[int], Optional[int]] = (None, None) + type: str = "number" + + +class StrEnumConfig(NamedTuple): + value: str + limit: List[str] + type: str = "enum" + + +class StrConfig(NamedTuple): + value: str + type: str = "string" + + +Config = Union[NumberConfig, StrEnumConfig, StrConfig] + +SUPPORTED_ATTRIBUTION_METHODS = [ + Deconvolution, + DeepLift, + GuidedBackprop, + InputXGradient, + IntegratedGradients, + Saliency, + FeatureAblation, + Occlusion, +] + + +class ConfigParameters(NamedTuple): + params: Dict[str, Config] + help_info: Optional[str] = None # TODO fill out help for each method + post_process: Optional[Dict[str, Callable[[Any], Any]]] = None + + +ATTRIBUTION_NAMES_TO_METHODS = { + # mypy bug - treating it as a type instead of a class + cls.get_name(): cls # type: ignore + for cls in SUPPORTED_ATTRIBUTION_METHODS +} + + +def _str_to_tuple(s): + if isinstance(s, tuple): + return s + return tuple([int(i) for i in s.split()]) + + +ATTRIBUTION_METHOD_CONFIG: Dict[str, ConfigParameters] = { + IntegratedGradients.get_name(): ConfigParameters( + params={ + "n_steps": NumberConfig(value=25, limit=(2, None)), + "method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"), + }, + post_process={"n_steps": int}, + ), + FeatureAblation.get_name(): ConfigParameters( + params={"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100))}, + ), + Occlusion.get_name(): ConfigParameters( + params={ + "sliding_window_shapes": StrConfig(value=""), + "strides": StrConfig(value=""), + "perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)), + }, + post_process={ + "sliding_window_shapes": _str_to_tuple, + "strides": _str_to_tuple, + "perturbations_per_eval": int, + }, + ), +} diff --git a/captum/insights/attr_vis/example.py b/captum/insights/attr_vis/example.py new file mode 100644 index 0000000000000000000000000000000000000000..72d7892758da21571fa9d801c12ad6b3559e80bc --- /dev/null +++ b/captum/insights/attr_vis/example.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +import os + +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from captum.insights import AttributionVisualizer, Batch +from captum.insights.attr_vis.features import ImageFeature + + +def get_classes(): + classes = [ + "Plane", + "Car", + "Bird", + "Cat", + "Deer", + "Dog", + "Frog", + "Horse", + "Ship", + "Truck", + ] + return classes + + +def get_pretrained_model(): + class Net(nn.Module): + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool1 = nn.MaxPool2d(2, 2) + self.pool2 = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.relu3 = nn.ReLU() + self.relu4 = nn.ReLU() + + def forward(self, x): + x = self.pool1(self.relu1(self.conv1(x))) + x = self.pool2(self.relu2(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = self.relu3(self.fc1(x)) + x = self.relu4(self.fc2(x)) + x = self.fc3(x) + return x + + net = Net() + pt_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "models/cifar_torchvision.pt") + ) + net.load_state_dict(torch.load(pt_path)) + return net + + +def baseline_func(input): + return input * 0 + + +def formatted_data_iter(): + dataset = torchvision.datasets.CIFAR10( + root="data/test", train=False, download=True, transform=transforms.ToTensor() + ) + dataloader = iter( + torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False, num_workers=2) + ) + while True: + images, labels = next(dataloader) + yield Batch(inputs=images, labels=labels) + + +def main(): + normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + model = get_pretrained_model() + visualizer = AttributionVisualizer( + models=[model], + score_func=lambda o: torch.nn.functional.softmax(o, 1), + classes=get_classes(), + features=[ + ImageFeature( + "Photo", + baseline_transforms=[baseline_func], + input_transforms=[normalize], + ) + ], + dataset=formatted_data_iter(), + ) + + visualizer.serve(debug=True) + + +if __name__ == "__main__": + main() diff --git a/captum/insights/attr_vis/features.py b/captum/insights/attr_vis/features.py new file mode 100644 index 0000000000000000000000000000000000000000..0986170758cd37e38a60abefe034e4104d76c8a9 --- /dev/null +++ b/captum/insights/attr_vis/features.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +import base64 +import warnings +from collections import namedtuple +from io import BytesIO +from typing import Callable, List, Optional, Union + +from captum._utils.common import safe_div +from captum.attr._utils import visualization as viz +from captum.insights.attr_vis._utils.transforms import format_transforms + +FeatureOutput = namedtuple("FeatureOutput", "name base modified type contribution") + + +def _convert_figure_base64(fig): + buff = BytesIO() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + fig.tight_layout() # removes padding + fig.savefig(buff, format="png") + base64img = base64.b64encode(buff.getvalue()).decode("utf-8") + return base64img + + +class BaseFeature: + r""" + All Feature classes extend this class to implement custom visualizations in + Insights. + + It enforces child classes to implement ``visualization_type`` and ``visualize`` + methods. + """ + + def __init__( + self, + name: str, + baseline_transforms: Optional[Union[Callable, List[Callable]]], + input_transforms: Optional[Union[Callable, List[Callable]]], + visualization_transform: Optional[Callable], + ) -> None: + r""" + Args: + + name (str): The label of the specific feature. For example, an + ImageFeature's name can be "Photo". + baseline_transforms (list, callable, optional): Optional list of + callables (e.g. functions) to be called on the input tensor + to construct multiple baselines. Currently only one baseline + is supported. See + :py:class:`.IntegratedGradients` for more + information about baselines. + input_transforms (list, callable, optional): Optional list of callables + (e.g. functions) called on the input tensor sequentially to + convert it into the format expected by the model. + visualization_transform (callable, optional): Optional callable (e.g. + function) applied as a postprocessing step of the original + input data (before ``input_transforms``) to convert it to a + format to be understood by the frontend visualizer as + specified in ``captum/captum/insights/frontend/App.js``. + """ + self.name = name + self.baseline_transforms = format_transforms(baseline_transforms) + self.input_transforms = format_transforms(input_transforms) + self.visualization_transform = visualization_transform + + @staticmethod + def visualization_type() -> str: + raise NotImplementedError + + def visualize(self, attribution, data, contribution_frac) -> FeatureOutput: + raise NotImplementedError + + +class ImageFeature(BaseFeature): + r""" + ImageFeature is used to visualize image features in Insights. It expects an image in + NCHW format. If C has a dimension of 1, its assumed to be a greyscale image. + If it has a dimension of 3, its expected to be in RGB format. + """ + + def __init__( + self, + name: str, + baseline_transforms: Union[Callable, List[Callable]], + input_transforms: Union[Callable, List[Callable]], + visualization_transform: Optional[Callable] = None, + ) -> None: + r""" + Args: + name (str): The label of the specific feature. For example, an + ImageFeature's name can be "Photo". + baseline_transforms (list, callable, optional): Optional list of + callables (e.g. functions) to be called on the input tensor + to construct multiple baselines. Currently only one baseline + is supported. See + :py:class:`.IntegratedGradients` for more + information about baselines. + input_transforms (list, callable, optional): A list of transforms + or transform to be applied to the input. For images, + normalization is often applied here. + visualization_transform (callable, optional): Optional callable (e.g. + function) applied as a postprocessing step of the original + input data (before input_transforms) to convert it to a + format to be visualized. + """ + super().__init__( + name, + baseline_transforms=baseline_transforms, + input_transforms=input_transforms, + visualization_transform=visualization_transform, + ) + + @staticmethod + def visualization_type() -> str: + return "image" + + def visualize(self, attribution, data, contribution_frac) -> FeatureOutput: + if self.visualization_transform: + data = self.visualization_transform(data) + + data_t, attribution_t = [ + t.detach().squeeze().permute((1, 2, 0)).cpu().numpy() + for t in (data, attribution) + ] + + orig_fig, _ = viz.visualize_image_attr( + attribution_t, data_t, method="original_image", use_pyplot=False + ) + attr_fig, _ = viz.visualize_image_attr( + attribution_t, + data_t, + method="heat_map", + sign="absolute_value", + use_pyplot=False, + ) + + img_64 = _convert_figure_base64(orig_fig) + attr_img_64 = _convert_figure_base64(attr_fig) + + return FeatureOutput( + name=self.name, + base=img_64, + modified=attr_img_64, + type=self.visualization_type(), + contribution=contribution_frac, + ) + + +class TextFeature(BaseFeature): + r""" + TextFeature is used to visualize text (e.g. sentences) in Insights. + It expects the visualization transform to convert the input data (e.g. index to + string) to the raw text. + """ + + def __init__( + self, + name: str, + baseline_transforms: Union[Callable, List[Callable]], + input_transforms: Union[Callable, List[Callable]], + visualization_transform: Callable, + ) -> None: + r""" + Args: + name (str): The label of the specific feature. For example, an + ImageFeature's name can be "Photo". + baseline_transforms (list, callable, optional): Optional list of + callables (e.g. functions) to be called on the input tensor + to construct multiple baselines. Currently only one baseline + is supported. See + :py:class:`.IntegratedGradients` for more + information about baselines. + For text features, a common baseline is a tensor of indices + corresponding to PAD with the same size as the input + tensor. See :py:class:`.TokenReferenceBase` for more + information. + input_transforms (list, callable, optional): A list of transforms + or transform to be applied to the input. For text, a common + transform is to convert the tokenized input tensor into an + interpretable embedding. See + :py:class:`.InterpretableEmbeddingBase` + and + :py:func:`~.configure_interpretable_embedding_layer` + for more information. + visualization_transform (callable, optional): Optional callable (e.g. + function) applied as a postprocessing step of the original + input data (before ``input_transforms``) to convert it to a + suitable format for visualization. For text features, + a common function is to convert the token indices to their + corresponding (sub)words. + """ + super().__init__( + name, + baseline_transforms=baseline_transforms, + input_transforms=input_transforms, + visualization_transform=visualization_transform, + ) + + @staticmethod + def visualization_type() -> str: + return "text" + + def visualize(self, attribution, data, contribution_frac) -> FeatureOutput: + if self.visualization_transform: + text = self.visualization_transform(data) + else: + text = data + + attribution = attribution.squeeze(0) + data = data.squeeze(0) + if len(attribution.shape) > 1: + attribution = attribution.sum(dim=1) + + # L-Infinity norm, if norm is 0, all attr elements are 0 + attr_max = attribution.abs().max() + normalized_attribution = safe_div(attribution, attr_max) + + modified = [x * 100 for x in normalized_attribution.tolist()] + return FeatureOutput( + name=self.name, + base=text, + modified=modified, + type=self.visualization_type(), + contribution=contribution_frac, + ) + + +class GeneralFeature(BaseFeature): + r""" + GeneralFeature is used for non-specified feature visualization in Insights. + It can be used for dense or sparse features. + + Currently general features are only supported for 2-d tensors, in the format (N, C) + where N is the number of samples and C is the number of categories. + """ + + def __init__(self, name: str, categories: List[str]) -> None: + r""" + Args: + name (str): The label of the specific feature. For example, an + ImageFeature's name can be "Photo". + categories (list[str]): Category labels for the general feature. The + order and size should match the second dimension of the + ``data`` tensor parameter in ``visualize``. + """ + super().__init__( + name, + baseline_transforms=None, + input_transforms=None, + visualization_transform=None, + ) + self.categories = categories + + @staticmethod + def visualization_type() -> str: + return "general" + + def visualize(self, attribution, data, contribution_frac) -> FeatureOutput: + attribution = attribution.squeeze(0) + data = data.squeeze(0) + + # L-2 norm, if norm is 0, all attr elements are 0 + l2_norm = attribution.norm() + normalized_attribution = safe_div(attribution, l2_norm) + + modified = [x * 100 for x in normalized_attribution.tolist()] + + base = [f"{c}: {d:.2f}" for c, d in zip(self.categories, data.tolist())] + return FeatureOutput( + name=self.name, + base=base, + modified=modified, + type=self.visualization_type(), + contribution=contribution_frac, + ) + + +class EmptyFeature(BaseFeature): + def __init__( + self, + name: str = "empty", + baseline_transforms: Optional[Union[Callable, List[Callable]]] = None, + input_transforms: Optional[Union[Callable, List[Callable]]] = None, + visualization_transform: Optional[Callable] = None, + ) -> None: + super().__init__( + name, + baseline_transforms=baseline_transforms, + input_transforms=input_transforms, + visualization_transform=visualization_transform, + ) + + @staticmethod + def visualization_type() -> str: + return "empty" + + def visualize(self, _attribution, _data, contribution_frac) -> FeatureOutput: + return FeatureOutput( + name=self.name, + base=None, + modified=None, + type=self.visualization_type(), + contribution=contribution_frac, + ) diff --git a/captum/insights/attr_vis/frontend/README.md b/captum/insights/attr_vis/frontend/README.md new file mode 100644 index 0000000000000000000000000000000000000000..757cd9392e0a637787b31af9595d35bdf21fd624 --- /dev/null +++ b/captum/insights/attr_vis/frontend/README.md @@ -0,0 +1,10 @@ +# PyTorch Captum Insights + +## Development + +Captum Insights frontend is built on top of React, using create-react-app. + +`yarn build` to compile files. + +Javascript should be prettier formatted. +CSS uses BEM syntax. diff --git a/captum/insights/attr_vis/frontend/public/index.html b/captum/insights/attr_vis/frontend/public/index.html new file mode 100644 index 0000000000000000000000000000000000000000..d39c84fa4976cd8e94f7ec36fb18e2ac6366ae0e --- /dev/null +++ b/captum/insights/attr_vis/frontend/public/index.html @@ -0,0 +1,12 @@ + + + + + + Captum Insights + + + +
+ + diff --git a/captum/insights/attr_vis/frontend/src/App.css b/captum/insights/attr_vis/frontend/src/App.css new file mode 100644 index 0000000000000000000000000000000000000000..60344bbfa18cf9e75de40d5abc85eb8f3092a42f --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/App.css @@ -0,0 +1,132 @@ +.react-tags { + position: relative; + padding: 6px 0 0 6px; + border: 1px solid #d1d1d1; + border-radius: 1px; + + /* shared font styles */ + font-size: 12px; + line-height: 1.2; + + /* clicking anywhere will focus the input */ + cursor: text; +} + +.react-tags.is-focused { + border-color: #b1b1b1; +} + +.react-tags__selected { + display: inline; +} + +.react-tags__selected-tag { + display: inline-block; + box-sizing: border-box; + margin: 0 6px 6px 0; + padding: 6px 8px; + border: 1px solid #d1d1d1; + border-radius: 2px; + background: #f1f1f1; + + /* match the font styles */ + font-size: inherit; + line-height: inherit; +} + +.react-tags__selected-tag:after { + content: "\2715"; + color: #aaa; + margin-left: 8px; +} + +.react-tags__selected-tag:hover, +.react-tags__selected-tag:focus { + border-color: #b1b1b1; +} + +.react-tags__search { + display: inline-block; + + /* match tag layout */ + padding: 7px 2px; + margin-bottom: 6px; + + /* prevent autoresize overflowing the container */ + max-width: 100%; +} + +@media screen and (min-width: 30em) { + .react-tags__search { + /* this will become the offsetParent for suggestions */ + position: relative; + } +} + +.react-tags__search input { + /* prevent autoresize overflowing the container */ + max-width: 100%; + + /* remove styles and layout from this element */ + margin: 0; + padding: 0; + border: 0; + outline: none; + + /* match the font styles */ + font-size: inherit; + line-height: inherit; +} + +.react-tags__search input::-ms-clear { + display: none; +} + +.react-tags__suggestions { + z-index: 9999999; + position: absolute; + top: 100%; + left: 0; + width: 100%; +} + +@media screen and (min-width: 30em) { + .react-tags__suggestions { + width: 100px; + } +} + +.react-tags__suggestions ul { + margin: 4px -1px; + padding: 0; + list-style: none; + background: rgba(255, 255, 255, 0.95); + border: 1px solid #d1d1d1; + border-radius: 2px; + box-shadow: 0 2px 6px rgba(0, 0, 0, 0.2); +} + +.react-tags__suggestions li { + border-bottom: 1px solid #ddd; + padding: 6px 8px; +} + +.react-tags__suggestions li mark { + text-decoration: underline; + background: none; + font-weight: 600; +} + +.react-tags__suggestions li:hover { + cursor: pointer; + background: #eee; +} + +.react-tags__suggestions li.is-active { + background: #b7cfe0; +} + +.react-tags__suggestions li.is-disabled { + opacity: 0.5; + cursor: auto; +} diff --git a/captum/insights/attr_vis/frontend/src/App.module.css b/captum/insights/attr_vis/frontend/src/App.module.css new file mode 100644 index 0000000000000000000000000000000000000000..0c658f89bbd72c70e092c1a05e7958784640421a --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/App.module.css @@ -0,0 +1,376 @@ +.app { + width: 100%; + height: 100%; +} + +.header { +} + +.header__name { + font-size: 1.5em; + font-weight: bold; + padding: 16px 32px; + text-transform: uppercase; +} + +.header__nav { + border-bottom: solid 1px #ccd0d5; + padding-left: 32px; + background-color: white; + padding-top: 4px; +} + +.header__nav ul { + list-style: none; + margin: 0; + padding: 0; +} + +.header__nav__item { + color: #606770; + display: inline-block; + font-size: 1.2em; + line-height: 2em; + margin: 0 8px; + padding: 0 12px; +} + +.header__nav__item--active { + color: black; + font-weight: 600; + border-bottom: solid 4px #ee4c2c; +} + +.filter-panel { + background-color: white; + padding: 0 24px; +} + +.filter-panel, +.viz__panel { + align-content: space-between; + display: flex; + flex-direction: row; +} + +.filter-panel__column { + width: 33%; + padding: 12px 8px; +} + +.filter-panel__column__title { + font-weight: bold; + color: #1c1e21; + padding-bottom: 12px; +} + +.filter-panel__column__body { + color: #606770; +} + +.filter-panel__column--end { + flex-grow: 1; + align-self: center; + width: auto; +} + +.select { + background: none; + border: none; + border-radius: 0; + text-align-last: center; + padding: 0 8px; + margin: 0 4px; + border-bottom: solid 1px #1c1e21; + font-size: 1em; + appearance: none; + color: #1c1e21; +} + +.input { + background: none; + box-shadow: none; + border: none; + font-size: 1em; + border-bottom: solid 1px #1c1e21; + text-align-last: center; + padding: 0; + margin: 0; +} + +.input--narrow { + width: 100px; +} + +.row { + display: block; +} + +.row--padding { + margin: 8px 0; +} + +.btn { + border: solid 1px #ee4c2c; + background: white; + text-align: center; + font-weight: 600; + font-size: 1em; + border-radius: 4px; + padding: 6px 8px; + display: inline-block; + cursor: pointer; +} + +.btn--large { + font-size: 1.1em; + padding: 8px 10px; +} + +.btn--outline { + color: #ee4c2c; +} + +.btn--outline:hover { + background-color: rgba(0, 0, 0, 0.05); + border: solid 1px #ee4c2c; +} + +.btn--solid { + background-color: #ee4c2c; + color: white; +} + +.btn--solid:hover { + background-color: #d7725e; +} + +.viz { + display: block; +} + +.loading { + margin-top: 150px; + position: absolute; + width: 100%; + align-items: center; + justify-content: center; + display: flex; +} + +.panel { + margin: 16px; + padding: 24px; + background: white; + border-radius: 8px; + box-shadow: 0px 3px 6px 0px rgba(0, 0, 0, 0.18); + transition: opacity 0.2s; /* for loading */ + overflow-y: scroll; +} + +.panel__column__title { + font-weight: 700; + border-bottom: 2px solid #c1c1c1; + color: #1c1e21; + padding-bottom: 2px; + margin-bottom: 15px; +} + +.panel--loading { + opacity: 0.5; + pointer-events: none; /* disables all interactions inside panel */ +} + +.panel--center { + display: flex; + align-items: center; + justify-content: center; +} + +.panel__column { + padding: 0 8px; + flex: 1; +} + +.panel__column--stretch { + flex-grow: 3; +} + +.gallery { + display: flex; +} + +.gallery__item + .gallery__item { + padding: 0 8px; +} + +.gallery__item__image img { + height: 200px; + width: auto; +} + +.gallery__item__description { + text-align: center; +} + +.bar-chart__group { + padding: 2px 0; + display: flex; +} + +.bar-chart__group__bar { + width: 10px; + border-radius: 2px; + flex-shrink: 0; +} + +.bar-chart__group__title { + padding-left: 8px; +} + +.percentage-blue { + background-color: #80aaff; +} + +.percentage-light-blue { + background-color: rgba(128, 170, 255, 0.6); +} + +.percentage-light-red { + background-color: #e79c8d; +} + +.percentage-red { + background-color: #d45c43; +} + +.percentage-gray { + background: #c6c6c6; +} + +.percentage-white { + background-color: auto; +} + +.text-feature-word { + position: relative; + display: inline-block; + border-radius: 2px; + padding: 2px; +} + +.tooltip__label { + z-index: 999999; + max-width: 200px; + background: rgba(0, 0, 0, 0.75); + position: absolute; + text-align: center; + border-radius: 4px; + color: white; + padding: 4px; + font-size: 1.1em; + font-weight: 600; + visibility: hidden; + width: 80px; + bottom: 100%; + left: 50%; + margin-left: -40px; +} + +.tooltip__label::after { + z-index: 999999; + content: " "; + position: absolute; + top: 100%; + left: 50%; + margin-left: -6px; + border-style: solid; + border-width: 6px; + border-color: rgba(0, 0, 0, 0.75) transparent transparent transparent; +} + +.text-feature-word:hover .tooltip__label { + visibility: visible; +} + +.general-feature__label-container { + display: inline-flex; + justify-content: space-between; + width: 20%; + padding-right: 8px; +} + +.general-feature__label { + font-weight: 600; +} + +.general-feature__percent { + display: inline-block; +} + +.general-feature__bar-container { + display: inline-block; + width: 70%; +} +.general-feature__bar { + display: inline-block; + border-radius: 4px; + height: 12px; +} + +.general-feature__bar__positive { + background: #80aaff; +} + +.general-feature__bar__negative { + background: #d45c43; +} + +.spinner { + display: inline-block; + width: 64px; + height: 64px; +} + +.spinner:after { + content: " "; + display: block; + width: 46px; + height: 46px; + margin: 1px; + border-radius: 50%; + border: 5px solid #ee4c2c; + border-color: #ee4c2c transparent #ee4c2c transparent; + animation: spinner 1.2s linear infinite; +} + +@keyframes spinner { + 0% { + transform: rotate(0deg); + } + 100% { + transform: rotate(360deg); + } +} + +.visualization-container { + display: flex; +} + +.model-number { + display: block; + height: 2em; + font-size: 16px; + font-weight: 800; +} + +.model-number-spacer { + display: block; + height: 2em; +} + +.model-separator { + width: 100%; + border-bottom: 2px solid #c1c1c1; + margin: 10px 0px; +} diff --git a/captum/insights/attr_vis/frontend/src/App.test.js b/captum/insights/attr_vis/frontend/src/App.test.js new file mode 100644 index 0000000000000000000000000000000000000000..b82a25bb7220a3de9a35abb68edb29235b23c048 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/App.test.js @@ -0,0 +1,9 @@ +import React from "react"; +import ReactDOM from "react-dom"; +import WebApp from "./WebApp"; + +it("renders without crashing", () => { + const div = document.createElement("div"); + ReactDOM.render(, div); + ReactDOM.unmountComponentAtNode(div); +}); diff --git a/captum/insights/attr_vis/frontend/src/App.tsx b/captum/insights/attr_vis/frontend/src/App.tsx new file mode 100644 index 0000000000000000000000000000000000000000..a829cb693e84811460e923dce0210f9cc8b13c7a --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/App.tsx @@ -0,0 +1,99 @@ +import React from "react"; +import styles from "./App.module.css"; +import Header from "./components/Header"; +import cx from "./utils/cx"; +import Spinner from "./components/Spinner"; +import FilterContainer from "./components/FilterContainer"; +import VisualizationGroupDisplay from "./components/VisualizationGroup"; +import "./App.css"; +import { VisualizationGroup } from "./models/visualizationOutput"; +import { FilterConfig } from "./models/filter"; + +interface VisualizationsProps { + loading: boolean; + data: VisualizationGroup[]; + onTargetClick: ( + labelIndex: number, + inputIndex: number, + modelIndex: number, + callback: () => void + ) => void; +} + +function Visualizations(props: VisualizationsProps) { + if (props.loading) { + return ( +
+
+ +
+
+ ); + } + + if (!props.data || props.data.length === 0) { + return ( +
+
+
+ Please press{" "} + Fetch to + start loading data. +
+
+
+ ); + } + return ( +
+ {props.data.map((vg, i) => ( + + ))} +
+ ); +} + +interface AppBaseProps { + fetchInit: () => void; + fetchData: (filter_config: FilterConfig) => void; + config: any; + data: VisualizationGroup[]; + loading: boolean; + onTargetClick: ( + labelIndex: number, + inputIndex: number, + modelIndex: number, + callback: () => void + ) => void; +} + +class AppBase extends React.Component { + componentDidMount() { + this.props.fetchInit(); + } + + render() { + return ( +
+
+ + +
+ ); + } +} + +export default AppBase; diff --git a/captum/insights/attr_vis/frontend/src/WebApp.tsx b/captum/insights/attr_vis/frontend/src/WebApp.tsx new file mode 100644 index 0000000000000000000000000000000000000000..7390565ab6ecb6f9e3cb34d108472d09dd18d094 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/WebApp.tsx @@ -0,0 +1,84 @@ +import React from "react"; +import AppBase from "./App"; +import { FilterConfig } from "./models/filter"; +import { VisualizationGroup } from "./models/visualizationOutput"; +import { InsightsConfig } from "./models/insightsConfig"; + +interface WebAppState { + data: VisualizationGroup[]; + config: InsightsConfig; + loading: boolean; +} + +class WebApp extends React.Component<{}, WebAppState> { + constructor(props: {}) { + super(props); + this.state = { + data: [], + config: { + classes: [], + methods: [], + method_arguments: {}, + selected_method: "", + }, + loading: false, + }; + this._fetchInit(); + } + + _fetchInit = () => { + fetch("init") + .then((r) => r.json()) + .then((r) => this.setState({ config: r })); + }; + + fetchData = (filter_config: FilterConfig) => { + this.setState({ loading: true }); + fetch("fetch", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(filter_config), + }) + .then((response) => response.json()) + .then((response) => this.setState({ data: response, loading: false })); + }; + + onTargetClick = ( + labelIndex: number, + inputIndex: number, + modelIndex: number, + callback: () => void + ) => { + fetch("attribute", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ labelIndex, inputIndex, modelIndex }), + }) + .then((response) => response.json()) + .then((response) => { + const data = this.state.data ?? []; + data[inputIndex][modelIndex] = response; + this.setState({ data }); + callback(); + }); + }; + + render() { + return ( + + ); + } +} + +export default WebApp; diff --git a/captum/insights/attr_vis/frontend/src/components/Arguments.tsx b/captum/insights/attr_vis/frontend/src/components/Arguments.tsx new file mode 100644 index 0000000000000000000000000000000000000000..1e98b0fc89437051e2702b555ca6996ea4c0b3e8 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/components/Arguments.tsx @@ -0,0 +1,65 @@ +import React from "react"; +import styles from "../App.module.css"; +import cx from "../utils/cx"; +import { GenericArgumentConfig } from "../models/insightsConfig"; +import { UserInputField } from "../models/typeHelpers"; + +interface ArgumentProps { + name: string; + handleInputChange: React.ChangeEventHandler; +} + +function NumberArgument(props: ArgumentProps & GenericArgumentConfig) { + var min = props.limit[0]; + var max = props.limit[1]; + return ( +
+ {props.name}: + +
+ ); +} + +function EnumArgument(props: ArgumentProps & GenericArgumentConfig) { + const options = props.limit.map((item, key) => ( + + )); + return ( +
+ {props.name}: + +
+ ); +} + +function StringArgument(props: ArgumentProps & { value: string }) { + return ( +
+ {props.name}: + +
+ ); +} + +export { StringArgument, EnumArgument, NumberArgument }; diff --git a/captum/insights/attr_vis/frontend/src/components/ClassFilter.tsx b/captum/insights/attr_vis/frontend/src/components/ClassFilter.tsx new file mode 100644 index 0000000000000000000000000000000000000000..d169145bf29c34eaa170535eea9446567e40da22 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/components/ClassFilter.tsx @@ -0,0 +1,38 @@ +import React from "react"; +import ReactTags from "react-tag-autocomplete"; +import { TagClass } from "../models/filter"; + +interface ClassFilterProps { + suggestedClasses: TagClass[]; + classes: TagClass[]; + handleClassDelete: (classId: number) => void; + handleClassAdd: (newClass: TagClass) => void; +} + +function ClassFilter(props: ClassFilterProps) { + const handleAddition = (newTag: { id: number | string; name: string }) => { + /** + * Need this type check as we expect tagId to be number while the `react-tag-autocomplete` has + * id as number | string. + */ + if (typeof newTag.id === "string") { + throw Error("Invalid tag id received from ReactTags"); + } else { + props.handleClassAdd({ id: newTag.id, name: newTag.name }); + } + }; + + return ( + + ); +} + +export default ClassFilter; diff --git a/captum/insights/attr_vis/frontend/src/components/Contributions.tsx b/captum/insights/attr_vis/frontend/src/components/Contributions.tsx new file mode 100644 index 0000000000000000000000000000000000000000..18c401547ae80087adba6a659c12c6a00d064d9e --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/components/Contributions.tsx @@ -0,0 +1,35 @@ +import React from "react"; +import styles from "../App.module.css"; +import { calcHSLFromScore } from "../utils/color"; +import { FeatureOutput } from "../models/visualizationOutput"; + +interface ContributionsProps { + feature_outputs: FeatureOutput[]; +} + +function Contributions(props: ContributionsProps) { + return ( + <> + {props.feature_outputs.map((f) => { + // pad bar height so features with 0 contribution can still be seen + // in graph + const contribution = f.contribution * 100; + const bar_height = contribution > 10 ? contribution : contribution + 10; + return ( +
+
+
{f.name}
+
+ ); + })} + + ); +} + +export default Contributions; diff --git a/captum/insights/attr_vis/frontend/src/components/Feature.tsx b/captum/insights/attr_vis/frontend/src/components/Feature.tsx new file mode 100644 index 0000000000000000000000000000000000000000..9060cf9331f3045b6f0a7bf80508e842d8a9adc6 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/components/Feature.tsx @@ -0,0 +1,164 @@ +import { calcHSLFromScore } from "../utils/color"; +import { DataPoint } from "../utils/dataPoint"; +import React from "react"; +import styles from "../App.module.css"; +import Tooltip from "./Tooltip"; +import { Bar } from "react-chartjs-2"; +import { FeatureOutput } from "../models/visualizationOutput"; + +interface FeatureProps { + data: T; + hideHeaders?: boolean; +} + +type ImageFeatureProps = FeatureProps<{ + base: string; + modified: string; + name: string; +}>; + +function ImageFeature(props: ImageFeatureProps) { + return ( + <> + {props.hideHeaders && ( +
+ {props.data.name} (Image) +
+ )} +
+
+
+
+
+ original +
+
Original
+
+
+
+ attribution +
+
+ Attribution Magnitude +
+
+
+
+ + ); +} + +type TextFeatureProps = FeatureProps<{ + base: number[]; + name: string; + modified: number[]; +}>; + +function TextFeature(props: TextFeatureProps) { + const color_words = props.data.base.map((w, i) => { + return ( + <> + + {w} + + {" "} + + ); + }); + return ( + <> + {props.hideHeaders && ( +
+ {props.data.name} (Text) +
+ )} +
+
+ {color_words} +
+ + ); +} + +type GeneralFeatureProps = FeatureProps<{ + base: number[]; + modified: number[]; + name: string; +}>; + +function GeneralFeature(props: GeneralFeatureProps) { + const data = { + labels: props.data.base, + datasets: [ + { + barPercentage: 0.5, + data: props.data.modified, + backgroundColor: (dataPoint: DataPoint) => { + if (!dataPoint.dataset || !dataPoint.dataset.data || dataPoint.datasetIndex === undefined) { + return "#d45c43"; // Default to red + } + const yValue = dataPoint.dataset.data[dataPoint.dataIndex as number] || 0; + return yValue < 0 ? "#d45c43" : "#80aaff"; // Red if negative, else blue + }, + }, + ], + }; + + return ( + + ); +} + +function Feature(props: { data: FeatureOutput; hideHeaders: boolean }) { + const data = props.data; + switch (data.type) { + case "image": + return ; + case "text": + return ; + case "general": + return ; + case "empty": + return <>; + default: + throw new Error("Unsupported feature visualization type: " + data.type); + } +} + +export default Feature; diff --git a/captum/insights/attr_vis/frontend/src/components/Filter.tsx b/captum/insights/attr_vis/frontend/src/components/Filter.tsx new file mode 100644 index 0000000000000000000000000000000000000000..25e235b6ecc93db783cb8933edab7f9aff8b6260 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/components/Filter.tsx @@ -0,0 +1,156 @@ +import React from "react"; +import { StringArgument, EnumArgument, NumberArgument } from "./Arguments"; +import cx from "../utils/cx"; +import styles from "../App.module.css"; +import ClassFilter from "./ClassFilter"; +import { + MethodsArguments, + ArgumentConfig, + ArgumentType, +} from "../models/insightsConfig"; +import { TagClass } from "../models/filter"; +import { UserInputField } from "../models/typeHelpers"; + +interface FilterProps { + prediction: string; + selectedMethod: string; + methodArguments: MethodsArguments; + suggestedClasses: TagClass[]; + classes: TagClass[]; + methods: string[]; + handleInputChange: React.ChangeEventHandler; + handleArgumentChange: React.ChangeEventHandler; + handleSubmit: React.FormEventHandler; + handleClassAdd: (newClass: TagClass) => void; + handleClassDelete: (id: number) => void; +} + +function Filter(props: FilterProps) { + const createComponentFromConfig = (name: string, config: ArgumentConfig) => { + switch (config.type) { + case ArgumentType.Number: + return ( + + ); + case ArgumentType.Enum: + return ( + + ); + case ArgumentType.String: + return ( + + ); + default: + throw new Error("Unsupported config type: " + config.type); + } + }; + + const methods = props.methods.map((item, key) => ( + + )); + var method_args_components = null; + if (props.selectedMethod in props.methodArguments) { + const method_arguments = props.methodArguments[props.selectedMethod]; + method_args_components = Object.keys(method_arguments).map((key, idx) => + createComponentFromConfig(key, method_arguments[key]) + ); + } + return ( +
+
+
+
+ Filter by Classes +
+
+ +
+
+
+
+ Filter by Instances +
+
+ Prediction:{" "} + +
+
+
+
+ Choose Attribution Method +
+
+ Attribution Method:{" "} + +
+
+
+
+ Attribution Method Arguments +
+
+ {method_args_components} +
+
+
+ +
+
+
+ ); +} + +export default Filter; diff --git a/captum/insights/attr_vis/frontend/src/components/FilterContainer.tsx b/captum/insights/attr_vis/frontend/src/components/FilterContainer.tsx new file mode 100644 index 0000000000000000000000000000000000000000..52871ce67a8d0d4f3cac4e79aa4d34fa4d66f71f --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/components/FilterContainer.tsx @@ -0,0 +1,124 @@ +import React from "react"; +import Filter from "./Filter"; +import { InsightsConfig, MethodsArguments } from "../models/insightsConfig"; +import { TagClass, FilterConfig } from "../models/filter"; +import { UserInputField } from "../models/typeHelpers"; + +function parseEventTargetValue(target: UserInputField) { + switch (target.type) { + case "checkbox": + return (target as HTMLInputElement).checked; + case "number": + return parseInt(target.value); + default: + return target.value; + } +} + +interface FilterContainerProps { + config: InsightsConfig; + fetchData: (filter_config: FilterConfig) => void; +} + +interface FilterContainerState { + prediction: string; + classes: TagClass[]; + suggested_classes: TagClass[]; + selected_method: string; + method_arguments: MethodsArguments; +} + +class FilterContainer extends React.Component< + FilterContainerProps, + FilterContainerState +> { + constructor(props: FilterContainerProps) { + super(props); + const suggested_classes = props.config.classes.map((c, classId) => ({ + id: classId, + name: c, + })); + this.state = { + prediction: "all", + classes: [], + suggested_classes: suggested_classes, + selected_method: props.config.selected_method, + method_arguments: props.config.method_arguments, + }; + } + + handleClassDelete = (classId: number) => { + const classes = this.state.classes.slice(0); + const removed_class = classes.splice(classId, 1); + const suggested_classes = [ + ...this.state.suggested_classes, + ...removed_class, + ]; + this.setState({ classes, suggested_classes }); + }; + + handleClassAdd = (added_class: TagClass) => { + const classes = [...this.state.classes, added_class]; + const suggested_classes = this.state.suggested_classes.filter( + (t) => t.id !== added_class.id + ); + this.setState({ classes, suggested_classes }); + }; + + handleInputChange = (event: React.ChangeEvent) => { + const target = event.target; + const value = parseEventTargetValue(event.target); + const name = target.name; + this.setState({ + [name]: value, + } as any); + }; + + handleArgumentChange = (event: React.ChangeEvent) => { + const target = event.target; + const name = target.name; + const value = parseEventTargetValue(target); + const method_arguments = this.state.method_arguments; + method_arguments[this.state.selected_method][name].value = value; + this.setState({ method_arguments }); + }; + + handleSubmit = (event: React.FormEvent) => { + const method = this.state.selected_method; + const method_arguments = this.state.method_arguments; + const argument_config = + method in method_arguments ? method_arguments[method] : {}; + const args: { [key: string]: string | boolean | number } = {}; + Object.keys(argument_config).forEach(function (key) { + args[key] = argument_config[key].value; + }); + const data = { + prediction: this.state.prediction, + classes: this.state.classes.map((classId) => classId["name"]), + attribution_method: method, + arguments: args, + }; + this.props.fetchData(data); + event.preventDefault(); + }; + + render() { + return ( + + ); + } +} + +export default FilterContainer; diff --git a/captum/insights/attr_vis/frontend/src/components/Header.tsx b/captum/insights/attr_vis/frontend/src/components/Header.tsx new file mode 100644 index 0000000000000000000000000000000000000000..4d2ef187e9075239fc2dc7ac7efd2b3b856bef85 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/components/Header.tsx @@ -0,0 +1,25 @@ +import React from "react"; +import styles from "../App.module.css"; +import cx from "../utils/cx"; + +function Header() { + return ( +
+
Captum Insights
+ +
+ ); +} + +export default Header; diff --git a/captum/insights/attr_vis/frontend/src/components/LabelButton.tsx b/captum/insights/attr_vis/frontend/src/components/LabelButton.tsx new file mode 100644 index 0000000000000000000000000000000000000000..68589be0219e57647577cb6fe276b094e9457531 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/components/LabelButton.tsx @@ -0,0 +1,37 @@ +import React from "react"; +import cx from "../utils/cx"; +import styles from "../App.module.css"; + +interface LabelButtonProps { + labelIndex: number; + inputIndex: number; + modelIndex: number; + active: boolean; + onTargetClick: ( + labelIndex: number, + inputIndex: number, + modelIndex: number + ) => void; +} + +function LabelButton(props: React.PropsWithChildren) { + const onClick = (e: React.MouseEvent) => { + e.preventDefault(); + props.onTargetClick(props.labelIndex, props.inputIndex, props.modelIndex); + }; + + return ( + + ); +} + +export default LabelButton; diff --git a/captum/insights/attr_vis/frontend/src/components/Spinner.tsx b/captum/insights/attr_vis/frontend/src/components/Spinner.tsx new file mode 100644 index 0000000000000000000000000000000000000000..9f63bd69d621810d4a5bb579586cda0c9e81d0aa --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/components/Spinner.tsx @@ -0,0 +1,8 @@ +import React from "react"; +import styles from "../App.module.css"; + +function Spinner() { + return
; +} + +export default Spinner; diff --git a/captum/insights/attr_vis/frontend/src/components/Tooltip.tsx b/captum/insights/attr_vis/frontend/src/components/Tooltip.tsx new file mode 100644 index 0000000000000000000000000000000000000000..9e2cecf6f171233b5b060e3489c340ec36d581fd --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/components/Tooltip.tsx @@ -0,0 +1,13 @@ +import React from "react"; + +import styles from "../App.module.css"; + +function Tooltip(props: { label: string }) { + return ( +
+
{props.label}
+
+ ); +} + +export default Tooltip; diff --git a/captum/insights/attr_vis/frontend/src/components/Visualization.tsx b/captum/insights/attr_vis/frontend/src/components/Visualization.tsx new file mode 100644 index 0000000000000000000000000000000000000000..8f8d04119c61f32c9a74aa63b11e5d361aafd03b --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/components/Visualization.tsx @@ -0,0 +1,131 @@ +import React from "react"; +import styles from "../App.module.css"; +import cx from "../utils/cx"; +import Feature from "./Feature"; +import Spinner from "./Spinner"; +import LabelButton from "./LabelButton"; +import Contributions from "./Contributions"; +import { VisualizationOutput } from "../models/visualizationOutput"; + +interface VisualizationProps { + data: VisualizationOutput; + instance: number; + onTargetClick: ( + labelIndex: number, + inputIndex: number, + modelIndex: number, + callback: () => void + ) => void; +} + +interface VisualizationState { + loading: boolean; +} + +class Visualization extends React.Component< + VisualizationProps, + VisualizationState +> { + constructor(props: VisualizationProps) { + super(props); + this.state = { + loading: false, + }; + } + + onTargetClick = ( + labelIndex: number, + inputIndex: number, + modelIndex: number + ) => { + this.setState({ loading: true }); + this.props.onTargetClick(labelIndex, inputIndex, modelIndex, () => + this.setState({ loading: false }) + ); + }; + + //TODO: Refactor the visualization table as a instead of columns, in order to have cleaner styling + render() { + const data = this.props.data; + const isFirstInGroup = this.props.data.model_index === 0; + const features = data.feature_outputs.map((f) => ( + + )); + + return ( + <> + {this.state.loading && ( +
+ +
+ )} + {!isFirstInGroup &&
} +
+
+ {isFirstInGroup && ( +
Predicted
+ )} +
+
+ Model {data.model_index + 1} +
+ {data.predicted.map((p) => ( +
+ + {p.label} ({p.score.toFixed(3)}) + +
+ ))} +
+
+
+ {isFirstInGroup && ( +
Label
+ )} +
+
+
+ + {data.actual.label} + +
+
+
+
+ {isFirstInGroup && ( +
Contribution
+ )} +
+
+
+ +
+
+
+
+ {features} +
+
+ + ); + } +} + +export default Visualization; diff --git a/captum/insights/attr_vis/frontend/src/components/VisualizationGroup.tsx b/captum/insights/attr_vis/frontend/src/components/VisualizationGroup.tsx new file mode 100644 index 0000000000000000000000000000000000000000..023699413fc2886d36fc93869be474e981362406 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/components/VisualizationGroup.tsx @@ -0,0 +1,38 @@ +import React from "react"; +import styles from "../App.module.css"; +import cx from "../utils/cx"; +import Visualization from "../components/Visualization"; +import { VisualizationGroup } from "../models/visualizationOutput"; + +interface VisualizationGroupDisplayProps { + inputIndex: number; + data: VisualizationGroup; + onTargetClick: ( + labelIndex: number, + inputIndex: number, + modelIndex: number, + callback: () => void + ) => void; +} + +function VisualizationGroupDisplay(props: VisualizationGroupDisplayProps) { + return ( +
+ {props.data.map((v, i) => ( + + ))} +
+ ); +} + +export default VisualizationGroupDisplay; diff --git a/captum/insights/attr_vis/frontend/src/index.css b/captum/insights/attr_vis/frontend/src/index.css new file mode 100644 index 0000000000000000000000000000000000000000..fb3e89d8ed1a411a010ce7d00af7878a8c5f2ac2 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/index.css @@ -0,0 +1,17 @@ +html, +body { + height: 100%; +} + +body { + margin: 0; + font-family: "Segoe UI", "Roboto", "Oxygen", "Ubuntu", "Cantarell", + "Fira Sans", "Droid Sans", "Helvetica Neue", sans-serif; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; + background-color: rgba(0, 0, 0, 0.05); +} + +* { + box-sizing: border-box; +} diff --git a/captum/insights/attr_vis/frontend/src/index.tsx b/captum/insights/attr_vis/frontend/src/index.tsx new file mode 100644 index 0000000000000000000000000000000000000000..e496bdcd5c220dc5656b341dffe942ae61bbb3be --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/index.tsx @@ -0,0 +1,6 @@ +import React from "react"; +import ReactDOM from "react-dom"; +import "./index.css"; +import WebApp from "./WebApp"; + +ReactDOM.render(, document.getElementById("root")); diff --git a/captum/insights/attr_vis/frontend/src/models/filter.ts b/captum/insights/attr_vis/frontend/src/models/filter.ts new file mode 100644 index 0000000000000000000000000000000000000000..4002ccd44b71c7ce74aad9256d3633b6a225d146 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/models/filter.ts @@ -0,0 +1,11 @@ +export interface FilterConfig { + attribution_method: string; + arguments: { [key: string]: any }; + prediction: string; + classes: string[]; +} + +export interface TagClass { + id: number; + name: string; +} diff --git a/captum/insights/attr_vis/frontend/src/models/insightsConfig.ts b/captum/insights/attr_vis/frontend/src/models/insightsConfig.ts new file mode 100644 index 0000000000000000000000000000000000000000..54afb547ee6b7e7b8d910a31d6d848c214e0838b --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/models/insightsConfig.ts @@ -0,0 +1,30 @@ +export enum ArgumentType { + Number = "number", + Enum = "enum", + String = "string", + Boolean = "boolean", +} + +export type GenericArgumentConfig = { + value: T; + limit: T[]; +}; + +export type ArgumentConfig = + | ({ type: ArgumentType.Number } & GenericArgumentConfig) + | ({ type: ArgumentType.Enum } & GenericArgumentConfig) + | ({ type: ArgumentType.String } & { value: string }) + | ({ type: ArgumentType.Boolean } & { value: boolean }); + +export interface MethodsArguments { + [method_name: string]: { + [arg_name: string]: ArgumentConfig; + }; +} + +export interface InsightsConfig { + classes: string[]; + methods: string[]; + method_arguments: MethodsArguments; + selected_method: string; +} diff --git a/captum/insights/attr_vis/frontend/src/models/typeHelpers.ts b/captum/insights/attr_vis/frontend/src/models/typeHelpers.ts new file mode 100644 index 0000000000000000000000000000000000000000..b1386f36efb4342013b63979568a9110055ca099 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/models/typeHelpers.ts @@ -0,0 +1 @@ +export type UserInputField = HTMLInputElement | HTMLSelectElement; diff --git a/captum/insights/attr_vis/frontend/src/models/visualizationOutput.ts b/captum/insights/attr_vis/frontend/src/models/visualizationOutput.ts new file mode 100644 index 0000000000000000000000000000000000000000..923d5e54388d060965fdf803c92c60ec909bde0e --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/models/visualizationOutput.ts @@ -0,0 +1,41 @@ +interface OutputScore { + label: string; + index: number; + score: number; +} + +export enum FeatureType { + TEXT = "text", + IMAGE = "image", + GENERAL = "general", + EMPTY = "empty", +} + +type GenericFeatureOutput = { + type: F; + name: string; + contribution: number; +} & T; + +export type FeatureOutput = + | GenericFeatureOutput< + FeatureType.TEXT, + { base: number[]; modified: number[] } + > + | GenericFeatureOutput + | GenericFeatureOutput< + FeatureType.GENERAL, + { base: number[]; modified: number[] } + > + | GenericFeatureOutput; + +export interface VisualizationOutput { + model_index: number; + feature_outputs: FeatureOutput[]; + actual: OutputScore; + predicted: OutputScore[]; + active_index: number; +} + +//When multiple models are compared, visualizations are grouped together +export type VisualizationGroup = VisualizationOutput[]; diff --git a/captum/insights/attr_vis/frontend/src/react-app-env.d.ts b/captum/insights/attr_vis/frontend/src/react-app-env.d.ts new file mode 100644 index 0000000000000000000000000000000000000000..6431bc5fc6b2c932dfe5d0418fc667b86c18b9fc --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/react-app-env.d.ts @@ -0,0 +1 @@ +/// diff --git a/captum/insights/attr_vis/frontend/src/utils/color.ts b/captum/insights/attr_vis/frontend/src/utils/color.ts new file mode 100644 index 0000000000000000000000000000000000000000..6c4df2149bd66b2ea7f9e629b57d549f0bd7eb23 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/utils/color.ts @@ -0,0 +1,26 @@ +function calcHSLFromScore(percentage: number, zeroDefault = false) { + const blue_hsl = [220, 100, 80]; + const red_hsl = [10, 100, 67]; + + let target_hsl = null; + if (percentage > 0) { + target_hsl = blue_hsl; + } else { + target_hsl = red_hsl; + } + + const default_hsl = [0, 40, zeroDefault ? 100 : 90]; + const abs_percent = Math.abs(percentage * 0.01); + if (abs_percent < 0.02) { + return `hsl(${default_hsl[0]}, ${default_hsl[1]}%, ${default_hsl[2]}%)`; + } + + const color = [ + target_hsl[0], + (target_hsl[1] - default_hsl[1]) * abs_percent + default_hsl[1], + (target_hsl[2] - default_hsl[2]) * abs_percent + default_hsl[2], + ]; + return `hsl(${color[0]}, ${color[1]}%, ${color[2]}%)`; +} + +export { calcHSLFromScore }; diff --git a/captum/insights/attr_vis/frontend/src/utils/cx.ts b/captum/insights/attr_vis/frontend/src/utils/cx.ts new file mode 100644 index 0000000000000000000000000000000000000000..3b8c3e44274251e4d4920ad6a1d3ca426fd07f40 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/utils/cx.ts @@ -0,0 +1,11 @@ +// helper method to convert an array or object into a valid classname +function cx(obj: any) { + if (Array.isArray(obj)) { + return obj.join(" "); + } + return Object.keys(obj) + .filter((k) => !!obj[k]) + .join(" "); +} + +export default cx; diff --git a/captum/insights/attr_vis/frontend/src/utils/dataPoint.ts b/captum/insights/attr_vis/frontend/src/utils/dataPoint.ts new file mode 100644 index 0000000000000000000000000000000000000000..c4852ec5feb3568937255b682b5de5a20b395223 --- /dev/null +++ b/captum/insights/attr_vis/frontend/src/utils/dataPoint.ts @@ -0,0 +1,11 @@ +import * as chartjs from "chart.js"; + +// Because there's no data point type exported by the +// main type declaration for chart.js, we have our own. + +export interface DataPoint { + chart?: object; + dataIndex?: number; + dataset?: chartjs.ChartDataSets; + datasetIndex?: number; +} diff --git a/captum/insights/attr_vis/frontend/widget/src/Widget.js b/captum/insights/attr_vis/frontend/widget/src/Widget.js new file mode 100644 index 0000000000000000000000000000000000000000..71efd4f4a3465e90e66e19080292c843aeb72aee --- /dev/null +++ b/captum/insights/attr_vis/frontend/widget/src/Widget.js @@ -0,0 +1,102 @@ +import React from "react"; +import ReactDOM from "react-dom"; +import AppBase from "../../src/App"; +import * as widgets from "@jupyter-widgets/base"; +import * as _ from "lodash"; + +class Widget extends React.Component { + constructor(props) { + super(props); + this.state = { + data: [], + config: { + classes: [], + methods: [], + method_arguments: {}, + }, + loading: false, + callback: null, + }; + this.backbone = this.props.backbone; + } + + componentDidMount() { + this.backbone.model.on("change:output", this._outputChanged, this); + this.backbone.model.on( + "change:attribution", + this._attributionChanged, + this + ); + } + + _outputChanged(model, output, options) { + if (_.isEmpty(output)) return; + this.setState({ data: output, loading: false }); + } + + _attributionChanged(model, attribution, options) { + if (_.isEmpty(attribution)) return; + const data = Object.assign([], this.state.data); + const callback = this.state.callback; + const labelDetails = model.attributes.label_details; + data[labelDetails.inputIndex][labelDetails.modelIndex] = attribution; + this.setState({ data: data, callback: null }, () => { + callback(); + }); + } + + _fetchInit = () => { + this.setState({ + config: this.backbone.model.get("insights_config"), + }); + }; + + fetchData = (filterConfig) => { + this.setState({ loading: true }, () => { + this.backbone.model.save({ config: filterConfig, output: [] }); + }); + }; + + onTargetClick = (labelIndex, inputIndex, modelIndex, callback) => { + this.setState({ callback: callback }, () => { + this.backbone.model.save({ + label_details: { labelIndex, inputIndex, modelIndex }, + attribution: {}, + }); + }); + }; + + render() { + return ( + + ); + } +} + +var CaptumInsightsModel = widgets.DOMWidgetModel.extend({ + defaults: _.extend(widgets.DOMWidgetModel.prototype.defaults(), { + _model_name: "CaptumInsightsModel", + _view_name: "CaptumInsightsView", + _model_module: "jupyter-captum-insights", + _view_module: "jupyter-captum-insights", + _model_module_version: "0.1.0", + _view_module_version: "0.1.0", + }), +}); + +var CaptumInsightsView = widgets.DOMWidgetView.extend({ + initialize() { + const $app = document.createElement("div"); + ReactDOM.render(, $app); + this.el.append($app); + }, +}); + +export { Widget as default, CaptumInsightsModel, CaptumInsightsView }; diff --git a/captum/insights/attr_vis/frontend/widget/src/embed.js b/captum/insights/attr_vis/frontend/widget/src/embed.js new file mode 100644 index 0000000000000000000000000000000000000000..ed7984351d51c3c1cf9edf03ff23e015d2492cc1 --- /dev/null +++ b/captum/insights/attr_vis/frontend/widget/src/embed.js @@ -0,0 +1,9 @@ +// Entry point for the unpkg bundle containing custom model definitions. +// +// It differs from the notebook bundle in that it does not need to define a +// dynamic baseURL for the static assets and may load some css that would +// already be loaded by the notebook otherwise. + +// Export widget models and views, and the npm package version number. +module.exports = require("./Widget.js"); +module.exports["version"] = require("../../package.json").version; diff --git a/captum/insights/attr_vis/frontend/widget/src/extension.js b/captum/insights/attr_vis/frontend/widget/src/extension.js new file mode 100644 index 0000000000000000000000000000000000000000..4a5214baebf23c64a0146ea9a4f231032122cf6e --- /dev/null +++ b/captum/insights/attr_vis/frontend/widget/src/extension.js @@ -0,0 +1,26 @@ +// This file contains the javascript that is run when the notebook is loaded. +// It contains some requirejs configuration and the `load_ipython_extension` +// which is required for any notebook extension. +// +// Some static assets may be required by the custom widget javascript. The base +// url for the notebook is not known at build time and is therefore computed +// dynamically. +__webpack_public_path__ = + document.querySelector("body").getAttribute("data-base-url") + + "nbextensions/jupyter-captum-insights"; + +// Configure requirejs +if (window.require) { + window.require.config({ + map: { + "*": { + "jupyter-captum-insights": "nbextensions/jupyter-captum-insights/index", + }, + }, + }); +} + +// Export the required load_ipython_extension +module.exports = { + load_ipython_extension: function () {}, +}; diff --git a/captum/insights/attr_vis/frontend/widget/src/index.js b/captum/insights/attr_vis/frontend/widget/src/index.js new file mode 100644 index 0000000000000000000000000000000000000000..1975ca891971854e2fa911530d4079294ae6d72e --- /dev/null +++ b/captum/insights/attr_vis/frontend/widget/src/index.js @@ -0,0 +1,2 @@ +module.exports = require("./Widget.js"); +module.exports["version"] = require("../../package.json").version; diff --git a/captum/insights/attr_vis/frontend/widget/webpack.config.js b/captum/insights/attr_vis/frontend/widget/webpack.config.js new file mode 100644 index 0000000000000000000000000000000000000000..f3c2b01fa3b831b9bb734f56af4a0e1ecab3dab8 --- /dev/null +++ b/captum/insights/attr_vis/frontend/widget/webpack.config.js @@ -0,0 +1,88 @@ +var path = require("path"); + +// Custom webpack rules are generally the same for all webpack bundles, hence +// stored in a separate local variable. +var rules = [ + { + test: /\.module.css$/, + use: [ + "style-loader", + { + loader: "css-loader", + options: { + modules: true, + }, + }, + ], + }, + { test: /^((?!\.module).)*.css$/, use: ["style-loader", "css-loader"] }, + { + test: /\.(js|ts|tsx)$/, + exclude: /node_modules/, + loaders: "babel-loader", + options: { + presets: [ + "@babel/preset-react", + "@babel/preset-env", + "@babel/preset-typescript", + ], + plugins: ["@babel/plugin-proposal-class-properties"], + }, + }, +]; + +var extensions = [".js", ".ts", ".tsx"]; + +module.exports = [ + { + // Notebook extension + // + // This bundle only contains the part of the JavaScript that is run on + // load of the notebook. This section generally only performs + // some configuration for requirejs, and provides the legacy + // "load_ipython_extension" function which is required for any notebook + // extension. + // + mode: "production", + entry: "./src/extension.js", + output: { + filename: "extension.js", + path: path.resolve(__dirname, "..", "..", "widget", "static"), + libraryTarget: "amd", + }, + resolveLoader: { + modules: ["../node_modules"], + extensions: extensions, + }, + resolve: { + modules: ["../node_modules"], + }, + externals: ["moment"], // Removes unused dependency-of-dependency + }, + { + // Bundle for the notebook containing the custom widget views and models + // + // This bundle contains the implementation for the custom widget views and + // custom widget. + // It must be an amd module + // + mode: "production", + entry: "./src/index.js", + output: { + filename: "index.js", + path: path.resolve(__dirname, "..", "..", "widget", "static"), + libraryTarget: "amd", + }, + module: { + rules: rules, + }, + resolveLoader: { + modules: ["../node_modules"], + }, + resolve: { + modules: ["../node_modules"], + extensions: extensions, + }, + externals: ["@jupyter-widgets/base", "moment"], + }, +]; diff --git a/captum/insights/attr_vis/frontend/yarn.lock b/captum/insights/attr_vis/frontend/yarn.lock new file mode 100644 index 0000000000000000000000000000000000000000..588f31aa81dc57be110ea1e7935ac5178a51e47a --- /dev/null +++ b/captum/insights/attr_vis/frontend/yarn.lock @@ -0,0 +1,11054 @@ +# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY. +# yarn lockfile v1 + + +"@babel/code-frame@7.8.3", "@babel/code-frame@^7.0.0", "@babel/code-frame@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/code-frame/-/code-frame-7.8.3.tgz#33e25903d7481181534e12ec0a25f16b6fcf419e" + integrity sha512-a9gxpmdXtZEInkCSHUJDLHZVBgb1QS0jhss4cPP93EW7s+uC5bikET2twEF3KV+7rDblJcmNvTR7VJejqd2C2g== + dependencies: + "@babel/highlight" "^7.8.3" + +"@babel/code-frame@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/code-frame/-/code-frame-7.10.1.tgz#d5481c5095daa1c57e16e54c6f9198443afb49ff" + integrity sha512-IGhtTmpjGbYzcEDOw7DcQtbQSXcG9ftmAXtWTu9V936vDye4xjjekktFAtgZsWpzTj/X01jocB46mTywm/4SZw== + dependencies: + "@babel/highlight" "^7.10.1" + +"@babel/compat-data@^7.8.6", "@babel/compat-data@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/compat-data/-/compat-data-7.9.0.tgz#04815556fc90b0c174abd2c0c1bb966faa036a6c" + integrity sha512-zeFQrr+284Ekvd9e7KAX954LkapWiOmQtsfHirhxqfdlX6MEC32iRE+pqUGlYIBchdevaCwvzxWGSy/YBNI85g== + dependencies: + browserslist "^4.9.1" + invariant "^2.2.4" + semver "^5.5.0" + +"@babel/core@7.9.0", "@babel/core@^7.1.0", "@babel/core@^7.4.5": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/core/-/core-7.9.0.tgz#ac977b538b77e132ff706f3b8a4dbad09c03c56e" + integrity sha512-kWc7L0fw1xwvI0zi8OKVBuxRVefwGOrKSQMvrQ3dW+bIIavBY3/NpXmpjMy7bQnLgwgzWQZ8TlM57YHpHNHz4w== + dependencies: + "@babel/code-frame" "^7.8.3" + "@babel/generator" "^7.9.0" + "@babel/helper-module-transforms" "^7.9.0" + "@babel/helpers" "^7.9.0" + "@babel/parser" "^7.9.0" + "@babel/template" "^7.8.6" + "@babel/traverse" "^7.9.0" + "@babel/types" "^7.9.0" + convert-source-map "^1.7.0" + debug "^4.1.0" + gensync "^1.0.0-beta.1" + json5 "^2.1.2" + lodash "^4.17.13" + resolve "^1.3.2" + semver "^5.4.1" + source-map "^0.5.0" + +"@babel/generator@^7.10.1": + version "7.10.2" + resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.10.2.tgz#0fa5b5b2389db8bfdfcc3492b551ee20f5dd69a9" + integrity sha512-AxfBNHNu99DTMvlUPlt1h2+Hn7knPpH5ayJ8OqDWSeLld+Fi2AYBTC/IejWDM9Edcii4UzZRCsbUt0WlSDsDsA== + dependencies: + "@babel/types" "^7.10.2" + jsesc "^2.5.1" + lodash "^4.17.13" + source-map "^0.5.0" + +"@babel/generator@^7.4.0", "@babel/generator@^7.9.0": + version "7.9.4" + resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.9.4.tgz#12441e90c3b3c4159cdecf312075bf1a8ce2dbce" + integrity sha512-rjP8ahaDy/ouhrvCoU1E5mqaitWrxwuNGU+dy1EpaoK48jZay4MdkskKGIMHLZNewg8sAsqpGSREJwP0zH3YQA== + dependencies: + "@babel/types" "^7.9.0" + jsesc "^2.5.1" + lodash "^4.17.13" + source-map "^0.5.0" + +"@babel/helper-annotate-as-pure@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-annotate-as-pure/-/helper-annotate-as-pure-7.8.3.tgz#60bc0bc657f63a0924ff9a4b4a0b24a13cf4deee" + integrity sha512-6o+mJrZBxOoEX77Ezv9zwW7WV8DdluouRKNY/IR5u/YTMuKHgugHOzYWlYvYLpLA9nPsQCAAASpCIbjI9Mv+Uw== + dependencies: + "@babel/types" "^7.8.3" + +"@babel/helper-builder-binary-assignment-operator-visitor@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-builder-binary-assignment-operator-visitor/-/helper-builder-binary-assignment-operator-visitor-7.8.3.tgz#c84097a427a061ac56a1c30ebf54b7b22d241503" + integrity sha512-5eFOm2SyFPK4Rh3XMMRDjN7lBH0orh3ss0g3rTYZnBQ+r6YPj7lgDyCvPphynHvUrobJmeMignBr6Acw9mAPlw== + dependencies: + "@babel/helper-explode-assignable-expression" "^7.8.3" + "@babel/types" "^7.8.3" + +"@babel/helper-builder-react-jsx-experimental@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/helper-builder-react-jsx-experimental/-/helper-builder-react-jsx-experimental-7.9.0.tgz#066d80262ade488f9c1b1823ce5db88a4cedaa43" + integrity sha512-3xJEiyuYU4Q/Ar9BsHisgdxZsRlsShMe90URZ0e6przL26CCs8NJbDoxH94kKT17PcxlMhsCAwZd90evCo26VQ== + dependencies: + "@babel/helper-annotate-as-pure" "^7.8.3" + "@babel/helper-module-imports" "^7.8.3" + "@babel/types" "^7.9.0" + +"@babel/helper-builder-react-jsx@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/helper-builder-react-jsx/-/helper-builder-react-jsx-7.9.0.tgz#16bf391990b57732700a3278d4d9a81231ea8d32" + integrity sha512-weiIo4gaoGgnhff54GQ3P5wsUQmnSwpkvU0r6ZHq6TzoSzKy4JxHEgnxNytaKbov2a9z/CVNyzliuCOUPEX3Jw== + dependencies: + "@babel/helper-annotate-as-pure" "^7.8.3" + "@babel/types" "^7.9.0" + +"@babel/helper-call-delegate@^7.8.7": + version "7.8.7" + resolved "https://registry.yarnpkg.com/@babel/helper-call-delegate/-/helper-call-delegate-7.8.7.tgz#28a279c2e6c622a6233da548127f980751324cab" + integrity sha512-doAA5LAKhsFCR0LAFIf+r2RSMmC+m8f/oQ+URnUET/rWeEzC0yTRmAGyWkD4sSu3xwbS7MYQ2u+xlt1V5R56KQ== + dependencies: + "@babel/helper-hoist-variables" "^7.8.3" + "@babel/traverse" "^7.8.3" + "@babel/types" "^7.8.7" + +"@babel/helper-compilation-targets@^7.8.7": + version "7.8.7" + resolved "https://registry.yarnpkg.com/@babel/helper-compilation-targets/-/helper-compilation-targets-7.8.7.tgz#dac1eea159c0e4bd46e309b5a1b04a66b53c1dde" + integrity sha512-4mWm8DCK2LugIS+p1yArqvG1Pf162upsIsjE7cNBjez+NjliQpVhj20obE520nao0o14DaTnFJv+Fw5a0JpoUw== + dependencies: + "@babel/compat-data" "^7.8.6" + browserslist "^4.9.1" + invariant "^2.2.4" + levenary "^1.1.1" + semver "^5.5.0" + +"@babel/helper-create-class-features-plugin@^7.10.1": + version "7.10.2" + resolved "https://registry.yarnpkg.com/@babel/helper-create-class-features-plugin/-/helper-create-class-features-plugin-7.10.2.tgz#7474295770f217dbcf288bf7572eb213db46ee67" + integrity sha512-5C/QhkGFh1vqcziq1vAL6SI9ymzUp8BCYjFpvYVhWP4DlATIb3u5q3iUd35mvlyGs8fO7hckkW7i0tmH+5+bvQ== + dependencies: + "@babel/helper-function-name" "^7.10.1" + "@babel/helper-member-expression-to-functions" "^7.10.1" + "@babel/helper-optimise-call-expression" "^7.10.1" + "@babel/helper-plugin-utils" "^7.10.1" + "@babel/helper-replace-supers" "^7.10.1" + "@babel/helper-split-export-declaration" "^7.10.1" + +"@babel/helper-create-class-features-plugin@^7.8.3": + version "7.8.6" + resolved "https://registry.yarnpkg.com/@babel/helper-create-class-features-plugin/-/helper-create-class-features-plugin-7.8.6.tgz#243a5b46e2f8f0f674dc1387631eb6b28b851de0" + integrity sha512-klTBDdsr+VFFqaDHm5rR69OpEQtO2Qv8ECxHS1mNhJJvaHArR6a1xTf5K/eZW7eZpJbhCx3NW1Yt/sKsLXLblg== + dependencies: + "@babel/helper-function-name" "^7.8.3" + "@babel/helper-member-expression-to-functions" "^7.8.3" + "@babel/helper-optimise-call-expression" "^7.8.3" + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/helper-replace-supers" "^7.8.6" + "@babel/helper-split-export-declaration" "^7.8.3" + +"@babel/helper-create-regexp-features-plugin@^7.8.3", "@babel/helper-create-regexp-features-plugin@^7.8.8": + version "7.8.8" + resolved "https://registry.yarnpkg.com/@babel/helper-create-regexp-features-plugin/-/helper-create-regexp-features-plugin-7.8.8.tgz#5d84180b588f560b7864efaeea89243e58312087" + integrity sha512-LYVPdwkrQEiX9+1R29Ld/wTrmQu1SSKYnuOk3g0CkcZMA1p0gsNxJFj/3gBdaJ7Cg0Fnek5z0DsMULePP7Lrqg== + dependencies: + "@babel/helper-annotate-as-pure" "^7.8.3" + "@babel/helper-regex" "^7.8.3" + regexpu-core "^4.7.0" + +"@babel/helper-define-map@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-define-map/-/helper-define-map-7.8.3.tgz#a0655cad5451c3760b726eba875f1cd8faa02c15" + integrity sha512-PoeBYtxoZGtct3md6xZOCWPcKuMuk3IHhgxsRRNtnNShebf4C8YonTSblsK4tvDbm+eJAw2HAPOfCr+Q/YRG/g== + dependencies: + "@babel/helper-function-name" "^7.8.3" + "@babel/types" "^7.8.3" + lodash "^4.17.13" + +"@babel/helper-explode-assignable-expression@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-explode-assignable-expression/-/helper-explode-assignable-expression-7.8.3.tgz#a728dc5b4e89e30fc2dfc7d04fa28a930653f982" + integrity sha512-N+8eW86/Kj147bO9G2uclsg5pwfs/fqqY5rwgIL7eTBklgXjcOJ3btzS5iM6AitJcftnY7pm2lGsrJVYLGjzIw== + dependencies: + "@babel/traverse" "^7.8.3" + "@babel/types" "^7.8.3" + +"@babel/helper-function-name@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/helper-function-name/-/helper-function-name-7.10.1.tgz#92bd63829bfc9215aca9d9defa85f56b539454f4" + integrity sha512-fcpumwhs3YyZ/ttd5Rz0xn0TpIwVkN7X0V38B9TWNfVF42KEkhkAAuPCQ3oXmtTRtiPJrmZ0TrfS0GKF0eMaRQ== + dependencies: + "@babel/helper-get-function-arity" "^7.10.1" + "@babel/template" "^7.10.1" + "@babel/types" "^7.10.1" + +"@babel/helper-function-name@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-function-name/-/helper-function-name-7.8.3.tgz#eeeb665a01b1f11068e9fb86ad56a1cb1a824cca" + integrity sha512-BCxgX1BC2hD/oBlIFUgOCQDOPV8nSINxCwM3o93xP4P9Fq6aV5sgv2cOOITDMtCfQ+3PvHp3l689XZvAM9QyOA== + dependencies: + "@babel/helper-get-function-arity" "^7.8.3" + "@babel/template" "^7.8.3" + "@babel/types" "^7.8.3" + +"@babel/helper-get-function-arity@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/helper-get-function-arity/-/helper-get-function-arity-7.10.1.tgz#7303390a81ba7cb59613895a192b93850e373f7d" + integrity sha512-F5qdXkYGOQUb0hpRaPoetF9AnsXknKjWMZ+wmsIRsp5ge5sFh4c3h1eH2pRTTuy9KKAA2+TTYomGXAtEL2fQEw== + dependencies: + "@babel/types" "^7.10.1" + +"@babel/helper-get-function-arity@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-get-function-arity/-/helper-get-function-arity-7.8.3.tgz#b894b947bd004381ce63ea1db9f08547e920abd5" + integrity sha512-FVDR+Gd9iLjUMY1fzE2SR0IuaJToR4RkCDARVfsBBPSP53GEqSFjD8gNyxg246VUyc/ALRxFaAK8rVG7UT7xRA== + dependencies: + "@babel/types" "^7.8.3" + +"@babel/helper-hoist-variables@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.8.3.tgz#1dbe9b6b55d78c9b4183fc8cdc6e30ceb83b7134" + integrity sha512-ky1JLOjcDUtSc+xkt0xhYff7Z6ILTAHKmZLHPxAhOP0Nd77O+3nCsd6uSVYur6nJnCI029CrNbYlc0LoPfAPQg== + dependencies: + "@babel/types" "^7.8.3" + +"@babel/helper-member-expression-to-functions@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/helper-member-expression-to-functions/-/helper-member-expression-to-functions-7.10.1.tgz#432967fd7e12a4afef66c4687d4ca22bc0456f15" + integrity sha512-u7XLXeM2n50gb6PWJ9hoO5oO7JFPaZtrh35t8RqKLT1jFKj9IWeD1zrcrYp1q1qiZTdEarfDWfTIP8nGsu0h5g== + dependencies: + "@babel/types" "^7.10.1" + +"@babel/helper-member-expression-to-functions@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-member-expression-to-functions/-/helper-member-expression-to-functions-7.8.3.tgz#659b710498ea6c1d9907e0c73f206eee7dadc24c" + integrity sha512-fO4Egq88utkQFjbPrSHGmGLFqmrshs11d46WI+WZDESt7Wu7wN2G2Iu+NMMZJFDOVRHAMIkB5SNh30NtwCA7RA== + dependencies: + "@babel/types" "^7.8.3" + +"@babel/helper-module-imports@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-module-imports/-/helper-module-imports-7.8.3.tgz#7fe39589b39c016331b6b8c3f441e8f0b1419498" + integrity sha512-R0Bx3jippsbAEtzkpZ/6FIiuzOURPcMjHp+Z6xPe6DtApDJx+w7UYyOLanZqO8+wKR9G10s/FmHXvxaMd9s6Kg== + dependencies: + "@babel/types" "^7.8.3" + +"@babel/helper-module-transforms@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/helper-module-transforms/-/helper-module-transforms-7.9.0.tgz#43b34dfe15961918707d247327431388e9fe96e5" + integrity sha512-0FvKyu0gpPfIQ8EkxlrAydOWROdHpBmiCiRwLkUiBGhCUPRRbVD2/tm3sFr/c/GWFrQ/ffutGUAnx7V0FzT2wA== + dependencies: + "@babel/helper-module-imports" "^7.8.3" + "@babel/helper-replace-supers" "^7.8.6" + "@babel/helper-simple-access" "^7.8.3" + "@babel/helper-split-export-declaration" "^7.8.3" + "@babel/template" "^7.8.6" + "@babel/types" "^7.9.0" + lodash "^4.17.13" + +"@babel/helper-optimise-call-expression@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/helper-optimise-call-expression/-/helper-optimise-call-expression-7.10.1.tgz#b4a1f2561870ce1247ceddb02a3860fa96d72543" + integrity sha512-a0DjNS1prnBsoKx83dP2falChcs7p3i8VMzdrSbfLhuQra/2ENC4sbri34dz/rWmDADsmF1q5GbfaXydh0Jbjg== + dependencies: + "@babel/types" "^7.10.1" + +"@babel/helper-optimise-call-expression@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-optimise-call-expression/-/helper-optimise-call-expression-7.8.3.tgz#7ed071813d09c75298ef4f208956006b6111ecb9" + integrity sha512-Kag20n86cbO2AvHca6EJsvqAd82gc6VMGule4HwebwMlwkpXuVqrNRj6CkCV2sKxgi9MyAUnZVnZ6lJ1/vKhHQ== + dependencies: + "@babel/types" "^7.8.3" + +"@babel/helper-plugin-utils@^7.0.0", "@babel/helper-plugin-utils@^7.8.0", "@babel/helper-plugin-utils@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-plugin-utils/-/helper-plugin-utils-7.8.3.tgz#9ea293be19babc0f52ff8ca88b34c3611b208670" + integrity sha512-j+fq49Xds2smCUNYmEHF9kGNkhbet6yVIBp4e6oeQpH1RUs/Ir06xUKzDjDkGcaaokPiTNs2JBWHjaE4csUkZQ== + +"@babel/helper-plugin-utils@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/helper-plugin-utils/-/helper-plugin-utils-7.10.1.tgz#ec5a5cf0eec925b66c60580328b122c01230a127" + integrity sha512-fvoGeXt0bJc7VMWZGCAEBEMo/HAjW2mP8apF5eXK0wSqwLAVHAISCWRoLMBMUs2kqeaG77jltVqu4Hn8Egl3nA== + +"@babel/helper-regex@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-regex/-/helper-regex-7.8.3.tgz#139772607d51b93f23effe72105b319d2a4c6965" + integrity sha512-BWt0QtYv/cg/NecOAZMdcn/waj/5P26DR4mVLXfFtDokSR6fyuG0Pj+e2FqtSME+MqED1khnSMulkmGl8qWiUQ== + dependencies: + lodash "^4.17.13" + +"@babel/helper-remap-async-to-generator@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-remap-async-to-generator/-/helper-remap-async-to-generator-7.8.3.tgz#273c600d8b9bf5006142c1e35887d555c12edd86" + integrity sha512-kgwDmw4fCg7AVgS4DukQR/roGp+jP+XluJE5hsRZwxCYGg+Rv9wSGErDWhlI90FODdYfd4xG4AQRiMDjjN0GzA== + dependencies: + "@babel/helper-annotate-as-pure" "^7.8.3" + "@babel/helper-wrap-function" "^7.8.3" + "@babel/template" "^7.8.3" + "@babel/traverse" "^7.8.3" + "@babel/types" "^7.8.3" + +"@babel/helper-replace-supers@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/helper-replace-supers/-/helper-replace-supers-7.10.1.tgz#ec6859d20c5d8087f6a2dc4e014db7228975f13d" + integrity sha512-SOwJzEfpuQwInzzQJGjGaiG578UYmyi2Xw668klPWV5n07B73S0a9btjLk/52Mlcxa+5AdIYqws1KyXRfMoB7A== + dependencies: + "@babel/helper-member-expression-to-functions" "^7.10.1" + "@babel/helper-optimise-call-expression" "^7.10.1" + "@babel/traverse" "^7.10.1" + "@babel/types" "^7.10.1" + +"@babel/helper-replace-supers@^7.8.3", "@babel/helper-replace-supers@^7.8.6": + version "7.8.6" + resolved "https://registry.yarnpkg.com/@babel/helper-replace-supers/-/helper-replace-supers-7.8.6.tgz#5ada744fd5ad73203bf1d67459a27dcba67effc8" + integrity sha512-PeMArdA4Sv/Wf4zXwBKPqVj7n9UF/xg6slNRtZW84FM7JpE1CbG8B612FyM4cxrf4fMAMGO0kR7voy1ForHHFA== + dependencies: + "@babel/helper-member-expression-to-functions" "^7.8.3" + "@babel/helper-optimise-call-expression" "^7.8.3" + "@babel/traverse" "^7.8.6" + "@babel/types" "^7.8.6" + +"@babel/helper-simple-access@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-simple-access/-/helper-simple-access-7.8.3.tgz#7f8109928b4dab4654076986af575231deb639ae" + integrity sha512-VNGUDjx5cCWg4vvCTR8qQ7YJYZ+HBjxOgXEl7ounz+4Sn7+LMD3CFrCTEU6/qXKbA2nKg21CwhhBzO0RpRbdCw== + dependencies: + "@babel/template" "^7.8.3" + "@babel/types" "^7.8.3" + +"@babel/helper-split-export-declaration@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.10.1.tgz#c6f4be1cbc15e3a868e4c64a17d5d31d754da35f" + integrity sha512-UQ1LVBPrYdbchNhLwj6fetj46BcFwfS4NllJo/1aJsT+1dLTEnXJL0qHqtY7gPzF8S2fXBJamf1biAXV3X077g== + dependencies: + "@babel/types" "^7.10.1" + +"@babel/helper-split-export-declaration@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.8.3.tgz#31a9f30070f91368a7182cf05f831781065fc7a9" + integrity sha512-3x3yOeyBhW851hroze7ElzdkeRXQYQbFIb7gLK1WQYsw2GWDay5gAJNw1sWJ0VFP6z5J1whqeXH/WCdCjZv6dA== + dependencies: + "@babel/types" "^7.8.3" + +"@babel/helper-validator-identifier@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.10.1.tgz#5770b0c1a826c4f53f5ede5e153163e0318e94b5" + integrity sha512-5vW/JXLALhczRCWP0PnFDMCJAchlBvM7f4uk/jXritBnIa6E1KmqmtrS3yn1LAnxFBypQ3eneLuXjsnfQsgILw== + +"@babel/helper-validator-identifier@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.9.0.tgz#ad53562a7fc29b3b9a91bbf7d10397fd146346ed" + integrity sha512-6G8bQKjOh+of4PV/ThDm/rRqlU7+IGoJuofpagU5GlEl29Vv0RGqqt86ZGRV8ZuSOY3o+8yXl5y782SMcG7SHw== + +"@babel/helper-wrap-function@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/helper-wrap-function/-/helper-wrap-function-7.8.3.tgz#9dbdb2bb55ef14aaa01fe8c99b629bd5352d8610" + integrity sha512-LACJrbUET9cQDzb6kG7EeD7+7doC3JNvUgTEQOx2qaO1fKlzE/Bf05qs9w1oXQMmXlPO65lC3Tq9S6gZpTErEQ== + dependencies: + "@babel/helper-function-name" "^7.8.3" + "@babel/template" "^7.8.3" + "@babel/traverse" "^7.8.3" + "@babel/types" "^7.8.3" + +"@babel/helpers@^7.9.0": + version "7.9.2" + resolved "https://registry.yarnpkg.com/@babel/helpers/-/helpers-7.9.2.tgz#b42a81a811f1e7313b88cba8adc66b3d9ae6c09f" + integrity sha512-JwLvzlXVPjO8eU9c/wF9/zOIN7X6h8DYf7mG4CiFRZRvZNKEF5dQ3H3V+ASkHoIB3mWhatgl5ONhyqHRI6MppA== + dependencies: + "@babel/template" "^7.8.3" + "@babel/traverse" "^7.9.0" + "@babel/types" "^7.9.0" + +"@babel/highlight@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/highlight/-/highlight-7.10.1.tgz#841d098ba613ba1a427a2b383d79e35552c38ae0" + integrity sha512-8rMof+gVP8mxYZApLF/JgNDAkdKa+aJt3ZYxF8z6+j/hpeXL7iMsKCPHa2jNMHu/qqBwzQF4OHNoYi8dMA/rYg== + dependencies: + "@babel/helper-validator-identifier" "^7.10.1" + chalk "^2.0.0" + js-tokens "^4.0.0" + +"@babel/highlight@^7.8.3": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/highlight/-/highlight-7.9.0.tgz#4e9b45ccb82b79607271b2979ad82c7b68163079" + integrity sha512-lJZPilxX7Op3Nv/2cvFdnlepPXDxi29wxteT57Q965oc5R9v86ztx0jfxVrTcBk8C2kcPkkDa2Z4T3ZsPPVWsQ== + dependencies: + "@babel/helper-validator-identifier" "^7.9.0" + chalk "^2.0.0" + js-tokens "^4.0.0" + +"@babel/parser@^7.1.0", "@babel/parser@^7.4.3", "@babel/parser@^7.7.0", "@babel/parser@^7.8.6", "@babel/parser@^7.9.0": + version "7.9.4" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.9.4.tgz#68a35e6b0319bbc014465be43828300113f2f2e8" + integrity sha512-bC49otXX6N0/VYhgOMh4gnP26E9xnDZK3TmbNpxYzzz9BQLBosQwfyOe9/cXUU3txYhTzLCbcqd5c8y/OmCjHA== + +"@babel/parser@^7.10.1": + version "7.10.2" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.10.2.tgz#871807f10442b92ff97e4783b9b54f6a0ca812d0" + integrity sha512-PApSXlNMJyB4JiGVhCOlzKIif+TKFTvu0aQAhnTvfP/z3vVSN6ZypH5bfUNwFXXjRQtUEBNFd2PtmCmG2Py3qQ== + +"@babel/plugin-proposal-async-generator-functions@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-proposal-async-generator-functions/-/plugin-proposal-async-generator-functions-7.8.3.tgz#bad329c670b382589721b27540c7d288601c6e6f" + integrity sha512-NZ9zLv848JsV3hs8ryEh7Uaz/0KsmPLqv0+PdkDJL1cJy0K4kOCFa8zc1E3mp+RHPQcpdfb/6GovEsW4VDrOMw== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/helper-remap-async-to-generator" "^7.8.3" + "@babel/plugin-syntax-async-generators" "^7.8.0" + +"@babel/plugin-proposal-class-properties@7.8.3", "@babel/plugin-proposal-class-properties@^7.5.5": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-proposal-class-properties/-/plugin-proposal-class-properties-7.8.3.tgz#5e06654af5cd04b608915aada9b2a6788004464e" + integrity sha512-EqFhbo7IosdgPgZggHaNObkmO1kNUe3slaKu54d5OWvy+p9QIKOzK1GAEpAIsZtWVtPXUHSMcT4smvDrCfY4AA== + dependencies: + "@babel/helper-create-class-features-plugin" "^7.8.3" + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-proposal-decorators@7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-proposal-decorators/-/plugin-proposal-decorators-7.8.3.tgz#2156860ab65c5abf068c3f67042184041066543e" + integrity sha512-e3RvdvS4qPJVTe288DlXjwKflpfy1hr0j5dz5WpIYYeP7vQZg2WfAEIp8k5/Lwis/m5REXEteIz6rrcDtXXG7w== + dependencies: + "@babel/helper-create-class-features-plugin" "^7.8.3" + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-syntax-decorators" "^7.8.3" + +"@babel/plugin-proposal-dynamic-import@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-proposal-dynamic-import/-/plugin-proposal-dynamic-import-7.8.3.tgz#38c4fe555744826e97e2ae930b0fb4cc07e66054" + integrity sha512-NyaBbyLFXFLT9FP+zk0kYlUlA8XtCUbehs67F0nnEg7KICgMc2mNkIeu9TYhKzyXMkrapZFwAhXLdnt4IYHy1w== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-syntax-dynamic-import" "^7.8.0" + +"@babel/plugin-proposal-json-strings@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-proposal-json-strings/-/plugin-proposal-json-strings-7.8.3.tgz#da5216b238a98b58a1e05d6852104b10f9a70d6b" + integrity sha512-KGhQNZ3TVCQG/MjRbAUwuH+14y9q0tpxs1nWWs3pbSleRdDro9SAMMDyye8HhY1gqZ7/NqIc8SKhya0wRDgP1Q== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-syntax-json-strings" "^7.8.0" + +"@babel/plugin-proposal-nullish-coalescing-operator@7.8.3", "@babel/plugin-proposal-nullish-coalescing-operator@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-proposal-nullish-coalescing-operator/-/plugin-proposal-nullish-coalescing-operator-7.8.3.tgz#e4572253fdeed65cddeecfdab3f928afeb2fd5d2" + integrity sha512-TS9MlfzXpXKt6YYomudb/KU7nQI6/xnapG6in1uZxoxDghuSMZsPb6D2fyUwNYSAp4l1iR7QtFOjkqcRYcUsfw== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-syntax-nullish-coalescing-operator" "^7.8.0" + +"@babel/plugin-proposal-numeric-separator@7.8.3", "@babel/plugin-proposal-numeric-separator@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-proposal-numeric-separator/-/plugin-proposal-numeric-separator-7.8.3.tgz#5d6769409699ec9b3b68684cd8116cedff93bad8" + integrity sha512-jWioO1s6R/R+wEHizfaScNsAx+xKgwTLNXSh7tTC4Usj3ItsPEhYkEpU4h+lpnBwq7NBVOJXfO6cRFYcX69JUQ== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-syntax-numeric-separator" "^7.8.3" + +"@babel/plugin-proposal-object-rest-spread@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/plugin-proposal-object-rest-spread/-/plugin-proposal-object-rest-spread-7.9.0.tgz#a28993699fc13df165995362693962ba6b061d6f" + integrity sha512-UgqBv6bjq4fDb8uku9f+wcm1J7YxJ5nT7WO/jBr0cl0PLKb7t1O6RNR1kZbjgx2LQtsDI9hwoQVmn0yhXeQyow== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-syntax-object-rest-spread" "^7.8.0" + +"@babel/plugin-proposal-optional-catch-binding@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-proposal-optional-catch-binding/-/plugin-proposal-optional-catch-binding-7.8.3.tgz#9dee96ab1650eed88646ae9734ca167ac4a9c5c9" + integrity sha512-0gkX7J7E+AtAw9fcwlVQj8peP61qhdg/89D5swOkjYbkboA2CVckn3kiyum1DE0wskGb7KJJxBdyEBApDLLVdw== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-syntax-optional-catch-binding" "^7.8.0" + +"@babel/plugin-proposal-optional-chaining@7.9.0", "@babel/plugin-proposal-optional-chaining@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/plugin-proposal-optional-chaining/-/plugin-proposal-optional-chaining-7.9.0.tgz#31db16b154c39d6b8a645292472b98394c292a58" + integrity sha512-NDn5tu3tcv4W30jNhmc2hyD5c56G6cXx4TesJubhxrJeCvuuMpttxr0OnNCqbZGhFjLrg+NIhxxC+BK5F6yS3w== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-syntax-optional-chaining" "^7.8.0" + +"@babel/plugin-proposal-unicode-property-regex@^7.4.4", "@babel/plugin-proposal-unicode-property-regex@^7.8.3": + version "7.8.8" + resolved "https://registry.yarnpkg.com/@babel/plugin-proposal-unicode-property-regex/-/plugin-proposal-unicode-property-regex-7.8.8.tgz#ee3a95e90cdc04fe8cd92ec3279fa017d68a0d1d" + integrity sha512-EVhjVsMpbhLw9ZfHWSx2iy13Q8Z/eg8e8ccVWt23sWQK5l1UdkoLJPN5w69UA4uITGBnEZD2JOe4QOHycYKv8A== + dependencies: + "@babel/helper-create-regexp-features-plugin" "^7.8.8" + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-syntax-async-generators@^7.8.0": + version "7.8.4" + resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-async-generators/-/plugin-syntax-async-generators-7.8.4.tgz#a983fb1aeb2ec3f6ed042a210f640e90e786fe0d" + integrity sha512-tycmZxkGfZaxhMRbXlPXuVFpdWlXpir2W4AMhSJgRKzk/eDlIXOhb2LHWoLpDF7TEHylV5zNhykX6KAgHJmTNw== + dependencies: + "@babel/helper-plugin-utils" "^7.8.0" + +"@babel/plugin-syntax-decorators@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-decorators/-/plugin-syntax-decorators-7.8.3.tgz#8d2c15a9f1af624b0025f961682a9d53d3001bda" + integrity sha512-8Hg4dNNT9/LcA1zQlfwuKR8BUc/if7Q7NkTam9sGTcJphLwpf2g4S42uhspQrIrR+dpzE0dtTqBVFoHl8GtnnQ== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-syntax-dynamic-import@^7.8.0": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-dynamic-import/-/plugin-syntax-dynamic-import-7.8.3.tgz#62bf98b2da3cd21d626154fc96ee5b3cb68eacb3" + integrity sha512-5gdGbFon+PszYzqs83S3E5mpi7/y/8M9eC90MRTZfduQOYW76ig6SOSPNe41IG5LoP3FGBn2N0RjVDSQiS94kQ== + dependencies: + "@babel/helper-plugin-utils" "^7.8.0" + +"@babel/plugin-syntax-flow@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-flow/-/plugin-syntax-flow-7.8.3.tgz#f2c883bd61a6316f2c89380ae5122f923ba4527f" + integrity sha512-innAx3bUbA0KSYj2E2MNFSn9hiCeowOFLxlsuhXzw8hMQnzkDomUr9QCD7E9VF60NmnG1sNTuuv6Qf4f8INYsg== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-syntax-json-strings@^7.8.0": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-json-strings/-/plugin-syntax-json-strings-7.8.3.tgz#01ca21b668cd8218c9e640cb6dd88c5412b2c96a" + integrity sha512-lY6kdGpWHvjoe2vk4WrAapEuBR69EMxZl+RoGRhrFGNYVK8mOPAW8VfbT/ZgrFbXlDNiiaxQnAtgVCZ6jv30EA== + dependencies: + "@babel/helper-plugin-utils" "^7.8.0" + +"@babel/plugin-syntax-jsx@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-jsx/-/plugin-syntax-jsx-7.8.3.tgz#521b06c83c40480f1e58b4fd33b92eceb1d6ea94" + integrity sha512-WxdW9xyLgBdefoo0Ynn3MRSkhe5tFVxxKNVdnZSh318WrG2e2jH+E9wd/++JsqcLJZPfz87njQJ8j2Upjm0M0A== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-syntax-nullish-coalescing-operator@^7.8.0": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-nullish-coalescing-operator/-/plugin-syntax-nullish-coalescing-operator-7.8.3.tgz#167ed70368886081f74b5c36c65a88c03b66d1a9" + integrity sha512-aSff4zPII1u2QD7y+F8oDsz19ew4IGEJg9SVW+bqwpwtfFleiQDMdzA/R+UlWDzfnHFCxxleFT0PMIrR36XLNQ== + dependencies: + "@babel/helper-plugin-utils" "^7.8.0" + +"@babel/plugin-syntax-numeric-separator@^7.8.0", "@babel/plugin-syntax-numeric-separator@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-numeric-separator/-/plugin-syntax-numeric-separator-7.8.3.tgz#0e3fb63e09bea1b11e96467271c8308007e7c41f" + integrity sha512-H7dCMAdN83PcCmqmkHB5dtp+Xa9a6LKSvA2hiFBC/5alSHxM5VgWZXFqDi0YFe8XNGT6iCa+z4V4zSt/PdZ7Dw== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-syntax-object-rest-spread@^7.0.0", "@babel/plugin-syntax-object-rest-spread@^7.8.0": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-object-rest-spread/-/plugin-syntax-object-rest-spread-7.8.3.tgz#60e225edcbd98a640332a2e72dd3e66f1af55871" + integrity sha512-XoqMijGZb9y3y2XskN+P1wUGiVwWZ5JmoDRwx5+3GmEplNyVM2s2Dg8ILFQm8rWM48orGy5YpI5Bl8U1y7ydlA== + dependencies: + "@babel/helper-plugin-utils" "^7.8.0" + +"@babel/plugin-syntax-optional-catch-binding@^7.8.0": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-optional-catch-binding/-/plugin-syntax-optional-catch-binding-7.8.3.tgz#6111a265bcfb020eb9efd0fdfd7d26402b9ed6c1" + integrity sha512-6VPD0Pc1lpTqw0aKoeRTMiB+kWhAoT24PA+ksWSBrFtl5SIRVpZlwN3NNPQjehA2E/91FV3RjLWoVTglWcSV3Q== + dependencies: + "@babel/helper-plugin-utils" "^7.8.0" + +"@babel/plugin-syntax-optional-chaining@^7.8.0": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-optional-chaining/-/plugin-syntax-optional-chaining-7.8.3.tgz#4f69c2ab95167e0180cd5336613f8c5788f7d48a" + integrity sha512-KoK9ErH1MBlCPxV0VANkXW2/dw4vlbGDrFgz8bmUsBGYkFRcbRwMh6cIJubdPrkxRwuGdtCk0v/wPTKbQgBjkg== + dependencies: + "@babel/helper-plugin-utils" "^7.8.0" + +"@babel/plugin-syntax-top-level-await@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-top-level-await/-/plugin-syntax-top-level-await-7.8.3.tgz#3acdece695e6b13aaf57fc291d1a800950c71391" + integrity sha512-kwj1j9lL/6Wd0hROD3b/OZZ7MSrZLqqn9RAZ5+cYYsflQ9HZBIKCUkr3+uL1MEJ1NePiUbf98jjiMQSv0NMR9g== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-syntax-typescript@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-typescript/-/plugin-syntax-typescript-7.10.1.tgz#5e82bc27bb4202b93b949b029e699db536733810" + integrity sha512-X/d8glkrAtra7CaQGMiGs/OGa6XgUzqPcBXCIGFCpCqnfGlT0Wfbzo/B89xHhnInTaItPK8LALblVXcUOEh95Q== + dependencies: + "@babel/helper-plugin-utils" "^7.10.1" + +"@babel/plugin-syntax-typescript@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-typescript/-/plugin-syntax-typescript-7.8.3.tgz#c1f659dda97711a569cef75275f7e15dcaa6cabc" + integrity sha512-GO1MQ/SGGGoiEXY0e0bSpHimJvxqB7lktLLIq2pv8xG7WZ8IMEle74jIe1FhprHBWjwjZtXHkycDLZXIWM5Wfg== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-arrow-functions@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-arrow-functions/-/plugin-transform-arrow-functions-7.8.3.tgz#82776c2ed0cd9e1a49956daeb896024c9473b8b6" + integrity sha512-0MRF+KC8EqH4dbuITCWwPSzsyO3HIWWlm30v8BbbpOrS1B++isGxPnnuq/IZvOX5J2D/p7DQalQm+/2PnlKGxg== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-async-to-generator@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-async-to-generator/-/plugin-transform-async-to-generator-7.8.3.tgz#4308fad0d9409d71eafb9b1a6ee35f9d64b64086" + integrity sha512-imt9tFLD9ogt56Dd5CI/6XgpukMwd/fLGSrix2httihVe7LOGVPhyhMh1BU5kDM7iHD08i8uUtmV2sWaBFlHVQ== + dependencies: + "@babel/helper-module-imports" "^7.8.3" + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/helper-remap-async-to-generator" "^7.8.3" + +"@babel/plugin-transform-block-scoped-functions@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-block-scoped-functions/-/plugin-transform-block-scoped-functions-7.8.3.tgz#437eec5b799b5852072084b3ae5ef66e8349e8a3" + integrity sha512-vo4F2OewqjbB1+yaJ7k2EJFHlTP3jR634Z9Cj9itpqNjuLXvhlVxgnjsHsdRgASR8xYDrx6onw4vW5H6We0Jmg== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-block-scoping@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-block-scoping/-/plugin-transform-block-scoping-7.8.3.tgz#97d35dab66857a437c166358b91d09050c868f3a" + integrity sha512-pGnYfm7RNRgYRi7bids5bHluENHqJhrV4bCZRwc5GamaWIIs07N4rZECcmJL6ZClwjDz1GbdMZFtPs27hTB06w== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + lodash "^4.17.13" + +"@babel/plugin-transform-classes@^7.9.0": + version "7.9.2" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-classes/-/plugin-transform-classes-7.9.2.tgz#8603fc3cc449e31fdbdbc257f67717536a11af8d" + integrity sha512-TC2p3bPzsfvSsqBZo0kJnuelnoK9O3welkUpqSqBQuBF6R5MN2rysopri8kNvtlGIb2jmUO7i15IooAZJjZuMQ== + dependencies: + "@babel/helper-annotate-as-pure" "^7.8.3" + "@babel/helper-define-map" "^7.8.3" + "@babel/helper-function-name" "^7.8.3" + "@babel/helper-optimise-call-expression" "^7.8.3" + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/helper-replace-supers" "^7.8.6" + "@babel/helper-split-export-declaration" "^7.8.3" + globals "^11.1.0" + +"@babel/plugin-transform-computed-properties@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-computed-properties/-/plugin-transform-computed-properties-7.8.3.tgz#96d0d28b7f7ce4eb5b120bb2e0e943343c86f81b" + integrity sha512-O5hiIpSyOGdrQZRQ2ccwtTVkgUDBBiCuK//4RJ6UfePllUTCENOzKxfh6ulckXKc0DixTFLCfb2HVkNA7aDpzA== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-destructuring@^7.8.3": + version "7.8.8" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-destructuring/-/plugin-transform-destructuring-7.8.8.tgz#fadb2bc8e90ccaf5658de6f8d4d22ff6272a2f4b" + integrity sha512-eRJu4Vs2rmttFCdhPUM3bV0Yo/xPSdPw6ML9KHs/bjB4bLA5HXlbvYXPOD5yASodGod+krjYx21xm1QmL8dCJQ== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-dotall-regex@^7.4.4", "@babel/plugin-transform-dotall-regex@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-dotall-regex/-/plugin-transform-dotall-regex-7.8.3.tgz#c3c6ec5ee6125c6993c5cbca20dc8621a9ea7a6e" + integrity sha512-kLs1j9Nn4MQoBYdRXH6AeaXMbEJFaFu/v1nQkvib6QzTj8MZI5OQzqmD83/2jEM1z0DLilra5aWO5YpyC0ALIw== + dependencies: + "@babel/helper-create-regexp-features-plugin" "^7.8.3" + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-duplicate-keys@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-duplicate-keys/-/plugin-transform-duplicate-keys-7.8.3.tgz#8d12df309aa537f272899c565ea1768e286e21f1" + integrity sha512-s8dHiBUbcbSgipS4SMFuWGqCvyge5V2ZeAWzR6INTVC3Ltjig/Vw1G2Gztv0vU/hRG9X8IvKvYdoksnUfgXOEQ== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-exponentiation-operator@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-exponentiation-operator/-/plugin-transform-exponentiation-operator-7.8.3.tgz#581a6d7f56970e06bf51560cd64f5e947b70d7b7" + integrity sha512-zwIpuIymb3ACcInbksHaNcR12S++0MDLKkiqXHl3AzpgdKlFNhog+z/K0+TGW+b0w5pgTq4H6IwV/WhxbGYSjQ== + dependencies: + "@babel/helper-builder-binary-assignment-operator-visitor" "^7.8.3" + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-flow-strip-types@7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-flow-strip-types/-/plugin-transform-flow-strip-types-7.9.0.tgz#8a3538aa40434e000b8f44a3c5c9ac7229bd2392" + integrity sha512-7Qfg0lKQhEHs93FChxVLAvhBshOPQDtJUTVHr/ZwQNRccCm4O9D79r9tVSoV8iNwjP1YgfD+e/fgHcPkN1qEQg== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-syntax-flow" "^7.8.3" + +"@babel/plugin-transform-for-of@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-for-of/-/plugin-transform-for-of-7.9.0.tgz#0f260e27d3e29cd1bb3128da5e76c761aa6c108e" + integrity sha512-lTAnWOpMwOXpyDx06N+ywmF3jNbafZEqZ96CGYabxHrxNX8l5ny7dt4bK/rGwAh9utyP2b2Hv7PlZh1AAS54FQ== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-function-name@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-function-name/-/plugin-transform-function-name-7.8.3.tgz#279373cb27322aaad67c2683e776dfc47196ed8b" + integrity sha512-rO/OnDS78Eifbjn5Py9v8y0aR+aSYhDhqAwVfsTl0ERuMZyr05L1aFSCJnbv2mmsLkit/4ReeQ9N2BgLnOcPCQ== + dependencies: + "@babel/helper-function-name" "^7.8.3" + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-literals@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-literals/-/plugin-transform-literals-7.8.3.tgz#aef239823d91994ec7b68e55193525d76dbd5dc1" + integrity sha512-3Tqf8JJ/qB7TeldGl+TT55+uQei9JfYaregDcEAyBZ7akutriFrt6C/wLYIer6OYhleVQvH/ntEhjE/xMmy10A== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-member-expression-literals@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-member-expression-literals/-/plugin-transform-member-expression-literals-7.8.3.tgz#963fed4b620ac7cbf6029c755424029fa3a40410" + integrity sha512-3Wk2EXhnw+rP+IDkK6BdtPKsUE5IeZ6QOGrPYvw52NwBStw9V1ZVzxgK6fSKSxqUvH9eQPR3tm3cOq79HlsKYA== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-modules-amd@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-modules-amd/-/plugin-transform-modules-amd-7.9.0.tgz#19755ee721912cf5bb04c07d50280af3484efef4" + integrity sha512-vZgDDF003B14O8zJy0XXLnPH4sg+9X5hFBBGN1V+B2rgrB+J2xIypSN6Rk9imB2hSTHQi5OHLrFWsZab1GMk+Q== + dependencies: + "@babel/helper-module-transforms" "^7.9.0" + "@babel/helper-plugin-utils" "^7.8.3" + babel-plugin-dynamic-import-node "^2.3.0" + +"@babel/plugin-transform-modules-commonjs@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-modules-commonjs/-/plugin-transform-modules-commonjs-7.9.0.tgz#e3e72f4cbc9b4a260e30be0ea59bdf5a39748940" + integrity sha512-qzlCrLnKqio4SlgJ6FMMLBe4bySNis8DFn1VkGmOcxG9gqEyPIOzeQrA//u0HAKrWpJlpZbZMPB1n/OPa4+n8g== + dependencies: + "@babel/helper-module-transforms" "^7.9.0" + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/helper-simple-access" "^7.8.3" + babel-plugin-dynamic-import-node "^2.3.0" + +"@babel/plugin-transform-modules-systemjs@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-modules-systemjs/-/plugin-transform-modules-systemjs-7.9.0.tgz#e9fd46a296fc91e009b64e07ddaa86d6f0edeb90" + integrity sha512-FsiAv/nao/ud2ZWy4wFacoLOm5uxl0ExSQ7ErvP7jpoihLR6Cq90ilOFyX9UXct3rbtKsAiZ9kFt5XGfPe/5SQ== + dependencies: + "@babel/helper-hoist-variables" "^7.8.3" + "@babel/helper-module-transforms" "^7.9.0" + "@babel/helper-plugin-utils" "^7.8.3" + babel-plugin-dynamic-import-node "^2.3.0" + +"@babel/plugin-transform-modules-umd@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-modules-umd/-/plugin-transform-modules-umd-7.9.0.tgz#e909acae276fec280f9b821a5f38e1f08b480697" + integrity sha512-uTWkXkIVtg/JGRSIABdBoMsoIeoHQHPTL0Y2E7xf5Oj7sLqwVsNXOkNk0VJc7vF0IMBsPeikHxFjGe+qmwPtTQ== + dependencies: + "@babel/helper-module-transforms" "^7.9.0" + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-named-capturing-groups-regex@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-named-capturing-groups-regex/-/plugin-transform-named-capturing-groups-regex-7.8.3.tgz#a2a72bffa202ac0e2d0506afd0939c5ecbc48c6c" + integrity sha512-f+tF/8UVPU86TrCb06JoPWIdDpTNSGGcAtaD9mLP0aYGA0OS0j7j7DHJR0GTFrUZPUU6loZhbsVZgTh0N+Qdnw== + dependencies: + "@babel/helper-create-regexp-features-plugin" "^7.8.3" + +"@babel/plugin-transform-new-target@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-new-target/-/plugin-transform-new-target-7.8.3.tgz#60cc2ae66d85c95ab540eb34babb6434d4c70c43" + integrity sha512-QuSGysibQpyxexRyui2vca+Cmbljo8bcRckgzYV4kRIsHpVeyeC3JDO63pY+xFZ6bWOBn7pfKZTqV4o/ix9sFw== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-object-super@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-object-super/-/plugin-transform-object-super-7.8.3.tgz#ebb6a1e7a86ffa96858bd6ac0102d65944261725" + integrity sha512-57FXk+gItG/GejofIyLIgBKTas4+pEU47IXKDBWFTxdPd7F80H8zybyAY7UoblVfBhBGs2EKM+bJUu2+iUYPDQ== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/helper-replace-supers" "^7.8.3" + +"@babel/plugin-transform-parameters@^7.8.7": + version "7.9.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-parameters/-/plugin-transform-parameters-7.9.3.tgz#3028d0cc20ddc733166c6e9c8534559cee09f54a" + integrity sha512-fzrQFQhp7mIhOzmOtPiKffvCYQSK10NR8t6BBz2yPbeUHb9OLW8RZGtgDRBn8z2hGcwvKDL3vC7ojPTLNxmqEg== + dependencies: + "@babel/helper-get-function-arity" "^7.8.3" + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-property-literals@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-property-literals/-/plugin-transform-property-literals-7.8.3.tgz#33194300d8539c1ed28c62ad5087ba3807b98263" + integrity sha512-uGiiXAZMqEoQhRWMK17VospMZh5sXWg+dlh2soffpkAl96KAm+WZuJfa6lcELotSRmooLqg0MWdH6UUq85nmmg== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-react-constant-elements@^7.0.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-react-constant-elements/-/plugin-transform-react-constant-elements-7.9.0.tgz#a75abc936a3819edec42d3386d9f1c93f28d9d9e" + integrity sha512-wXMXsToAUOxJuBBEHajqKLFWcCkOSLshTI2ChCFFj1zDd7od4IOxiwLCOObNUvOpkxLpjIuaIdBMmNt6ocCPAw== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-react-display-name@7.8.3", "@babel/plugin-transform-react-display-name@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-react-display-name/-/plugin-transform-react-display-name-7.8.3.tgz#70ded987c91609f78353dd76d2fb2a0bb991e8e5" + integrity sha512-3Jy/PCw8Fe6uBKtEgz3M82ljt+lTg+xJaM4og+eyu83qLT87ZUSckn0wy7r31jflURWLO83TW6Ylf7lyXj3m5A== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-react-jsx-development@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-react-jsx-development/-/plugin-transform-react-jsx-development-7.9.0.tgz#3c2a130727caf00c2a293f0aed24520825dbf754" + integrity sha512-tK8hWKrQncVvrhvtOiPpKrQjfNX3DtkNLSX4ObuGcpS9p0QrGetKmlySIGR07y48Zft8WVgPakqd/bk46JrMSw== + dependencies: + "@babel/helper-builder-react-jsx-experimental" "^7.9.0" + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-syntax-jsx" "^7.8.3" + +"@babel/plugin-transform-react-jsx-self@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-react-jsx-self/-/plugin-transform-react-jsx-self-7.9.0.tgz#f4f26a325820205239bb915bad8e06fcadabb49b" + integrity sha512-K2ObbWPKT7KUTAoyjCsFilOkEgMvFG+y0FqOl6Lezd0/13kMkkjHskVsZvblRPj1PHA44PrToaZANrryppzTvQ== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-syntax-jsx" "^7.8.3" + +"@babel/plugin-transform-react-jsx-source@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-react-jsx-source/-/plugin-transform-react-jsx-source-7.9.0.tgz#89ef93025240dd5d17d3122294a093e5e0183de0" + integrity sha512-K6m3LlSnTSfRkM6FcRk8saNEeaeyG5k7AVkBU2bZK3+1zdkSED3qNdsWrUgQBeTVD2Tp3VMmerxVO2yM5iITmw== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-syntax-jsx" "^7.8.3" + +"@babel/plugin-transform-react-jsx@^7.9.1", "@babel/plugin-transform-react-jsx@^7.9.4": + version "7.9.4" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-react-jsx/-/plugin-transform-react-jsx-7.9.4.tgz#86f576c8540bd06d0e95e0b61ea76d55f6cbd03f" + integrity sha512-Mjqf3pZBNLt854CK0C/kRuXAnE6H/bo7xYojP+WGtX8glDGSibcwnsWwhwoSuRg0+EBnxPC1ouVnuetUIlPSAw== + dependencies: + "@babel/helper-builder-react-jsx" "^7.9.0" + "@babel/helper-builder-react-jsx-experimental" "^7.9.0" + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-syntax-jsx" "^7.8.3" + +"@babel/plugin-transform-regenerator@^7.8.7": + version "7.8.7" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-regenerator/-/plugin-transform-regenerator-7.8.7.tgz#5e46a0dca2bee1ad8285eb0527e6abc9c37672f8" + integrity sha512-TIg+gAl4Z0a3WmD3mbYSk+J9ZUH6n/Yc57rtKRnlA/7rcCvpekHXe0CMZHP1gYp7/KLe9GHTuIba0vXmls6drA== + dependencies: + regenerator-transform "^0.14.2" + +"@babel/plugin-transform-reserved-words@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-reserved-words/-/plugin-transform-reserved-words-7.8.3.tgz#9a0635ac4e665d29b162837dd3cc50745dfdf1f5" + integrity sha512-mwMxcycN3omKFDjDQUl+8zyMsBfjRFr0Zn/64I41pmjv4NJuqcYlEtezwYtw9TFd9WR1vN5kiM+O0gMZzO6L0A== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-runtime@7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-runtime/-/plugin-transform-runtime-7.9.0.tgz#45468c0ae74cc13204e1d3b1f4ce6ee83258af0b" + integrity sha512-pUu9VSf3kI1OqbWINQ7MaugnitRss1z533436waNXp+0N3ur3zfut37sXiQMxkuCF4VUjwZucen/quskCh7NHw== + dependencies: + "@babel/helper-module-imports" "^7.8.3" + "@babel/helper-plugin-utils" "^7.8.3" + resolve "^1.8.1" + semver "^5.5.1" + +"@babel/plugin-transform-shorthand-properties@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-shorthand-properties/-/plugin-transform-shorthand-properties-7.8.3.tgz#28545216e023a832d4d3a1185ed492bcfeac08c8" + integrity sha512-I9DI6Odg0JJwxCHzbzW08ggMdCezoWcuQRz3ptdudgwaHxTjxw5HgdFJmZIkIMlRymL6YiZcped4TTCB0JcC8w== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-spread@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-spread/-/plugin-transform-spread-7.8.3.tgz#9c8ffe8170fdfb88b114ecb920b82fb6e95fe5e8" + integrity sha512-CkuTU9mbmAoFOI1tklFWYYbzX5qCIZVXPVy0jpXgGwkplCndQAa58s2jr66fTeQnA64bDox0HL4U56CFYoyC7g== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-sticky-regex@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-sticky-regex/-/plugin-transform-sticky-regex-7.8.3.tgz#be7a1290f81dae767475452199e1f76d6175b100" + integrity sha512-9Spq0vGCD5Bb4Z/ZXXSK5wbbLFMG085qd2vhL1JYu1WcQ5bXqZBAYRzU1d+p79GcHs2szYv5pVQCX13QgldaWw== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/helper-regex" "^7.8.3" + +"@babel/plugin-transform-template-literals@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-template-literals/-/plugin-transform-template-literals-7.8.3.tgz#7bfa4732b455ea6a43130adc0ba767ec0e402a80" + integrity sha512-820QBtykIQOLFT8NZOcTRJ1UNuztIELe4p9DCgvj4NK+PwluSJ49we7s9FB1HIGNIYT7wFUJ0ar2QpCDj0escQ== + dependencies: + "@babel/helper-annotate-as-pure" "^7.8.3" + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-typeof-symbol@^7.8.4": + version "7.8.4" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-typeof-symbol/-/plugin-transform-typeof-symbol-7.8.4.tgz#ede4062315ce0aaf8a657a920858f1a2f35fc412" + integrity sha512-2QKyfjGdvuNfHsb7qnBBlKclbD4CfshH2KvDabiijLMGXPHJXGxtDzwIF7bQP+T0ysw8fYTtxPafgfs/c1Lrqg== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/plugin-transform-typescript@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-typescript/-/plugin-transform-typescript-7.10.1.tgz#2c54daea231f602468686d9faa76f182a94507a6" + integrity sha512-v+QWKlmCnsaimLeqq9vyCsVRMViZG1k2SZTlcZvB+TqyH570Zsij8nvVUZzOASCRiQFUxkLrn9Wg/kH0zgy5OQ== + dependencies: + "@babel/helper-create-class-features-plugin" "^7.10.1" + "@babel/helper-plugin-utils" "^7.10.1" + "@babel/plugin-syntax-typescript" "^7.10.1" + +"@babel/plugin-transform-typescript@^7.9.0": + version "7.9.4" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-typescript/-/plugin-transform-typescript-7.9.4.tgz#4bb4dde4f10bbf2d787fce9707fb09b483e33359" + integrity sha512-yeWeUkKx2auDbSxRe8MusAG+n4m9BFY/v+lPjmQDgOFX5qnySkUY5oXzkp6FwPdsYqnKay6lorXYdC0n3bZO7w== + dependencies: + "@babel/helper-create-class-features-plugin" "^7.8.3" + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-syntax-typescript" "^7.8.3" + +"@babel/plugin-transform-unicode-regex@^7.8.3": + version "7.8.3" + resolved "https://registry.yarnpkg.com/@babel/plugin-transform-unicode-regex/-/plugin-transform-unicode-regex-7.8.3.tgz#0cef36e3ba73e5c57273effb182f46b91a1ecaad" + integrity sha512-+ufgJjYdmWfSQ+6NS9VGUR2ns8cjJjYbrbi11mZBTaWm+Fui/ncTLFF28Ei1okavY+xkojGr1eJxNsWYeA5aZw== + dependencies: + "@babel/helper-create-regexp-features-plugin" "^7.8.3" + "@babel/helper-plugin-utils" "^7.8.3" + +"@babel/preset-env@7.9.0", "@babel/preset-env@^7.4.5": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/preset-env/-/preset-env-7.9.0.tgz#a5fc42480e950ae8f5d9f8f2bbc03f52722df3a8" + integrity sha512-712DeRXT6dyKAM/FMbQTV/FvRCms2hPCx+3weRjZ8iQVQWZejWWk1wwG6ViWMyqb/ouBbGOl5b6aCk0+j1NmsQ== + dependencies: + "@babel/compat-data" "^7.9.0" + "@babel/helper-compilation-targets" "^7.8.7" + "@babel/helper-module-imports" "^7.8.3" + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-proposal-async-generator-functions" "^7.8.3" + "@babel/plugin-proposal-dynamic-import" "^7.8.3" + "@babel/plugin-proposal-json-strings" "^7.8.3" + "@babel/plugin-proposal-nullish-coalescing-operator" "^7.8.3" + "@babel/plugin-proposal-numeric-separator" "^7.8.3" + "@babel/plugin-proposal-object-rest-spread" "^7.9.0" + "@babel/plugin-proposal-optional-catch-binding" "^7.8.3" + "@babel/plugin-proposal-optional-chaining" "^7.9.0" + "@babel/plugin-proposal-unicode-property-regex" "^7.8.3" + "@babel/plugin-syntax-async-generators" "^7.8.0" + "@babel/plugin-syntax-dynamic-import" "^7.8.0" + "@babel/plugin-syntax-json-strings" "^7.8.0" + "@babel/plugin-syntax-nullish-coalescing-operator" "^7.8.0" + "@babel/plugin-syntax-numeric-separator" "^7.8.0" + "@babel/plugin-syntax-object-rest-spread" "^7.8.0" + "@babel/plugin-syntax-optional-catch-binding" "^7.8.0" + "@babel/plugin-syntax-optional-chaining" "^7.8.0" + "@babel/plugin-syntax-top-level-await" "^7.8.3" + "@babel/plugin-transform-arrow-functions" "^7.8.3" + "@babel/plugin-transform-async-to-generator" "^7.8.3" + "@babel/plugin-transform-block-scoped-functions" "^7.8.3" + "@babel/plugin-transform-block-scoping" "^7.8.3" + "@babel/plugin-transform-classes" "^7.9.0" + "@babel/plugin-transform-computed-properties" "^7.8.3" + "@babel/plugin-transform-destructuring" "^7.8.3" + "@babel/plugin-transform-dotall-regex" "^7.8.3" + "@babel/plugin-transform-duplicate-keys" "^7.8.3" + "@babel/plugin-transform-exponentiation-operator" "^7.8.3" + "@babel/plugin-transform-for-of" "^7.9.0" + "@babel/plugin-transform-function-name" "^7.8.3" + "@babel/plugin-transform-literals" "^7.8.3" + "@babel/plugin-transform-member-expression-literals" "^7.8.3" + "@babel/plugin-transform-modules-amd" "^7.9.0" + "@babel/plugin-transform-modules-commonjs" "^7.9.0" + "@babel/plugin-transform-modules-systemjs" "^7.9.0" + "@babel/plugin-transform-modules-umd" "^7.9.0" + "@babel/plugin-transform-named-capturing-groups-regex" "^7.8.3" + "@babel/plugin-transform-new-target" "^7.8.3" + "@babel/plugin-transform-object-super" "^7.8.3" + "@babel/plugin-transform-parameters" "^7.8.7" + "@babel/plugin-transform-property-literals" "^7.8.3" + "@babel/plugin-transform-regenerator" "^7.8.7" + "@babel/plugin-transform-reserved-words" "^7.8.3" + "@babel/plugin-transform-shorthand-properties" "^7.8.3" + "@babel/plugin-transform-spread" "^7.8.3" + "@babel/plugin-transform-sticky-regex" "^7.8.3" + "@babel/plugin-transform-template-literals" "^7.8.3" + "@babel/plugin-transform-typeof-symbol" "^7.8.4" + "@babel/plugin-transform-unicode-regex" "^7.8.3" + "@babel/preset-modules" "^0.1.3" + "@babel/types" "^7.9.0" + browserslist "^4.9.1" + core-js-compat "^3.6.2" + invariant "^2.2.2" + levenary "^1.1.1" + semver "^5.5.0" + +"@babel/preset-modules@^0.1.3": + version "0.1.3" + resolved "https://registry.yarnpkg.com/@babel/preset-modules/-/preset-modules-0.1.3.tgz#13242b53b5ef8c883c3cf7dddd55b36ce80fbc72" + integrity sha512-Ra3JXOHBq2xd56xSF7lMKXdjBn3T772Y1Wet3yWnkDly9zHvJki029tAFzvAAK5cf4YV3yoxuP61crYRol6SVg== + dependencies: + "@babel/helper-plugin-utils" "^7.0.0" + "@babel/plugin-proposal-unicode-property-regex" "^7.4.4" + "@babel/plugin-transform-dotall-regex" "^7.4.4" + "@babel/types" "^7.4.4" + esutils "^2.0.2" + +"@babel/preset-react@7.9.1": + version "7.9.1" + resolved "https://registry.yarnpkg.com/@babel/preset-react/-/preset-react-7.9.1.tgz#b346403c36d58c3bb544148272a0cefd9c28677a" + integrity sha512-aJBYF23MPj0RNdp/4bHnAP0NVqqZRr9kl0NAOP4nJCex6OYVio59+dnQzsAWFuogdLyeaKA1hmfUIVZkY5J+TQ== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-transform-react-display-name" "^7.8.3" + "@babel/plugin-transform-react-jsx" "^7.9.1" + "@babel/plugin-transform-react-jsx-development" "^7.9.0" + "@babel/plugin-transform-react-jsx-self" "^7.9.0" + "@babel/plugin-transform-react-jsx-source" "^7.9.0" + +"@babel/preset-react@^7.0.0": + version "7.9.4" + resolved "https://registry.yarnpkg.com/@babel/preset-react/-/preset-react-7.9.4.tgz#c6c97693ac65b6b9c0b4f25b948a8f665463014d" + integrity sha512-AxylVB3FXeOTQXNXyiuAQJSvss62FEotbX2Pzx3K/7c+MKJMdSg6Ose6QYllkdCFA8EInCJVw7M/o5QbLuA4ZQ== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-transform-react-display-name" "^7.8.3" + "@babel/plugin-transform-react-jsx" "^7.9.4" + "@babel/plugin-transform-react-jsx-development" "^7.9.0" + "@babel/plugin-transform-react-jsx-self" "^7.9.0" + "@babel/plugin-transform-react-jsx-source" "^7.9.0" + +"@babel/preset-typescript@7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/preset-typescript/-/preset-typescript-7.9.0.tgz#87705a72b1f0d59df21c179f7c3d2ef4b16ce192" + integrity sha512-S4cueFnGrIbvYJgwsVFKdvOmpiL0XGw9MFW9D0vgRys5g36PBhZRL8NX8Gr2akz8XRtzq6HuDXPD/1nniagNUg== + dependencies: + "@babel/helper-plugin-utils" "^7.8.3" + "@babel/plugin-transform-typescript" "^7.9.0" + +"@babel/preset-typescript@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/preset-typescript/-/preset-typescript-7.10.1.tgz#a8d8d9035f55b7d99a2461a0bdc506582914d07e" + integrity sha512-m6GV3y1ShiqxnyQj10600ZVOFrSSAa8HQ3qIUk2r+gcGtHTIRw0dJnFLt1WNXpKjtVw7yw1DAPU/6ma2ZvgJuA== + dependencies: + "@babel/helper-plugin-utils" "^7.10.1" + "@babel/plugin-transform-typescript" "^7.10.1" + +"@babel/runtime-corejs3@^7.8.3": + version "7.9.2" + resolved "https://registry.yarnpkg.com/@babel/runtime-corejs3/-/runtime-corejs3-7.9.2.tgz#26fe4aa77e9f1ecef9b776559bbb8e84d34284b7" + integrity sha512-HHxmgxbIzOfFlZ+tdeRKtaxWOMUoCG5Mu3wKeUmOxjYrwb3AAHgnmtCUbPPK11/raIWLIBK250t8E2BPO0p7jA== + dependencies: + core-js-pure "^3.0.0" + regenerator-runtime "^0.13.4" + +"@babel/runtime@7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/runtime/-/runtime-7.9.0.tgz#337eda67401f5b066a6f205a3113d4ac18ba495b" + integrity sha512-cTIudHnzuWLS56ik4DnRnqqNf8MkdUzV4iFFI1h7Jo9xvrpQROYaAnaSd2mHLQAzzZAPfATynX5ord6YlNYNMA== + dependencies: + regenerator-runtime "^0.13.4" + +"@babel/runtime@^7.0.0", "@babel/runtime@^7.3.4", "@babel/runtime@^7.4.5", "@babel/runtime@^7.7.2", "@babel/runtime@^7.8.4", "@babel/runtime@^7.8.7": + version "7.9.2" + resolved "https://registry.yarnpkg.com/@babel/runtime/-/runtime-7.9.2.tgz#d90df0583a3a252f09aaa619665367bae518db06" + integrity sha512-NE2DtOdufG7R5vnfQUTehdTfNycfUANEtCa9PssN9O/xmTzP4E08UI797ixaei6hBEVL9BI/PsdJS5x7mWoB9Q== + dependencies: + regenerator-runtime "^0.13.4" + +"@babel/template@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/template/-/template-7.10.1.tgz#e167154a94cb5f14b28dc58f5356d2162f539811" + integrity sha512-OQDg6SqvFSsc9A0ej6SKINWrpJiNonRIniYondK2ViKhB06i3c0s+76XUft71iqBEe9S1OKsHwPAjfHnuvnCig== + dependencies: + "@babel/code-frame" "^7.10.1" + "@babel/parser" "^7.10.1" + "@babel/types" "^7.10.1" + +"@babel/template@^7.4.0", "@babel/template@^7.8.3", "@babel/template@^7.8.6": + version "7.8.6" + resolved "https://registry.yarnpkg.com/@babel/template/-/template-7.8.6.tgz#86b22af15f828dfb086474f964dcc3e39c43ce2b" + integrity sha512-zbMsPMy/v0PWFZEhQJ66bqjhH+z0JgMoBWuikXybgG3Gkd/3t5oQ1Rw2WQhnSrsOmsKXnZOx15tkC4qON/+JPg== + dependencies: + "@babel/code-frame" "^7.8.3" + "@babel/parser" "^7.8.6" + "@babel/types" "^7.8.6" + +"@babel/traverse@^7.1.0", "@babel/traverse@^7.4.3", "@babel/traverse@^7.7.0", "@babel/traverse@^7.8.3", "@babel/traverse@^7.8.6", "@babel/traverse@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.9.0.tgz#d3882c2830e513f4fe4cec9fe76ea1cc78747892" + integrity sha512-jAZQj0+kn4WTHO5dUZkZKhbFrqZE7K5LAQ5JysMnmvGij+wOdr+8lWqPeW0BcF4wFwrEXXtdGO7wcV6YPJcf3w== + dependencies: + "@babel/code-frame" "^7.8.3" + "@babel/generator" "^7.9.0" + "@babel/helper-function-name" "^7.8.3" + "@babel/helper-split-export-declaration" "^7.8.3" + "@babel/parser" "^7.9.0" + "@babel/types" "^7.9.0" + debug "^4.1.0" + globals "^11.1.0" + lodash "^4.17.13" + +"@babel/traverse@^7.10.1": + version "7.10.1" + resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.10.1.tgz#bbcef3031e4152a6c0b50147f4958df54ca0dd27" + integrity sha512-C/cTuXeKt85K+p08jN6vMDz8vSV0vZcI0wmQ36o6mjbuo++kPMdpOYw23W2XH04dbRt9/nMEfA4W3eR21CD+TQ== + dependencies: + "@babel/code-frame" "^7.10.1" + "@babel/generator" "^7.10.1" + "@babel/helper-function-name" "^7.10.1" + "@babel/helper-split-export-declaration" "^7.10.1" + "@babel/parser" "^7.10.1" + "@babel/types" "^7.10.1" + debug "^4.1.0" + globals "^11.1.0" + lodash "^4.17.13" + +"@babel/types@^7.0.0", "@babel/types@^7.3.0", "@babel/types@^7.4.0", "@babel/types@^7.4.4", "@babel/types@^7.7.0", "@babel/types@^7.8.3", "@babel/types@^7.8.6", "@babel/types@^7.8.7", "@babel/types@^7.9.0": + version "7.9.0" + resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.9.0.tgz#00b064c3df83ad32b2dbf5ff07312b15c7f1efb5" + integrity sha512-BS9JKfXkzzJl8RluW4JGknzpiUV7ZrvTayM6yfqLTVBEnFtyowVIOu6rqxRd5cVO6yGoWf4T8u8dgK9oB+GCng== + dependencies: + "@babel/helper-validator-identifier" "^7.9.0" + lodash "^4.17.13" + to-fast-properties "^2.0.0" + +"@babel/types@^7.10.1", "@babel/types@^7.10.2": + version "7.10.2" + resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.10.2.tgz#30283be31cad0dbf6fb00bd40641ca0ea675172d" + integrity sha512-AD3AwWBSz0AWF0AkCN9VPiWrvldXq+/e3cHa4J89vo4ymjz1XwrBFFVZmkJTsQIPNk+ZVomPSXUJqq8yyjZsng== + dependencies: + "@babel/helper-validator-identifier" "^7.10.1" + lodash "^4.17.13" + to-fast-properties "^2.0.0" + +"@cnakazawa/watch@^1.0.3": + version "1.0.4" + resolved "https://registry.yarnpkg.com/@cnakazawa/watch/-/watch-1.0.4.tgz#f864ae85004d0fcab6f50be9141c4da368d1656a" + integrity sha512-v9kIhKwjeZThiWrLmj0y17CWoyddASLj9O2yvbZkbvw/N3rWOYy9zkV66ursAoVr0mV15bL8g0c4QZUE6cdDoQ== + dependencies: + exec-sh "^0.3.2" + minimist "^1.2.0" + +"@csstools/convert-colors@^1.4.0": + version "1.4.0" + resolved "https://registry.yarnpkg.com/@csstools/convert-colors/-/convert-colors-1.4.0.tgz#ad495dc41b12e75d588c6db8b9834f08fa131eb7" + integrity sha512-5a6wqoJV/xEdbRNKVo6I4hO3VjyDq//8q2f9I6PBAvMesJHFauXDorcNCsr9RzvsZnaWi5NYCcfyqP1QeFHFbw== + +"@csstools/normalize.css@^10.1.0": + version "10.1.0" + resolved "https://registry.yarnpkg.com/@csstools/normalize.css/-/normalize.css-10.1.0.tgz#f0950bba18819512d42f7197e56c518aa491cf18" + integrity sha512-ij4wRiunFfaJxjB0BdrYHIH8FxBJpOwNPhhAcunlmPdXudL1WQV1qoP9un6JsEBAgQH+7UXyyjh0g7jTxXK6tg== + +"@hapi/address@2.x.x": + version "2.1.4" + resolved "https://registry.yarnpkg.com/@hapi/address/-/address-2.1.4.tgz#5d67ed43f3fd41a69d4b9ff7b56e7c0d1d0a81e5" + integrity sha512-QD1PhQk+s31P1ixsX0H0Suoupp3VMXzIVMSwobR3F3MSUO2YCV0B7xqLcUw/Bh8yuvd3LhpyqLQWTNcRmp6IdQ== + +"@hapi/bourne@1.x.x": + version "1.3.2" + resolved "https://registry.yarnpkg.com/@hapi/bourne/-/bourne-1.3.2.tgz#0a7095adea067243ce3283e1b56b8a8f453b242a" + integrity sha512-1dVNHT76Uu5N3eJNTYcvxee+jzX4Z9lfciqRRHCU27ihbUcYi+iSc2iml5Ke1LXe1SyJCLA0+14Jh4tXJgOppA== + +"@hapi/hoek@8.x.x", "@hapi/hoek@^8.3.0": + version "8.5.1" + resolved "https://registry.yarnpkg.com/@hapi/hoek/-/hoek-8.5.1.tgz#fde96064ca446dec8c55a8c2f130957b070c6e06" + integrity sha512-yN7kbciD87WzLGc5539Tn0sApjyiGHAJgKvG9W8C7O+6c7qmoQMfVs0W4bX17eqz6C78QJqqFrtgdK5EWf6Qow== + +"@hapi/joi@^15.0.0": + version "15.1.1" + resolved "https://registry.yarnpkg.com/@hapi/joi/-/joi-15.1.1.tgz#c675b8a71296f02833f8d6d243b34c57b8ce19d7" + integrity sha512-entf8ZMOK8sc+8YfeOlM8pCfg3b5+WZIKBfUaaJT8UsjAAPjartzxIYm3TIbjvA4u+u++KbcXD38k682nVHDAQ== + dependencies: + "@hapi/address" "2.x.x" + "@hapi/bourne" "1.x.x" + "@hapi/hoek" "8.x.x" + "@hapi/topo" "3.x.x" + +"@hapi/topo@3.x.x": + version "3.1.6" + resolved "https://registry.yarnpkg.com/@hapi/topo/-/topo-3.1.6.tgz#68d935fa3eae7fdd5ab0d7f953f3205d8b2bfc29" + integrity sha512-tAag0jEcjwH+P2quUfipd7liWCNX2F8NvYjQp2wtInsZxnMlypdw0FtAOLxtvvkO+GSRRbmNi8m/5y42PQJYCQ== + dependencies: + "@hapi/hoek" "^8.3.0" + +"@jest/console@^24.7.1", "@jest/console@^24.9.0": + version "24.9.0" + resolved "https://registry.yarnpkg.com/@jest/console/-/console-24.9.0.tgz#79b1bc06fb74a8cfb01cbdedf945584b1b9707f0" + integrity sha512-Zuj6b8TnKXi3q4ymac8EQfc3ea/uhLeCGThFqXeC8H9/raaH8ARPUTdId+XyGd03Z4In0/VjD2OYFcBF09fNLQ== + dependencies: + "@jest/source-map" "^24.9.0" + chalk "^2.0.1" + slash "^2.0.0" + +"@jest/core@^24.9.0": + version "24.9.0" + resolved "https://registry.yarnpkg.com/@jest/core/-/core-24.9.0.tgz#2ceccd0b93181f9c4850e74f2a9ad43d351369c4" + integrity sha512-Fogg3s4wlAr1VX7q+rhV9RVnUv5tD7VuWfYy1+whMiWUrvl7U3QJSJyWcDio9Lq2prqYsZaeTv2Rz24pWGkJ2A== + dependencies: + "@jest/console" "^24.7.1" + "@jest/reporters" "^24.9.0" + "@jest/test-result" "^24.9.0" + "@jest/transform" "^24.9.0" + "@jest/types" "^24.9.0" + ansi-escapes "^3.0.0" + chalk "^2.0.1" + exit "^0.1.2" + graceful-fs "^4.1.15" + jest-changed-files "^24.9.0" + jest-config "^24.9.0" + jest-haste-map "^24.9.0" + jest-message-util "^24.9.0" + jest-regex-util "^24.3.0" + jest-resolve "^24.9.0" + jest-resolve-dependencies "^24.9.0" + jest-runner "^24.9.0" + jest-runtime "^24.9.0" + jest-snapshot "^24.9.0" + jest-util "^24.9.0" + jest-validate "^24.9.0" + jest-watcher "^24.9.0" + micromatch "^3.1.10" + p-each-series "^1.0.0" + realpath-native "^1.1.0" + rimraf "^2.5.4" + slash "^2.0.0" + strip-ansi "^5.0.0" + +"@jest/environment@^24.3.0", "@jest/environment@^24.9.0": + version "24.9.0" + resolved "https://registry.yarnpkg.com/@jest/environment/-/environment-24.9.0.tgz#21e3afa2d65c0586cbd6cbefe208bafade44ab18" + integrity sha512-5A1QluTPhvdIPFYnO3sZC3smkNeXPVELz7ikPbhUj0bQjB07EoE9qtLrem14ZUYWdVayYbsjVwIiL4WBIMV4aQ== + dependencies: + "@jest/fake-timers" "^24.9.0" + "@jest/transform" "^24.9.0" + "@jest/types" "^24.9.0" + jest-mock "^24.9.0" + +"@jest/fake-timers@^24.3.0", "@jest/fake-timers@^24.9.0": + version "24.9.0" + resolved "https://registry.yarnpkg.com/@jest/fake-timers/-/fake-timers-24.9.0.tgz#ba3e6bf0eecd09a636049896434d306636540c93" + integrity sha512-eWQcNa2YSwzXWIMC5KufBh3oWRIijrQFROsIqt6v/NS9Io/gknw1jsAC9c+ih/RQX4A3O7SeWAhQeN0goKhT9A== + dependencies: + "@jest/types" "^24.9.0" + jest-message-util "^24.9.0" + jest-mock "^24.9.0" + +"@jest/reporters@^24.9.0": + version "24.9.0" + resolved "https://registry.yarnpkg.com/@jest/reporters/-/reporters-24.9.0.tgz#86660eff8e2b9661d042a8e98a028b8d631a5b43" + integrity sha512-mu4X0yjaHrffOsWmVLzitKmmmWSQ3GGuefgNscUSWNiUNcEOSEQk9k3pERKEQVBb0Cnn88+UESIsZEMH3o88Gw== + dependencies: + "@jest/environment" "^24.9.0" + "@jest/test-result" "^24.9.0" + "@jest/transform" "^24.9.0" + "@jest/types" "^24.9.0" + chalk "^2.0.1" + exit "^0.1.2" + glob "^7.1.2" + istanbul-lib-coverage "^2.0.2" + istanbul-lib-instrument "^3.0.1" + istanbul-lib-report "^2.0.4" + istanbul-lib-source-maps "^3.0.1" + istanbul-reports "^2.2.6" + jest-haste-map "^24.9.0" + jest-resolve "^24.9.0" + jest-runtime "^24.9.0" + jest-util "^24.9.0" + jest-worker "^24.6.0" + node-notifier "^5.4.2" + slash "^2.0.0" + source-map "^0.6.0" + string-length "^2.0.0" + +"@jest/source-map@^24.3.0", "@jest/source-map@^24.9.0": + version "24.9.0" + resolved "https://registry.yarnpkg.com/@jest/source-map/-/source-map-24.9.0.tgz#0e263a94430be4b41da683ccc1e6bffe2a191714" + integrity sha512-/Xw7xGlsZb4MJzNDgB7PW5crou5JqWiBQaz6xyPd3ArOg2nfn/PunV8+olXbbEZzNl591o5rWKE9BRDaFAuIBg== + dependencies: + callsites "^3.0.0" + graceful-fs "^4.1.15" + source-map "^0.6.0" + +"@jest/test-result@^24.9.0": + version "24.9.0" + resolved "https://registry.yarnpkg.com/@jest/test-result/-/test-result-24.9.0.tgz#11796e8aa9dbf88ea025757b3152595ad06ba0ca" + integrity sha512-XEFrHbBonBJ8dGp2JmF8kP/nQI/ImPpygKHwQ/SY+es59Z3L5PI4Qb9TQQMAEeYsThG1xF0k6tmG0tIKATNiiA== + dependencies: + "@jest/console" "^24.9.0" + "@jest/types" "^24.9.0" + "@types/istanbul-lib-coverage" "^2.0.0" + +"@jest/test-sequencer@^24.9.0": + version "24.9.0" + resolved "https://registry.yarnpkg.com/@jest/test-sequencer/-/test-sequencer-24.9.0.tgz#f8f334f35b625a4f2f355f2fe7e6036dad2e6b31" + integrity sha512-6qqsU4o0kW1dvA95qfNog8v8gkRN9ph6Lz7r96IvZpHdNipP2cBcb07J1Z45mz/VIS01OHJ3pY8T5fUY38tg4A== + dependencies: + "@jest/test-result" "^24.9.0" + jest-haste-map "^24.9.0" + jest-runner "^24.9.0" + jest-runtime "^24.9.0" + +"@jest/transform@^24.9.0": + version "24.9.0" + resolved "https://registry.yarnpkg.com/@jest/transform/-/transform-24.9.0.tgz#4ae2768b296553fadab09e9ec119543c90b16c56" + integrity sha512-TcQUmyNRxV94S0QpMOnZl0++6RMiqpbH/ZMccFB/amku6Uwvyb1cjYX7xkp5nGNkbX4QPH/FcB6q1HBTHynLmQ== + dependencies: + "@babel/core" "^7.1.0" + "@jest/types" "^24.9.0" + babel-plugin-istanbul "^5.1.0" + chalk "^2.0.1" + convert-source-map "^1.4.0" + fast-json-stable-stringify "^2.0.0" + graceful-fs "^4.1.15" + jest-haste-map "^24.9.0" + jest-regex-util "^24.9.0" + jest-util "^24.9.0" + micromatch "^3.1.10" + pirates "^4.0.1" + realpath-native "^1.1.0" + slash "^2.0.0" + source-map "^0.6.1" + write-file-atomic "2.4.1" + +"@jest/types@^24.3.0", "@jest/types@^24.9.0": + version "24.9.0" + resolved "https://registry.yarnpkg.com/@jest/types/-/types-24.9.0.tgz#63cb26cb7500d069e5a389441a7c6ab5e909fc59" + integrity sha512-XKK7ze1apu5JWQ5eZjHITP66AX+QsLlbaJRBGYr8pNzwcAE2JVkwnf0yqjHTsDRcjR0mujy/NmZMXw5kl+kGBw== + dependencies: + "@types/istanbul-lib-coverage" "^2.0.0" + "@types/istanbul-reports" "^1.1.1" + "@types/yargs" "^13.0.0" + +"@jest/types@^25.5.0": + version "25.5.0" + resolved "https://registry.yarnpkg.com/@jest/types/-/types-25.5.0.tgz#4d6a4793f7b9599fc3680877b856a97dbccf2a9d" + integrity sha512-OXD0RgQ86Tu3MazKo8bnrkDRaDXXMGUqd+kTtLtK1Zb7CRzQcaSRPPPV37SvYTdevXEBVxe0HXylEjs8ibkmCw== + dependencies: + "@types/istanbul-lib-coverage" "^2.0.0" + "@types/istanbul-reports" "^1.1.1" + "@types/yargs" "^15.0.0" + chalk "^3.0.0" + +"@mrmlnc/readdir-enhanced@^2.2.1": + version "2.2.1" + resolved "https://registry.yarnpkg.com/@mrmlnc/readdir-enhanced/-/readdir-enhanced-2.2.1.tgz#524af240d1a360527b730475ecfa1344aa540dde" + integrity sha512-bPHp6Ji8b41szTOcaP63VlnbbO5Ny6dwAATtY6JTjh5N2OLrb5Qk/Th5cRkRQhkWCt+EJsYrNB0MiL+Gpn6e3g== + dependencies: + call-me-maybe "^1.0.1" + glob-to-regexp "^0.3.0" + +"@nodelib/fs.stat@^1.1.2": + version "1.1.3" + resolved "https://registry.yarnpkg.com/@nodelib/fs.stat/-/fs.stat-1.1.3.tgz#2b5a3ab3f918cca48a8c754c08168e3f03eba61b" + integrity sha512-shAmDyaQC4H92APFoIaVDHCx5bStIocgvbwQyxPRrbUY20V1EYTbSDchWbuwlMG3V17cprZhA6+78JfB+3DTPw== + +"@svgr/babel-plugin-add-jsx-attribute@^4.2.0": + version "4.2.0" + resolved "https://registry.yarnpkg.com/@svgr/babel-plugin-add-jsx-attribute/-/babel-plugin-add-jsx-attribute-4.2.0.tgz#dadcb6218503532d6884b210e7f3c502caaa44b1" + integrity sha512-j7KnilGyZzYr/jhcrSYS3FGWMZVaqyCG0vzMCwzvei0coIkczuYMcniK07nI0aHJINciujjH11T72ICW5eL5Ig== + +"@svgr/babel-plugin-remove-jsx-attribute@^4.2.0": + version "4.2.0" + resolved "https://registry.yarnpkg.com/@svgr/babel-plugin-remove-jsx-attribute/-/babel-plugin-remove-jsx-attribute-4.2.0.tgz#297550b9a8c0c7337bea12bdfc8a80bb66f85abc" + integrity sha512-3XHLtJ+HbRCH4n28S7y/yZoEQnRpl0tvTZQsHqvaeNXPra+6vE5tbRliH3ox1yZYPCxrlqaJT/Mg+75GpDKlvQ== + +"@svgr/babel-plugin-remove-jsx-empty-expression@^4.2.0": + version "4.2.0" + resolved "https://registry.yarnpkg.com/@svgr/babel-plugin-remove-jsx-empty-expression/-/babel-plugin-remove-jsx-empty-expression-4.2.0.tgz#c196302f3e68eab6a05e98af9ca8570bc13131c7" + integrity sha512-yTr2iLdf6oEuUE9MsRdvt0NmdpMBAkgK8Bjhl6epb+eQWk6abBaX3d65UZ3E3FWaOwePyUgNyNCMVG61gGCQ7w== + +"@svgr/babel-plugin-replace-jsx-attribute-value@^4.2.0": + version "4.2.0" + resolved "https://registry.yarnpkg.com/@svgr/babel-plugin-replace-jsx-attribute-value/-/babel-plugin-replace-jsx-attribute-value-4.2.0.tgz#310ec0775de808a6a2e4fd4268c245fd734c1165" + integrity sha512-U9m870Kqm0ko8beHawRXLGLvSi/ZMrl89gJ5BNcT452fAjtF2p4uRzXkdzvGJJJYBgx7BmqlDjBN/eCp5AAX2w== + +"@svgr/babel-plugin-svg-dynamic-title@^4.3.3": + version "4.3.3" + resolved "https://registry.yarnpkg.com/@svgr/babel-plugin-svg-dynamic-title/-/babel-plugin-svg-dynamic-title-4.3.3.tgz#2cdedd747e5b1b29ed4c241e46256aac8110dd93" + integrity sha512-w3Be6xUNdwgParsvxkkeZb545VhXEwjGMwExMVBIdPQJeyMQHqm9Msnb2a1teHBqUYL66qtwfhNkbj1iarCG7w== + +"@svgr/babel-plugin-svg-em-dimensions@^4.2.0": + version "4.2.0" + resolved "https://registry.yarnpkg.com/@svgr/babel-plugin-svg-em-dimensions/-/babel-plugin-svg-em-dimensions-4.2.0.tgz#9a94791c9a288108d20a9d2cc64cac820f141391" + integrity sha512-C0Uy+BHolCHGOZ8Dnr1zXy/KgpBOkEUYY9kI/HseHVPeMbluaX3CijJr7D4C5uR8zrc1T64nnq/k63ydQuGt4w== + +"@svgr/babel-plugin-transform-react-native-svg@^4.2.0": + version "4.2.0" + resolved "https://registry.yarnpkg.com/@svgr/babel-plugin-transform-react-native-svg/-/babel-plugin-transform-react-native-svg-4.2.0.tgz#151487322843359a1ca86b21a3815fd21a88b717" + integrity sha512-7YvynOpZDpCOUoIVlaaOUU87J4Z6RdD6spYN4eUb5tfPoKGSF9OG2NuhgYnq4jSkAxcpMaXWPf1cePkzmqTPNw== + +"@svgr/babel-plugin-transform-svg-component@^4.2.0": + version "4.2.0" + resolved "https://registry.yarnpkg.com/@svgr/babel-plugin-transform-svg-component/-/babel-plugin-transform-svg-component-4.2.0.tgz#5f1e2f886b2c85c67e76da42f0f6be1b1767b697" + integrity sha512-hYfYuZhQPCBVotABsXKSCfel2slf/yvJY8heTVX1PCTaq/IgASq1IyxPPKJ0chWREEKewIU/JMSsIGBtK1KKxw== + +"@svgr/babel-preset@^4.3.3": + version "4.3.3" + resolved "https://registry.yarnpkg.com/@svgr/babel-preset/-/babel-preset-4.3.3.tgz#a75d8c2f202ac0e5774e6bfc165d028b39a1316c" + integrity sha512-6PG80tdz4eAlYUN3g5GZiUjg2FMcp+Wn6rtnz5WJG9ITGEF1pmFdzq02597Hn0OmnQuCVaBYQE1OVFAnwOl+0A== + dependencies: + "@svgr/babel-plugin-add-jsx-attribute" "^4.2.0" + "@svgr/babel-plugin-remove-jsx-attribute" "^4.2.0" + "@svgr/babel-plugin-remove-jsx-empty-expression" "^4.2.0" + "@svgr/babel-plugin-replace-jsx-attribute-value" "^4.2.0" + "@svgr/babel-plugin-svg-dynamic-title" "^4.3.3" + "@svgr/babel-plugin-svg-em-dimensions" "^4.2.0" + "@svgr/babel-plugin-transform-react-native-svg" "^4.2.0" + "@svgr/babel-plugin-transform-svg-component" "^4.2.0" + +"@svgr/core@^4.3.3": + version "4.3.3" + resolved "https://registry.yarnpkg.com/@svgr/core/-/core-4.3.3.tgz#b37b89d5b757dc66e8c74156d00c368338d24293" + integrity sha512-qNuGF1QON1626UCaZamWt5yedpgOytvLj5BQZe2j1k1B8DUG4OyugZyfEwBeXozCUwhLEpsrgPrE+eCu4fY17w== + dependencies: + "@svgr/plugin-jsx" "^4.3.3" + camelcase "^5.3.1" + cosmiconfig "^5.2.1" + +"@svgr/hast-util-to-babel-ast@^4.3.2": + version "4.3.2" + resolved "https://registry.yarnpkg.com/@svgr/hast-util-to-babel-ast/-/hast-util-to-babel-ast-4.3.2.tgz#1d5a082f7b929ef8f1f578950238f630e14532b8" + integrity sha512-JioXclZGhFIDL3ddn4Kiq8qEqYM2PyDKV0aYno8+IXTLuYt6TOgHUbUAAFvqtb0Xn37NwP0BTHglejFoYr8RZg== + dependencies: + "@babel/types" "^7.4.4" + +"@svgr/plugin-jsx@^4.3.3": + version "4.3.3" + resolved "https://registry.yarnpkg.com/@svgr/plugin-jsx/-/plugin-jsx-4.3.3.tgz#e2ba913dbdfbe85252a34db101abc7ebd50992fa" + integrity sha512-cLOCSpNWQnDB1/v+SUENHH7a0XY09bfuMKdq9+gYvtuwzC2rU4I0wKGFEp1i24holdQdwodCtDQdFtJiTCWc+w== + dependencies: + "@babel/core" "^7.4.5" + "@svgr/babel-preset" "^4.3.3" + "@svgr/hast-util-to-babel-ast" "^4.3.2" + svg-parser "^2.0.0" + +"@svgr/plugin-svgo@^4.3.1": + version "4.3.1" + resolved "https://registry.yarnpkg.com/@svgr/plugin-svgo/-/plugin-svgo-4.3.1.tgz#daac0a3d872e3f55935c6588dd370336865e9e32" + integrity sha512-PrMtEDUWjX3Ea65JsVCwTIXuSqa3CG9px+DluF1/eo9mlDrgrtFE7NE/DjdhjJgSM9wenlVBzkzneSIUgfUI/w== + dependencies: + cosmiconfig "^5.2.1" + merge-deep "^3.0.2" + svgo "^1.2.2" + +"@svgr/webpack@4.3.3": + version "4.3.3" + resolved "https://registry.yarnpkg.com/@svgr/webpack/-/webpack-4.3.3.tgz#13cc2423bf3dff2d494f16b17eb7eacb86895017" + integrity sha512-bjnWolZ6KVsHhgyCoYRFmbd26p8XVbulCzSG53BDQqAr+JOAderYK7CuYrB3bDjHJuF6LJ7Wrr42+goLRV9qIg== + dependencies: + "@babel/core" "^7.4.5" + "@babel/plugin-transform-react-constant-elements" "^7.0.0" + "@babel/preset-env" "^7.4.5" + "@babel/preset-react" "^7.0.0" + "@svgr/core" "^4.3.3" + "@svgr/plugin-jsx" "^4.3.3" + "@svgr/plugin-svgo" "^4.3.1" + loader-utils "^1.2.3" + +"@types/babel__core@^7.1.0": + version "7.1.7" + resolved "https://registry.yarnpkg.com/@types/babel__core/-/babel__core-7.1.7.tgz#1dacad8840364a57c98d0dd4855c6dd3752c6b89" + integrity sha512-RL62NqSFPCDK2FM1pSDH0scHpJvsXtZNiYlMB73DgPBaG1E38ZYVL+ei5EkWRbr+KC4YNiAUNBnRj+bgwpgjMw== + dependencies: + "@babel/parser" "^7.1.0" + "@babel/types" "^7.0.0" + "@types/babel__generator" "*" + "@types/babel__template" "*" + "@types/babel__traverse" "*" + +"@types/babel__generator@*": + version "7.6.1" + resolved "https://registry.yarnpkg.com/@types/babel__generator/-/babel__generator-7.6.1.tgz#4901767b397e8711aeb99df8d396d7ba7b7f0e04" + integrity sha512-bBKm+2VPJcMRVwNhxKu8W+5/zT7pwNEqeokFOmbvVSqGzFneNxYcEBro9Ac7/N9tlsaPYnZLK8J1LWKkMsLAew== + dependencies: + "@babel/types" "^7.0.0" + +"@types/babel__template@*": + version "7.0.2" + resolved "https://registry.yarnpkg.com/@types/babel__template/-/babel__template-7.0.2.tgz#4ff63d6b52eddac1de7b975a5223ed32ecea9307" + integrity sha512-/K6zCpeW7Imzgab2bLkLEbz0+1JlFSrUMdw7KoIIu+IUdu51GWaBZpd3y1VXGVXzynvGa4DaIaxNZHiON3GXUg== + dependencies: + "@babel/parser" "^7.1.0" + "@babel/types" "^7.0.0" + +"@types/babel__traverse@*", "@types/babel__traverse@^7.0.6": + version "7.0.9" + resolved "https://registry.yarnpkg.com/@types/babel__traverse/-/babel__traverse-7.0.9.tgz#be82fab304b141c3eee81a4ce3b034d0eba1590a" + integrity sha512-jEFQ8L1tuvPjOI8lnpaf73oCJe+aoxL6ygqSy6c8LcW98zaC+4mzWuQIRCEvKeCOu+lbqdXcg4Uqmm1S8AP1tw== + dependencies: + "@babel/types" "^7.3.0" + +"@types/chart.js@^2.9.29": + version "2.9.29" + resolved "https://registry.yarnpkg.com/@types/chart.js/-/chart.js-2.9.29.tgz#73bf7f02387402943f29946012492f10bde7ed43" + integrity sha512-WOZMitUU3gHDM0oQsCsVivX+oDsIki93szcTmmUPBm39cCvAELBjokjSDVOoA3xiIEbb+jp17z/3S2tIqruwOQ== + dependencies: + moment "^2.10.2" + +"@types/color-name@^1.1.1": + version "1.1.1" + resolved "https://registry.yarnpkg.com/@types/color-name/-/color-name-1.1.1.tgz#1c1261bbeaa10a8055bbc5d8ab84b7b2afc846a0" + integrity sha512-rr+OQyAjxze7GgWrSaJwydHStIhHq2lvY3BOC2Mj7KnzI7XK0Uw1TOOdI9lDoajEbSWLiYgoo4f1R51erQfhPQ== + +"@types/eslint-visitor-keys@^1.0.0": + version "1.0.0" + resolved "https://registry.yarnpkg.com/@types/eslint-visitor-keys/-/eslint-visitor-keys-1.0.0.tgz#1ee30d79544ca84d68d4b3cdb0af4f205663dd2d" + integrity sha512-OCutwjDZ4aFS6PB1UZ988C4YgwlBHJd6wCeQqaLdmadZ/7e+w79+hbMUFC1QXDNCmdyoRfAFdm0RypzwR+Qpag== + +"@types/events@*": + version "3.0.0" + resolved "https://registry.yarnpkg.com/@types/events/-/events-3.0.0.tgz#2862f3f58a9a7f7c3e78d79f130dd4d71c25c2a7" + integrity sha512-EaObqwIvayI5a8dCzhFrjKzVwKLxjoG9T6Ppd5CEo07LRKfQ8Yokw54r5+Wq7FaBQ+yXRvQAYPrHwya1/UFt9g== + +"@types/glob@^7.1.1": + version "7.1.1" + resolved "https://registry.yarnpkg.com/@types/glob/-/glob-7.1.1.tgz#aa59a1c6e3fbc421e07ccd31a944c30eba521575" + integrity sha512-1Bh06cbWJUHMC97acuD6UMG29nMt0Aqz1vF3guLfG+kHHJhy3AyohZFFxYk2f7Q1SQIrNwvncxAE0N/9s70F2w== + dependencies: + "@types/events" "*" + "@types/minimatch" "*" + "@types/node" "*" + +"@types/istanbul-lib-coverage@*", "@types/istanbul-lib-coverage@^2.0.0": + version "2.0.1" + resolved "https://registry.yarnpkg.com/@types/istanbul-lib-coverage/-/istanbul-lib-coverage-2.0.1.tgz#42995b446db9a48a11a07ec083499a860e9138ff" + integrity sha512-hRJD2ahnnpLgsj6KWMYSrmXkM3rm2Dl1qkx6IOFD5FnuNPXJIG5L0dhgKXCYTRMGzU4n0wImQ/xfmRc4POUFlg== + +"@types/istanbul-lib-report@*": + version "3.0.0" + resolved "https://registry.yarnpkg.com/@types/istanbul-lib-report/-/istanbul-lib-report-3.0.0.tgz#c14c24f18ea8190c118ee7562b7ff99a36552686" + integrity sha512-plGgXAPfVKFoYfa9NpYDAkseG+g6Jr294RqeqcqDixSbU34MZVJRi/P+7Y8GDpzkEwLaGZZOpKIEmeVZNtKsrg== + dependencies: + "@types/istanbul-lib-coverage" "*" + +"@types/istanbul-reports@^1.1.1": + version "1.1.1" + resolved "https://registry.yarnpkg.com/@types/istanbul-reports/-/istanbul-reports-1.1.1.tgz#7a8cbf6a406f36c8add871625b278eaf0b0d255a" + integrity sha512-UpYjBi8xefVChsCoBpKShdxTllC9pwISirfoZsUa2AAdQg/Jd2KQGtSbw+ya7GPo7x/wAPlH6JBhKhAsXUEZNA== + dependencies: + "@types/istanbul-lib-coverage" "*" + "@types/istanbul-lib-report" "*" + +"@types/jest@^26.0.0": + version "26.0.0" + resolved "https://registry.yarnpkg.com/@types/jest/-/jest-26.0.0.tgz#a6d7573dffa9c68cbbdf38f2e0de26f159e11134" + integrity sha512-/yeMsH9HQ1RLORlXAwoLXe8S98xxvhNtUz3yrgrwbaxYjT+6SFPZZRksmRKRA6L5vsUtSHeN71viDOTTyYAD+g== + dependencies: + jest-diff "^25.2.1" + pretty-format "^25.2.1" + +"@types/json-schema@^7.0.3": + version "7.0.4" + resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.4.tgz#38fd73ddfd9b55abb1e1b2ed578cb55bd7b7d339" + integrity sha512-8+KAKzEvSUdeo+kmqnKrqgeE+LcA0tjYWFY7RPProVYwnqDjukzO+3b6dLD56rYX5TdWejnEOLJYOIeh4CXKuA== + +"@types/minimatch@*": + version "3.0.3" + resolved "https://registry.yarnpkg.com/@types/minimatch/-/minimatch-3.0.3.tgz#3dca0e3f33b200fc7d1139c0cd96c1268cadfd9d" + integrity sha512-tHq6qdbT9U1IRSGf14CL0pUlULksvY9OZ+5eEgl1N7t+OA3tGvNpxJCzuKQlsNgCVwbAs670L1vcVQi8j9HjnA== + +"@types/node@*": + version "13.9.8" + resolved "https://registry.yarnpkg.com/@types/node/-/node-13.9.8.tgz#09976420fc80a7a00bf40680c63815ed8c7616f4" + integrity sha512-1WgO8hsyHynlx7nhP1kr0OFzsgKz5XDQL+Lfc3b1Q3qIln/n8cKD4m09NJ0+P1Rq7Zgnc7N0+SsMnoD1rEb0kA== + +"@types/node@^14.0.13": + version "14.0.13" + resolved "https://registry.yarnpkg.com/@types/node/-/node-14.0.13.tgz#ee1128e881b874c371374c1f72201893616417c9" + integrity sha512-rouEWBImiRaSJsVA+ITTFM6ZxibuAlTuNOCyxVbwreu6k6+ujs7DfnU9o+PShFhET78pMBl3eH+AGSI5eOTkPA== + +"@types/parse-json@^4.0.0": + version "4.0.0" + resolved "https://registry.yarnpkg.com/@types/parse-json/-/parse-json-4.0.0.tgz#2f8bb441434d163b35fb8ffdccd7138927ffb8c0" + integrity sha512-//oorEZjL6sbPcKUaCdIGlIUeH26mgzimjBB77G6XRgnDl/L5wOnpyBGRe/Mmf5CVW3PwEBE1NjiMZ/ssFh4wA== + +"@types/prop-types@*": + version "15.7.3" + resolved "https://registry.yarnpkg.com/@types/prop-types/-/prop-types-15.7.3.tgz#2ab0d5da2e5815f94b0b9d4b95d1e5f243ab2ca7" + integrity sha512-KfRL3PuHmqQLOG+2tGpRO26Ctg+Cq1E01D2DMriKEATHgWLfeNDmq9e29Q9WIky0dQ3NPkd1mzYH8Lm936Z9qw== + +"@types/q@^1.5.1": + version "1.5.2" + resolved "https://registry.yarnpkg.com/@types/q/-/q-1.5.2.tgz#690a1475b84f2a884fd07cd797c00f5f31356ea8" + integrity sha512-ce5d3q03Ex0sy4R14722Rmt6MT07Ua+k4FwDfdcToYJcMKNtRVQvJ6JCAPdAmAnbRb6CsX6aYb9m96NGod9uTw== + +"@types/react-dom@^16.9.8": + version "16.9.8" + resolved "https://registry.yarnpkg.com/@types/react-dom/-/react-dom-16.9.8.tgz#fe4c1e11dfc67155733dfa6aa65108b4971cb423" + integrity sha512-ykkPQ+5nFknnlU6lDd947WbQ6TE3NNzbQAkInC2EKY1qeYdTKp7onFusmYZb+ityzx2YviqT6BXSu+LyWWJwcA== + dependencies: + "@types/react" "*" + +"@types/react-tag-autocomplete@^5.12.0": + version "5.12.0" + resolved "https://registry.yarnpkg.com/@types/react-tag-autocomplete/-/react-tag-autocomplete-5.12.0.tgz#bcebefd07abd20ee7a1d9594cf4f2f0297625868" + integrity sha512-QvPBUrXnkU5e3EaHiXQcI0oRsV1JwvL2vvgXtY5x7SkW19O3FRGFdcPg4+SA545+pcIsijf1O2wQ/sCCTeEe9w== + dependencies: + "@types/react" "*" + +"@types/react@*", "@types/react@^16.9.38": + version "16.9.38" + resolved "https://registry.yarnpkg.com/@types/react/-/react-16.9.38.tgz#868405dace93a4095d3e054f4c4a1de7a1ac0680" + integrity sha512-pHAeZbjjNRa/hxyNuLrvbxhhnKyKNiLC6I5fRF2Zr/t/S6zS41MiyzH4+c+1I9vVfvuRt1VS2Lodjr4ZWnxrdA== + dependencies: + "@types/prop-types" "*" + csstype "^2.2.0" + +"@types/stack-utils@^1.0.1": + version "1.0.1" + resolved "https://registry.yarnpkg.com/@types/stack-utils/-/stack-utils-1.0.1.tgz#0a851d3bd96498fa25c33ab7278ed3bd65f06c3e" + integrity sha512-l42BggppR6zLmpfU6fq9HEa2oGPEI8yrSPL3GITjfRInppYFahObbIQOQK3UGxEnyQpltZLaPe75046NOZQikw== + +"@types/yargs-parser@*": + version "15.0.0" + resolved "https://registry.yarnpkg.com/@types/yargs-parser/-/yargs-parser-15.0.0.tgz#cb3f9f741869e20cce330ffbeb9271590483882d" + integrity sha512-FA/BWv8t8ZWJ+gEOnLLd8ygxH/2UFbAvgEonyfN6yWGLKc7zVjbpl2Y4CTjid9h2RfgPP6SEt6uHwEOply00yw== + +"@types/yargs@^13.0.0": + version "13.0.8" + resolved "https://registry.yarnpkg.com/@types/yargs/-/yargs-13.0.8.tgz#a38c22def2f1c2068f8971acb3ea734eb3c64a99" + integrity sha512-XAvHLwG7UQ+8M4caKIH0ZozIOYay5fQkAgyIXegXT9jPtdIGdhga+sUEdAr1CiG46aB+c64xQEYyEzlwWVTNzA== + dependencies: + "@types/yargs-parser" "*" + +"@types/yargs@^15.0.0": + version "15.0.5" + resolved "https://registry.yarnpkg.com/@types/yargs/-/yargs-15.0.5.tgz#947e9a6561483bdee9adffc983e91a6902af8b79" + integrity sha512-Dk/IDOPtOgubt/IaevIUbTgV7doaKkoorvOyYM2CMwuDyP89bekI7H4xLIwunNYiK9jhCkmc6pUrJk3cj2AB9w== + dependencies: + "@types/yargs-parser" "*" + +"@typescript-eslint/eslint-plugin@^2.10.0": + version "2.26.0" + resolved "https://registry.yarnpkg.com/@typescript-eslint/eslint-plugin/-/eslint-plugin-2.26.0.tgz#04c96560c8981421e5a9caad8394192363cc423f" + integrity sha512-4yUnLv40bzfzsXcTAtZyTjbiGUXMrcIJcIMioI22tSOyAxpdXiZ4r7YQUU8Jj6XXrLz9d5aMHPQf5JFR7h27Nw== + dependencies: + "@typescript-eslint/experimental-utils" "2.26.0" + functional-red-black-tree "^1.0.1" + regexpp "^3.0.0" + tsutils "^3.17.1" + +"@typescript-eslint/experimental-utils@2.26.0": + version "2.26.0" + resolved "https://registry.yarnpkg.com/@typescript-eslint/experimental-utils/-/experimental-utils-2.26.0.tgz#063390c404d9980767d76274df386c0aa675d91d" + integrity sha512-RELVoH5EYd+JlGprEyojUv9HeKcZqF7nZUGSblyAw1FwOGNnmQIU8kxJ69fttQvEwCsX5D6ECJT8GTozxrDKVQ== + dependencies: + "@types/json-schema" "^7.0.3" + "@typescript-eslint/typescript-estree" "2.26.0" + eslint-scope "^5.0.0" + eslint-utils "^2.0.0" + +"@typescript-eslint/parser@^2.10.0": + version "2.26.0" + resolved "https://registry.yarnpkg.com/@typescript-eslint/parser/-/parser-2.26.0.tgz#385463615818b33acb72a25b39c03579df93d76f" + integrity sha512-+Xj5fucDtdKEVGSh9353wcnseMRkPpEAOY96EEenN7kJVrLqy/EVwtIh3mxcUz8lsFXW1mT5nN5vvEam/a5HiQ== + dependencies: + "@types/eslint-visitor-keys" "^1.0.0" + "@typescript-eslint/experimental-utils" "2.26.0" + "@typescript-eslint/typescript-estree" "2.26.0" + eslint-visitor-keys "^1.1.0" + +"@typescript-eslint/typescript-estree@2.26.0": + version "2.26.0" + resolved "https://registry.yarnpkg.com/@typescript-eslint/typescript-estree/-/typescript-estree-2.26.0.tgz#d8132cf1ee8a72234f996519a47d8a9118b57d56" + integrity sha512-3x4SyZCLB4zsKsjuhxDLeVJN6W29VwBnYpCsZ7vIdPel9ZqLfIZJgJXO47MNUkurGpQuIBALdPQKtsSnWpE1Yg== + dependencies: + debug "^4.1.1" + eslint-visitor-keys "^1.1.0" + glob "^7.1.6" + is-glob "^4.0.1" + lodash "^4.17.15" + semver "^6.3.0" + tsutils "^3.17.1" + +"@webassemblyjs/ast@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/ast/-/ast-1.8.5.tgz#51b1c5fe6576a34953bf4b253df9f0d490d9e359" + integrity sha512-aJMfngIZ65+t71C3y2nBBg5FFG0Okt9m0XEgWZ7Ywgn1oMAT8cNwx00Uv1cQyHtidq0Xn94R4TAywO+LCQ+ZAQ== + dependencies: + "@webassemblyjs/helper-module-context" "1.8.5" + "@webassemblyjs/helper-wasm-bytecode" "1.8.5" + "@webassemblyjs/wast-parser" "1.8.5" + +"@webassemblyjs/floating-point-hex-parser@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.8.5.tgz#1ba926a2923613edce496fd5b02e8ce8a5f49721" + integrity sha512-9p+79WHru1oqBh9ewP9zW95E3XAo+90oth7S5Re3eQnECGq59ly1Ri5tsIipKGpiStHsUYmY3zMLqtk3gTcOtQ== + +"@webassemblyjs/helper-api-error@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/helper-api-error/-/helper-api-error-1.8.5.tgz#c49dad22f645227c5edb610bdb9697f1aab721f7" + integrity sha512-Za/tnzsvnqdaSPOUXHyKJ2XI7PDX64kWtURyGiJJZKVEdFOsdKUCPTNEVFZq3zJ2R0G5wc2PZ5gvdTRFgm81zA== + +"@webassemblyjs/helper-buffer@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/helper-buffer/-/helper-buffer-1.8.5.tgz#fea93e429863dd5e4338555f42292385a653f204" + integrity sha512-Ri2R8nOS0U6G49Q86goFIPNgjyl6+oE1abW1pS84BuhP1Qcr5JqMwRFT3Ah3ADDDYGEgGs1iyb1DGX+kAi/c/Q== + +"@webassemblyjs/helper-code-frame@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/helper-code-frame/-/helper-code-frame-1.8.5.tgz#9a740ff48e3faa3022b1dff54423df9aa293c25e" + integrity sha512-VQAadSubZIhNpH46IR3yWO4kZZjMxN1opDrzePLdVKAZ+DFjkGD/rf4v1jap744uPVU6yjL/smZbRIIJTOUnKQ== + dependencies: + "@webassemblyjs/wast-printer" "1.8.5" + +"@webassemblyjs/helper-fsm@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/helper-fsm/-/helper-fsm-1.8.5.tgz#ba0b7d3b3f7e4733da6059c9332275d860702452" + integrity sha512-kRuX/saORcg8se/ft6Q2UbRpZwP4y7YrWsLXPbbmtepKr22i8Z4O3V5QE9DbZK908dh5Xya4Un57SDIKwB9eow== + +"@webassemblyjs/helper-module-context@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/helper-module-context/-/helper-module-context-1.8.5.tgz#def4b9927b0101dc8cbbd8d1edb5b7b9c82eb245" + integrity sha512-/O1B236mN7UNEU4t9X7Pj38i4VoU8CcMHyy3l2cV/kIF4U5KoHXDVqcDuOs1ltkac90IM4vZdHc52t1x8Yfs3g== + dependencies: + "@webassemblyjs/ast" "1.8.5" + mamacro "^0.0.3" + +"@webassemblyjs/helper-wasm-bytecode@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.8.5.tgz#537a750eddf5c1e932f3744206551c91c1b93e61" + integrity sha512-Cu4YMYG3Ddl72CbmpjU/wbP6SACcOPVbHN1dI4VJNJVgFwaKf1ppeFJrwydOG3NDHxVGuCfPlLZNyEdIYlQ6QQ== + +"@webassemblyjs/helper-wasm-section@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.8.5.tgz#74ca6a6bcbe19e50a3b6b462847e69503e6bfcbf" + integrity sha512-VV083zwR+VTrIWWtgIUpqfvVdK4ff38loRmrdDBgBT8ADXYsEZ5mPQ4Nde90N3UYatHdYoDIFb7oHzMncI02tA== + dependencies: + "@webassemblyjs/ast" "1.8.5" + "@webassemblyjs/helper-buffer" "1.8.5" + "@webassemblyjs/helper-wasm-bytecode" "1.8.5" + "@webassemblyjs/wasm-gen" "1.8.5" + +"@webassemblyjs/ieee754@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/ieee754/-/ieee754-1.8.5.tgz#712329dbef240f36bf57bd2f7b8fb9bf4154421e" + integrity sha512-aaCvQYrvKbY/n6wKHb/ylAJr27GglahUO89CcGXMItrOBqRarUMxWLJgxm9PJNuKULwN5n1csT9bYoMeZOGF3g== + dependencies: + "@xtuc/ieee754" "^1.2.0" + +"@webassemblyjs/leb128@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/leb128/-/leb128-1.8.5.tgz#044edeb34ea679f3e04cd4fd9824d5e35767ae10" + integrity sha512-plYUuUwleLIziknvlP8VpTgO4kqNaH57Y3JnNa6DLpu/sGcP6hbVdfdX5aHAV716pQBKrfuU26BJK29qY37J7A== + dependencies: + "@xtuc/long" "4.2.2" + +"@webassemblyjs/utf8@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/utf8/-/utf8-1.8.5.tgz#a8bf3b5d8ffe986c7c1e373ccbdc2a0915f0cedc" + integrity sha512-U7zgftmQriw37tfD934UNInokz6yTmn29inT2cAetAsaU9YeVCveWEwhKL1Mg4yS7q//NGdzy79nlXh3bT8Kjw== + +"@webassemblyjs/wasm-edit@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/wasm-edit/-/wasm-edit-1.8.5.tgz#962da12aa5acc1c131c81c4232991c82ce56e01a" + integrity sha512-A41EMy8MWw5yvqj7MQzkDjU29K7UJq1VrX2vWLzfpRHt3ISftOXqrtojn7nlPsZ9Ijhp5NwuODuycSvfAO/26Q== + dependencies: + "@webassemblyjs/ast" "1.8.5" + "@webassemblyjs/helper-buffer" "1.8.5" + "@webassemblyjs/helper-wasm-bytecode" "1.8.5" + "@webassemblyjs/helper-wasm-section" "1.8.5" + "@webassemblyjs/wasm-gen" "1.8.5" + "@webassemblyjs/wasm-opt" "1.8.5" + "@webassemblyjs/wasm-parser" "1.8.5" + "@webassemblyjs/wast-printer" "1.8.5" + +"@webassemblyjs/wasm-gen@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/wasm-gen/-/wasm-gen-1.8.5.tgz#54840766c2c1002eb64ed1abe720aded714f98bc" + integrity sha512-BCZBT0LURC0CXDzj5FXSc2FPTsxwp3nWcqXQdOZE4U7h7i8FqtFK5Egia6f9raQLpEKT1VL7zr4r3+QX6zArWg== + dependencies: + "@webassemblyjs/ast" "1.8.5" + "@webassemblyjs/helper-wasm-bytecode" "1.8.5" + "@webassemblyjs/ieee754" "1.8.5" + "@webassemblyjs/leb128" "1.8.5" + "@webassemblyjs/utf8" "1.8.5" + +"@webassemblyjs/wasm-opt@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/wasm-opt/-/wasm-opt-1.8.5.tgz#b24d9f6ba50394af1349f510afa8ffcb8a63d264" + integrity sha512-HKo2mO/Uh9A6ojzu7cjslGaHaUU14LdLbGEKqTR7PBKwT6LdPtLLh9fPY33rmr5wcOMrsWDbbdCHq4hQUdd37Q== + dependencies: + "@webassemblyjs/ast" "1.8.5" + "@webassemblyjs/helper-buffer" "1.8.5" + "@webassemblyjs/wasm-gen" "1.8.5" + "@webassemblyjs/wasm-parser" "1.8.5" + +"@webassemblyjs/wasm-parser@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/wasm-parser/-/wasm-parser-1.8.5.tgz#21576f0ec88b91427357b8536383668ef7c66b8d" + integrity sha512-pi0SYE9T6tfcMkthwcgCpL0cM9nRYr6/6fjgDtL6q/ZqKHdMWvxitRi5JcZ7RI4SNJJYnYNaWy5UUrHQy998lw== + dependencies: + "@webassemblyjs/ast" "1.8.5" + "@webassemblyjs/helper-api-error" "1.8.5" + "@webassemblyjs/helper-wasm-bytecode" "1.8.5" + "@webassemblyjs/ieee754" "1.8.5" + "@webassemblyjs/leb128" "1.8.5" + "@webassemblyjs/utf8" "1.8.5" + +"@webassemblyjs/wast-parser@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/wast-parser/-/wast-parser-1.8.5.tgz#e10eecd542d0e7bd394f6827c49f3df6d4eefb8c" + integrity sha512-daXC1FyKWHF1i11obK086QRlsMsY4+tIOKgBqI1lxAnkp9xe9YMcgOxm9kLe+ttjs5aWV2KKE1TWJCN57/Btsg== + dependencies: + "@webassemblyjs/ast" "1.8.5" + "@webassemblyjs/floating-point-hex-parser" "1.8.5" + "@webassemblyjs/helper-api-error" "1.8.5" + "@webassemblyjs/helper-code-frame" "1.8.5" + "@webassemblyjs/helper-fsm" "1.8.5" + "@xtuc/long" "4.2.2" + +"@webassemblyjs/wast-printer@1.8.5": + version "1.8.5" + resolved "https://registry.yarnpkg.com/@webassemblyjs/wast-printer/-/wast-printer-1.8.5.tgz#114bbc481fd10ca0e23b3560fa812748b0bae5bc" + integrity sha512-w0U0pD4EhlnvRyeJzBqaVSJAo9w/ce7/WPogeXLzGkO6hzhr4GnQIZ4W4uUt5b9ooAaXPtnXlj0gzsXEOUNYMg== + dependencies: + "@webassemblyjs/ast" "1.8.5" + "@webassemblyjs/wast-parser" "1.8.5" + "@xtuc/long" "4.2.2" + +"@xtuc/ieee754@^1.2.0": + version "1.2.0" + resolved "https://registry.yarnpkg.com/@xtuc/ieee754/-/ieee754-1.2.0.tgz#eef014a3145ae477a1cbc00cd1e552336dceb790" + integrity sha512-DX8nKgqcGwsc0eJSqYt5lwP4DH5FlHnmuWWBRy7X0NcaGR0ZtuyeESgMwTYVEtxmsNGY+qit4QYT/MIYTOTPeA== + +"@xtuc/long@4.2.2": + version "4.2.2" + resolved "https://registry.yarnpkg.com/@xtuc/long/-/long-4.2.2.tgz#d291c6a4e97989b5c61d9acf396ae4fe133a718d" + integrity sha512-NuHqBY1PB/D8xU6s/thBgOAiAP7HOYDQ32+BFZILJ8ivkUkAHQnWfn6WhL79Owj1qmUnoN/YPhktdIoucipkAQ== + +abab@^2.0.0: + version "2.0.3" + resolved "https://registry.yarnpkg.com/abab/-/abab-2.0.3.tgz#623e2075e02eb2d3f2475e49f99c91846467907a" + integrity sha512-tsFzPpcttalNjFBCFMqsKYQcWxxen1pgJR56by//QwvJc4/OUS3kPOOttx2tSIfjsylB0pYu7f5D3K1RCxUnUg== + +accepts@~1.3.4, accepts@~1.3.5, accepts@~1.3.7: + version "1.3.7" + resolved "https://registry.yarnpkg.com/accepts/-/accepts-1.3.7.tgz#531bc726517a3b2b41f850021c6cc15eaab507cd" + integrity sha512-Il80Qs2WjYlJIBNzNkK6KYqlVMTbZLXgHx2oT0pU/fjRHyEp+PEfEPY0R3WCwAGVOtauxh1hOxNgIf5bv7dQpA== + dependencies: + mime-types "~2.1.24" + negotiator "0.6.2" + +acorn-globals@^4.1.0, acorn-globals@^4.3.0: + version "4.3.4" + resolved "https://registry.yarnpkg.com/acorn-globals/-/acorn-globals-4.3.4.tgz#9fa1926addc11c97308c4e66d7add0d40c3272e7" + integrity sha512-clfQEh21R+D0leSbUdWf3OcfqyaCSAQ8Ryq00bofSekfr9W8u1jyYZo6ir0xu9Gtcf7BjcHJpnbZH7JOCpP60A== + dependencies: + acorn "^6.0.1" + acorn-walk "^6.0.1" + +acorn-jsx@^5.2.0: + version "5.2.0" + resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.2.0.tgz#4c66069173d6fdd68ed85239fc256226182b2ebe" + integrity sha512-HiUX/+K2YpkpJ+SzBffkM/AQ2YE03S0U1kjTLVpoJdhZMOWy8qvXVN9JdLqv2QsaQ6MPYQIuNmwD8zOiYUofLQ== + +acorn-walk@^6.0.1: + version "6.2.0" + resolved "https://registry.yarnpkg.com/acorn-walk/-/acorn-walk-6.2.0.tgz#123cb8f3b84c2171f1f7fb252615b1c78a6b1a8c" + integrity sha512-7evsyfH1cLOCdAzZAd43Cic04yKydNx0cF+7tiA19p1XnLLPU4dpCQOqpjqwokFe//vS0QqfqqjCS2JkiIs0cA== + +acorn@^5.5.3: + version "5.7.4" + resolved "https://registry.yarnpkg.com/acorn/-/acorn-5.7.4.tgz#3e8d8a9947d0599a1796d10225d7432f4a4acf5e" + integrity sha512-1D++VG7BhrtvQpNbBzovKNc1FLGGEE/oGe7b9xJm/RFHMBeUaUGpluV9RLjZa47YFdPcDAenEYuq9pQPcMdLJg== + +acorn@^6.0.1, acorn@^6.0.4, acorn@^6.2.1: + version "6.4.1" + resolved "https://registry.yarnpkg.com/acorn/-/acorn-6.4.1.tgz#531e58ba3f51b9dacb9a6646ca4debf5b14ca474" + integrity sha512-ZVA9k326Nwrj3Cj9jlh3wGFutC2ZornPNARZwsNYqQYgN0EsV2d53w5RN/co65Ohn4sUAUtb1rSUAOD6XN9idA== + +acorn@^7.1.1: + version "7.1.1" + resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.1.1.tgz#e35668de0b402f359de515c5482a1ab9f89a69bf" + integrity sha512-add7dgA5ppRPxCFJoAGfMDi7PIBXq1RtGo7BhbLaxwrXPOmw8gq48Y9ozT01hUKy9byMjlR20EJhu5zlkErEkg== + +address@1.1.2, address@^1.0.1: + version "1.1.2" + resolved "https://registry.yarnpkg.com/address/-/address-1.1.2.tgz#bf1116c9c758c51b7a933d296b72c221ed9428b6" + integrity sha512-aT6camzM4xEA54YVJYSqxz1kv4IHnQZRtThJJHhUMRExaU5spC7jX5ugSwTaTgJliIgs4VhZOk7htClvQ/LmRA== + +adjust-sourcemap-loader@2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/adjust-sourcemap-loader/-/adjust-sourcemap-loader-2.0.0.tgz#6471143af75ec02334b219f54bc7970c52fb29a4" + integrity sha512-4hFsTsn58+YjrU9qKzML2JSSDqKvN8mUGQ0nNIrfPi8hmIONT4L3uUaT6MKdMsZ9AjsU6D2xDkZxCkbQPxChrA== + dependencies: + assert "1.4.1" + camelcase "5.0.0" + loader-utils "1.2.3" + object-path "0.11.4" + regex-parser "2.2.10" + +aggregate-error@^3.0.0: + version "3.0.1" + resolved "https://registry.yarnpkg.com/aggregate-error/-/aggregate-error-3.0.1.tgz#db2fe7246e536f40d9b5442a39e117d7dd6a24e0" + integrity sha512-quoaXsZ9/BLNae5yiNoUz+Nhkwz83GhWwtYFglcjEQB2NDHCIpApbqXxIFnm4Pq/Nvhrsq5sYJFyohrrxnTGAA== + dependencies: + clean-stack "^2.0.0" + indent-string "^4.0.0" + +ajv-errors@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/ajv-errors/-/ajv-errors-1.0.1.tgz#f35986aceb91afadec4102fbd85014950cefa64d" + integrity sha512-DCRfO/4nQ+89p/RK43i8Ezd41EqdGIU4ld7nGF8OQ14oc/we5rEntLCUa7+jrn3nn83BosfwZA0wb4pon2o8iQ== + +ajv-keywords@^3.1.0, ajv-keywords@^3.4.1: + version "3.4.1" + resolved "https://registry.yarnpkg.com/ajv-keywords/-/ajv-keywords-3.4.1.tgz#ef916e271c64ac12171fd8384eaae6b2345854da" + integrity sha512-RO1ibKvd27e6FEShVFfPALuHI3WjSVNeK5FIsmme/LYRNxjKuNj+Dt7bucLa6NdSv3JcVTyMlm9kGR84z1XpaQ== + +ajv@^6.1.0, ajv@^6.10.0, ajv@^6.10.2, ajv@^6.12.0, ajv@^6.5.5: + version "6.12.0" + resolved "https://registry.yarnpkg.com/ajv/-/ajv-6.12.0.tgz#06d60b96d87b8454a5adaba86e7854da629db4b7" + integrity sha512-D6gFiFA0RRLyUbvijN74DWAjXSFxWKaWP7mldxkVhyhAV3+SWA9HEJPHQ2c9soIeTFJqcSdFDGFgdqs1iUU2Hw== + dependencies: + fast-deep-equal "^3.1.1" + fast-json-stable-stringify "^2.0.0" + json-schema-traverse "^0.4.1" + uri-js "^4.2.2" + +alphanum-sort@^1.0.0: + version "1.0.2" + resolved "https://registry.yarnpkg.com/alphanum-sort/-/alphanum-sort-1.0.2.tgz#97a1119649b211ad33691d9f9f486a8ec9fbe0a3" + integrity sha1-l6ERlkmyEa0zaR2fn0hqjsn74KM= + +ansi-colors@^3.0.0: + version "3.2.4" + resolved "https://registry.yarnpkg.com/ansi-colors/-/ansi-colors-3.2.4.tgz#e3a3da4bfbae6c86a9c285625de124a234026fbf" + integrity sha512-hHUXGagefjN2iRrID63xckIvotOXOojhQKWIPUZ4mNUZ9nLZW+7FMNoE1lOkEhNWYsx/7ysGIuJYCiMAA9FnrA== + +ansi-escapes@^3.0.0: + version "3.2.0" + resolved "https://registry.yarnpkg.com/ansi-escapes/-/ansi-escapes-3.2.0.tgz#8780b98ff9dbf5638152d1f1fe5c1d7b4442976b" + integrity sha512-cBhpre4ma+U0T1oM5fXg7Dy1Jw7zzwv7lt/GoCpr+hDQJoYnKVPLL4dCvSEFMmQurOQvSrwT7SL/DAlhBI97RQ== + +ansi-escapes@^4.2.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/ansi-escapes/-/ansi-escapes-4.3.1.tgz#a5c47cc43181f1f38ffd7076837700d395522a61" + integrity sha512-JWF7ocqNrp8u9oqpgV+wH5ftbt+cfvv+PTjOvKLT3AdYly/LmORARfEVT1iyjwN+4MqE5UmVKoAdIBqeoCHgLA== + dependencies: + type-fest "^0.11.0" + +ansi-html@0.0.7: + version "0.0.7" + resolved "https://registry.yarnpkg.com/ansi-html/-/ansi-html-0.0.7.tgz#813584021962a9e9e6fd039f940d12f56ca7859e" + integrity sha1-gTWEAhliqenm/QOflA0S9WynhZ4= + +ansi-regex@^2.0.0: + version "2.1.1" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-2.1.1.tgz#c3b33ab5ee360d86e0e628f0468ae7ef27d654df" + integrity sha1-w7M6te42DYbg5ijwRorn7yfWVN8= + +ansi-regex@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-3.0.0.tgz#ed0317c322064f79466c02966bddb605ab37d998" + integrity sha1-7QMXwyIGT3lGbAKWa922Bas32Zg= + +ansi-regex@^4.0.0, ansi-regex@^4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-4.1.0.tgz#8b9f8f08cf1acb843756a839ca8c7e3168c51997" + integrity sha512-1apePfXM1UOSqw0o9IiFAovVz9M5S1Dg+4TrDwfMewQ6p/rmMueb7tWZjQ1rx4Loy1ArBggoqGpfqqdI4rondg== + +ansi-regex@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.0.tgz#388539f55179bf39339c81af30a654d69f87cb75" + integrity sha512-bY6fj56OUQ0hU1KjFNDQuJFezqKdrAyFdIevADiqrWHwSlbmBNMHp5ak2f40Pm8JTFyM2mqxkG6ngkHO11f/lg== + +ansi-styles@^2.2.1: + version "2.2.1" + resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-2.2.1.tgz#b432dd3358b634cf75e1e4664368240533c1ddbe" + integrity sha1-tDLdM1i2NM914eRmQ2gkBTPB3b4= + +ansi-styles@^3.2.0, ansi-styles@^3.2.1: + version "3.2.1" + resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-3.2.1.tgz#41fbb20243e50b12be0f04b8dedbf07520ce841d" + integrity sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA== + dependencies: + color-convert "^1.9.0" + +ansi-styles@^4.0.0, ansi-styles@^4.1.0: + version "4.2.1" + resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.2.1.tgz#90ae75c424d008d2624c5bf29ead3177ebfcf359" + integrity sha512-9VGjrMsG1vePxcSweQsN20KY/c4zN0h9fLjqAbwbPfahM3t+NL+M9HC8xeXG2I8pX5NoamTGNuomEUFI7fcUjA== + dependencies: + "@types/color-name" "^1.1.1" + color-convert "^2.0.1" + +anymatch@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/anymatch/-/anymatch-2.0.0.tgz#bcb24b4f37934d9aa7ac17b4adaf89e7c76ef2eb" + integrity sha512-5teOsQWABXHHBFP9y3skS5P3d/WfWXpv3FUpy+LorMrNYaT9pI4oLMQX7jzQ2KklNpGpWHzdCXTDT2Y3XGlZBw== + dependencies: + micromatch "^3.1.4" + normalize-path "^2.1.1" + +anymatch@~3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/anymatch/-/anymatch-3.1.1.tgz#c55ecf02185e2469259399310c173ce31233b142" + integrity sha512-mM8522psRCqzV+6LhomX5wgp25YVibjh8Wj23I5RPkPppSVSjyKD2A2mBJmWGa+KN7f2D6LNh9jkBCeyLktzjg== + dependencies: + normalize-path "^3.0.0" + picomatch "^2.0.4" + +aproba@^1.1.1: + version "1.2.0" + resolved "https://registry.yarnpkg.com/aproba/-/aproba-1.2.0.tgz#6802e6264efd18c790a1b0d517f0f2627bf2c94a" + integrity sha512-Y9J6ZjXtoYh8RnXVCMOU/ttDmk1aBjunq9vO0ta5x85WDQiQfUF9sIPBITdbiiIVcBo03Hi3jMxigBtsddlXRw== + +argparse@^1.0.7: + version "1.0.10" + resolved "https://registry.yarnpkg.com/argparse/-/argparse-1.0.10.tgz#bcd6791ea5ae09725e17e5ad988134cd40b3d911" + integrity sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg== + dependencies: + sprintf-js "~1.0.2" + +aria-query@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/aria-query/-/aria-query-3.0.0.tgz#65b3fcc1ca1155a8c9ae64d6eee297f15d5133cc" + integrity sha1-ZbP8wcoRVajJrmTW7uKX8V1RM8w= + dependencies: + ast-types-flow "0.0.7" + commander "^2.11.0" + +arity-n@^1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/arity-n/-/arity-n-1.0.4.tgz#d9e76b11733e08569c0847ae7b39b2860b30b745" + integrity sha1-2edrEXM+CFacCEeuezmyhgswt0U= + +arr-diff@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/arr-diff/-/arr-diff-4.0.0.tgz#d6461074febfec71e7e15235761a329a5dc7c520" + integrity sha1-1kYQdP6/7HHn4VI1dhoyml3HxSA= + +arr-flatten@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/arr-flatten/-/arr-flatten-1.1.0.tgz#36048bbff4e7b47e136644316c99669ea5ae91f1" + integrity sha512-L3hKV5R/p5o81R7O02IGnwpDmkp6E982XhtbuwSe3O4qOtMMMtodicASA1Cny2U+aCXcNpml+m4dPsvsJ3jatg== + +arr-union@^3.1.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/arr-union/-/arr-union-3.1.0.tgz#e39b09aea9def866a8f206e288af63919bae39c4" + integrity sha1-45sJrqne+Gao8gbiiK9jkZuuOcQ= + +array-equal@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/array-equal/-/array-equal-1.0.0.tgz#8c2a5ef2472fd9ea742b04c77a75093ba2757c93" + integrity sha1-jCpe8kcv2ep0KwTHenUJO6J1fJM= + +array-flatten@1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/array-flatten/-/array-flatten-1.1.1.tgz#9a5f699051b1e7073328f2a008968b64ea2955d2" + integrity sha1-ml9pkFGx5wczKPKgCJaLZOopVdI= + +array-flatten@^2.1.0: + version "2.1.2" + resolved "https://registry.yarnpkg.com/array-flatten/-/array-flatten-2.1.2.tgz#24ef80a28c1a893617e2149b0c6d0d788293b099" + integrity sha512-hNfzcOV8W4NdualtqBFPyVO+54DSJuZGY9qT4pRroB6S9e3iiido2ISIC5h9R2sPJ8H3FHCIiEnsv1lPXO3KtQ== + +array-includes@^3.0.3, array-includes@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/array-includes/-/array-includes-3.1.1.tgz#cdd67e6852bdf9c1215460786732255ed2459348" + integrity sha512-c2VXaCHl7zPsvpkFsw4nxvFie4fh1ur9bpcgsVkIjqn0H/Xwdg+7fv3n2r/isyS8EBj5b06M9kHyZuIr4El6WQ== + dependencies: + define-properties "^1.1.3" + es-abstract "^1.17.0" + is-string "^1.0.5" + +array-union@^1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/array-union/-/array-union-1.0.2.tgz#9a34410e4f4e3da23dea375be5be70f24778ec39" + integrity sha1-mjRBDk9OPaI96jdb5b5w8kd47Dk= + dependencies: + array-uniq "^1.0.1" + +array-uniq@^1.0.1: + version "1.0.3" + resolved "https://registry.yarnpkg.com/array-uniq/-/array-uniq-1.0.3.tgz#af6ac877a25cc7f74e058894753858dfdb24fdb6" + integrity sha1-r2rId6Jcx/dOBYiUdThY39sk/bY= + +array-unique@^0.3.2: + version "0.3.2" + resolved "https://registry.yarnpkg.com/array-unique/-/array-unique-0.3.2.tgz#a894b75d4bc4f6cd679ef3244a9fd8f46ae2d428" + integrity sha1-qJS3XUvE9s1nnvMkSp/Y9Gri1Cg= + +array.prototype.flat@^1.2.1: + version "1.2.3" + resolved "https://registry.yarnpkg.com/array.prototype.flat/-/array.prototype.flat-1.2.3.tgz#0de82b426b0318dbfdb940089e38b043d37f6c7b" + integrity sha512-gBlRZV0VSmfPIeWfuuy56XZMvbVfbEUnOXUvt3F/eUUUSyzlgLxhEX4YAEpxNAogRGehPSnfXyPtYyKAhkzQhQ== + dependencies: + define-properties "^1.1.3" + es-abstract "^1.17.0-next.1" + +arrify@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/arrify/-/arrify-1.0.1.tgz#898508da2226f380df904728456849c1501a4b0d" + integrity sha1-iYUI2iIm84DfkEcoRWhJwVAaSw0= + +asap@~2.0.6: + version "2.0.6" + resolved "https://registry.yarnpkg.com/asap/-/asap-2.0.6.tgz#e50347611d7e690943208bbdafebcbc2fb866d46" + integrity sha1-5QNHYR1+aQlDIIu9r+vLwvuGbUY= + +asn1.js@^4.0.0: + version "4.10.1" + resolved "https://registry.yarnpkg.com/asn1.js/-/asn1.js-4.10.1.tgz#b9c2bf5805f1e64aadeed6df3a2bfafb5a73f5a0" + integrity sha512-p32cOF5q0Zqs9uBiONKYLm6BClCoBCM5O9JfeUSlnQLBTxYdTK+pW+nXflm8UkKd2UYlEbYz5qEi0JuZR9ckSw== + dependencies: + bn.js "^4.0.0" + inherits "^2.0.1" + minimalistic-assert "^1.0.0" + +asn1@~0.2.3: + version "0.2.4" + resolved "https://registry.yarnpkg.com/asn1/-/asn1-0.2.4.tgz#8d2475dfab553bb33e77b54e59e880bb8ce23136" + integrity sha512-jxwzQpLQjSmWXgwaCZE9Nz+glAG01yF1QnWgbhGwHI5A6FRIEY6IVqtHhIepHqI7/kyEyQEagBC5mBEFlIYvdg== + dependencies: + safer-buffer "~2.1.0" + +assert-plus@1.0.0, assert-plus@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/assert-plus/-/assert-plus-1.0.0.tgz#f12e0f3c5d77b0b1cdd9146942e4e96c1e4dd525" + integrity sha1-8S4PPF13sLHN2RRpQuTpbB5N1SU= + +assert@1.4.1: + version "1.4.1" + resolved "https://registry.yarnpkg.com/assert/-/assert-1.4.1.tgz#99912d591836b5a6f5b345c0f07eefc08fc65d91" + integrity sha1-mZEtWRg2tab1s0XA8H7vwI/GXZE= + dependencies: + util "0.10.3" + +assert@^1.1.1: + version "1.5.0" + resolved "https://registry.yarnpkg.com/assert/-/assert-1.5.0.tgz#55c109aaf6e0aefdb3dc4b71240c70bf574b18eb" + integrity sha512-EDsgawzwoun2CZkCgtxJbv392v4nbk9XDD06zI+kQYoBM/3RBWLlEyJARDOmhAAosBjWACEkKL6S+lIZtcAubA== + dependencies: + object-assign "^4.1.1" + util "0.10.3" + +assign-symbols@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/assign-symbols/-/assign-symbols-1.0.0.tgz#59667f41fadd4f20ccbc2bb96b8d4f7f78ec0367" + integrity sha1-WWZ/QfrdTyDMvCu5a41Pf3jsA2c= + +ast-types-flow@0.0.7, ast-types-flow@^0.0.7: + version "0.0.7" + resolved "https://registry.yarnpkg.com/ast-types-flow/-/ast-types-flow-0.0.7.tgz#f70b735c6bca1a5c9c22d982c3e39e7feba3bdad" + integrity sha1-9wtzXGvKGlycItmCw+Oef+ujva0= + +astral-regex@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/astral-regex/-/astral-regex-1.0.0.tgz#6c8c3fb827dd43ee3918f27b82782ab7658a6fd9" + integrity sha512-+Ryf6g3BKoRc7jfp7ad8tM4TtMiaWvbF/1/sQcZPkkS7ag3D5nMBCe2UfOTONtAkaG0tO0ij3C5Lwmf1EiyjHg== + +async-each@^1.0.1: + version "1.0.3" + resolved "https://registry.yarnpkg.com/async-each/-/async-each-1.0.3.tgz#b727dbf87d7651602f06f4d4ac387f47d91b0cbf" + integrity sha512-z/WhQ5FPySLdvREByI2vZiTWwCnF0moMJ1hK9YQwDTHKh6I7/uSckMetoRGb5UBZPC1z0jlw+n/XCgjeH7y1AQ== + +async-limiter@~1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/async-limiter/-/async-limiter-1.0.1.tgz#dd379e94f0db8310b08291f9d64c3209766617fd" + integrity sha512-csOlWGAcRFJaI6m+F2WKdnMKr4HhdhFVBk0H/QbJFMCr+uO2kwohwXQPxw/9OCxp05r5ghVBFSyioixx3gfkNQ== + +async@^2.6.2: + version "2.6.3" + resolved "https://registry.yarnpkg.com/async/-/async-2.6.3.tgz#d72625e2344a3656e3a3ad4fa749fa83299d82ff" + integrity sha512-zflvls11DCy+dQWzTW2dzuilv8Z5X/pjfmZOWba6TNIVDm+2UDaJmXSOXlasHKfNBs8oo3M0aT50fDEWfKZjXg== + dependencies: + lodash "^4.17.14" + +asynckit@^0.4.0: + version "0.4.0" + resolved "https://registry.yarnpkg.com/asynckit/-/asynckit-0.4.0.tgz#c79ed97f7f34cb8f2ba1bc9790bcc366474b4b79" + integrity sha1-x57Zf380y48robyXkLzDZkdLS3k= + +atob@^2.1.2: + version "2.1.2" + resolved "https://registry.yarnpkg.com/atob/-/atob-2.1.2.tgz#6d9517eb9e030d2436666651e86bd9f6f13533c9" + integrity sha512-Wm6ukoaOGJi/73p/cl2GvLjTI5JM1k/O14isD73YML8StrH/7/lRFgmg8nICZgD3bZZvjwCGxtMOD3wWNAu8cg== + +autoprefixer@^9.6.1: + version "9.7.5" + resolved "https://registry.yarnpkg.com/autoprefixer/-/autoprefixer-9.7.5.tgz#8df10b9ff9b5814a8d411a5cfbab9c793c392376" + integrity sha512-URo6Zvt7VYifomeAfJlMFnYDhow1rk2bufwkbamPEAtQFcL11moLk4PnR7n9vlu7M+BkXAZkHFA0mIcY7tjQFg== + dependencies: + browserslist "^4.11.0" + caniuse-lite "^1.0.30001036" + chalk "^2.4.2" + normalize-range "^0.1.2" + num2fraction "^1.2.2" + postcss "^7.0.27" + postcss-value-parser "^4.0.3" + +aws-sign2@~0.7.0: + version "0.7.0" + resolved "https://registry.yarnpkg.com/aws-sign2/-/aws-sign2-0.7.0.tgz#b46e890934a9591f2d2f6f86d7e6a9f1b3fe76a8" + integrity sha1-tG6JCTSpWR8tL2+G1+ap8bP+dqg= + +aws4@^1.8.0: + version "1.9.1" + resolved "https://registry.yarnpkg.com/aws4/-/aws4-1.9.1.tgz#7e33d8f7d449b3f673cd72deb9abdc552dbe528e" + integrity sha512-wMHVg2EOHaMRxbzgFJ9gtjOOCrI80OHLG14rxi28XwOW8ux6IiEbRCGGGqCtdAIg4FQCbW20k9RsT4y3gJlFug== + +axobject-query@^2.0.2: + version "2.1.2" + resolved "https://registry.yarnpkg.com/axobject-query/-/axobject-query-2.1.2.tgz#2bdffc0371e643e5f03ba99065d5179b9ca79799" + integrity sha512-ICt34ZmrVt8UQnvPl6TVyDTkmhXmAyAT4Jh5ugfGUX4MOrZ+U/ZY6/sdylRw3qGNr9Ub5AJsaHeDMzNLehRdOQ== + +babel-code-frame@^6.22.0: + version "6.26.0" + resolved "https://registry.yarnpkg.com/babel-code-frame/-/babel-code-frame-6.26.0.tgz#63fd43f7dc1e3bb7ce35947db8fe369a3f58c74b" + integrity sha1-Y/1D99weO7fONZR9uP42mj9Yx0s= + dependencies: + chalk "^1.1.3" + esutils "^2.0.2" + js-tokens "^3.0.2" + +babel-eslint@10.1.0: + version "10.1.0" + resolved "https://registry.yarnpkg.com/babel-eslint/-/babel-eslint-10.1.0.tgz#6968e568a910b78fb3779cdd8b6ac2f479943232" + integrity sha512-ifWaTHQ0ce+448CYop8AdrQiBsGrnC+bMgfyKFdi6EsPLTAWG+QfyDeM6OH+FmWnKvEq5NnBMLvlBUPKQZoDSg== + dependencies: + "@babel/code-frame" "^7.0.0" + "@babel/parser" "^7.7.0" + "@babel/traverse" "^7.7.0" + "@babel/types" "^7.7.0" + eslint-visitor-keys "^1.0.0" + resolve "^1.12.0" + +babel-extract-comments@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/babel-extract-comments/-/babel-extract-comments-1.0.0.tgz#0a2aedf81417ed391b85e18b4614e693a0351a21" + integrity sha512-qWWzi4TlddohA91bFwgt6zO/J0X+io7Qp184Fw0m2JYRSTZnJbFR8+07KmzudHCZgOiKRCrjhylwv9Xd8gfhVQ== + dependencies: + babylon "^6.18.0" + +babel-jest@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/babel-jest/-/babel-jest-24.9.0.tgz#3fc327cb8467b89d14d7bc70e315104a783ccd54" + integrity sha512-ntuddfyiN+EhMw58PTNL1ph4C9rECiQXjI4nMMBKBaNjXvqLdkXpPRcMSr4iyBrJg/+wz9brFUD6RhOAT6r4Iw== + dependencies: + "@jest/transform" "^24.9.0" + "@jest/types" "^24.9.0" + "@types/babel__core" "^7.1.0" + babel-plugin-istanbul "^5.1.0" + babel-preset-jest "^24.9.0" + chalk "^2.4.2" + slash "^2.0.0" + +babel-loader@8.1.0, babel-loader@^8.0.6: + version "8.1.0" + resolved "https://registry.yarnpkg.com/babel-loader/-/babel-loader-8.1.0.tgz#c611d5112bd5209abe8b9fa84c3e4da25275f1c3" + integrity sha512-7q7nC1tYOrqvUrN3LQK4GwSk/TQorZSOlO9C+RZDZpODgyN4ZlCqE5q9cDsyWOliN+aU9B4JX01xK9eJXowJLw== + dependencies: + find-cache-dir "^2.1.0" + loader-utils "^1.4.0" + mkdirp "^0.5.3" + pify "^4.0.1" + schema-utils "^2.6.5" + +babel-plugin-dynamic-import-node@^2.3.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/babel-plugin-dynamic-import-node/-/babel-plugin-dynamic-import-node-2.3.0.tgz#f00f507bdaa3c3e3ff6e7e5e98d90a7acab96f7f" + integrity sha512-o6qFkpeQEBxcqt0XYlWzAVxNCSCZdUgcR8IRlhD/8DylxjjO4foPcvTW0GGKa/cVt3rvxZ7o5ippJ+/0nvLhlQ== + dependencies: + object.assign "^4.1.0" + +babel-plugin-istanbul@^5.1.0: + version "5.2.0" + resolved "https://registry.yarnpkg.com/babel-plugin-istanbul/-/babel-plugin-istanbul-5.2.0.tgz#df4ade83d897a92df069c4d9a25cf2671293c854" + integrity sha512-5LphC0USA8t4i1zCtjbbNb6jJj/9+X6P37Qfirc/70EQ34xKlMW+a1RHGwxGI+SwWpNwZ27HqvzAobeqaXwiZw== + dependencies: + "@babel/helper-plugin-utils" "^7.0.0" + find-up "^3.0.0" + istanbul-lib-instrument "^3.3.0" + test-exclude "^5.2.3" + +babel-plugin-jest-hoist@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/babel-plugin-jest-hoist/-/babel-plugin-jest-hoist-24.9.0.tgz#4f837091eb407e01447c8843cbec546d0002d756" + integrity sha512-2EMA2P8Vp7lG0RAzr4HXqtYwacfMErOuv1U3wrvxHX6rD1sV6xS3WXG3r8TRQ2r6w8OhvSdWt+z41hQNwNm3Xw== + dependencies: + "@types/babel__traverse" "^7.0.6" + +babel-plugin-macros@2.8.0: + version "2.8.0" + resolved "https://registry.yarnpkg.com/babel-plugin-macros/-/babel-plugin-macros-2.8.0.tgz#0f958a7cc6556b1e65344465d99111a1e5e10138" + integrity sha512-SEP5kJpfGYqYKpBrj5XU3ahw5p5GOHJ0U5ssOSQ/WBVdwkD2Dzlce95exQTs3jOVWPPKLBN2rlEWkCK7dSmLvg== + dependencies: + "@babel/runtime" "^7.7.2" + cosmiconfig "^6.0.0" + resolve "^1.12.0" + +babel-plugin-named-asset-import@^0.3.6: + version "0.3.6" + resolved "https://registry.yarnpkg.com/babel-plugin-named-asset-import/-/babel-plugin-named-asset-import-0.3.6.tgz#c9750a1b38d85112c9e166bf3ef7c5dbc605f4be" + integrity sha512-1aGDUfL1qOOIoqk9QKGIo2lANk+C7ko/fqH0uIyC71x3PEGz0uVP8ISgfEsFuG+FKmjHTvFK/nNM8dowpmUxLA== + +babel-plugin-syntax-object-rest-spread@^6.8.0: + version "6.13.0" + resolved "https://registry.yarnpkg.com/babel-plugin-syntax-object-rest-spread/-/babel-plugin-syntax-object-rest-spread-6.13.0.tgz#fd6536f2bce13836ffa3a5458c4903a597bb3bf5" + integrity sha1-/WU28rzhODb/o6VFjEkDpZe7O/U= + +babel-plugin-transform-object-rest-spread@^6.26.0: + version "6.26.0" + resolved "https://registry.yarnpkg.com/babel-plugin-transform-object-rest-spread/-/babel-plugin-transform-object-rest-spread-6.26.0.tgz#0f36692d50fef6b7e2d4b3ac1478137a963b7b06" + integrity sha1-DzZpLVD+9rfi1LOsFHgTepY7ewY= + dependencies: + babel-plugin-syntax-object-rest-spread "^6.8.0" + babel-runtime "^6.26.0" + +babel-plugin-transform-react-remove-prop-types@0.4.24: + version "0.4.24" + resolved "https://registry.yarnpkg.com/babel-plugin-transform-react-remove-prop-types/-/babel-plugin-transform-react-remove-prop-types-0.4.24.tgz#f2edaf9b4c6a5fbe5c1d678bfb531078c1555f3a" + integrity sha512-eqj0hVcJUR57/Ug2zE1Yswsw4LhuqqHhD+8v120T1cl3kjg76QwtyBrdIk4WVwK+lAhBJVYCd/v+4nc4y+8JsA== + +babel-preset-jest@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/babel-preset-jest/-/babel-preset-jest-24.9.0.tgz#192b521e2217fb1d1f67cf73f70c336650ad3cdc" + integrity sha512-izTUuhE4TMfTRPF92fFwD2QfdXaZW08qvWTFCI51V8rW5x00UuPgc3ajRoWofXOuxjfcOM5zzSYsQS3H8KGCAg== + dependencies: + "@babel/plugin-syntax-object-rest-spread" "^7.0.0" + babel-plugin-jest-hoist "^24.9.0" + +babel-preset-react-app@^9.1.2: + version "9.1.2" + resolved "https://registry.yarnpkg.com/babel-preset-react-app/-/babel-preset-react-app-9.1.2.tgz#54775d976588a8a6d1a99201a702befecaf48030" + integrity sha512-k58RtQOKH21NyKtzptoAvtAODuAJJs3ZhqBMl456/GnXEQ/0La92pNmwgWoMn5pBTrsvk3YYXdY7zpY4e3UIxA== + dependencies: + "@babel/core" "7.9.0" + "@babel/plugin-proposal-class-properties" "7.8.3" + "@babel/plugin-proposal-decorators" "7.8.3" + "@babel/plugin-proposal-nullish-coalescing-operator" "7.8.3" + "@babel/plugin-proposal-numeric-separator" "7.8.3" + "@babel/plugin-proposal-optional-chaining" "7.9.0" + "@babel/plugin-transform-flow-strip-types" "7.9.0" + "@babel/plugin-transform-react-display-name" "7.8.3" + "@babel/plugin-transform-runtime" "7.9.0" + "@babel/preset-env" "7.9.0" + "@babel/preset-react" "7.9.1" + "@babel/preset-typescript" "7.9.0" + "@babel/runtime" "7.9.0" + babel-plugin-macros "2.8.0" + babel-plugin-transform-react-remove-prop-types "0.4.24" + +babel-runtime@^6.26.0: + version "6.26.0" + resolved "https://registry.yarnpkg.com/babel-runtime/-/babel-runtime-6.26.0.tgz#965c7058668e82b55d7bfe04ff2337bc8b5647fe" + integrity sha1-llxwWGaOgrVde/4E/yM3vItWR/4= + dependencies: + core-js "^2.4.0" + regenerator-runtime "^0.11.0" + +babylon@^6.18.0: + version "6.18.0" + resolved "https://registry.yarnpkg.com/babylon/-/babylon-6.18.0.tgz#af2f3b88fa6f5c1e4c634d1a0f8eac4f55b395e3" + integrity sha512-q/UEjfGJ2Cm3oKV71DJz9d25TPnq5rhBVL2Q4fA5wcC3jcrdn7+SssEybFIxwAvvP+YCsCYNKughoF33GxgycQ== + +balanced-match@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/balanced-match/-/balanced-match-1.0.0.tgz#89b4d199ab2bee49de164ea02b89ce462d71b767" + integrity sha1-ibTRmasr7kneFk6gK4nORi1xt2c= + +base64-js@^1.0.2: + version "1.3.1" + resolved "https://registry.yarnpkg.com/base64-js/-/base64-js-1.3.1.tgz#58ece8cb75dd07e71ed08c736abc5fac4dbf8df1" + integrity sha512-mLQ4i2QO1ytvGWFWmcngKO//JXAQueZvwEKtjgQFM4jIK0kU+ytMfplL8j+n5mspOfjHwoAg+9yhb7BwAHm36g== + +base@^0.11.1: + version "0.11.2" + resolved "https://registry.yarnpkg.com/base/-/base-0.11.2.tgz#7bde5ced145b6d551a90db87f83c558b4eb48a8f" + integrity sha512-5T6P4xPgpp0YDFvSWwEZ4NoE3aM4QBQXDzmVbraCkFj8zHM+mba8SyqB5DbZWyR7mYHo6Y7BdQo3MoA4m0TeQg== + dependencies: + cache-base "^1.0.1" + class-utils "^0.3.5" + component-emitter "^1.2.1" + define-property "^1.0.0" + isobject "^3.0.1" + mixin-deep "^1.2.0" + pascalcase "^0.1.1" + +batch@0.6.1: + version "0.6.1" + resolved "https://registry.yarnpkg.com/batch/-/batch-0.6.1.tgz#dc34314f4e679318093fc760272525f94bf25c16" + integrity sha1-3DQxT05nkxgJP8dgJyUl+UvyXBY= + +bcrypt-pbkdf@^1.0.0: + version "1.0.2" + resolved "https://registry.yarnpkg.com/bcrypt-pbkdf/-/bcrypt-pbkdf-1.0.2.tgz#a4301d389b6a43f9b67ff3ca11a3f6637e360e9e" + integrity sha1-pDAdOJtqQ/m2f/PKEaP2Y342Dp4= + dependencies: + tweetnacl "^0.14.3" + +big.js@^5.2.2: + version "5.2.2" + resolved "https://registry.yarnpkg.com/big.js/-/big.js-5.2.2.tgz#65f0af382f578bcdc742bd9c281e9cb2d7768328" + integrity sha512-vyL2OymJxmarO8gxMr0mhChsO9QGwhynfuu4+MHTAW6czfq9humCB7rKpUjDd9YUiDPU4mzpyupFSvOClAwbmQ== + +binary-extensions@^1.0.0: + version "1.13.1" + resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-1.13.1.tgz#598afe54755b2868a5330d2aff9d4ebb53209b65" + integrity sha512-Un7MIEDdUC5gNpcGDV97op1Ywk748MpHcFTHoYs6qnj1Z3j7I53VG3nwZhKzoBZmbdRNnb6WRdFlwl7tSDuZGw== + +binary-extensions@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-2.0.0.tgz#23c0df14f6a88077f5f986c0d167ec03c3d5537c" + integrity sha512-Phlt0plgpIIBOGTT/ehfFnbNlfsDEiqmzE2KRXoX1bLIlir4X/MR+zSyBEkL05ffWgnRSf/DXv+WrUAVr93/ow== + +bindings@^1.5.0: + version "1.5.0" + resolved "https://registry.yarnpkg.com/bindings/-/bindings-1.5.0.tgz#10353c9e945334bc0511a6d90b38fbc7c9c504df" + integrity sha512-p2q/t/mhvuOj/UeLlV6566GD/guowlr0hHxClI0W9m7MWYkL1F0hLo+0Aexs9HSPCtR1SXQ0TD3MMKrXZajbiQ== + dependencies: + file-uri-to-path "1.0.0" + +bluebird@^3.5.5: + version "3.7.2" + resolved "https://registry.yarnpkg.com/bluebird/-/bluebird-3.7.2.tgz#9f229c15be272454ffa973ace0dbee79a1b0c36f" + integrity sha512-XpNj6GDQzdfW+r2Wnn7xiSAd7TM3jzkxGXBGTtWKuSXv1xUV+azxAm8jdWZN06QTQk+2N2XB9jRDkvbmQmcRtg== + +bn.js@^4.0.0, bn.js@^4.1.0, bn.js@^4.1.1, bn.js@^4.4.0: + version "4.11.8" + resolved "https://registry.yarnpkg.com/bn.js/-/bn.js-4.11.8.tgz#2cde09eb5ee341f484746bb0309b3253b1b1442f" + integrity sha512-ItfYfPLkWHUjckQCk8xC+LwxgK8NYcXywGigJgSwOP8Y2iyWT4f2vsZnoOXTTbo+o5yXmIUJ4gn5538SO5S3gA== + +body-parser@1.19.0: + version "1.19.0" + resolved "https://registry.yarnpkg.com/body-parser/-/body-parser-1.19.0.tgz#96b2709e57c9c4e09a6fd66a8fd979844f69f08a" + integrity sha512-dhEPs72UPbDnAQJ9ZKMNTP6ptJaionhP5cBb541nXPlW60Jepo9RV/a4fX4XWW9CuFNK22krhrj1+rgzifNCsw== + dependencies: + bytes "3.1.0" + content-type "~1.0.4" + debug "2.6.9" + depd "~1.1.2" + http-errors "1.7.2" + iconv-lite "0.4.24" + on-finished "~2.3.0" + qs "6.7.0" + raw-body "2.4.0" + type-is "~1.6.17" + +bonjour@^3.5.0: + version "3.5.0" + resolved "https://registry.yarnpkg.com/bonjour/-/bonjour-3.5.0.tgz#8e890a183d8ee9a2393b3844c691a42bcf7bc9f5" + integrity sha1-jokKGD2O6aI5OzhExpGkK897yfU= + dependencies: + array-flatten "^2.1.0" + deep-equal "^1.0.1" + dns-equal "^1.0.0" + dns-txt "^2.0.2" + multicast-dns "^6.0.1" + multicast-dns-service-types "^1.1.0" + +boolbase@^1.0.0, boolbase@~1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/boolbase/-/boolbase-1.0.0.tgz#68dff5fbe60c51eb37725ea9e3ed310dcc1e776e" + integrity sha1-aN/1++YMUes3cl6p4+0xDcwed24= + +brace-expansion@^1.1.7: + version "1.1.11" + resolved "https://registry.yarnpkg.com/brace-expansion/-/brace-expansion-1.1.11.tgz#3c7fcbf529d87226f3d2f52b966ff5271eb441dd" + integrity sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA== + dependencies: + balanced-match "^1.0.0" + concat-map "0.0.1" + +braces@^2.3.1, braces@^2.3.2: + version "2.3.2" + resolved "https://registry.yarnpkg.com/braces/-/braces-2.3.2.tgz#5979fd3f14cd531565e5fa2df1abfff1dfaee729" + integrity sha512-aNdbnj9P8PjdXU4ybaWLK2IF3jc/EoDYbC7AazW6to3TRsfXxscC9UXOB5iDiEQrkyIbWp2SLQda4+QAa7nc3w== + dependencies: + arr-flatten "^1.1.0" + array-unique "^0.3.2" + extend-shallow "^2.0.1" + fill-range "^4.0.0" + isobject "^3.0.1" + repeat-element "^1.1.2" + snapdragon "^0.8.1" + snapdragon-node "^2.0.1" + split-string "^3.0.2" + to-regex "^3.0.1" + +braces@~3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.2.tgz#3454e1a462ee8d599e236df336cd9ea4f8afe107" + integrity sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A== + dependencies: + fill-range "^7.0.1" + +brorand@^1.0.1: + version "1.1.0" + resolved "https://registry.yarnpkg.com/brorand/-/brorand-1.1.0.tgz#12c25efe40a45e3c323eb8675a0a0ce57b22371f" + integrity sha1-EsJe/kCkXjwyPrhnWgoM5XsiNx8= + +browser-process-hrtime@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/browser-process-hrtime/-/browser-process-hrtime-1.0.0.tgz#3c9b4b7d782c8121e56f10106d84c0d0ffc94626" + integrity sha512-9o5UecI3GhkpM6DrXr69PblIuWxPKk9Y0jHBRhdocZ2y7YECBFCsHm79Pr3OyR2AvjhDkabFJaDJMYRazHgsow== + +browser-resolve@^1.11.3: + version "1.11.3" + resolved "https://registry.yarnpkg.com/browser-resolve/-/browser-resolve-1.11.3.tgz#9b7cbb3d0f510e4cb86bdbd796124d28b5890af6" + integrity sha512-exDi1BYWB/6raKHmDTCicQfTkqwN5fioMFV4j8BsfMU4R2DK/QfZfK7kOVkmWCNANf0snkBzqGqAJBao9gZMdQ== + dependencies: + resolve "1.1.7" + +browserify-aes@^1.0.0, browserify-aes@^1.0.4: + version "1.2.0" + resolved "https://registry.yarnpkg.com/browserify-aes/-/browserify-aes-1.2.0.tgz#326734642f403dabc3003209853bb70ad428ef48" + integrity sha512-+7CHXqGuspUn/Sl5aO7Ea0xWGAtETPXNSAjHo48JfLdPWcMng33Xe4znFvQweqc/uzk5zSOI3H52CYnjCfb5hA== + dependencies: + buffer-xor "^1.0.3" + cipher-base "^1.0.0" + create-hash "^1.1.0" + evp_bytestokey "^1.0.3" + inherits "^2.0.1" + safe-buffer "^5.0.1" + +browserify-cipher@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/browserify-cipher/-/browserify-cipher-1.0.1.tgz#8d6474c1b870bfdabcd3bcfcc1934a10e94f15f0" + integrity sha512-sPhkz0ARKbf4rRQt2hTpAHqn47X3llLkUGn+xEJzLjwY8LRs2p0v7ljvI5EyoRO/mexrNunNECisZs+gw2zz1w== + dependencies: + browserify-aes "^1.0.4" + browserify-des "^1.0.0" + evp_bytestokey "^1.0.0" + +browserify-des@^1.0.0: + version "1.0.2" + resolved "https://registry.yarnpkg.com/browserify-des/-/browserify-des-1.0.2.tgz#3af4f1f59839403572f1c66204375f7a7f703e9c" + integrity sha512-BioO1xf3hFwz4kc6iBhI3ieDFompMhrMlnDFC4/0/vd5MokpuAc3R+LYbwTA9A5Yc9pq9UYPqffKpW2ObuwX5A== + dependencies: + cipher-base "^1.0.1" + des.js "^1.0.0" + inherits "^2.0.1" + safe-buffer "^5.1.2" + +browserify-rsa@^4.0.0: + version "4.0.1" + resolved "https://registry.yarnpkg.com/browserify-rsa/-/browserify-rsa-4.0.1.tgz#21e0abfaf6f2029cf2fafb133567a701d4135524" + integrity sha1-IeCr+vbyApzy+vsTNWenAdQTVSQ= + dependencies: + bn.js "^4.1.0" + randombytes "^2.0.1" + +browserify-sign@^4.0.0: + version "4.0.4" + resolved "https://registry.yarnpkg.com/browserify-sign/-/browserify-sign-4.0.4.tgz#aa4eb68e5d7b658baa6bf6a57e630cbd7a93d298" + integrity sha1-qk62jl17ZYuqa/alfmMMvXqT0pg= + dependencies: + bn.js "^4.1.1" + browserify-rsa "^4.0.0" + create-hash "^1.1.0" + create-hmac "^1.1.2" + elliptic "^6.0.0" + inherits "^2.0.1" + parse-asn1 "^5.0.0" + +browserify-zlib@^0.2.0: + version "0.2.0" + resolved "https://registry.yarnpkg.com/browserify-zlib/-/browserify-zlib-0.2.0.tgz#2869459d9aa3be245fe8fe2ca1f46e2e7f54d73f" + integrity sha512-Z942RysHXmJrhqk88FmKBVq/v5tqmSkDz7p54G/MGyjMnCFFnC79XWNbg+Vta8W6Wb2qtSZTSxIGkJrRpCFEiA== + dependencies: + pako "~1.0.5" + +browserslist@4.10.0: + version "4.10.0" + resolved "https://registry.yarnpkg.com/browserslist/-/browserslist-4.10.0.tgz#f179737913eaf0d2b98e4926ac1ca6a15cbcc6a9" + integrity sha512-TpfK0TDgv71dzuTsEAlQiHeWQ/tiPqgNZVdv046fvNtBZrjbv2O3TsWCDU0AWGJJKCF/KsjNdLzR9hXOsh/CfA== + dependencies: + caniuse-lite "^1.0.30001035" + electron-to-chromium "^1.3.378" + node-releases "^1.1.52" + pkg-up "^3.1.0" + +browserslist@^4.0.0, browserslist@^4.11.0, browserslist@^4.6.2, browserslist@^4.6.4, browserslist@^4.8.3, browserslist@^4.9.1: + version "4.11.1" + resolved "https://registry.yarnpkg.com/browserslist/-/browserslist-4.11.1.tgz#92f855ee88d6e050e7e7311d987992014f1a1f1b" + integrity sha512-DCTr3kDrKEYNw6Jb9HFxVLQNaue8z+0ZfRBRjmCunKDEXEBajKDj2Y+Uelg+Pi29OnvaSGwjOsnRyNEkXzHg5g== + dependencies: + caniuse-lite "^1.0.30001038" + electron-to-chromium "^1.3.390" + node-releases "^1.1.53" + pkg-up "^2.0.0" + +bser@2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/bser/-/bser-2.1.1.tgz#e6787da20ece9d07998533cfd9de6f5c38f4bc05" + integrity sha512-gQxTNE/GAfIIrmHLUE3oJyp5FO6HRBfhjnw4/wMmA63ZGDJnWBmgY/lyQBpnDUkGmAhbSe39tx2d/iTOAfglwQ== + dependencies: + node-int64 "^0.4.0" + +buffer-from@^1.0.0: + version "1.1.1" + resolved "https://registry.yarnpkg.com/buffer-from/-/buffer-from-1.1.1.tgz#32713bc028f75c02fdb710d7c7bcec1f2c6070ef" + integrity sha512-MQcXEUbCKtEo7bhqEs6560Hyd4XaovZlO/k9V3hjVUF/zwW7KBVdSK4gIt/bzwS9MbR5qob+F5jusZsb0YQK2A== + +buffer-indexof@^1.0.0: + version "1.1.1" + resolved "https://registry.yarnpkg.com/buffer-indexof/-/buffer-indexof-1.1.1.tgz#52fabcc6a606d1a00302802648ef68f639da268c" + integrity sha512-4/rOEg86jivtPTeOUUT61jJO1Ya1TrR/OkqCSZDyq84WJh3LuuiphBYJN+fm5xufIk4XAFcEwte/8WzC8If/1g== + +buffer-xor@^1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/buffer-xor/-/buffer-xor-1.0.3.tgz#26e61ed1422fb70dd42e6e36729ed51d855fe8d9" + integrity sha1-JuYe0UIvtw3ULm42cp7VHYVf6Nk= + +buffer@^4.3.0: + version "4.9.2" + resolved "https://registry.yarnpkg.com/buffer/-/buffer-4.9.2.tgz#230ead344002988644841ab0244af8c44bbe3ef8" + integrity sha512-xq+q3SRMOxGivLhBNaUdC64hDTQwejJ+H0T/NB1XMtTVEwNTrfFF3gAxiyW0Bu/xWEGhjVKgUcMhCrUy2+uCWg== + dependencies: + base64-js "^1.0.2" + ieee754 "^1.1.4" + isarray "^1.0.0" + +builtin-status-codes@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/builtin-status-codes/-/builtin-status-codes-3.0.0.tgz#85982878e21b98e1c66425e03d0174788f569ee8" + integrity sha1-hZgoeOIbmOHGZCXgPQF0eI9Wnug= + +bytes@3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/bytes/-/bytes-3.0.0.tgz#d32815404d689699f85a4ea4fa8755dd13a96048" + integrity sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg= + +bytes@3.1.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/bytes/-/bytes-3.1.0.tgz#f6cf7933a360e0588fa9fde85651cdc7f805d1f6" + integrity sha512-zauLjrfCG+xvoyaqLoV8bLVXXNGC4JqlxFCutSDWA6fJrTo2ZuvLYTqZ7aHBLZSMOopbzwv8f+wZcVzfVTI2Dg== + +cacache@^12.0.2: + version "12.0.4" + resolved "https://registry.yarnpkg.com/cacache/-/cacache-12.0.4.tgz#668bcbd105aeb5f1d92fe25570ec9525c8faa40c" + integrity sha512-a0tMB40oefvuInr4Cwb3GerbL9xTj1D5yg0T5xrjGCGyfvbxseIXX7BAO/u/hIXdafzOI5JC3wDwHyf24buOAQ== + dependencies: + bluebird "^3.5.5" + chownr "^1.1.1" + figgy-pudding "^3.5.1" + glob "^7.1.4" + graceful-fs "^4.1.15" + infer-owner "^1.0.3" + lru-cache "^5.1.1" + mississippi "^3.0.0" + mkdirp "^0.5.1" + move-concurrently "^1.0.1" + promise-inflight "^1.0.1" + rimraf "^2.6.3" + ssri "^6.0.1" + unique-filename "^1.1.1" + y18n "^4.0.0" + +cacache@^13.0.1: + version "13.0.1" + resolved "https://registry.yarnpkg.com/cacache/-/cacache-13.0.1.tgz#a8000c21697089082f85287a1aec6e382024a71c" + integrity sha512-5ZvAxd05HDDU+y9BVvcqYu2LLXmPnQ0hW62h32g4xBTgL/MppR4/04NHfj/ycM2y6lmTnbw6HVi+1eN0Psba6w== + dependencies: + chownr "^1.1.2" + figgy-pudding "^3.5.1" + fs-minipass "^2.0.0" + glob "^7.1.4" + graceful-fs "^4.2.2" + infer-owner "^1.0.4" + lru-cache "^5.1.1" + minipass "^3.0.0" + minipass-collect "^1.0.2" + minipass-flush "^1.0.5" + minipass-pipeline "^1.2.2" + mkdirp "^0.5.1" + move-concurrently "^1.0.1" + p-map "^3.0.0" + promise-inflight "^1.0.1" + rimraf "^2.7.1" + ssri "^7.0.0" + unique-filename "^1.1.1" + +cache-base@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/cache-base/-/cache-base-1.0.1.tgz#0a7f46416831c8b662ee36fe4e7c59d76f666ab2" + integrity sha512-AKcdTnFSWATd5/GCPRxr2ChwIJ85CeyrEyjRHlKxQ56d4XJMGym0uAiKn0xbLOGOl3+yRpOTi484dVCEc5AUzQ== + dependencies: + collection-visit "^1.0.0" + component-emitter "^1.2.1" + get-value "^2.0.6" + has-value "^1.0.0" + isobject "^3.0.1" + set-value "^2.0.0" + to-object-path "^0.3.0" + union-value "^1.0.0" + unset-value "^1.0.0" + +call-me-maybe@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/call-me-maybe/-/call-me-maybe-1.0.1.tgz#26d208ea89e37b5cbde60250a15f031c16a4d66b" + integrity sha1-JtII6onje1y95gJQoV8DHBak1ms= + +caller-callsite@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/caller-callsite/-/caller-callsite-2.0.0.tgz#847e0fce0a223750a9a027c54b33731ad3154134" + integrity sha1-hH4PzgoiN1CpoCfFSzNzGtMVQTQ= + dependencies: + callsites "^2.0.0" + +caller-path@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/caller-path/-/caller-path-2.0.0.tgz#468f83044e369ab2010fac5f06ceee15bb2cb1f4" + integrity sha1-Ro+DBE42mrIBD6xfBs7uFbsssfQ= + dependencies: + caller-callsite "^2.0.0" + +callsites@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/callsites/-/callsites-2.0.0.tgz#06eb84f00eea413da86affefacbffb36093b3c50" + integrity sha1-BuuE8A7qQT2oav/vrL/7Ngk7PFA= + +callsites@^3.0.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/callsites/-/callsites-3.1.0.tgz#b3630abd8943432f54b3f0519238e33cd7df2f73" + integrity sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ== + +camel-case@^4.1.1: + version "4.1.1" + resolved "https://registry.yarnpkg.com/camel-case/-/camel-case-4.1.1.tgz#1fc41c854f00e2f7d0139dfeba1542d6896fe547" + integrity sha512-7fa2WcG4fYFkclIvEmxBbTvmibwF2/agfEBc6q3lOpVu0A13ltLsA+Hr/8Hp6kp5f+G7hKi6t8lys6XxP+1K6Q== + dependencies: + pascal-case "^3.1.1" + tslib "^1.10.0" + +camelcase@5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/camelcase/-/camelcase-5.0.0.tgz#03295527d58bd3cd4aa75363f35b2e8d97be2f42" + integrity sha512-faqwZqnWxbxn+F1d399ygeamQNy3lPp/H9H6rNrqYh4FSVCtcY+3cub1MxA8o9mDd55mM8Aghuu/kuyYA6VTsA== + +camelcase@5.3.1, camelcase@^5.0.0, camelcase@^5.3.1: + version "5.3.1" + resolved "https://registry.yarnpkg.com/camelcase/-/camelcase-5.3.1.tgz#e3c9b31569e106811df242f715725a1f4c494320" + integrity sha512-L28STB170nwWS63UjtlEOE3dldQApaJXZkOI1uMFfzf3rRuPegHaHesyee+YxQ+W6SvRDQV6UrdOdRiR153wJg== + +caniuse-api@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/caniuse-api/-/caniuse-api-3.0.0.tgz#5e4d90e2274961d46291997df599e3ed008ee4c0" + integrity sha512-bsTwuIg/BZZK/vreVTYYbSWoe2F+71P7K5QGEX+pT250DZbfU1MQ5prOKpPR+LL6uWKK3KMwMCAS74QB3Um1uw== + dependencies: + browserslist "^4.0.0" + caniuse-lite "^1.0.0" + lodash.memoize "^4.1.2" + lodash.uniq "^4.5.0" + +caniuse-lite@^1.0.0, caniuse-lite@^1.0.30000981, caniuse-lite@^1.0.30001035, caniuse-lite@^1.0.30001036, caniuse-lite@^1.0.30001038: + version "1.0.30001038" + resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001038.tgz#44da3cbca2ab6cb6aa83d1be5d324e17f141caff" + integrity sha512-zii9quPo96XfOiRD4TrfYGs+QsGZpb2cGiMAzPjtf/hpFgB6zCPZgJb7I1+EATeMw/o+lG8FyRAnI+CWStHcaQ== + +capture-exit@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/capture-exit/-/capture-exit-2.0.0.tgz#fb953bfaebeb781f62898239dabb426d08a509a4" + integrity sha512-PiT/hQmTonHhl/HFGN+Lx3JJUznrVYJ3+AQsnthneZbvW7x+f08Tk7yLJTLEOUvBTbduLeeBkxEaYXUOUrRq6g== + dependencies: + rsvp "^4.8.4" + +case-sensitive-paths-webpack-plugin@2.3.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/case-sensitive-paths-webpack-plugin/-/case-sensitive-paths-webpack-plugin-2.3.0.tgz#23ac613cc9a856e4f88ff8bb73bbb5e989825cf7" + integrity sha512-/4YgnZS8y1UXXmC02xD5rRrBEu6T5ub+mQHLNRj0fzTRbgdBYhsNo2V5EqwgqrExjxsjtF/OpAKAMkKsxbD5XQ== + +caseless@~0.12.0: + version "0.12.0" + resolved "https://registry.yarnpkg.com/caseless/-/caseless-0.12.0.tgz#1b681c21ff84033c826543090689420d187151dc" + integrity sha1-G2gcIf+EAzyCZUMJBolCDRhxUdw= + +chalk@2.4.2, chalk@^2.0.0, chalk@^2.0.1, chalk@^2.1.0, chalk@^2.4.1, chalk@^2.4.2: + version "2.4.2" + resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.2.tgz#cd42541677a54333cf541a49108c1432b44c9424" + integrity sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ== + dependencies: + ansi-styles "^3.2.1" + escape-string-regexp "^1.0.5" + supports-color "^5.3.0" + +chalk@^1.1.3: + version "1.1.3" + resolved "https://registry.yarnpkg.com/chalk/-/chalk-1.1.3.tgz#a8115c55e4a702fe4d150abd3872822a7e09fc98" + integrity sha1-qBFcVeSnAv5NFQq9OHKCKn4J/Jg= + dependencies: + ansi-styles "^2.2.1" + escape-string-regexp "^1.0.2" + has-ansi "^2.0.0" + strip-ansi "^3.0.0" + supports-color "^2.0.0" + +chalk@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/chalk/-/chalk-3.0.0.tgz#3f73c2bf526591f574cc492c51e2456349f844e4" + integrity sha512-4D3B6Wf41KOYRFdszmDqMCGq5VV/uMAB273JILmO+3jAlh8X4qDtdtgCR3fxtbLEMzSx22QdhnDcJvu2u1fVwg== + dependencies: + ansi-styles "^4.1.0" + supports-color "^7.1.0" + +chardet@^0.7.0: + version "0.7.0" + resolved "https://registry.yarnpkg.com/chardet/-/chardet-0.7.0.tgz#90094849f0937f2eedc2425d0d28a9e5f0cbad9e" + integrity sha512-mT8iDcrh03qDGRRmoA2hmBJnxpllMR+0/0qlzjqZES6NdiWDcZkCNAk4rPFZ9Q85r27unkiNNg8ZOiwZXBHwcA== + +chart.js@^2.9.4: + version "2.9.4" + resolved "https://registry.yarnpkg.com/chart.js/-/chart.js-2.9.4.tgz#0827f9563faffb2dc5c06562f8eb10337d5b9684" + integrity sha512-B07aAzxcrikjAPyV+01j7BmOpxtQETxTSlQ26BEYJ+3iUkbNKaOJ/nDbT6JjyqYxseM0ON12COHYdU2cTIjC7A== + dependencies: + chartjs-color "^2.1.0" + moment "^2.10.2" + +chartjs-color-string@^0.6.0: + version "0.6.0" + resolved "https://registry.yarnpkg.com/chartjs-color-string/-/chartjs-color-string-0.6.0.tgz#1df096621c0e70720a64f4135ea171d051402f71" + integrity sha512-TIB5OKn1hPJvO7JcteW4WY/63v6KwEdt6udfnDE9iCAZgy+V4SrbSxoIbTw/xkUIapjEI4ExGtD0+6D3KyFd7A== + dependencies: + color-name "^1.0.0" + +chartjs-color@^2.1.0: + version "2.4.1" + resolved "https://registry.yarnpkg.com/chartjs-color/-/chartjs-color-2.4.1.tgz#6118bba202fe1ea79dd7f7c0f9da93467296c3b0" + integrity sha512-haqOg1+Yebys/Ts/9bLo/BqUcONQOdr/hoEr2LLTRl6C5LXctUdHxsCYfvQVg5JIxITrfCNUDr4ntqmQk9+/0w== + dependencies: + chartjs-color-string "^0.6.0" + color-convert "^1.9.3" + +chokidar@^2.1.8: + version "2.1.8" + resolved "https://registry.yarnpkg.com/chokidar/-/chokidar-2.1.8.tgz#804b3a7b6a99358c3c5c61e71d8728f041cff917" + integrity sha512-ZmZUazfOzf0Nve7duiCKD23PFSCs4JPoYyccjUFF3aQkQadqBhfzhjkwBH2mNOG9cTBwhamM37EIsIkZw3nRgg== + dependencies: + anymatch "^2.0.0" + async-each "^1.0.1" + braces "^2.3.2" + glob-parent "^3.1.0" + inherits "^2.0.3" + is-binary-path "^1.0.0" + is-glob "^4.0.0" + normalize-path "^3.0.0" + path-is-absolute "^1.0.0" + readdirp "^2.2.1" + upath "^1.1.1" + optionalDependencies: + fsevents "^1.2.7" + +chokidar@^3.3.0: + version "3.3.1" + resolved "https://registry.yarnpkg.com/chokidar/-/chokidar-3.3.1.tgz#c84e5b3d18d9a4d77558fef466b1bf16bbeb3450" + integrity sha512-4QYCEWOcK3OJrxwvyyAOxFuhpvOVCYkr33LPfFNBjAD/w3sEzWsp2BUOkI4l9bHvWioAd0rc6NlHUOEaWkTeqg== + dependencies: + anymatch "~3.1.1" + braces "~3.0.2" + glob-parent "~5.1.0" + is-binary-path "~2.1.0" + is-glob "~4.0.1" + normalize-path "~3.0.0" + readdirp "~3.3.0" + optionalDependencies: + fsevents "~2.1.2" + +chownr@^1.1.1, chownr@^1.1.2: + version "1.1.4" + resolved "https://registry.yarnpkg.com/chownr/-/chownr-1.1.4.tgz#6fc9d7b42d32a583596337666e7d08084da2cc6b" + integrity sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg== + +chrome-trace-event@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/chrome-trace-event/-/chrome-trace-event-1.0.2.tgz#234090ee97c7d4ad1a2c4beae27505deffc608a4" + integrity sha512-9e/zx1jw7B4CO+c/RXoCsfg/x1AfUBioy4owYH0bJprEYAx5hRFLRhWBqHAG57D0ZM4H7vxbP7bPe0VwhQRYDQ== + dependencies: + tslib "^1.9.0" + +ci-info@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/ci-info/-/ci-info-2.0.0.tgz#67a9e964be31a51e15e5010d58e6f12834002f46" + integrity sha512-5tK7EtrZ0N+OLFMthtqOj4fI2Jeb88C4CAZPu25LDVUgXJ0A3Js4PMGqrn0JU1W0Mh1/Z8wZzYPxqUrXeBboCQ== + +cipher-base@^1.0.0, cipher-base@^1.0.1, cipher-base@^1.0.3: + version "1.0.4" + resolved "https://registry.yarnpkg.com/cipher-base/-/cipher-base-1.0.4.tgz#8760e4ecc272f4c363532f926d874aae2c1397de" + integrity sha512-Kkht5ye6ZGmwv40uUDZztayT2ThLQGfnj/T71N/XzeZeo3nf8foyW7zGTsPYkEya3m5f3cAypH+qe7YOrM1U2Q== + dependencies: + inherits "^2.0.1" + safe-buffer "^5.0.1" + +class-utils@^0.3.5: + version "0.3.6" + resolved "https://registry.yarnpkg.com/class-utils/-/class-utils-0.3.6.tgz#f93369ae8b9a7ce02fd41faad0ca83033190c463" + integrity sha512-qOhPa/Fj7s6TY8H8esGu5QNpMMQxz79h+urzrNYN6mn+9BnxlDGf5QZ+XeCDsxSjPqsSR56XOZOJmpeurnLMeg== + dependencies: + arr-union "^3.1.0" + define-property "^0.2.5" + isobject "^3.0.0" + static-extend "^0.1.1" + +clean-css@^4.2.3: + version "4.2.3" + resolved "https://registry.yarnpkg.com/clean-css/-/clean-css-4.2.3.tgz#507b5de7d97b48ee53d84adb0160ff6216380f78" + integrity sha512-VcMWDN54ZN/DS+g58HYL5/n4Zrqe8vHJpGA8KdgUXFU4fuP/aHNw8eld9SyEIyabIMJX/0RaY/fplOo5hYLSFA== + dependencies: + source-map "~0.6.0" + +clean-stack@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/clean-stack/-/clean-stack-2.2.0.tgz#ee8472dbb129e727b31e8a10a427dee9dfe4008b" + integrity sha512-4diC9HaTE+KRAMWhDhrGOECgWZxoevMc5TlkObMqNSsVU62PYzXZ/SMTjzyGAFF1YusgxGcSWTEXBhp0CPwQ1A== + +cli-cursor@^3.1.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/cli-cursor/-/cli-cursor-3.1.0.tgz#264305a7ae490d1d03bf0c9ba7c925d1753af307" + integrity sha512-I/zHAwsKf9FqGoXM4WWRACob9+SNukZTd94DWF57E4toouRulbCxcUh6RKUEOQlYTHJnzkPMySvPNaaSLNfLZw== + dependencies: + restore-cursor "^3.1.0" + +cli-width@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/cli-width/-/cli-width-2.2.0.tgz#ff19ede8a9a5e579324147b0c11f0fbcbabed639" + integrity sha1-/xnt6Kml5XkyQUewwR8PvLq+1jk= + +cliui@^4.0.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-4.1.0.tgz#348422dbe82d800b3022eef4f6ac10bf2e4d1b49" + integrity sha512-4FG+RSG9DL7uEwRUZXZn3SS34DiDPfzP0VOiEwtUWlE+AR2EIg+hSyvrIgUUfhdgR/UkAeW2QHgeP+hWrXs7jQ== + dependencies: + string-width "^2.1.1" + strip-ansi "^4.0.0" + wrap-ansi "^2.0.0" + +cliui@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-5.0.0.tgz#deefcfdb2e800784aa34f46fa08e06851c7bbbc5" + integrity sha512-PYeGSEmmHM6zvoef2w8TPzlrnNpXIjTipYK780YswmIP9vjxmd6Y2a3CB2Ks6/AU8NHjZugXvo8w3oWM2qnwXA== + dependencies: + string-width "^3.1.0" + strip-ansi "^5.2.0" + wrap-ansi "^5.1.0" + +clone-deep@^0.2.4: + version "0.2.4" + resolved "https://registry.yarnpkg.com/clone-deep/-/clone-deep-0.2.4.tgz#4e73dd09e9fb971cc38670c5dced9c1896481cc6" + integrity sha1-TnPdCen7lxzDhnDF3O2cGJZIHMY= + dependencies: + for-own "^0.1.3" + is-plain-object "^2.0.1" + kind-of "^3.0.2" + lazy-cache "^1.0.3" + shallow-clone "^0.1.2" + +clone-deep@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/clone-deep/-/clone-deep-4.0.1.tgz#c19fd9bdbbf85942b4fd979c84dcf7d5f07c2387" + integrity sha512-neHB9xuzh/wk0dIHweyAXv2aPGZIVk3pLMe+/RNzINf17fe0OG96QroktYAUm7SM1PBnzTabaLboqqxDyMU+SQ== + dependencies: + is-plain-object "^2.0.4" + kind-of "^6.0.2" + shallow-clone "^3.0.0" + +co@^4.6.0: + version "4.6.0" + resolved "https://registry.yarnpkg.com/co/-/co-4.6.0.tgz#6ea6bdf3d853ae54ccb8e47bfa0bf3f9031fb184" + integrity sha1-bqa989hTrlTMuOR7+gvz+QMfsYQ= + +coa@^2.0.2: + version "2.0.2" + resolved "https://registry.yarnpkg.com/coa/-/coa-2.0.2.tgz#43f6c21151b4ef2bf57187db0d73de229e3e7ec3" + integrity sha512-q5/jG+YQnSy4nRTV4F7lPepBJZ8qBNJJDBuJdoejDyLXgmL7IEo+Le2JDZudFTFt7mrCqIRaSjws4ygRCTCAXA== + dependencies: + "@types/q" "^1.5.1" + chalk "^2.4.1" + q "^1.1.2" + +code-point-at@^1.0.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/code-point-at/-/code-point-at-1.1.0.tgz#0d070b4d043a5bea33a2f1a40e2edb3d9a4ccf77" + integrity sha1-DQcLTQQ6W+ozovGkDi7bPZpMz3c= + +collection-visit@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/collection-visit/-/collection-visit-1.0.0.tgz#4bc0373c164bc3291b4d368c829cf1a80a59dca0" + integrity sha1-S8A3PBZLwykbTTaMgpzxqApZ3KA= + dependencies: + map-visit "^1.0.0" + object-visit "^1.0.0" + +color-convert@^1.9.0, color-convert@^1.9.1, color-convert@^1.9.3: + version "1.9.3" + resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-1.9.3.tgz#bb71850690e1f136567de629d2d5471deda4c1e8" + integrity sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg== + dependencies: + color-name "1.1.3" + +color-convert@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" + integrity sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ== + dependencies: + color-name "~1.1.4" + +color-name@1.1.3: + version "1.1.3" + resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.3.tgz#a7d0558bd89c42f795dd42328f740831ca53bc25" + integrity sha1-p9BVi9icQveV3UIyj3QIMcpTvCU= + +color-name@^1.0.0, color-name@~1.1.4: + version "1.1.4" + resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2" + integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA== + +color-string@^1.5.2: + version "1.5.3" + resolved "https://registry.yarnpkg.com/color-string/-/color-string-1.5.3.tgz#c9bbc5f01b58b5492f3d6857459cb6590ce204cc" + integrity sha512-dC2C5qeWoYkxki5UAXapdjqO672AM4vZuPGRQfO8b5HKuKGBbKWpITyDYN7TOFKvRW7kOgAn3746clDBMDJyQw== + dependencies: + color-name "^1.0.0" + simple-swizzle "^0.2.2" + +color@^3.0.0: + version "3.1.2" + resolved "https://registry.yarnpkg.com/color/-/color-3.1.2.tgz#68148e7f85d41ad7649c5fa8c8106f098d229e10" + integrity sha512-vXTJhHebByxZn3lDvDJYw4lR5+uB3vuoHsuYA5AKuxRVn5wzzIfQKGLBmgdVRHKTJYeK5rvJcHnrd0Li49CFpg== + dependencies: + color-convert "^1.9.1" + color-string "^1.5.2" + +combined-stream@^1.0.6, combined-stream@~1.0.6: + version "1.0.8" + resolved "https://registry.yarnpkg.com/combined-stream/-/combined-stream-1.0.8.tgz#c3d45a8b34fd730631a110a8a2520682b31d5a7f" + integrity sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg== + dependencies: + delayed-stream "~1.0.0" + +commander@^2.11.0, commander@^2.20.0: + version "2.20.3" + resolved "https://registry.yarnpkg.com/commander/-/commander-2.20.3.tgz#fd485e84c03eb4881c20722ba48035e8531aeb33" + integrity sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ== + +commander@^4.1.1: + version "4.1.1" + resolved "https://registry.yarnpkg.com/commander/-/commander-4.1.1.tgz#9fd602bd936294e9e9ef46a3f4d6964044b18068" + integrity sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA== + +common-tags@^1.8.0: + version "1.8.0" + resolved "https://registry.yarnpkg.com/common-tags/-/common-tags-1.8.0.tgz#8e3153e542d4a39e9b10554434afaaf98956a937" + integrity sha512-6P6g0uetGpW/sdyUy/iQQCbFF0kWVMSIVSyYz7Zgjcgh8mgw8PQzDNZeyZ5DQ2gM7LBoZPHmnjz8rUthkBG5tw== + +commondir@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/commondir/-/commondir-1.0.1.tgz#ddd800da0c66127393cca5950ea968a3aaf1253b" + integrity sha1-3dgA2gxmEnOTzKWVDqloo6rxJTs= + +component-emitter@^1.2.1: + version "1.3.0" + resolved "https://registry.yarnpkg.com/component-emitter/-/component-emitter-1.3.0.tgz#16e4070fba8ae29b679f2215853ee181ab2eabc0" + integrity sha512-Rd3se6QB+sO1TwqZjscQrurpEPIfO0/yYnSin6Q/rD3mOutHvUrCAhJub3r90uNb+SESBuE0QYoB90YdfatsRg== + +compose-function@3.0.3: + version "3.0.3" + resolved "https://registry.yarnpkg.com/compose-function/-/compose-function-3.0.3.tgz#9ed675f13cc54501d30950a486ff6a7ba3ab185f" + integrity sha1-ntZ18TzFRQHTCVCkhv9qe6OrGF8= + dependencies: + arity-n "^1.0.4" + +compressible@~2.0.16: + version "2.0.18" + resolved "https://registry.yarnpkg.com/compressible/-/compressible-2.0.18.tgz#af53cca6b070d4c3c0750fbd77286a6d7cc46fba" + integrity sha512-AF3r7P5dWxL8MxyITRMlORQNaOA2IkAFaTr4k7BUumjPtRpGDTZpl0Pb1XCO6JeDCBdp126Cgs9sMxqSjgYyRg== + dependencies: + mime-db ">= 1.43.0 < 2" + +compression@^1.7.4: + version "1.7.4" + resolved "https://registry.yarnpkg.com/compression/-/compression-1.7.4.tgz#95523eff170ca57c29a0ca41e6fe131f41e5bb8f" + integrity sha512-jaSIDzP9pZVS4ZfQ+TzvtiWhdpFhE2RDHz8QJkpX9SIpLq88VueF5jJw6t+6CUQcAoA6t+x89MLrWAqpfDE8iQ== + dependencies: + accepts "~1.3.5" + bytes "3.0.0" + compressible "~2.0.16" + debug "2.6.9" + on-headers "~1.0.2" + safe-buffer "5.1.2" + vary "~1.1.2" + +concat-map@0.0.1: + version "0.0.1" + resolved "https://registry.yarnpkg.com/concat-map/-/concat-map-0.0.1.tgz#d8a96bd77fd68df7793a73036a3ba0d5405d477b" + integrity sha1-2Klr13/Wjfd5OnMDajug1UBdR3s= + +concat-stream@^1.5.0: + version "1.6.2" + resolved "https://registry.yarnpkg.com/concat-stream/-/concat-stream-1.6.2.tgz#904bdf194cd3122fc675c77fc4ac3d4ff0fd1a34" + integrity sha512-27HBghJxjiZtIk3Ycvn/4kbJk/1uZuJFfuPEns6LaEvpvG1f0hTea8lilrouyo9mVc2GWdcEZ8OLoGmSADlrCw== + dependencies: + buffer-from "^1.0.0" + inherits "^2.0.3" + readable-stream "^2.2.2" + typedarray "^0.0.6" + +confusing-browser-globals@^1.0.9: + version "1.0.9" + resolved "https://registry.yarnpkg.com/confusing-browser-globals/-/confusing-browser-globals-1.0.9.tgz#72bc13b483c0276801681871d4898516f8f54fdd" + integrity sha512-KbS1Y0jMtyPgIxjO7ZzMAuUpAKMt1SzCL9fsrKsX6b0zJPTaT0SiSPmewwVZg9UAO83HVIlEhZF84LIjZ0lmAw== + +connect-history-api-fallback@^1.6.0: + version "1.6.0" + resolved "https://registry.yarnpkg.com/connect-history-api-fallback/-/connect-history-api-fallback-1.6.0.tgz#8b32089359308d111115d81cad3fceab888f97bc" + integrity sha512-e54B99q/OUoH64zYYRf3HBP5z24G38h5D3qXu23JGRoigpX5Ss4r9ZnDk3g0Z8uQC2x2lPaJ+UlWBc1ZWBWdLg== + +console-browserify@^1.1.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/console-browserify/-/console-browserify-1.2.0.tgz#67063cef57ceb6cf4993a2ab3a55840ae8c49336" + integrity sha512-ZMkYO/LkF17QvCPqM0gxw8yUzigAOZOSWSHg91FH6orS7vcEj5dVZTidN2fQ14yBSdg97RqhSNwLUXInd52OTA== + +constants-browserify@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/constants-browserify/-/constants-browserify-1.0.0.tgz#c20b96d8c617748aaf1c16021760cd27fcb8cb75" + integrity sha1-wguW2MYXdIqvHBYCF2DNJ/y4y3U= + +contains-path@^0.1.0: + version "0.1.0" + resolved "https://registry.yarnpkg.com/contains-path/-/contains-path-0.1.0.tgz#fe8cf184ff6670b6baef01a9d4861a5cbec4120a" + integrity sha1-/ozxhP9mcLa67wGp1IYaXL7EEgo= + +content-disposition@0.5.3: + version "0.5.3" + resolved "https://registry.yarnpkg.com/content-disposition/-/content-disposition-0.5.3.tgz#e130caf7e7279087c5616c2007d0485698984fbd" + integrity sha512-ExO0774ikEObIAEV9kDo50o+79VCUdEB6n6lzKgGwupcVeRlhrj3qGAfwq8G6uBJjkqLrhT0qEYFcWng8z1z0g== + dependencies: + safe-buffer "5.1.2" + +content-type@~1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/content-type/-/content-type-1.0.4.tgz#e138cc75e040c727b1966fe5e5f8c9aee256fe3b" + integrity sha512-hIP3EEPs8tB9AT1L+NUqtwOAps4mk2Zob89MWXMHjHWg9milF/j4osnnQLXBCBFBk/tvIG/tUc9mOUJiPBhPXA== + +convert-source-map@1.7.0, convert-source-map@^1.4.0, convert-source-map@^1.7.0: + version "1.7.0" + resolved "https://registry.yarnpkg.com/convert-source-map/-/convert-source-map-1.7.0.tgz#17a2cb882d7f77d3490585e2ce6c524424a3a442" + integrity sha512-4FJkXzKXEDB1snCFZlLP4gpC3JILicCpGbzG9f9G7tGqGCzETQ2hWPrcinA9oU4wtf2biUaEH5065UnMeR33oA== + dependencies: + safe-buffer "~5.1.1" + +convert-source-map@^0.3.3: + version "0.3.5" + resolved "https://registry.yarnpkg.com/convert-source-map/-/convert-source-map-0.3.5.tgz#f1d802950af7dd2631a1febe0596550c86ab3190" + integrity sha1-8dgClQr33SYxof6+BZZVDIarMZA= + +cookie-signature@1.0.6: + version "1.0.6" + resolved "https://registry.yarnpkg.com/cookie-signature/-/cookie-signature-1.0.6.tgz#e303a882b342cc3ee8ca513a79999734dab3ae2c" + integrity sha1-4wOogrNCzD7oylE6eZmXNNqzriw= + +cookie@0.4.0: + version "0.4.0" + resolved "https://registry.yarnpkg.com/cookie/-/cookie-0.4.0.tgz#beb437e7022b3b6d49019d088665303ebe9c14ba" + integrity sha512-+Hp8fLp57wnUSt0tY0tHEXh4voZRDnoIrZPqlo3DPiI4y9lwg/jqx+1Om94/W6ZaPDOUbnjOt/99w66zk+l1Xg== + +copy-concurrently@^1.0.0: + version "1.0.5" + resolved "https://registry.yarnpkg.com/copy-concurrently/-/copy-concurrently-1.0.5.tgz#92297398cae34937fcafd6ec8139c18051f0b5e0" + integrity sha512-f2domd9fsVDFtaFcbaRZuYXwtdmnzqbADSwhSWYxYB/Q8zsdUUFMXVRwXGDMWmbEzAn1kdRrtI1T/KTFOL4X2A== + dependencies: + aproba "^1.1.1" + fs-write-stream-atomic "^1.0.8" + iferr "^0.1.5" + mkdirp "^0.5.1" + rimraf "^2.5.4" + run-queue "^1.0.0" + +copy-descriptor@^0.1.0: + version "0.1.1" + resolved "https://registry.yarnpkg.com/copy-descriptor/-/copy-descriptor-0.1.1.tgz#676f6eb3c39997c2ee1ac3a924fd6124748f578d" + integrity sha1-Z29us8OZl8LuGsOpJP1hJHSPV40= + +core-js-compat@^3.6.2: + version "3.6.4" + resolved "https://registry.yarnpkg.com/core-js-compat/-/core-js-compat-3.6.4.tgz#938476569ebb6cda80d339bcf199fae4f16fff17" + integrity sha512-zAa3IZPvsJ0slViBQ2z+vgyyTuhd3MFn1rBQjZSKVEgB0UMYhUkCj9jJUVPgGTGqWvsBVmfnruXgTcNyTlEiSA== + dependencies: + browserslist "^4.8.3" + semver "7.0.0" + +core-js-pure@^3.0.0: + version "3.6.4" + resolved "https://registry.yarnpkg.com/core-js-pure/-/core-js-pure-3.6.4.tgz#4bf1ba866e25814f149d4e9aaa08c36173506e3a" + integrity sha512-epIhRLkXdgv32xIUFaaAry2wdxZYBi6bgM7cB136dzzXXa+dFyRLTZeLUJxnd8ShrmyVXBub63n2NHo2JAt8Cw== + +core-js@^2.4.0: + version "2.6.11" + resolved "https://registry.yarnpkg.com/core-js/-/core-js-2.6.11.tgz#38831469f9922bded8ee21c9dc46985e0399308c" + integrity sha512-5wjnpaT/3dV+XB4borEsnAYQchn00XSgTAWKDkEqv+K8KevjbzmofK6hfJ9TZIlpj2N0xQpazy7PiRQiWHqzWg== + +core-js@^3.5.0: + version "3.6.4" + resolved "https://registry.yarnpkg.com/core-js/-/core-js-3.6.4.tgz#440a83536b458114b9cb2ac1580ba377dc470647" + integrity sha512-4paDGScNgZP2IXXilaffL9X7968RuvwlkK3xWtZRVqgd8SYNiVKRJvkFd1aqqEuPfN7E68ZHEp9hDj6lHj4Hyw== + +core-util-is@1.0.2, core-util-is@~1.0.0: + version "1.0.2" + resolved "https://registry.yarnpkg.com/core-util-is/-/core-util-is-1.0.2.tgz#b5fd54220aa2bc5ab57aab7140c940754503c1a7" + integrity sha1-tf1UIgqivFq1eqtxQMlAdUUDwac= + +cosmiconfig@^5.0.0, cosmiconfig@^5.2.1: + version "5.2.1" + resolved "https://registry.yarnpkg.com/cosmiconfig/-/cosmiconfig-5.2.1.tgz#040f726809c591e77a17c0a3626ca45b4f168b1a" + integrity sha512-H65gsXo1SKjf8zmrJ67eJk8aIRKV5ff2D4uKZIBZShbhGSpEmsQOPW/SKMKYhSTrqR7ufy6RP69rPogdaPh/kA== + dependencies: + import-fresh "^2.0.0" + is-directory "^0.3.1" + js-yaml "^3.13.1" + parse-json "^4.0.0" + +cosmiconfig@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/cosmiconfig/-/cosmiconfig-6.0.0.tgz#da4fee853c52f6b1e6935f41c1a2fc50bd4a9982" + integrity sha512-xb3ZL6+L8b9JLLCx3ZdoZy4+2ECphCMo2PwqgP1tlfVq6M6YReyzBJtvWWtbDSpNr9hn96pkCiZqUcFEc+54Qg== + dependencies: + "@types/parse-json" "^4.0.0" + import-fresh "^3.1.0" + parse-json "^5.0.0" + path-type "^4.0.0" + yaml "^1.7.2" + +create-ecdh@^4.0.0: + version "4.0.3" + resolved "https://registry.yarnpkg.com/create-ecdh/-/create-ecdh-4.0.3.tgz#c9111b6f33045c4697f144787f9254cdc77c45ff" + integrity sha512-GbEHQPMOswGpKXM9kCWVrremUcBmjteUaQ01T9rkKCPDXfUHX0IoP9LpHYo2NPFampa4e+/pFDc3jQdxrxQLaw== + dependencies: + bn.js "^4.1.0" + elliptic "^6.0.0" + +create-hash@^1.1.0, create-hash@^1.1.2: + version "1.2.0" + resolved "https://registry.yarnpkg.com/create-hash/-/create-hash-1.2.0.tgz#889078af11a63756bcfb59bd221996be3a9ef196" + integrity sha512-z00bCGNHDG8mHAkP7CtT1qVu+bFQUPjYq/4Iv3C3kWjTFV10zIjfSoeqXo9Asws8gwSHDGj/hl2u4OGIjapeCg== + dependencies: + cipher-base "^1.0.1" + inherits "^2.0.1" + md5.js "^1.3.4" + ripemd160 "^2.0.1" + sha.js "^2.4.0" + +create-hmac@^1.1.0, create-hmac@^1.1.2, create-hmac@^1.1.4: + version "1.1.7" + resolved "https://registry.yarnpkg.com/create-hmac/-/create-hmac-1.1.7.tgz#69170c78b3ab957147b2b8b04572e47ead2243ff" + integrity sha512-MJG9liiZ+ogc4TzUwuvbER1JRdgvUFSB5+VR/g5h82fGaIRWMWddtKBHi7/sVhfjQZ6SehlyhvQYrcYkaUIpLg== + dependencies: + cipher-base "^1.0.3" + create-hash "^1.1.0" + inherits "^2.0.1" + ripemd160 "^2.0.0" + safe-buffer "^5.0.1" + sha.js "^2.4.8" + +cross-spawn@6.0.5, cross-spawn@^6.0.0, cross-spawn@^6.0.5: + version "6.0.5" + resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-6.0.5.tgz#4a5ec7c64dfae22c3a14124dbacdee846d80cbc4" + integrity sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ== + dependencies: + nice-try "^1.0.4" + path-key "^2.0.1" + semver "^5.5.0" + shebang-command "^1.2.0" + which "^1.2.9" + +cross-spawn@7.0.1: + version "7.0.1" + resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.1.tgz#0ab56286e0f7c24e153d04cc2aa027e43a9a5d14" + integrity sha512-u7v4o84SwFpD32Z8IIcPZ6z1/ie24O6RU3RbtL5Y316l3KuHVPx9ItBgWQ6VlfAFnRnTtMUrsQ9MUUTuEZjogg== + dependencies: + path-key "^3.1.0" + shebang-command "^2.0.0" + which "^2.0.1" + +crypto-browserify@^3.11.0: + version "3.12.0" + resolved "https://registry.yarnpkg.com/crypto-browserify/-/crypto-browserify-3.12.0.tgz#396cf9f3137f03e4b8e532c58f698254e00f80ec" + integrity sha512-fz4spIh+znjO2VjL+IdhEpRJ3YN6sMzITSBijk6FK2UvTqruSQW+/cCZTSNsMiZNvUeq0CqurF+dAbyiGOY6Wg== + dependencies: + browserify-cipher "^1.0.0" + browserify-sign "^4.0.0" + create-ecdh "^4.0.0" + create-hash "^1.1.0" + create-hmac "^1.1.0" + diffie-hellman "^5.0.0" + inherits "^2.0.1" + pbkdf2 "^3.0.3" + public-encrypt "^4.0.0" + randombytes "^2.0.0" + randomfill "^1.0.3" + +css-blank-pseudo@^0.1.4: + version "0.1.4" + resolved "https://registry.yarnpkg.com/css-blank-pseudo/-/css-blank-pseudo-0.1.4.tgz#dfdefd3254bf8a82027993674ccf35483bfcb3c5" + integrity sha512-LHz35Hr83dnFeipc7oqFDmsjHdljj3TQtxGGiNWSOsTLIAubSm4TEz8qCaKFpk7idaQ1GfWscF4E6mgpBysA1w== + dependencies: + postcss "^7.0.5" + +css-color-names@0.0.4, css-color-names@^0.0.4: + version "0.0.4" + resolved "https://registry.yarnpkg.com/css-color-names/-/css-color-names-0.0.4.tgz#808adc2e79cf84738069b646cb20ec27beb629e0" + integrity sha1-gIrcLnnPhHOAabZGyyDsJ762KeA= + +css-declaration-sorter@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/css-declaration-sorter/-/css-declaration-sorter-4.0.1.tgz#c198940f63a76d7e36c1e71018b001721054cb22" + integrity sha512-BcxQSKTSEEQUftYpBVnsH4SF05NTuBokb19/sBt6asXGKZ/6VP7PLG1CBCkFDYOnhXhPh0jMhO6xZ71oYHXHBA== + dependencies: + postcss "^7.0.1" + timsort "^0.3.0" + +css-has-pseudo@^0.10.0: + version "0.10.0" + resolved "https://registry.yarnpkg.com/css-has-pseudo/-/css-has-pseudo-0.10.0.tgz#3c642ab34ca242c59c41a125df9105841f6966ee" + integrity sha512-Z8hnfsZu4o/kt+AuFzeGpLVhFOGO9mluyHBaA2bA8aCGTwah5sT3WV/fTHH8UNZUytOIImuGPrl/prlb4oX4qQ== + dependencies: + postcss "^7.0.6" + postcss-selector-parser "^5.0.0-rc.4" + +css-loader@3.3.0: + version "3.3.0" + resolved "https://registry.yarnpkg.com/css-loader/-/css-loader-3.3.0.tgz#65f889807baec3197313965d6cda9899f936734d" + integrity sha512-x9Y1vvHe5RR+4tzwFdWExPueK00uqFTCw7mZy+9aE/X1SKWOArm5luaOrtJ4d05IpOwJ6S86b/tVcIdhw1Bu4A== + dependencies: + camelcase "^5.3.1" + cssesc "^3.0.0" + icss-utils "^4.1.1" + loader-utils "^1.2.3" + normalize-path "^3.0.0" + postcss "^7.0.23" + postcss-modules-extract-imports "^2.0.0" + postcss-modules-local-by-default "^3.0.2" + postcss-modules-scope "^2.1.1" + postcss-modules-values "^3.0.0" + postcss-value-parser "^4.0.2" + schema-utils "^2.6.0" + +css-loader@3.4.2: + version "3.4.2" + resolved "https://registry.yarnpkg.com/css-loader/-/css-loader-3.4.2.tgz#d3fdb3358b43f233b78501c5ed7b1c6da6133202" + integrity sha512-jYq4zdZT0oS0Iykt+fqnzVLRIeiPWhka+7BqPn+oSIpWJAHak5tmB/WZrJ2a21JhCeFyNnnlroSl8c+MtVndzA== + dependencies: + camelcase "^5.3.1" + cssesc "^3.0.0" + icss-utils "^4.1.1" + loader-utils "^1.2.3" + normalize-path "^3.0.0" + postcss "^7.0.23" + postcss-modules-extract-imports "^2.0.0" + postcss-modules-local-by-default "^3.0.2" + postcss-modules-scope "^2.1.1" + postcss-modules-values "^3.0.0" + postcss-value-parser "^4.0.2" + schema-utils "^2.6.0" + +css-prefers-color-scheme@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/css-prefers-color-scheme/-/css-prefers-color-scheme-3.1.1.tgz#6f830a2714199d4f0d0d0bb8a27916ed65cff1f4" + integrity sha512-MTu6+tMs9S3EUqzmqLXEcgNRbNkkD/TGFvowpeoWJn5Vfq7FMgsmRQs9X5NXAURiOBmOxm/lLjsDNXDE6k9bhg== + dependencies: + postcss "^7.0.5" + +css-select-base-adapter@^0.1.1: + version "0.1.1" + resolved "https://registry.yarnpkg.com/css-select-base-adapter/-/css-select-base-adapter-0.1.1.tgz#3b2ff4972cc362ab88561507a95408a1432135d7" + integrity sha512-jQVeeRG70QI08vSTwf1jHxp74JoZsr2XSgETae8/xC8ovSnL2WF87GTLO86Sbwdt2lK4Umg4HnnwMO4YF3Ce7w== + +css-select@^1.1.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/css-select/-/css-select-1.2.0.tgz#2b3a110539c5355f1cd8d314623e870b121ec858" + integrity sha1-KzoRBTnFNV8c2NMUYj6HCxIeyFg= + dependencies: + boolbase "~1.0.0" + css-what "2.1" + domutils "1.5.1" + nth-check "~1.0.1" + +css-select@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/css-select/-/css-select-2.1.0.tgz#6a34653356635934a81baca68d0255432105dbef" + integrity sha512-Dqk7LQKpwLoH3VovzZnkzegqNSuAziQyNZUcrdDM401iY+R5NkGBXGmtO05/yaXQziALuPogeG0b7UAgjnTJTQ== + dependencies: + boolbase "^1.0.0" + css-what "^3.2.1" + domutils "^1.7.0" + nth-check "^1.0.2" + +css-tree@1.0.0-alpha.37: + version "1.0.0-alpha.37" + resolved "https://registry.yarnpkg.com/css-tree/-/css-tree-1.0.0-alpha.37.tgz#98bebd62c4c1d9f960ec340cf9f7522e30709a22" + integrity sha512-DMxWJg0rnz7UgxKT0Q1HU/L9BeJI0M6ksor0OgqOnF+aRCDWg/N2641HmVyU9KVIu0OVVWOb2IpC9A+BJRnejg== + dependencies: + mdn-data "2.0.4" + source-map "^0.6.1" + +css-tree@1.0.0-alpha.39: + version "1.0.0-alpha.39" + resolved "https://registry.yarnpkg.com/css-tree/-/css-tree-1.0.0-alpha.39.tgz#2bff3ffe1bb3f776cf7eefd91ee5cba77a149eeb" + integrity sha512-7UvkEYgBAHRG9Nt980lYxjsTrCyHFN53ky3wVsDkiMdVqylqRt+Zc+jm5qw7/qyOvN2dHSYtX0e4MbCCExSvnA== + dependencies: + mdn-data "2.0.6" + source-map "^0.6.1" + +css-what@2.1: + version "2.1.3" + resolved "https://registry.yarnpkg.com/css-what/-/css-what-2.1.3.tgz#a6d7604573365fe74686c3f311c56513d88285f2" + integrity sha512-a+EPoD+uZiNfh+5fxw2nO9QwFa6nJe2Or35fGY6Ipw1R3R4AGz1d1TEZrCegvw2YTmZ0jXirGYlzxxpYSHwpEg== + +css-what@^3.2.1: + version "3.2.1" + resolved "https://registry.yarnpkg.com/css-what/-/css-what-3.2.1.tgz#f4a8f12421064621b456755e34a03a2c22df5da1" + integrity sha512-WwOrosiQTvyms+Ti5ZC5vGEK0Vod3FTt1ca+payZqvKuGJF+dq7bG63DstxtN0dpm6FxY27a/zS3Wten+gEtGw== + +css@^2.0.0: + version "2.2.4" + resolved "https://registry.yarnpkg.com/css/-/css-2.2.4.tgz#c646755c73971f2bba6a601e2cf2fd71b1298929" + integrity sha512-oUnjmWpy0niI3x/mPL8dVEI1l7MnG3+HHyRPHf+YFSbK+svOhXpmSOcDURUh2aOCgl2grzrOPt1nHLuCVFULLw== + dependencies: + inherits "^2.0.3" + source-map "^0.6.1" + source-map-resolve "^0.5.2" + urix "^0.1.0" + +cssdb@^4.4.0: + version "4.4.0" + resolved "https://registry.yarnpkg.com/cssdb/-/cssdb-4.4.0.tgz#3bf2f2a68c10f5c6a08abd92378331ee803cddb0" + integrity sha512-LsTAR1JPEM9TpGhl/0p3nQecC2LJ0kD8X5YARu1hk/9I1gril5vDtMZyNxcEpxxDj34YNck/ucjuoUd66K03oQ== + +cssesc@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/cssesc/-/cssesc-2.0.0.tgz#3b13bd1bb1cb36e1bcb5a4dcd27f54c5dcb35703" + integrity sha512-MsCAG1z9lPdoO/IUMLSBWBSVxVtJ1395VGIQ+Fc2gNdkQ1hNDnQdw3YhA71WJCBW1vdwA0cAnk/DnW6bqoEUYg== + +cssesc@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/cssesc/-/cssesc-3.0.0.tgz#37741919903b868565e1c09ea747445cd18983ee" + integrity sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg== + +cssnano-preset-default@^4.0.7: + version "4.0.7" + resolved "https://registry.yarnpkg.com/cssnano-preset-default/-/cssnano-preset-default-4.0.7.tgz#51ec662ccfca0f88b396dcd9679cdb931be17f76" + integrity sha512-x0YHHx2h6p0fCl1zY9L9roD7rnlltugGu7zXSKQx6k2rYw0Hi3IqxcoAGF7u9Q5w1nt7vK0ulxV8Lo+EvllGsA== + dependencies: + css-declaration-sorter "^4.0.1" + cssnano-util-raw-cache "^4.0.1" + postcss "^7.0.0" + postcss-calc "^7.0.1" + postcss-colormin "^4.0.3" + postcss-convert-values "^4.0.1" + postcss-discard-comments "^4.0.2" + postcss-discard-duplicates "^4.0.2" + postcss-discard-empty "^4.0.1" + postcss-discard-overridden "^4.0.1" + postcss-merge-longhand "^4.0.11" + postcss-merge-rules "^4.0.3" + postcss-minify-font-values "^4.0.2" + postcss-minify-gradients "^4.0.2" + postcss-minify-params "^4.0.2" + postcss-minify-selectors "^4.0.2" + postcss-normalize-charset "^4.0.1" + postcss-normalize-display-values "^4.0.2" + postcss-normalize-positions "^4.0.2" + postcss-normalize-repeat-style "^4.0.2" + postcss-normalize-string "^4.0.2" + postcss-normalize-timing-functions "^4.0.2" + postcss-normalize-unicode "^4.0.1" + postcss-normalize-url "^4.0.1" + postcss-normalize-whitespace "^4.0.2" + postcss-ordered-values "^4.1.2" + postcss-reduce-initial "^4.0.3" + postcss-reduce-transforms "^4.0.2" + postcss-svgo "^4.0.2" + postcss-unique-selectors "^4.0.1" + +cssnano-util-get-arguments@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/cssnano-util-get-arguments/-/cssnano-util-get-arguments-4.0.0.tgz#ed3a08299f21d75741b20f3b81f194ed49cc150f" + integrity sha1-7ToIKZ8h11dBsg87gfGU7UnMFQ8= + +cssnano-util-get-match@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/cssnano-util-get-match/-/cssnano-util-get-match-4.0.0.tgz#c0e4ca07f5386bb17ec5e52250b4f5961365156d" + integrity sha1-wOTKB/U4a7F+xeUiULT1lhNlFW0= + +cssnano-util-raw-cache@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/cssnano-util-raw-cache/-/cssnano-util-raw-cache-4.0.1.tgz#b26d5fd5f72a11dfe7a7846fb4c67260f96bf282" + integrity sha512-qLuYtWK2b2Dy55I8ZX3ky1Z16WYsx544Q0UWViebptpwn/xDBmog2TLg4f+DBMg1rJ6JDWtn96WHbOKDWt1WQA== + dependencies: + postcss "^7.0.0" + +cssnano-util-same-parent@^4.0.0: + version "4.0.1" + resolved "https://registry.yarnpkg.com/cssnano-util-same-parent/-/cssnano-util-same-parent-4.0.1.tgz#574082fb2859d2db433855835d9a8456ea18bbf3" + integrity sha512-WcKx5OY+KoSIAxBW6UBBRay1U6vkYheCdjyVNDm85zt5K9mHoGOfsOsqIszfAqrQQFIIKgjh2+FDgIj/zsl21Q== + +cssnano@^4.1.10: + version "4.1.10" + resolved "https://registry.yarnpkg.com/cssnano/-/cssnano-4.1.10.tgz#0ac41f0b13d13d465487e111b778d42da631b8b2" + integrity sha512-5wny+F6H4/8RgNlaqab4ktc3e0/blKutmq8yNlBFXA//nSFFAqAngjNVRzUvCgYROULmZZUoosL/KSoZo5aUaQ== + dependencies: + cosmiconfig "^5.0.0" + cssnano-preset-default "^4.0.7" + is-resolvable "^1.0.0" + postcss "^7.0.0" + +csso@^4.0.2: + version "4.0.3" + resolved "https://registry.yarnpkg.com/csso/-/csso-4.0.3.tgz#0d9985dc852c7cc2b2cacfbbe1079014d1a8e903" + integrity sha512-NL3spysxUkcrOgnpsT4Xdl2aiEiBG6bXswAABQVHcMrfjjBisFOKwLDOmf4wf32aPdcJws1zds2B0Rg+jqMyHQ== + dependencies: + css-tree "1.0.0-alpha.39" + +cssom@0.3.x, "cssom@>= 0.3.2 < 0.4.0", cssom@^0.3.4: + version "0.3.8" + resolved "https://registry.yarnpkg.com/cssom/-/cssom-0.3.8.tgz#9f1276f5b2b463f2114d3f2c75250af8c1a36f4a" + integrity sha512-b0tGHbfegbhPJpxpiBPU2sCkigAqtM9O121le6bbOlgyV+NyGyCmVfJ6QW9eRjz8CpNfWEOYBIMIGRYkLwsIYg== + +cssstyle@^1.0.0, cssstyle@^1.1.1: + version "1.4.0" + resolved "https://registry.yarnpkg.com/cssstyle/-/cssstyle-1.4.0.tgz#9d31328229d3c565c61e586b02041a28fccdccf1" + integrity sha512-GBrLZYZ4X4x6/QEoBnIrqb8B/f5l4+8me2dkom/j1Gtbxy0kBv6OGzKuAsGM75bkGwGAFkt56Iwg28S3XTZgSA== + dependencies: + cssom "0.3.x" + +csstype@^2.2.0: + version "2.6.10" + resolved "https://registry.yarnpkg.com/csstype/-/csstype-2.6.10.tgz#e63af50e66d7c266edb6b32909cfd0aabe03928b" + integrity sha512-D34BqZU4cIlMCY93rZHbrq9pjTAQJ3U8S8rfBqjwHxkGPThWFjzZDQpgMJY0QViLxth6ZKYiwFBo14RdN44U/w== + +cyclist@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/cyclist/-/cyclist-1.0.1.tgz#596e9698fd0c80e12038c2b82d6eb1b35b6224d9" + integrity sha1-WW6WmP0MgOEgOMK4LW6xs1tiJNk= + +d@1, d@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/d/-/d-1.0.1.tgz#8698095372d58dbee346ffd0c7093f99f8f9eb5a" + integrity sha512-m62ShEObQ39CfralilEQRjH6oAMtNCV1xJyEx5LpRYUVN+EviphDgUc/F3hnYbADmkiNs67Y+3ylmlG7Lnu+FA== + dependencies: + es5-ext "^0.10.50" + type "^1.0.1" + +damerau-levenshtein@^1.0.4: + version "1.0.6" + resolved "https://registry.yarnpkg.com/damerau-levenshtein/-/damerau-levenshtein-1.0.6.tgz#143c1641cb3d85c60c32329e26899adea8701791" + integrity sha512-JVrozIeElnj3QzfUIt8tB8YMluBJom4Vw9qTPpjGYQ9fYlB3D/rb6OordUxf3xeFB35LKWs0xqcO5U6ySvBtug== + +dashdash@^1.12.0: + version "1.14.1" + resolved "https://registry.yarnpkg.com/dashdash/-/dashdash-1.14.1.tgz#853cfa0f7cbe2fed5de20326b8dd581035f6e2f0" + integrity sha1-hTz6D3y+L+1d4gMmuN1YEDX24vA= + dependencies: + assert-plus "^1.0.0" + +data-urls@^1.0.0, data-urls@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/data-urls/-/data-urls-1.1.0.tgz#15ee0582baa5e22bb59c77140da8f9c76963bbfe" + integrity sha512-YTWYI9se1P55u58gL5GkQHW4P6VJBJ5iBT+B5a7i2Tjadhv52paJG0qHX4A0OR6/t52odI64KP2YvFpkDOi3eQ== + dependencies: + abab "^2.0.0" + whatwg-mimetype "^2.2.0" + whatwg-url "^7.0.0" + +debug@2.6.9, debug@^2.2.0, debug@^2.3.3, debug@^2.6.0, debug@^2.6.9: + version "2.6.9" + resolved "https://registry.yarnpkg.com/debug/-/debug-2.6.9.tgz#5d128515df134ff327e90a4c93f4e077a536341f" + integrity sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA== + dependencies: + ms "2.0.0" + +debug@^3.0.0, debug@^3.1.1, debug@^3.2.5: + version "3.2.6" + resolved "https://registry.yarnpkg.com/debug/-/debug-3.2.6.tgz#e83d17de16d8a7efb7717edbe5fb10135eee629b" + integrity sha512-mel+jf7nrtEl5Pn1Qx46zARXKDpBbvzezse7p7LqINmdoIk8PYP5SySaxEmYv6TZ0JyEKA1hsCId6DIhgITtWQ== + dependencies: + ms "^2.1.1" + +debug@^4.0.1, debug@^4.1.0, debug@^4.1.1: + version "4.1.1" + resolved "https://registry.yarnpkg.com/debug/-/debug-4.1.1.tgz#3b72260255109c6b589cee050f1d516139664791" + integrity sha512-pYAIzeRo8J6KPEaJ0VWOh5Pzkbw/RetuzehGM7QRRX5he4fPHx2rdKMB256ehJCkX+XRQm16eZLqLNS8RSZXZw== + dependencies: + ms "^2.1.1" + +decamelize@^1.2.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/decamelize/-/decamelize-1.2.0.tgz#f6534d15148269b20352e7bee26f501f9a191290" + integrity sha1-9lNNFRSCabIDUue+4m9QH5oZEpA= + +decode-uri-component@^0.2.0: + version "0.2.0" + resolved "https://registry.yarnpkg.com/decode-uri-component/-/decode-uri-component-0.2.0.tgz#eb3913333458775cb84cd1a1fae062106bb87545" + integrity sha1-6zkTMzRYd1y4TNGh+uBiEGu4dUU= + +deep-equal@^1.0.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/deep-equal/-/deep-equal-1.1.1.tgz#b5c98c942ceffaf7cb051e24e1434a25a2e6076a" + integrity sha512-yd9c5AdiqVcR+JjcwUQb9DkhJc8ngNr0MahEBGvDiJw8puWab2yZlh+nkasOnZP+EGTAP6rRp2JzJhJZzvNF8g== + dependencies: + is-arguments "^1.0.4" + is-date-object "^1.0.1" + is-regex "^1.0.4" + object-is "^1.0.1" + object-keys "^1.1.1" + regexp.prototype.flags "^1.2.0" + +deep-is@~0.1.3: + version "0.1.3" + resolved "https://registry.yarnpkg.com/deep-is/-/deep-is-0.1.3.tgz#b369d6fb5dbc13eecf524f91b070feedc357cf34" + integrity sha1-s2nW+128E+7PUk+RsHD+7cNXzzQ= + +default-gateway@^4.2.0: + version "4.2.0" + resolved "https://registry.yarnpkg.com/default-gateway/-/default-gateway-4.2.0.tgz#167104c7500c2115f6dd69b0a536bb8ed720552b" + integrity sha512-h6sMrVB1VMWVrW13mSc6ia/DwYYw5MN6+exNu1OaJeFac5aSAvwM7lZ0NVfTABuSkQelr4h5oebg3KB1XPdjgA== + dependencies: + execa "^1.0.0" + ip-regex "^2.1.0" + +define-properties@^1.1.2, define-properties@^1.1.3: + version "1.1.3" + resolved "https://registry.yarnpkg.com/define-properties/-/define-properties-1.1.3.tgz#cf88da6cbee26fe6db7094f61d870cbd84cee9f1" + integrity sha512-3MqfYKj2lLzdMSf8ZIZE/V+Zuy+BgD6f164e8K2w7dgnpKArBDerGYpM46IYYcjnkdPNMjPk9A6VFB8+3SKlXQ== + dependencies: + object-keys "^1.0.12" + +define-property@^0.2.5: + version "0.2.5" + resolved "https://registry.yarnpkg.com/define-property/-/define-property-0.2.5.tgz#c35b1ef918ec3c990f9a5bc57be04aacec5c8116" + integrity sha1-w1se+RjsPJkPmlvFe+BKrOxcgRY= + dependencies: + is-descriptor "^0.1.0" + +define-property@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/define-property/-/define-property-1.0.0.tgz#769ebaaf3f4a63aad3af9e8d304c9bbe79bfb0e6" + integrity sha1-dp66rz9KY6rTr56NMEybvnm/sOY= + dependencies: + is-descriptor "^1.0.0" + +define-property@^2.0.2: + version "2.0.2" + resolved "https://registry.yarnpkg.com/define-property/-/define-property-2.0.2.tgz#d459689e8d654ba77e02a817f8710d702cb16e9d" + integrity sha512-jwK2UV4cnPpbcG7+VRARKTZPUWowwXA8bzH5NP6ud0oeAxyYPuGZUAC7hMugpCdz4BeSZl2Dl9k66CHJ/46ZYQ== + dependencies: + is-descriptor "^1.0.2" + isobject "^3.0.1" + +del@^4.1.1: + version "4.1.1" + resolved "https://registry.yarnpkg.com/del/-/del-4.1.1.tgz#9e8f117222ea44a31ff3a156c049b99052a9f0b4" + integrity sha512-QwGuEUouP2kVwQenAsOof5Fv8K9t3D8Ca8NxcXKrIpEHjTXK5J2nXLdP+ALI1cgv8wj7KuwBhTwBkOZSJKM5XQ== + dependencies: + "@types/glob" "^7.1.1" + globby "^6.1.0" + is-path-cwd "^2.0.0" + is-path-in-cwd "^2.0.0" + p-map "^2.0.0" + pify "^4.0.1" + rimraf "^2.6.3" + +delayed-stream@~1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/delayed-stream/-/delayed-stream-1.0.0.tgz#df3ae199acadfb7d440aaae0b29e2272b24ec619" + integrity sha1-3zrhmayt+31ECqrgsp4icrJOxhk= + +depd@~1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/depd/-/depd-1.1.2.tgz#9bcd52e14c097763e749b274c4346ed2e560b5a9" + integrity sha1-m81S4UwJd2PnSbJ0xDRu0uVgtak= + +des.js@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/des.js/-/des.js-1.0.1.tgz#5382142e1bdc53f85d86d53e5f4aa7deb91e0843" + integrity sha512-Q0I4pfFrv2VPd34/vfLrFOoRmlYj3OV50i7fskps1jZWK1kApMWWT9G6RRUeYedLcBDIhnSDaUvJMb3AhUlaEA== + dependencies: + inherits "^2.0.1" + minimalistic-assert "^1.0.0" + +destroy@~1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/destroy/-/destroy-1.0.4.tgz#978857442c44749e4206613e37946205826abd80" + integrity sha1-l4hXRCxEdJ5CBmE+N5RiBYJqvYA= + +detect-file@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/detect-file/-/detect-file-1.0.0.tgz#f0d66d03672a825cb1b73bdb3fe62310c8e552b7" + integrity sha1-8NZtA2cqglyxtzvbP+YjEMjlUrc= + +detect-newline@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/detect-newline/-/detect-newline-2.1.0.tgz#f41f1c10be4b00e87b5f13da680759f2c5bfd3e2" + integrity sha1-9B8cEL5LAOh7XxPaaAdZ8sW/0+I= + +detect-node@^2.0.4: + version "2.0.4" + resolved "https://registry.yarnpkg.com/detect-node/-/detect-node-2.0.4.tgz#014ee8f8f669c5c58023da64b8179c083a28c46c" + integrity sha512-ZIzRpLJrOj7jjP2miAtgqIfmzbxa4ZOr5jJc601zklsfEx9oTzmmj2nVpIPRpNlRTIh8lc1kyViIY7BWSGNmKw== + +detect-port-alt@1.1.6: + version "1.1.6" + resolved "https://registry.yarnpkg.com/detect-port-alt/-/detect-port-alt-1.1.6.tgz#24707deabe932d4a3cf621302027c2b266568275" + integrity sha512-5tQykt+LqfJFBEYaDITx7S7cR7mJ/zQmLXZ2qt5w04ainYZw6tBf9dBunMjVeVOdYVRUzUOE4HkY5J7+uttb5Q== + dependencies: + address "^1.0.1" + debug "^2.6.0" + +diff-sequences@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/diff-sequences/-/diff-sequences-24.9.0.tgz#5715d6244e2aa65f48bba0bc972db0b0b11e95b5" + integrity sha512-Dj6Wk3tWyTE+Fo1rW8v0Xhwk80um6yFYKbuAxc9c3EZxIHFDYwbi34Uk42u1CdnIiVorvt4RmlSDjIPyzGC2ew== + +diff-sequences@^25.2.6: + version "25.2.6" + resolved "https://registry.yarnpkg.com/diff-sequences/-/diff-sequences-25.2.6.tgz#5f467c00edd35352b7bca46d7927d60e687a76dd" + integrity sha512-Hq8o7+6GaZeoFjtpgvRBUknSXNeJiCx7V9Fr94ZMljNiCr9n9L8H8aJqgWOQiDDGdyn29fRNcDdRVJ5fdyihfg== + +diffie-hellman@^5.0.0: + version "5.0.3" + resolved "https://registry.yarnpkg.com/diffie-hellman/-/diffie-hellman-5.0.3.tgz#40e8ee98f55a2149607146921c63e1ae5f3d2875" + integrity sha512-kqag/Nl+f3GwyK25fhUMYj81BUOrZ9IuJsjIcDE5icNM9FJHAVm3VcUDxdLPoQtTuUylWm6ZIknYJwwaPxsUzg== + dependencies: + bn.js "^4.1.0" + miller-rabin "^4.0.0" + randombytes "^2.0.0" + +dir-glob@2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/dir-glob/-/dir-glob-2.0.0.tgz#0b205d2b6aef98238ca286598a8204d29d0a0034" + integrity sha512-37qirFDz8cA5fimp9feo43fSuRo2gHwaIn6dXL8Ber1dGwUosDrGZeCCXq57WnIqE4aQ+u3eQZzsk1yOzhdwag== + dependencies: + arrify "^1.0.1" + path-type "^3.0.0" + +dns-equal@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/dns-equal/-/dns-equal-1.0.0.tgz#b39e7f1da6eb0a75ba9c17324b34753c47e0654d" + integrity sha1-s55/HabrCnW6nBcySzR1PEfgZU0= + +dns-packet@^1.3.1: + version "1.3.1" + resolved "https://registry.yarnpkg.com/dns-packet/-/dns-packet-1.3.1.tgz#12aa426981075be500b910eedcd0b47dd7deda5a" + integrity sha512-0UxfQkMhYAUaZI+xrNZOz/as5KgDU0M/fQ9b6SpkyLbk3GEswDi6PADJVaYJradtRVsRIlF1zLyOodbcTCDzUg== + dependencies: + ip "^1.1.0" + safe-buffer "^5.0.1" + +dns-txt@^2.0.2: + version "2.0.2" + resolved "https://registry.yarnpkg.com/dns-txt/-/dns-txt-2.0.2.tgz#b91d806f5d27188e4ab3e7d107d881a1cc4642b6" + integrity sha1-uR2Ab10nGI5Ks+fRB9iBocxGQrY= + dependencies: + buffer-indexof "^1.0.0" + +doctrine@1.5.0: + version "1.5.0" + resolved "https://registry.yarnpkg.com/doctrine/-/doctrine-1.5.0.tgz#379dce730f6166f76cefa4e6707a159b02c5a6fa" + integrity sha1-N53Ocw9hZvds76TmcHoVmwLFpvo= + dependencies: + esutils "^2.0.2" + isarray "^1.0.0" + +doctrine@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/doctrine/-/doctrine-2.1.0.tgz#5cd01fc101621b42c4cd7f5d1a66243716d3f39d" + integrity sha512-35mSku4ZXK0vfCuHEDAwt55dg2jNajHZ1odvF+8SSr82EsZY4QmXfuWso8oEd8zRhVObSN18aM0CjSdoBX7zIw== + dependencies: + esutils "^2.0.2" + +doctrine@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/doctrine/-/doctrine-3.0.0.tgz#addebead72a6574db783639dc87a121773973961" + integrity sha512-yS+Q5i3hBf7GBkd4KG8a7eBNNWNGLTaEwwYWUijIYM7zrlYDM0BFXHjjPWlWZ1Rg7UaddZeIDmi9jF3HmqiQ2w== + dependencies: + esutils "^2.0.2" + +dom-converter@^0.2: + version "0.2.0" + resolved "https://registry.yarnpkg.com/dom-converter/-/dom-converter-0.2.0.tgz#6721a9daee2e293682955b6afe416771627bb768" + integrity sha512-gd3ypIPfOMr9h5jIKq8E3sHOTCjeirnl0WK5ZdS1AW0Odt0b1PaWaHdJ4Qk4klv+YB9aJBS7mESXjFoDQPu6DA== + dependencies: + utila "~0.4" + +dom-serializer@0: + version "0.2.2" + resolved "https://registry.yarnpkg.com/dom-serializer/-/dom-serializer-0.2.2.tgz#1afb81f533717175d478655debc5e332d9f9bb51" + integrity sha512-2/xPb3ORsQ42nHYiSunXkDjPLBaEj/xTwUO4B7XCZQTRk7EBtTOPaygh10YAAh2OI1Qrp6NWfpAhzswj0ydt9g== + dependencies: + domelementtype "^2.0.1" + entities "^2.0.0" + +domain-browser@^1.1.1: + version "1.2.0" + resolved "https://registry.yarnpkg.com/domain-browser/-/domain-browser-1.2.0.tgz#3d31f50191a6749dd1375a7f522e823d42e54eda" + integrity sha512-jnjyiM6eRyZl2H+W8Q/zLMA481hzi0eszAaBUzIVnmYVDBbnLxVNnfu1HgEBvCbL+71FrxMl3E6lpKH7Ge3OXA== + +domelementtype@1, domelementtype@^1.3.1: + version "1.3.1" + resolved "https://registry.yarnpkg.com/domelementtype/-/domelementtype-1.3.1.tgz#d048c44b37b0d10a7f2a3d5fee3f4333d790481f" + integrity sha512-BSKB+TSpMpFI/HOxCNr1O8aMOTZ8hT3pM3GQ0w/mWRmkhEDSFJkkyzz4XQsBV44BChwGkrDfMyjVD0eA2aFV3w== + +domelementtype@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/domelementtype/-/domelementtype-2.0.1.tgz#1f8bdfe91f5a78063274e803b4bdcedf6e94f94d" + integrity sha512-5HOHUDsYZWV8FGWN0Njbr/Rn7f/eWSQi1v7+HsUVwXgn8nWWlL64zKDkS0n8ZmQ3mlWOMuXOnR+7Nx/5tMO5AQ== + +domexception@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/domexception/-/domexception-1.0.1.tgz#937442644ca6a31261ef36e3ec677fe805582c90" + integrity sha512-raigMkn7CJNNo6Ihro1fzG7wr3fHuYVytzquZKX5n0yizGsTcYgzdIUwj1X9pK0VvjeihV+XiclP+DjwbsSKug== + dependencies: + webidl-conversions "^4.0.2" + +domhandler@^2.3.0: + version "2.4.2" + resolved "https://registry.yarnpkg.com/domhandler/-/domhandler-2.4.2.tgz#8805097e933d65e85546f726d60f5eb88b44f803" + integrity sha512-JiK04h0Ht5u/80fdLMCEmV4zkNh2BcoMFBmZ/91WtYZ8qVXSKjiw7fXMgFPnHcSZgOo3XdinHvmnDUeMf5R4wA== + dependencies: + domelementtype "1" + +domutils@1.5.1: + version "1.5.1" + resolved "https://registry.yarnpkg.com/domutils/-/domutils-1.5.1.tgz#dcd8488a26f563d61079e48c9f7b7e32373682cf" + integrity sha1-3NhIiib1Y9YQeeSMn3t+Mjc2gs8= + dependencies: + dom-serializer "0" + domelementtype "1" + +domutils@^1.5.1, domutils@^1.7.0: + version "1.7.0" + resolved "https://registry.yarnpkg.com/domutils/-/domutils-1.7.0.tgz#56ea341e834e06e6748af7a1cb25da67ea9f8c2a" + integrity sha512-Lgd2XcJ/NjEw+7tFvfKxOzCYKZsdct5lczQ2ZaQY8Djz7pfAD3Gbp8ySJWtreII/vDlMVmxwa6pHmdxIYgttDg== + dependencies: + dom-serializer "0" + domelementtype "1" + +dot-case@^3.0.3: + version "3.0.3" + resolved "https://registry.yarnpkg.com/dot-case/-/dot-case-3.0.3.tgz#21d3b52efaaba2ea5fda875bb1aa8124521cf4aa" + integrity sha512-7hwEmg6RiSQfm/GwPL4AAWXKy3YNNZA3oFv2Pdiey0mwkRCPZ9x6SZbkLcn8Ma5PYeVokzoD4Twv2n7LKp5WeA== + dependencies: + no-case "^3.0.3" + tslib "^1.10.0" + +dot-prop@^5.2.0: + version "5.2.0" + resolved "https://registry.yarnpkg.com/dot-prop/-/dot-prop-5.2.0.tgz#c34ecc29556dc45f1f4c22697b6f4904e0cc4fcb" + integrity sha512-uEUyaDKoSQ1M4Oq8l45hSE26SnTxL6snNnqvK/VWx5wJhmff5z0FUVJDKDanor/6w3kzE3i7XZOk+7wC0EXr1A== + dependencies: + is-obj "^2.0.0" + +dotenv-expand@5.1.0: + version "5.1.0" + resolved "https://registry.yarnpkg.com/dotenv-expand/-/dotenv-expand-5.1.0.tgz#3fbaf020bfd794884072ea26b1e9791d45a629f0" + integrity sha512-YXQl1DSa4/PQyRfgrv6aoNjhasp/p4qs9FjJ4q4cQk+8m4r6k4ZSiEyytKG8f8W9gi8WsQtIObNmKd+tMzNTmA== + +dotenv@8.2.0: + version "8.2.0" + resolved "https://registry.yarnpkg.com/dotenv/-/dotenv-8.2.0.tgz#97e619259ada750eea3e4ea3e26bceea5424b16a" + integrity sha512-8sJ78ElpbDJBHNeBzUbUVLsqKdccaa/BXF1uPTw3GrvQTBgrQrtObr2mUrE38vzYd8cEv+m/JBfDLioYcfXoaw== + +duplexer@^0.1.1: + version "0.1.1" + resolved "https://registry.yarnpkg.com/duplexer/-/duplexer-0.1.1.tgz#ace6ff808c1ce66b57d1ebf97977acb02334cfc1" + integrity sha1-rOb/gIwc5mtX0ev5eXessCM0z8E= + +duplexify@^3.4.2, duplexify@^3.6.0: + version "3.7.1" + resolved "https://registry.yarnpkg.com/duplexify/-/duplexify-3.7.1.tgz#2a4df5317f6ccfd91f86d6fd25d8d8a103b88309" + integrity sha512-07z8uv2wMyS51kKhD1KsdXJg5WQ6t93RneqRxUHnskXVtlYYkLqM0gqStQZ3pj073g687jPCHrqNfCzawLYh5g== + dependencies: + end-of-stream "^1.0.0" + inherits "^2.0.1" + readable-stream "^2.0.0" + stream-shift "^1.0.0" + +ecc-jsbn@~0.1.1: + version "0.1.2" + resolved "https://registry.yarnpkg.com/ecc-jsbn/-/ecc-jsbn-0.1.2.tgz#3a83a904e54353287874c564b7549386849a98c9" + integrity sha1-OoOpBOVDUyh4dMVkt1SThoSamMk= + dependencies: + jsbn "~0.1.0" + safer-buffer "^2.1.0" + +ee-first@1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/ee-first/-/ee-first-1.1.1.tgz#590c61156b0ae2f4f0255732a158b266bc56b21d" + integrity sha1-WQxhFWsK4vTwJVcyoViyZrxWsh0= + +electron-to-chromium@^1.3.378, electron-to-chromium@^1.3.390: + version "1.3.392" + resolved "https://registry.yarnpkg.com/electron-to-chromium/-/electron-to-chromium-1.3.392.tgz#280ab4f7a3ae47419cfabb15dbfc1567be7f1111" + integrity sha512-/hsgeVdReDsyTBE0aU9FRdh1wnNPrX3xlz3t61F+CJPOT+Umfi9DXHsCX85TEgWZQqlow0Rw44/4/jbU2Sqgkg== + +elliptic@^6.0.0: + version "6.5.2" + resolved "https://registry.yarnpkg.com/elliptic/-/elliptic-6.5.2.tgz#05c5678d7173c049d8ca433552224a495d0e3762" + integrity sha512-f4x70okzZbIQl/NSRLkI/+tteV/9WqL98zx+SQ69KbXxmVrmjwsNUPn/gYJJ0sHvEak24cZgHIPegRePAtA/xw== + dependencies: + bn.js "^4.4.0" + brorand "^1.0.1" + hash.js "^1.0.0" + hmac-drbg "^1.0.0" + inherits "^2.0.1" + minimalistic-assert "^1.0.0" + minimalistic-crypto-utils "^1.0.0" + +emoji-regex@^7.0.1, emoji-regex@^7.0.2: + version "7.0.3" + resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-7.0.3.tgz#933a04052860c85e83c122479c4748a8e4c72156" + integrity sha512-CwBLREIQ7LvYFB0WyRvwhq5N5qPhc6PMjD6bYggFlI5YyDgl+0vxq5VHbMOFqLg7hfWzmu8T5Z1QofhmTIhItA== + +emoji-regex@^8.0.0: + version "8.0.0" + resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37" + integrity sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A== + +emojis-list@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/emojis-list/-/emojis-list-2.1.0.tgz#4daa4d9db00f9819880c79fa457ae5b09a1fd389" + integrity sha1-TapNnbAPmBmIDHn6RXrlsJof04k= + +emojis-list@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/emojis-list/-/emojis-list-3.0.0.tgz#5570662046ad29e2e916e71aae260abdff4f6a78" + integrity sha512-/kyM18EfinwXZbno9FyUGeFh87KC8HRQBQGildHZbEuRyWFOmv1U10o9BBp8XVZDVNNuQKyIGIu5ZYAAXJ0V2Q== + +encodeurl@~1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/encodeurl/-/encodeurl-1.0.2.tgz#ad3ff4c86ec2d029322f5a02c3a9a606c95b3f59" + integrity sha1-rT/0yG7C0CkyL1oCw6mmBslbP1k= + +end-of-stream@^1.0.0, end-of-stream@^1.1.0: + version "1.4.4" + resolved "https://registry.yarnpkg.com/end-of-stream/-/end-of-stream-1.4.4.tgz#5ae64a5f45057baf3626ec14da0ca5e4b2431eb0" + integrity sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q== + dependencies: + once "^1.4.0" + +enhanced-resolve@4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/enhanced-resolve/-/enhanced-resolve-4.1.0.tgz#41c7e0bfdfe74ac1ffe1e57ad6a5c6c9f3742a7f" + integrity sha512-F/7vkyTtyc/llOIn8oWclcB25KdRaiPBpZYDgJHgh/UHtpgT2p2eldQgtQnLtUvfMKPKxbRaQM/hHkvLHt1Vng== + dependencies: + graceful-fs "^4.1.2" + memory-fs "^0.4.0" + tapable "^1.0.0" + +enhanced-resolve@^4.1.0: + version "4.1.1" + resolved "https://registry.yarnpkg.com/enhanced-resolve/-/enhanced-resolve-4.1.1.tgz#2937e2b8066cd0fe7ce0990a98f0d71a35189f66" + integrity sha512-98p2zE+rL7/g/DzMHMTF4zZlCgeVdJ7yr6xzEpJRYwFYrGi9ANdn5DnJURg6RpBkyk60XYDnWIv51VfIhfNGuA== + dependencies: + graceful-fs "^4.1.2" + memory-fs "^0.5.0" + tapable "^1.0.0" + +entities@^1.1.1: + version "1.1.2" + resolved "https://registry.yarnpkg.com/entities/-/entities-1.1.2.tgz#bdfa735299664dfafd34529ed4f8522a275fea56" + integrity sha512-f2LZMYl1Fzu7YSBKg+RoROelpOaNrcGmE9AZubeDfrCEia483oW4MI4VyFd5VNHIgQ/7qm1I0wUHK1eJnn2y2w== + +entities@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/entities/-/entities-2.0.0.tgz#68d6084cab1b079767540d80e56a39b423e4abf4" + integrity sha512-D9f7V0JSRwIxlRI2mjMqufDrRDnx8p+eEOz7aUM9SuvF8gsBzra0/6tbjl1m8eQHrZlYj6PxqE00hZ1SAIKPLw== + +errno@^0.1.3, errno@~0.1.7: + version "0.1.7" + resolved "https://registry.yarnpkg.com/errno/-/errno-0.1.7.tgz#4684d71779ad39af177e3f007996f7c67c852618" + integrity sha512-MfrRBDWzIWifgq6tJj60gkAwtLNb6sQPlcFrSOflcP1aFmmruKQ2wRnze/8V6kgyz7H3FF8Npzv78mZ7XLLflg== + dependencies: + prr "~1.0.1" + +error-ex@^1.2.0, error-ex@^1.3.1: + version "1.3.2" + resolved "https://registry.yarnpkg.com/error-ex/-/error-ex-1.3.2.tgz#b4ac40648107fdcdcfae242f428bea8a14d4f1bf" + integrity sha512-7dFHNmqeFSEt2ZBsCriorKnn3Z2pj+fd9kmI6QoWw4//DL+icEBfc0U7qJCisqrTsKTjw4fNFy2pW9OqStD84g== + dependencies: + is-arrayish "^0.2.1" + +es-abstract@^1.17.0, es-abstract@^1.17.0-next.1, es-abstract@^1.17.2, es-abstract@^1.17.5: + version "1.17.5" + resolved "https://registry.yarnpkg.com/es-abstract/-/es-abstract-1.17.5.tgz#d8c9d1d66c8981fb9200e2251d799eee92774ae9" + integrity sha512-BR9auzDbySxOcfog0tLECW8l28eRGpDpU3Dm3Hp4q/N+VtLTmyj4EUN088XZWQDW/hzj6sYRDXeOFsaAODKvpg== + dependencies: + es-to-primitive "^1.2.1" + function-bind "^1.1.1" + has "^1.0.3" + has-symbols "^1.0.1" + is-callable "^1.1.5" + is-regex "^1.0.5" + object-inspect "^1.7.0" + object-keys "^1.1.1" + object.assign "^4.1.0" + string.prototype.trimleft "^2.1.1" + string.prototype.trimright "^2.1.1" + +es-to-primitive@^1.2.1: + version "1.2.1" + resolved "https://registry.yarnpkg.com/es-to-primitive/-/es-to-primitive-1.2.1.tgz#e55cd4c9cdc188bcefb03b366c736323fc5c898a" + integrity sha512-QCOllgZJtaUo9miYBcLChTUaHNjJF3PYs1VidD7AwiEj1kYxKeQTctLAezAOH5ZKRH0g2IgPn6KwB4IT8iRpvA== + dependencies: + is-callable "^1.1.4" + is-date-object "^1.0.1" + is-symbol "^1.0.2" + +es5-ext@^0.10.35, es5-ext@^0.10.50: + version "0.10.53" + resolved "https://registry.yarnpkg.com/es5-ext/-/es5-ext-0.10.53.tgz#93c5a3acfdbef275220ad72644ad02ee18368de1" + integrity sha512-Xs2Stw6NiNHWypzRTY1MtaG/uJlwCk8kH81920ma8mvN8Xq1gsfhZvpkImLQArw8AHnv8MT2I45J3c0R8slE+Q== + dependencies: + es6-iterator "~2.0.3" + es6-symbol "~3.1.3" + next-tick "~1.0.0" + +es6-iterator@2.0.3, es6-iterator@~2.0.3: + version "2.0.3" + resolved "https://registry.yarnpkg.com/es6-iterator/-/es6-iterator-2.0.3.tgz#a7de889141a05a94b0854403b2d0a0fbfa98f3b7" + integrity sha1-p96IkUGgWpSwhUQDstCg+/qY87c= + dependencies: + d "1" + es5-ext "^0.10.35" + es6-symbol "^3.1.1" + +es6-symbol@^3.1.1, es6-symbol@~3.1.3: + version "3.1.3" + resolved "https://registry.yarnpkg.com/es6-symbol/-/es6-symbol-3.1.3.tgz#bad5d3c1bcdac28269f4cb331e431c78ac705d18" + integrity sha512-NJ6Yn3FuDinBaBRWl/q5X/s4koRHBrgKAu+yGI6JCBeiu3qrcbJhwT2GeR/EXVfylRk8dpQVJoLEFhK+Mu31NA== + dependencies: + d "^1.0.1" + ext "^1.1.2" + +escape-html@~1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/escape-html/-/escape-html-1.0.3.tgz#0258eae4d3d0c0974de1c169188ef0051d1d1988" + integrity sha1-Aljq5NPQwJdN4cFpGI7wBR0dGYg= + +escape-string-regexp@2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-2.0.0.tgz#a30304e99daa32e23b2fd20f51babd07cffca344" + integrity sha512-UpzcLCXolUWcNu5HtVMHYdXJjArjsF9C0aNnquZYY4uW/Vu0miy5YoWvbV345HauVvcAUnpRuhMMcqTcGOY2+w== + +escape-string-regexp@^1.0.2, escape-string-regexp@^1.0.5: + version "1.0.5" + resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz#1b61c0562190a8dff6ae3bb2cf0200ca130b86d4" + integrity sha1-G2HAViGQqN/2rjuyzwIAyhMLhtQ= + +escodegen@^1.11.0, escodegen@^1.9.1: + version "1.14.1" + resolved "https://registry.yarnpkg.com/escodegen/-/escodegen-1.14.1.tgz#ba01d0c8278b5e95a9a45350142026659027a457" + integrity sha512-Bmt7NcRySdIfNPfU2ZoXDrrXsG9ZjvDxcAlMfDUgRBjLOWTuIACXPBFJH7Z+cLb40JeQco5toikyc9t9P8E9SQ== + dependencies: + esprima "^4.0.1" + estraverse "^4.2.0" + esutils "^2.0.2" + optionator "^0.8.1" + optionalDependencies: + source-map "~0.6.1" + +eslint-config-react-app@^5.2.1: + version "5.2.1" + resolved "https://registry.yarnpkg.com/eslint-config-react-app/-/eslint-config-react-app-5.2.1.tgz#698bf7aeee27f0cea0139eaef261c7bf7dd623df" + integrity sha512-pGIZ8t0mFLcV+6ZirRgYK6RVqUIKRIi9MmgzUEmrIknsn3AdO0I32asO86dJgloHq+9ZPl8UIg8mYrvgP5u2wQ== + dependencies: + confusing-browser-globals "^1.0.9" + +eslint-import-resolver-node@^0.3.2: + version "0.3.3" + resolved "https://registry.yarnpkg.com/eslint-import-resolver-node/-/eslint-import-resolver-node-0.3.3.tgz#dbaa52b6b2816b50bc6711af75422de808e98404" + integrity sha512-b8crLDo0M5RSe5YG8Pu2DYBj71tSB6OvXkfzwbJU2w7y8P4/yo0MyF8jU26IEuEuHF2K5/gcAJE3LhQGqBBbVg== + dependencies: + debug "^2.6.9" + resolve "^1.13.1" + +eslint-loader@3.0.3: + version "3.0.3" + resolved "https://registry.yarnpkg.com/eslint-loader/-/eslint-loader-3.0.3.tgz#e018e3d2722381d982b1201adb56819c73b480ca" + integrity sha512-+YRqB95PnNvxNp1HEjQmvf9KNvCin5HXYYseOXVC2U0KEcw4IkQ2IQEBG46j7+gW39bMzeu0GsUhVbBY3Votpw== + dependencies: + fs-extra "^8.1.0" + loader-fs-cache "^1.0.2" + loader-utils "^1.2.3" + object-hash "^2.0.1" + schema-utils "^2.6.1" + +eslint-module-utils@^2.4.1: + version "2.6.0" + resolved "https://registry.yarnpkg.com/eslint-module-utils/-/eslint-module-utils-2.6.0.tgz#579ebd094f56af7797d19c9866c9c9486629bfa6" + integrity sha512-6j9xxegbqe8/kZY8cYpcp0xhbK0EgJlg3g9mib3/miLaExuuwc3n5UEfSnU6hWMbT0FAYVvDbL9RrRgpUeQIvA== + dependencies: + debug "^2.6.9" + pkg-dir "^2.0.0" + +eslint-plugin-flowtype@4.6.0: + version "4.6.0" + resolved "https://registry.yarnpkg.com/eslint-plugin-flowtype/-/eslint-plugin-flowtype-4.6.0.tgz#82b2bd6f21770e0e5deede0228e456cb35308451" + integrity sha512-W5hLjpFfZyZsXfo5anlu7HM970JBDqbEshAJUkeczP6BFCIfJXuiIBQXyberLRtOStT0OGPF8efeTbxlHk4LpQ== + dependencies: + lodash "^4.17.15" + +eslint-plugin-import@2.20.1: + version "2.20.1" + resolved "https://registry.yarnpkg.com/eslint-plugin-import/-/eslint-plugin-import-2.20.1.tgz#802423196dcb11d9ce8435a5fc02a6d3b46939b3" + integrity sha512-qQHgFOTjguR+LnYRoToeZWT62XM55MBVXObHM6SKFd1VzDcX/vqT1kAz8ssqigh5eMj8qXcRoXXGZpPP6RfdCw== + dependencies: + array-includes "^3.0.3" + array.prototype.flat "^1.2.1" + contains-path "^0.1.0" + debug "^2.6.9" + doctrine "1.5.0" + eslint-import-resolver-node "^0.3.2" + eslint-module-utils "^2.4.1" + has "^1.0.3" + minimatch "^3.0.4" + object.values "^1.1.0" + read-pkg-up "^2.0.0" + resolve "^1.12.0" + +eslint-plugin-jsx-a11y@6.2.3: + version "6.2.3" + resolved "https://registry.yarnpkg.com/eslint-plugin-jsx-a11y/-/eslint-plugin-jsx-a11y-6.2.3.tgz#b872a09d5de51af70a97db1eea7dc933043708aa" + integrity sha512-CawzfGt9w83tyuVekn0GDPU9ytYtxyxyFZ3aSWROmnRRFQFT2BiPJd7jvRdzNDi6oLWaS2asMeYSNMjWTV4eNg== + dependencies: + "@babel/runtime" "^7.4.5" + aria-query "^3.0.0" + array-includes "^3.0.3" + ast-types-flow "^0.0.7" + axobject-query "^2.0.2" + damerau-levenshtein "^1.0.4" + emoji-regex "^7.0.2" + has "^1.0.3" + jsx-ast-utils "^2.2.1" + +eslint-plugin-react-hooks@^1.6.1: + version "1.7.0" + resolved "https://registry.yarnpkg.com/eslint-plugin-react-hooks/-/eslint-plugin-react-hooks-1.7.0.tgz#6210b6d5a37205f0b92858f895a4e827020a7d04" + integrity sha512-iXTCFcOmlWvw4+TOE8CLWj6yX1GwzT0Y6cUfHHZqWnSk144VmVIRcVGtUAzrLES7C798lmvnt02C7rxaOX1HNA== + +eslint-plugin-react@7.19.0: + version "7.19.0" + resolved "https://registry.yarnpkg.com/eslint-plugin-react/-/eslint-plugin-react-7.19.0.tgz#6d08f9673628aa69c5559d33489e855d83551666" + integrity sha512-SPT8j72CGuAP+JFbT0sJHOB80TX/pu44gQ4vXH/cq+hQTiY2PuZ6IHkqXJV6x1b28GDdo1lbInjKUrrdUf0LOQ== + dependencies: + array-includes "^3.1.1" + doctrine "^2.1.0" + has "^1.0.3" + jsx-ast-utils "^2.2.3" + object.entries "^1.1.1" + object.fromentries "^2.0.2" + object.values "^1.1.1" + prop-types "^15.7.2" + resolve "^1.15.1" + semver "^6.3.0" + string.prototype.matchall "^4.0.2" + xregexp "^4.3.0" + +eslint-scope@^4.0.3: + version "4.0.3" + resolved "https://registry.yarnpkg.com/eslint-scope/-/eslint-scope-4.0.3.tgz#ca03833310f6889a3264781aa82e63eb9cfe7848" + integrity sha512-p7VutNr1O/QrxysMo3E45FjYDTeXBy0iTltPFNSqKAIfjDSXC+4dj+qfyuD8bfAXrW/y6lW3O76VaYNPKfpKrg== + dependencies: + esrecurse "^4.1.0" + estraverse "^4.1.1" + +eslint-scope@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/eslint-scope/-/eslint-scope-5.0.0.tgz#e87c8887c73e8d1ec84f1ca591645c358bfc8fb9" + integrity sha512-oYrhJW7S0bxAFDvWqzvMPRm6pcgcnWc4QnofCAqRTRfQC0JcwenzGglTtsLyIuuWFfkqDG9vz67cnttSd53djw== + dependencies: + esrecurse "^4.1.0" + estraverse "^4.1.1" + +eslint-utils@^1.4.3: + version "1.4.3" + resolved "https://registry.yarnpkg.com/eslint-utils/-/eslint-utils-1.4.3.tgz#74fec7c54d0776b6f67e0251040b5806564e981f" + integrity sha512-fbBN5W2xdY45KulGXmLHZ3c3FHfVYmKg0IrAKGOkT/464PQsx2UeIzfz1RmEci+KLm1bBaAzZAh8+/E+XAeZ8Q== + dependencies: + eslint-visitor-keys "^1.1.0" + +eslint-utils@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/eslint-utils/-/eslint-utils-2.0.0.tgz#7be1cc70f27a72a76cd14aa698bcabed6890e1cd" + integrity sha512-0HCPuJv+7Wv1bACm8y5/ECVfYdfsAm9xmVb7saeFlxjPYALefjhbYoCkBjPdPzGH8wWyTpAez82Fh3VKYEZ8OA== + dependencies: + eslint-visitor-keys "^1.1.0" + +eslint-visitor-keys@^1.0.0, eslint-visitor-keys@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/eslint-visitor-keys/-/eslint-visitor-keys-1.1.0.tgz#e2a82cea84ff246ad6fb57f9bde5b46621459ec2" + integrity sha512-8y9YjtM1JBJU/A9Kc+SbaOV4y29sSWckBwMHa+FGtVj5gN/sbnKDf6xJUl+8g7FAij9LVaP8C24DUiH/f/2Z9A== + +eslint@^6.6.0: + version "6.8.0" + resolved "https://registry.yarnpkg.com/eslint/-/eslint-6.8.0.tgz#62262d6729739f9275723824302fb227c8c93ffb" + integrity sha512-K+Iayyo2LtyYhDSYwz5D5QdWw0hCacNzyq1Y821Xna2xSJj7cijoLLYmLxTQgcgZ9mC61nryMy9S7GRbYpI5Ig== + dependencies: + "@babel/code-frame" "^7.0.0" + ajv "^6.10.0" + chalk "^2.1.0" + cross-spawn "^6.0.5" + debug "^4.0.1" + doctrine "^3.0.0" + eslint-scope "^5.0.0" + eslint-utils "^1.4.3" + eslint-visitor-keys "^1.1.0" + espree "^6.1.2" + esquery "^1.0.1" + esutils "^2.0.2" + file-entry-cache "^5.0.1" + functional-red-black-tree "^1.0.1" + glob-parent "^5.0.0" + globals "^12.1.0" + ignore "^4.0.6" + import-fresh "^3.0.0" + imurmurhash "^0.1.4" + inquirer "^7.0.0" + is-glob "^4.0.0" + js-yaml "^3.13.1" + json-stable-stringify-without-jsonify "^1.0.1" + levn "^0.3.0" + lodash "^4.17.14" + minimatch "^3.0.4" + mkdirp "^0.5.1" + natural-compare "^1.4.0" + optionator "^0.8.3" + progress "^2.0.0" + regexpp "^2.0.1" + semver "^6.1.2" + strip-ansi "^5.2.0" + strip-json-comments "^3.0.1" + table "^5.2.3" + text-table "^0.2.0" + v8-compile-cache "^2.0.3" + +espree@^6.1.2: + version "6.2.1" + resolved "https://registry.yarnpkg.com/espree/-/espree-6.2.1.tgz#77fc72e1fd744a2052c20f38a5b575832e82734a" + integrity sha512-ysCxRQY3WaXJz9tdbWOwuWr5Y/XrPTGX9Kiz3yoUXwW0VZ4w30HTkQLaGx/+ttFjF8i+ACbArnB4ce68a9m5hw== + dependencies: + acorn "^7.1.1" + acorn-jsx "^5.2.0" + eslint-visitor-keys "^1.1.0" + +esprima@^4.0.0, esprima@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/esprima/-/esprima-4.0.1.tgz#13b04cdb3e6c5d19df91ab6987a8695619b0aa71" + integrity sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A== + +esquery@^1.0.1: + version "1.2.0" + resolved "https://registry.yarnpkg.com/esquery/-/esquery-1.2.0.tgz#a010a519c0288f2530b3404124bfb5f02e9797fe" + integrity sha512-weltsSqdeWIX9G2qQZz7KlTRJdkkOCTPgLYJUz1Hacf48R4YOwGPHO3+ORfWedqJKbq5WQmsgK90n+pFLIKt/Q== + dependencies: + estraverse "^5.0.0" + +esrecurse@^4.1.0: + version "4.2.1" + resolved "https://registry.yarnpkg.com/esrecurse/-/esrecurse-4.2.1.tgz#007a3b9fdbc2b3bb87e4879ea19c92fdbd3942cf" + integrity sha512-64RBB++fIOAXPw3P9cy89qfMlvZEXZkqqJkjqqXIvzP5ezRZjW+lPWjw35UX/3EhUPFYbg5ER4JYgDw4007/DQ== + dependencies: + estraverse "^4.1.0" + +estraverse@^4.1.0, estraverse@^4.1.1, estraverse@^4.2.0: + version "4.3.0" + resolved "https://registry.yarnpkg.com/estraverse/-/estraverse-4.3.0.tgz#398ad3f3c5a24948be7725e83d11a7de28cdbd1d" + integrity sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw== + +estraverse@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/estraverse/-/estraverse-5.0.0.tgz#ac81750b482c11cca26e4b07e83ed8f75fbcdc22" + integrity sha512-j3acdrMzqrxmJTNj5dbr1YbjacrYgAxVMeF0gK16E3j494mOe7xygM/ZLIguEQ0ETwAg2hlJCtHRGav+y0Ny5A== + +esutils@^2.0.2: + version "2.0.3" + resolved "https://registry.yarnpkg.com/esutils/-/esutils-2.0.3.tgz#74d2eb4de0b8da1293711910d50775b9b710ef64" + integrity sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g== + +etag@~1.8.1: + version "1.8.1" + resolved "https://registry.yarnpkg.com/etag/-/etag-1.8.1.tgz#41ae2eeb65efa62268aebfea83ac7d79299b0887" + integrity sha1-Qa4u62XvpiJorr/qg6x9eSmbCIc= + +eventemitter3@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/eventemitter3/-/eventemitter3-4.0.0.tgz#d65176163887ee59f386d64c82610b696a4a74eb" + integrity sha512-qerSRB0p+UDEssxTtm6EDKcE7W4OaoisfIMl4CngyEhjpYglocpNg6UEqCvemdGhosAsg4sO2dXJOdyBifPGCg== + +events@^3.0.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/events/-/events-3.1.0.tgz#84279af1b34cb75aa88bf5ff291f6d0bd9b31a59" + integrity sha512-Rv+u8MLHNOdMjTAFeT3nCjHn2aGlx435FP/sDHNaRhDEMwyI/aB22Kj2qIN8R0cw3z28psEQLYwxVKLsKrMgWg== + +eventsource@^1.0.7: + version "1.0.7" + resolved "https://registry.yarnpkg.com/eventsource/-/eventsource-1.0.7.tgz#8fbc72c93fcd34088090bc0a4e64f4b5cee6d8d0" + integrity sha512-4Ln17+vVT0k8aWq+t/bF5arcS3EpT9gYtW66EPacdj/mAFevznsnyoHLPy2BA8gbIQeIHoPsvwmfBftfcG//BQ== + dependencies: + original "^1.0.0" + +evp_bytestokey@^1.0.0, evp_bytestokey@^1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/evp_bytestokey/-/evp_bytestokey-1.0.3.tgz#7fcbdb198dc71959432efe13842684e0525acb02" + integrity sha512-/f2Go4TognH/KvCISP7OUsHn85hT9nUkxxA9BEWxFn+Oj9o8ZNLm/40hdlgSLyuOimsrTKLUMEorQexp/aPQeA== + dependencies: + md5.js "^1.3.4" + safe-buffer "^5.1.1" + +exec-sh@^0.3.2: + version "0.3.4" + resolved "https://registry.yarnpkg.com/exec-sh/-/exec-sh-0.3.4.tgz#3a018ceb526cc6f6df2bb504b2bfe8e3a4934ec5" + integrity sha512-sEFIkc61v75sWeOe72qyrqg2Qg0OuLESziUDk/O/z2qgS15y2gWVFrI6f2Qn/qw/0/NCfCEsmNA4zOjkwEZT1A== + +execa@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/execa/-/execa-1.0.0.tgz#c6236a5bb4df6d6f15e88e7f017798216749ddd8" + integrity sha512-adbxcyWV46qiHyvSp50TKt05tB4tK3HcmF7/nxfAdhnox83seTDbwnaqKO4sXRy7roHAIFqJP/Rw/AuEbX61LA== + dependencies: + cross-spawn "^6.0.0" + get-stream "^4.0.0" + is-stream "^1.1.0" + npm-run-path "^2.0.0" + p-finally "^1.0.0" + signal-exit "^3.0.0" + strip-eof "^1.0.0" + +exit@^0.1.2: + version "0.1.2" + resolved "https://registry.yarnpkg.com/exit/-/exit-0.1.2.tgz#0632638f8d877cc82107d30a0fff1a17cba1cd0c" + integrity sha1-BjJjj42HfMghB9MKD/8aF8uhzQw= + +expand-brackets@^2.1.4: + version "2.1.4" + resolved "https://registry.yarnpkg.com/expand-brackets/-/expand-brackets-2.1.4.tgz#b77735e315ce30f6b6eff0f83b04151a22449622" + integrity sha1-t3c14xXOMPa27/D4OwQVGiJEliI= + dependencies: + debug "^2.3.3" + define-property "^0.2.5" + extend-shallow "^2.0.1" + posix-character-classes "^0.1.0" + regex-not "^1.0.0" + snapdragon "^0.8.1" + to-regex "^3.0.1" + +expand-tilde@^2.0.0, expand-tilde@^2.0.2: + version "2.0.2" + resolved "https://registry.yarnpkg.com/expand-tilde/-/expand-tilde-2.0.2.tgz#97e801aa052df02454de46b02bf621642cdc8502" + integrity sha1-l+gBqgUt8CRU3kawK/YhZCzchQI= + dependencies: + homedir-polyfill "^1.0.1" + +expect@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/expect/-/expect-24.9.0.tgz#b75165b4817074fa4a157794f46fe9f1ba15b6ca" + integrity sha512-wvVAx8XIol3Z5m9zvZXiyZOQ+sRJqNTIm6sGjdWlaZIeupQGO3WbYI+15D/AmEwZywL6wtJkbAbJtzkOfBuR0Q== + dependencies: + "@jest/types" "^24.9.0" + ansi-styles "^3.2.0" + jest-get-type "^24.9.0" + jest-matcher-utils "^24.9.0" + jest-message-util "^24.9.0" + jest-regex-util "^24.9.0" + +express@^4.17.1: + version "4.17.1" + resolved "https://registry.yarnpkg.com/express/-/express-4.17.1.tgz#4491fc38605cf51f8629d39c2b5d026f98a4c134" + integrity sha512-mHJ9O79RqluphRrcw2X/GTh3k9tVv8YcoyY4Kkh4WDMUYKRZUq0h1o0w2rrrxBqM7VoeUVqgb27xlEMXTnYt4g== + dependencies: + accepts "~1.3.7" + array-flatten "1.1.1" + body-parser "1.19.0" + content-disposition "0.5.3" + content-type "~1.0.4" + cookie "0.4.0" + cookie-signature "1.0.6" + debug "2.6.9" + depd "~1.1.2" + encodeurl "~1.0.2" + escape-html "~1.0.3" + etag "~1.8.1" + finalhandler "~1.1.2" + fresh "0.5.2" + merge-descriptors "1.0.1" + methods "~1.1.2" + on-finished "~2.3.0" + parseurl "~1.3.3" + path-to-regexp "0.1.7" + proxy-addr "~2.0.5" + qs "6.7.0" + range-parser "~1.2.1" + safe-buffer "5.1.2" + send "0.17.1" + serve-static "1.14.1" + setprototypeof "1.1.1" + statuses "~1.5.0" + type-is "~1.6.18" + utils-merge "1.0.1" + vary "~1.1.2" + +ext@^1.1.2: + version "1.4.0" + resolved "https://registry.yarnpkg.com/ext/-/ext-1.4.0.tgz#89ae7a07158f79d35517882904324077e4379244" + integrity sha512-Key5NIsUxdqKg3vIsdw9dSuXpPCQ297y6wBjL30edxwPgt2E44WcWBZey/ZvUc6sERLTxKdyCu4gZFmUbk1Q7A== + dependencies: + type "^2.0.0" + +extend-shallow@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/extend-shallow/-/extend-shallow-2.0.1.tgz#51af7d614ad9a9f610ea1bafbb989d6b1c56890f" + integrity sha1-Ua99YUrZqfYQ6huvu5idaxxWiQ8= + dependencies: + is-extendable "^0.1.0" + +extend-shallow@^3.0.0, extend-shallow@^3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/extend-shallow/-/extend-shallow-3.0.2.tgz#26a71aaf073b39fb2127172746131c2704028db8" + integrity sha1-Jqcarwc7OfshJxcnRhMcJwQCjbg= + dependencies: + assign-symbols "^1.0.0" + is-extendable "^1.0.1" + +extend@~3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/extend/-/extend-3.0.2.tgz#f8b1136b4071fbd8eb140aff858b1019ec2915fa" + integrity sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g== + +external-editor@^3.0.3: + version "3.1.0" + resolved "https://registry.yarnpkg.com/external-editor/-/external-editor-3.1.0.tgz#cb03f740befae03ea4d283caed2741a83f335495" + integrity sha512-hMQ4CX1p1izmuLYyZqLMO/qGNw10wSv9QDCPfzXfyFrOaCSSoRfqE1Kf1s5an66J5JZC62NewG+mK49jOCtQew== + dependencies: + chardet "^0.7.0" + iconv-lite "^0.4.24" + tmp "^0.0.33" + +extglob@^2.0.4: + version "2.0.4" + resolved "https://registry.yarnpkg.com/extglob/-/extglob-2.0.4.tgz#ad00fe4dc612a9232e8718711dc5cb5ab0285543" + integrity sha512-Nmb6QXkELsuBr24CJSkilo6UHHgbekK5UiZgfE6UHD3Eb27YC6oD+bhcT+tJ6cl8dmsgdQxnWlcry8ksBIBLpw== + dependencies: + array-unique "^0.3.2" + define-property "^1.0.0" + expand-brackets "^2.1.4" + extend-shallow "^2.0.1" + fragment-cache "^0.2.1" + regex-not "^1.0.0" + snapdragon "^0.8.1" + to-regex "^3.0.1" + +extsprintf@1.3.0: + version "1.3.0" + resolved "https://registry.yarnpkg.com/extsprintf/-/extsprintf-1.3.0.tgz#96918440e3041a7a414f8c52e3c574eb3c3e1e05" + integrity sha1-lpGEQOMEGnpBT4xS48V06zw+HgU= + +extsprintf@^1.2.0: + version "1.4.0" + resolved "https://registry.yarnpkg.com/extsprintf/-/extsprintf-1.4.0.tgz#e2689f8f356fad62cca65a3a91c5df5f9551692f" + integrity sha1-4mifjzVvrWLMplo6kcXfX5VRaS8= + +fast-deep-equal@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/fast-deep-equal/-/fast-deep-equal-3.1.1.tgz#545145077c501491e33b15ec408c294376e94ae4" + integrity sha512-8UEa58QDLauDNfpbrX55Q9jrGHThw2ZMdOky5Gl1CDtVeJDPVrG4Jxx1N8jw2gkWaff5UUuX1KJd+9zGe2B+ZA== + +fast-glob@^2.0.2: + version "2.2.7" + resolved "https://registry.yarnpkg.com/fast-glob/-/fast-glob-2.2.7.tgz#6953857c3afa475fff92ee6015d52da70a4cd39d" + integrity sha512-g1KuQwHOZAmOZMuBtHdxDtju+T2RT8jgCC9aANsbpdiDDTSnjgfuVsIBNKbUeJI3oKMRExcfNDtJl4OhbffMsw== + dependencies: + "@mrmlnc/readdir-enhanced" "^2.2.1" + "@nodelib/fs.stat" "^1.1.2" + glob-parent "^3.1.0" + is-glob "^4.0.0" + merge2 "^1.2.3" + micromatch "^3.1.10" + +fast-json-stable-stringify@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz#874bf69c6f404c2b5d99c481341399fd55892633" + integrity sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw== + +fast-levenshtein@~2.0.6: + version "2.0.6" + resolved "https://registry.yarnpkg.com/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz#3d8a5c66883a16a30ca8643e851f19baa7797917" + integrity sha1-PYpcZog6FqMMqGQ+hR8Zuqd5eRc= + +faye-websocket@^0.10.0: + version "0.10.0" + resolved "https://registry.yarnpkg.com/faye-websocket/-/faye-websocket-0.10.0.tgz#4e492f8d04dfb6f89003507f6edbf2d501e7c6f4" + integrity sha1-TkkvjQTftviQA1B/btvy1QHnxvQ= + dependencies: + websocket-driver ">=0.5.1" + +faye-websocket@~0.11.1: + version "0.11.3" + resolved "https://registry.yarnpkg.com/faye-websocket/-/faye-websocket-0.11.3.tgz#5c0e9a8968e8912c286639fde977a8b209f2508e" + integrity sha512-D2y4bovYpzziGgbHYtGCMjlJM36vAl/y+xUyn1C+FVx8szd1E+86KwVw6XvYSzOP8iMpm1X0I4xJD+QtUb36OA== + dependencies: + websocket-driver ">=0.5.1" + +fb-watchman@^2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/fb-watchman/-/fb-watchman-2.0.1.tgz#fc84fb39d2709cf3ff6d743706157bb5708a8a85" + integrity sha512-DkPJKQeY6kKwmuMretBhr7G6Vodr7bFwDYTXIkfG1gjvNpaxBTQV3PbXg6bR1c1UP4jPOX0jHUbbHANL9vRjVg== + dependencies: + bser "2.1.1" + +figgy-pudding@^3.5.1: + version "3.5.2" + resolved "https://registry.yarnpkg.com/figgy-pudding/-/figgy-pudding-3.5.2.tgz#b4eee8148abb01dcf1d1ac34367d59e12fa61d6e" + integrity sha512-0btnI/H8f2pavGMN8w40mlSKOfTK2SVJmBfBeVIj3kNw0swwgzyRq0d5TJVOwodFmtvpPeWPN/MCcfuWF0Ezbw== + +figures@^3.0.0: + version "3.2.0" + resolved "https://registry.yarnpkg.com/figures/-/figures-3.2.0.tgz#625c18bd293c604dc4a8ddb2febf0c88341746af" + integrity sha512-yaduQFRKLXYOGgEn6AZau90j3ggSOyiqXU0F9JZfeXYhNa+Jk4X+s45A2zg5jns87GAFa34BBm2kXw4XpNcbdg== + dependencies: + escape-string-regexp "^1.0.5" + +file-entry-cache@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/file-entry-cache/-/file-entry-cache-5.0.1.tgz#ca0f6efa6dd3d561333fb14515065c2fafdf439c" + integrity sha512-bCg29ictuBaKUwwArK4ouCaqDgLZcysCFLmM/Yn/FDoqndh/9vNuQfXRDvTuXKLxfD/JtZQGKFT8MGcJBK644g== + dependencies: + flat-cache "^2.0.1" + +file-loader@4.3.0: + version "4.3.0" + resolved "https://registry.yarnpkg.com/file-loader/-/file-loader-4.3.0.tgz#780f040f729b3d18019f20605f723e844b8a58af" + integrity sha512-aKrYPYjF1yG3oX0kWRrqrSMfgftm7oJW5M+m4owoldH5C51C0RkIwB++JbRvEW3IU6/ZG5n8UvEcdgwOt2UOWA== + dependencies: + loader-utils "^1.2.3" + schema-utils "^2.5.0" + +file-uri-to-path@1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/file-uri-to-path/-/file-uri-to-path-1.0.0.tgz#553a7b8446ff6f684359c445f1e37a05dacc33dd" + integrity sha512-0Zt+s3L7Vf1biwWZ29aARiVYLx7iMGnEUl9x33fbB/j3jR81u/O2LbqK+Bm1CDSNDKVtJ/YjwY7TUd5SkeLQLw== + +filesize@6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/filesize/-/filesize-6.0.1.tgz#f850b509909c7c86f7e450ea19006c31c2ed3d2f" + integrity sha512-u4AYWPgbI5GBhs6id1KdImZWn5yfyFrrQ8OWZdN7ZMfA8Bf4HcO0BGo9bmUIEV8yrp8I1xVfJ/dn90GtFNNJcg== + +fill-range@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-4.0.0.tgz#d544811d428f98eb06a63dc402d2403c328c38f7" + integrity sha1-1USBHUKPmOsGpj3EAtJAPDKMOPc= + dependencies: + extend-shallow "^2.0.1" + is-number "^3.0.0" + repeat-string "^1.6.1" + to-regex-range "^2.1.0" + +fill-range@^7.0.1: + version "7.0.1" + resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40" + integrity sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ== + dependencies: + to-regex-range "^5.0.1" + +finalhandler@~1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/finalhandler/-/finalhandler-1.1.2.tgz#b7e7d000ffd11938d0fdb053506f6ebabe9f587d" + integrity sha512-aAWcW57uxVNrQZqFXjITpW3sIUQmHGG3qSb9mUah9MgMC4NeWhNOlNjXEYq3HjRAvL6arUviZGGJsBg6z0zsWA== + dependencies: + debug "2.6.9" + encodeurl "~1.0.2" + escape-html "~1.0.3" + on-finished "~2.3.0" + parseurl "~1.3.3" + statuses "~1.5.0" + unpipe "~1.0.0" + +find-cache-dir@^0.1.1: + version "0.1.1" + resolved "https://registry.yarnpkg.com/find-cache-dir/-/find-cache-dir-0.1.1.tgz#c8defae57c8a52a8a784f9e31c57c742e993a0b9" + integrity sha1-yN765XyKUqinhPnjHFfHQumToLk= + dependencies: + commondir "^1.0.1" + mkdirp "^0.5.1" + pkg-dir "^1.0.0" + +find-cache-dir@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/find-cache-dir/-/find-cache-dir-2.1.0.tgz#8d0f94cd13fe43c6c7c261a0d86115ca918c05f7" + integrity sha512-Tq6PixE0w/VMFfCgbONnkiQIVol/JJL7nRMi20fqzA4NRs9AfeqMGeRdPi3wIhYkxjeBaWh2rxwapn5Tu3IqOQ== + dependencies: + commondir "^1.0.1" + make-dir "^2.0.0" + pkg-dir "^3.0.0" + +find-cache-dir@^3.2.0: + version "3.3.1" + resolved "https://registry.yarnpkg.com/find-cache-dir/-/find-cache-dir-3.3.1.tgz#89b33fad4a4670daa94f855f7fbe31d6d84fe880" + integrity sha512-t2GDMt3oGC/v+BMwzmllWDuJF/xcDtE5j/fCGbqDD7OLuJkj0cfh1YSA5VKPvwMeLFLNDBkwOKZ2X85jGLVftQ== + dependencies: + commondir "^1.0.1" + make-dir "^3.0.2" + pkg-dir "^4.1.0" + +find-up@4.1.0, find-up@^4.0.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/find-up/-/find-up-4.1.0.tgz#97afe7d6cdc0bc5928584b7c8d7b16e8a9aa5d19" + integrity sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw== + dependencies: + locate-path "^5.0.0" + path-exists "^4.0.0" + +find-up@^1.0.0: + version "1.1.2" + resolved "https://registry.yarnpkg.com/find-up/-/find-up-1.1.2.tgz#6b2e9822b1a2ce0a60ab64d610eccad53cb24d0f" + integrity sha1-ay6YIrGizgpgq2TWEOzK1TyyTQ8= + dependencies: + path-exists "^2.0.0" + pinkie-promise "^2.0.0" + +find-up@^2.0.0, find-up@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/find-up/-/find-up-2.1.0.tgz#45d1b7e506c717ddd482775a2b77920a3c0c57a7" + integrity sha1-RdG35QbHF93UgndaK3eSCjwMV6c= + dependencies: + locate-path "^2.0.0" + +find-up@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/find-up/-/find-up-3.0.0.tgz#49169f1d7993430646da61ecc5ae355c21c97b73" + integrity sha512-1yD6RmLI1XBfxugvORwlck6f75tYL+iR0jqwsOrOxMZyGYqUuDhJ0l4AXdO1iX/FTs9cBAMEk1gWSEx1kSbylg== + dependencies: + locate-path "^3.0.0" + +findup-sync@3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/findup-sync/-/findup-sync-3.0.0.tgz#17b108f9ee512dfb7a5c7f3c8b27ea9e1a9c08d1" + integrity sha512-YbffarhcicEhOrm4CtrwdKBdCuz576RLdhJDsIfvNtxUuhdRet1qZcsMjqbePtAseKdAnDyM/IyXbu7PRPRLYg== + dependencies: + detect-file "^1.0.0" + is-glob "^4.0.0" + micromatch "^3.0.4" + resolve-dir "^1.0.1" + +flat-cache@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/flat-cache/-/flat-cache-2.0.1.tgz#5d296d6f04bda44a4630a301413bdbc2ec085ec0" + integrity sha512-LoQe6yDuUMDzQAEH8sgmh4Md6oZnc/7PjtwjNFSzveXqSHt6ka9fPBuso7IGf9Rz4uqnSnWiFH2B/zj24a5ReA== + dependencies: + flatted "^2.0.0" + rimraf "2.6.3" + write "1.0.3" + +flatted@^2.0.0: + version "2.0.2" + resolved "https://registry.yarnpkg.com/flatted/-/flatted-2.0.2.tgz#4575b21e2bcee7434aa9be662f4b7b5f9c2b5138" + integrity sha512-r5wGx7YeOwNWNlCA0wQ86zKyDLMQr+/RB8xy74M4hTphfmjlijTSSXGuH8rnvKZnfT9i+75zmd8jcKdMR4O6jA== + +flatten@^1.0.2: + version "1.0.3" + resolved "https://registry.yarnpkg.com/flatten/-/flatten-1.0.3.tgz#c1283ac9f27b368abc1e36d1ff7b04501a30356b" + integrity sha512-dVsPA/UwQ8+2uoFe5GHtiBMu48dWLTdsuEd7CKGlZlD78r1TTWBvDuFaFGKCo/ZfEr95Uk56vZoX86OsHkUeIg== + +flush-write-stream@^1.0.0: + version "1.1.1" + resolved "https://registry.yarnpkg.com/flush-write-stream/-/flush-write-stream-1.1.1.tgz#8dd7d873a1babc207d94ead0c2e0e44276ebf2e8" + integrity sha512-3Z4XhFZ3992uIq0XOqb9AreonueSYphE6oYbpt5+3u06JWklbsPkNv3ZKkP9Bz/r+1MWCaMoSQ28P85+1Yc77w== + dependencies: + inherits "^2.0.3" + readable-stream "^2.3.6" + +follow-redirects@^1.0.0: + version "1.11.0" + resolved "https://registry.yarnpkg.com/follow-redirects/-/follow-redirects-1.11.0.tgz#afa14f08ba12a52963140fe43212658897bc0ecb" + integrity sha512-KZm0V+ll8PfBrKwMzdo5D13b1bur9Iq9Zd/RMmAoQQcl2PxxFml8cxXPaaPYVbV0RjNjq1CU7zIzAOqtUPudmA== + dependencies: + debug "^3.0.0" + +for-in@^0.1.3: + version "0.1.8" + resolved "https://registry.yarnpkg.com/for-in/-/for-in-0.1.8.tgz#d8773908e31256109952b1fdb9b3fa867d2775e1" + integrity sha1-2Hc5COMSVhCZUrH9ubP6hn0ndeE= + +for-in@^1.0.1, for-in@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/for-in/-/for-in-1.0.2.tgz#81068d295a8142ec0ac726c6e2200c30fb6d5e80" + integrity sha1-gQaNKVqBQuwKxybG4iAMMPttXoA= + +for-own@^0.1.3: + version "0.1.5" + resolved "https://registry.yarnpkg.com/for-own/-/for-own-0.1.5.tgz#5265c681a4f294dabbf17c9509b6763aa84510ce" + integrity sha1-UmXGgaTylNq78XyVCbZ2OqhFEM4= + dependencies: + for-in "^1.0.1" + +forever-agent@~0.6.1: + version "0.6.1" + resolved "https://registry.yarnpkg.com/forever-agent/-/forever-agent-0.6.1.tgz#fbc71f0c41adeb37f96c577ad1ed42d8fdacca91" + integrity sha1-+8cfDEGt6zf5bFd60e1C2P2sypE= + +fork-ts-checker-webpack-plugin@3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/fork-ts-checker-webpack-plugin/-/fork-ts-checker-webpack-plugin-3.1.1.tgz#a1642c0d3e65f50c2cc1742e9c0a80f441f86b19" + integrity sha512-DuVkPNrM12jR41KM2e+N+styka0EgLkTnXmNcXdgOM37vtGeY+oCBK/Jx0hzSeEU6memFCtWb4htrHPMDfwwUQ== + dependencies: + babel-code-frame "^6.22.0" + chalk "^2.4.1" + chokidar "^3.3.0" + micromatch "^3.1.10" + minimatch "^3.0.4" + semver "^5.6.0" + tapable "^1.0.0" + worker-rpc "^0.1.0" + +form-data@~2.3.2: + version "2.3.3" + resolved "https://registry.yarnpkg.com/form-data/-/form-data-2.3.3.tgz#dcce52c05f644f298c6a7ab936bd724ceffbf3a6" + integrity sha512-1lLKB2Mu3aGP1Q/2eCOx0fNbRMe7XdwktwOruhfqqd0rIJWwN4Dh+E3hrPSlDCXnSR7UtZ1N38rVXm+6+MEhJQ== + dependencies: + asynckit "^0.4.0" + combined-stream "^1.0.6" + mime-types "^2.1.12" + +forwarded@~0.1.2: + version "0.1.2" + resolved "https://registry.yarnpkg.com/forwarded/-/forwarded-0.1.2.tgz#98c23dab1175657b8c0573e8ceccd91b0ff18c84" + integrity sha1-mMI9qxF1ZXuMBXPozszZGw/xjIQ= + +fragment-cache@^0.2.1: + version "0.2.1" + resolved "https://registry.yarnpkg.com/fragment-cache/-/fragment-cache-0.2.1.tgz#4290fad27f13e89be7f33799c6bc5a0abfff0d19" + integrity sha1-QpD60n8T6Jvn8zeZxrxaCr//DRk= + dependencies: + map-cache "^0.2.2" + +fresh@0.5.2: + version "0.5.2" + resolved "https://registry.yarnpkg.com/fresh/-/fresh-0.5.2.tgz#3d8cadd90d976569fa835ab1f8e4b23a105605a7" + integrity sha1-PYyt2Q2XZWn6g1qx+OSyOhBWBac= + +from2@^2.1.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/from2/-/from2-2.3.0.tgz#8bfb5502bde4a4d36cfdeea007fcca21d7e382af" + integrity sha1-i/tVAr3kpNNs/e6gB/zKIdfjgq8= + dependencies: + inherits "^2.0.1" + readable-stream "^2.0.0" + +fs-extra@^4.0.2: + version "4.0.3" + resolved "https://registry.yarnpkg.com/fs-extra/-/fs-extra-4.0.3.tgz#0d852122e5bc5beb453fb028e9c0c9bf36340c94" + integrity sha512-q6rbdDd1o2mAnQreO7YADIxf/Whx4AHBiRf6d+/cVT8h44ss+lHgxf1FemcqDnQt9X3ct4McHr+JMGlYSsK7Cg== + dependencies: + graceful-fs "^4.1.2" + jsonfile "^4.0.0" + universalify "^0.1.0" + +fs-extra@^7.0.0: + version "7.0.1" + resolved "https://registry.yarnpkg.com/fs-extra/-/fs-extra-7.0.1.tgz#4f189c44aa123b895f722804f55ea23eadc348e9" + integrity sha512-YJDaCJZEnBmcbw13fvdAM9AwNOJwOzrE4pqMqBq5nFiEqXUqHwlK4B+3pUw6JNvfSPtX05xFHtYy/1ni01eGCw== + dependencies: + graceful-fs "^4.1.2" + jsonfile "^4.0.0" + universalify "^0.1.0" + +fs-extra@^8.1.0: + version "8.1.0" + resolved "https://registry.yarnpkg.com/fs-extra/-/fs-extra-8.1.0.tgz#49d43c45a88cd9677668cb7be1b46efdb8d2e1c0" + integrity sha512-yhlQgA6mnOJUKOsRUFsgJdQCvkKhcz8tlZG5HBQfReYZy46OwLcY+Zia0mtdHsOo9y/hP+CxMN0TU9QxoOtG4g== + dependencies: + graceful-fs "^4.2.0" + jsonfile "^4.0.0" + universalify "^0.1.0" + +fs-minipass@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/fs-minipass/-/fs-minipass-2.1.0.tgz#7f5036fdbf12c63c169190cbe4199c852271f9fb" + integrity sha512-V/JgOLFCS+R6Vcq0slCuaeWEdNC3ouDlJMNIsacH2VtALiu9mV4LPrHc5cDl8k5aw6J8jwgWWpiTo5RYhmIzvg== + dependencies: + minipass "^3.0.0" + +fs-write-stream-atomic@^1.0.8: + version "1.0.10" + resolved "https://registry.yarnpkg.com/fs-write-stream-atomic/-/fs-write-stream-atomic-1.0.10.tgz#b47df53493ef911df75731e70a9ded0189db40c9" + integrity sha1-tH31NJPvkR33VzHnCp3tAYnbQMk= + dependencies: + graceful-fs "^4.1.2" + iferr "^0.1.5" + imurmurhash "^0.1.4" + readable-stream "1 || 2" + +fs.realpath@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f" + integrity sha1-FQStJSMVjKpA20onh8sBQRmU6k8= + +fsevents@2.1.2, fsevents@~2.1.2: + version "2.1.2" + resolved "https://registry.yarnpkg.com/fsevents/-/fsevents-2.1.2.tgz#4c0a1fb34bc68e543b4b82a9ec392bfbda840805" + integrity sha512-R4wDiBwZ0KzpgOWetKDug1FZcYhqYnUYKtfZYt4mD5SBz76q0KR4Q9o7GIPamsVPGmW3EYPPJ0dOOjvx32ldZA== + +fsevents@^1.2.7: + version "1.2.12" + resolved "https://registry.yarnpkg.com/fsevents/-/fsevents-1.2.12.tgz#db7e0d8ec3b0b45724fd4d83d43554a8f1f0de5c" + integrity sha512-Ggd/Ktt7E7I8pxZRbGIs7vwqAPscSESMrCSkx2FtWeqmheJgCo2R74fTsZFCifr0VTPwqRpPv17+6b8Zp7th0Q== + dependencies: + bindings "^1.5.0" + nan "^2.12.1" + +function-bind@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" + integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== + +functional-red-black-tree@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/functional-red-black-tree/-/functional-red-black-tree-1.0.1.tgz#1b0ab3bd553b2a0d6399d29c0e3ea0b252078327" + integrity sha1-GwqzvVU7Kg1jmdKcDj6gslIHgyc= + +gensync@^1.0.0-beta.1: + version "1.0.0-beta.1" + resolved "https://registry.yarnpkg.com/gensync/-/gensync-1.0.0-beta.1.tgz#58f4361ff987e5ff6e1e7a210827aa371eaac269" + integrity sha512-r8EC6NO1sngH/zdD9fiRDLdcgnbayXah+mLgManTaIZJqEC1MZstmnox8KpnI2/fxQwrp5OpCOYWLp4rBl4Jcg== + +get-caller-file@^1.0.1: + version "1.0.3" + resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-1.0.3.tgz#f978fa4c90d1dfe7ff2d6beda2a515e713bdcf4a" + integrity sha512-3t6rVToeoZfYSGd8YoLFR2DJkiQrIiUrGcjvFX2mDw3bn6k2OtwHN0TNCLbBO+w8qTvimhDkv+LSscbJY1vE6w== + +get-caller-file@^2.0.1: + version "2.0.5" + resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-2.0.5.tgz#4f94412a82db32f36e3b0b9741f8a97feb031f7e" + integrity sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg== + +get-own-enumerable-property-symbols@^3.0.0: + version "3.0.2" + resolved "https://registry.yarnpkg.com/get-own-enumerable-property-symbols/-/get-own-enumerable-property-symbols-3.0.2.tgz#b5fde77f22cbe35f390b4e089922c50bce6ef664" + integrity sha512-I0UBV/XOz1XkIJHEUDMZAbzCThU/H8DxmSfmdGcKPnVhu2VfFqr34jr9777IyaTYvxjedWhqVIilEDsCdP5G6g== + +get-stream@^4.0.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/get-stream/-/get-stream-4.1.0.tgz#c1b255575f3dc21d59bfc79cd3d2b46b1c3a54b5" + integrity sha512-GMat4EJ5161kIy2HevLlr4luNjBgvmj413KaQA7jt4V8B4RDsfpHk7WQ9GVqfYyyx8OS/L66Kox+rJRNklLK7w== + dependencies: + pump "^3.0.0" + +get-value@^2.0.3, get-value@^2.0.6: + version "2.0.6" + resolved "https://registry.yarnpkg.com/get-value/-/get-value-2.0.6.tgz#dc15ca1c672387ca76bd37ac0a395ba2042a2c28" + integrity sha1-3BXKHGcjh8p2vTesCjlbogQqLCg= + +getpass@^0.1.1: + version "0.1.7" + resolved "https://registry.yarnpkg.com/getpass/-/getpass-0.1.7.tgz#5eff8e3e684d569ae4cb2b1282604e8ba62149fa" + integrity sha1-Xv+OPmhNVprkyysSgmBOi6YhSfo= + dependencies: + assert-plus "^1.0.0" + +glob-parent@^3.1.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-3.1.0.tgz#9e6af6299d8d3bd2bd40430832bd113df906c5ae" + integrity sha1-nmr2KZ2NO9K9QEMIMr0RPfkGxa4= + dependencies: + is-glob "^3.1.0" + path-dirname "^1.0.0" + +glob-parent@^5.0.0, glob-parent@~5.1.0: + version "5.1.1" + resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-5.1.1.tgz#b6c1ef417c4e5663ea498f1c45afac6916bbc229" + integrity sha512-FnI+VGOpnlGHWZxthPGR+QhR78fuiK0sNLkHQv+bL9fQi57lNNdquIbna/WrfROrolq8GK5Ek6BiMwqL/voRYQ== + dependencies: + is-glob "^4.0.1" + +glob-to-regexp@^0.3.0: + version "0.3.0" + resolved "https://registry.yarnpkg.com/glob-to-regexp/-/glob-to-regexp-0.3.0.tgz#8c5a1494d2066c570cc3bfe4496175acc4d502ab" + integrity sha1-jFoUlNIGbFcMw7/kSWF1rMTVAqs= + +glob@^7.0.3, glob@^7.1.1, glob@^7.1.2, glob@^7.1.3, glob@^7.1.4, glob@^7.1.6: + version "7.1.6" + resolved "https://registry.yarnpkg.com/glob/-/glob-7.1.6.tgz#141f33b81a7c2492e125594307480c46679278a6" + integrity sha512-LwaxwyZ72Lk7vZINtNNrywX0ZuLyStrdDtabefZKAY5ZGJhVtgdznluResxNmPitE0SAO+O26sWTHeKSI2wMBA== + dependencies: + fs.realpath "^1.0.0" + inflight "^1.0.4" + inherits "2" + minimatch "^3.0.4" + once "^1.3.0" + path-is-absolute "^1.0.0" + +global-modules@2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/global-modules/-/global-modules-2.0.0.tgz#997605ad2345f27f51539bea26574421215c7780" + integrity sha512-NGbfmJBp9x8IxyJSd1P+otYK8vonoJactOogrVfFRIAEY1ukil8RSKDz2Yo7wh1oihl51l/r6W4epkeKJHqL8A== + dependencies: + global-prefix "^3.0.0" + +global-modules@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/global-modules/-/global-modules-1.0.0.tgz#6d770f0eb523ac78164d72b5e71a8877265cc3ea" + integrity sha512-sKzpEkf11GpOFuw0Zzjzmt4B4UZwjOcG757PPvrfhxcLFbq0wpsgpOqxpxtxFiCG4DtG93M6XRVbF2oGdev7bg== + dependencies: + global-prefix "^1.0.1" + is-windows "^1.0.1" + resolve-dir "^1.0.0" + +global-prefix@^1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/global-prefix/-/global-prefix-1.0.2.tgz#dbf743c6c14992593c655568cb66ed32c0122ebe" + integrity sha1-2/dDxsFJklk8ZVVoy2btMsASLr4= + dependencies: + expand-tilde "^2.0.2" + homedir-polyfill "^1.0.1" + ini "^1.3.4" + is-windows "^1.0.1" + which "^1.2.14" + +global-prefix@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/global-prefix/-/global-prefix-3.0.0.tgz#fc85f73064df69f50421f47f883fe5b913ba9b97" + integrity sha512-awConJSVCHVGND6x3tmMaKcQvwXLhjdkmomy2W+Goaui8YPgYgXJZewhg3fWC+DlfqqQuWg8AwqjGTD2nAPVWg== + dependencies: + ini "^1.3.5" + kind-of "^6.0.2" + which "^1.3.1" + +globals@^11.1.0: + version "11.12.0" + resolved "https://registry.yarnpkg.com/globals/-/globals-11.12.0.tgz#ab8795338868a0babd8525758018c2a7eb95c42e" + integrity sha512-WOBp/EEGUiIsJSp7wcv/y6MO+lV9UoncWqxuFfm8eBwzWNgyfBd6Gz+IeKQ9jCmyhoH99g15M3T+QaVHFjizVA== + +globals@^12.1.0: + version "12.4.0" + resolved "https://registry.yarnpkg.com/globals/-/globals-12.4.0.tgz#a18813576a41b00a24a97e7f815918c2e19925f8" + integrity sha512-BWICuzzDvDoH54NHKCseDanAhE3CeDorgDL5MT6LMXXj2WCnd9UC2szdk4AWLfjdgNBCXLUanXYcpBBKOSWGwg== + dependencies: + type-fest "^0.8.1" + +globby@8.0.2: + version "8.0.2" + resolved "https://registry.yarnpkg.com/globby/-/globby-8.0.2.tgz#5697619ccd95c5275dbb2d6faa42087c1a941d8d" + integrity sha512-yTzMmKygLp8RUpG1Ymu2VXPSJQZjNAZPD4ywgYEaG7e4tBJeUQBO8OpXrf1RCNcEs5alsoJYPAMiIHP0cmeC7w== + dependencies: + array-union "^1.0.1" + dir-glob "2.0.0" + fast-glob "^2.0.2" + glob "^7.1.2" + ignore "^3.3.5" + pify "^3.0.0" + slash "^1.0.0" + +globby@^6.1.0: + version "6.1.0" + resolved "https://registry.yarnpkg.com/globby/-/globby-6.1.0.tgz#f5a6d70e8395e21c858fb0489d64df02424d506c" + integrity sha1-9abXDoOV4hyFj7BInWTfAkJNUGw= + dependencies: + array-union "^1.0.1" + glob "^7.0.3" + object-assign "^4.0.1" + pify "^2.0.0" + pinkie-promise "^2.0.0" + +graceful-fs@^4.1.11, graceful-fs@^4.1.15, graceful-fs@^4.1.2, graceful-fs@^4.1.6, graceful-fs@^4.2.0, graceful-fs@^4.2.2: + version "4.2.3" + resolved "https://registry.yarnpkg.com/graceful-fs/-/graceful-fs-4.2.3.tgz#4a12ff1b60376ef09862c2093edd908328be8423" + integrity sha512-a30VEBm4PEdx1dRB7MFK7BejejvCvBronbLjht+sHuGYj8PHs7M/5Z+rt5lw551vZ7yfTCj4Vuyy3mSJytDWRQ== + +growly@^1.3.0: + version "1.3.0" + resolved "https://registry.yarnpkg.com/growly/-/growly-1.3.0.tgz#f10748cbe76af964b7c96c93c6bcc28af120c081" + integrity sha1-8QdIy+dq+WS3yWyTxrzCivEgwIE= + +gzip-size@5.1.1: + version "5.1.1" + resolved "https://registry.yarnpkg.com/gzip-size/-/gzip-size-5.1.1.tgz#cb9bee692f87c0612b232840a873904e4c135274" + integrity sha512-FNHi6mmoHvs1mxZAds4PpdCS6QG8B4C1krxJsMutgxl5t3+GlRTzzI3NEkifXx2pVsOvJdOGSmIgDhQ55FwdPA== + dependencies: + duplexer "^0.1.1" + pify "^4.0.1" + +handle-thing@^2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/handle-thing/-/handle-thing-2.0.1.tgz#857f79ce359580c340d43081cc648970d0bb234e" + integrity sha512-9Qn4yBxelxoh2Ow62nP+Ka/kMnOXRi8BXnRaUwezLNhqelnN49xKz4F/dPP8OYLxLxq6JDtZb2i9XznUQbNPTg== + +har-schema@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/har-schema/-/har-schema-2.0.0.tgz#a94c2224ebcac04782a0d9035521f24735b7ec92" + integrity sha1-qUwiJOvKwEeCoNkDVSHyRzW37JI= + +har-validator@~5.1.3: + version "5.1.3" + resolved "https://registry.yarnpkg.com/har-validator/-/har-validator-5.1.3.tgz#1ef89ebd3e4996557675eed9893110dc350fa080" + integrity sha512-sNvOCzEQNr/qrvJgc3UG/kD4QtlHycrzwS+6mfTrrSq97BvaYcPZZI1ZSqGSPR73Cxn4LKTD4PttRwfU7jWq5g== + dependencies: + ajv "^6.5.5" + har-schema "^2.0.0" + +harmony-reflect@^1.4.6: + version "1.6.1" + resolved "https://registry.yarnpkg.com/harmony-reflect/-/harmony-reflect-1.6.1.tgz#c108d4f2bb451efef7a37861fdbdae72c9bdefa9" + integrity sha512-WJTeyp0JzGtHcuMsi7rw2VwtkvLa+JyfEKJCFyfcS0+CDkjQ5lHPu7zEhFZP+PDSRrEgXa5Ah0l1MbgbE41XjA== + +has-ansi@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/has-ansi/-/has-ansi-2.0.0.tgz#34f5049ce1ecdf2b0649af3ef24e45ed35416d91" + integrity sha1-NPUEnOHs3ysGSa8+8k5F7TVBbZE= + dependencies: + ansi-regex "^2.0.0" + +has-flag@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-3.0.0.tgz#b5d454dc2199ae225699f3467e5a07f3b955bafd" + integrity sha1-tdRU3CGZriJWmfNGfloH87lVuv0= + +has-flag@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-4.0.0.tgz#944771fd9c81c81265c4d6941860da06bb59479b" + integrity sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ== + +has-symbols@^1.0.0, has-symbols@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/has-symbols/-/has-symbols-1.0.1.tgz#9f5214758a44196c406d9bd76cebf81ec2dd31e8" + integrity sha512-PLcsoqu++dmEIZB+6totNFKq/7Do+Z0u4oT0zKOJNl3lYK6vGwwu2hjHs+68OEZbTjiUE9bgOABXbP/GvrS0Kg== + +has-value@^0.3.1: + version "0.3.1" + resolved "https://registry.yarnpkg.com/has-value/-/has-value-0.3.1.tgz#7b1f58bada62ca827ec0a2078025654845995e1f" + integrity sha1-ex9YutpiyoJ+wKIHgCVlSEWZXh8= + dependencies: + get-value "^2.0.3" + has-values "^0.1.4" + isobject "^2.0.0" + +has-value@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/has-value/-/has-value-1.0.0.tgz#18b281da585b1c5c51def24c930ed29a0be6b177" + integrity sha1-GLKB2lhbHFxR3vJMkw7SmgvmsXc= + dependencies: + get-value "^2.0.6" + has-values "^1.0.0" + isobject "^3.0.0" + +has-values@^0.1.4: + version "0.1.4" + resolved "https://registry.yarnpkg.com/has-values/-/has-values-0.1.4.tgz#6d61de95d91dfca9b9a02089ad384bff8f62b771" + integrity sha1-bWHeldkd/Km5oCCJrThL/49it3E= + +has-values@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/has-values/-/has-values-1.0.0.tgz#95b0b63fec2146619a6fe57fe75628d5a39efe4f" + integrity sha1-lbC2P+whRmGab+V/51Yo1aOe/k8= + dependencies: + is-number "^3.0.0" + kind-of "^4.0.0" + +has@^1.0.0, has@^1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/has/-/has-1.0.3.tgz#722d7cbfc1f6aa8241f16dd814e011e1f41e8796" + integrity sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw== + dependencies: + function-bind "^1.1.1" + +hash-base@^3.0.0: + version "3.0.4" + resolved "https://registry.yarnpkg.com/hash-base/-/hash-base-3.0.4.tgz#5fc8686847ecd73499403319a6b0a3f3f6ae4918" + integrity sha1-X8hoaEfs1zSZQDMZprCj8/auSRg= + dependencies: + inherits "^2.0.1" + safe-buffer "^5.0.1" + +hash.js@^1.0.0, hash.js@^1.0.3: + version "1.1.7" + resolved "https://registry.yarnpkg.com/hash.js/-/hash.js-1.1.7.tgz#0babca538e8d4ee4a0f8988d68866537a003cf42" + integrity sha512-taOaskGt4z4SOANNseOviYDvjEJinIkRgmp7LbKP2YTTmVxWBl87s/uzK9r+44BclBSp2X7K1hqeNfz9JbBeXA== + dependencies: + inherits "^2.0.3" + minimalistic-assert "^1.0.1" + +he@^1.2.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/he/-/he-1.2.0.tgz#84ae65fa7eafb165fddb61566ae14baf05664f0f" + integrity sha512-F/1DnUGPopORZi0ni+CvrCgHQ5FyEAHRLSApuYWMmrbSwoN2Mn/7k+Gl38gJnR7yyDZk6WLXwiGod1JOWNDKGw== + +hex-color-regex@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/hex-color-regex/-/hex-color-regex-1.1.0.tgz#4c06fccb4602fe2602b3c93df82d7e7dbf1a8a8e" + integrity sha512-l9sfDFsuqtOqKDsQdqrMRk0U85RZc0RtOR9yPI7mRVOa4FsR/BVnZ0shmQRM96Ji99kYZP/7hn1cedc1+ApsTQ== + +hmac-drbg@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/hmac-drbg/-/hmac-drbg-1.0.1.tgz#d2745701025a6c775a6c545793ed502fc0c649a1" + integrity sha1-0nRXAQJabHdabFRXk+1QL8DGSaE= + dependencies: + hash.js "^1.0.3" + minimalistic-assert "^1.0.0" + minimalistic-crypto-utils "^1.0.1" + +homedir-polyfill@^1.0.1: + version "1.0.3" + resolved "https://registry.yarnpkg.com/homedir-polyfill/-/homedir-polyfill-1.0.3.tgz#743298cef4e5af3e194161fbadcc2151d3a058e8" + integrity sha512-eSmmWE5bZTK2Nou4g0AI3zZ9rswp7GRKoKXS1BLUkvPviOqs4YTN1djQIqrXy9k5gEtdLPy86JjRwsNM9tnDcA== + dependencies: + parse-passwd "^1.0.0" + +hosted-git-info@^2.1.4: + version "2.8.8" + resolved "https://registry.yarnpkg.com/hosted-git-info/-/hosted-git-info-2.8.8.tgz#7539bd4bc1e0e0a895815a2e0262420b12858488" + integrity sha512-f/wzC2QaWBs7t9IYqB4T3sR1xviIViXJRJTWBlx2Gf3g0Xi5vI7Yy4koXQ1c9OYDGHN9sBy1DQ2AB8fqZBWhUg== + +hpack.js@^2.1.6: + version "2.1.6" + resolved "https://registry.yarnpkg.com/hpack.js/-/hpack.js-2.1.6.tgz#87774c0949e513f42e84575b3c45681fade2a0b2" + integrity sha1-h3dMCUnlE/QuhFdbPEVoH63ioLI= + dependencies: + inherits "^2.0.1" + obuf "^1.0.0" + readable-stream "^2.0.1" + wbuf "^1.1.0" + +hsl-regex@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/hsl-regex/-/hsl-regex-1.0.0.tgz#d49330c789ed819e276a4c0d272dffa30b18fe6e" + integrity sha1-1JMwx4ntgZ4nakwNJy3/owsY/m4= + +hsla-regex@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/hsla-regex/-/hsla-regex-1.0.0.tgz#c1ce7a3168c8c6614033a4b5f7877f3b225f9c38" + integrity sha1-wc56MWjIxmFAM6S194d/OyJfnDg= + +html-comment-regex@^1.1.0: + version "1.1.2" + resolved "https://registry.yarnpkg.com/html-comment-regex/-/html-comment-regex-1.1.2.tgz#97d4688aeb5c81886a364faa0cad1dda14d433a7" + integrity sha512-P+M65QY2JQ5Y0G9KKdlDpo0zK+/OHptU5AaBwUfAIDJZk1MYf32Frm84EcOytfJE0t5JvkAnKlmjsXDnWzCJmQ== + +html-encoding-sniffer@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/html-encoding-sniffer/-/html-encoding-sniffer-1.0.2.tgz#e70d84b94da53aa375e11fe3a351be6642ca46f8" + integrity sha512-71lZziiDnsuabfdYiUeWdCVyKuqwWi23L8YeIgV9jSSZHCtb6wB1BKWooH7L3tn4/FuZJMVWyNaIDr4RGmaSYw== + dependencies: + whatwg-encoding "^1.0.1" + +html-entities@^1.2.1: + version "1.2.1" + resolved "https://registry.yarnpkg.com/html-entities/-/html-entities-1.2.1.tgz#0df29351f0721163515dfb9e5543e5f6eed5162f" + integrity sha1-DfKTUfByEWNRXfueVUPl9u7VFi8= + +html-escaper@^2.0.0: + version "2.0.2" + resolved "https://registry.yarnpkg.com/html-escaper/-/html-escaper-2.0.2.tgz#dfd60027da36a36dfcbe236262c00a5822681453" + integrity sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg== + +html-minifier-terser@^5.0.1: + version "5.0.5" + resolved "https://registry.yarnpkg.com/html-minifier-terser/-/html-minifier-terser-5.0.5.tgz#8f12f639789f04faa9f5cf2ff9b9f65607f21f8b" + integrity sha512-cBSFFghQh/uHcfSiL42KxxIRMF7A144+3E44xdlctIjxEmkEfCvouxNyFH2wysXk1fCGBPwtcr3hDWlGTfkDew== + dependencies: + camel-case "^4.1.1" + clean-css "^4.2.3" + commander "^4.1.1" + he "^1.2.0" + param-case "^3.0.3" + relateurl "^0.2.7" + terser "^4.6.3" + +html-webpack-plugin@4.0.0-beta.11: + version "4.0.0-beta.11" + resolved "https://registry.yarnpkg.com/html-webpack-plugin/-/html-webpack-plugin-4.0.0-beta.11.tgz#3059a69144b5aecef97708196ca32f9e68677715" + integrity sha512-4Xzepf0qWxf8CGg7/WQM5qBB2Lc/NFI7MhU59eUDTkuQp3skZczH4UA1d6oQyDEIoMDgERVhRyTdtUPZ5s5HBg== + dependencies: + html-minifier-terser "^5.0.1" + loader-utils "^1.2.3" + lodash "^4.17.15" + pretty-error "^2.1.1" + tapable "^1.1.3" + util.promisify "1.0.0" + +htmlparser2@^3.3.0: + version "3.10.1" + resolved "https://registry.yarnpkg.com/htmlparser2/-/htmlparser2-3.10.1.tgz#bd679dc3f59897b6a34bb10749c855bb53a9392f" + integrity sha512-IgieNijUMbkDovyoKObU1DUhm1iwNYE/fuifEoEHfd1oZKZDaONBSkal7Y01shxsM49R4XaMdGez3WnF9UfiCQ== + dependencies: + domelementtype "^1.3.1" + domhandler "^2.3.0" + domutils "^1.5.1" + entities "^1.1.1" + inherits "^2.0.1" + readable-stream "^3.1.1" + +http-deceiver@^1.2.7: + version "1.2.7" + resolved "https://registry.yarnpkg.com/http-deceiver/-/http-deceiver-1.2.7.tgz#fa7168944ab9a519d337cb0bec7284dc3e723d87" + integrity sha1-+nFolEq5pRnTN8sL7HKE3D5yPYc= + +http-errors@1.7.2: + version "1.7.2" + resolved "https://registry.yarnpkg.com/http-errors/-/http-errors-1.7.2.tgz#4f5029cf13239f31036e5b2e55292bcfbcc85c8f" + integrity sha512-uUQBt3H/cSIVfch6i1EuPNy/YsRSOUBXTVfZ+yR7Zjez3qjBz6i9+i4zjNaoqcoFVI4lQJ5plg63TvGfRSDCRg== + dependencies: + depd "~1.1.2" + inherits "2.0.3" + setprototypeof "1.1.1" + statuses ">= 1.5.0 < 2" + toidentifier "1.0.0" + +http-errors@~1.6.2: + version "1.6.3" + resolved "https://registry.yarnpkg.com/http-errors/-/http-errors-1.6.3.tgz#8b55680bb4be283a0b5bf4ea2e38580be1d9320d" + integrity sha1-i1VoC7S+KDoLW/TqLjhYC+HZMg0= + dependencies: + depd "~1.1.2" + inherits "2.0.3" + setprototypeof "1.1.0" + statuses ">= 1.4.0 < 2" + +http-errors@~1.7.2: + version "1.7.3" + resolved "https://registry.yarnpkg.com/http-errors/-/http-errors-1.7.3.tgz#6c619e4f9c60308c38519498c14fbb10aacebb06" + integrity sha512-ZTTX0MWrsQ2ZAhA1cejAwDLycFsd7I7nVtnkT3Ol0aqodaKW+0CTZDQ1uBv5whptCnc8e8HeRRJxRs0kmm/Qfw== + dependencies: + depd "~1.1.2" + inherits "2.0.4" + setprototypeof "1.1.1" + statuses ">= 1.5.0 < 2" + toidentifier "1.0.0" + +"http-parser-js@>=0.4.0 <0.4.11": + version "0.4.10" + resolved "https://registry.yarnpkg.com/http-parser-js/-/http-parser-js-0.4.10.tgz#92c9c1374c35085f75db359ec56cc257cbb93fa4" + integrity sha1-ksnBN0w1CF912zWexWzCV8u5P6Q= + +http-proxy-middleware@0.19.1: + version "0.19.1" + resolved "https://registry.yarnpkg.com/http-proxy-middleware/-/http-proxy-middleware-0.19.1.tgz#183c7dc4aa1479150306498c210cdaf96080a43a" + integrity sha512-yHYTgWMQO8VvwNS22eLLloAkvungsKdKTLO8AJlftYIKNfJr3GK3zK0ZCfzDDGUBttdGc8xFy1mCitvNKQtC3Q== + dependencies: + http-proxy "^1.17.0" + is-glob "^4.0.0" + lodash "^4.17.11" + micromatch "^3.1.10" + +http-proxy@^1.17.0: + version "1.18.0" + resolved "https://registry.yarnpkg.com/http-proxy/-/http-proxy-1.18.0.tgz#dbe55f63e75a347db7f3d99974f2692a314a6a3a" + integrity sha512-84I2iJM/n1d4Hdgc6y2+qY5mDaz2PUVjlg9znE9byl+q0uC3DeByqBGReQu5tpLK0TAqTIXScRUV+dg7+bUPpQ== + dependencies: + eventemitter3 "^4.0.0" + follow-redirects "^1.0.0" + requires-port "^1.0.0" + +http-signature@~1.2.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/http-signature/-/http-signature-1.2.0.tgz#9aecd925114772f3d95b65a60abb8f7c18fbace1" + integrity sha1-muzZJRFHcvPZW2WmCruPfBj7rOE= + dependencies: + assert-plus "^1.0.0" + jsprim "^1.2.2" + sshpk "^1.7.0" + +https-browserify@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/https-browserify/-/https-browserify-1.0.0.tgz#ec06c10e0a34c0f2faf199f7fd7fc78fffd03c73" + integrity sha1-7AbBDgo0wPL68Zn3/X/Hj//QPHM= + +iconv-lite@0.4.24, iconv-lite@^0.4.24: + version "0.4.24" + resolved "https://registry.yarnpkg.com/iconv-lite/-/iconv-lite-0.4.24.tgz#2022b4b25fbddc21d2f524974a474aafe733908b" + integrity sha512-v3MXnZAcvnywkTUEZomIActle7RXXeedOR31wwl7VlyoXO4Qi9arvSenNQWne1TcRwhCL1HwLI21bEqdpj8/rA== + dependencies: + safer-buffer ">= 2.1.2 < 3" + +icss-utils@^4.0.0, icss-utils@^4.1.1: + version "4.1.1" + resolved "https://registry.yarnpkg.com/icss-utils/-/icss-utils-4.1.1.tgz#21170b53789ee27447c2f47dd683081403f9a467" + integrity sha512-4aFq7wvWyMHKgxsH8QQtGpvbASCf+eM3wPRLI6R+MgAnTCZ6STYsRvttLvRWK0Nfif5piF394St3HeJDaljGPA== + dependencies: + postcss "^7.0.14" + +identity-obj-proxy@3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/identity-obj-proxy/-/identity-obj-proxy-3.0.0.tgz#94d2bda96084453ef36fbc5aaec37e0f79f1fc14" + integrity sha1-lNK9qWCERT7zb7xarsN+D3nx/BQ= + dependencies: + harmony-reflect "^1.4.6" + +ieee754@^1.1.4: + version "1.1.13" + resolved "https://registry.yarnpkg.com/ieee754/-/ieee754-1.1.13.tgz#ec168558e95aa181fd87d37f55c32bbcb6708b84" + integrity sha512-4vf7I2LYV/HaWerSo3XmlMkp5eZ83i+/CDluXi/IGTs/O1sejBNhTtnxzmRZfvOUqj7lZjqHkeTvpgSFDlWZTg== + +iferr@^0.1.5: + version "0.1.5" + resolved "https://registry.yarnpkg.com/iferr/-/iferr-0.1.5.tgz#c60eed69e6d8fdb6b3104a1fcbca1c192dc5b501" + integrity sha1-xg7taebY/bazEEofy8ocGS3FtQE= + +ignore@^3.3.5: + version "3.3.10" + resolved "https://registry.yarnpkg.com/ignore/-/ignore-3.3.10.tgz#0a97fb876986e8081c631160f8f9f389157f0043" + integrity sha512-Pgs951kaMm5GXP7MOvxERINe3gsaVjUWFm+UZPSq9xYriQAksyhg0csnS0KXSNRD5NmNdapXEpjxG49+AKh/ug== + +ignore@^4.0.6: + version "4.0.6" + resolved "https://registry.yarnpkg.com/ignore/-/ignore-4.0.6.tgz#750e3db5862087b4737ebac8207ffd1ef27b25fc" + integrity sha512-cyFDKrqc/YdcWFniJhzI42+AzS+gNwmUzOSFcRCQYwySuBBBy/KjuxWLZ/FHEH6Moq1NizMOBWyTcv8O4OZIMg== + +immer@1.10.0: + version "1.10.0" + resolved "https://registry.yarnpkg.com/immer/-/immer-1.10.0.tgz#bad67605ba9c810275d91e1c2a47d4582e98286d" + integrity sha512-O3sR1/opvCDGLEVcvrGTMtLac8GJ5IwZC4puPrLuRj3l7ICKvkmA0vGuU9OW8mV9WIBRnaxp5GJh9IEAaNOoYg== + +import-cwd@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/import-cwd/-/import-cwd-2.1.0.tgz#aa6cf36e722761285cb371ec6519f53e2435b0a9" + integrity sha1-qmzzbnInYShcs3HsZRn1PiQ1sKk= + dependencies: + import-from "^2.1.0" + +import-fresh@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/import-fresh/-/import-fresh-2.0.0.tgz#d81355c15612d386c61f9ddd3922d4304822a546" + integrity sha1-2BNVwVYS04bGH53dOSLUMEgipUY= + dependencies: + caller-path "^2.0.0" + resolve-from "^3.0.0" + +import-fresh@^3.0.0, import-fresh@^3.1.0: + version "3.2.1" + resolved "https://registry.yarnpkg.com/import-fresh/-/import-fresh-3.2.1.tgz#633ff618506e793af5ac91bf48b72677e15cbe66" + integrity sha512-6e1q1cnWP2RXD9/keSkxHScg508CdXqXWgWBaETNhyuBFz+kUZlKboh+ISK+bU++DmbHimVBrOz/zzPe0sZ3sQ== + dependencies: + parent-module "^1.0.0" + resolve-from "^4.0.0" + +import-from@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/import-from/-/import-from-2.1.0.tgz#335db7f2a7affd53aaa471d4b8021dee36b7f3b1" + integrity sha1-M1238qev/VOqpHHUuAId7ja387E= + dependencies: + resolve-from "^3.0.0" + +import-local@2.0.0, import-local@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/import-local/-/import-local-2.0.0.tgz#55070be38a5993cf18ef6db7e961f5bee5c5a09d" + integrity sha512-b6s04m3O+s3CGSbqDIyP4R6aAwAeYlVq9+WUWep6iHa8ETRf9yei1U48C5MmfJmV9AiLYYBKPMq/W+/WRpQmCQ== + dependencies: + pkg-dir "^3.0.0" + resolve-cwd "^2.0.0" + +imurmurhash@^0.1.4: + version "0.1.4" + resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea" + integrity sha1-khi5srkoojixPcT7a21XbyMUU+o= + +indent-string@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/indent-string/-/indent-string-4.0.0.tgz#624f8f4497d619b2d9768531d58f4122854d7251" + integrity sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg== + +indexes-of@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/indexes-of/-/indexes-of-1.0.1.tgz#f30f716c8e2bd346c7b67d3df3915566a7c05607" + integrity sha1-8w9xbI4r00bHtn0985FVZqfAVgc= + +infer-owner@^1.0.3, infer-owner@^1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/infer-owner/-/infer-owner-1.0.4.tgz#c4cefcaa8e51051c2a40ba2ce8a3d27295af9467" + integrity sha512-IClj+Xz94+d7irH5qRyfJonOdfTzuDaifE6ZPWfx0N0+/ATZCbuTPq2prFl526urkQd90WyUKIh1DfBQ2hMz9A== + +inflight@^1.0.4: + version "1.0.6" + resolved "https://registry.yarnpkg.com/inflight/-/inflight-1.0.6.tgz#49bd6331d7d02d0c09bc910a1075ba8165b56df9" + integrity sha1-Sb1jMdfQLQwJvJEKEHW6gWW1bfk= + dependencies: + once "^1.3.0" + wrappy "1" + +inherits@2, inherits@2.0.4, inherits@^2.0.1, inherits@^2.0.3, inherits@~2.0.1, inherits@~2.0.3: + version "2.0.4" + resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.4.tgz#0fa2c64f932917c3433a0ded55363aae37416b7c" + integrity sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ== + +inherits@2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.1.tgz#b17d08d326b4423e568eff719f91b0b1cbdf69f1" + integrity sha1-sX0I0ya0Qj5Wjv9xn5GwscvfafE= + +inherits@2.0.3: + version "2.0.3" + resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.3.tgz#633c2c83e3da42a502f52466022480f4208261de" + integrity sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4= + +ini@^1.3.4, ini@^1.3.5: + version "1.3.5" + resolved "https://registry.yarnpkg.com/ini/-/ini-1.3.5.tgz#eee25f56db1c9ec6085e0c22778083f596abf927" + integrity sha512-RZY5huIKCMRWDUqZlEi72f/lmXKMvuszcMBduliQ3nnWbx9X/ZBQO7DijMEYS9EhHBb2qacRUMtC7svLwe0lcw== + +inquirer@7.0.4: + version "7.0.4" + resolved "https://registry.yarnpkg.com/inquirer/-/inquirer-7.0.4.tgz#99af5bde47153abca23f5c7fc30db247f39da703" + integrity sha512-Bu5Td5+j11sCkqfqmUTiwv+tWisMtP0L7Q8WrqA2C/BbBhy1YTdFrvjjlrKq8oagA/tLQBski2Gcx/Sqyi2qSQ== + dependencies: + ansi-escapes "^4.2.1" + chalk "^2.4.2" + cli-cursor "^3.1.0" + cli-width "^2.0.0" + external-editor "^3.0.3" + figures "^3.0.0" + lodash "^4.17.15" + mute-stream "0.0.8" + run-async "^2.2.0" + rxjs "^6.5.3" + string-width "^4.1.0" + strip-ansi "^5.1.0" + through "^2.3.6" + +inquirer@^7.0.0: + version "7.1.0" + resolved "https://registry.yarnpkg.com/inquirer/-/inquirer-7.1.0.tgz#1298a01859883e17c7264b82870ae1034f92dd29" + integrity sha512-5fJMWEmikSYu0nv/flMc475MhGbB7TSPd/2IpFV4I4rMklboCH2rQjYY5kKiYGHqUF9gvaambupcJFFG9dvReg== + dependencies: + ansi-escapes "^4.2.1" + chalk "^3.0.0" + cli-cursor "^3.1.0" + cli-width "^2.0.0" + external-editor "^3.0.3" + figures "^3.0.0" + lodash "^4.17.15" + mute-stream "0.0.8" + run-async "^2.4.0" + rxjs "^6.5.3" + string-width "^4.1.0" + strip-ansi "^6.0.0" + through "^2.3.6" + +internal-ip@^4.3.0: + version "4.3.0" + resolved "https://registry.yarnpkg.com/internal-ip/-/internal-ip-4.3.0.tgz#845452baad9d2ca3b69c635a137acb9a0dad0907" + integrity sha512-S1zBo1D6zcsyuC6PMmY5+55YMILQ9av8lotMx447Bq6SAgo/sDK6y6uUKmuYhW7eacnIhFfsPmCNYdDzsnnDCg== + dependencies: + default-gateway "^4.2.0" + ipaddr.js "^1.9.0" + +internal-slot@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/internal-slot/-/internal-slot-1.0.2.tgz#9c2e9fb3cd8e5e4256c6f45fe310067fcfa378a3" + integrity sha512-2cQNfwhAfJIkU4KZPkDI+Gj5yNNnbqi40W9Gge6dfnk4TocEVm00B3bdiL+JINrbGJil2TeHvM4rETGzk/f/0g== + dependencies: + es-abstract "^1.17.0-next.1" + has "^1.0.3" + side-channel "^1.0.2" + +interpret@1.2.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/interpret/-/interpret-1.2.0.tgz#d5061a6224be58e8083985f5014d844359576296" + integrity sha512-mT34yGKMNceBQUoVn7iCDKDntA7SC6gycMAWzGx1z/CMCTV7b2AAtXlo3nRyHZ1FelRkQbQjprHSYGwzLtkVbw== + +invariant@^2.2.2, invariant@^2.2.4: + version "2.2.4" + resolved "https://registry.yarnpkg.com/invariant/-/invariant-2.2.4.tgz#610f3c92c9359ce1db616e538008d23ff35158e6" + integrity sha512-phJfQVBuaJM5raOpJjSfkiD6BpbCE4Ns//LaXl6wGYtUBY83nWS6Rf9tXm2e8VaK60JEjYldbPif/A2B1C2gNA== + dependencies: + loose-envify "^1.0.0" + +invert-kv@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/invert-kv/-/invert-kv-2.0.0.tgz#7393f5afa59ec9ff5f67a27620d11c226e3eec02" + integrity sha512-wPVv/y/QQ/Uiirj/vh3oP+1Ww+AWehmi1g5fFWGPF6IpCBCDVrhgHRMvrLfdYcwDh3QJbGXDW4JAuzxElLSqKA== + +ip-regex@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/ip-regex/-/ip-regex-2.1.0.tgz#fa78bf5d2e6913c911ce9f819ee5146bb6d844e9" + integrity sha1-+ni/XS5pE8kRzp+BnuUUa7bYROk= + +ip@^1.1.0, ip@^1.1.5: + version "1.1.5" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.5.tgz#bdded70114290828c0a039e72ef25f5aaec4354a" + integrity sha1-vd7XARQpCCjAoDnnLvJfWq7ENUo= + +ipaddr.js@1.9.1, ipaddr.js@^1.9.0: + version "1.9.1" + resolved "https://registry.yarnpkg.com/ipaddr.js/-/ipaddr.js-1.9.1.tgz#bff38543eeb8984825079ff3a2a8e6cbd46781b3" + integrity sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g== + +is-absolute-url@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/is-absolute-url/-/is-absolute-url-2.1.0.tgz#50530dfb84fcc9aa7dbe7852e83a37b93b9f2aa6" + integrity sha1-UFMN+4T8yap9vnhS6Do3uTufKqY= + +is-absolute-url@^3.0.3: + version "3.0.3" + resolved "https://registry.yarnpkg.com/is-absolute-url/-/is-absolute-url-3.0.3.tgz#96c6a22b6a23929b11ea0afb1836c36ad4a5d698" + integrity sha512-opmNIX7uFnS96NtPmhWQgQx6/NYFgsUXYMllcfzwWKUMwfo8kku1TvE6hkNcH+Q1ts5cMVrsY7j0bxXQDciu9Q== + +is-accessor-descriptor@^0.1.6: + version "0.1.6" + resolved "https://registry.yarnpkg.com/is-accessor-descriptor/-/is-accessor-descriptor-0.1.6.tgz#a9e12cb3ae8d876727eeef3843f8a0897b5c98d6" + integrity sha1-qeEss66Nh2cn7u84Q/igiXtcmNY= + dependencies: + kind-of "^3.0.2" + +is-accessor-descriptor@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-accessor-descriptor/-/is-accessor-descriptor-1.0.0.tgz#169c2f6d3df1f992618072365c9b0ea1f6878656" + integrity sha512-m5hnHTkcVsPfqx3AKlyttIPb7J+XykHvJP2B9bZDjlhLIoEq4XoK64Vg7boZlVWYK6LUY94dYPEE7Lh0ZkZKcQ== + dependencies: + kind-of "^6.0.0" + +is-arguments@^1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/is-arguments/-/is-arguments-1.0.4.tgz#3faf966c7cba0ff437fb31f6250082fcf0448cf3" + integrity sha512-xPh0Rmt8NE65sNzvyUmWgI1tz3mKq74lGA0mL8LYZcoIzKOzDh6HmrYm3d18k60nHerC8A9Km8kYu87zfSFnLA== + +is-arrayish@^0.2.1: + version "0.2.1" + resolved "https://registry.yarnpkg.com/is-arrayish/-/is-arrayish-0.2.1.tgz#77c99840527aa8ecb1a8ba697b80645a7a926a9d" + integrity sha1-d8mYQFJ6qOyxqLppe4BkWnqSap0= + +is-arrayish@^0.3.1: + version "0.3.2" + resolved "https://registry.yarnpkg.com/is-arrayish/-/is-arrayish-0.3.2.tgz#4574a2ae56f7ab206896fb431eaeed066fdf8f03" + integrity sha512-eVRqCvVlZbuw3GrM63ovNSNAeA1K16kaR/LRY/92w0zxQ5/1YzwblUX652i4Xs9RwAGjW9d9y6X88t8OaAJfWQ== + +is-binary-path@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/is-binary-path/-/is-binary-path-1.0.1.tgz#75f16642b480f187a711c814161fd3a4a7655898" + integrity sha1-dfFmQrSA8YenEcgUFh/TpKdlWJg= + dependencies: + binary-extensions "^1.0.0" + +is-binary-path@~2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/is-binary-path/-/is-binary-path-2.1.0.tgz#ea1f7f3b80f064236e83470f86c09c254fb45b09" + integrity sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw== + dependencies: + binary-extensions "^2.0.0" + +is-buffer@^1.0.2, is-buffer@^1.1.5: + version "1.1.6" + resolved "https://registry.yarnpkg.com/is-buffer/-/is-buffer-1.1.6.tgz#efaa2ea9daa0d7ab2ea13a97b2b8ad51fefbe8be" + integrity sha512-NcdALwpXkTm5Zvvbk7owOUSvVvBKDgKP5/ewfXEznmQFfs4ZRmanOeKBTjRVjka3QFoN6XJ+9F3USqfHqTaU5w== + +is-callable@^1.1.4, is-callable@^1.1.5: + version "1.1.5" + resolved "https://registry.yarnpkg.com/is-callable/-/is-callable-1.1.5.tgz#f7e46b596890456db74e7f6e976cb3273d06faab" + integrity sha512-ESKv5sMCJB2jnHTWZ3O5itG+O128Hsus4K4Qh1h2/cgn2vbgnLSVqfV46AeJA9D5EeeLa9w81KUXMtn34zhX+Q== + +is-ci@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/is-ci/-/is-ci-2.0.0.tgz#6bc6334181810e04b5c22b3d589fdca55026404c" + integrity sha512-YfJT7rkpQB0updsdHLGWrvhBJfcfzNNawYDNIyQXJz0IViGf75O8EBPKSdvw2rF+LGCsX4FZ8tcr3b19LcZq4w== + dependencies: + ci-info "^2.0.0" + +is-color-stop@^1.0.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/is-color-stop/-/is-color-stop-1.1.0.tgz#cfff471aee4dd5c9e158598fbe12967b5cdad345" + integrity sha1-z/9HGu5N1cnhWFmPvhKWe1za00U= + dependencies: + css-color-names "^0.0.4" + hex-color-regex "^1.1.0" + hsl-regex "^1.0.0" + hsla-regex "^1.0.0" + rgb-regex "^1.0.1" + rgba-regex "^1.0.0" + +is-data-descriptor@^0.1.4: + version "0.1.4" + resolved "https://registry.yarnpkg.com/is-data-descriptor/-/is-data-descriptor-0.1.4.tgz#0b5ee648388e2c860282e793f1856fec3f301b56" + integrity sha1-C17mSDiOLIYCgueT8YVv7D8wG1Y= + dependencies: + kind-of "^3.0.2" + +is-data-descriptor@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-data-descriptor/-/is-data-descriptor-1.0.0.tgz#d84876321d0e7add03990406abbbbd36ba9268c7" + integrity sha512-jbRXy1FmtAoCjQkVmIVYwuuqDFUbaOeDjmed1tOGPrsMhtJA4rD9tkgA0F1qJ3gRFRXcHYVkdeaP50Q5rE/jLQ== + dependencies: + kind-of "^6.0.0" + +is-date-object@^1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/is-date-object/-/is-date-object-1.0.2.tgz#bda736f2cd8fd06d32844e7743bfa7494c3bfd7e" + integrity sha512-USlDT524woQ08aoZFzh3/Z6ch9Y/EWXEHQ/AaRN0SkKq4t2Jw2R2339tSXmwuVoY7LLlBCbOIlx2myP/L5zk0g== + +is-descriptor@^0.1.0: + version "0.1.6" + resolved "https://registry.yarnpkg.com/is-descriptor/-/is-descriptor-0.1.6.tgz#366d8240dde487ca51823b1ab9f07a10a78251ca" + integrity sha512-avDYr0SB3DwO9zsMov0gKCESFYqCnE4hq/4z3TdUlukEy5t9C0YRq7HLrsN52NAcqXKaepeCD0n+B0arnVG3Hg== + dependencies: + is-accessor-descriptor "^0.1.6" + is-data-descriptor "^0.1.4" + kind-of "^5.0.0" + +is-descriptor@^1.0.0, is-descriptor@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/is-descriptor/-/is-descriptor-1.0.2.tgz#3b159746a66604b04f8c81524ba365c5f14d86ec" + integrity sha512-2eis5WqQGV7peooDyLmNEPUrps9+SXX5c9pL3xEB+4e9HnGuDa7mB7kHxHw4CbqS9k1T2hOH3miL8n8WtiYVtg== + dependencies: + is-accessor-descriptor "^1.0.0" + is-data-descriptor "^1.0.0" + kind-of "^6.0.2" + +is-directory@^0.3.1: + version "0.3.1" + resolved "https://registry.yarnpkg.com/is-directory/-/is-directory-0.3.1.tgz#61339b6f2475fc772fd9c9d83f5c8575dc154ae1" + integrity sha1-YTObbyR1/Hcv2cnYP1yFddwVSuE= + +is-docker@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/is-docker/-/is-docker-2.0.0.tgz#2cb0df0e75e2d064fe1864c37cdeacb7b2dcf25b" + integrity sha512-pJEdRugimx4fBMra5z2/5iRdZ63OhYV0vr0Dwm5+xtW4D1FvRkB8hamMIhnWfyJeDdyr/aa7BDyNbtG38VxgoQ== + +is-extendable@^0.1.0, is-extendable@^0.1.1: + version "0.1.1" + resolved "https://registry.yarnpkg.com/is-extendable/-/is-extendable-0.1.1.tgz#62b110e289a471418e3ec36a617d472e301dfc89" + integrity sha1-YrEQ4omkcUGOPsNqYX1HLjAd/Ik= + +is-extendable@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/is-extendable/-/is-extendable-1.0.1.tgz#a7470f9e426733d81bd81e1155264e3a3507cab4" + integrity sha512-arnXMxT1hhoKo9k1LZdmlNyJdDDfy2v0fXjFlmok4+i8ul/6WlbVge9bhM74OpNPQPMGUToDtz+KXa1PneJxOA== + dependencies: + is-plain-object "^2.0.4" + +is-extglob@^2.1.0, is-extglob@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/is-extglob/-/is-extglob-2.1.1.tgz#a88c02535791f02ed37c76a1b9ea9773c833f8c2" + integrity sha1-qIwCU1eR8C7TfHahueqXc8gz+MI= + +is-fullwidth-code-point@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-1.0.0.tgz#ef9e31386f031a7f0d643af82fde50c457ef00cb" + integrity sha1-754xOG8DGn8NZDr4L95QxFfvAMs= + dependencies: + number-is-nan "^1.0.0" + +is-fullwidth-code-point@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-2.0.0.tgz#a3b30a5c4f199183167aaab93beefae3ddfb654f" + integrity sha1-o7MKXE8ZkYMWeqq5O+764937ZU8= + +is-fullwidth-code-point@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d" + integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg== + +is-generator-fn@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/is-generator-fn/-/is-generator-fn-2.1.0.tgz#7d140adc389aaf3011a8f2a2a4cfa6faadffb118" + integrity sha512-cTIB4yPYL/Grw0EaSzASzg6bBy9gqCofvWN8okThAYIxKJZC+udlRAmGbM0XLeniEJSs8uEgHPGuHSe1XsOLSQ== + +is-glob@^3.1.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/is-glob/-/is-glob-3.1.0.tgz#7ba5ae24217804ac70707b96922567486cc3e84a" + integrity sha1-e6WuJCF4BKxwcHuWkiVnSGzD6Eo= + dependencies: + is-extglob "^2.1.0" + +is-glob@^4.0.0, is-glob@^4.0.1, is-glob@~4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/is-glob/-/is-glob-4.0.1.tgz#7567dbe9f2f5e2467bc77ab83c4a29482407a5dc" + integrity sha512-5G0tKtBTFImOqDnLB2hG6Bp2qcKEFduo4tZu9MT/H6NQv/ghhy30o55ufafxJ/LdH79LLs2Kfrn85TLKyA7BUg== + dependencies: + is-extglob "^2.1.1" + +is-number@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/is-number/-/is-number-3.0.0.tgz#24fd6201a4782cf50561c810276afc7d12d71195" + integrity sha1-JP1iAaR4LPUFYcgQJ2r8fRLXEZU= + dependencies: + kind-of "^3.0.2" + +is-number@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/is-number/-/is-number-7.0.0.tgz#7535345b896734d5f80c4d06c50955527a14f12b" + integrity sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng== + +is-obj@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/is-obj/-/is-obj-1.0.1.tgz#3e4729ac1f5fde025cd7d83a896dab9f4f67db0f" + integrity sha1-PkcprB9f3gJc19g6iW2rn09n2w8= + +is-obj@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/is-obj/-/is-obj-2.0.0.tgz#473fb05d973705e3fd9620545018ca8e22ef4982" + integrity sha512-drqDG3cbczxxEJRoOXcOjtdp1J/lyp1mNn0xaznRs8+muBhgQcrnbspox5X5fOw0HnMnbfDzvnEMEtqDEJEo8w== + +is-path-cwd@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/is-path-cwd/-/is-path-cwd-2.2.0.tgz#67d43b82664a7b5191fd9119127eb300048a9fdb" + integrity sha512-w942bTcih8fdJPJmQHFzkS76NEP8Kzzvmw92cXsazb8intwLqPibPPdXf4ANdKV3rYMuuQYGIWtvz9JilB3NFQ== + +is-path-in-cwd@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/is-path-in-cwd/-/is-path-in-cwd-2.1.0.tgz#bfe2dca26c69f397265a4009963602935a053acb" + integrity sha512-rNocXHgipO+rvnP6dk3zI20RpOtrAM/kzbB258Uw5BWr3TpXi861yzjo16Dn4hUox07iw5AyeMLHWsujkjzvRQ== + dependencies: + is-path-inside "^2.1.0" + +is-path-inside@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/is-path-inside/-/is-path-inside-2.1.0.tgz#7c9810587d659a40d27bcdb4d5616eab059494b2" + integrity sha512-wiyhTzfDWsvwAW53OBWF5zuvaOGlZ6PwYxAbPVDhpm+gM09xKQGjBq/8uYN12aDvMxnAnq3dxTyoSoRNmg5YFg== + dependencies: + path-is-inside "^1.0.2" + +is-plain-obj@^1.0.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/is-plain-obj/-/is-plain-obj-1.1.0.tgz#71a50c8429dfca773c92a390a4a03b39fcd51d3e" + integrity sha1-caUMhCnfync8kqOQpKA7OfzVHT4= + +is-plain-object@^2.0.1, is-plain-object@^2.0.3, is-plain-object@^2.0.4: + version "2.0.4" + resolved "https://registry.yarnpkg.com/is-plain-object/-/is-plain-object-2.0.4.tgz#2c163b3fafb1b606d9d17928f05c2a1c38e07677" + integrity sha512-h5PpgXkWitc38BBMYawTYMWJHFZJVnBquFE57xFpjB8pJFiF6gZ+bU+WyI/yqXiFR5mdLsgYNaPe8uao6Uv9Og== + dependencies: + isobject "^3.0.1" + +is-promise@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/is-promise/-/is-promise-2.1.0.tgz#79a2a9ece7f096e80f36d2b2f3bc16c1ff4bf3fa" + integrity sha1-eaKp7OfwlugPNtKy87wWwf9L8/o= + +is-regex@^1.0.4, is-regex@^1.0.5: + version "1.0.5" + resolved "https://registry.yarnpkg.com/is-regex/-/is-regex-1.0.5.tgz#39d589a358bf18967f726967120b8fc1aed74eae" + integrity sha512-vlKW17SNq44owv5AQR3Cq0bQPEb8+kF3UKZ2fiZNOWtztYE5i0CzCZxFDwO58qAOWtxdBRVO/V5Qin1wjCqFYQ== + dependencies: + has "^1.0.3" + +is-regexp@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-regexp/-/is-regexp-1.0.0.tgz#fd2d883545c46bac5a633e7b9a09e87fa2cb5069" + integrity sha1-/S2INUXEa6xaYz57mgnof6LLUGk= + +is-resolvable@^1.0.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/is-resolvable/-/is-resolvable-1.1.0.tgz#fb18f87ce1feb925169c9a407c19318a3206ed88" + integrity sha512-qgDYXFSR5WvEfuS5dMj6oTMEbrrSaM0CrFk2Yiq/gXnBvD9pMa2jGXxyhGLfvhZpuMZe18CJpFxAt3CRs42NMg== + +is-root@2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/is-root/-/is-root-2.1.0.tgz#809e18129cf1129644302a4f8544035d51984a9c" + integrity sha512-AGOriNp96vNBd3HtU+RzFEc75FfR5ymiYv8E553I71SCeXBiMsVDUtdio1OEFvrPyLIQ9tVR5RxXIFe5PUFjMg== + +is-stream@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/is-stream/-/is-stream-1.1.0.tgz#12d4a3dd4e68e0b79ceb8dbc84173ae80d91ca44" + integrity sha1-EtSj3U5o4Lec6428hBc66A2RykQ= + +is-string@^1.0.5: + version "1.0.5" + resolved "https://registry.yarnpkg.com/is-string/-/is-string-1.0.5.tgz#40493ed198ef3ff477b8c7f92f644ec82a5cd3a6" + integrity sha512-buY6VNRjhQMiF1qWDouloZlQbRhDPCebwxSjxMjxgemYT46YMd2NR0/H+fBhEfWX4A/w9TBJ+ol+okqJKFE6vQ== + +is-svg@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/is-svg/-/is-svg-3.0.0.tgz#9321dbd29c212e5ca99c4fa9794c714bcafa2f75" + integrity sha512-gi4iHK53LR2ujhLVVj+37Ykh9GLqYHX6JOVXbLAucaG/Cqw9xwdFOjDM2qeifLs1sF1npXXFvDu0r5HNgCMrzQ== + dependencies: + html-comment-regex "^1.1.0" + +is-symbol@^1.0.2: + version "1.0.3" + resolved "https://registry.yarnpkg.com/is-symbol/-/is-symbol-1.0.3.tgz#38e1014b9e6329be0de9d24a414fd7441ec61937" + integrity sha512-OwijhaRSgqvhm/0ZdAcXNZt9lYdKFpcRDT5ULUuYXPoT794UNOdU+gpT6Rzo7b4V2HUl/op6GqY894AZwv9faQ== + dependencies: + has-symbols "^1.0.1" + +is-typedarray@~1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-typedarray/-/is-typedarray-1.0.0.tgz#e479c80858df0c1b11ddda6940f96011fcda4a9a" + integrity sha1-5HnICFjfDBsR3dppQPlgEfzaSpo= + +is-windows@^1.0.1, is-windows@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/is-windows/-/is-windows-1.0.2.tgz#d1850eb9791ecd18e6182ce12a30f396634bb19d" + integrity sha512-eXK1UInq2bPmjyX6e3VHIzMLobc4J94i4AWn+Hpq3OU5KkrRC96OAcR3PRJ/pGu6m8TRnBHP9dkXQVsT/COVIA== + +is-wsl@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/is-wsl/-/is-wsl-1.1.0.tgz#1f16e4aa22b04d1336b66188a66af3c600c3a66d" + integrity sha1-HxbkqiKwTRM2tmGIpmrzxgDDpm0= + +is-wsl@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/is-wsl/-/is-wsl-2.1.1.tgz#4a1c152d429df3d441669498e2486d3596ebaf1d" + integrity sha512-umZHcSrwlDHo2TGMXv0DZ8dIUGunZ2Iv68YZnrmCiBPkZ4aaOhtv7pXJKeki9k3qJ3RJr0cDyitcl5wEH3AYog== + +isarray@1.0.0, isarray@^1.0.0, isarray@~1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/isarray/-/isarray-1.0.0.tgz#bb935d48582cba168c06834957a54a3e07124f11" + integrity sha1-u5NdSFgsuhaMBoNJV6VKPgcSTxE= + +isexe@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/isexe/-/isexe-2.0.0.tgz#e8fbf374dc556ff8947a10dcb0572d633f2cfa10" + integrity sha1-6PvzdNxVb/iUehDcsFctYz8s+hA= + +isobject@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/isobject/-/isobject-2.1.0.tgz#f065561096a3f1da2ef46272f815c840d87e0c89" + integrity sha1-8GVWEJaj8dou9GJy+BXIQNh+DIk= + dependencies: + isarray "1.0.0" + +isobject@^3.0.0, isobject@^3.0.1: + version "3.0.1" + resolved "https://registry.yarnpkg.com/isobject/-/isobject-3.0.1.tgz#4e431e92b11a9731636aa1f9c8d1ccbcfdab78df" + integrity sha1-TkMekrEalzFjaqH5yNHMvP2reN8= + +isstream@~0.1.2: + version "0.1.2" + resolved "https://registry.yarnpkg.com/isstream/-/isstream-0.1.2.tgz#47e63f7af55afa6f92e1500e690eb8b8529c099a" + integrity sha1-R+Y/evVa+m+S4VAOaQ64uFKcCZo= + +istanbul-lib-coverage@^2.0.2, istanbul-lib-coverage@^2.0.5: + version "2.0.5" + resolved "https://registry.yarnpkg.com/istanbul-lib-coverage/-/istanbul-lib-coverage-2.0.5.tgz#675f0ab69503fad4b1d849f736baaca803344f49" + integrity sha512-8aXznuEPCJvGnMSRft4udDRDtb1V3pkQkMMI5LI+6HuQz5oQ4J2UFn1H82raA3qJtyOLkkwVqICBQkjnGtn5mA== + +istanbul-lib-instrument@^3.0.1, istanbul-lib-instrument@^3.3.0: + version "3.3.0" + resolved "https://registry.yarnpkg.com/istanbul-lib-instrument/-/istanbul-lib-instrument-3.3.0.tgz#a5f63d91f0bbc0c3e479ef4c5de027335ec6d630" + integrity sha512-5nnIN4vo5xQZHdXno/YDXJ0G+I3dAm4XgzfSVTPLQpj/zAV2dV6Juy0yaf10/zrJOJeHoN3fraFe+XRq2bFVZA== + dependencies: + "@babel/generator" "^7.4.0" + "@babel/parser" "^7.4.3" + "@babel/template" "^7.4.0" + "@babel/traverse" "^7.4.3" + "@babel/types" "^7.4.0" + istanbul-lib-coverage "^2.0.5" + semver "^6.0.0" + +istanbul-lib-report@^2.0.4: + version "2.0.8" + resolved "https://registry.yarnpkg.com/istanbul-lib-report/-/istanbul-lib-report-2.0.8.tgz#5a8113cd746d43c4889eba36ab10e7d50c9b4f33" + integrity sha512-fHBeG573EIihhAblwgxrSenp0Dby6tJMFR/HvlerBsrCTD5bkUuoNtn3gVh29ZCS824cGGBPn7Sg7cNk+2xUsQ== + dependencies: + istanbul-lib-coverage "^2.0.5" + make-dir "^2.1.0" + supports-color "^6.1.0" + +istanbul-lib-source-maps@^3.0.1: + version "3.0.6" + resolved "https://registry.yarnpkg.com/istanbul-lib-source-maps/-/istanbul-lib-source-maps-3.0.6.tgz#284997c48211752ec486253da97e3879defba8c8" + integrity sha512-R47KzMtDJH6X4/YW9XTx+jrLnZnscW4VpNN+1PViSYTejLVPWv7oov+Duf8YQSPyVRUvueQqz1TcsC6mooZTXw== + dependencies: + debug "^4.1.1" + istanbul-lib-coverage "^2.0.5" + make-dir "^2.1.0" + rimraf "^2.6.3" + source-map "^0.6.1" + +istanbul-reports@^2.2.6: + version "2.2.7" + resolved "https://registry.yarnpkg.com/istanbul-reports/-/istanbul-reports-2.2.7.tgz#5d939f6237d7b48393cc0959eab40cd4fd056931" + integrity sha512-uu1F/L1o5Y6LzPVSVZXNOoD/KXpJue9aeLRd0sM9uMXfZvzomB0WxVamWb5ue8kA2vVWEmW7EG+A5n3f1kqHKg== + dependencies: + html-escaper "^2.0.0" + +jest-changed-files@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-changed-files/-/jest-changed-files-24.9.0.tgz#08d8c15eb79a7fa3fc98269bc14b451ee82f8039" + integrity sha512-6aTWpe2mHF0DhL28WjdkO8LyGjs3zItPET4bMSeXU6T3ub4FPMw+mcOcbdGXQOAfmLcxofD23/5Bl9Z4AkFwqg== + dependencies: + "@jest/types" "^24.9.0" + execa "^1.0.0" + throat "^4.0.0" + +jest-cli@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-cli/-/jest-cli-24.9.0.tgz#ad2de62d07472d419c6abc301fc432b98b10d2af" + integrity sha512-+VLRKyitT3BWoMeSUIHRxV/2g8y9gw91Jh5z2UmXZzkZKpbC08CSehVxgHUwTpy+HwGcns/tqafQDJW7imYvGg== + dependencies: + "@jest/core" "^24.9.0" + "@jest/test-result" "^24.9.0" + "@jest/types" "^24.9.0" + chalk "^2.0.1" + exit "^0.1.2" + import-local "^2.0.0" + is-ci "^2.0.0" + jest-config "^24.9.0" + jest-util "^24.9.0" + jest-validate "^24.9.0" + prompts "^2.0.1" + realpath-native "^1.1.0" + yargs "^13.3.0" + +jest-config@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-config/-/jest-config-24.9.0.tgz#fb1bbc60c73a46af03590719efa4825e6e4dd1b5" + integrity sha512-RATtQJtVYQrp7fvWg6f5y3pEFj9I+H8sWw4aKxnDZ96mob5i5SD6ZEGWgMLXQ4LE8UurrjbdlLWdUeo+28QpfQ== + dependencies: + "@babel/core" "^7.1.0" + "@jest/test-sequencer" "^24.9.0" + "@jest/types" "^24.9.0" + babel-jest "^24.9.0" + chalk "^2.0.1" + glob "^7.1.1" + jest-environment-jsdom "^24.9.0" + jest-environment-node "^24.9.0" + jest-get-type "^24.9.0" + jest-jasmine2 "^24.9.0" + jest-regex-util "^24.3.0" + jest-resolve "^24.9.0" + jest-util "^24.9.0" + jest-validate "^24.9.0" + micromatch "^3.1.10" + pretty-format "^24.9.0" + realpath-native "^1.1.0" + +jest-diff@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-diff/-/jest-diff-24.9.0.tgz#931b7d0d5778a1baf7452cb816e325e3724055da" + integrity sha512-qMfrTs8AdJE2iqrTp0hzh7kTd2PQWrsFyj9tORoKmu32xjPjeE4NyjVRDz8ybYwqS2ik8N4hsIpiVTyFeo2lBQ== + dependencies: + chalk "^2.0.1" + diff-sequences "^24.9.0" + jest-get-type "^24.9.0" + pretty-format "^24.9.0" + +jest-diff@^25.2.1: + version "25.5.0" + resolved "https://registry.yarnpkg.com/jest-diff/-/jest-diff-25.5.0.tgz#1dd26ed64f96667c068cef026b677dfa01afcfa9" + integrity sha512-z1kygetuPiREYdNIumRpAHY6RXiGmp70YHptjdaxTWGmA085W3iCnXNx0DhflK3vwrKmrRWyY1wUpkPMVxMK7A== + dependencies: + chalk "^3.0.0" + diff-sequences "^25.2.6" + jest-get-type "^25.2.6" + pretty-format "^25.5.0" + +jest-docblock@^24.3.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-docblock/-/jest-docblock-24.9.0.tgz#7970201802ba560e1c4092cc25cbedf5af5a8ce2" + integrity sha512-F1DjdpDMJMA1cN6He0FNYNZlo3yYmOtRUnktrT9Q37njYzC5WEaDdmbynIgy0L/IvXvvgsG8OsqhLPXTpfmZAA== + dependencies: + detect-newline "^2.1.0" + +jest-each@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-each/-/jest-each-24.9.0.tgz#eb2da602e2a610898dbc5f1f6df3ba86b55f8b05" + integrity sha512-ONi0R4BvW45cw8s2Lrx8YgbeXL1oCQ/wIDwmsM3CqM/nlblNCPmnC3IPQlMbRFZu3wKdQ2U8BqM6lh3LJ5Bsog== + dependencies: + "@jest/types" "^24.9.0" + chalk "^2.0.1" + jest-get-type "^24.9.0" + jest-util "^24.9.0" + pretty-format "^24.9.0" + +jest-environment-jsdom-fourteen@1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/jest-environment-jsdom-fourteen/-/jest-environment-jsdom-fourteen-1.0.1.tgz#4cd0042f58b4ab666950d96532ecb2fc188f96fb" + integrity sha512-DojMX1sY+at5Ep+O9yME34CdidZnO3/zfPh8UW+918C5fIZET5vCjfkegixmsi7AtdYfkr4bPlIzmWnlvQkP7Q== + dependencies: + "@jest/environment" "^24.3.0" + "@jest/fake-timers" "^24.3.0" + "@jest/types" "^24.3.0" + jest-mock "^24.0.0" + jest-util "^24.0.0" + jsdom "^14.1.0" + +jest-environment-jsdom@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-environment-jsdom/-/jest-environment-jsdom-24.9.0.tgz#4b0806c7fc94f95edb369a69cc2778eec2b7375b" + integrity sha512-Zv9FV9NBRzLuALXjvRijO2351DRQeLYXtpD4xNvfoVFw21IOKNhZAEUKcbiEtjTkm2GsJ3boMVgkaR7rN8qetA== + dependencies: + "@jest/environment" "^24.9.0" + "@jest/fake-timers" "^24.9.0" + "@jest/types" "^24.9.0" + jest-mock "^24.9.0" + jest-util "^24.9.0" + jsdom "^11.5.1" + +jest-environment-node@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-environment-node/-/jest-environment-node-24.9.0.tgz#333d2d2796f9687f2aeebf0742b519f33c1cbfd3" + integrity sha512-6d4V2f4nxzIzwendo27Tr0aFm+IXWa0XEUnaH6nU0FMaozxovt+sfRvh4J47wL1OvF83I3SSTu0XK+i4Bqe7uA== + dependencies: + "@jest/environment" "^24.9.0" + "@jest/fake-timers" "^24.9.0" + "@jest/types" "^24.9.0" + jest-mock "^24.9.0" + jest-util "^24.9.0" + +jest-get-type@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-get-type/-/jest-get-type-24.9.0.tgz#1684a0c8a50f2e4901b6644ae861f579eed2ef0e" + integrity sha512-lUseMzAley4LhIcpSP9Jf+fTrQ4a1yHQwLNeeVa2cEmbCGeoZAtYPOIv8JaxLD/sUpKxetKGP+gsHl8f8TSj8Q== + +jest-get-type@^25.2.6: + version "25.2.6" + resolved "https://registry.yarnpkg.com/jest-get-type/-/jest-get-type-25.2.6.tgz#0b0a32fab8908b44d508be81681487dbabb8d877" + integrity sha512-DxjtyzOHjObRM+sM1knti6or+eOgcGU4xVSb2HNP1TqO4ahsT+rqZg+nyqHWJSvWgKC5cG3QjGFBqxLghiF/Ig== + +jest-haste-map@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-haste-map/-/jest-haste-map-24.9.0.tgz#b38a5d64274934e21fa417ae9a9fbeb77ceaac7d" + integrity sha512-kfVFmsuWui2Sj1Rp1AJ4D9HqJwE4uwTlS/vO+eRUaMmd54BFpli2XhMQnPC2k4cHFVbB2Q2C+jtI1AGLgEnCjQ== + dependencies: + "@jest/types" "^24.9.0" + anymatch "^2.0.0" + fb-watchman "^2.0.0" + graceful-fs "^4.1.15" + invariant "^2.2.4" + jest-serializer "^24.9.0" + jest-util "^24.9.0" + jest-worker "^24.9.0" + micromatch "^3.1.10" + sane "^4.0.3" + walker "^1.0.7" + optionalDependencies: + fsevents "^1.2.7" + +jest-jasmine2@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-jasmine2/-/jest-jasmine2-24.9.0.tgz#1f7b1bd3242c1774e62acabb3646d96afc3be6a0" + integrity sha512-Cq7vkAgaYKp+PsX+2/JbTarrk0DmNhsEtqBXNwUHkdlbrTBLtMJINADf2mf5FkowNsq8evbPc07/qFO0AdKTzw== + dependencies: + "@babel/traverse" "^7.1.0" + "@jest/environment" "^24.9.0" + "@jest/test-result" "^24.9.0" + "@jest/types" "^24.9.0" + chalk "^2.0.1" + co "^4.6.0" + expect "^24.9.0" + is-generator-fn "^2.0.0" + jest-each "^24.9.0" + jest-matcher-utils "^24.9.0" + jest-message-util "^24.9.0" + jest-runtime "^24.9.0" + jest-snapshot "^24.9.0" + jest-util "^24.9.0" + pretty-format "^24.9.0" + throat "^4.0.0" + +jest-leak-detector@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-leak-detector/-/jest-leak-detector-24.9.0.tgz#b665dea7c77100c5c4f7dfcb153b65cf07dcf96a" + integrity sha512-tYkFIDsiKTGwb2FG1w8hX9V0aUb2ot8zY/2nFg087dUageonw1zrLMP4W6zsRO59dPkTSKie+D4rhMuP9nRmrA== + dependencies: + jest-get-type "^24.9.0" + pretty-format "^24.9.0" + +jest-matcher-utils@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-matcher-utils/-/jest-matcher-utils-24.9.0.tgz#f5b3661d5e628dffe6dd65251dfdae0e87c3a073" + integrity sha512-OZz2IXsu6eaiMAwe67c1T+5tUAtQyQx27/EMEkbFAGiw52tB9em+uGbzpcgYVpA8wl0hlxKPZxrly4CXU/GjHA== + dependencies: + chalk "^2.0.1" + jest-diff "^24.9.0" + jest-get-type "^24.9.0" + pretty-format "^24.9.0" + +jest-message-util@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-message-util/-/jest-message-util-24.9.0.tgz#527f54a1e380f5e202a8d1149b0ec872f43119e3" + integrity sha512-oCj8FiZ3U0hTP4aSui87P4L4jC37BtQwUMqk+zk/b11FR19BJDeZsZAvIHutWnmtw7r85UmR3CEWZ0HWU2mAlw== + dependencies: + "@babel/code-frame" "^7.0.0" + "@jest/test-result" "^24.9.0" + "@jest/types" "^24.9.0" + "@types/stack-utils" "^1.0.1" + chalk "^2.0.1" + micromatch "^3.1.10" + slash "^2.0.0" + stack-utils "^1.0.1" + +jest-mock@^24.0.0, jest-mock@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-mock/-/jest-mock-24.9.0.tgz#c22835541ee379b908673ad51087a2185c13f1c6" + integrity sha512-3BEYN5WbSq9wd+SyLDES7AHnjH9A/ROBwmz7l2y+ol+NtSFO8DYiEBzoO1CeFc9a8DYy10EO4dDFVv/wN3zl1w== + dependencies: + "@jest/types" "^24.9.0" + +jest-pnp-resolver@^1.2.1: + version "1.2.1" + resolved "https://registry.yarnpkg.com/jest-pnp-resolver/-/jest-pnp-resolver-1.2.1.tgz#ecdae604c077a7fbc70defb6d517c3c1c898923a" + integrity sha512-pgFw2tm54fzgYvc/OHrnysABEObZCUNFnhjoRjaVOCN8NYc032/gVjPaHD4Aq6ApkSieWtfKAFQtmDKAmhupnQ== + +jest-regex-util@^24.3.0, jest-regex-util@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-regex-util/-/jest-regex-util-24.9.0.tgz#c13fb3380bde22bf6575432c493ea8fe37965636" + integrity sha512-05Cmb6CuxaA+Ys6fjr3PhvV3bGQmO+2p2La4hFbU+W5uOc479f7FdLXUWXw4pYMAhhSZIuKHwSXSu6CsSBAXQA== + +jest-resolve-dependencies@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-resolve-dependencies/-/jest-resolve-dependencies-24.9.0.tgz#ad055198959c4cfba8a4f066c673a3f0786507ab" + integrity sha512-Fm7b6AlWnYhT0BXy4hXpactHIqER7erNgIsIozDXWl5dVm+k8XdGVe1oTg1JyaFnOxarMEbax3wyRJqGP2Pq+g== + dependencies: + "@jest/types" "^24.9.0" + jest-regex-util "^24.3.0" + jest-snapshot "^24.9.0" + +jest-resolve@24.9.0, jest-resolve@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-resolve/-/jest-resolve-24.9.0.tgz#dff04c7687af34c4dd7e524892d9cf77e5d17321" + integrity sha512-TaLeLVL1l08YFZAt3zaPtjiVvyy4oSA6CRe+0AFPPVX3Q/VI0giIWWoAvoS5L96vj9Dqxj4fB5p2qrHCmTU/MQ== + dependencies: + "@jest/types" "^24.9.0" + browser-resolve "^1.11.3" + chalk "^2.0.1" + jest-pnp-resolver "^1.2.1" + realpath-native "^1.1.0" + +jest-runner@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-runner/-/jest-runner-24.9.0.tgz#574fafdbd54455c2b34b4bdf4365a23857fcdf42" + integrity sha512-KksJQyI3/0mhcfspnxxEOBueGrd5E4vV7ADQLT9ESaCzz02WnbdbKWIf5Mkaucoaj7obQckYPVX6JJhgUcoWWg== + dependencies: + "@jest/console" "^24.7.1" + "@jest/environment" "^24.9.0" + "@jest/test-result" "^24.9.0" + "@jest/types" "^24.9.0" + chalk "^2.4.2" + exit "^0.1.2" + graceful-fs "^4.1.15" + jest-config "^24.9.0" + jest-docblock "^24.3.0" + jest-haste-map "^24.9.0" + jest-jasmine2 "^24.9.0" + jest-leak-detector "^24.9.0" + jest-message-util "^24.9.0" + jest-resolve "^24.9.0" + jest-runtime "^24.9.0" + jest-util "^24.9.0" + jest-worker "^24.6.0" + source-map-support "^0.5.6" + throat "^4.0.0" + +jest-runtime@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-runtime/-/jest-runtime-24.9.0.tgz#9f14583af6a4f7314a6a9d9f0226e1a781c8e4ac" + integrity sha512-8oNqgnmF3v2J6PVRM2Jfuj8oX3syKmaynlDMMKQ4iyzbQzIG6th5ub/lM2bCMTmoTKM3ykcUYI2Pw9xwNtjMnw== + dependencies: + "@jest/console" "^24.7.1" + "@jest/environment" "^24.9.0" + "@jest/source-map" "^24.3.0" + "@jest/transform" "^24.9.0" + "@jest/types" "^24.9.0" + "@types/yargs" "^13.0.0" + chalk "^2.0.1" + exit "^0.1.2" + glob "^7.1.3" + graceful-fs "^4.1.15" + jest-config "^24.9.0" + jest-haste-map "^24.9.0" + jest-message-util "^24.9.0" + jest-mock "^24.9.0" + jest-regex-util "^24.3.0" + jest-resolve "^24.9.0" + jest-snapshot "^24.9.0" + jest-util "^24.9.0" + jest-validate "^24.9.0" + realpath-native "^1.1.0" + slash "^2.0.0" + strip-bom "^3.0.0" + yargs "^13.3.0" + +jest-serializer@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-serializer/-/jest-serializer-24.9.0.tgz#e6d7d7ef96d31e8b9079a714754c5d5c58288e73" + integrity sha512-DxYipDr8OvfrKH3Kel6NdED3OXxjvxXZ1uIY2I9OFbGg+vUkkg7AGvi65qbhbWNPvDckXmzMPbK3u3HaDO49bQ== + +jest-snapshot@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-snapshot/-/jest-snapshot-24.9.0.tgz#ec8e9ca4f2ec0c5c87ae8f925cf97497b0e951ba" + integrity sha512-uI/rszGSs73xCM0l+up7O7a40o90cnrk429LOiK3aeTvfC0HHmldbd81/B7Ix81KSFe1lwkbl7GnBGG4UfuDew== + dependencies: + "@babel/types" "^7.0.0" + "@jest/types" "^24.9.0" + chalk "^2.0.1" + expect "^24.9.0" + jest-diff "^24.9.0" + jest-get-type "^24.9.0" + jest-matcher-utils "^24.9.0" + jest-message-util "^24.9.0" + jest-resolve "^24.9.0" + mkdirp "^0.5.1" + natural-compare "^1.4.0" + pretty-format "^24.9.0" + semver "^6.2.0" + +jest-util@^24.0.0, jest-util@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-util/-/jest-util-24.9.0.tgz#7396814e48536d2e85a37de3e4c431d7cb140162" + integrity sha512-x+cZU8VRmOJxbA1K5oDBdxQmdq0OIdADarLxk0Mq+3XS4jgvhG/oKGWcIDCtPG0HgjxOYvF+ilPJQsAyXfbNOg== + dependencies: + "@jest/console" "^24.9.0" + "@jest/fake-timers" "^24.9.0" + "@jest/source-map" "^24.9.0" + "@jest/test-result" "^24.9.0" + "@jest/types" "^24.9.0" + callsites "^3.0.0" + chalk "^2.0.1" + graceful-fs "^4.1.15" + is-ci "^2.0.0" + mkdirp "^0.5.1" + slash "^2.0.0" + source-map "^0.6.0" + +jest-validate@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-validate/-/jest-validate-24.9.0.tgz#0775c55360d173cd854e40180756d4ff52def8ab" + integrity sha512-HPIt6C5ACwiqSiwi+OfSSHbK8sG7akG8eATl+IPKaeIjtPOeBUd/g3J7DghugzxrGjI93qS/+RPKe1H6PqvhRQ== + dependencies: + "@jest/types" "^24.9.0" + camelcase "^5.3.1" + chalk "^2.0.1" + jest-get-type "^24.9.0" + leven "^3.1.0" + pretty-format "^24.9.0" + +jest-watch-typeahead@0.4.2: + version "0.4.2" + resolved "https://registry.yarnpkg.com/jest-watch-typeahead/-/jest-watch-typeahead-0.4.2.tgz#e5be959698a7fa2302229a5082c488c3c8780a4a" + integrity sha512-f7VpLebTdaXs81rg/oj4Vg/ObZy2QtGzAmGLNsqUS5G5KtSN68tFcIsbvNODfNyQxU78g7D8x77o3bgfBTR+2Q== + dependencies: + ansi-escapes "^4.2.1" + chalk "^2.4.1" + jest-regex-util "^24.9.0" + jest-watcher "^24.3.0" + slash "^3.0.0" + string-length "^3.1.0" + strip-ansi "^5.0.0" + +jest-watcher@^24.3.0, jest-watcher@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-watcher/-/jest-watcher-24.9.0.tgz#4b56e5d1ceff005f5b88e528dc9afc8dd4ed2b3b" + integrity sha512-+/fLOfKPXXYJDYlks62/4R4GoT+GU1tYZed99JSCOsmzkkF7727RqKrjNAxtfO4YpGv11wybgRvCjR73lK2GZw== + dependencies: + "@jest/test-result" "^24.9.0" + "@jest/types" "^24.9.0" + "@types/yargs" "^13.0.0" + ansi-escapes "^3.0.0" + chalk "^2.0.1" + jest-util "^24.9.0" + string-length "^2.0.0" + +jest-worker@^24.6.0, jest-worker@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest-worker/-/jest-worker-24.9.0.tgz#5dbfdb5b2d322e98567898238a9697bcce67b3e5" + integrity sha512-51PE4haMSXcHohnSMdM42anbvZANYTqMrr52tVKPqqsPJMzoP6FYYDVqahX/HrAoKEKz3uUPzSvKs9A3qR4iVw== + dependencies: + merge-stream "^2.0.0" + supports-color "^6.1.0" + +jest-worker@^25.1.0: + version "25.2.1" + resolved "https://registry.yarnpkg.com/jest-worker/-/jest-worker-25.2.1.tgz#209617015c768652646aa33a7828cc2ab472a18a" + integrity sha512-IHnpekk8H/hCUbBlfeaPZzU6v75bqwJp3n4dUrQuQOAgOneI4tx3jV2o8pvlXnDfcRsfkFIUD//HWXpCmR+evQ== + dependencies: + merge-stream "^2.0.0" + supports-color "^7.0.0" + +jest@24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/jest/-/jest-24.9.0.tgz#987d290c05a08b52c56188c1002e368edb007171" + integrity sha512-YvkBL1Zm7d2B1+h5fHEOdyjCG+sGMz4f8D86/0HiqJ6MB4MnDc8FgP5vdWsGnemOQro7lnYo8UakZ3+5A0jxGw== + dependencies: + import-local "^2.0.0" + jest-cli "^24.9.0" + +js-levenshtein@^1.1.6: + version "1.1.6" + resolved "https://registry.yarnpkg.com/js-levenshtein/-/js-levenshtein-1.1.6.tgz#c6cee58eb3550372df8deb85fad5ce66ce01d59d" + integrity sha512-X2BB11YZtrRqY4EnQcLX5Rh373zbK4alC1FW7D7MBhL2gtcC17cTnr6DmfHZeS0s2rTHjUTMMHfG7gO8SSdw+g== + +"js-tokens@^3.0.0 || ^4.0.0", js-tokens@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499" + integrity sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ== + +js-tokens@^3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-3.0.2.tgz#9866df395102130e38f7f996bceb65443209c25b" + integrity sha1-mGbfOVECEw449/mWvOtlRDIJwls= + +js-yaml@^3.13.1: + version "3.13.1" + resolved "https://registry.yarnpkg.com/js-yaml/-/js-yaml-3.13.1.tgz#aff151b30bfdfa8e49e05da22e7415e9dfa37847" + integrity sha512-YfbcO7jXDdyj0DGxYVSlSeQNHbD7XPWvrVWeVUujrQEoZzWJIRrCPoyk6kL6IAjAG2IolMK4T0hNUe0HOUs5Jw== + dependencies: + argparse "^1.0.7" + esprima "^4.0.0" + +jsbn@~0.1.0: + version "0.1.1" + resolved "https://registry.yarnpkg.com/jsbn/-/jsbn-0.1.1.tgz#a5e654c2e5a2deb5f201d96cefbca80c0ef2f513" + integrity sha1-peZUwuWi3rXyAdls77yoDA7y9RM= + +jsdom@^11.5.1: + version "11.12.0" + resolved "https://registry.yarnpkg.com/jsdom/-/jsdom-11.12.0.tgz#1a80d40ddd378a1de59656e9e6dc5a3ba8657bc8" + integrity sha512-y8Px43oyiBM13Zc1z780FrfNLJCXTL40EWlty/LXUtcjykRBNgLlCjWXpfSPBl2iv+N7koQN+dvqszHZgT/Fjw== + dependencies: + abab "^2.0.0" + acorn "^5.5.3" + acorn-globals "^4.1.0" + array-equal "^1.0.0" + cssom ">= 0.3.2 < 0.4.0" + cssstyle "^1.0.0" + data-urls "^1.0.0" + domexception "^1.0.1" + escodegen "^1.9.1" + html-encoding-sniffer "^1.0.2" + left-pad "^1.3.0" + nwsapi "^2.0.7" + parse5 "4.0.0" + pn "^1.1.0" + request "^2.87.0" + request-promise-native "^1.0.5" + sax "^1.2.4" + symbol-tree "^3.2.2" + tough-cookie "^2.3.4" + w3c-hr-time "^1.0.1" + webidl-conversions "^4.0.2" + whatwg-encoding "^1.0.3" + whatwg-mimetype "^2.1.0" + whatwg-url "^6.4.1" + ws "^5.2.0" + xml-name-validator "^3.0.0" + +jsdom@^14.1.0: + version "14.1.0" + resolved "https://registry.yarnpkg.com/jsdom/-/jsdom-14.1.0.tgz#916463b6094956b0a6c1782c94e380cd30e1981b" + integrity sha512-O901mfJSuTdwU2w3Sn+74T+RnDVP+FuV5fH8tcPWyqrseRAb0s5xOtPgCFiPOtLcyK7CLIJwPyD83ZqQWvA5ng== + dependencies: + abab "^2.0.0" + acorn "^6.0.4" + acorn-globals "^4.3.0" + array-equal "^1.0.0" + cssom "^0.3.4" + cssstyle "^1.1.1" + data-urls "^1.1.0" + domexception "^1.0.1" + escodegen "^1.11.0" + html-encoding-sniffer "^1.0.2" + nwsapi "^2.1.3" + parse5 "5.1.0" + pn "^1.1.0" + request "^2.88.0" + request-promise-native "^1.0.5" + saxes "^3.1.9" + symbol-tree "^3.2.2" + tough-cookie "^2.5.0" + w3c-hr-time "^1.0.1" + w3c-xmlserializer "^1.1.2" + webidl-conversions "^4.0.2" + whatwg-encoding "^1.0.5" + whatwg-mimetype "^2.3.0" + whatwg-url "^7.0.0" + ws "^6.1.2" + xml-name-validator "^3.0.0" + +jsesc@^2.5.1: + version "2.5.2" + resolved "https://registry.yarnpkg.com/jsesc/-/jsesc-2.5.2.tgz#80564d2e483dacf6e8ef209650a67df3f0c283a4" + integrity sha512-OYu7XEzjkCQ3C5Ps3QIZsQfNpqoJyZZA99wd9aWd05NCtC5pWOkShK2mkL6HXQR6/Cy2lbNdPlZBpuQHXE63gA== + +jsesc@~0.5.0: + version "0.5.0" + resolved "https://registry.yarnpkg.com/jsesc/-/jsesc-0.5.0.tgz#e7dee66e35d6fc16f710fe91d5cf69f70f08911d" + integrity sha1-597mbjXW/Bb3EP6R1c9p9w8IkR0= + +json-parse-better-errors@^1.0.1, json-parse-better-errors@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/json-parse-better-errors/-/json-parse-better-errors-1.0.2.tgz#bb867cfb3450e69107c131d1c514bab3dc8bcaa9" + integrity sha512-mrqyZKfX5EhL7hvqcV6WG1yYjnjeuYDzDhhcAAUrq8Po85NBQBJP+ZDUT75qZQ98IkUoBqdkExkukOU7Ts2wrw== + +json-schema-traverse@^0.4.1: + version "0.4.1" + resolved "https://registry.yarnpkg.com/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz#69f6a87d9513ab8bb8fe63bdb0979c448e684660" + integrity sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg== + +json-schema@0.2.3: + version "0.2.3" + resolved "https://registry.yarnpkg.com/json-schema/-/json-schema-0.2.3.tgz#b480c892e59a2f05954ce727bd3f2a4e882f9e13" + integrity sha1-tIDIkuWaLwWVTOcnvT8qTogvnhM= + +json-stable-stringify-without-jsonify@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/json-stable-stringify-without-jsonify/-/json-stable-stringify-without-jsonify-1.0.1.tgz#9db7b59496ad3f3cfef30a75142d2d930ad72651" + integrity sha1-nbe1lJatPzz+8wp1FC0tkwrXJlE= + +json-stable-stringify@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/json-stable-stringify/-/json-stable-stringify-1.0.1.tgz#9a759d39c5f2ff503fd5300646ed445f88c4f9af" + integrity sha1-mnWdOcXy/1A/1TAGRu1EX4jE+a8= + dependencies: + jsonify "~0.0.0" + +json-stringify-safe@~5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz#1296a2d58fd45f19a0f6ce01d65701e2c735b6eb" + integrity sha1-Epai1Y/UXxmg9s4B1lcB4sc1tus= + +json3@^3.3.2: + version "3.3.3" + resolved "https://registry.yarnpkg.com/json3/-/json3-3.3.3.tgz#7fc10e375fc5ae42c4705a5cc0aa6f62be305b81" + integrity sha512-c7/8mbUsKigAbLkD5B010BK4D9LZm7A1pNItkEwiUZRpIN66exu/e7YQWysGun+TRKaJp8MhemM+VkfWv42aCA== + +json5@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/json5/-/json5-1.0.1.tgz#779fb0018604fa854eacbf6252180d83543e3dbe" + integrity sha512-aKS4WQjPenRxiQsC93MNfjx+nbF4PAdYzmd/1JIj8HYzqfbu86beTuNgXDzPknWk0n0uARlyewZo4s++ES36Ow== + dependencies: + minimist "^1.2.0" + +json5@^2.1.2: + version "2.1.2" + resolved "https://registry.yarnpkg.com/json5/-/json5-2.1.2.tgz#43ef1f0af9835dd624751a6b7fa48874fb2d608e" + integrity sha512-MoUOQ4WdiN3yxhm7NEVJSJrieAo5hNSLQ5sj05OTRHPL9HOBy8u4Bu88jsC1jvqAdN+E1bJmsUcZH+1HQxliqQ== + dependencies: + minimist "^1.2.5" + +jsonfile@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/jsonfile/-/jsonfile-4.0.0.tgz#8771aae0799b64076b76640fca058f9c10e33ecb" + integrity sha1-h3Gq4HmbZAdrdmQPygWPnBDjPss= + optionalDependencies: + graceful-fs "^4.1.6" + +jsonify@~0.0.0: + version "0.0.0" + resolved "https://registry.yarnpkg.com/jsonify/-/jsonify-0.0.0.tgz#2c74b6ee41d93ca51b7b5aaee8f503631d252a73" + integrity sha1-LHS27kHZPKUbe1qu6PUDYx0lKnM= + +jsprim@^1.2.2: + version "1.4.1" + resolved "https://registry.yarnpkg.com/jsprim/-/jsprim-1.4.1.tgz#313e66bc1e5cc06e438bc1b7499c2e5c56acb6a2" + integrity sha1-MT5mvB5cwG5Di8G3SZwuXFastqI= + dependencies: + assert-plus "1.0.0" + extsprintf "1.3.0" + json-schema "0.2.3" + verror "1.10.0" + +jsx-ast-utils@^2.2.1, jsx-ast-utils@^2.2.3: + version "2.2.3" + resolved "https://registry.yarnpkg.com/jsx-ast-utils/-/jsx-ast-utils-2.2.3.tgz#8a9364e402448a3ce7f14d357738310d9248054f" + integrity sha512-EdIHFMm+1BPynpKOpdPqiOsvnIrInRGJD7bzPZdPkjitQEqpdpUuFpq4T0npZFKTiB3RhWFdGN+oqOJIdhDhQA== + dependencies: + array-includes "^3.0.3" + object.assign "^4.1.0" + +killable@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/killable/-/killable-1.0.1.tgz#4c8ce441187a061c7474fb87ca08e2a638194892" + integrity sha512-LzqtLKlUwirEUyl/nicirVmNiPvYs7l5n8wOPP7fyJVpUPkvCnW/vuiXGpylGUlnPDnB7311rARzAt3Mhswpjg== + +kind-of@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/kind-of/-/kind-of-2.0.1.tgz#018ec7a4ce7e3a86cb9141be519d24c8faa981b5" + integrity sha1-AY7HpM5+OobLkUG+UZ0kyPqpgbU= + dependencies: + is-buffer "^1.0.2" + +kind-of@^3.0.2, kind-of@^3.0.3, kind-of@^3.2.0: + version "3.2.2" + resolved "https://registry.yarnpkg.com/kind-of/-/kind-of-3.2.2.tgz#31ea21a734bab9bbb0f32466d893aea51e4a3c64" + integrity sha1-MeohpzS6ubuw8yRm2JOupR5KPGQ= + dependencies: + is-buffer "^1.1.5" + +kind-of@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/kind-of/-/kind-of-4.0.0.tgz#20813df3d712928b207378691a45066fae72dd57" + integrity sha1-IIE989cSkosgc3hpGkUGb65y3Vc= + dependencies: + is-buffer "^1.1.5" + +kind-of@^5.0.0: + version "5.1.0" + resolved "https://registry.yarnpkg.com/kind-of/-/kind-of-5.1.0.tgz#729c91e2d857b7a419a1f9aa65685c4c33f5845d" + integrity sha512-NGEErnH6F2vUuXDh+OlbcKW7/wOcfdRHaZ7VWtqCztfHri/++YKmP51OdWeGPuqCOba6kk2OTe5d02VmTB80Pw== + +kind-of@^6.0.0, kind-of@^6.0.2: + version "6.0.3" + resolved "https://registry.yarnpkg.com/kind-of/-/kind-of-6.0.3.tgz#07c05034a6c349fa06e24fa35aa76db4580ce4dd" + integrity sha512-dcS1ul+9tmeD95T+x28/ehLgd9mENa3LsvDTtzm3vyBEO7RPptvAD+t44WVXaUjTBRcrpFeFlC8WCruUR456hw== + +kleur@^3.0.3: + version "3.0.3" + resolved "https://registry.yarnpkg.com/kleur/-/kleur-3.0.3.tgz#a79c9ecc86ee1ce3fa6206d1216c501f147fc07e" + integrity sha512-eTIzlVOSUR+JxdDFepEYcBMtZ9Qqdef+rnzWdRZuMbOywu5tO2w2N7rqjoANZ5k9vywhL6Br1VRjUIgTQx4E8w== + +last-call-webpack-plugin@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/last-call-webpack-plugin/-/last-call-webpack-plugin-3.0.0.tgz#9742df0e10e3cf46e5c0381c2de90d3a7a2d7555" + integrity sha512-7KI2l2GIZa9p2spzPIVZBYyNKkN+e/SQPpnjlTiPhdbDW3F86tdKKELxKpzJ5sgU19wQWsACULZmpTPYHeWO5w== + dependencies: + lodash "^4.17.5" + webpack-sources "^1.1.0" + +lazy-cache@^0.2.3: + version "0.2.7" + resolved "https://registry.yarnpkg.com/lazy-cache/-/lazy-cache-0.2.7.tgz#7feddf2dcb6edb77d11ef1d117ab5ffdf0ab1b65" + integrity sha1-f+3fLctu23fRHvHRF6tf/fCrG2U= + +lazy-cache@^1.0.3: + version "1.0.4" + resolved "https://registry.yarnpkg.com/lazy-cache/-/lazy-cache-1.0.4.tgz#a1d78fc3a50474cb80845d3b3b6e1da49a446e8e" + integrity sha1-odePw6UEdMuAhF07O24dpJpEbo4= + +lcid@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/lcid/-/lcid-2.0.0.tgz#6ef5d2df60e52f82eb228a4c373e8d1f397253cf" + integrity sha512-avPEb8P8EGnwXKClwsNUgryVjllcRqtMYa49NTsbQagYuT1DcXnl1915oxWjoyGrXR6zH/Y0Zc96xWsPcoDKeA== + dependencies: + invert-kv "^2.0.0" + +left-pad@^1.3.0: + version "1.3.0" + resolved "https://registry.yarnpkg.com/left-pad/-/left-pad-1.3.0.tgz#5b8a3a7765dfe001261dde915589e782f8c94d1e" + integrity sha512-XI5MPzVNApjAyhQzphX8BkmKsKUxD4LdyK24iZeQGinBN9yTQT3bFlCBy/aVx2HrNcqQGsdot8ghrjyrvMCoEA== + +leven@^3.1.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/leven/-/leven-3.1.0.tgz#77891de834064cccba82ae7842bb6b14a13ed7f2" + integrity sha512-qsda+H8jTaUaN/x5vzW2rzc+8Rw4TAQ/4KjB46IwK5VH+IlVeeeje/EoZRpiXvIqjFgK84QffqPztGI3VBLG1A== + +levenary@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/levenary/-/levenary-1.1.1.tgz#842a9ee98d2075aa7faeedbe32679e9205f46f77" + integrity sha512-mkAdOIt79FD6irqjYSs4rdbnlT5vRonMEvBVPVb3XmevfS8kgRXwfes0dhPdEtzTWD/1eNE/Bm/G1iRt6DcnQQ== + dependencies: + leven "^3.1.0" + +levn@^0.3.0, levn@~0.3.0: + version "0.3.0" + resolved "https://registry.yarnpkg.com/levn/-/levn-0.3.0.tgz#3b09924edf9f083c0490fdd4c0bc4421e04764ee" + integrity sha1-OwmSTt+fCDwEkP3UwLxEIeBHZO4= + dependencies: + prelude-ls "~1.1.2" + type-check "~0.3.2" + +lines-and-columns@^1.1.6: + version "1.1.6" + resolved "https://registry.yarnpkg.com/lines-and-columns/-/lines-and-columns-1.1.6.tgz#1c00c743b433cd0a4e80758f7b64a57440d9ff00" + integrity sha1-HADHQ7QzzQpOgHWPe2SldEDZ/wA= + +load-json-file@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/load-json-file/-/load-json-file-2.0.0.tgz#7947e42149af80d696cbf797bcaabcfe1fe29ca8" + integrity sha1-eUfkIUmvgNaWy/eXvKq8/h/inKg= + dependencies: + graceful-fs "^4.1.2" + parse-json "^2.2.0" + pify "^2.0.0" + strip-bom "^3.0.0" + +load-json-file@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/load-json-file/-/load-json-file-4.0.0.tgz#2f5f45ab91e33216234fd53adab668eb4ec0993b" + integrity sha1-L19Fq5HjMhYjT9U62rZo607AmTs= + dependencies: + graceful-fs "^4.1.2" + parse-json "^4.0.0" + pify "^3.0.0" + strip-bom "^3.0.0" + +loader-fs-cache@^1.0.2: + version "1.0.3" + resolved "https://registry.yarnpkg.com/loader-fs-cache/-/loader-fs-cache-1.0.3.tgz#f08657646d607078be2f0a032f8bd69dd6f277d9" + integrity sha512-ldcgZpjNJj71n+2Mf6yetz+c9bM4xpKtNds4LbqXzU/PTdeAX0g3ytnU1AJMEcTk2Lex4Smpe3Q/eCTsvUBxbA== + dependencies: + find-cache-dir "^0.1.1" + mkdirp "^0.5.1" + +loader-runner@^2.4.0: + version "2.4.0" + resolved "https://registry.yarnpkg.com/loader-runner/-/loader-runner-2.4.0.tgz#ed47066bfe534d7e84c4c7b9998c2a75607d9357" + integrity sha512-Jsmr89RcXGIwivFY21FcRrisYZfvLMTWx5kOLc+JTxtpBOG6xML0vzbc6SEQG2FO9/4Fc3wW4LVcB5DmGflaRw== + +loader-utils@1.2.3: + version "1.2.3" + resolved "https://registry.yarnpkg.com/loader-utils/-/loader-utils-1.2.3.tgz#1ff5dc6911c9f0a062531a4c04b609406108c2c7" + integrity sha512-fkpz8ejdnEMG3s37wGL07iSBDg99O9D5yflE9RGNH3hRdx9SOwYfnGYdZOUIZitN8E+E2vkq3MUMYMvPYl5ZZA== + dependencies: + big.js "^5.2.2" + emojis-list "^2.0.0" + json5 "^1.0.1" + +loader-utils@^1.1.0, loader-utils@^1.2.3, loader-utils@^1.4.0: + version "1.4.0" + resolved "https://registry.yarnpkg.com/loader-utils/-/loader-utils-1.4.0.tgz#c579b5e34cb34b1a74edc6c1fb36bfa371d5a613" + integrity sha512-qH0WSMBtn/oHuwjy/NucEgbx5dbxxnxup9s4PVXJUDHZBQY+s0NWA9rJf53RBnQZxfch7euUui7hpoAPvALZdA== + dependencies: + big.js "^5.2.2" + emojis-list "^3.0.0" + json5 "^1.0.1" + +locate-path@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-2.0.0.tgz#2b568b265eec944c6d9c0de9c3dbbbca0354cd8e" + integrity sha1-K1aLJl7slExtnA3pw9u7ygNUzY4= + dependencies: + p-locate "^2.0.0" + path-exists "^3.0.0" + +locate-path@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-3.0.0.tgz#dbec3b3ab759758071b58fe59fc41871af21400e" + integrity sha512-7AO748wWnIhNqAuaty2ZWHkQHRSNfPVIsPIfwEOWO22AmaoVrWavlOcMR5nzTLNYvp36X220/maaRsrec1G65A== + dependencies: + p-locate "^3.0.0" + path-exists "^3.0.0" + +locate-path@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-5.0.0.tgz#1afba396afd676a6d42504d0a67a3a7eb9f62aa0" + integrity sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g== + dependencies: + p-locate "^4.1.0" + +lodash._reinterpolate@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/lodash._reinterpolate/-/lodash._reinterpolate-3.0.0.tgz#0ccf2d89166af03b3663c796538b75ac6e114d9d" + integrity sha1-DM8tiRZq8Ds2Y8eWU4t1rG4RTZ0= + +lodash.memoize@^4.1.2: + version "4.1.2" + resolved "https://registry.yarnpkg.com/lodash.memoize/-/lodash.memoize-4.1.2.tgz#bcc6c49a42a2840ed997f323eada5ecd182e0bfe" + integrity sha1-vMbEmkKihA7Zl/Mj6tpezRguC/4= + +lodash.sortby@^4.7.0: + version "4.7.0" + resolved "https://registry.yarnpkg.com/lodash.sortby/-/lodash.sortby-4.7.0.tgz#edd14c824e2cc9c1e0b0a1b42bb5210516a42438" + integrity sha1-7dFMgk4sycHgsKG0K7UhBRakJDg= + +lodash.template@^4.4.0, lodash.template@^4.5.0: + version "4.5.0" + resolved "https://registry.yarnpkg.com/lodash.template/-/lodash.template-4.5.0.tgz#f976195cf3f347d0d5f52483569fe8031ccce8ab" + integrity sha512-84vYFxIkmidUiFxidA/KjjH9pAycqW+h980j7Fuz5qxRtO9pgB7MDFTdys1N7A5mcucRiDyEq4fusljItR1T/A== + dependencies: + lodash._reinterpolate "^3.0.0" + lodash.templatesettings "^4.0.0" + +lodash.templatesettings@^4.0.0: + version "4.2.0" + resolved "https://registry.yarnpkg.com/lodash.templatesettings/-/lodash.templatesettings-4.2.0.tgz#e481310f049d3cf6d47e912ad09313b154f0fb33" + integrity sha512-stgLz+i3Aa9mZgnjr/O+v9ruKZsPsndy7qPZOchbqk2cnTU1ZaldKK+v7m54WoKIyxiuMZTKT2H81F8BeAc3ZQ== + dependencies: + lodash._reinterpolate "^3.0.0" + +lodash.uniq@^4.5.0: + version "4.5.0" + resolved "https://registry.yarnpkg.com/lodash.uniq/-/lodash.uniq-4.5.0.tgz#d0225373aeb652adc1bc82e4945339a842754773" + integrity sha1-0CJTc662Uq3BvILklFM5qEJ1R3M= + +"lodash@>=3.5 <5", lodash@^4.17.11, lodash@^4.17.13, lodash@^4.17.14, lodash@^4.17.15, lodash@^4.17.5: + version "4.17.15" + resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.15.tgz#b447f6670a0455bbfeedd11392eff330ea097548" + integrity sha512-8xOcRHvCjnocdS5cpwXQXVzmmh5e5+saE2QGoeQmbKmRS6J3VQppPOIt0MnmE+4xlZoumy0GPG0D0MVIQbNA1A== + +lodash@^4.17.19: + version "4.17.20" + resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.20.tgz#b44a9b6297bcb698f1c51a3545a2b3b368d59c52" + integrity sha512-PlhdFcillOINfeV7Ni6oF1TAEayyZBoZ8bcshTHqOYJYlrqzRK5hagpagky5o4HfCzzd1TRkXPMFq6cKk9rGmA== + +loglevel@^1.6.6: + version "1.6.7" + resolved "https://registry.yarnpkg.com/loglevel/-/loglevel-1.6.7.tgz#b3e034233188c68b889f5b862415306f565e2c56" + integrity sha512-cY2eLFrQSAfVPhCgH1s7JI73tMbg9YC3v3+ZHVW67sBS7UxWzNEk/ZBbSfLykBWHp33dqqtOv82gjhKEi81T/A== + +loose-envify@^1.0.0, loose-envify@^1.1.0, loose-envify@^1.4.0: + version "1.4.0" + resolved "https://registry.yarnpkg.com/loose-envify/-/loose-envify-1.4.0.tgz#71ee51fa7be4caec1a63839f7e682d8132d30caf" + integrity sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q== + dependencies: + js-tokens "^3.0.0 || ^4.0.0" + +lower-case@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/lower-case/-/lower-case-2.0.1.tgz#39eeb36e396115cc05e29422eaea9e692c9408c7" + integrity sha512-LiWgfDLLb1dwbFQZsSglpRj+1ctGnayXz3Uv0/WO8n558JycT5fg6zkNcnW0G68Nn0aEldTFeEfmjCfmqry/rQ== + dependencies: + tslib "^1.10.0" + +lru-cache@^5.1.1: + version "5.1.1" + resolved "https://registry.yarnpkg.com/lru-cache/-/lru-cache-5.1.1.tgz#1da27e6710271947695daf6848e847f01d84b920" + integrity sha512-KpNARQA3Iwv+jTA0utUVVbrh+Jlrr1Fv0e56GGzAFOXN7dk/FviaDW8LHmK52DlcH4WP2n6gI8vN1aesBFgo9w== + dependencies: + yallist "^3.0.2" + +make-dir@^2.0.0, make-dir@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/make-dir/-/make-dir-2.1.0.tgz#5f0310e18b8be898cc07009295a30ae41e91e6f5" + integrity sha512-LS9X+dc8KLxXCb8dni79fLIIUA5VyZoyjSMCwTluaXA0o27cCK0bhXkpgw+sTXVpPy/lSO57ilRixqk0vDmtRA== + dependencies: + pify "^4.0.1" + semver "^5.6.0" + +make-dir@^3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/make-dir/-/make-dir-3.0.2.tgz#04a1acbf22221e1d6ef43559f43e05a90dbb4392" + integrity sha512-rYKABKutXa6vXTXhoV18cBE7PaewPXHe/Bdq4v+ZLMhxbWApkFFplT0LcbMW+6BbjnQXzZ/sAvSE/JdguApG5w== + dependencies: + semver "^6.0.0" + +makeerror@1.0.x: + version "1.0.11" + resolved "https://registry.yarnpkg.com/makeerror/-/makeerror-1.0.11.tgz#e01a5c9109f2af79660e4e8b9587790184f5a96c" + integrity sha1-4BpckQnyr3lmDk6LlYd5AYT1qWw= + dependencies: + tmpl "1.0.x" + +mamacro@^0.0.3: + version "0.0.3" + resolved "https://registry.yarnpkg.com/mamacro/-/mamacro-0.0.3.tgz#ad2c9576197c9f1abf308d0787865bd975a3f3e4" + integrity sha512-qMEwh+UujcQ+kbz3T6V+wAmO2U8veoq2w+3wY8MquqwVA3jChfwY+Tk52GZKDfACEPjuZ7r2oJLejwpt8jtwTA== + +map-age-cleaner@^0.1.1: + version "0.1.3" + resolved "https://registry.yarnpkg.com/map-age-cleaner/-/map-age-cleaner-0.1.3.tgz#7d583a7306434c055fe474b0f45078e6e1b4b92a" + integrity sha512-bJzx6nMoP6PDLPBFmg7+xRKeFZvFboMrGlxmNj9ClvX53KrmvM5bXFXEWjbz4cz1AFn+jWJ9z/DJSz7hrs0w3w== + dependencies: + p-defer "^1.0.0" + +map-cache@^0.2.2: + version "0.2.2" + resolved "https://registry.yarnpkg.com/map-cache/-/map-cache-0.2.2.tgz#c32abd0bd6525d9b051645bb4f26ac5dc98a0dbf" + integrity sha1-wyq9C9ZSXZsFFkW7TyasXcmKDb8= + +map-visit@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/map-visit/-/map-visit-1.0.0.tgz#ecdca8f13144e660f1b5bd41f12f3479d98dfb8f" + integrity sha1-7Nyo8TFE5mDxtb1B8S80edmN+48= + dependencies: + object-visit "^1.0.0" + +md5.js@^1.3.4: + version "1.3.5" + resolved "https://registry.yarnpkg.com/md5.js/-/md5.js-1.3.5.tgz#b5d07b8e3216e3e27cd728d72f70d1e6a342005f" + integrity sha512-xitP+WxNPcTTOgnTJcrhM0xvdPepipPSf3I8EIpGKeFLjt3PlJLIDG3u8EX53ZIubkb+5U2+3rELYpEhHhzdkg== + dependencies: + hash-base "^3.0.0" + inherits "^2.0.1" + safe-buffer "^5.1.2" + +mdn-data@2.0.4: + version "2.0.4" + resolved "https://registry.yarnpkg.com/mdn-data/-/mdn-data-2.0.4.tgz#699b3c38ac6f1d728091a64650b65d388502fd5b" + integrity sha512-iV3XNKw06j5Q7mi6h+9vbx23Tv7JkjEVgKHW4pimwyDGWm0OIQntJJ+u1C6mg6mK1EaTv42XQ7w76yuzH7M2cA== + +mdn-data@2.0.6: + version "2.0.6" + resolved "https://registry.yarnpkg.com/mdn-data/-/mdn-data-2.0.6.tgz#852dc60fcaa5daa2e8cf6c9189c440ed3e042978" + integrity sha512-rQvjv71olwNHgiTbfPZFkJtjNMciWgswYeciZhtvWLO8bmX3TnhyA62I6sTWOyZssWHJJjY6/KiWwqQsWWsqOA== + +media-typer@0.3.0: + version "0.3.0" + resolved "https://registry.yarnpkg.com/media-typer/-/media-typer-0.3.0.tgz#8710d7af0aa626f8fffa1ce00168545263255748" + integrity sha1-hxDXrwqmJvj/+hzgAWhUUmMlV0g= + +mem@^4.0.0: + version "4.3.0" + resolved "https://registry.yarnpkg.com/mem/-/mem-4.3.0.tgz#461af497bc4ae09608cdb2e60eefb69bff744178" + integrity sha512-qX2bG48pTqYRVmDB37rn/6PT7LcR8T7oAX3bf99u1Tt1nzxYfxkgqDwUwolPlXweM0XzBOBFzSx4kfp7KP1s/w== + dependencies: + map-age-cleaner "^0.1.1" + mimic-fn "^2.0.0" + p-is-promise "^2.0.0" + +memory-fs@^0.4.0, memory-fs@^0.4.1: + version "0.4.1" + resolved "https://registry.yarnpkg.com/memory-fs/-/memory-fs-0.4.1.tgz#3a9a20b8462523e447cfbc7e8bb80ed667bfc552" + integrity sha1-OpoguEYlI+RHz7x+i7gO1me/xVI= + dependencies: + errno "^0.1.3" + readable-stream "^2.0.1" + +memory-fs@^0.5.0: + version "0.5.0" + resolved "https://registry.yarnpkg.com/memory-fs/-/memory-fs-0.5.0.tgz#324c01288b88652966d161db77838720845a8e3c" + integrity sha512-jA0rdU5KoQMC0e6ppoNRtpp6vjFq6+NY7r8hywnC7V+1Xj/MtHwGIbB1QaK/dunyjWteJzmkpd7ooeWg10T7GA== + dependencies: + errno "^0.1.3" + readable-stream "^2.0.1" + +merge-deep@^3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/merge-deep/-/merge-deep-3.0.2.tgz#f39fa100a4f1bd34ff29f7d2bf4508fbb8d83ad2" + integrity sha512-T7qC8kg4Zoti1cFd8Cr0M+qaZfOwjlPDEdZIIPPB2JZctjaPM4fX+i7HOId69tAti2fvO6X5ldfYUONDODsrkA== + dependencies: + arr-union "^3.1.0" + clone-deep "^0.2.4" + kind-of "^3.0.2" + +merge-descriptors@1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/merge-descriptors/-/merge-descriptors-1.0.1.tgz#b00aaa556dd8b44568150ec9d1b953f3f90cbb61" + integrity sha1-sAqqVW3YtEVoFQ7J0blT8/kMu2E= + +merge-stream@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/merge-stream/-/merge-stream-2.0.0.tgz#52823629a14dd00c9770fb6ad47dc6310f2c1f60" + integrity sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w== + +merge2@^1.2.3: + version "1.3.0" + resolved "https://registry.yarnpkg.com/merge2/-/merge2-1.3.0.tgz#5b366ee83b2f1582c48f87e47cf1a9352103ca81" + integrity sha512-2j4DAdlBOkiSZIsaXk4mTE3sRS02yBHAtfy127xRV3bQUFqXkjHCHLW6Scv7DwNRbIWNHH8zpnz9zMaKXIdvYw== + +methods@~1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/methods/-/methods-1.1.2.tgz#5529a4d67654134edcc5266656835b0f851afcee" + integrity sha1-VSmk1nZUE07cxSZmVoNbD4Ua/O4= + +microevent.ts@~0.1.1: + version "0.1.1" + resolved "https://registry.yarnpkg.com/microevent.ts/-/microevent.ts-0.1.1.tgz#70b09b83f43df5172d0205a63025bce0f7357fa0" + integrity sha512-jo1OfR4TaEwd5HOrt5+tAZ9mqT4jmpNAusXtyfNzqVm9uiSYFZlKM1wYL4oU7azZW/PxQW53wM0S6OR1JHNa2g== + +micromatch@^3.0.4, micromatch@^3.1.10, micromatch@^3.1.4: + version "3.1.10" + resolved "https://registry.yarnpkg.com/micromatch/-/micromatch-3.1.10.tgz#70859bc95c9840952f359a068a3fc49f9ecfac23" + integrity sha512-MWikgl9n9M3w+bpsY3He8L+w9eF9338xRl8IAO5viDizwSzziFEyUzo2xrrloB64ADbTf8uA8vRqqttDTOmccg== + dependencies: + arr-diff "^4.0.0" + array-unique "^0.3.2" + braces "^2.3.1" + define-property "^2.0.2" + extend-shallow "^3.0.2" + extglob "^2.0.4" + fragment-cache "^0.2.1" + kind-of "^6.0.2" + nanomatch "^1.2.9" + object.pick "^1.3.0" + regex-not "^1.0.0" + snapdragon "^0.8.1" + to-regex "^3.0.2" + +miller-rabin@^4.0.0: + version "4.0.1" + resolved "https://registry.yarnpkg.com/miller-rabin/-/miller-rabin-4.0.1.tgz#f080351c865b0dc562a8462966daa53543c78a4d" + integrity sha512-115fLhvZVqWwHPbClyntxEVfVDfl9DLLTuJvq3g2O/Oxi8AiNouAHvDSzHS0viUJc+V5vm3eq91Xwqn9dp4jRA== + dependencies: + bn.js "^4.0.0" + brorand "^1.0.1" + +mime-db@1.43.0, "mime-db@>= 1.43.0 < 2": + version "1.43.0" + resolved "https://registry.yarnpkg.com/mime-db/-/mime-db-1.43.0.tgz#0a12e0502650e473d735535050e7c8f4eb4fae58" + integrity sha512-+5dsGEEovYbT8UY9yD7eE4XTc4UwJ1jBYlgaQQF38ENsKR3wj/8q8RFZrF9WIZpB2V1ArTVFUva8sAul1NzRzQ== + +mime-types@^2.1.12, mime-types@~2.1.17, mime-types@~2.1.19, mime-types@~2.1.24: + version "2.1.26" + resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.26.tgz#9c921fc09b7e149a65dfdc0da4d20997200b0a06" + integrity sha512-01paPWYgLrkqAyrlDorC1uDwl2p3qZT7yl806vW7DvDoxwXi46jsjFbg+WdwotBIk6/MbEhO/dh5aZ5sNj/dWQ== + dependencies: + mime-db "1.43.0" + +mime@1.6.0: + version "1.6.0" + resolved "https://registry.yarnpkg.com/mime/-/mime-1.6.0.tgz#32cd9e5c64553bd58d19a568af452acff04981b1" + integrity sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg== + +mime@^2.4.4: + version "2.4.4" + resolved "https://registry.yarnpkg.com/mime/-/mime-2.4.4.tgz#bd7b91135fc6b01cde3e9bae33d659b63d8857e5" + integrity sha512-LRxmNwziLPT828z+4YkNzloCFC2YM4wrB99k+AV5ZbEyfGNWfG8SO1FUXLmLDBSo89NrJZ4DIWeLjy1CHGhMGA== + +mimic-fn@^2.0.0, mimic-fn@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-2.1.0.tgz#7ed2c2ccccaf84d3ffcb7a69b57711fc2083401b" + integrity sha512-OqbOk5oEQeAZ8WXWydlu9HJjz9WVdEIvamMCcXmuqUYjTknH/sqsWvhQ3vgwKFRR1HpjvNBKQ37nbJgYzGqGcg== + +mini-css-extract-plugin@0.9.0: + version "0.9.0" + resolved "https://registry.yarnpkg.com/mini-css-extract-plugin/-/mini-css-extract-plugin-0.9.0.tgz#47f2cf07aa165ab35733b1fc97d4c46c0564339e" + integrity sha512-lp3GeY7ygcgAmVIcRPBVhIkf8Us7FZjA+ILpal44qLdSu11wmjKQ3d9k15lfD7pO4esu9eUIAW7qiYIBppv40A== + dependencies: + loader-utils "^1.1.0" + normalize-url "1.9.1" + schema-utils "^1.0.0" + webpack-sources "^1.1.0" + +minimalistic-assert@^1.0.0, minimalistic-assert@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/minimalistic-assert/-/minimalistic-assert-1.0.1.tgz#2e194de044626d4a10e7f7fbc00ce73e83e4d5c7" + integrity sha512-UtJcAD4yEaGtjPezWuO9wC4nwUnVH/8/Im3yEHQP4b67cXlD/Qr9hdITCU1xDbSEXg2XKNaP8jsReV7vQd00/A== + +minimalistic-crypto-utils@^1.0.0, minimalistic-crypto-utils@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/minimalistic-crypto-utils/-/minimalistic-crypto-utils-1.0.1.tgz#f6c00c1c0b082246e5c4d99dfb8c7c083b2b582a" + integrity sha1-9sAMHAsIIkblxNmd+4x8CDsrWCo= + +minimatch@3.0.4, minimatch@^3.0.4: + version "3.0.4" + resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-3.0.4.tgz#5166e286457f03306064be5497e8dbb0c3d32083" + integrity sha512-yJHVQEhyqPLUTgt9B83PXu6W3rx4MvvHvSUvToogpwoGDOUQ+yDrR0HRot+yOCdCO7u4hX3pWft6kWBBcqh0UA== + dependencies: + brace-expansion "^1.1.7" + +minimist@^1.1.1, minimist@^1.2.0, minimist@^1.2.5: + version "1.2.5" + resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.5.tgz#67d66014b66a6a8aaa0c083c5fd58df4e4e97602" + integrity sha512-FM9nNUYrRBAELZQT3xeZQ7fmMOBg6nWNmJKTcgsJeaLstP/UODVpGsr5OhXhhXg6f+qtJ8uiZ+PUxkDWcgIXLw== + +minipass-collect@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/minipass-collect/-/minipass-collect-1.0.2.tgz#22b813bf745dc6edba2576b940022ad6edc8c617" + integrity sha512-6T6lH0H8OG9kITm/Jm6tdooIbogG9e0tLgpY6mphXSm/A9u8Nq1ryBG+Qspiub9LjWlBPsPS3tWQ/Botq4FdxA== + dependencies: + minipass "^3.0.0" + +minipass-flush@^1.0.5: + version "1.0.5" + resolved "https://registry.yarnpkg.com/minipass-flush/-/minipass-flush-1.0.5.tgz#82e7135d7e89a50ffe64610a787953c4c4cbb373" + integrity sha512-JmQSYYpPUqX5Jyn1mXaRwOda1uQ8HP5KAT/oDSLCzt1BYRhQU0/hDtsB1ufZfEEzMZ9aAVmsBw8+FWsIXlClWw== + dependencies: + minipass "^3.0.0" + +minipass-pipeline@^1.2.2: + version "1.2.2" + resolved "https://registry.yarnpkg.com/minipass-pipeline/-/minipass-pipeline-1.2.2.tgz#3dcb6bb4a546e32969c7ad710f2c79a86abba93a" + integrity sha512-3JS5A2DKhD2g0Gg8x3yamO0pj7YeKGwVlDS90pF++kxptwx/F+B//roxf9SqYil5tQo65bijy+dAuAFZmYOouA== + dependencies: + minipass "^3.0.0" + +minipass@^3.0.0, minipass@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/minipass/-/minipass-3.1.1.tgz#7607ce778472a185ad6d89082aa2070f79cedcd5" + integrity sha512-UFqVihv6PQgwj8/yTGvl9kPz7xIAY+R5z6XYjRInD3Gk3qx6QGSD6zEcpeG4Dy/lQnv1J6zv8ejV90hyYIKf3w== + dependencies: + yallist "^4.0.0" + +mississippi@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/mississippi/-/mississippi-3.0.0.tgz#ea0a3291f97e0b5e8776b363d5f0a12d94c67022" + integrity sha512-x471SsVjUtBRtcvd4BzKE9kFC+/2TeWgKCgw0bZcw1b9l2X3QX5vCWgF+KaZaYm87Ss//rHnWryupDrgLvmSkA== + dependencies: + concat-stream "^1.5.0" + duplexify "^3.4.2" + end-of-stream "^1.1.0" + flush-write-stream "^1.0.0" + from2 "^2.1.0" + parallel-transform "^1.1.0" + pump "^3.0.0" + pumpify "^1.3.3" + stream-each "^1.1.0" + through2 "^2.0.0" + +mixin-deep@^1.2.0: + version "1.3.2" + resolved "https://registry.yarnpkg.com/mixin-deep/-/mixin-deep-1.3.2.tgz#1120b43dc359a785dce65b55b82e257ccf479566" + integrity sha512-WRoDn//mXBiJ1H40rqa3vH0toePwSsGb45iInWlTySa+Uu4k3tYUSxa2v1KqAiLtvlrSzaExqS1gtk96A9zvEA== + dependencies: + for-in "^1.0.2" + is-extendable "^1.0.1" + +mixin-object@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/mixin-object/-/mixin-object-2.0.1.tgz#4fb949441dab182540f1fe035ba60e1947a5e57e" + integrity sha1-T7lJRB2rGCVA8f4DW6YOGUel5X4= + dependencies: + for-in "^0.1.3" + is-extendable "^0.1.1" + +mkdirp@^0.5.1, mkdirp@^0.5.3, mkdirp@~0.5.1: + version "0.5.4" + resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-0.5.4.tgz#fd01504a6797ec5c9be81ff43d204961ed64a512" + integrity sha512-iG9AK/dJLtJ0XNgTuDbSyNS3zECqDlAhnQW4CsNxBG3LQJBbHmRX1egw39DmtOdCAqY+dKXV+sgPgilNWUKMVw== + dependencies: + minimist "^1.2.5" + +moment@^2.10.2: + version "2.29.1" + resolved "https://registry.yarnpkg.com/moment/-/moment-2.29.1.tgz#b2be769fa31940be9eeea6469c075e35006fa3d3" + integrity sha512-kHmoybcPV8Sqy59DwNDY3Jefr64lK/by/da0ViFcuA4DH0vQg5Q6Ze5VimxkfQNSC+Mls/Kx53s7TjP1RhFEDQ== + +move-concurrently@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/move-concurrently/-/move-concurrently-1.0.1.tgz#be2c005fda32e0b29af1f05d7c4b33214c701f92" + integrity sha1-viwAX9oy4LKa8fBdfEszIUxwH5I= + dependencies: + aproba "^1.1.1" + copy-concurrently "^1.0.0" + fs-write-stream-atomic "^1.0.8" + mkdirp "^0.5.1" + rimraf "^2.5.4" + run-queue "^1.0.3" + +ms@2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/ms/-/ms-2.0.0.tgz#5608aeadfc00be6c2901df5f9861788de0d597c8" + integrity sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g= + +ms@2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.1.tgz#30a5864eb3ebb0a66f2ebe6d727af06a09d86e0a" + integrity sha512-tgp+dl5cGk28utYktBsrFqA7HKgrhgPsg6Z/EfhWI4gl1Hwq8B/GmY/0oXZ6nF8hDVesS/FpnYaD/kOWhYQvyg== + +ms@^2.1.1: + version "2.1.2" + resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.2.tgz#d09d1f357b443f493382a8eb3ccd183872ae6009" + integrity sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w== + +multicast-dns-service-types@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/multicast-dns-service-types/-/multicast-dns-service-types-1.1.0.tgz#899f11d9686e5e05cb91b35d5f0e63b773cfc901" + integrity sha1-iZ8R2WhuXgXLkbNdXw5jt3PPyQE= + +multicast-dns@^6.0.1: + version "6.2.3" + resolved "https://registry.yarnpkg.com/multicast-dns/-/multicast-dns-6.2.3.tgz#a0ec7bd9055c4282f790c3c82f4e28db3b31b229" + integrity sha512-ji6J5enbMyGRHIAkAOu3WdV8nggqviKCEKtXcOqfphZZtQrmHKycfynJ2V7eVPUA4NhJ6V7Wf4TmGbTwKE9B6g== + dependencies: + dns-packet "^1.3.1" + thunky "^1.0.2" + +mute-stream@0.0.8: + version "0.0.8" + resolved "https://registry.yarnpkg.com/mute-stream/-/mute-stream-0.0.8.tgz#1630c42b2251ff81e2a283de96a5497ea92e5e0d" + integrity sha512-nnbWWOkoWyUsTjKrhgD0dcz22mdkSnpYqbEjIm2nhwhuxlSkpywJmBo8h0ZqJdkp73mb90SssHkN4rsRaBAfAA== + +nan@^2.12.1: + version "2.14.0" + resolved "https://registry.yarnpkg.com/nan/-/nan-2.14.0.tgz#7818f722027b2459a86f0295d434d1fc2336c52c" + integrity sha512-INOFj37C7k3AfaNTtX8RhsTw7qRy7eLET14cROi9+5HAVbbHuIWUHEauBv5qT4Av2tWasiTY1Jw6puUNqRJXQg== + +nanomatch@^1.2.9: + version "1.2.13" + resolved "https://registry.yarnpkg.com/nanomatch/-/nanomatch-1.2.13.tgz#b87a8aa4fc0de8fe6be88895b38983ff265bd119" + integrity sha512-fpoe2T0RbHwBTBUOftAfBPaDEi06ufaUai0mE6Yn1kacc3SnTErfb/h+X94VXzI64rKFHYImXSvdwGGCmwOqCA== + dependencies: + arr-diff "^4.0.0" + array-unique "^0.3.2" + define-property "^2.0.2" + extend-shallow "^3.0.2" + fragment-cache "^0.2.1" + is-windows "^1.0.2" + kind-of "^6.0.2" + object.pick "^1.3.0" + regex-not "^1.0.0" + snapdragon "^0.8.1" + to-regex "^3.0.1" + +natural-compare@^1.4.0: + version "1.4.0" + resolved "https://registry.yarnpkg.com/natural-compare/-/natural-compare-1.4.0.tgz#4abebfeed7541f2c27acfb29bdbbd15c8d5ba4f7" + integrity sha1-Sr6/7tdUHywnrPspvbvRXI1bpPc= + +negotiator@0.6.2: + version "0.6.2" + resolved "https://registry.yarnpkg.com/negotiator/-/negotiator-0.6.2.tgz#feacf7ccf525a77ae9634436a64883ffeca346fb" + integrity sha512-hZXc7K2e+PgeI1eDBe/10Ard4ekbfrrqG8Ep+8Jmf4JID2bNg7NvCPOZN+kfF574pFQI7mum2AUqDidoKqcTOw== + +neo-async@^2.5.0, neo-async@^2.6.1: + version "2.6.1" + resolved "https://registry.yarnpkg.com/neo-async/-/neo-async-2.6.1.tgz#ac27ada66167fa8849a6addd837f6b189ad2081c" + integrity sha512-iyam8fBuCUpWeKPGpaNMetEocMt364qkCsfL9JuhjXX6dRnguRVOfk2GZaDpPjcOKiiXCPINZC1GczQ7iTq3Zw== + +next-tick@~1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/next-tick/-/next-tick-1.0.0.tgz#ca86d1fe8828169b0120208e3dc8424b9db8342c" + integrity sha1-yobR/ogoFpsBICCOPchCS524NCw= + +nice-try@^1.0.4: + version "1.0.5" + resolved "https://registry.yarnpkg.com/nice-try/-/nice-try-1.0.5.tgz#a3378a7696ce7d223e88fc9b764bd7ef1089e366" + integrity sha512-1nh45deeb5olNY7eX82BkPO7SSxR5SSYJiPTrTdFUVYwAl8CKMA5N9PjTYkHiRjisVcxcQ1HXdLhx2qxxJzLNQ== + +no-case@^3.0.3: + version "3.0.3" + resolved "https://registry.yarnpkg.com/no-case/-/no-case-3.0.3.tgz#c21b434c1ffe48b39087e86cfb4d2582e9df18f8" + integrity sha512-ehY/mVQCf9BL0gKfsJBvFJen+1V//U+0HQMPrWct40ixE4jnv0bfvxDbWtAHL9EcaPEOJHVVYKoQn1TlZUB8Tw== + dependencies: + lower-case "^2.0.1" + tslib "^1.10.0" + +node-forge@0.9.0: + version "0.9.0" + resolved "https://registry.yarnpkg.com/node-forge/-/node-forge-0.9.0.tgz#d624050edbb44874adca12bb9a52ec63cb782579" + integrity sha512-7ASaDa3pD+lJ3WvXFsxekJQelBKRpne+GOVbLbtHYdd7pFspyeuJHnWfLplGf3SwKGbfs/aYl5V/JCIaHVUKKQ== + +node-int64@^0.4.0: + version "0.4.0" + resolved "https://registry.yarnpkg.com/node-int64/-/node-int64-0.4.0.tgz#87a9065cdb355d3182d8f94ce11188b825c68a3b" + integrity sha1-h6kGXNs1XTGC2PlM4RGIuCXGijs= + +node-libs-browser@^2.2.1: + version "2.2.1" + resolved "https://registry.yarnpkg.com/node-libs-browser/-/node-libs-browser-2.2.1.tgz#b64f513d18338625f90346d27b0d235e631f6425" + integrity sha512-h/zcD8H9kaDZ9ALUWwlBUDo6TKF8a7qBSCSEGfjTVIYeqsioSKaAX+BN7NgiMGp6iSIXZ3PxgCu8KS3b71YK5Q== + dependencies: + assert "^1.1.1" + browserify-zlib "^0.2.0" + buffer "^4.3.0" + console-browserify "^1.1.0" + constants-browserify "^1.0.0" + crypto-browserify "^3.11.0" + domain-browser "^1.1.1" + events "^3.0.0" + https-browserify "^1.0.0" + os-browserify "^0.3.0" + path-browserify "0.0.1" + process "^0.11.10" + punycode "^1.2.4" + querystring-es3 "^0.2.0" + readable-stream "^2.3.3" + stream-browserify "^2.0.1" + stream-http "^2.7.2" + string_decoder "^1.0.0" + timers-browserify "^2.0.4" + tty-browserify "0.0.0" + url "^0.11.0" + util "^0.11.0" + vm-browserify "^1.0.1" + +node-modules-regexp@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/node-modules-regexp/-/node-modules-regexp-1.0.0.tgz#8d9dbe28964a4ac5712e9131642107c71e90ec40" + integrity sha1-jZ2+KJZKSsVxLpExZCEHxx6Q7EA= + +node-notifier@^5.4.2: + version "5.4.3" + resolved "https://registry.yarnpkg.com/node-notifier/-/node-notifier-5.4.3.tgz#cb72daf94c93904098e28b9c590fd866e464bd50" + integrity sha512-M4UBGcs4jeOK9CjTsYwkvH6/MzuUmGCyTW+kCY7uO+1ZVr0+FHGdPdIf5CCLqAaxnRrWidyoQlNkMIIVwbKB8Q== + dependencies: + growly "^1.3.0" + is-wsl "^1.1.0" + semver "^5.5.0" + shellwords "^0.1.1" + which "^1.3.0" + +node-releases@^1.1.52, node-releases@^1.1.53: + version "1.1.53" + resolved "https://registry.yarnpkg.com/node-releases/-/node-releases-1.1.53.tgz#2d821bfa499ed7c5dffc5e2f28c88e78a08ee3f4" + integrity sha512-wp8zyQVwef2hpZ/dJH7SfSrIPD6YoJz6BDQDpGEkcA0s3LpAQoxBIYmfIq6QAhC1DhwsyCgTaTTcONwX8qzCuQ== + +normalize-package-data@^2.3.2: + version "2.5.0" + resolved "https://registry.yarnpkg.com/normalize-package-data/-/normalize-package-data-2.5.0.tgz#e66db1838b200c1dfc233225d12cb36520e234a8" + integrity sha512-/5CMN3T0R4XTj4DcGaexo+roZSdSFW/0AOOTROrjxzCG1wrWXEsGbRKevjlIL+ZDE4sZlJr5ED4YW0yqmkK+eA== + dependencies: + hosted-git-info "^2.1.4" + resolve "^1.10.0" + semver "2 || 3 || 4 || 5" + validate-npm-package-license "^3.0.1" + +normalize-path@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/normalize-path/-/normalize-path-2.1.1.tgz#1ab28b556e198363a8c1a6f7e6fa20137fe6aed9" + integrity sha1-GrKLVW4Zg2Oowab35vogE3/mrtk= + dependencies: + remove-trailing-separator "^1.0.1" + +normalize-path@^3.0.0, normalize-path@~3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/normalize-path/-/normalize-path-3.0.0.tgz#0dcd69ff23a1c9b11fd0978316644a0388216a65" + integrity sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA== + +normalize-range@^0.1.2: + version "0.1.2" + resolved "https://registry.yarnpkg.com/normalize-range/-/normalize-range-0.1.2.tgz#2d10c06bdfd312ea9777695a4d28439456b75942" + integrity sha1-LRDAa9/TEuqXd2laTShDlFa3WUI= + +normalize-url@1.9.1: + version "1.9.1" + resolved "https://registry.yarnpkg.com/normalize-url/-/normalize-url-1.9.1.tgz#2cc0d66b31ea23036458436e3620d85954c66c3c" + integrity sha1-LMDWazHqIwNkWENuNiDYWVTGbDw= + dependencies: + object-assign "^4.0.1" + prepend-http "^1.0.0" + query-string "^4.1.0" + sort-keys "^1.0.0" + +normalize-url@^3.0.0: + version "3.3.0" + resolved "https://registry.yarnpkg.com/normalize-url/-/normalize-url-3.3.0.tgz#b2e1c4dc4f7c6d57743df733a4f5978d18650559" + integrity sha512-U+JJi7duF1o+u2pynbp2zXDW2/PADgC30f0GsHZtRh+HOcXHnw137TrNlyxxRvWW5fjKd3bcLHPxofWuCjaeZg== + +npm-run-path@^2.0.0: + version "2.0.2" + resolved "https://registry.yarnpkg.com/npm-run-path/-/npm-run-path-2.0.2.tgz#35a9232dfa35d7067b4cb2ddf2357b1871536c5f" + integrity sha1-NakjLfo11wZ7TLLd8jV7GHFTbF8= + dependencies: + path-key "^2.0.0" + +nth-check@^1.0.2, nth-check@~1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/nth-check/-/nth-check-1.0.2.tgz#b2bd295c37e3dd58a3bf0700376663ba4d9cf05c" + integrity sha512-WeBOdju8SnzPN5vTUJYxYUxLeXpCaVP5i5e0LF8fg7WORF2Wd7wFX/pk0tYZk7s8T+J7VLy0Da6J1+wCT0AtHg== + dependencies: + boolbase "~1.0.0" + +num2fraction@^1.2.2: + version "1.2.2" + resolved "https://registry.yarnpkg.com/num2fraction/-/num2fraction-1.2.2.tgz#6f682b6a027a4e9ddfa4564cd2589d1d4e669ede" + integrity sha1-b2gragJ6Tp3fpFZM0lidHU5mnt4= + +number-is-nan@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/number-is-nan/-/number-is-nan-1.0.1.tgz#097b602b53422a522c1afb8790318336941a011d" + integrity sha1-CXtgK1NCKlIsGvuHkDGDNpQaAR0= + +nwsapi@^2.0.7, nwsapi@^2.1.3: + version "2.2.0" + resolved "https://registry.yarnpkg.com/nwsapi/-/nwsapi-2.2.0.tgz#204879a9e3d068ff2a55139c2c772780681a38b7" + integrity sha512-h2AatdwYH+JHiZpv7pt/gSX1XoRGb7L/qSIeuqA6GwYoF9w1vP1cw42TO0aI2pNyshRK5893hNSl+1//vHK7hQ== + +oauth-sign@~0.9.0: + version "0.9.0" + resolved "https://registry.yarnpkg.com/oauth-sign/-/oauth-sign-0.9.0.tgz#47a7b016baa68b5fa0ecf3dee08a85c679ac6455" + integrity sha512-fexhUFFPTGV8ybAtSIGbV6gOkSv8UtRbDBnAyLQw4QPKkgNlsH2ByPGtMUqdWkos6YCRmAqViwgZrJc/mRDzZQ== + +object-assign@^4.0.1, object-assign@^4.1.0, object-assign@^4.1.1: + version "4.1.1" + resolved "https://registry.yarnpkg.com/object-assign/-/object-assign-4.1.1.tgz#2109adc7965887cfc05cbbd442cac8bfbb360863" + integrity sha1-IQmtx5ZYh8/AXLvUQsrIv7s2CGM= + +object-copy@^0.1.0: + version "0.1.0" + resolved "https://registry.yarnpkg.com/object-copy/-/object-copy-0.1.0.tgz#7e7d858b781bd7c991a41ba975ed3812754e998c" + integrity sha1-fn2Fi3gb18mRpBupde04EnVOmYw= + dependencies: + copy-descriptor "^0.1.0" + define-property "^0.2.5" + kind-of "^3.0.3" + +object-hash@^2.0.1: + version "2.0.3" + resolved "https://registry.yarnpkg.com/object-hash/-/object-hash-2.0.3.tgz#d12db044e03cd2ca3d77c0570d87225b02e1e6ea" + integrity sha512-JPKn0GMu+Fa3zt3Bmr66JhokJU5BaNBIh4ZeTlaCBzrBsOeXzwcKKAK1tbLiPKgvwmPXsDvvLHoWh5Bm7ofIYg== + +object-inspect@^1.7.0: + version "1.7.0" + resolved "https://registry.yarnpkg.com/object-inspect/-/object-inspect-1.7.0.tgz#f4f6bd181ad77f006b5ece60bd0b6f398ff74a67" + integrity sha512-a7pEHdh1xKIAgTySUGgLMx/xwDZskN1Ud6egYYN3EdRW4ZMPNEDUTF+hwy2LUC+Bl+SyLXANnwz/jyh/qutKUw== + +object-is@^1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/object-is/-/object-is-1.0.2.tgz#6b80eb84fe451498f65007982f035a5b445edec4" + integrity sha512-Epah+btZd5wrrfjkJZq1AOB9O6OxUQto45hzFd7lXGrpHPGE0W1k+426yrZV+k6NJOzLNNW/nVsmZdIWsAqoOQ== + +object-keys@^1.0.11, object-keys@^1.0.12, object-keys@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/object-keys/-/object-keys-1.1.1.tgz#1c47f272df277f3b1daf061677d9c82e2322c60e" + integrity sha512-NuAESUOUMrlIXOfHKzD6bpPu3tYt3xvjNdRIQ+FeT0lNb4K8WR70CaDxhuNguS2XG+GjkyMwOzsN5ZktImfhLA== + +object-path@0.11.4: + version "0.11.4" + resolved "https://registry.yarnpkg.com/object-path/-/object-path-0.11.4.tgz#370ae752fbf37de3ea70a861c23bba8915691949" + integrity sha1-NwrnUvvzfePqcKhhwju6iRVpGUk= + +object-visit@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/object-visit/-/object-visit-1.0.1.tgz#f79c4493af0c5377b59fe39d395e41042dd045bb" + integrity sha1-95xEk68MU3e1n+OdOV5BBC3QRbs= + dependencies: + isobject "^3.0.0" + +object.assign@^4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/object.assign/-/object.assign-4.1.0.tgz#968bf1100d7956bb3ca086f006f846b3bc4008da" + integrity sha512-exHJeq6kBKj58mqGyTQ9DFvrZC/eR6OwxzoM9YRoGBqrXYonaFyGiFMuc9VZrXf7DarreEwMpurG3dd+CNyW5w== + dependencies: + define-properties "^1.1.2" + function-bind "^1.1.1" + has-symbols "^1.0.0" + object-keys "^1.0.11" + +object.entries@^1.1.0, object.entries@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/object.entries/-/object.entries-1.1.1.tgz#ee1cf04153de02bb093fec33683900f57ce5399b" + integrity sha512-ilqR7BgdyZetJutmDPfXCDffGa0/Yzl2ivVNpbx/g4UeWrCdRnFDUBrKJGLhGieRHDATnyZXWBeCb29k9CJysQ== + dependencies: + define-properties "^1.1.3" + es-abstract "^1.17.0-next.1" + function-bind "^1.1.1" + has "^1.0.3" + +object.fromentries@^2.0.2: + version "2.0.2" + resolved "https://registry.yarnpkg.com/object.fromentries/-/object.fromentries-2.0.2.tgz#4a09c9b9bb3843dd0f89acdb517a794d4f355ac9" + integrity sha512-r3ZiBH7MQppDJVLx6fhD618GKNG40CZYH9wgwdhKxBDDbQgjeWGGd4AtkZad84d291YxvWe7bJGuE65Anh0dxQ== + dependencies: + define-properties "^1.1.3" + es-abstract "^1.17.0-next.1" + function-bind "^1.1.1" + has "^1.0.3" + +object.getownpropertydescriptors@^2.0.3, object.getownpropertydescriptors@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/object.getownpropertydescriptors/-/object.getownpropertydescriptors-2.1.0.tgz#369bf1f9592d8ab89d712dced5cb81c7c5352649" + integrity sha512-Z53Oah9A3TdLoblT7VKJaTDdXdT+lQO+cNpKVnya5JDe9uLvzu1YyY1yFDFrcxrlRgWrEFH0jJtD/IbuwjcEVg== + dependencies: + define-properties "^1.1.3" + es-abstract "^1.17.0-next.1" + +object.pick@^1.3.0: + version "1.3.0" + resolved "https://registry.yarnpkg.com/object.pick/-/object.pick-1.3.0.tgz#87a10ac4c1694bd2e1cbf53591a66141fb5dd747" + integrity sha1-h6EKxMFpS9Lhy/U1kaZhQftd10c= + dependencies: + isobject "^3.0.1" + +object.values@^1.1.0, object.values@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/object.values/-/object.values-1.1.1.tgz#68a99ecde356b7e9295a3c5e0ce31dc8c953de5e" + integrity sha512-WTa54g2K8iu0kmS/us18jEmdv1a4Wi//BZ/DTVYEcH0XhLM5NYdpDHja3gt57VrZLcNAO2WGA+KpWsDBaHt6eA== + dependencies: + define-properties "^1.1.3" + es-abstract "^1.17.0-next.1" + function-bind "^1.1.1" + has "^1.0.3" + +obuf@^1.0.0, obuf@^1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/obuf/-/obuf-1.1.2.tgz#09bea3343d41859ebd446292d11c9d4db619084e" + integrity sha512-PX1wu0AmAdPqOL1mWhqmlOd8kOIZQwGZw6rh7uby9fTc5lhaOWFLX3I6R1hrF9k3zUY40e6igsLGkDXK92LJNg== + +on-finished@~2.3.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/on-finished/-/on-finished-2.3.0.tgz#20f1336481b083cd75337992a16971aa2d906947" + integrity sha1-IPEzZIGwg811M3mSoWlxqi2QaUc= + dependencies: + ee-first "1.1.1" + +on-headers@~1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/on-headers/-/on-headers-1.0.2.tgz#772b0ae6aaa525c399e489adfad90c403eb3c28f" + integrity sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA== + +once@^1.3.0, once@^1.3.1, once@^1.4.0: + version "1.4.0" + resolved "https://registry.yarnpkg.com/once/-/once-1.4.0.tgz#583b1aa775961d4b113ac17d9c50baef9dd76bd1" + integrity sha1-WDsap3WWHUsROsF9nFC6753Xa9E= + dependencies: + wrappy "1" + +onetime@^5.1.0: + version "5.1.0" + resolved "https://registry.yarnpkg.com/onetime/-/onetime-5.1.0.tgz#fff0f3c91617fe62bb50189636e99ac8a6df7be5" + integrity sha512-5NcSkPHhwTVFIQN+TUqXoS5+dlElHXdpAWu9I0HP20YOtIi+aZ0Ct82jdlILDxjLEAWwvm+qj1m6aEtsDVmm6Q== + dependencies: + mimic-fn "^2.1.0" + +open@^7.0.2: + version "7.0.3" + resolved "https://registry.yarnpkg.com/open/-/open-7.0.3.tgz#db551a1af9c7ab4c7af664139930826138531c48" + integrity sha512-sP2ru2v0P290WFfv49Ap8MF6PkzGNnGlAwHweB4WR4mr5d2d0woiCluUeJ218w7/+PmoBy9JmYgD5A4mLcWOFA== + dependencies: + is-docker "^2.0.0" + is-wsl "^2.1.1" + +opn@^5.5.0: + version "5.5.0" + resolved "https://registry.yarnpkg.com/opn/-/opn-5.5.0.tgz#fc7164fab56d235904c51c3b27da6758ca3b9bfc" + integrity sha512-PqHpggC9bLV0VeWcdKhkpxY+3JTzetLSqTCWL/z/tFIbI6G8JCjondXklT1JinczLz2Xib62sSp0T/gKT4KksA== + dependencies: + is-wsl "^1.1.0" + +optimize-css-assets-webpack-plugin@5.0.3: + version "5.0.3" + resolved "https://registry.yarnpkg.com/optimize-css-assets-webpack-plugin/-/optimize-css-assets-webpack-plugin-5.0.3.tgz#e2f1d4d94ad8c0af8967ebd7cf138dcb1ef14572" + integrity sha512-q9fbvCRS6EYtUKKSwI87qm2IxlyJK5b4dygW1rKUBT6mMDhdG5e5bZT63v6tnJR9F9FB/H5a0HTmtw+laUBxKA== + dependencies: + cssnano "^4.1.10" + last-call-webpack-plugin "^3.0.0" + +optionator@^0.8.1, optionator@^0.8.3: + version "0.8.3" + resolved "https://registry.yarnpkg.com/optionator/-/optionator-0.8.3.tgz#84fa1d036fe9d3c7e21d99884b601167ec8fb495" + integrity sha512-+IW9pACdk3XWmmTXG8m3upGUJst5XRGzxMRjXzAuJ1XnIFNvfhjjIuYkDvysnPQ7qzqVzLt78BCruntqRhWQbA== + dependencies: + deep-is "~0.1.3" + fast-levenshtein "~2.0.6" + levn "~0.3.0" + prelude-ls "~1.1.2" + type-check "~0.3.2" + word-wrap "~1.2.3" + +original@^1.0.0: + version "1.0.2" + resolved "https://registry.yarnpkg.com/original/-/original-1.0.2.tgz#e442a61cffe1c5fd20a65f3261c26663b303f25f" + integrity sha512-hyBVl6iqqUOJ8FqRe+l/gS8H+kKYjrEndd5Pm1MfBtsEKA038HkkdbAl/72EAXGyonD/PFsvmVG+EvcIpliMBg== + dependencies: + url-parse "^1.4.3" + +os-browserify@^0.3.0: + version "0.3.0" + resolved "https://registry.yarnpkg.com/os-browserify/-/os-browserify-0.3.0.tgz#854373c7f5c2315914fc9bfc6bd8238fdda1ec27" + integrity sha1-hUNzx/XCMVkU/Jv8a9gjj92h7Cc= + +os-locale@^3.0.0, os-locale@^3.1.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/os-locale/-/os-locale-3.1.0.tgz#a802a6ee17f24c10483ab9935719cef4ed16bf1a" + integrity sha512-Z8l3R4wYWM40/52Z+S265okfFj8Kt2cC2MKY+xNi3kFs+XGI7WXu/I309QQQYbRW4ijiZ+yxs9pqEhJh0DqW3Q== + dependencies: + execa "^1.0.0" + lcid "^2.0.0" + mem "^4.0.0" + +os-tmpdir@~1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/os-tmpdir/-/os-tmpdir-1.0.2.tgz#bbe67406c79aa85c5cfec766fe5734555dfa1274" + integrity sha1-u+Z0BseaqFxc/sdm/lc0VV36EnQ= + +p-defer@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/p-defer/-/p-defer-1.0.0.tgz#9f6eb182f6c9aa8cd743004a7d4f96b196b0fb0c" + integrity sha1-n26xgvbJqozXQwBKfU+WsZaw+ww= + +p-each-series@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/p-each-series/-/p-each-series-1.0.0.tgz#930f3d12dd1f50e7434457a22cd6f04ac6ad7f71" + integrity sha1-kw89Et0fUOdDRFeiLNbwSsatf3E= + dependencies: + p-reduce "^1.0.0" + +p-finally@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/p-finally/-/p-finally-1.0.0.tgz#3fbcfb15b899a44123b34b6dcc18b724336a2cae" + integrity sha1-P7z7FbiZpEEjs0ttzBi3JDNqLK4= + +p-is-promise@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/p-is-promise/-/p-is-promise-2.1.0.tgz#918cebaea248a62cf7ffab8e3bca8c5f882fc42e" + integrity sha512-Y3W0wlRPK8ZMRbNq97l4M5otioeA5lm1z7bkNkxCka8HSPjR0xRWmpCmc9utiaLP9Jb1eD8BgeIxTW4AIF45Pg== + +p-limit@^1.1.0: + version "1.3.0" + resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-1.3.0.tgz#b86bd5f0c25690911c7590fcbfc2010d54b3ccb8" + integrity sha512-vvcXsLAJ9Dr5rQOPk7toZQZJApBl2K4J6dANSsEuh6QI41JYcsS/qhTGa9ErIUUgK3WNQoJYvylxvjqmiqEA9Q== + dependencies: + p-try "^1.0.0" + +p-limit@^2.0.0, p-limit@^2.2.0, p-limit@^2.2.2: + version "2.2.2" + resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-2.2.2.tgz#61279b67721f5287aa1c13a9a7fbbc48c9291b1e" + integrity sha512-WGR+xHecKTr7EbUEhyLSh5Dube9JtdiG78ufaeLxTgpudf/20KqyMioIUZJAezlTIi6evxuoUs9YXc11cU+yzQ== + dependencies: + p-try "^2.0.0" + +p-locate@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-2.0.0.tgz#20a0103b222a70c8fd39cc2e580680f3dde5ec43" + integrity sha1-IKAQOyIqcMj9OcwuWAaA893l7EM= + dependencies: + p-limit "^1.1.0" + +p-locate@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-3.0.0.tgz#322d69a05c0264b25997d9f40cd8a891ab0064a4" + integrity sha512-x+12w/To+4GFfgJhBEpiDcLozRJGegY+Ei7/z0tSLkMmxGZNybVMSfWj9aJn8Z5Fc7dBUNJOOVgPv2H7IwulSQ== + dependencies: + p-limit "^2.0.0" + +p-locate@^4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-4.1.0.tgz#a3428bb7088b3a60292f66919278b7c297ad4f07" + integrity sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A== + dependencies: + p-limit "^2.2.0" + +p-map@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/p-map/-/p-map-2.1.0.tgz#310928feef9c9ecc65b68b17693018a665cea175" + integrity sha512-y3b8Kpd8OAN444hxfBbFfj1FY/RjtTd8tzYwhUqNYXx0fXx2iX4maP4Qr6qhIKbQXI02wTLAda4fYUbDagTUFw== + +p-map@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/p-map/-/p-map-3.0.0.tgz#d704d9af8a2ba684e2600d9a215983d4141a979d" + integrity sha512-d3qXVTF/s+W+CdJ5A29wywV2n8CQQYahlgz2bFiA+4eVNJbHJodPZ+/gXwPGh0bOqA+j8S+6+ckmvLGPk1QpxQ== + dependencies: + aggregate-error "^3.0.0" + +p-reduce@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/p-reduce/-/p-reduce-1.0.0.tgz#18c2b0dd936a4690a529f8231f58a0fdb6a47dfa" + integrity sha1-GMKw3ZNqRpClKfgjH1ig/bakffo= + +p-retry@^3.0.1: + version "3.0.1" + resolved "https://registry.yarnpkg.com/p-retry/-/p-retry-3.0.1.tgz#316b4c8893e2c8dc1cfa891f406c4b422bebf328" + integrity sha512-XE6G4+YTTkT2a0UWb2kjZe8xNwf8bIbnqpc/IS/idOBVhyves0mK5OJgeocjx7q5pvX/6m23xuzVPYT1uGM73w== + dependencies: + retry "^0.12.0" + +p-try@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/p-try/-/p-try-1.0.0.tgz#cbc79cdbaf8fd4228e13f621f2b1a237c1b207b3" + integrity sha1-y8ec26+P1CKOE/Yh8rGiN8GyB7M= + +p-try@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/p-try/-/p-try-2.2.0.tgz#cb2868540e313d61de58fafbe35ce9004d5540e6" + integrity sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ== + +pako@~1.0.5: + version "1.0.11" + resolved "https://registry.yarnpkg.com/pako/-/pako-1.0.11.tgz#6c9599d340d54dfd3946380252a35705a6b992bf" + integrity sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw== + +parallel-transform@^1.1.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/parallel-transform/-/parallel-transform-1.2.0.tgz#9049ca37d6cb2182c3b1d2c720be94d14a5814fc" + integrity sha512-P2vSmIu38uIlvdcU7fDkyrxj33gTUy/ABO5ZUbGowxNCopBq/OoD42bP4UmMrJoPyk4Uqf0mu3mtWBhHCZD8yg== + dependencies: + cyclist "^1.0.1" + inherits "^2.0.3" + readable-stream "^2.1.5" + +param-case@^3.0.3: + version "3.0.3" + resolved "https://registry.yarnpkg.com/param-case/-/param-case-3.0.3.tgz#4be41f8399eff621c56eebb829a5e451d9801238" + integrity sha512-VWBVyimc1+QrzappRs7waeN2YmoZFCGXWASRYX1/rGHtXqEcrGEIDm+jqIwFa2fRXNgQEwrxaYuIrX0WcAguTA== + dependencies: + dot-case "^3.0.3" + tslib "^1.10.0" + +parent-module@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/parent-module/-/parent-module-1.0.1.tgz#691d2709e78c79fae3a156622452d00762caaaa2" + integrity sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g== + dependencies: + callsites "^3.0.0" + +parse-asn1@^5.0.0: + version "5.1.5" + resolved "https://registry.yarnpkg.com/parse-asn1/-/parse-asn1-5.1.5.tgz#003271343da58dc94cace494faef3d2147ecea0e" + integrity sha512-jkMYn1dcJqF6d5CpU689bq7w/b5ALS9ROVSpQDPrZsqqesUJii9qutvoT5ltGedNXMO2e16YUWIghG9KxaViTQ== + dependencies: + asn1.js "^4.0.0" + browserify-aes "^1.0.0" + create-hash "^1.1.0" + evp_bytestokey "^1.0.0" + pbkdf2 "^3.0.3" + safe-buffer "^5.1.1" + +parse-json@^2.2.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/parse-json/-/parse-json-2.2.0.tgz#f480f40434ef80741f8469099f8dea18f55a4dc9" + integrity sha1-9ID0BDTvgHQfhGkJn43qGPVaTck= + dependencies: + error-ex "^1.2.0" + +parse-json@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/parse-json/-/parse-json-4.0.0.tgz#be35f5425be1f7f6c747184f98a788cb99477ee0" + integrity sha1-vjX1Qlvh9/bHRxhPmKeIy5lHfuA= + dependencies: + error-ex "^1.3.1" + json-parse-better-errors "^1.0.1" + +parse-json@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/parse-json/-/parse-json-5.0.0.tgz#73e5114c986d143efa3712d4ea24db9a4266f60f" + integrity sha512-OOY5b7PAEFV0E2Fir1KOkxchnZNCdowAJgQ5NuxjpBKTRP3pQhwkrkxqQjeoKJ+fO7bCpmIZaogI4eZGDMEGOw== + dependencies: + "@babel/code-frame" "^7.0.0" + error-ex "^1.3.1" + json-parse-better-errors "^1.0.1" + lines-and-columns "^1.1.6" + +parse-passwd@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/parse-passwd/-/parse-passwd-1.0.0.tgz#6d5b934a456993b23d37f40a382d6f1666a8e5c6" + integrity sha1-bVuTSkVpk7I9N/QKOC1vFmao5cY= + +parse5@4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/parse5/-/parse5-4.0.0.tgz#6d78656e3da8d78b4ec0b906f7c08ef1dfe3f608" + integrity sha512-VrZ7eOd3T1Fk4XWNXMgiGBK/z0MG48BWG2uQNU4I72fkQuKUTZpl+u9k+CxEG0twMVzSmXEEz12z5Fnw1jIQFA== + +parse5@5.1.0: + version "5.1.0" + resolved "https://registry.yarnpkg.com/parse5/-/parse5-5.1.0.tgz#c59341c9723f414c452975564c7c00a68d58acd2" + integrity sha512-fxNG2sQjHvlVAYmzBZS9YlDp6PTSSDwa98vkD4QgVDDCAo84z5X1t5XyJQ62ImdLXx5NdIIfihey6xpum9/gRQ== + +parseurl@~1.3.2, parseurl@~1.3.3: + version "1.3.3" + resolved "https://registry.yarnpkg.com/parseurl/-/parseurl-1.3.3.tgz#9da19e7bee8d12dff0513ed5b76957793bc2e8d4" + integrity sha512-CiyeOxFT/JZyN5m0z9PfXw4SCBJ6Sygz1Dpl0wqjlhDEGGBP1GnsUVEL0p63hoG1fcj3fHynXi9NYO4nWOL+qQ== + +pascal-case@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/pascal-case/-/pascal-case-3.1.1.tgz#5ac1975133ed619281e88920973d2cd1f279de5f" + integrity sha512-XIeHKqIrsquVTQL2crjq3NfJUxmdLasn3TYOU0VBM+UX2a6ztAWBlJQBePLGY7VHW8+2dRadeIPK5+KImwTxQA== + dependencies: + no-case "^3.0.3" + tslib "^1.10.0" + +pascalcase@^0.1.1: + version "0.1.1" + resolved "https://registry.yarnpkg.com/pascalcase/-/pascalcase-0.1.1.tgz#b363e55e8006ca6fe21784d2db22bd15d7917f14" + integrity sha1-s2PlXoAGym/iF4TS2yK9FdeRfxQ= + +path-browserify@0.0.1: + version "0.0.1" + resolved "https://registry.yarnpkg.com/path-browserify/-/path-browserify-0.0.1.tgz#e6c4ddd7ed3aa27c68a20cc4e50e1a4ee83bbc4a" + integrity sha512-BapA40NHICOS+USX9SN4tyhq+A2RrN/Ws5F0Z5aMHDp98Fl86lX8Oti8B7uN93L4Ifv4fHOEA+pQw87gmMO/lQ== + +path-dirname@^1.0.0: + version "1.0.2" + resolved "https://registry.yarnpkg.com/path-dirname/-/path-dirname-1.0.2.tgz#cc33d24d525e099a5388c0336c6e32b9160609e0" + integrity sha1-zDPSTVJeCZpTiMAzbG4yuRYGCeA= + +path-exists@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-2.1.0.tgz#0feb6c64f0fc518d9a754dd5efb62c7022761f4b" + integrity sha1-D+tsZPD8UY2adU3V77YscCJ2H0s= + dependencies: + pinkie-promise "^2.0.0" + +path-exists@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-3.0.0.tgz#ce0ebeaa5f78cb18925ea7d810d7b59b010fd515" + integrity sha1-zg6+ql94yxiSXqfYENe1mwEP1RU= + +path-exists@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-4.0.0.tgz#513bdbe2d3b95d7762e8c1137efa195c6c61b5b3" + integrity sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w== + +path-is-absolute@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz#174b9268735534ffbc7ace6bf53a5a9e1b5c5f5f" + integrity sha1-F0uSaHNVNP+8es5r9TpanhtcX18= + +path-is-inside@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/path-is-inside/-/path-is-inside-1.0.2.tgz#365417dede44430d1c11af61027facf074bdfc53" + integrity sha1-NlQX3t5EQw0cEa9hAn+s8HS9/FM= + +path-key@^2.0.0, path-key@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/path-key/-/path-key-2.0.1.tgz#411cadb574c5a140d3a4b1910d40d80cc9f40b40" + integrity sha1-QRyttXTFoUDTpLGRDUDYDMn0C0A= + +path-key@^3.1.0: + version "3.1.1" + resolved "https://registry.yarnpkg.com/path-key/-/path-key-3.1.1.tgz#581f6ade658cbba65a0d3380de7753295054f375" + integrity sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q== + +path-parse@^1.0.6: + version "1.0.6" + resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.6.tgz#d62dbb5679405d72c4737ec58600e9ddcf06d24c" + integrity sha512-GSmOT2EbHrINBf9SR7CDELwlJ8AENk3Qn7OikK4nFYAu3Ote2+JYNVvkpAEQm3/TLNEJFD/xZJjzyxg3KBWOzw== + +path-to-regexp@0.1.7: + version "0.1.7" + resolved "https://registry.yarnpkg.com/path-to-regexp/-/path-to-regexp-0.1.7.tgz#df604178005f522f15eb4490e7247a1bfaa67f8c" + integrity sha1-32BBeABfUi8V60SQ5yR6G/qmf4w= + +path-type@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/path-type/-/path-type-2.0.0.tgz#f012ccb8415b7096fc2daa1054c3d72389594c73" + integrity sha1-8BLMuEFbcJb8LaoQVMPXI4lZTHM= + dependencies: + pify "^2.0.0" + +path-type@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/path-type/-/path-type-3.0.0.tgz#cef31dc8e0a1a3bb0d105c0cd97cf3bf47f4e36f" + integrity sha512-T2ZUsdZFHgA3u4e5PfPbjd7HDDpxPnQb5jN0SrDsjNSuVXHJqtwTnWqG0B1jZrgmJ/7lj1EmVIByWt1gxGkWvg== + dependencies: + pify "^3.0.0" + +path-type@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/path-type/-/path-type-4.0.0.tgz#84ed01c0a7ba380afe09d90a8c180dcd9d03043b" + integrity sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw== + +pbkdf2@^3.0.3: + version "3.0.17" + resolved "https://registry.yarnpkg.com/pbkdf2/-/pbkdf2-3.0.17.tgz#976c206530617b14ebb32114239f7b09336e93a6" + integrity sha512-U/il5MsrZp7mGg3mSQfn742na2T+1/vHDCG5/iTI3X9MKUuYUZVLQhyRsg06mCgDBTd57TxzgZt7P+fYfjRLtA== + dependencies: + create-hash "^1.1.2" + create-hmac "^1.1.4" + ripemd160 "^2.0.1" + safe-buffer "^5.0.1" + sha.js "^2.4.8" + +performance-now@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/performance-now/-/performance-now-2.1.0.tgz#6309f4e0e5fa913ec1c69307ae364b4b377c9e7b" + integrity sha1-Ywn04OX6kT7BxpMHrjZLSzd8nns= + +picomatch@^2.0.4, picomatch@^2.0.7: + version "2.2.2" + resolved "https://registry.yarnpkg.com/picomatch/-/picomatch-2.2.2.tgz#21f333e9b6b8eaff02468f5146ea406d345f4dad" + integrity sha512-q0M/9eZHzmr0AulXyPwNfZjtwZ/RBZlbN3K3CErVrk50T2ASYI7Bye0EvekFY3IP1Nt2DHu0re+V2ZHIpMkuWg== + +pify@^2.0.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/pify/-/pify-2.3.0.tgz#ed141a6ac043a849ea588498e7dca8b15330e90c" + integrity sha1-7RQaasBDqEnqWISY59yosVMw6Qw= + +pify@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/pify/-/pify-3.0.0.tgz#e5a4acd2c101fdf3d9a4d07f0dbc4db49dd28176" + integrity sha1-5aSs0sEB/fPZpNB/DbxNtJ3SgXY= + +pify@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/pify/-/pify-4.0.1.tgz#4b2cd25c50d598735c50292224fd8c6df41e3231" + integrity sha512-uB80kBFb/tfd68bVleG9T5GGsGPjJrLAUpR5PZIrhBnIaRTQRjqdJSsIKkOP6OAIFbj7GOrcudc5pNjZ+geV2g== + +pinkie-promise@^2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/pinkie-promise/-/pinkie-promise-2.0.1.tgz#2135d6dfa7a358c069ac9b178776288228450ffa" + integrity sha1-ITXW36ejWMBprJsXh3YogihFD/o= + dependencies: + pinkie "^2.0.0" + +pinkie@^2.0.0: + version "2.0.4" + resolved "https://registry.yarnpkg.com/pinkie/-/pinkie-2.0.4.tgz#72556b80cfa0d48a974e80e77248e80ed4f7f870" + integrity sha1-clVrgM+g1IqXToDnckjoDtT3+HA= + +pirates@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/pirates/-/pirates-4.0.1.tgz#643a92caf894566f91b2b986d2c66950a8e2fb87" + integrity sha512-WuNqLTbMI3tmfef2TKxlQmAiLHKtFhlsCZnPIpuv2Ow0RDVO8lfy1Opf4NUzlMXLjPl+Men7AuVdX6TA+s+uGA== + dependencies: + node-modules-regexp "^1.0.0" + +pkg-dir@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/pkg-dir/-/pkg-dir-1.0.0.tgz#7a4b508a8d5bb2d629d447056ff4e9c9314cf3d4" + integrity sha1-ektQio1bstYp1EcFb/TpyTFM89Q= + dependencies: + find-up "^1.0.0" + +pkg-dir@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/pkg-dir/-/pkg-dir-2.0.0.tgz#f6d5d1109e19d63edf428e0bd57e12777615334b" + integrity sha1-9tXREJ4Z1j7fQo4L1X4Sd3YVM0s= + dependencies: + find-up "^2.1.0" + +pkg-dir@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/pkg-dir/-/pkg-dir-3.0.0.tgz#2749020f239ed990881b1f71210d51eb6523bea3" + integrity sha512-/E57AYkoeQ25qkxMj5PBOVgF8Kiu/h7cYS30Z5+R7WaiCCBfLq58ZI/dSeaEKb9WVJV5n/03QwrN3IeWIFllvw== + dependencies: + find-up "^3.0.0" + +pkg-dir@^4.1.0: + version "4.2.0" + resolved "https://registry.yarnpkg.com/pkg-dir/-/pkg-dir-4.2.0.tgz#f099133df7ede422e81d1d8448270eeb3e4261f3" + integrity sha512-HRDzbaKjC+AOWVXxAU/x54COGeIv9eb+6CkDSQoNTt4XyWoIJvuPsXizxu/Fr23EiekbtZwmh1IcIG/l/a10GQ== + dependencies: + find-up "^4.0.0" + +pkg-up@3.1.0, pkg-up@^3.1.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/pkg-up/-/pkg-up-3.1.0.tgz#100ec235cc150e4fd42519412596a28512a0def5" + integrity sha512-nDywThFk1i4BQK4twPQ6TA4RT8bDY96yeuCVBWL3ePARCiEKDRSrNGbFIgUJpLp+XeIR65v8ra7WuJOFUBtkMA== + dependencies: + find-up "^3.0.0" + +pkg-up@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/pkg-up/-/pkg-up-2.0.0.tgz#c819ac728059a461cab1c3889a2be3c49a004d7f" + integrity sha1-yBmscoBZpGHKscOImivjxJoATX8= + dependencies: + find-up "^2.1.0" + +pn@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/pn/-/pn-1.1.0.tgz#e2f4cef0e219f463c179ab37463e4e1ecdccbafb" + integrity sha512-2qHaIQr2VLRFoxe2nASzsV6ef4yOOH+Fi9FBOVH6cqeSgUnoyySPZkxzLuzd+RYOQTRpROA0ztTMqxROKSb/nA== + +pnp-webpack-plugin@1.6.4: + version "1.6.4" + resolved "https://registry.yarnpkg.com/pnp-webpack-plugin/-/pnp-webpack-plugin-1.6.4.tgz#c9711ac4dc48a685dabafc86f8b6dd9f8df84149" + integrity sha512-7Wjy+9E3WwLOEL30D+m8TSTF7qJJUJLONBnwQp0518siuMxUQUbgZwssaFX+QKlZkjHZcw/IpZCt/H0srrntSg== + dependencies: + ts-pnp "^1.1.6" + +portfinder@^1.0.25: + version "1.0.25" + resolved "https://registry.yarnpkg.com/portfinder/-/portfinder-1.0.25.tgz#254fd337ffba869f4b9d37edc298059cb4d35eca" + integrity sha512-6ElJnHBbxVA1XSLgBp7G1FiCkQdlqGzuF7DswL5tcea+E8UpuvPU7beVAjjRwCioTS9ZluNbu+ZyRvgTsmqEBg== + dependencies: + async "^2.6.2" + debug "^3.1.1" + mkdirp "^0.5.1" + +posix-character-classes@^0.1.0: + version "0.1.1" + resolved "https://registry.yarnpkg.com/posix-character-classes/-/posix-character-classes-0.1.1.tgz#01eac0fe3b5af71a2a6c02feabb8c1fef7e00eab" + integrity sha1-AerA/jta9xoqbAL+q7jB/vfgDqs= + +postcss-attribute-case-insensitive@^4.0.1: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-attribute-case-insensitive/-/postcss-attribute-case-insensitive-4.0.2.tgz#d93e46b504589e94ac7277b0463226c68041a880" + integrity sha512-clkFxk/9pcdb4Vkn0hAHq3YnxBQ2p0CGD1dy24jN+reBck+EWxMbxSUqN4Yj7t0w8csl87K6p0gxBe1utkJsYA== + dependencies: + postcss "^7.0.2" + postcss-selector-parser "^6.0.2" + +postcss-browser-comments@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/postcss-browser-comments/-/postcss-browser-comments-3.0.0.tgz#1248d2d935fb72053c8e1f61a84a57292d9f65e9" + integrity sha512-qfVjLfq7HFd2e0HW4s1dvU8X080OZdG46fFbIBFjW7US7YPDcWfRvdElvwMJr2LI6hMmD+7LnH2HcmXTs+uOig== + dependencies: + postcss "^7" + +postcss-calc@^7.0.1: + version "7.0.2" + resolved "https://registry.yarnpkg.com/postcss-calc/-/postcss-calc-7.0.2.tgz#504efcd008ca0273120568b0792b16cdcde8aac1" + integrity sha512-rofZFHUg6ZIrvRwPeFktv06GdbDYLcGqh9EwiMutZg+a0oePCCw1zHOEiji6LCpyRcjTREtPASuUqeAvYlEVvQ== + dependencies: + postcss "^7.0.27" + postcss-selector-parser "^6.0.2" + postcss-value-parser "^4.0.2" + +postcss-color-functional-notation@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/postcss-color-functional-notation/-/postcss-color-functional-notation-2.0.1.tgz#5efd37a88fbabeb00a2966d1e53d98ced93f74e0" + integrity sha512-ZBARCypjEDofW4P6IdPVTLhDNXPRn8T2s1zHbZidW6rPaaZvcnCS2soYFIQJrMZSxiePJ2XIYTlcb2ztr/eT2g== + dependencies: + postcss "^7.0.2" + postcss-values-parser "^2.0.0" + +postcss-color-gray@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/postcss-color-gray/-/postcss-color-gray-5.0.0.tgz#532a31eb909f8da898ceffe296fdc1f864be8547" + integrity sha512-q6BuRnAGKM/ZRpfDascZlIZPjvwsRye7UDNalqVz3s7GDxMtqPY6+Q871liNxsonUw8oC61OG+PSaysYpl1bnw== + dependencies: + "@csstools/convert-colors" "^1.4.0" + postcss "^7.0.5" + postcss-values-parser "^2.0.0" + +postcss-color-hex-alpha@^5.0.3: + version "5.0.3" + resolved "https://registry.yarnpkg.com/postcss-color-hex-alpha/-/postcss-color-hex-alpha-5.0.3.tgz#a8d9ca4c39d497c9661e374b9c51899ef0f87388" + integrity sha512-PF4GDel8q3kkreVXKLAGNpHKilXsZ6xuu+mOQMHWHLPNyjiUBOr75sp5ZKJfmv1MCus5/DWUGcK9hm6qHEnXYw== + dependencies: + postcss "^7.0.14" + postcss-values-parser "^2.0.1" + +postcss-color-mod-function@^3.0.3: + version "3.0.3" + resolved "https://registry.yarnpkg.com/postcss-color-mod-function/-/postcss-color-mod-function-3.0.3.tgz#816ba145ac11cc3cb6baa905a75a49f903e4d31d" + integrity sha512-YP4VG+xufxaVtzV6ZmhEtc+/aTXH3d0JLpnYfxqTvwZPbJhWqp8bSY3nfNzNRFLgB4XSaBA82OE4VjOOKpCdVQ== + dependencies: + "@csstools/convert-colors" "^1.4.0" + postcss "^7.0.2" + postcss-values-parser "^2.0.0" + +postcss-color-rebeccapurple@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/postcss-color-rebeccapurple/-/postcss-color-rebeccapurple-4.0.1.tgz#c7a89be872bb74e45b1e3022bfe5748823e6de77" + integrity sha512-aAe3OhkS6qJXBbqzvZth2Au4V3KieR5sRQ4ptb2b2O8wgvB3SJBsdG+jsn2BZbbwekDG8nTfcCNKcSfe/lEy8g== + dependencies: + postcss "^7.0.2" + postcss-values-parser "^2.0.0" + +postcss-colormin@^4.0.3: + version "4.0.3" + resolved "https://registry.yarnpkg.com/postcss-colormin/-/postcss-colormin-4.0.3.tgz#ae060bce93ed794ac71264f08132d550956bd381" + integrity sha512-WyQFAdDZpExQh32j0U0feWisZ0dmOtPl44qYmJKkq9xFWY3p+4qnRzCHeNrkeRhwPHz9bQ3mo0/yVkaply0MNw== + dependencies: + browserslist "^4.0.0" + color "^3.0.0" + has "^1.0.0" + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + +postcss-convert-values@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/postcss-convert-values/-/postcss-convert-values-4.0.1.tgz#ca3813ed4da0f812f9d43703584e449ebe189a7f" + integrity sha512-Kisdo1y77KUC0Jmn0OXU/COOJbzM8cImvw1ZFsBgBgMgb1iL23Zs/LXRe3r+EZqM3vGYKdQ2YJVQ5VkJI+zEJQ== + dependencies: + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + +postcss-custom-media@^7.0.8: + version "7.0.8" + resolved "https://registry.yarnpkg.com/postcss-custom-media/-/postcss-custom-media-7.0.8.tgz#fffd13ffeffad73621be5f387076a28b00294e0c" + integrity sha512-c9s5iX0Ge15o00HKbuRuTqNndsJUbaXdiNsksnVH8H4gdc+zbLzr/UasOwNG6CTDpLFekVY4672eWdiiWu2GUg== + dependencies: + postcss "^7.0.14" + +postcss-custom-properties@^8.0.11: + version "8.0.11" + resolved "https://registry.yarnpkg.com/postcss-custom-properties/-/postcss-custom-properties-8.0.11.tgz#2d61772d6e92f22f5e0d52602df8fae46fa30d97" + integrity sha512-nm+o0eLdYqdnJ5abAJeXp4CEU1c1k+eB2yMCvhgzsds/e0umabFrN6HoTy/8Q4K5ilxERdl/JD1LO5ANoYBeMA== + dependencies: + postcss "^7.0.17" + postcss-values-parser "^2.0.1" + +postcss-custom-selectors@^5.1.2: + version "5.1.2" + resolved "https://registry.yarnpkg.com/postcss-custom-selectors/-/postcss-custom-selectors-5.1.2.tgz#64858c6eb2ecff2fb41d0b28c9dd7b3db4de7fba" + integrity sha512-DSGDhqinCqXqlS4R7KGxL1OSycd1lydugJ1ky4iRXPHdBRiozyMHrdu0H3o7qNOCiZwySZTUI5MV0T8QhCLu+w== + dependencies: + postcss "^7.0.2" + postcss-selector-parser "^5.0.0-rc.3" + +postcss-dir-pseudo-class@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/postcss-dir-pseudo-class/-/postcss-dir-pseudo-class-5.0.0.tgz#6e3a4177d0edb3abcc85fdb6fbb1c26dabaeaba2" + integrity sha512-3pm4oq8HYWMZePJY+5ANriPs3P07q+LW6FAdTlkFH2XqDdP4HeeJYMOzn0HYLhRSjBO3fhiqSwwU9xEULSrPgw== + dependencies: + postcss "^7.0.2" + postcss-selector-parser "^5.0.0-rc.3" + +postcss-discard-comments@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-discard-comments/-/postcss-discard-comments-4.0.2.tgz#1fbabd2c246bff6aaad7997b2b0918f4d7af4033" + integrity sha512-RJutN259iuRf3IW7GZyLM5Sw4GLTOH8FmsXBnv8Ab/Tc2k4SR4qbV4DNbyyY4+Sjo362SyDmW2DQ7lBSChrpkg== + dependencies: + postcss "^7.0.0" + +postcss-discard-duplicates@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-discard-duplicates/-/postcss-discard-duplicates-4.0.2.tgz#3fe133cd3c82282e550fc9b239176a9207b784eb" + integrity sha512-ZNQfR1gPNAiXZhgENFfEglF93pciw0WxMkJeVmw8eF+JZBbMD7jp6C67GqJAXVZP2BWbOztKfbsdmMp/k8c6oQ== + dependencies: + postcss "^7.0.0" + +postcss-discard-empty@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/postcss-discard-empty/-/postcss-discard-empty-4.0.1.tgz#c8c951e9f73ed9428019458444a02ad90bb9f765" + integrity sha512-B9miTzbznhDjTfjvipfHoqbWKwd0Mj+/fL5s1QOz06wufguil+Xheo4XpOnc4NqKYBCNqqEzgPv2aPBIJLox0w== + dependencies: + postcss "^7.0.0" + +postcss-discard-overridden@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/postcss-discard-overridden/-/postcss-discard-overridden-4.0.1.tgz#652aef8a96726f029f5e3e00146ee7a4e755ff57" + integrity sha512-IYY2bEDD7g1XM1IDEsUT4//iEYCxAmP5oDSFMVU/JVvT7gh+l4fmjciLqGgwjdWpQIdb0Che2VX00QObS5+cTg== + dependencies: + postcss "^7.0.0" + +postcss-double-position-gradients@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/postcss-double-position-gradients/-/postcss-double-position-gradients-1.0.0.tgz#fc927d52fddc896cb3a2812ebc5df147e110522e" + integrity sha512-G+nV8EnQq25fOI8CH/B6krEohGWnF5+3A6H/+JEpOncu5dCnkS1QQ6+ct3Jkaepw1NGVqqOZH6lqrm244mCftA== + dependencies: + postcss "^7.0.5" + postcss-values-parser "^2.0.0" + +postcss-env-function@^2.0.2: + version "2.0.2" + resolved "https://registry.yarnpkg.com/postcss-env-function/-/postcss-env-function-2.0.2.tgz#0f3e3d3c57f094a92c2baf4b6241f0b0da5365d7" + integrity sha512-rwac4BuZlITeUbiBq60h/xbLzXY43qOsIErngWa4l7Mt+RaSkT7QBjXVGTcBHupykkblHMDrBFh30zchYPaOUw== + dependencies: + postcss "^7.0.2" + postcss-values-parser "^2.0.0" + +postcss-flexbugs-fixes@4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/postcss-flexbugs-fixes/-/postcss-flexbugs-fixes-4.1.0.tgz#e094a9df1783e2200b7b19f875dcad3b3aff8b20" + integrity sha512-jr1LHxQvStNNAHlgco6PzY308zvLklh7SJVYuWUwyUQncofaAlD2l+P/gxKHOdqWKe7xJSkVLFF/2Tp+JqMSZA== + dependencies: + postcss "^7.0.0" + +postcss-focus-visible@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/postcss-focus-visible/-/postcss-focus-visible-4.0.0.tgz#477d107113ade6024b14128317ade2bd1e17046e" + integrity sha512-Z5CkWBw0+idJHSV6+Bgf2peDOFf/x4o+vX/pwcNYrWpXFrSfTkQ3JQ1ojrq9yS+upnAlNRHeg8uEwFTgorjI8g== + dependencies: + postcss "^7.0.2" + +postcss-focus-within@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/postcss-focus-within/-/postcss-focus-within-3.0.0.tgz#763b8788596cee9b874c999201cdde80659ef680" + integrity sha512-W0APui8jQeBKbCGZudW37EeMCjDeVxKgiYfIIEo8Bdh5SpB9sxds/Iq8SEuzS0Q4YFOlG7EPFulbbxujpkrV2w== + dependencies: + postcss "^7.0.2" + +postcss-font-variant@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/postcss-font-variant/-/postcss-font-variant-4.0.0.tgz#71dd3c6c10a0d846c5eda07803439617bbbabacc" + integrity sha512-M8BFYKOvCrI2aITzDad7kWuXXTm0YhGdP9Q8HanmN4EF1Hmcgs1KK5rSHylt/lUJe8yLxiSwWAHdScoEiIxztg== + dependencies: + postcss "^7.0.2" + +postcss-gap-properties@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/postcss-gap-properties/-/postcss-gap-properties-2.0.0.tgz#431c192ab3ed96a3c3d09f2ff615960f902c1715" + integrity sha512-QZSqDaMgXCHuHTEzMsS2KfVDOq7ZFiknSpkrPJY6jmxbugUPTuSzs/vuE5I3zv0WAS+3vhrlqhijiprnuQfzmg== + dependencies: + postcss "^7.0.2" + +postcss-image-set-function@^3.0.1: + version "3.0.1" + resolved "https://registry.yarnpkg.com/postcss-image-set-function/-/postcss-image-set-function-3.0.1.tgz#28920a2f29945bed4c3198d7df6496d410d3f288" + integrity sha512-oPTcFFip5LZy8Y/whto91L9xdRHCWEMs3e1MdJxhgt4jy2WYXfhkng59fH5qLXSCPN8k4n94p1Czrfe5IOkKUw== + dependencies: + postcss "^7.0.2" + postcss-values-parser "^2.0.0" + +postcss-initial@^3.0.0: + version "3.0.2" + resolved "https://registry.yarnpkg.com/postcss-initial/-/postcss-initial-3.0.2.tgz#f018563694b3c16ae8eaabe3c585ac6319637b2d" + integrity sha512-ugA2wKonC0xeNHgirR4D3VWHs2JcU08WAi1KFLVcnb7IN89phID6Qtg2RIctWbnvp1TM2BOmDtX8GGLCKdR8YA== + dependencies: + lodash.template "^4.5.0" + postcss "^7.0.2" + +postcss-lab-function@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/postcss-lab-function/-/postcss-lab-function-2.0.1.tgz#bb51a6856cd12289ab4ae20db1e3821ef13d7d2e" + integrity sha512-whLy1IeZKY+3fYdqQFuDBf8Auw+qFuVnChWjmxm/UhHWqNHZx+B99EwxTvGYmUBqe3Fjxs4L1BoZTJmPu6usVg== + dependencies: + "@csstools/convert-colors" "^1.4.0" + postcss "^7.0.2" + postcss-values-parser "^2.0.0" + +postcss-load-config@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/postcss-load-config/-/postcss-load-config-2.1.0.tgz#c84d692b7bb7b41ddced94ee62e8ab31b417b003" + integrity sha512-4pV3JJVPLd5+RueiVVB+gFOAa7GWc25XQcMp86Zexzke69mKf6Nx9LRcQywdz7yZI9n1udOxmLuAwTBypypF8Q== + dependencies: + cosmiconfig "^5.0.0" + import-cwd "^2.0.0" + +postcss-loader@3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/postcss-loader/-/postcss-loader-3.0.0.tgz#6b97943e47c72d845fa9e03f273773d4e8dd6c2d" + integrity sha512-cLWoDEY5OwHcAjDnkyRQzAXfs2jrKjXpO/HQFcc5b5u/r7aa471wdmChmwfnv7x2u840iat/wi0lQ5nbRgSkUA== + dependencies: + loader-utils "^1.1.0" + postcss "^7.0.0" + postcss-load-config "^2.0.0" + schema-utils "^1.0.0" + +postcss-logical@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/postcss-logical/-/postcss-logical-3.0.0.tgz#2495d0f8b82e9f262725f75f9401b34e7b45d5b5" + integrity sha512-1SUKdJc2vuMOmeItqGuNaC+N8MzBWFWEkAnRnLpFYj1tGGa7NqyVBujfRtgNa2gXR+6RkGUiB2O5Vmh7E2RmiA== + dependencies: + postcss "^7.0.2" + +postcss-media-minmax@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/postcss-media-minmax/-/postcss-media-minmax-4.0.0.tgz#b75bb6cbc217c8ac49433e12f22048814a4f5ed5" + integrity sha512-fo9moya6qyxsjbFAYl97qKO9gyre3qvbMnkOZeZwlsW6XYFsvs2DMGDlchVLfAd8LHPZDxivu/+qW2SMQeTHBw== + dependencies: + postcss "^7.0.2" + +postcss-merge-longhand@^4.0.11: + version "4.0.11" + resolved "https://registry.yarnpkg.com/postcss-merge-longhand/-/postcss-merge-longhand-4.0.11.tgz#62f49a13e4a0ee04e7b98f42bb16062ca2549e24" + integrity sha512-alx/zmoeXvJjp7L4mxEMjh8lxVlDFX1gqWHzaaQewwMZiVhLo42TEClKaeHbRf6J7j82ZOdTJ808RtN0ZOZwvw== + dependencies: + css-color-names "0.0.4" + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + stylehacks "^4.0.0" + +postcss-merge-rules@^4.0.3: + version "4.0.3" + resolved "https://registry.yarnpkg.com/postcss-merge-rules/-/postcss-merge-rules-4.0.3.tgz#362bea4ff5a1f98e4075a713c6cb25aefef9a650" + integrity sha512-U7e3r1SbvYzO0Jr3UT/zKBVgYYyhAz0aitvGIYOYK5CPmkNih+WDSsS5tvPrJ8YMQYlEMvsZIiqmn7HdFUaeEQ== + dependencies: + browserslist "^4.0.0" + caniuse-api "^3.0.0" + cssnano-util-same-parent "^4.0.0" + postcss "^7.0.0" + postcss-selector-parser "^3.0.0" + vendors "^1.0.0" + +postcss-minify-font-values@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-minify-font-values/-/postcss-minify-font-values-4.0.2.tgz#cd4c344cce474343fac5d82206ab2cbcb8afd5a6" + integrity sha512-j85oO6OnRU9zPf04+PZv1LYIYOprWm6IA6zkXkrJXyRveDEuQggG6tvoy8ir8ZwjLxLuGfNkCZEQG7zan+Hbtg== + dependencies: + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + +postcss-minify-gradients@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-minify-gradients/-/postcss-minify-gradients-4.0.2.tgz#93b29c2ff5099c535eecda56c4aa6e665a663471" + integrity sha512-qKPfwlONdcf/AndP1U8SJ/uzIJtowHlMaSioKzebAXSG4iJthlWC9iSWznQcX4f66gIWX44RSA841HTHj3wK+Q== + dependencies: + cssnano-util-get-arguments "^4.0.0" + is-color-stop "^1.0.0" + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + +postcss-minify-params@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-minify-params/-/postcss-minify-params-4.0.2.tgz#6b9cef030c11e35261f95f618c90036d680db874" + integrity sha512-G7eWyzEx0xL4/wiBBJxJOz48zAKV2WG3iZOqVhPet/9geefm/Px5uo1fzlHu+DOjT+m0Mmiz3jkQzVHe6wxAWg== + dependencies: + alphanum-sort "^1.0.0" + browserslist "^4.0.0" + cssnano-util-get-arguments "^4.0.0" + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + uniqs "^2.0.0" + +postcss-minify-selectors@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-minify-selectors/-/postcss-minify-selectors-4.0.2.tgz#e2e5eb40bfee500d0cd9243500f5f8ea4262fbd8" + integrity sha512-D5S1iViljXBj9kflQo4YutWnJmwm8VvIsU1GeXJGiG9j8CIg9zs4voPMdQDUmIxetUOh60VilsNzCiAFTOqu3g== + dependencies: + alphanum-sort "^1.0.0" + has "^1.0.0" + postcss "^7.0.0" + postcss-selector-parser "^3.0.0" + +postcss-modules-extract-imports@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/postcss-modules-extract-imports/-/postcss-modules-extract-imports-2.0.0.tgz#818719a1ae1da325f9832446b01136eeb493cd7e" + integrity sha512-LaYLDNS4SG8Q5WAWqIJgdHPJrDDr/Lv775rMBFUbgjTz6j34lUznACHcdRWroPvXANP2Vj7yNK57vp9eFqzLWQ== + dependencies: + postcss "^7.0.5" + +postcss-modules-local-by-default@^3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/postcss-modules-local-by-default/-/postcss-modules-local-by-default-3.0.2.tgz#e8a6561be914aaf3c052876377524ca90dbb7915" + integrity sha512-jM/V8eqM4oJ/22j0gx4jrp63GSvDH6v86OqyTHHUvk4/k1vceipZsaymiZ5PvocqZOl5SFHiFJqjs3la0wnfIQ== + dependencies: + icss-utils "^4.1.1" + postcss "^7.0.16" + postcss-selector-parser "^6.0.2" + postcss-value-parser "^4.0.0" + +postcss-modules-scope@^2.1.1: + version "2.2.0" + resolved "https://registry.yarnpkg.com/postcss-modules-scope/-/postcss-modules-scope-2.2.0.tgz#385cae013cc7743f5a7d7602d1073a89eaae62ee" + integrity sha512-YyEgsTMRpNd+HmyC7H/mh3y+MeFWevy7V1evVhJWewmMbjDHIbZbOXICC2y+m1xI1UVfIT1HMW/O04Hxyu9oXQ== + dependencies: + postcss "^7.0.6" + postcss-selector-parser "^6.0.0" + +postcss-modules-values@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/postcss-modules-values/-/postcss-modules-values-3.0.0.tgz#5b5000d6ebae29b4255301b4a3a54574423e7f10" + integrity sha512-1//E5jCBrZ9DmRX+zCtmQtRSV6PV42Ix7Bzj9GbwJceduuf7IqP8MgeTXuRDHOWj2m0VzZD5+roFWDuU8RQjcg== + dependencies: + icss-utils "^4.0.0" + postcss "^7.0.6" + +postcss-nesting@^7.0.0: + version "7.0.1" + resolved "https://registry.yarnpkg.com/postcss-nesting/-/postcss-nesting-7.0.1.tgz#b50ad7b7f0173e5b5e3880c3501344703e04c052" + integrity sha512-FrorPb0H3nuVq0Sff7W2rnc3SmIcruVC6YwpcS+k687VxyxO33iE1amna7wHuRVzM8vfiYofXSBHNAZ3QhLvYg== + dependencies: + postcss "^7.0.2" + +postcss-normalize-charset@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/postcss-normalize-charset/-/postcss-normalize-charset-4.0.1.tgz#8b35add3aee83a136b0471e0d59be58a50285dd4" + integrity sha512-gMXCrrlWh6G27U0hF3vNvR3w8I1s2wOBILvA87iNXaPvSNo5uZAMYsZG7XjCUf1eVxuPfyL4TJ7++SGZLc9A3g== + dependencies: + postcss "^7.0.0" + +postcss-normalize-display-values@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-normalize-display-values/-/postcss-normalize-display-values-4.0.2.tgz#0dbe04a4ce9063d4667ed2be476bb830c825935a" + integrity sha512-3F2jcsaMW7+VtRMAqf/3m4cPFhPD3EFRgNs18u+k3lTJJlVe7d0YPO+bnwqo2xg8YiRpDXJI2u8A0wqJxMsQuQ== + dependencies: + cssnano-util-get-match "^4.0.0" + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + +postcss-normalize-positions@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-normalize-positions/-/postcss-normalize-positions-4.0.2.tgz#05f757f84f260437378368a91f8932d4b102917f" + integrity sha512-Dlf3/9AxpxE+NF1fJxYDeggi5WwV35MXGFnnoccP/9qDtFrTArZ0D0R+iKcg5WsUd8nUYMIl8yXDCtcrT8JrdA== + dependencies: + cssnano-util-get-arguments "^4.0.0" + has "^1.0.0" + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + +postcss-normalize-repeat-style@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-normalize-repeat-style/-/postcss-normalize-repeat-style-4.0.2.tgz#c4ebbc289f3991a028d44751cbdd11918b17910c" + integrity sha512-qvigdYYMpSuoFs3Is/f5nHdRLJN/ITA7huIoCyqqENJe9PvPmLhNLMu7QTjPdtnVf6OcYYO5SHonx4+fbJE1+Q== + dependencies: + cssnano-util-get-arguments "^4.0.0" + cssnano-util-get-match "^4.0.0" + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + +postcss-normalize-string@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-normalize-string/-/postcss-normalize-string-4.0.2.tgz#cd44c40ab07a0c7a36dc5e99aace1eca4ec2690c" + integrity sha512-RrERod97Dnwqq49WNz8qo66ps0swYZDSb6rM57kN2J+aoyEAJfZ6bMx0sx/F9TIEX0xthPGCmeyiam/jXif0eA== + dependencies: + has "^1.0.0" + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + +postcss-normalize-timing-functions@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-normalize-timing-functions/-/postcss-normalize-timing-functions-4.0.2.tgz#8e009ca2a3949cdaf8ad23e6b6ab99cb5e7d28d9" + integrity sha512-acwJY95edP762e++00Ehq9L4sZCEcOPyaHwoaFOhIwWCDfik6YvqsYNxckee65JHLKzuNSSmAdxwD2Cud1Z54A== + dependencies: + cssnano-util-get-match "^4.0.0" + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + +postcss-normalize-unicode@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/postcss-normalize-unicode/-/postcss-normalize-unicode-4.0.1.tgz#841bd48fdcf3019ad4baa7493a3d363b52ae1cfb" + integrity sha512-od18Uq2wCYn+vZ/qCOeutvHjB5jm57ToxRaMeNuf0nWVHaP9Hua56QyMF6fs/4FSUnVIw0CBPsU0K4LnBPwYwg== + dependencies: + browserslist "^4.0.0" + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + +postcss-normalize-url@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/postcss-normalize-url/-/postcss-normalize-url-4.0.1.tgz#10e437f86bc7c7e58f7b9652ed878daaa95faae1" + integrity sha512-p5oVaF4+IHwu7VpMan/SSpmpYxcJMtkGppYf0VbdH5B6hN8YNmVyJLuY9FmLQTzY3fag5ESUUHDqM+heid0UVA== + dependencies: + is-absolute-url "^2.0.0" + normalize-url "^3.0.0" + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + +postcss-normalize-whitespace@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-normalize-whitespace/-/postcss-normalize-whitespace-4.0.2.tgz#bf1d4070fe4fcea87d1348e825d8cc0c5faa7d82" + integrity sha512-tO8QIgrsI3p95r8fyqKV+ufKlSHh9hMJqACqbv2XknufqEDhDvbguXGBBqxw9nsQoXWf0qOqppziKJKHMD4GtA== + dependencies: + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + +postcss-normalize@8.0.1: + version "8.0.1" + resolved "https://registry.yarnpkg.com/postcss-normalize/-/postcss-normalize-8.0.1.tgz#90e80a7763d7fdf2da6f2f0f82be832ce4f66776" + integrity sha512-rt9JMS/m9FHIRroDDBGSMsyW1c0fkvOJPy62ggxSHUldJO7B195TqFMqIf+lY5ezpDcYOV4j86aUp3/XbxzCCQ== + dependencies: + "@csstools/normalize.css" "^10.1.0" + browserslist "^4.6.2" + postcss "^7.0.17" + postcss-browser-comments "^3.0.0" + sanitize.css "^10.0.0" + +postcss-ordered-values@^4.1.2: + version "4.1.2" + resolved "https://registry.yarnpkg.com/postcss-ordered-values/-/postcss-ordered-values-4.1.2.tgz#0cf75c820ec7d5c4d280189559e0b571ebac0eee" + integrity sha512-2fCObh5UanxvSxeXrtLtlwVThBvHn6MQcu4ksNT2tsaV2Fg76R2CV98W7wNSlX+5/pFwEyaDwKLLoEV7uRybAw== + dependencies: + cssnano-util-get-arguments "^4.0.0" + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + +postcss-overflow-shorthand@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/postcss-overflow-shorthand/-/postcss-overflow-shorthand-2.0.0.tgz#31ecf350e9c6f6ddc250a78f0c3e111f32dd4c30" + integrity sha512-aK0fHc9CBNx8jbzMYhshZcEv8LtYnBIRYQD5i7w/K/wS9c2+0NSR6B3OVMu5y0hBHYLcMGjfU+dmWYNKH0I85g== + dependencies: + postcss "^7.0.2" + +postcss-page-break@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/postcss-page-break/-/postcss-page-break-2.0.0.tgz#add52d0e0a528cabe6afee8b46e2abb277df46bf" + integrity sha512-tkpTSrLpfLfD9HvgOlJuigLuk39wVTbbd8RKcy8/ugV2bNBUW3xU+AIqyxhDrQr1VUj1RmyJrBn1YWrqUm9zAQ== + dependencies: + postcss "^7.0.2" + +postcss-place@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/postcss-place/-/postcss-place-4.0.1.tgz#e9f39d33d2dc584e46ee1db45adb77ca9d1dcc62" + integrity sha512-Zb6byCSLkgRKLODj/5mQugyuj9bvAAw9LqJJjgwz5cYryGeXfFZfSXoP1UfveccFmeq0b/2xxwcTEVScnqGxBg== + dependencies: + postcss "^7.0.2" + postcss-values-parser "^2.0.0" + +postcss-preset-env@6.7.0: + version "6.7.0" + resolved "https://registry.yarnpkg.com/postcss-preset-env/-/postcss-preset-env-6.7.0.tgz#c34ddacf8f902383b35ad1e030f178f4cdf118a5" + integrity sha512-eU4/K5xzSFwUFJ8hTdTQzo2RBLbDVt83QZrAvI07TULOkmyQlnYlpwep+2yIK+K+0KlZO4BvFcleOCCcUtwchg== + dependencies: + autoprefixer "^9.6.1" + browserslist "^4.6.4" + caniuse-lite "^1.0.30000981" + css-blank-pseudo "^0.1.4" + css-has-pseudo "^0.10.0" + css-prefers-color-scheme "^3.1.1" + cssdb "^4.4.0" + postcss "^7.0.17" + postcss-attribute-case-insensitive "^4.0.1" + postcss-color-functional-notation "^2.0.1" + postcss-color-gray "^5.0.0" + postcss-color-hex-alpha "^5.0.3" + postcss-color-mod-function "^3.0.3" + postcss-color-rebeccapurple "^4.0.1" + postcss-custom-media "^7.0.8" + postcss-custom-properties "^8.0.11" + postcss-custom-selectors "^5.1.2" + postcss-dir-pseudo-class "^5.0.0" + postcss-double-position-gradients "^1.0.0" + postcss-env-function "^2.0.2" + postcss-focus-visible "^4.0.0" + postcss-focus-within "^3.0.0" + postcss-font-variant "^4.0.0" + postcss-gap-properties "^2.0.0" + postcss-image-set-function "^3.0.1" + postcss-initial "^3.0.0" + postcss-lab-function "^2.0.1" + postcss-logical "^3.0.0" + postcss-media-minmax "^4.0.0" + postcss-nesting "^7.0.0" + postcss-overflow-shorthand "^2.0.0" + postcss-page-break "^2.0.0" + postcss-place "^4.0.1" + postcss-pseudo-class-any-link "^6.0.0" + postcss-replace-overflow-wrap "^3.0.0" + postcss-selector-matches "^4.0.0" + postcss-selector-not "^4.0.0" + +postcss-pseudo-class-any-link@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/postcss-pseudo-class-any-link/-/postcss-pseudo-class-any-link-6.0.0.tgz#2ed3eed393b3702879dec4a87032b210daeb04d1" + integrity sha512-lgXW9sYJdLqtmw23otOzrtbDXofUdfYzNm4PIpNE322/swES3VU9XlXHeJS46zT2onFO7V1QFdD4Q9LiZj8mew== + dependencies: + postcss "^7.0.2" + postcss-selector-parser "^5.0.0-rc.3" + +postcss-reduce-initial@^4.0.3: + version "4.0.3" + resolved "https://registry.yarnpkg.com/postcss-reduce-initial/-/postcss-reduce-initial-4.0.3.tgz#7fd42ebea5e9c814609639e2c2e84ae270ba48df" + integrity sha512-gKWmR5aUulSjbzOfD9AlJiHCGH6AEVLaM0AV+aSioxUDd16qXP1PCh8d1/BGVvpdWn8k/HiK7n6TjeoXN1F7DA== + dependencies: + browserslist "^4.0.0" + caniuse-api "^3.0.0" + has "^1.0.0" + postcss "^7.0.0" + +postcss-reduce-transforms@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-reduce-transforms/-/postcss-reduce-transforms-4.0.2.tgz#17efa405eacc6e07be3414a5ca2d1074681d4e29" + integrity sha512-EEVig1Q2QJ4ELpJXMZR8Vt5DQx8/mo+dGWSR7vWXqcob2gQLyQGsionYcGKATXvQzMPn6DSN1vTN7yFximdIAg== + dependencies: + cssnano-util-get-match "^4.0.0" + has "^1.0.0" + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + +postcss-replace-overflow-wrap@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/postcss-replace-overflow-wrap/-/postcss-replace-overflow-wrap-3.0.0.tgz#61b360ffdaedca84c7c918d2b0f0d0ea559ab01c" + integrity sha512-2T5hcEHArDT6X9+9dVSPQdo7QHzG4XKclFT8rU5TzJPDN7RIRTbO9c4drUISOVemLj03aezStHCR2AIcr8XLpw== + dependencies: + postcss "^7.0.2" + +postcss-safe-parser@4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/postcss-safe-parser/-/postcss-safe-parser-4.0.1.tgz#8756d9e4c36fdce2c72b091bbc8ca176ab1fcdea" + integrity sha512-xZsFA3uX8MO3yAda03QrG3/Eg1LN3EPfjjf07vke/46HERLZyHrTsQ9E1r1w1W//fWEhtYNndo2hQplN2cVpCQ== + dependencies: + postcss "^7.0.0" + +postcss-selector-matches@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/postcss-selector-matches/-/postcss-selector-matches-4.0.0.tgz#71c8248f917ba2cc93037c9637ee09c64436fcff" + integrity sha512-LgsHwQR/EsRYSqlwdGzeaPKVT0Ml7LAT6E75T8W8xLJY62CE4S/l03BWIt3jT8Taq22kXP08s2SfTSzaraoPww== + dependencies: + balanced-match "^1.0.0" + postcss "^7.0.2" + +postcss-selector-not@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/postcss-selector-not/-/postcss-selector-not-4.0.0.tgz#c68ff7ba96527499e832724a2674d65603b645c0" + integrity sha512-W+bkBZRhqJaYN8XAnbbZPLWMvZD1wKTu0UxtFKdhtGjWYmxhkUneoeOhRJKdAE5V7ZTlnbHfCR+6bNwK9e1dTQ== + dependencies: + balanced-match "^1.0.0" + postcss "^7.0.2" + +postcss-selector-parser@^3.0.0: + version "3.1.2" + resolved "https://registry.yarnpkg.com/postcss-selector-parser/-/postcss-selector-parser-3.1.2.tgz#b310f5c4c0fdaf76f94902bbaa30db6aa84f5270" + integrity sha512-h7fJ/5uWuRVyOtkO45pnt1Ih40CEleeyCHzipqAZO2e5H20g25Y48uYnFUiShvY4rZWNJ/Bib/KVPmanaCtOhA== + dependencies: + dot-prop "^5.2.0" + indexes-of "^1.0.1" + uniq "^1.0.1" + +postcss-selector-parser@^5.0.0-rc.3, postcss-selector-parser@^5.0.0-rc.4: + version "5.0.0" + resolved "https://registry.yarnpkg.com/postcss-selector-parser/-/postcss-selector-parser-5.0.0.tgz#249044356697b33b64f1a8f7c80922dddee7195c" + integrity sha512-w+zLE5Jhg6Liz8+rQOWEAwtwkyqpfnmsinXjXg6cY7YIONZZtgvE0v2O0uhQBs0peNomOJwWRKt6JBfTdTd3OQ== + dependencies: + cssesc "^2.0.0" + indexes-of "^1.0.1" + uniq "^1.0.1" + +postcss-selector-parser@^6.0.0, postcss-selector-parser@^6.0.2: + version "6.0.2" + resolved "https://registry.yarnpkg.com/postcss-selector-parser/-/postcss-selector-parser-6.0.2.tgz#934cf799d016c83411859e09dcecade01286ec5c" + integrity sha512-36P2QR59jDTOAiIkqEprfJDsoNrvwFei3eCqKd1Y0tUsBimsq39BLp7RD+JWny3WgB1zGhJX8XVePwm9k4wdBg== + dependencies: + cssesc "^3.0.0" + indexes-of "^1.0.1" + uniq "^1.0.1" + +postcss-svgo@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/postcss-svgo/-/postcss-svgo-4.0.2.tgz#17b997bc711b333bab143aaed3b8d3d6e3d38258" + integrity sha512-C6wyjo3VwFm0QgBy+Fu7gCYOkCmgmClghO+pjcxvrcBKtiKt0uCF+hvbMO1fyv5BMImRK90SMb+dwUnfbGd+jw== + dependencies: + is-svg "^3.0.0" + postcss "^7.0.0" + postcss-value-parser "^3.0.0" + svgo "^1.0.0" + +postcss-unique-selectors@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/postcss-unique-selectors/-/postcss-unique-selectors-4.0.1.tgz#9446911f3289bfd64c6d680f073c03b1f9ee4bac" + integrity sha512-+JanVaryLo9QwZjKrmJgkI4Fn8SBgRO6WXQBJi7KiAVPlmxikB5Jzc4EvXMT2H0/m0RjrVVm9rGNhZddm/8Spg== + dependencies: + alphanum-sort "^1.0.0" + postcss "^7.0.0" + uniqs "^2.0.0" + +postcss-value-parser@^3.0.0: + version "3.3.1" + resolved "https://registry.yarnpkg.com/postcss-value-parser/-/postcss-value-parser-3.3.1.tgz#9ff822547e2893213cf1c30efa51ac5fd1ba8281" + integrity sha512-pISE66AbVkp4fDQ7VHBwRNXzAAKJjw4Vw7nWI/+Q3vuly7SNfgYXvm6i5IgFylHGK5sP/xHAbB7N49OS4gWNyQ== + +postcss-value-parser@^4.0.0, postcss-value-parser@^4.0.2, postcss-value-parser@^4.0.3: + version "4.0.3" + resolved "https://registry.yarnpkg.com/postcss-value-parser/-/postcss-value-parser-4.0.3.tgz#651ff4593aa9eda8d5d0d66593a2417aeaeb325d" + integrity sha512-N7h4pG+Nnu5BEIzyeaaIYWs0LI5XC40OrRh5L60z0QjFsqGWcHcbkBvpe1WYpcIS9yQ8sOi/vIPt1ejQCrMVrg== + +postcss-values-parser@^2.0.0, postcss-values-parser@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/postcss-values-parser/-/postcss-values-parser-2.0.1.tgz#da8b472d901da1e205b47bdc98637b9e9e550e5f" + integrity sha512-2tLuBsA6P4rYTNKCXYG/71C7j1pU6pK503suYOmn4xYrQIzW+opD+7FAFNuGSdZC/3Qfy334QbeMu7MEb8gOxg== + dependencies: + flatten "^1.0.2" + indexes-of "^1.0.1" + uniq "^1.0.1" + +postcss@7.0.21: + version "7.0.21" + resolved "https://registry.yarnpkg.com/postcss/-/postcss-7.0.21.tgz#06bb07824c19c2021c5d056d5b10c35b989f7e17" + integrity sha512-uIFtJElxJo29QC753JzhidoAhvp/e/Exezkdhfmt8AymWT6/5B7W1WmponYWkHk2eg6sONyTch0A3nkMPun3SQ== + dependencies: + chalk "^2.4.2" + source-map "^0.6.1" + supports-color "^6.1.0" + +postcss@^7, postcss@^7.0.0, postcss@^7.0.1, postcss@^7.0.14, postcss@^7.0.16, postcss@^7.0.17, postcss@^7.0.2, postcss@^7.0.23, postcss@^7.0.27, postcss@^7.0.5, postcss@^7.0.6: + version "7.0.27" + resolved "https://registry.yarnpkg.com/postcss/-/postcss-7.0.27.tgz#cc67cdc6b0daa375105b7c424a85567345fc54d9" + integrity sha512-WuQETPMcW9Uf1/22HWUWP9lgsIC+KEHg2kozMflKjbeUtw9ujvFX6QmIfozaErDkmLWS9WEnEdEe6Uo9/BNTdQ== + dependencies: + chalk "^2.4.2" + source-map "^0.6.1" + supports-color "^6.1.0" + +prelude-ls@~1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/prelude-ls/-/prelude-ls-1.1.2.tgz#21932a549f5e52ffd9a827f570e04be62a97da54" + integrity sha1-IZMqVJ9eUv/ZqCf1cOBL5iqX2lQ= + +prepend-http@^1.0.0: + version "1.0.4" + resolved "https://registry.yarnpkg.com/prepend-http/-/prepend-http-1.0.4.tgz#d4f4562b0ce3696e41ac52d0e002e57a635dc6dc" + integrity sha1-1PRWKwzjaW5BrFLQ4ALlemNdxtw= + +pretty-bytes@^5.1.0: + version "5.3.0" + resolved "https://registry.yarnpkg.com/pretty-bytes/-/pretty-bytes-5.3.0.tgz#f2849e27db79fb4d6cfe24764fc4134f165989f2" + integrity sha512-hjGrh+P926p4R4WbaB6OckyRtO0F0/lQBiT+0gnxjV+5kjPBrfVBFCsCLbMqVQeydvIoouYTCmmEURiH3R1Bdg== + +pretty-error@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/pretty-error/-/pretty-error-2.1.1.tgz#5f4f87c8f91e5ae3f3ba87ab4cf5e03b1a17f1a3" + integrity sha1-X0+HyPkeWuPzuoerTPXgOxoX8aM= + dependencies: + renderkid "^2.0.1" + utila "~0.4" + +pretty-format@^24.9.0: + version "24.9.0" + resolved "https://registry.yarnpkg.com/pretty-format/-/pretty-format-24.9.0.tgz#12fac31b37019a4eea3c11aa9a959eb7628aa7c9" + integrity sha512-00ZMZUiHaJrNfk33guavqgvfJS30sLYf0f8+Srklv0AMPodGGHcoHgksZ3OThYnIvOd+8yMCn0YiEOogjlgsnA== + dependencies: + "@jest/types" "^24.9.0" + ansi-regex "^4.0.0" + ansi-styles "^3.2.0" + react-is "^16.8.4" + +pretty-format@^25.2.1, pretty-format@^25.5.0: + version "25.5.0" + resolved "https://registry.yarnpkg.com/pretty-format/-/pretty-format-25.5.0.tgz#7873c1d774f682c34b8d48b6743a2bf2ac55791a" + integrity sha512-kbo/kq2LQ/A/is0PQwsEHM7Ca6//bGPPvU6UnsdDRSKTWxT/ru/xb88v4BJf6a69H+uTytOEsTusT9ksd/1iWQ== + dependencies: + "@jest/types" "^25.5.0" + ansi-regex "^5.0.0" + ansi-styles "^4.0.0" + react-is "^16.12.0" + +private@^0.1.8: + version "0.1.8" + resolved "https://registry.yarnpkg.com/private/-/private-0.1.8.tgz#2381edb3689f7a53d653190060fcf822d2f368ff" + integrity sha512-VvivMrbvd2nKkiG38qjULzlc+4Vx4wm/whI9pQD35YrARNnhxeiRktSOhSukRLFNlzg6Br/cJPet5J/u19r/mg== + +process-nextick-args@~2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/process-nextick-args/-/process-nextick-args-2.0.1.tgz#7820d9b16120cc55ca9ae7792680ae7dba6d7fe2" + integrity sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag== + +process@^0.11.10: + version "0.11.10" + resolved "https://registry.yarnpkg.com/process/-/process-0.11.10.tgz#7332300e840161bda3e69a1d1d91a7d4bc16f182" + integrity sha1-czIwDoQBYb2j5podHZGn1LwW8YI= + +progress@^2.0.0: + version "2.0.3" + resolved "https://registry.yarnpkg.com/progress/-/progress-2.0.3.tgz#7e8cf8d8f5b8f239c1bc68beb4eb78567d572ef8" + integrity sha512-7PiHtLll5LdnKIMw100I+8xJXR5gW2QwWYkT6iJva0bXitZKa/XMrSbdmg3r2Xnaidz9Qumd0VPaMrZlF9V9sA== + +promise-inflight@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/promise-inflight/-/promise-inflight-1.0.1.tgz#98472870bf228132fcbdd868129bad12c3c029e3" + integrity sha1-mEcocL8igTL8vdhoEputEsPAKeM= + +promise@^8.0.3: + version "8.1.0" + resolved "https://registry.yarnpkg.com/promise/-/promise-8.1.0.tgz#697c25c3dfe7435dd79fcd58c38a135888eaf05e" + integrity sha512-W04AqnILOL/sPRXziNicCjSNRruLAuIHEOVBazepu0545DDNGYHz7ar9ZgZ1fMU8/MA4mVxp5rkBWRi6OXIy3Q== + dependencies: + asap "~2.0.6" + +prompts@^2.0.1: + version "2.3.2" + resolved "https://registry.yarnpkg.com/prompts/-/prompts-2.3.2.tgz#480572d89ecf39566d2bd3fe2c9fccb7c4c0b068" + integrity sha512-Q06uKs2CkNYVID0VqwfAl9mipo99zkBv/n2JtWY89Yxa3ZabWSrs0e2KTudKVa3peLUvYXMefDqIleLPVUBZMA== + dependencies: + kleur "^3.0.3" + sisteransi "^1.0.4" + +prop-types@^15.6.2, prop-types@^15.7.2: + version "15.7.2" + resolved "https://registry.yarnpkg.com/prop-types/-/prop-types-15.7.2.tgz#52c41e75b8c87e72b9d9360e0206b99dcbffa6c5" + integrity sha512-8QQikdH7//R2vurIJSutZ1smHYTcLpRWEOlHnzcWHmBYrOGUysKwSsrC89BCiFj3CbrfJ/nXFdJepOVrY1GCHQ== + dependencies: + loose-envify "^1.4.0" + object-assign "^4.1.1" + react-is "^16.8.1" + +proxy-addr@~2.0.5: + version "2.0.6" + resolved "https://registry.yarnpkg.com/proxy-addr/-/proxy-addr-2.0.6.tgz#fdc2336505447d3f2f2c638ed272caf614bbb2bf" + integrity sha512-dh/frvCBVmSsDYzw6n926jv974gddhkFPfiN8hPOi30Wax25QZyZEGveluCgliBnqmuM+UJmBErbAUFIoDbjOw== + dependencies: + forwarded "~0.1.2" + ipaddr.js "1.9.1" + +prr@~1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/prr/-/prr-1.0.1.tgz#d3fc114ba06995a45ec6893f484ceb1d78f5f476" + integrity sha1-0/wRS6BplaRexok/SEzrHXj19HY= + +psl@^1.1.28: + version "1.8.0" + resolved "https://registry.yarnpkg.com/psl/-/psl-1.8.0.tgz#9326f8bcfb013adcc005fdff056acce020e51c24" + integrity sha512-RIdOzyoavK+hA18OGGWDqUTsCLhtA7IcZ/6NCs4fFJaHBDab+pDDmDIByWFRQJq2Cd7r1OoQxBGKOaztq+hjIQ== + +public-encrypt@^4.0.0: + version "4.0.3" + resolved "https://registry.yarnpkg.com/public-encrypt/-/public-encrypt-4.0.3.tgz#4fcc9d77a07e48ba7527e7cbe0de33d0701331e0" + integrity sha512-zVpa8oKZSz5bTMTFClc1fQOnyyEzpl5ozpi1B5YcvBrdohMjH2rfsBtyXcuNuwjsDIXmBYlF2N5FlJYhR29t8Q== + dependencies: + bn.js "^4.1.0" + browserify-rsa "^4.0.0" + create-hash "^1.1.0" + parse-asn1 "^5.0.0" + randombytes "^2.0.1" + safe-buffer "^5.1.2" + +pump@^2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/pump/-/pump-2.0.1.tgz#12399add6e4cf7526d973cbc8b5ce2e2908b3909" + integrity sha512-ruPMNRkN3MHP1cWJc9OWr+T/xDP0jhXYCLfJcBuX54hhfIBnaQmAUMfDcG4DM5UMWByBbJY69QSphm3jtDKIkA== + dependencies: + end-of-stream "^1.1.0" + once "^1.3.1" + +pump@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/pump/-/pump-3.0.0.tgz#b4a2116815bde2f4e1ea602354e8c75565107a64" + integrity sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww== + dependencies: + end-of-stream "^1.1.0" + once "^1.3.1" + +pumpify@^1.3.3: + version "1.5.1" + resolved "https://registry.yarnpkg.com/pumpify/-/pumpify-1.5.1.tgz#36513be246ab27570b1a374a5ce278bfd74370ce" + integrity sha512-oClZI37HvuUJJxSKKrC17bZ9Cu0ZYhEAGPsPUy9KlMUmv9dKX2o77RUmq7f3XjIxbwyGwYzbzQ1L2Ks8sIradQ== + dependencies: + duplexify "^3.6.0" + inherits "^2.0.3" + pump "^2.0.0" + +punycode@1.3.2: + version "1.3.2" + resolved "https://registry.yarnpkg.com/punycode/-/punycode-1.3.2.tgz#9653a036fb7c1ee42342f2325cceefea3926c48d" + integrity sha1-llOgNvt8HuQjQvIyXM7v6jkmxI0= + +punycode@^1.2.4: + version "1.4.1" + resolved "https://registry.yarnpkg.com/punycode/-/punycode-1.4.1.tgz#c0d5a63b2718800ad8e1eb0fa5269c84dd41845e" + integrity sha1-wNWmOycYgArY4esPpSachN1BhF4= + +punycode@^2.1.0, punycode@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/punycode/-/punycode-2.1.1.tgz#b58b010ac40c22c5657616c8d2c2c02c7bf479ec" + integrity sha512-XRsRjdf+j5ml+y/6GKHPZbrF/8p2Yga0JPtdqTIY2Xe5ohJPD9saDJJLPvp9+NSBprVvevdXZybnj2cv8OEd0A== + +q@^1.1.2: + version "1.5.1" + resolved "https://registry.yarnpkg.com/q/-/q-1.5.1.tgz#7e32f75b41381291d04611f1bf14109ac00651d7" + integrity sha1-fjL3W0E4EpHQRhHxvxQQmsAGUdc= + +qs@6.7.0: + version "6.7.0" + resolved "https://registry.yarnpkg.com/qs/-/qs-6.7.0.tgz#41dc1a015e3d581f1621776be31afb2876a9b1bc" + integrity sha512-VCdBRNFTX1fyE7Nb6FYoURo/SPe62QCaAyzJvUjwRaIsc+NePBEniHlvxFmmX56+HZphIGtV0XeCirBtpDrTyQ== + +qs@~6.5.2: + version "6.5.2" + resolved "https://registry.yarnpkg.com/qs/-/qs-6.5.2.tgz#cb3ae806e8740444584ef154ce8ee98d403f3e36" + integrity sha512-N5ZAX4/LxJmF+7wN74pUD6qAh9/wnvdQcjq9TZjevvXzSUo7bfmw91saqMjzGS2xq91/odN2dW/WOl7qQHNDGA== + +query-string@^4.1.0: + version "4.3.4" + resolved "https://registry.yarnpkg.com/query-string/-/query-string-4.3.4.tgz#bbb693b9ca915c232515b228b1a02b609043dbeb" + integrity sha1-u7aTucqRXCMlFbIosaArYJBD2+s= + dependencies: + object-assign "^4.1.0" + strict-uri-encode "^1.0.0" + +querystring-es3@^0.2.0: + version "0.2.1" + resolved "https://registry.yarnpkg.com/querystring-es3/-/querystring-es3-0.2.1.tgz#9ec61f79049875707d69414596fd907a4d711e73" + integrity sha1-nsYfeQSYdXB9aUFFlv2Qek1xHnM= + +querystring@0.2.0: + version "0.2.0" + resolved "https://registry.yarnpkg.com/querystring/-/querystring-0.2.0.tgz#b209849203bb25df820da756e747005878521620" + integrity sha1-sgmEkgO7Jd+CDadW50cAWHhSFiA= + +querystringify@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/querystringify/-/querystringify-2.1.1.tgz#60e5a5fd64a7f8bfa4d2ab2ed6fdf4c85bad154e" + integrity sha512-w7fLxIRCRT7U8Qu53jQnJyPkYZIaR4n5151KMfcJlO/A9397Wxb1amJvROTK6TOnp7PfoAmg/qXiNHI+08jRfA== + +raf@^3.4.1: + version "3.4.1" + resolved "https://registry.yarnpkg.com/raf/-/raf-3.4.1.tgz#0742e99a4a6552f445d73e3ee0328af0ff1ede39" + integrity sha512-Sq4CW4QhwOHE8ucn6J34MqtZCeWFP2aQSmrlroYgqAV1PjStIhJXxYuTgUIfkEk7zTLjmIjLmU5q+fbD1NnOJA== + dependencies: + performance-now "^2.1.0" + +randombytes@^2.0.0, randombytes@^2.0.1, randombytes@^2.0.5: + version "2.1.0" + resolved "https://registry.yarnpkg.com/randombytes/-/randombytes-2.1.0.tgz#df6f84372f0270dc65cdf6291349ab7a473d4f2a" + integrity sha512-vYl3iOX+4CKUWuxGi9Ukhie6fsqXqS9FE2Zaic4tNFD2N2QQaXOMFbuKK4QmDHC0JO6B1Zp41J0LpT0oR68amQ== + dependencies: + safe-buffer "^5.1.0" + +randomfill@^1.0.3: + version "1.0.4" + resolved "https://registry.yarnpkg.com/randomfill/-/randomfill-1.0.4.tgz#c92196fc86ab42be983f1bf31778224931d61458" + integrity sha512-87lcbR8+MhcWcUiQ+9e+Rwx8MyR2P7qnt15ynUlbm3TU/fjbgz4GsvfSUDTemtCCtVCqb4ZcEFlyPNTh9bBTLw== + dependencies: + randombytes "^2.0.5" + safe-buffer "^5.1.0" + +range-parser@^1.2.1, range-parser@~1.2.1: + version "1.2.1" + resolved "https://registry.yarnpkg.com/range-parser/-/range-parser-1.2.1.tgz#3cf37023d199e1c24d1a55b84800c2f3e6468031" + integrity sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg== + +raw-body@2.4.0: + version "2.4.0" + resolved "https://registry.yarnpkg.com/raw-body/-/raw-body-2.4.0.tgz#a1ce6fb9c9bc356ca52e89256ab59059e13d0332" + integrity sha512-4Oz8DUIwdvoa5qMJelxipzi/iJIi40O5cGV1wNYp5hvZP8ZN0T+jiNkL0QepXs+EsQ9XJ8ipEDoiH70ySUJP3Q== + dependencies: + bytes "3.1.0" + http-errors "1.7.2" + iconv-lite "0.4.24" + unpipe "1.0.0" + +react-app-polyfill@^1.0.6: + version "1.0.6" + resolved "https://registry.yarnpkg.com/react-app-polyfill/-/react-app-polyfill-1.0.6.tgz#890f8d7f2842ce6073f030b117de9130a5f385f0" + integrity sha512-OfBnObtnGgLGfweORmdZbyEz+3dgVePQBb3zipiaDsMHV1NpWm0rDFYIVXFV/AK+x4VIIfWHhrdMIeoTLyRr2g== + dependencies: + core-js "^3.5.0" + object-assign "^4.1.1" + promise "^8.0.3" + raf "^3.4.1" + regenerator-runtime "^0.13.3" + whatwg-fetch "^3.0.0" + +react-chartjs-2@^2.11.1: + version "2.11.1" + resolved "https://registry.yarnpkg.com/react-chartjs-2/-/react-chartjs-2-2.11.1.tgz#a78d0df05fc8bc8ffcd4c4ab5b89a25dd2ca3278" + integrity sha512-G7cNq/n2Bkh/v4vcI+GKx7Q1xwZexKYhOSj2HmrFXlvNeaURWXun6KlOUpEQwi1cv9Tgs4H3kGywDWMrX2kxfA== + dependencies: + lodash "^4.17.19" + prop-types "^15.7.2" + +react-dev-utils@^10.2.1: + version "10.2.1" + resolved "https://registry.yarnpkg.com/react-dev-utils/-/react-dev-utils-10.2.1.tgz#f6de325ae25fa4d546d09df4bb1befdc6dd19c19" + integrity sha512-XxTbgJnYZmxuPtY3y/UV0D8/65NKkmaia4rXzViknVnZeVlklSh8u6TnaEYPfAi/Gh1TP4mEOXHI6jQOPbeakQ== + dependencies: + "@babel/code-frame" "7.8.3" + address "1.1.2" + browserslist "4.10.0" + chalk "2.4.2" + cross-spawn "7.0.1" + detect-port-alt "1.1.6" + escape-string-regexp "2.0.0" + filesize "6.0.1" + find-up "4.1.0" + fork-ts-checker-webpack-plugin "3.1.1" + global-modules "2.0.0" + globby "8.0.2" + gzip-size "5.1.1" + immer "1.10.0" + inquirer "7.0.4" + is-root "2.1.0" + loader-utils "1.2.3" + open "^7.0.2" + pkg-up "3.1.0" + react-error-overlay "^6.0.7" + recursive-readdir "2.2.2" + shell-quote "1.7.2" + strip-ansi "6.0.0" + text-table "0.2.0" + +react-dom@^16.9.0: + version "16.13.1" + resolved "https://registry.yarnpkg.com/react-dom/-/react-dom-16.13.1.tgz#c1bd37331a0486c078ee54c4740720993b2e0e7f" + integrity sha512-81PIMmVLnCNLO/fFOQxdQkvEq/+Hfpv24XNJfpyZhTRfO0QcmQIF/PgCa1zCOj2w1hrn12MFLyaJ/G0+Mxtfag== + dependencies: + loose-envify "^1.1.0" + object-assign "^4.1.1" + prop-types "^15.6.2" + scheduler "^0.19.1" + +react-error-overlay@^6.0.7: + version "6.0.7" + resolved "https://registry.yarnpkg.com/react-error-overlay/-/react-error-overlay-6.0.7.tgz#1dcfb459ab671d53f660a991513cb2f0a0553108" + integrity sha512-TAv1KJFh3RhqxNvhzxj6LeT5NWklP6rDr2a0jaTfsZ5wSZWHOGeqQyejUp3xxLfPt2UpyJEcVQB/zyPcmonNFA== + +react-is@^16.12.0, react-is@^16.8.1, react-is@^16.8.4: + version "16.13.1" + resolved "https://registry.yarnpkg.com/react-is/-/react-is-16.13.1.tgz#789729a4dc36de2999dc156dd6c1d9c18cea56a4" + integrity sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ== + +react-scripts@3.4.1: + version "3.4.1" + resolved "https://registry.yarnpkg.com/react-scripts/-/react-scripts-3.4.1.tgz#f551298b5c71985cc491b9acf3c8e8c0ae3ada0a" + integrity sha512-JpTdi/0Sfd31mZA6Ukx+lq5j1JoKItX7qqEK4OiACjVQletM1P38g49d9/D0yTxp9FrSF+xpJFStkGgKEIRjlQ== + dependencies: + "@babel/core" "7.9.0" + "@svgr/webpack" "4.3.3" + "@typescript-eslint/eslint-plugin" "^2.10.0" + "@typescript-eslint/parser" "^2.10.0" + babel-eslint "10.1.0" + babel-jest "^24.9.0" + babel-loader "8.1.0" + babel-plugin-named-asset-import "^0.3.6" + babel-preset-react-app "^9.1.2" + camelcase "^5.3.1" + case-sensitive-paths-webpack-plugin "2.3.0" + css-loader "3.4.2" + dotenv "8.2.0" + dotenv-expand "5.1.0" + eslint "^6.6.0" + eslint-config-react-app "^5.2.1" + eslint-loader "3.0.3" + eslint-plugin-flowtype "4.6.0" + eslint-plugin-import "2.20.1" + eslint-plugin-jsx-a11y "6.2.3" + eslint-plugin-react "7.19.0" + eslint-plugin-react-hooks "^1.6.1" + file-loader "4.3.0" + fs-extra "^8.1.0" + html-webpack-plugin "4.0.0-beta.11" + identity-obj-proxy "3.0.0" + jest "24.9.0" + jest-environment-jsdom-fourteen "1.0.1" + jest-resolve "24.9.0" + jest-watch-typeahead "0.4.2" + mini-css-extract-plugin "0.9.0" + optimize-css-assets-webpack-plugin "5.0.3" + pnp-webpack-plugin "1.6.4" + postcss-flexbugs-fixes "4.1.0" + postcss-loader "3.0.0" + postcss-normalize "8.0.1" + postcss-preset-env "6.7.0" + postcss-safe-parser "4.0.1" + react-app-polyfill "^1.0.6" + react-dev-utils "^10.2.1" + resolve "1.15.0" + resolve-url-loader "3.1.1" + sass-loader "8.0.2" + semver "6.3.0" + style-loader "0.23.1" + terser-webpack-plugin "2.3.5" + ts-pnp "1.1.6" + url-loader "2.3.0" + webpack "4.42.0" + webpack-dev-server "3.10.3" + webpack-manifest-plugin "2.2.0" + workbox-webpack-plugin "4.3.1" + optionalDependencies: + fsevents "2.1.2" + +react-tag-autocomplete@^5.11.1: + version "5.12.1" + resolved "https://registry.yarnpkg.com/react-tag-autocomplete/-/react-tag-autocomplete-5.12.1.tgz#f6a56aaa151771af3db99780b5f32150e827e835" + integrity sha512-tVY1U/QLfwtkkaPWVVE+MZl7w27Sbzms1atWT+x/PQdN46aPwrViMx02vUQ0TGD4IWPT5sFtdfv6/Pf0qcYhSw== + +react@^16.9.0: + version "16.13.1" + resolved "https://registry.yarnpkg.com/react/-/react-16.13.1.tgz#2e818822f1a9743122c063d6410d85c1e3afe48e" + integrity sha512-YMZQQq32xHLX0bz5Mnibv1/LHb3Sqzngu7xstSM+vrkE5Kzr9xE0yMByK5kMoTK30YVJE61WfbxIFFvfeDKT1w== + dependencies: + loose-envify "^1.1.0" + object-assign "^4.1.1" + prop-types "^15.6.2" + +read-pkg-up@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/read-pkg-up/-/read-pkg-up-2.0.0.tgz#6b72a8048984e0c41e79510fd5e9fa99b3b549be" + integrity sha1-a3KoBImE4MQeeVEP1en6mbO1Sb4= + dependencies: + find-up "^2.0.0" + read-pkg "^2.0.0" + +read-pkg-up@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/read-pkg-up/-/read-pkg-up-4.0.0.tgz#1b221c6088ba7799601c808f91161c66e58f8978" + integrity sha512-6etQSH7nJGsK0RbG/2TeDzZFa8shjQ1um+SwQQ5cwKy0dhSXdOncEhb1CPpvQG4h7FyOV6EB6YlV0yJvZQNAkA== + dependencies: + find-up "^3.0.0" + read-pkg "^3.0.0" + +read-pkg@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/read-pkg/-/read-pkg-2.0.0.tgz#8ef1c0623c6a6db0dc6713c4bfac46332b2368f8" + integrity sha1-jvHAYjxqbbDcZxPEv6xGMysjaPg= + dependencies: + load-json-file "^2.0.0" + normalize-package-data "^2.3.2" + path-type "^2.0.0" + +read-pkg@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/read-pkg/-/read-pkg-3.0.0.tgz#9cbc686978fee65d16c00e2b19c237fcf6e38389" + integrity sha1-nLxoaXj+5l0WwA4rGcI3/Pbjg4k= + dependencies: + load-json-file "^4.0.0" + normalize-package-data "^2.3.2" + path-type "^3.0.0" + +"readable-stream@1 || 2", readable-stream@^2.0.0, readable-stream@^2.0.1, readable-stream@^2.0.2, readable-stream@^2.1.5, readable-stream@^2.2.2, readable-stream@^2.3.3, readable-stream@^2.3.6, readable-stream@~2.3.6: + version "2.3.7" + resolved "https://registry.yarnpkg.com/readable-stream/-/readable-stream-2.3.7.tgz#1eca1cf711aef814c04f62252a36a62f6cb23b57" + integrity sha512-Ebho8K4jIbHAxnuxi7o42OrZgF/ZTNcsZj6nRKyUmkhLFq8CHItp/fy6hQZuZmP/n3yZ9VBUbp4zz/mX8hmYPw== + dependencies: + core-util-is "~1.0.0" + inherits "~2.0.3" + isarray "~1.0.0" + process-nextick-args "~2.0.0" + safe-buffer "~5.1.1" + string_decoder "~1.1.1" + util-deprecate "~1.0.1" + +readable-stream@^3.0.6, readable-stream@^3.1.1: + version "3.6.0" + resolved "https://registry.yarnpkg.com/readable-stream/-/readable-stream-3.6.0.tgz#337bbda3adc0706bd3e024426a286d4b4b2c9198" + integrity sha512-BViHy7LKeTz4oNnkcLJ+lVSL6vpiFeX6/d3oSH8zCW7UxP2onchk+vTGB143xuFjHS3deTgkKoXXymXqymiIdA== + dependencies: + inherits "^2.0.3" + string_decoder "^1.1.1" + util-deprecate "^1.0.1" + +readdirp@^2.2.1: + version "2.2.1" + resolved "https://registry.yarnpkg.com/readdirp/-/readdirp-2.2.1.tgz#0e87622a3325aa33e892285caf8b4e846529a525" + integrity sha512-1JU/8q+VgFZyxwrJ+SVIOsh+KywWGpds3NTqikiKpDMZWScmAYyKIgqkO+ARvNWJfXeXR1zxz7aHF4u4CyH6vQ== + dependencies: + graceful-fs "^4.1.11" + micromatch "^3.1.10" + readable-stream "^2.0.2" + +readdirp@~3.3.0: + version "3.3.0" + resolved "https://registry.yarnpkg.com/readdirp/-/readdirp-3.3.0.tgz#984458d13a1e42e2e9f5841b129e162f369aff17" + integrity sha512-zz0pAkSPOXXm1viEwygWIPSPkcBYjW1xU5j/JBh5t9bGCJwa6f9+BJa6VaB2g+b55yVrmXzqkyLf4xaWYM0IkQ== + dependencies: + picomatch "^2.0.7" + +realpath-native@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/realpath-native/-/realpath-native-1.1.0.tgz#2003294fea23fb0672f2476ebe22fcf498a2d65c" + integrity sha512-wlgPA6cCIIg9gKz0fgAPjnzh4yR/LnXovwuo9hvyGvx3h8nX4+/iLZplfUWasXpqD8BdnGnP5njOFjkUwPzvjA== + dependencies: + util.promisify "^1.0.0" + +recursive-readdir@2.2.2: + version "2.2.2" + resolved "https://registry.yarnpkg.com/recursive-readdir/-/recursive-readdir-2.2.2.tgz#9946fb3274e1628de6e36b2f6714953b4845094f" + integrity sha512-nRCcW9Sj7NuZwa2XvH9co8NPeXUBhZP7CRKJtU+cS6PW9FpCIFoI5ib0NT1ZrbNuPoRy0ylyCaUL8Gih4LSyFg== + dependencies: + minimatch "3.0.4" + +regenerate-unicode-properties@^8.2.0: + version "8.2.0" + resolved "https://registry.yarnpkg.com/regenerate-unicode-properties/-/regenerate-unicode-properties-8.2.0.tgz#e5de7111d655e7ba60c057dbe9ff37c87e65cdec" + integrity sha512-F9DjY1vKLo/tPePDycuH3dn9H1OTPIkVD9Kz4LODu+F2C75mgjAJ7x/gwy6ZcSNRAAkhNlJSOHRe8k3p+K9WhA== + dependencies: + regenerate "^1.4.0" + +regenerate@^1.4.0: + version "1.4.0" + resolved "https://registry.yarnpkg.com/regenerate/-/regenerate-1.4.0.tgz#4a856ec4b56e4077c557589cae85e7a4c8869a11" + integrity sha512-1G6jJVDWrt0rK99kBjvEtziZNCICAuvIPkSiUFIQxVP06RCVpq3dmDo2oi6ABpYaDYaTRr67BEhL8r1wgEZZKg== + +regenerator-runtime@^0.11.0: + version "0.11.1" + resolved "https://registry.yarnpkg.com/regenerator-runtime/-/regenerator-runtime-0.11.1.tgz#be05ad7f9bf7d22e056f9726cee5017fbf19e2e9" + integrity sha512-MguG95oij0fC3QV3URf4V2SDYGJhJnJGqvIIgdECeODCT98wSWDAJ94SSuVpYQUoTcGUIL6L4yNB7j1DFFHSBg== + +regenerator-runtime@^0.13.3, regenerator-runtime@^0.13.4: + version "0.13.5" + resolved "https://registry.yarnpkg.com/regenerator-runtime/-/regenerator-runtime-0.13.5.tgz#d878a1d094b4306d10b9096484b33ebd55e26697" + integrity sha512-ZS5w8CpKFinUzOwW3c83oPeVXoNsrLsaCoLtJvAClH135j/R77RuymhiSErhm2lKcwSCIpmvIWSbDkIfAqKQlA== + +regenerator-transform@^0.14.2: + version "0.14.4" + resolved "https://registry.yarnpkg.com/regenerator-transform/-/regenerator-transform-0.14.4.tgz#5266857896518d1616a78a0479337a30ea974cc7" + integrity sha512-EaJaKPBI9GvKpvUz2mz4fhx7WPgvwRLY9v3hlNHWmAuJHI13T4nwKnNvm5RWJzEdnI5g5UwtOww+S8IdoUC2bw== + dependencies: + "@babel/runtime" "^7.8.4" + private "^0.1.8" + +regex-not@^1.0.0, regex-not@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/regex-not/-/regex-not-1.0.2.tgz#1f4ece27e00b0b65e0247a6810e6a85d83a5752c" + integrity sha512-J6SDjUgDxQj5NusnOtdFxDwN/+HWykR8GELwctJ7mdqhcyy1xEc4SRFHUXvxTp661YaVKAjfRLZ9cCqS6tn32A== + dependencies: + extend-shallow "^3.0.2" + safe-regex "^1.1.0" + +regex-parser@2.2.10: + version "2.2.10" + resolved "https://registry.yarnpkg.com/regex-parser/-/regex-parser-2.2.10.tgz#9e66a8f73d89a107616e63b39d4deddfee912b37" + integrity sha512-8t6074A68gHfU8Neftl0Le6KTDwfGAj7IyjPIMSfikI2wJUTHDMaIq42bUsfVnj8mhx0R+45rdUXHGpN164avA== + +regexp.prototype.flags@^1.2.0, regexp.prototype.flags@^1.3.0: + version "1.3.0" + resolved "https://registry.yarnpkg.com/regexp.prototype.flags/-/regexp.prototype.flags-1.3.0.tgz#7aba89b3c13a64509dabcf3ca8d9fbb9bdf5cb75" + integrity sha512-2+Q0C5g951OlYlJz6yu5/M33IcsESLlLfsyIaLJaG4FA2r4yP8MvVMJUUP/fVBkSpbbbZlS5gynbEWLipiiXiQ== + dependencies: + define-properties "^1.1.3" + es-abstract "^1.17.0-next.1" + +regexpp@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/regexpp/-/regexpp-2.0.1.tgz#8d19d31cf632482b589049f8281f93dbcba4d07f" + integrity sha512-lv0M6+TkDVniA3aD1Eg0DVpfU/booSu7Eev3TDO/mZKHBfVjgCGTV4t4buppESEYDtkArYFOxTJWv6S5C+iaNw== + +regexpp@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/regexpp/-/regexpp-3.0.0.tgz#dd63982ee3300e67b41c1956f850aa680d9d330e" + integrity sha512-Z+hNr7RAVWxznLPuA7DIh8UNX1j9CDrUQxskw9IrBE1Dxue2lyXT+shqEIeLUjrokxIP8CMy1WkjgG3rTsd5/g== + +regexpu-core@^4.7.0: + version "4.7.0" + resolved "https://registry.yarnpkg.com/regexpu-core/-/regexpu-core-4.7.0.tgz#fcbf458c50431b0bb7b45d6967b8192d91f3d938" + integrity sha512-TQ4KXRnIn6tz6tjnrXEkD/sshygKH/j5KzK86X8MkeHyZ8qst/LZ89j3X4/8HEIfHANTFIP/AbXakeRhWIl5YQ== + dependencies: + regenerate "^1.4.0" + regenerate-unicode-properties "^8.2.0" + regjsgen "^0.5.1" + regjsparser "^0.6.4" + unicode-match-property-ecmascript "^1.0.4" + unicode-match-property-value-ecmascript "^1.2.0" + +regjsgen@^0.5.1: + version "0.5.1" + resolved "https://registry.yarnpkg.com/regjsgen/-/regjsgen-0.5.1.tgz#48f0bf1a5ea205196929c0d9798b42d1ed98443c" + integrity sha512-5qxzGZjDs9w4tzT3TPhCJqWdCc3RLYwy9J2NB0nm5Lz+S273lvWcpjaTGHsT1dc6Hhfq41uSEOw8wBmxrKOuyg== + +regjsparser@^0.6.4: + version "0.6.4" + resolved "https://registry.yarnpkg.com/regjsparser/-/regjsparser-0.6.4.tgz#a769f8684308401a66e9b529d2436ff4d0666272" + integrity sha512-64O87/dPDgfk8/RQqC4gkZoGyyWFIEUTTh80CU6CWuK5vkCGyekIx+oKcEIYtP/RAxSQltCZHCNu/mdd7fqlJw== + dependencies: + jsesc "~0.5.0" + +relateurl@^0.2.7: + version "0.2.7" + resolved "https://registry.yarnpkg.com/relateurl/-/relateurl-0.2.7.tgz#54dbf377e51440aca90a4cd274600d3ff2d888a9" + integrity sha1-VNvzd+UUQKypCkzSdGANP/LYiKk= + +remove-trailing-separator@^1.0.1: + version "1.1.0" + resolved "https://registry.yarnpkg.com/remove-trailing-separator/-/remove-trailing-separator-1.1.0.tgz#c24bce2a283adad5bc3f58e0d48249b92379d8ef" + integrity sha1-wkvOKig62tW8P1jg1IJJuSN52O8= + +renderkid@^2.0.1: + version "2.0.3" + resolved "https://registry.yarnpkg.com/renderkid/-/renderkid-2.0.3.tgz#380179c2ff5ae1365c522bf2fcfcff01c5b74149" + integrity sha512-z8CLQp7EZBPCwCnncgf9C4XAi3WR0dv+uWu/PjIyhhAb5d6IJ/QZqlHFprHeKT+59//V6BNUsLbvN8+2LarxGA== + dependencies: + css-select "^1.1.0" + dom-converter "^0.2" + htmlparser2 "^3.3.0" + strip-ansi "^3.0.0" + utila "^0.4.0" + +repeat-element@^1.1.2: + version "1.1.3" + resolved "https://registry.yarnpkg.com/repeat-element/-/repeat-element-1.1.3.tgz#782e0d825c0c5a3bb39731f84efee6b742e6b1ce" + integrity sha512-ahGq0ZnV5m5XtZLMb+vP76kcAM5nkLqk0lpqAuojSKGgQtn4eRi4ZZGm2olo2zKFH+sMsWaqOCW1dqAnOru72g== + +repeat-string@^1.6.1: + version "1.6.1" + resolved "https://registry.yarnpkg.com/repeat-string/-/repeat-string-1.6.1.tgz#8dcae470e1c88abc2d600fff4a776286da75e637" + integrity sha1-jcrkcOHIirwtYA//Sndihtp15jc= + +request-promise-core@1.1.3: + version "1.1.3" + resolved "https://registry.yarnpkg.com/request-promise-core/-/request-promise-core-1.1.3.tgz#e9a3c081b51380dfea677336061fea879a829ee9" + integrity sha512-QIs2+ArIGQVp5ZYbWD5ZLCY29D5CfWizP8eWnm8FoGD1TX61veauETVQbrV60662V0oFBkrDOuaBI8XgtuyYAQ== + dependencies: + lodash "^4.17.15" + +request-promise-native@^1.0.5: + version "1.0.8" + resolved "https://registry.yarnpkg.com/request-promise-native/-/request-promise-native-1.0.8.tgz#a455b960b826e44e2bf8999af64dff2bfe58cb36" + integrity sha512-dapwLGqkHtwL5AEbfenuzjTYg35Jd6KPytsC2/TLkVMz8rm+tNt72MGUWT1RP/aYawMpN6HqbNGBQaRcBtjQMQ== + dependencies: + request-promise-core "1.1.3" + stealthy-require "^1.1.1" + tough-cookie "^2.3.3" + +request@^2.87.0, request@^2.88.0: + version "2.88.2" + resolved "https://registry.yarnpkg.com/request/-/request-2.88.2.tgz#d73c918731cb5a87da047e207234146f664d12b3" + integrity sha512-MsvtOrfG9ZcrOwAW+Qi+F6HbD0CWXEh9ou77uOb7FM2WPhwT7smM833PzanhJLsgXjN89Ir6V2PczXNnMpwKhw== + dependencies: + aws-sign2 "~0.7.0" + aws4 "^1.8.0" + caseless "~0.12.0" + combined-stream "~1.0.6" + extend "~3.0.2" + forever-agent "~0.6.1" + form-data "~2.3.2" + har-validator "~5.1.3" + http-signature "~1.2.0" + is-typedarray "~1.0.0" + isstream "~0.1.2" + json-stringify-safe "~5.0.1" + mime-types "~2.1.19" + oauth-sign "~0.9.0" + performance-now "^2.1.0" + qs "~6.5.2" + safe-buffer "^5.1.2" + tough-cookie "~2.5.0" + tunnel-agent "^0.6.0" + uuid "^3.3.2" + +require-directory@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" + integrity sha1-jGStX9MNqxyXbiNE/+f3kqam30I= + +require-main-filename@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-1.0.1.tgz#97f717b69d48784f5f526a6c5aa8ffdda055a4d1" + integrity sha1-l/cXtp1IeE9fUmpsWqj/3aBVpNE= + +require-main-filename@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-2.0.0.tgz#d0b329ecc7cc0f61649f62215be69af54aa8989b" + integrity sha512-NKN5kMDylKuldxYLSUfrbo5Tuzh4hd+2E8NPPX02mZtn1VuREQToYe/ZdlJy+J3uCpfaiGF05e7B8W0iXbQHmg== + +requires-port@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/requires-port/-/requires-port-1.0.0.tgz#925d2601d39ac485e091cf0da5c6e694dc3dcaff" + integrity sha1-kl0mAdOaxIXgkc8NpcbmlNw9yv8= + +resolve-cwd@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/resolve-cwd/-/resolve-cwd-2.0.0.tgz#00a9f7387556e27038eae232caa372a6a59b665a" + integrity sha1-AKn3OHVW4nA46uIyyqNypqWbZlo= + dependencies: + resolve-from "^3.0.0" + +resolve-dir@^1.0.0, resolve-dir@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/resolve-dir/-/resolve-dir-1.0.1.tgz#79a40644c362be82f26effe739c9bb5382046f43" + integrity sha1-eaQGRMNivoLybv/nOcm7U4IEb0M= + dependencies: + expand-tilde "^2.0.0" + global-modules "^1.0.0" + +resolve-from@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/resolve-from/-/resolve-from-3.0.0.tgz#b22c7af7d9d6881bc8b6e653335eebcb0a188748" + integrity sha1-six699nWiBvItuZTM17rywoYh0g= + +resolve-from@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/resolve-from/-/resolve-from-4.0.0.tgz#4abcd852ad32dd7baabfe9b40e00a36db5f392e6" + integrity sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g== + +resolve-url-loader@3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/resolve-url-loader/-/resolve-url-loader-3.1.1.tgz#28931895fa1eab9be0647d3b2958c100ae3c0bf0" + integrity sha512-K1N5xUjj7v0l2j/3Sgs5b8CjrrgtC70SmdCuZiJ8tSyb5J+uk3FoeZ4b7yTnH6j7ngI+Bc5bldHJIa8hYdu2gQ== + dependencies: + adjust-sourcemap-loader "2.0.0" + camelcase "5.3.1" + compose-function "3.0.3" + convert-source-map "1.7.0" + es6-iterator "2.0.3" + loader-utils "1.2.3" + postcss "7.0.21" + rework "1.0.1" + rework-visit "1.0.0" + source-map "0.6.1" + +resolve-url@^0.2.1: + version "0.2.1" + resolved "https://registry.yarnpkg.com/resolve-url/-/resolve-url-0.2.1.tgz#2c637fe77c893afd2a663fe21aa9080068e2052a" + integrity sha1-LGN/53yJOv0qZj/iGqkIAGjiBSo= + +resolve@1.1.7: + version "1.1.7" + resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.1.7.tgz#203114d82ad2c5ed9e8e0411b3932875e889e97b" + integrity sha1-IDEU2CrSxe2ejgQRs5ModeiJ6Xs= + +resolve@1.15.0: + version "1.15.0" + resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.15.0.tgz#1b7ca96073ebb52e741ffd799f6b39ea462c67f5" + integrity sha512-+hTmAldEGE80U2wJJDC1lebb5jWqvTYAfm3YZ1ckk1gBr0MnCqUKlwK1e+anaFljIl+F5tR5IoZcm4ZDA1zMQw== + dependencies: + path-parse "^1.0.6" + +resolve@^1.10.0, resolve@^1.12.0, resolve@^1.13.1, resolve@^1.15.1, resolve@^1.3.2, resolve@^1.8.1: + version "1.15.1" + resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.15.1.tgz#27bdcdeffeaf2d6244b95bb0f9f4b4653451f3e8" + integrity sha512-84oo6ZTtoTUpjgNEr5SJyzQhzL72gaRodsSfyxC/AXRvwu0Yse9H8eF9IpGo7b8YetZhlI6v7ZQ6bKBFV/6S7w== + dependencies: + path-parse "^1.0.6" + +restore-cursor@^3.1.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/restore-cursor/-/restore-cursor-3.1.0.tgz#39f67c54b3a7a58cea5236d95cf0034239631f7e" + integrity sha512-l+sSefzHpj5qimhFSE5a8nufZYAM3sBSVMAPtYkmC+4EH2anSGaEMXSD0izRQbu9nfyQ9y5JrVmp7E8oZrUjvA== + dependencies: + onetime "^5.1.0" + signal-exit "^3.0.2" + +ret@~0.1.10: + version "0.1.15" + resolved "https://registry.yarnpkg.com/ret/-/ret-0.1.15.tgz#b8a4825d5bdb1fc3f6f53c2bc33f81388681c7bc" + integrity sha512-TTlYpa+OL+vMMNG24xSlQGEJ3B/RzEfUlLct7b5G/ytav+wPrplCpVMFuwzXbkecJrb6IYo1iFb0S9v37754mg== + +retry@^0.12.0: + version "0.12.0" + resolved "https://registry.yarnpkg.com/retry/-/retry-0.12.0.tgz#1b42a6266a21f07421d1b0b54b7dc167b01c013b" + integrity sha1-G0KmJmoh8HQh0bC1S33BZ7AcATs= + +rework-visit@1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/rework-visit/-/rework-visit-1.0.0.tgz#9945b2803f219e2f7aca00adb8bc9f640f842c9a" + integrity sha1-mUWygD8hni96ygCtuLyfZA+ELJo= + +rework@1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/rework/-/rework-1.0.1.tgz#30806a841342b54510aa4110850cd48534144aa7" + integrity sha1-MIBqhBNCtUUQqkEQhQzUhTQUSqc= + dependencies: + convert-source-map "^0.3.3" + css "^2.0.0" + +rgb-regex@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/rgb-regex/-/rgb-regex-1.0.1.tgz#c0e0d6882df0e23be254a475e8edd41915feaeb1" + integrity sha1-wODWiC3w4jviVKR16O3UGRX+rrE= + +rgba-regex@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/rgba-regex/-/rgba-regex-1.0.0.tgz#43374e2e2ca0968b0ef1523460b7d730ff22eeb3" + integrity sha1-QzdOLiyglosO8VI0YLfXMP8i7rM= + +rimraf@2.6.3: + version "2.6.3" + resolved "https://registry.yarnpkg.com/rimraf/-/rimraf-2.6.3.tgz#b2d104fe0d8fb27cf9e0a1cda8262dd3833c6cab" + integrity sha512-mwqeW5XsA2qAejG46gYdENaxXjx9onRNCfn7L0duuP4hCuTIi/QO7PDK07KJfp1d+izWPrzEJDcSqBa0OZQriA== + dependencies: + glob "^7.1.3" + +rimraf@^2.5.4, rimraf@^2.6.3, rimraf@^2.7.1: + version "2.7.1" + resolved "https://registry.yarnpkg.com/rimraf/-/rimraf-2.7.1.tgz#35797f13a7fdadc566142c29d4f07ccad483e3ec" + integrity sha512-uWjbaKIK3T1OSVptzX7Nl6PvQ3qAGtKEtVRjRuazjfL3Bx5eI409VZSqgND+4UNnmzLVdPj9FqFJNPqBZFve4w== + dependencies: + glob "^7.1.3" + +ripemd160@^2.0.0, ripemd160@^2.0.1: + version "2.0.2" + resolved "https://registry.yarnpkg.com/ripemd160/-/ripemd160-2.0.2.tgz#a1c1a6f624751577ba5d07914cbc92850585890c" + integrity sha512-ii4iagi25WusVoiC4B4lq7pbXfAp3D9v5CwfkY33vffw2+pkDjY1D8GaN7spsxvCSx8dkPqOZCEZyfxcmJG2IA== + dependencies: + hash-base "^3.0.0" + inherits "^2.0.1" + +rsvp@^4.8.4: + version "4.8.5" + resolved "https://registry.yarnpkg.com/rsvp/-/rsvp-4.8.5.tgz#c8f155311d167f68f21e168df71ec5b083113734" + integrity sha512-nfMOlASu9OnRJo1mbEk2cz0D56a1MBNrJ7orjRZQG10XDyuvwksKbuXNp6qa+kbn839HwjwhBzhFmdsaEAfauA== + +run-async@^2.2.0, run-async@^2.4.0: + version "2.4.0" + resolved "https://registry.yarnpkg.com/run-async/-/run-async-2.4.0.tgz#e59054a5b86876cfae07f431d18cbaddc594f1e8" + integrity sha512-xJTbh/d7Lm7SBhc1tNvTpeCHaEzoyxPrqNlvSdMfBTYwaY++UJFyXUOxAtsRUXjlqOfj8luNaR9vjCh4KeV+pg== + dependencies: + is-promise "^2.1.0" + +run-queue@^1.0.0, run-queue@^1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/run-queue/-/run-queue-1.0.3.tgz#e848396f057d223f24386924618e25694161ec47" + integrity sha1-6Eg5bwV9Ij8kOGkkYY4laUFh7Ec= + dependencies: + aproba "^1.1.1" + +rxjs@^6.5.3: + version "6.5.4" + resolved "https://registry.yarnpkg.com/rxjs/-/rxjs-6.5.4.tgz#e0777fe0d184cec7872df147f303572d414e211c" + integrity sha512-naMQXcgEo3csAEGvw/NydRA0fuS2nDZJiw1YUWFKU7aPPAPGZEsD4Iimit96qwCieH6y614MCLYwdkrWx7z/7Q== + dependencies: + tslib "^1.9.0" + +safe-buffer@5.1.2, safe-buffer@~5.1.0, safe-buffer@~5.1.1: + version "5.1.2" + resolved "https://registry.yarnpkg.com/safe-buffer/-/safe-buffer-5.1.2.tgz#991ec69d296e0313747d59bdfd2b745c35f8828d" + integrity sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g== + +safe-buffer@>=5.1.0, safe-buffer@^5.0.1, safe-buffer@^5.1.0, safe-buffer@^5.1.1, safe-buffer@^5.1.2, safe-buffer@~5.2.0: + version "5.2.0" + resolved "https://registry.yarnpkg.com/safe-buffer/-/safe-buffer-5.2.0.tgz#b74daec49b1148f88c64b68d49b1e815c1f2f519" + integrity sha512-fZEwUGbVl7kouZs1jCdMLdt95hdIv0ZeHg6L7qPeciMZhZ+/gdesW4wgTARkrFWEpspjEATAzUGPG8N2jJiwbg== + +safe-regex@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/safe-regex/-/safe-regex-1.1.0.tgz#40a3669f3b077d1e943d44629e157dd48023bf2e" + integrity sha1-QKNmnzsHfR6UPURinhV91IAjvy4= + dependencies: + ret "~0.1.10" + +"safer-buffer@>= 2.1.2 < 3", safer-buffer@^2.0.2, safer-buffer@^2.1.0, safer-buffer@~2.1.0: + version "2.1.2" + resolved "https://registry.yarnpkg.com/safer-buffer/-/safer-buffer-2.1.2.tgz#44fa161b0187b9549dd84bb91802f9bd8385cd6a" + integrity sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg== + +sane@^4.0.3: + version "4.1.0" + resolved "https://registry.yarnpkg.com/sane/-/sane-4.1.0.tgz#ed881fd922733a6c461bc189dc2b6c006f3ffded" + integrity sha512-hhbzAgTIX8O7SHfp2c8/kREfEn4qO/9q8C9beyY6+tvZ87EpoZ3i1RIEvp27YBswnNbY9mWd6paKVmKbAgLfZA== + dependencies: + "@cnakazawa/watch" "^1.0.3" + anymatch "^2.0.0" + capture-exit "^2.0.0" + exec-sh "^0.3.2" + execa "^1.0.0" + fb-watchman "^2.0.0" + micromatch "^3.1.4" + minimist "^1.1.1" + walker "~1.0.5" + +sanitize.css@^10.0.0: + version "10.0.0" + resolved "https://registry.yarnpkg.com/sanitize.css/-/sanitize.css-10.0.0.tgz#b5cb2547e96d8629a60947544665243b1dc3657a" + integrity sha512-vTxrZz4dX5W86M6oVWVdOVe72ZiPs41Oi7Z6Km4W5Turyz28mrXSJhhEBZoRtzJWIv3833WKVwLSDWWkEfupMg== + +sass-loader@8.0.2: + version "8.0.2" + resolved "https://registry.yarnpkg.com/sass-loader/-/sass-loader-8.0.2.tgz#debecd8c3ce243c76454f2e8290482150380090d" + integrity sha512-7o4dbSK8/Ol2KflEmSco4jTjQoV988bM82P9CZdmo9hR3RLnvNc0ufMNdMrB0caq38JQ/FgF4/7RcbcfKzxoFQ== + dependencies: + clone-deep "^4.0.1" + loader-utils "^1.2.3" + neo-async "^2.6.1" + schema-utils "^2.6.1" + semver "^6.3.0" + +sax@^1.2.4, sax@~1.2.4: + version "1.2.4" + resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9" + integrity sha512-NqVDv9TpANUjFm0N8uM5GxL36UgKi9/atZw+x7YFnQ8ckwFGKrl4xX4yWtrey3UJm5nP1kUbnYgLopqWNSRhWw== + +saxes@^3.1.9: + version "3.1.11" + resolved "https://registry.yarnpkg.com/saxes/-/saxes-3.1.11.tgz#d59d1fd332ec92ad98a2e0b2ee644702384b1c5b" + integrity sha512-Ydydq3zC+WYDJK1+gRxRapLIED9PWeSuuS41wqyoRmzvhhh9nc+QQrVMKJYzJFULazeGhzSV0QleN2wD3boh2g== + dependencies: + xmlchars "^2.1.1" + +scheduler@^0.19.1: + version "0.19.1" + resolved "https://registry.yarnpkg.com/scheduler/-/scheduler-0.19.1.tgz#4f3e2ed2c1a7d65681f4c854fa8c5a1ccb40f196" + integrity sha512-n/zwRWRYSUj0/3g/otKDRPMh6qv2SYMWNq85IEa8iZyAv8od9zDYpGSnpBEjNgcMNq6Scbu5KfIPxNF72R/2EA== + dependencies: + loose-envify "^1.1.0" + object-assign "^4.1.1" + +schema-utils@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/schema-utils/-/schema-utils-1.0.0.tgz#0b79a93204d7b600d4b2850d1f66c2a34951c770" + integrity sha512-i27Mic4KovM/lnGsy8whRCHhc7VicJajAjTrYg11K9zfZXnYIt4k5F+kZkwjnrhKzLic/HLU4j11mjsz2G/75g== + dependencies: + ajv "^6.1.0" + ajv-errors "^1.0.0" + ajv-keywords "^3.1.0" + +schema-utils@^2.5.0, schema-utils@^2.6.0, schema-utils@^2.6.1, schema-utils@^2.6.4, schema-utils@^2.6.5: + version "2.6.5" + resolved "https://registry.yarnpkg.com/schema-utils/-/schema-utils-2.6.5.tgz#c758f0a7e624263073d396e29cd40aa101152d8a" + integrity sha512-5KXuwKziQrTVHh8j/Uxz+QUbxkaLW9X/86NBlx/gnKgtsZA2GIVMUn17qWhRFwF8jdYb3Dig5hRO/W5mZqy6SQ== + dependencies: + ajv "^6.12.0" + ajv-keywords "^3.4.1" + +select-hose@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/select-hose/-/select-hose-2.0.0.tgz#625d8658f865af43ec962bfc376a37359a4994ca" + integrity sha1-Yl2GWPhlr0Psliv8N2o3NZpJlMo= + +selfsigned@^1.10.7: + version "1.10.7" + resolved "https://registry.yarnpkg.com/selfsigned/-/selfsigned-1.10.7.tgz#da5819fd049d5574f28e88a9bcc6dbc6e6f3906b" + integrity sha512-8M3wBCzeWIJnQfl43IKwOmC4H/RAp50S8DF60znzjW5GVqTcSe2vWclt7hmYVPkKPlHWOu5EaWOMZ2Y6W8ZXTA== + dependencies: + node-forge "0.9.0" + +"semver@2 || 3 || 4 || 5", semver@^5.4.1, semver@^5.5.0, semver@^5.5.1, semver@^5.6.0: + version "5.7.1" + resolved "https://registry.yarnpkg.com/semver/-/semver-5.7.1.tgz#a954f931aeba508d307bbf069eff0c01c96116f7" + integrity sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ== + +semver@6.3.0, semver@^6.0.0, semver@^6.1.2, semver@^6.2.0, semver@^6.3.0: + version "6.3.0" + resolved "https://registry.yarnpkg.com/semver/-/semver-6.3.0.tgz#ee0a64c8af5e8ceea67687b133761e1becbd1d3d" + integrity sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw== + +semver@7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/semver/-/semver-7.0.0.tgz#5f3ca35761e47e05b206c6daff2cf814f0316b8e" + integrity sha512-+GB6zVA9LWh6zovYQLALHwv5rb2PHGlJi3lfiqIHxR0uuwCgefcOJc59v9fv1w8GbStwxuuqqAjI9NMAOOgq1A== + +send@0.17.1: + version "0.17.1" + resolved "https://registry.yarnpkg.com/send/-/send-0.17.1.tgz#c1d8b059f7900f7466dd4938bdc44e11ddb376c8" + integrity sha512-BsVKsiGcQMFwT8UxypobUKyv7irCNRHk1T0G680vk88yf6LBByGcZJOTJCrTP2xVN6yI+XjPJcNuE3V4fT9sAg== + dependencies: + debug "2.6.9" + depd "~1.1.2" + destroy "~1.0.4" + encodeurl "~1.0.2" + escape-html "~1.0.3" + etag "~1.8.1" + fresh "0.5.2" + http-errors "~1.7.2" + mime "1.6.0" + ms "2.1.1" + on-finished "~2.3.0" + range-parser "~1.2.1" + statuses "~1.5.0" + +serialize-javascript@^2.1.2: + version "2.1.2" + resolved "https://registry.yarnpkg.com/serialize-javascript/-/serialize-javascript-2.1.2.tgz#ecec53b0e0317bdc95ef76ab7074b7384785fa61" + integrity sha512-rs9OggEUF0V4jUSecXazOYsLfu7OGK2qIn3c7IPBiffz32XniEp/TX9Xmc9LQfK2nQ2QKHvZ2oygKUGU0lG4jQ== + +serve-index@^1.9.1: + version "1.9.1" + resolved "https://registry.yarnpkg.com/serve-index/-/serve-index-1.9.1.tgz#d3768d69b1e7d82e5ce050fff5b453bea12a9239" + integrity sha1-03aNabHn2C5c4FD/9bRTvqEqkjk= + dependencies: + accepts "~1.3.4" + batch "0.6.1" + debug "2.6.9" + escape-html "~1.0.3" + http-errors "~1.6.2" + mime-types "~2.1.17" + parseurl "~1.3.2" + +serve-static@1.14.1: + version "1.14.1" + resolved "https://registry.yarnpkg.com/serve-static/-/serve-static-1.14.1.tgz#666e636dc4f010f7ef29970a88a674320898b2f9" + integrity sha512-JMrvUwE54emCYWlTI+hGrGv5I8dEwmco/00EvkzIIsR7MqrHonbD9pO2MOfFnpFntl7ecpZs+3mW+XbQZu9QCg== + dependencies: + encodeurl "~1.0.2" + escape-html "~1.0.3" + parseurl "~1.3.3" + send "0.17.1" + +set-blocking@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/set-blocking/-/set-blocking-2.0.0.tgz#045f9782d011ae9a6803ddd382b24392b3d890f7" + integrity sha1-BF+XgtARrppoA93TgrJDkrPYkPc= + +set-value@^2.0.0, set-value@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/set-value/-/set-value-2.0.1.tgz#a18d40530e6f07de4228c7defe4227af8cad005b" + integrity sha512-JxHc1weCN68wRY0fhCoXpyK55m/XPHafOmK4UWD7m2CI14GMcFypt4w/0+NV5f/ZMby2F6S2wwA7fgynh9gWSw== + dependencies: + extend-shallow "^2.0.1" + is-extendable "^0.1.1" + is-plain-object "^2.0.3" + split-string "^3.0.1" + +setimmediate@^1.0.4: + version "1.0.5" + resolved "https://registry.yarnpkg.com/setimmediate/-/setimmediate-1.0.5.tgz#290cbb232e306942d7d7ea9b83732ab7856f8285" + integrity sha1-KQy7Iy4waULX1+qbg3Mqt4VvgoU= + +setprototypeof@1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/setprototypeof/-/setprototypeof-1.1.0.tgz#d0bd85536887b6fe7c0d818cb962d9d91c54e656" + integrity sha512-BvE/TwpZX4FXExxOxZyRGQQv651MSwmWKZGqvmPcRIjDqWub67kTKuIMx43cZZrS/cBBzwBcNDWoFxt2XEFIpQ== + +setprototypeof@1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/setprototypeof/-/setprototypeof-1.1.1.tgz#7e95acb24aa92f5885e0abef5ba131330d4ae683" + integrity sha512-JvdAWfbXeIGaZ9cILp38HntZSFSo3mWg6xGcJJsd+d4aRMOqauag1C63dJfDw7OaMYwEbHMOxEZ1lqVRYP2OAw== + +sha.js@^2.4.0, sha.js@^2.4.8: + version "2.4.11" + resolved "https://registry.yarnpkg.com/sha.js/-/sha.js-2.4.11.tgz#37a5cf0b81ecbc6943de109ba2960d1b26584ae7" + integrity sha512-QMEp5B7cftE7APOjk5Y6xgrbWu+WkLVQwk8JNjZ8nKRciZaByEW6MubieAiToS7+dwvrjGhH8jRXz3MVd0AYqQ== + dependencies: + inherits "^2.0.1" + safe-buffer "^5.0.1" + +shallow-clone@^0.1.2: + version "0.1.2" + resolved "https://registry.yarnpkg.com/shallow-clone/-/shallow-clone-0.1.2.tgz#5909e874ba77106d73ac414cfec1ffca87d97060" + integrity sha1-WQnodLp3EG1zrEFM/sH/yofZcGA= + dependencies: + is-extendable "^0.1.1" + kind-of "^2.0.1" + lazy-cache "^0.2.3" + mixin-object "^2.0.1" + +shallow-clone@^3.0.0: + version "3.0.1" + resolved "https://registry.yarnpkg.com/shallow-clone/-/shallow-clone-3.0.1.tgz#8f2981ad92531f55035b01fb230769a40e02efa3" + integrity sha512-/6KqX+GVUdqPuPPd2LxDDxzX6CAbjJehAAOKlNpqqUpAqPM6HeL8f+o3a+JsyGjn2lv0WY8UsTgUJjU9Ok55NA== + dependencies: + kind-of "^6.0.2" + +shebang-command@^1.2.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/shebang-command/-/shebang-command-1.2.0.tgz#44aac65b695b03398968c39f363fee5deafdf1ea" + integrity sha1-RKrGW2lbAzmJaMOfNj/uXer98eo= + dependencies: + shebang-regex "^1.0.0" + +shebang-command@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/shebang-command/-/shebang-command-2.0.0.tgz#ccd0af4f8835fbdc265b82461aaf0c36663f34ea" + integrity sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA== + dependencies: + shebang-regex "^3.0.0" + +shebang-regex@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/shebang-regex/-/shebang-regex-1.0.0.tgz#da42f49740c0b42db2ca9728571cb190c98efea3" + integrity sha1-2kL0l0DAtC2yypcoVxyxkMmO/qM= + +shebang-regex@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/shebang-regex/-/shebang-regex-3.0.0.tgz#ae16f1644d873ecad843b0307b143362d4c42172" + integrity sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A== + +shell-quote@1.7.2: + version "1.7.2" + resolved "https://registry.yarnpkg.com/shell-quote/-/shell-quote-1.7.2.tgz#67a7d02c76c9da24f99d20808fcaded0e0e04be2" + integrity sha512-mRz/m/JVscCrkMyPqHc/bczi3OQHkLTqXHEFu0zDhK/qfv3UcOA4SVmRCLmos4bhjr9ekVQubj/R7waKapmiQg== + +shellwords@^0.1.1: + version "0.1.1" + resolved "https://registry.yarnpkg.com/shellwords/-/shellwords-0.1.1.tgz#d6b9181c1a48d397324c84871efbcfc73fc0654b" + integrity sha512-vFwSUfQvqybiICwZY5+DAWIPLKsWO31Q91JSKl3UYv+K5c2QRPzn0qzec6QPu1Qc9eHYItiP3NdJqNVqetYAww== + +side-channel@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/side-channel/-/side-channel-1.0.2.tgz#df5d1abadb4e4bf4af1cd8852bf132d2f7876947" + integrity sha512-7rL9YlPHg7Ancea1S96Pa8/QWb4BtXL/TZvS6B8XFetGBeuhAsfmUspK6DokBeZ64+Kj9TCNRD/30pVz1BvQNA== + dependencies: + es-abstract "^1.17.0-next.1" + object-inspect "^1.7.0" + +signal-exit@^3.0.0, signal-exit@^3.0.2: + version "3.0.3" + resolved "https://registry.yarnpkg.com/signal-exit/-/signal-exit-3.0.3.tgz#a1410c2edd8f077b08b4e253c8eacfcaf057461c" + integrity sha512-VUJ49FC8U1OxwZLxIbTTrDvLnf/6TDgxZcK8wxR8zs13xpx7xbG60ndBlhNrFi2EMuFRoeDoJO7wthSLq42EjA== + +simple-swizzle@^0.2.2: + version "0.2.2" + resolved "https://registry.yarnpkg.com/simple-swizzle/-/simple-swizzle-0.2.2.tgz#a4da6b635ffcccca33f70d17cb92592de95e557a" + integrity sha1-pNprY1/8zMoz9w0Xy5JZLeleVXo= + dependencies: + is-arrayish "^0.3.1" + +sisteransi@^1.0.4: + version "1.0.5" + resolved "https://registry.yarnpkg.com/sisteransi/-/sisteransi-1.0.5.tgz#134d681297756437cc05ca01370d3a7a571075ed" + integrity sha512-bLGGlR1QxBcynn2d5YmDX4MGjlZvy2MRBDRNHLJ8VI6l6+9FUiyTFNJ0IveOSP0bcXgVDPRcfGqA0pjaqUpfVg== + +slash@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/slash/-/slash-1.0.0.tgz#c41f2f6c39fc16d1cd17ad4b5d896114ae470d55" + integrity sha1-xB8vbDn8FtHNF61LXYlhFK5HDVU= + +slash@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/slash/-/slash-2.0.0.tgz#de552851a1759df3a8f206535442f5ec4ddeab44" + integrity sha512-ZYKh3Wh2z1PpEXWr0MpSBZ0V6mZHAQfYevttO11c51CaWjGTaadiKZ+wVt1PbMlDV5qhMFslpZCemhwOK7C89A== + +slash@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/slash/-/slash-3.0.0.tgz#6539be870c165adbd5240220dbe361f1bc4d4634" + integrity sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q== + +slice-ansi@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/slice-ansi/-/slice-ansi-2.1.0.tgz#cacd7693461a637a5788d92a7dd4fba068e81636" + integrity sha512-Qu+VC3EwYLldKa1fCxuuvULvSJOKEgk9pi8dZeCVK7TqBfUNTH4sFkk4joj8afVSfAYgJoSOetjx9QWOJ5mYoQ== + dependencies: + ansi-styles "^3.2.0" + astral-regex "^1.0.0" + is-fullwidth-code-point "^2.0.0" + +snapdragon-node@^2.0.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/snapdragon-node/-/snapdragon-node-2.1.1.tgz#6c175f86ff14bdb0724563e8f3c1b021a286853b" + integrity sha512-O27l4xaMYt/RSQ5TR3vpWCAB5Kb/czIcqUFOM/C4fYcLnbZUc1PkjTAMjof2pBWaSTwOUd6qUHcFGVGj7aIwnw== + dependencies: + define-property "^1.0.0" + isobject "^3.0.0" + snapdragon-util "^3.0.1" + +snapdragon-util@^3.0.1: + version "3.0.1" + resolved "https://registry.yarnpkg.com/snapdragon-util/-/snapdragon-util-3.0.1.tgz#f956479486f2acd79700693f6f7b805e45ab56e2" + integrity sha512-mbKkMdQKsjX4BAL4bRYTj21edOf8cN7XHdYUJEe+Zn99hVEYcMvKPct1IqNe7+AZPirn8BCDOQBHQZknqmKlZQ== + dependencies: + kind-of "^3.2.0" + +snapdragon@^0.8.1: + version "0.8.2" + resolved "https://registry.yarnpkg.com/snapdragon/-/snapdragon-0.8.2.tgz#64922e7c565b0e14204ba1aa7d6964278d25182d" + integrity sha512-FtyOnWN/wCHTVXOMwvSv26d+ko5vWlIDD6zoUJ7LW8vh+ZBC8QdljveRP+crNrtBwioEUWy/4dMtbBjA4ioNlg== + dependencies: + base "^0.11.1" + debug "^2.2.0" + define-property "^0.2.5" + extend-shallow "^2.0.1" + map-cache "^0.2.2" + source-map "^0.5.6" + source-map-resolve "^0.5.0" + use "^3.1.0" + +sockjs-client@1.4.0: + version "1.4.0" + resolved "https://registry.yarnpkg.com/sockjs-client/-/sockjs-client-1.4.0.tgz#c9f2568e19c8fd8173b4997ea3420e0bb306c7d5" + integrity sha512-5zaLyO8/nri5cua0VtOrFXBPK1jbL4+1cebT/mmKA1E1ZXOvJrII75bPu0l0k843G/+iAbhEqzyKr0w/eCCj7g== + dependencies: + debug "^3.2.5" + eventsource "^1.0.7" + faye-websocket "~0.11.1" + inherits "^2.0.3" + json3 "^3.3.2" + url-parse "^1.4.3" + +sockjs@0.3.19: + version "0.3.19" + resolved "https://registry.yarnpkg.com/sockjs/-/sockjs-0.3.19.tgz#d976bbe800af7bd20ae08598d582393508993c0d" + integrity sha512-V48klKZl8T6MzatbLlzzRNhMepEys9Y4oGFpypBFFn1gLI/QQ9HtLLyWJNbPlwGLelOVOEijUbTTJeLLI59jLw== + dependencies: + faye-websocket "^0.10.0" + uuid "^3.0.1" + +sort-keys@^1.0.0: + version "1.1.2" + resolved "https://registry.yarnpkg.com/sort-keys/-/sort-keys-1.1.2.tgz#441b6d4d346798f1b4e49e8920adfba0e543f9ad" + integrity sha1-RBttTTRnmPG05J6JIK37oOVD+a0= + dependencies: + is-plain-obj "^1.0.0" + +source-list-map@^2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/source-list-map/-/source-list-map-2.0.1.tgz#3993bd873bfc48479cca9ea3a547835c7c154b34" + integrity sha512-qnQ7gVMxGNxsiL4lEuJwe/To8UnK7fAnmbGEEH8RpLouuKbeEm0lhbQVFIrNSuB+G7tVrAlVsZgETT5nljf+Iw== + +source-map-resolve@^0.5.0, source-map-resolve@^0.5.2: + version "0.5.3" + resolved "https://registry.yarnpkg.com/source-map-resolve/-/source-map-resolve-0.5.3.tgz#190866bece7553e1f8f267a2ee82c606b5509a1a" + integrity sha512-Htz+RnsXWk5+P2slx5Jh3Q66vhQj1Cllm0zvnaY98+NFx+Dv2CF/f5O/t8x+KaNdrdIAsruNzoh/KpialbqAnw== + dependencies: + atob "^2.1.2" + decode-uri-component "^0.2.0" + resolve-url "^0.2.1" + source-map-url "^0.4.0" + urix "^0.1.0" + +source-map-support@^0.5.6, source-map-support@~0.5.12: + version "0.5.16" + resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.16.tgz#0ae069e7fe3ba7538c64c98515e35339eac5a042" + integrity sha512-efyLRJDr68D9hBBNIPWFjhpFzURh+KJykQwvMyW5UiZzYwoF6l4YMMDIJJEyFWxWCqfyxLzz6tSfUFR+kXXsVQ== + dependencies: + buffer-from "^1.0.0" + source-map "^0.6.0" + +source-map-url@^0.4.0: + version "0.4.0" + resolved "https://registry.yarnpkg.com/source-map-url/-/source-map-url-0.4.0.tgz#3e935d7ddd73631b97659956d55128e87b5084a3" + integrity sha1-PpNdfd1zYxuXZZlW1VEo6HtQhKM= + +source-map@0.6.1, source-map@^0.6.0, source-map@^0.6.1, source-map@~0.6.0, source-map@~0.6.1: + version "0.6.1" + resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.6.1.tgz#74722af32e9614e9c287a8d0bbde48b5e2f1a263" + integrity sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g== + +source-map@^0.5.0, source-map@^0.5.6: + version "0.5.7" + resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.5.7.tgz#8a039d2d1021d22d1ea14c80d8ea468ba2ef3fcc" + integrity sha1-igOdLRAh0i0eoUyA2OpGi6LvP8w= + +spdx-correct@^3.0.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/spdx-correct/-/spdx-correct-3.1.0.tgz#fb83e504445268f154b074e218c87c003cd31df4" + integrity sha512-lr2EZCctC2BNR7j7WzJ2FpDznxky1sjfxvvYEyzxNyb6lZXHODmEoJeFu4JupYlkfha1KZpJyoqiJ7pgA1qq8Q== + dependencies: + spdx-expression-parse "^3.0.0" + spdx-license-ids "^3.0.0" + +spdx-exceptions@^2.1.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/spdx-exceptions/-/spdx-exceptions-2.2.0.tgz#2ea450aee74f2a89bfb94519c07fcd6f41322977" + integrity sha512-2XQACfElKi9SlVb1CYadKDXvoajPgBVPn/gOQLrTvHdElaVhr7ZEbqJaRnJLVNeaI4cMEAgVCeBMKF6MWRDCRA== + +spdx-expression-parse@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/spdx-expression-parse/-/spdx-expression-parse-3.0.0.tgz#99e119b7a5da00e05491c9fa338b7904823b41d0" + integrity sha512-Yg6D3XpRD4kkOmTpdgbUiEJFKghJH03fiC1OPll5h/0sO6neh2jqRDVHOQ4o/LMea0tgCkbMgea5ip/e+MkWyg== + dependencies: + spdx-exceptions "^2.1.0" + spdx-license-ids "^3.0.0" + +spdx-license-ids@^3.0.0: + version "3.0.5" + resolved "https://registry.yarnpkg.com/spdx-license-ids/-/spdx-license-ids-3.0.5.tgz#3694b5804567a458d3c8045842a6358632f62654" + integrity sha512-J+FWzZoynJEXGphVIS+XEh3kFSjZX/1i9gFBaWQcB+/tmpe2qUsSBABpcxqxnAxFdiUFEgAX1bjYGQvIZmoz9Q== + +spdy-transport@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/spdy-transport/-/spdy-transport-3.0.0.tgz#00d4863a6400ad75df93361a1608605e5dcdcf31" + integrity sha512-hsLVFE5SjA6TCisWeJXFKniGGOpBgMLmerfO2aCyCU5s7nJ/rpAepqmFifv/GCbSbueEeAJJnmSQ2rKC/g8Fcw== + dependencies: + debug "^4.1.0" + detect-node "^2.0.4" + hpack.js "^2.1.6" + obuf "^1.1.2" + readable-stream "^3.0.6" + wbuf "^1.7.3" + +spdy@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/spdy/-/spdy-4.0.1.tgz#6f12ed1c5db7ea4f24ebb8b89ba58c87c08257f2" + integrity sha512-HeZS3PBdMA+sZSu0qwpCxl3DeALD5ASx8pAX0jZdKXSpPWbQ6SYGnlg3BBmYLx5LtiZrmkAZfErCm2oECBcioA== + dependencies: + debug "^4.1.0" + handle-thing "^2.0.0" + http-deceiver "^1.2.7" + select-hose "^2.0.0" + spdy-transport "^3.0.0" + +split-string@^3.0.1, split-string@^3.0.2: + version "3.1.0" + resolved "https://registry.yarnpkg.com/split-string/-/split-string-3.1.0.tgz#7cb09dda3a86585705c64b39a6466038682e8fe2" + integrity sha512-NzNVhJDYpwceVVii8/Hu6DKfD2G+NrQHlS/V/qgv763EYudVwEcMQNxd2lh+0VrUByXN/oJkl5grOhYWvQUYiw== + dependencies: + extend-shallow "^3.0.0" + +sprintf-js@~1.0.2: + version "1.0.3" + resolved "https://registry.yarnpkg.com/sprintf-js/-/sprintf-js-1.0.3.tgz#04e6926f662895354f3dd015203633b857297e2c" + integrity sha1-BOaSb2YolTVPPdAVIDYzuFcpfiw= + +sshpk@^1.7.0: + version "1.16.1" + resolved "https://registry.yarnpkg.com/sshpk/-/sshpk-1.16.1.tgz#fb661c0bef29b39db40769ee39fa70093d6f6877" + integrity sha512-HXXqVUq7+pcKeLqqZj6mHFUMvXtOJt1uoUx09pFW6011inTMxqI8BA8PM95myrIyyKwdnzjdFjLiE6KBPVtJIg== + dependencies: + asn1 "~0.2.3" + assert-plus "^1.0.0" + bcrypt-pbkdf "^1.0.0" + dashdash "^1.12.0" + ecc-jsbn "~0.1.1" + getpass "^0.1.1" + jsbn "~0.1.0" + safer-buffer "^2.0.2" + tweetnacl "~0.14.0" + +ssri@^6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/ssri/-/ssri-6.0.1.tgz#2a3c41b28dd45b62b63676ecb74001265ae9edd8" + integrity sha512-3Wge10hNcT1Kur4PDFwEieXSCMCJs/7WvSACcrMYrNp+b8kDL1/0wJch5Ni2WrtwEa2IO8OsVfeKIciKCDx/QA== + dependencies: + figgy-pudding "^3.5.1" + +ssri@^7.0.0: + version "7.1.0" + resolved "https://registry.yarnpkg.com/ssri/-/ssri-7.1.0.tgz#92c241bf6de82365b5c7fb4bd76e975522e1294d" + integrity sha512-77/WrDZUWocK0mvA5NTRQyveUf+wsrIc6vyrxpS8tVvYBcX215QbafrJR3KtkpskIzoFLqqNuuYQvxaMjXJ/0g== + dependencies: + figgy-pudding "^3.5.1" + minipass "^3.1.1" + +stable@^0.1.8: + version "0.1.8" + resolved "https://registry.yarnpkg.com/stable/-/stable-0.1.8.tgz#836eb3c8382fe2936feaf544631017ce7d47a3cf" + integrity sha512-ji9qxRnOVfcuLDySj9qzhGSEFVobyt1kIOSkj1qZzYLzq7Tos/oUUWvotUPQLlrsidqsK6tBH89Bc9kL5zHA6w== + +stack-utils@^1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/stack-utils/-/stack-utils-1.0.2.tgz#33eba3897788558bebfc2db059dc158ec36cebb8" + integrity sha512-MTX+MeG5U994cazkjd/9KNAapsHnibjMLnfXodlkXw76JEea0UiNzrqidzo1emMwk7w5Qhc9jd4Bn9TBb1MFwA== + +static-extend@^0.1.1: + version "0.1.2" + resolved "https://registry.yarnpkg.com/static-extend/-/static-extend-0.1.2.tgz#60809c39cbff55337226fd5e0b520f341f1fb5c6" + integrity sha1-YICcOcv/VTNyJv1eC1IPNB8ftcY= + dependencies: + define-property "^0.2.5" + object-copy "^0.1.0" + +"statuses@>= 1.4.0 < 2", "statuses@>= 1.5.0 < 2", statuses@~1.5.0: + version "1.5.0" + resolved "https://registry.yarnpkg.com/statuses/-/statuses-1.5.0.tgz#161c7dac177659fd9811f43771fa99381478628c" + integrity sha1-Fhx9rBd2Wf2YEfQ3cfqZOBR4Yow= + +stealthy-require@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/stealthy-require/-/stealthy-require-1.1.1.tgz#35b09875b4ff49f26a777e509b3090a3226bf24b" + integrity sha1-NbCYdbT/SfJqd35QmzCQoyJr8ks= + +stream-browserify@^2.0.1: + version "2.0.2" + resolved "https://registry.yarnpkg.com/stream-browserify/-/stream-browserify-2.0.2.tgz#87521d38a44aa7ee91ce1cd2a47df0cb49dd660b" + integrity sha512-nX6hmklHs/gr2FuxYDltq8fJA1GDlxKQCz8O/IM4atRqBH8OORmBNgfvW5gG10GT/qQ9u0CzIvr2X5Pkt6ntqg== + dependencies: + inherits "~2.0.1" + readable-stream "^2.0.2" + +stream-each@^1.1.0: + version "1.2.3" + resolved "https://registry.yarnpkg.com/stream-each/-/stream-each-1.2.3.tgz#ebe27a0c389b04fbcc233642952e10731afa9bae" + integrity sha512-vlMC2f8I2u/bZGqkdfLQW/13Zihpej/7PmSiMQsbYddxuTsJp8vRe2x2FvVExZg7FaOds43ROAuFJwPR4MTZLw== + dependencies: + end-of-stream "^1.1.0" + stream-shift "^1.0.0" + +stream-http@^2.7.2: + version "2.8.3" + resolved "https://registry.yarnpkg.com/stream-http/-/stream-http-2.8.3.tgz#b2d242469288a5a27ec4fe8933acf623de6514fc" + integrity sha512-+TSkfINHDo4J+ZobQLWiMouQYB+UVYFttRA94FpEzzJ7ZdqcL4uUUQ7WkdkI4DSozGmgBUE/a47L+38PenXhUw== + dependencies: + builtin-status-codes "^3.0.0" + inherits "^2.0.1" + readable-stream "^2.3.6" + to-arraybuffer "^1.0.0" + xtend "^4.0.0" + +stream-shift@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/stream-shift/-/stream-shift-1.0.1.tgz#d7088281559ab2778424279b0877da3c392d5a3d" + integrity sha512-AiisoFqQ0vbGcZgQPY1cdP2I76glaVA/RauYR4G4thNFgkTqr90yXTo4LYX60Jl+sIlPNHHdGSwo01AvbKUSVQ== + +strict-uri-encode@^1.0.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/strict-uri-encode/-/strict-uri-encode-1.1.0.tgz#279b225df1d582b1f54e65addd4352e18faa0713" + integrity sha1-J5siXfHVgrH1TmWt3UNS4Y+qBxM= + +string-length@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/string-length/-/string-length-2.0.0.tgz#d40dbb686a3ace960c1cffca562bf2c45f8363ed" + integrity sha1-1A27aGo6zpYMHP/KVivyxF+DY+0= + dependencies: + astral-regex "^1.0.0" + strip-ansi "^4.0.0" + +string-length@^3.1.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/string-length/-/string-length-3.1.0.tgz#107ef8c23456e187a8abd4a61162ff4ac6e25837" + integrity sha512-Ttp5YvkGm5v9Ijagtaz1BnN+k9ObpvS0eIBblPMp2YWL8FBmi9qblQ9fexc2k/CXFgrTIteU3jAw3payCnwSTA== + dependencies: + astral-regex "^1.0.0" + strip-ansi "^5.2.0" + +string-width@^1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-1.0.2.tgz#118bdf5b8cdc51a2a7e70d211e07e2b0b9b107d3" + integrity sha1-EYvfW4zcUaKn5w0hHgfisLmxB9M= + dependencies: + code-point-at "^1.0.0" + is-fullwidth-code-point "^1.0.0" + strip-ansi "^3.0.0" + +string-width@^2.0.0, string-width@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-2.1.1.tgz#ab93f27a8dc13d28cac815c462143a6d9012ae9e" + integrity sha512-nOqH59deCq9SRHlxq1Aw85Jnt4w6KvLKqWVik6oA9ZklXLNIOlqg4F2yrT1MVaTjAqvVwdfeZ7w7aCvJD7ugkw== + dependencies: + is-fullwidth-code-point "^2.0.0" + strip-ansi "^4.0.0" + +string-width@^3.0.0, string-width@^3.1.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-3.1.0.tgz#22767be21b62af1081574306f69ac51b62203961" + integrity sha512-vafcv6KjVZKSgz06oM/H6GDBrAtz8vdhQakGjFIvNrHA6y3HCF1CInLy+QLq8dTJPQ1b+KDUqDFctkdRW44e1w== + dependencies: + emoji-regex "^7.0.1" + is-fullwidth-code-point "^2.0.0" + strip-ansi "^5.1.0" + +string-width@^4.1.0: + version "4.2.0" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.0.tgz#952182c46cc7b2c313d1596e623992bd163b72b5" + integrity sha512-zUz5JD+tgqtuDjMhwIg5uFVV3dtqZ9yQJlZVfq4I01/K5Paj5UHj7VyrQOJvzawSVlKpObApbfD0Ed6yJc+1eg== + dependencies: + emoji-regex "^8.0.0" + is-fullwidth-code-point "^3.0.0" + strip-ansi "^6.0.0" + +string.prototype.matchall@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/string.prototype.matchall/-/string.prototype.matchall-4.0.2.tgz#48bb510326fb9fdeb6a33ceaa81a6ea04ef7648e" + integrity sha512-N/jp6O5fMf9os0JU3E72Qhf590RSRZU/ungsL/qJUYVTNv7hTG0P/dbPjxINVN9jpscu3nzYwKESU3P3RY5tOg== + dependencies: + define-properties "^1.1.3" + es-abstract "^1.17.0" + has-symbols "^1.0.1" + internal-slot "^1.0.2" + regexp.prototype.flags "^1.3.0" + side-channel "^1.0.2" + +string.prototype.trimend@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/string.prototype.trimend/-/string.prototype.trimend-1.0.0.tgz#ee497fd29768646d84be2c9b819e292439614373" + integrity sha512-EEJnGqa/xNfIg05SxiPSqRS7S9qwDhYts1TSLR1BQfYUfPe1stofgGKvwERK9+9yf+PpfBMlpBaCHucXGPQfUA== + dependencies: + define-properties "^1.1.3" + es-abstract "^1.17.5" + +string.prototype.trimleft@^2.1.1: + version "2.1.2" + resolved "https://registry.yarnpkg.com/string.prototype.trimleft/-/string.prototype.trimleft-2.1.2.tgz#4408aa2e5d6ddd0c9a80739b087fbc067c03b3cc" + integrity sha512-gCA0tza1JBvqr3bfAIFJGqfdRTyPae82+KTnm3coDXkZN9wnuW3HjGgN386D7hfv5CHQYCI022/rJPVlqXyHSw== + dependencies: + define-properties "^1.1.3" + es-abstract "^1.17.5" + string.prototype.trimstart "^1.0.0" + +string.prototype.trimright@^2.1.1: + version "2.1.2" + resolved "https://registry.yarnpkg.com/string.prototype.trimright/-/string.prototype.trimright-2.1.2.tgz#c76f1cef30f21bbad8afeb8db1511496cfb0f2a3" + integrity sha512-ZNRQ7sY3KroTaYjRS6EbNiiHrOkjihL9aQE/8gfQ4DtAC/aEBRHFJa44OmoWxGGqXuJlfKkZW4WcXErGr+9ZFg== + dependencies: + define-properties "^1.1.3" + es-abstract "^1.17.5" + string.prototype.trimend "^1.0.0" + +string.prototype.trimstart@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/string.prototype.trimstart/-/string.prototype.trimstart-1.0.0.tgz#afe596a7ce9de905496919406c9734845f01a2f2" + integrity sha512-iCP8g01NFYiiBOnwG1Xc3WZLyoo+RuBymwIlWncShXDDJYWN6DbnM3odslBJdgCdRlq94B5s63NWAZlcn2CS4w== + dependencies: + define-properties "^1.1.3" + es-abstract "^1.17.5" + +string_decoder@^1.0.0, string_decoder@^1.1.1: + version "1.3.0" + resolved "https://registry.yarnpkg.com/string_decoder/-/string_decoder-1.3.0.tgz#42f114594a46cf1a8e30b0a84f56c78c3edac21e" + integrity sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA== + dependencies: + safe-buffer "~5.2.0" + +string_decoder@~1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/string_decoder/-/string_decoder-1.1.1.tgz#9cf1611ba62685d7030ae9e4ba34149c3af03fc8" + integrity sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg== + dependencies: + safe-buffer "~5.1.0" + +stringify-object@^3.3.0: + version "3.3.0" + resolved "https://registry.yarnpkg.com/stringify-object/-/stringify-object-3.3.0.tgz#703065aefca19300d3ce88af4f5b3956d7556629" + integrity sha512-rHqiFh1elqCQ9WPLIC8I0Q/g/wj5J1eMkyoiD6eoQApWHP0FtlK7rqnhmabL5VUY9JQCcqwwvlOaSuutekgyrw== + dependencies: + get-own-enumerable-property-symbols "^3.0.0" + is-obj "^1.0.1" + is-regexp "^1.0.0" + +strip-ansi@6.0.0, strip-ansi@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.0.tgz#0b1571dd7669ccd4f3e06e14ef1eed26225ae532" + integrity sha512-AuvKTrTfQNYNIctbR1K/YGTR1756GycPsg7b9bdV9Duqur4gv6aKqHXah67Z8ImS7WEz5QVcOtlfW2rZEugt6w== + dependencies: + ansi-regex "^5.0.0" + +strip-ansi@^3.0.0, strip-ansi@^3.0.1: + version "3.0.1" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-3.0.1.tgz#6a385fb8853d952d5ff05d0e8aaf94278dc63dcf" + integrity sha1-ajhfuIU9lS1f8F0Oiq+UJ43GPc8= + dependencies: + ansi-regex "^2.0.0" + +strip-ansi@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-4.0.0.tgz#a8479022eb1ac368a871389b635262c505ee368f" + integrity sha1-qEeQIusaw2iocTibY1JixQXuNo8= + dependencies: + ansi-regex "^3.0.0" + +strip-ansi@^5.0.0, strip-ansi@^5.1.0, strip-ansi@^5.2.0: + version "5.2.0" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-5.2.0.tgz#8c9a536feb6afc962bdfa5b104a5091c1ad9c0ae" + integrity sha512-DuRs1gKbBqsMKIZlrffwlug8MHkcnpjs5VPmL1PAh+mA30U0DTotfDZ0d2UUsXpPmPmMMJ6W773MaA3J+lbiWA== + dependencies: + ansi-regex "^4.1.0" + +strip-bom@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/strip-bom/-/strip-bom-3.0.0.tgz#2334c18e9c759f7bdd56fdef7e9ae3d588e68ed3" + integrity sha1-IzTBjpx1n3vdVv3vfprj1YjmjtM= + +strip-comments@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/strip-comments/-/strip-comments-1.0.2.tgz#82b9c45e7f05873bee53f37168af930aa368679d" + integrity sha512-kL97alc47hoyIQSV165tTt9rG5dn4w1dNnBhOQ3bOU1Nc1hel09jnXANaHJ7vzHLd4Ju8kseDGzlev96pghLFw== + dependencies: + babel-extract-comments "^1.0.0" + babel-plugin-transform-object-rest-spread "^6.26.0" + +strip-eof@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/strip-eof/-/strip-eof-1.0.0.tgz#bb43ff5598a6eb05d89b59fcd129c983313606bf" + integrity sha1-u0P/VZim6wXYm1n80SnJgzE2Br8= + +strip-json-comments@^3.0.1: + version "3.0.1" + resolved "https://registry.yarnpkg.com/strip-json-comments/-/strip-json-comments-3.0.1.tgz#85713975a91fb87bf1b305cca77395e40d2a64a7" + integrity sha512-VTyMAUfdm047mwKl+u79WIdrZxtFtn+nBxHeb844XBQ9uMNTuTHdx2hc5RiAJYqwTj3wc/xe5HLSdJSkJ+WfZw== + +style-loader@0.23.1: + version "0.23.1" + resolved "https://registry.yarnpkg.com/style-loader/-/style-loader-0.23.1.tgz#cb9154606f3e771ab6c4ab637026a1049174d925" + integrity sha512-XK+uv9kWwhZMZ1y7mysB+zoihsEj4wneFWAS5qoiLwzW0WzSqMrrsIy+a3zkQJq0ipFtBpX5W3MqyRIBF/WFGg== + dependencies: + loader-utils "^1.1.0" + schema-utils "^1.0.0" + +stylehacks@^4.0.0: + version "4.0.3" + resolved "https://registry.yarnpkg.com/stylehacks/-/stylehacks-4.0.3.tgz#6718fcaf4d1e07d8a1318690881e8d96726a71d5" + integrity sha512-7GlLk9JwlElY4Y6a/rmbH2MhVlTyVmiJd1PfTCqFaIBEGMYNsrO/v3SeGTdhBThLg4Z+NbOk/qFMwCa+J+3p/g== + dependencies: + browserslist "^4.0.0" + postcss "^7.0.0" + postcss-selector-parser "^3.0.0" + +supports-color@6.1.0, supports-color@^6.1.0: + version "6.1.0" + resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-6.1.0.tgz#0764abc69c63d5ac842dd4867e8d025e880df8f3" + integrity sha512-qe1jfm1Mg7Nq/NSh6XE24gPXROEVsWHxC1LIx//XNlD9iw7YZQGjZNjYN7xGaEG6iKdA8EtNFW6R0gjnVXp+wQ== + dependencies: + has-flag "^3.0.0" + +supports-color@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-2.0.0.tgz#535d045ce6b6363fa40117084629995e9df324c7" + integrity sha1-U10EXOa2Nj+kARcIRimZXp3zJMc= + +supports-color@^5.3.0: + version "5.5.0" + resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-5.5.0.tgz#e2e69a44ac8772f78a1ec0b35b689df6530efc8f" + integrity sha512-QjVjwdXIt408MIiAqCX4oUKsgU2EqAGzs2Ppkm4aQYbjm+ZEWEcW4SfFNTr4uMNZma0ey4f5lgLrkB0aX0QMow== + dependencies: + has-flag "^3.0.0" + +supports-color@^7.0.0, supports-color@^7.1.0: + version "7.1.0" + resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-7.1.0.tgz#68e32591df73e25ad1c4b49108a2ec507962bfd1" + integrity sha512-oRSIpR8pxT1Wr2FquTNnGet79b3BWljqOuoW/h4oBhxJ/HUbX5nX6JSruTkvXDCFMwDPvsaTTbvMLKZWSy0R5g== + dependencies: + has-flag "^4.0.0" + +svg-parser@^2.0.0: + version "2.0.4" + resolved "https://registry.yarnpkg.com/svg-parser/-/svg-parser-2.0.4.tgz#fdc2e29e13951736140b76cb122c8ee6630eb6b5" + integrity sha512-e4hG1hRwoOdRb37cIMSgzNsxyzKfayW6VOflrwvR+/bzrkyxY/31WkbgnQpgtrNp1SdpJvpUAGTa/ZoiPNDuRQ== + +svgo@^1.0.0, svgo@^1.2.2: + version "1.3.2" + resolved "https://registry.yarnpkg.com/svgo/-/svgo-1.3.2.tgz#b6dc511c063346c9e415b81e43401145b96d4167" + integrity sha512-yhy/sQYxR5BkC98CY7o31VGsg014AKLEPxdfhora76l36hD9Rdy5NZA/Ocn6yayNPgSamYdtX2rFJdcv07AYVw== + dependencies: + chalk "^2.4.1" + coa "^2.0.2" + css-select "^2.0.0" + css-select-base-adapter "^0.1.1" + css-tree "1.0.0-alpha.37" + csso "^4.0.2" + js-yaml "^3.13.1" + mkdirp "~0.5.1" + object.values "^1.1.0" + sax "~1.2.4" + stable "^0.1.8" + unquote "~1.1.1" + util.promisify "~1.0.0" + +symbol-tree@^3.2.2: + version "3.2.4" + resolved "https://registry.yarnpkg.com/symbol-tree/-/symbol-tree-3.2.4.tgz#430637d248ba77e078883951fb9aa0eed7c63fa2" + integrity sha512-9QNk5KwDF+Bvz+PyObkmSYjI5ksVUYtjW7AU22r2NKcfLJcXp96hkDWU3+XndOsUb+AQ9QhfzfCT2O+CNWT5Tw== + +table@^5.2.3: + version "5.4.6" + resolved "https://registry.yarnpkg.com/table/-/table-5.4.6.tgz#1292d19500ce3f86053b05f0e8e7e4a3bb21079e" + integrity sha512-wmEc8m4fjnob4gt5riFRtTu/6+4rSe12TpAELNSqHMfF3IqnA+CH37USM6/YR3qRZv7e56kAEAtd6nKZaxe0Ug== + dependencies: + ajv "^6.10.2" + lodash "^4.17.14" + slice-ansi "^2.1.0" + string-width "^3.0.0" + +tapable@^1.0.0, tapable@^1.1.3: + version "1.1.3" + resolved "https://registry.yarnpkg.com/tapable/-/tapable-1.1.3.tgz#a1fccc06b58db61fd7a45da2da44f5f3a3e67ba2" + integrity sha512-4WK/bYZmj8xLr+HUCODHGF1ZFzsYffasLUgEiMBY4fgtltdO6B4WJtlSbPaDTLpYTcGVwM2qLnFTICEcNxs3kA== + +terser-webpack-plugin@2.3.5: + version "2.3.5" + resolved "https://registry.yarnpkg.com/terser-webpack-plugin/-/terser-webpack-plugin-2.3.5.tgz#5ad971acce5c517440ba873ea4f09687de2f4a81" + integrity sha512-WlWksUoq+E4+JlJ+h+U+QUzXpcsMSSNXkDy9lBVkSqDn1w23Gg29L/ary9GeJVYCGiNJJX7LnVc4bwL1N3/g1w== + dependencies: + cacache "^13.0.1" + find-cache-dir "^3.2.0" + jest-worker "^25.1.0" + p-limit "^2.2.2" + schema-utils "^2.6.4" + serialize-javascript "^2.1.2" + source-map "^0.6.1" + terser "^4.4.3" + webpack-sources "^1.4.3" + +terser-webpack-plugin@^1.4.3: + version "1.4.3" + resolved "https://registry.yarnpkg.com/terser-webpack-plugin/-/terser-webpack-plugin-1.4.3.tgz#5ecaf2dbdc5fb99745fd06791f46fc9ddb1c9a7c" + integrity sha512-QMxecFz/gHQwteWwSo5nTc6UaICqN1bMedC5sMtUc7y3Ha3Q8y6ZO0iCR8pq4RJC8Hjf0FEPEHZqcMB/+DFCrA== + dependencies: + cacache "^12.0.2" + find-cache-dir "^2.1.0" + is-wsl "^1.1.0" + schema-utils "^1.0.0" + serialize-javascript "^2.1.2" + source-map "^0.6.1" + terser "^4.1.2" + webpack-sources "^1.4.0" + worker-farm "^1.7.0" + +terser@^4.1.2, terser@^4.4.3, terser@^4.6.3: + version "4.6.10" + resolved "https://registry.yarnpkg.com/terser/-/terser-4.6.10.tgz#90f5bd069ff456ddbc9503b18e52f9c493d3b7c2" + integrity sha512-qbF/3UOo11Hggsbsqm2hPa6+L4w7bkr+09FNseEe8xrcVD3APGLFqE+Oz1ZKAxjYnFsj80rLOfgAtJ0LNJjtTA== + dependencies: + commander "^2.20.0" + source-map "~0.6.1" + source-map-support "~0.5.12" + +test-exclude@^5.2.3: + version "5.2.3" + resolved "https://registry.yarnpkg.com/test-exclude/-/test-exclude-5.2.3.tgz#c3d3e1e311eb7ee405e092dac10aefd09091eac0" + integrity sha512-M+oxtseCFO3EDtAaGH7iiej3CBkzXqFMbzqYAACdzKui4eZA+pq3tZEwChvOdNfa7xxy8BfbmgJSIr43cC/+2g== + dependencies: + glob "^7.1.3" + minimatch "^3.0.4" + read-pkg-up "^4.0.0" + require-main-filename "^2.0.0" + +text-table@0.2.0, text-table@^0.2.0: + version "0.2.0" + resolved "https://registry.yarnpkg.com/text-table/-/text-table-0.2.0.tgz#7f5ee823ae805207c00af2df4a84ec3fcfa570b4" + integrity sha1-f17oI66AUgfACvLfSoTsP8+lcLQ= + +throat@^4.0.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/throat/-/throat-4.1.0.tgz#89037cbc92c56ab18926e6ba4cbb200e15672a6a" + integrity sha1-iQN8vJLFarGJJua6TLsgDhVnKmo= + +through2@^2.0.0: + version "2.0.5" + resolved "https://registry.yarnpkg.com/through2/-/through2-2.0.5.tgz#01c1e39eb31d07cb7d03a96a70823260b23132cd" + integrity sha512-/mrRod8xqpA+IHSLyGCQ2s8SPHiCDEeQJSep1jqLYeEUClOFG2Qsh+4FU6G9VeqpZnGW/Su8LQGc4YKni5rYSQ== + dependencies: + readable-stream "~2.3.6" + xtend "~4.0.1" + +through@^2.3.6: + version "2.3.8" + resolved "https://registry.yarnpkg.com/through/-/through-2.3.8.tgz#0dd4c9ffaabc357960b1b724115d7e0e86a2e1f5" + integrity sha1-DdTJ/6q8NXlgsbckEV1+Doai4fU= + +thunky@^1.0.2: + version "1.1.0" + resolved "https://registry.yarnpkg.com/thunky/-/thunky-1.1.0.tgz#5abaf714a9405db0504732bbccd2cedd9ef9537d" + integrity sha512-eHY7nBftgThBqOyHGVN+l8gF0BucP09fMo0oO/Lb0w1OF80dJv+lDVpXG60WMQvkcxAkNybKsrEIE3ZtKGmPrA== + +timers-browserify@^2.0.4: + version "2.0.11" + resolved "https://registry.yarnpkg.com/timers-browserify/-/timers-browserify-2.0.11.tgz#800b1f3eee272e5bc53ee465a04d0e804c31211f" + integrity sha512-60aV6sgJ5YEbzUdn9c8kYGIqOubPoUdqQCul3SBAsRCZ40s6Y5cMcrW4dt3/k/EsbLVJNl9n6Vz3fTc+k2GeKQ== + dependencies: + setimmediate "^1.0.4" + +timsort@^0.3.0: + version "0.3.0" + resolved "https://registry.yarnpkg.com/timsort/-/timsort-0.3.0.tgz#405411a8e7e6339fe64db9a234de11dc31e02bd4" + integrity sha1-QFQRqOfmM5/mTbmiNN4R3DHgK9Q= + +tmp@^0.0.33: + version "0.0.33" + resolved "https://registry.yarnpkg.com/tmp/-/tmp-0.0.33.tgz#6d34335889768d21b2bcda0aa277ced3b1bfadf9" + integrity sha512-jRCJlojKnZ3addtTOjdIqoRuPEKBvNXcGYqzO6zWZX8KfKEpnGY5jfggJQ3EjKuu8D4bJRr0y+cYJFmYbImXGw== + dependencies: + os-tmpdir "~1.0.2" + +tmpl@1.0.x: + version "1.0.4" + resolved "https://registry.yarnpkg.com/tmpl/-/tmpl-1.0.4.tgz#23640dd7b42d00433911140820e5cf440e521dd1" + integrity sha1-I2QN17QtAEM5ERQIIOXPRA5SHdE= + +to-arraybuffer@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/to-arraybuffer/-/to-arraybuffer-1.0.1.tgz#7d229b1fcc637e466ca081180836a7aabff83f43" + integrity sha1-fSKbH8xjfkZsoIEYCDanqr/4P0M= + +to-fast-properties@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/to-fast-properties/-/to-fast-properties-2.0.0.tgz#dc5e698cbd079265bc73e0377681a4e4e83f616e" + integrity sha1-3F5pjL0HkmW8c+A3doGk5Og/YW4= + +to-object-path@^0.3.0: + version "0.3.0" + resolved "https://registry.yarnpkg.com/to-object-path/-/to-object-path-0.3.0.tgz#297588b7b0e7e0ac08e04e672f85c1f4999e17af" + integrity sha1-KXWIt7Dn4KwI4E5nL4XB9JmeF68= + dependencies: + kind-of "^3.0.2" + +to-regex-range@^2.1.0: + version "2.1.1" + resolved "https://registry.yarnpkg.com/to-regex-range/-/to-regex-range-2.1.1.tgz#7c80c17b9dfebe599e27367e0d4dd5590141db38" + integrity sha1-fIDBe53+vlmeJzZ+DU3VWQFB2zg= + dependencies: + is-number "^3.0.0" + repeat-string "^1.6.1" + +to-regex-range@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/to-regex-range/-/to-regex-range-5.0.1.tgz#1648c44aae7c8d988a326018ed72f5b4dd0392e4" + integrity sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ== + dependencies: + is-number "^7.0.0" + +to-regex@^3.0.1, to-regex@^3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/to-regex/-/to-regex-3.0.2.tgz#13cfdd9b336552f30b51f33a8ae1b42a7a7599ce" + integrity sha512-FWtleNAtZ/Ki2qtqej2CXTOayOH9bHDQF+Q48VpWyDXjbYxA4Yz8iDB31zXOBUlOHHKidDbqGVrTUvQMPmBGBw== + dependencies: + define-property "^2.0.2" + extend-shallow "^3.0.2" + regex-not "^1.0.2" + safe-regex "^1.1.0" + +toidentifier@1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/toidentifier/-/toidentifier-1.0.0.tgz#7e1be3470f1e77948bc43d94a3c8f4d7752ba553" + integrity sha512-yaOH/Pk/VEhBWWTlhI+qXxDFXlejDGcQipMlyxda9nthulaxLZUNcUqFxokp0vcYnvteJln5FNQDRrxj3YcbVw== + +tough-cookie@^2.3.3, tough-cookie@^2.3.4, tough-cookie@^2.5.0, tough-cookie@~2.5.0: + version "2.5.0" + resolved "https://registry.yarnpkg.com/tough-cookie/-/tough-cookie-2.5.0.tgz#cd9fb2a0aa1d5a12b473bd9fb96fa3dcff65ade2" + integrity sha512-nlLsUzgm1kfLXSXfRZMc1KLAugd4hqJHDTvc2hDIwS3mZAfMEuMbc03SujMF+GEcpaX/qboeycw6iO8JwVv2+g== + dependencies: + psl "^1.1.28" + punycode "^2.1.1" + +tr46@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/tr46/-/tr46-1.0.1.tgz#a8b13fd6bfd2489519674ccde55ba3693b706d09" + integrity sha1-qLE/1r/SSJUZZ0zN5VujaTtwbQk= + dependencies: + punycode "^2.1.0" + +ts-pnp@1.1.6: + version "1.1.6" + resolved "https://registry.yarnpkg.com/ts-pnp/-/ts-pnp-1.1.6.tgz#389a24396d425a0d3162e96d2b4638900fdc289a" + integrity sha512-CrG5GqAAzMT7144Cl+UIFP7mz/iIhiy+xQ6GGcnjTezhALT02uPMRw7tgDSESgB5MsfKt55+GPWw4ir1kVtMIQ== + +ts-pnp@^1.1.6: + version "1.2.0" + resolved "https://registry.yarnpkg.com/ts-pnp/-/ts-pnp-1.2.0.tgz#a500ad084b0798f1c3071af391e65912c86bca92" + integrity sha512-csd+vJOb/gkzvcCHgTGSChYpy5f1/XKNsmvBGO4JXS+z1v2HobugDz4s1IeFXM3wZB44uczs+eazB5Q/ccdhQw== + +tslib@^1.10.0, tslib@^1.8.1, tslib@^1.9.0: + version "1.11.1" + resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.11.1.tgz#eb15d128827fbee2841549e171f45ed338ac7e35" + integrity sha512-aZW88SY8kQbU7gpV19lN24LtXh/yD4ZZg6qieAJDDg+YBsJcSmLGK9QpnUjAKVG/xefmvJGd1WUmfpT/g6AJGA== + +tsutils@^3.17.1: + version "3.17.1" + resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-3.17.1.tgz#ed719917f11ca0dee586272b2ac49e015a2dd759" + integrity sha512-kzeQ5B8H3w60nFY2g8cJIuH7JDpsALXySGtwGJ0p2LSjLgay3NdIpqq5SoOBe46bKDW2iq25irHCr8wjomUS2g== + dependencies: + tslib "^1.8.1" + +tty-browserify@0.0.0: + version "0.0.0" + resolved "https://registry.yarnpkg.com/tty-browserify/-/tty-browserify-0.0.0.tgz#a157ba402da24e9bf957f9aa69d524eed42901a6" + integrity sha1-oVe6QC2iTpv5V/mqadUk7tQpAaY= + +tunnel-agent@^0.6.0: + version "0.6.0" + resolved "https://registry.yarnpkg.com/tunnel-agent/-/tunnel-agent-0.6.0.tgz#27a5dea06b36b04a0a9966774b290868f0fc40fd" + integrity sha1-J6XeoGs2sEoKmWZ3SykIaPD8QP0= + dependencies: + safe-buffer "^5.0.1" + +tweetnacl@^0.14.3, tweetnacl@~0.14.0: + version "0.14.5" + resolved "https://registry.yarnpkg.com/tweetnacl/-/tweetnacl-0.14.5.tgz#5ae68177f192d4456269d108afa93ff8743f4f64" + integrity sha1-WuaBd/GS1EViadEIr6k/+HQ/T2Q= + +type-check@~0.3.2: + version "0.3.2" + resolved "https://registry.yarnpkg.com/type-check/-/type-check-0.3.2.tgz#5884cab512cf1d355e3fb784f30804b2b520db72" + integrity sha1-WITKtRLPHTVeP7eE8wgEsrUg23I= + dependencies: + prelude-ls "~1.1.2" + +type-fest@^0.11.0: + version "0.11.0" + resolved "https://registry.yarnpkg.com/type-fest/-/type-fest-0.11.0.tgz#97abf0872310fed88a5c466b25681576145e33f1" + integrity sha512-OdjXJxnCN1AvyLSzeKIgXTXxV+99ZuXl3Hpo9XpJAv9MBcHrrJOQ5kV7ypXOuQie+AmWG25hLbiKdwYTifzcfQ== + +type-fest@^0.8.1: + version "0.8.1" + resolved "https://registry.yarnpkg.com/type-fest/-/type-fest-0.8.1.tgz#09e249ebde851d3b1e48d27c105444667f17b83d" + integrity sha512-4dbzIzqvjtgiM5rw1k5rEHtBANKmdudhGyBEajN01fEyhaAIhsoKNy6y7+IN93IfpFtwY9iqi7kD+xwKhQsNJA== + +type-is@~1.6.17, type-is@~1.6.18: + version "1.6.18" + resolved "https://registry.yarnpkg.com/type-is/-/type-is-1.6.18.tgz#4e552cd05df09467dcbc4ef739de89f2cf37c131" + integrity sha512-TkRKr9sUTxEH8MdfuCSP7VizJyzRNMjj2J2do2Jr3Kym598JVdEksuzPQCnlFPW4ky9Q+iA+ma9BGm06XQBy8g== + dependencies: + media-typer "0.3.0" + mime-types "~2.1.24" + +type@^1.0.1: + version "1.2.0" + resolved "https://registry.yarnpkg.com/type/-/type-1.2.0.tgz#848dd7698dafa3e54a6c479e759c4bc3f18847a0" + integrity sha512-+5nt5AAniqsCnu2cEQQdpzCAh33kVx8n0VoFidKpB1dVVLAN/F+bgVOqOJqOnEnrhp222clB5p3vUlD+1QAnfg== + +type@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/type/-/type-2.0.0.tgz#5f16ff6ef2eb44f260494dae271033b29c09a9c3" + integrity sha512-KBt58xCHry4Cejnc2ISQAF7QY+ORngsWfxezO68+12hKV6lQY8P/psIkcbjeHWn7MqcgciWJyCCevFMJdIXpow== + +typedarray@^0.0.6: + version "0.0.6" + resolved "https://registry.yarnpkg.com/typedarray/-/typedarray-0.0.6.tgz#867ac74e3864187b1d3d47d996a78ec5c8830777" + integrity sha1-hnrHTjhkGHsdPUfZlqeOxciDB3c= + +typescript@^3.9.5: + version "3.9.5" + resolved "https://registry.yarnpkg.com/typescript/-/typescript-3.9.5.tgz#586f0dba300cde8be52dd1ac4f7e1009c1b13f36" + integrity sha512-hSAifV3k+i6lEoCJ2k6R2Z/rp/H3+8sdmcn5NrS3/3kE7+RyZXm9aqvxWqjEXHAd8b0pShatpcdMTvEdvAJltQ== + +unicode-canonical-property-names-ecmascript@^1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/unicode-canonical-property-names-ecmascript/-/unicode-canonical-property-names-ecmascript-1.0.4.tgz#2619800c4c825800efdd8343af7dd9933cbe2818" + integrity sha512-jDrNnXWHd4oHiTZnx/ZG7gtUTVp+gCcTTKr8L0HjlwphROEW3+Him+IpvC+xcJEFegapiMZyZe02CyuOnRmbnQ== + +unicode-match-property-ecmascript@^1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/unicode-match-property-ecmascript/-/unicode-match-property-ecmascript-1.0.4.tgz#8ed2a32569961bce9227d09cd3ffbb8fed5f020c" + integrity sha512-L4Qoh15vTfntsn4P1zqnHulG0LdXgjSO035fEpdtp6YxXhMT51Q6vgM5lYdG/5X3MjS+k/Y9Xw4SFCY9IkR0rg== + dependencies: + unicode-canonical-property-names-ecmascript "^1.0.4" + unicode-property-aliases-ecmascript "^1.0.4" + +unicode-match-property-value-ecmascript@^1.2.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/unicode-match-property-value-ecmascript/-/unicode-match-property-value-ecmascript-1.2.0.tgz#0d91f600eeeb3096aa962b1d6fc88876e64ea531" + integrity sha512-wjuQHGQVofmSJv1uVISKLE5zO2rNGzM/KCYZch/QQvez7C1hUhBIuZ701fYXExuufJFMPhv2SyL8CyoIfMLbIQ== + +unicode-property-aliases-ecmascript@^1.0.4: + version "1.1.0" + resolved "https://registry.yarnpkg.com/unicode-property-aliases-ecmascript/-/unicode-property-aliases-ecmascript-1.1.0.tgz#dd57a99f6207bedff4628abefb94c50db941c8f4" + integrity sha512-PqSoPh/pWetQ2phoj5RLiaqIk4kCNwoV3CI+LfGmWLKI3rE3kl1h59XpX2BjgDrmbxD9ARtQobPGU1SguCYuQg== + +union-value@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/union-value/-/union-value-1.0.1.tgz#0b6fe7b835aecda61c6ea4d4f02c14221e109847" + integrity sha512-tJfXmxMeWYnczCVs7XAEvIV7ieppALdyepWMkHkwciRpZraG/xwT+s2JN8+pr1+8jCRf80FFzvr+MpQeeoF4Xg== + dependencies: + arr-union "^3.1.0" + get-value "^2.0.6" + is-extendable "^0.1.1" + set-value "^2.0.1" + +uniq@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/uniq/-/uniq-1.0.1.tgz#b31c5ae8254844a3a8281541ce2b04b865a734ff" + integrity sha1-sxxa6CVIRKOoKBVBzisEuGWnNP8= + +uniqs@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/uniqs/-/uniqs-2.0.0.tgz#ffede4b36b25290696e6e165d4a59edb998e6b02" + integrity sha1-/+3ks2slKQaW5uFl1KWe25mOawI= + +unique-filename@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/unique-filename/-/unique-filename-1.1.1.tgz#1d69769369ada0583103a1e6ae87681b56573230" + integrity sha512-Vmp0jIp2ln35UTXuryvjzkjGdRyf9b2lTXuSYUiPmzRcl3FDtYqAwOnTJkAngD9SWhnoJzDbTKwaOrZ+STtxNQ== + dependencies: + unique-slug "^2.0.0" + +unique-slug@^2.0.0: + version "2.0.2" + resolved "https://registry.yarnpkg.com/unique-slug/-/unique-slug-2.0.2.tgz#baabce91083fc64e945b0f3ad613e264f7cd4e6c" + integrity sha512-zoWr9ObaxALD3DOPfjPSqxt4fnZiWblxHIgeWqW8x7UqDzEtHEQLzji2cuJYQFCU6KmoJikOYAZlrTHHebjx2w== + dependencies: + imurmurhash "^0.1.4" + +universalify@^0.1.0: + version "0.1.2" + resolved "https://registry.yarnpkg.com/universalify/-/universalify-0.1.2.tgz#b646f69be3942dabcecc9d6639c80dc105efaa66" + integrity sha512-rBJeI5CXAlmy1pV+617WB9J63U6XcazHHF2f2dbJix4XzpUF0RS3Zbj0FGIOCAva5P/d/GBOYaACQ1w+0azUkg== + +unpipe@1.0.0, unpipe@~1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/unpipe/-/unpipe-1.0.0.tgz#b2bf4ee8514aae6165b4817829d21b2ef49904ec" + integrity sha1-sr9O6FFKrmFltIF4KdIbLvSZBOw= + +unquote@~1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/unquote/-/unquote-1.1.1.tgz#8fded7324ec6e88a0ff8b905e7c098cdc086d544" + integrity sha1-j97XMk7G6IoP+LkF58CYzcCG1UQ= + +unset-value@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/unset-value/-/unset-value-1.0.0.tgz#8376873f7d2335179ffb1e6fc3a8ed0dfc8ab559" + integrity sha1-g3aHP30jNRef+x5vw6jtDfyKtVk= + dependencies: + has-value "^0.3.1" + isobject "^3.0.0" + +upath@^1.1.1: + version "1.2.0" + resolved "https://registry.yarnpkg.com/upath/-/upath-1.2.0.tgz#8f66dbcd55a883acdae4408af8b035a5044c1894" + integrity sha512-aZwGpamFO61g3OlfT7OQCHqhGnW43ieH9WZeP7QxN/G/jS4jfqUkZxoryvJgVPEcrl5NL/ggHsSmLMHuH64Lhg== + +uri-js@^4.2.2: + version "4.2.2" + resolved "https://registry.yarnpkg.com/uri-js/-/uri-js-4.2.2.tgz#94c540e1ff772956e2299507c010aea6c8838eb0" + integrity sha512-KY9Frmirql91X2Qgjry0Wd4Y+YTdrdZheS8TFwvkbLWf/G5KNJDCh6pKL5OZctEW4+0Baa5idK2ZQuELRwPznQ== + dependencies: + punycode "^2.1.0" + +urix@^0.1.0: + version "0.1.0" + resolved "https://registry.yarnpkg.com/urix/-/urix-0.1.0.tgz#da937f7a62e21fec1fd18d49b35c2935067a6c72" + integrity sha1-2pN/emLiH+wf0Y1Js1wpNQZ6bHI= + +url-loader@2.3.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/url-loader/-/url-loader-2.3.0.tgz#e0e2ef658f003efb8ca41b0f3ffbf76bab88658b" + integrity sha512-goSdg8VY+7nPZKUEChZSEtW5gjbS66USIGCeSJ1OVOJ7Yfuh/36YxCwMi5HVEJh6mqUYOoy3NJ0vlOMrWsSHog== + dependencies: + loader-utils "^1.2.3" + mime "^2.4.4" + schema-utils "^2.5.0" + +url-parse@^1.4.3: + version "1.4.7" + resolved "https://registry.yarnpkg.com/url-parse/-/url-parse-1.4.7.tgz#a8a83535e8c00a316e403a5db4ac1b9b853ae278" + integrity sha512-d3uaVyzDB9tQoSXFvuSUNFibTd9zxd2bkVrDRvF5TmvWWQwqE4lgYJ5m+x1DbecWkw+LK4RNl2CU1hHuOKPVlg== + dependencies: + querystringify "^2.1.1" + requires-port "^1.0.0" + +url@^0.11.0: + version "0.11.0" + resolved "https://registry.yarnpkg.com/url/-/url-0.11.0.tgz#3838e97cfc60521eb73c525a8e55bfdd9e2e28f1" + integrity sha1-ODjpfPxgUh63PFJajlW/3Z4uKPE= + dependencies: + punycode "1.3.2" + querystring "0.2.0" + +use@^3.1.0: + version "3.1.1" + resolved "https://registry.yarnpkg.com/use/-/use-3.1.1.tgz#d50c8cac79a19fbc20f2911f56eb973f4e10070f" + integrity sha512-cwESVXlO3url9YWlFW/TA9cshCEhtu7IKJ/p5soJ/gGpj7vbvFrAY/eIioQ6Dw23KjZhYgiIo8HOs1nQ2vr/oQ== + +util-deprecate@^1.0.1, util-deprecate@~1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" + integrity sha1-RQ1Nyfpw3nMnYvvS1KKJgUGaDM8= + +util.promisify@1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/util.promisify/-/util.promisify-1.0.0.tgz#440f7165a459c9a16dc145eb8e72f35687097030" + integrity sha512-i+6qA2MPhvoKLuxnJNpXAGhg7HphQOSUq2LKMZD0m15EiskXUkMvKdF4Uui0WYeCUGea+o2cw/ZuwehtfsrNkA== + dependencies: + define-properties "^1.1.2" + object.getownpropertydescriptors "^2.0.3" + +util.promisify@^1.0.0, util.promisify@~1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/util.promisify/-/util.promisify-1.0.1.tgz#6baf7774b80eeb0f7520d8b81d07982a59abbaee" + integrity sha512-g9JpC/3He3bm38zsLupWryXHoEcS22YHthuPQSJdMy6KNrzIRzWqcsHzD/WUnqe45whVou4VIsPew37DoXWNrA== + dependencies: + define-properties "^1.1.3" + es-abstract "^1.17.2" + has-symbols "^1.0.1" + object.getownpropertydescriptors "^2.1.0" + +util@0.10.3: + version "0.10.3" + resolved "https://registry.yarnpkg.com/util/-/util-0.10.3.tgz#7afb1afe50805246489e3db7fe0ed379336ac0f9" + integrity sha1-evsa/lCAUkZInj23/g7TeTNqwPk= + dependencies: + inherits "2.0.1" + +util@^0.11.0: + version "0.11.1" + resolved "https://registry.yarnpkg.com/util/-/util-0.11.1.tgz#3236733720ec64bb27f6e26f421aaa2e1b588d61" + integrity sha512-HShAsny+zS2TZfaXxD9tYj4HQGlBezXZMZuM/S5PKLLoZkShZiGk9o5CzukI1LVHZvjdvZ2Sj1aW/Ndn2NB/HQ== + dependencies: + inherits "2.0.3" + +utila@^0.4.0, utila@~0.4: + version "0.4.0" + resolved "https://registry.yarnpkg.com/utila/-/utila-0.4.0.tgz#8a16a05d445657a3aea5eecc5b12a4fa5379772c" + integrity sha1-ihagXURWV6Oupe7MWxKk+lN5dyw= + +utils-merge@1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/utils-merge/-/utils-merge-1.0.1.tgz#9f95710f50a267947b2ccc124741c1028427e713" + integrity sha1-n5VxD1CiZ5R7LMwSR0HBAoQn5xM= + +uuid@^3.0.1, uuid@^3.3.2: + version "3.4.0" + resolved "https://registry.yarnpkg.com/uuid/-/uuid-3.4.0.tgz#b23e4358afa8a202fe7a100af1f5f883f02007ee" + integrity sha512-HjSDRw6gZE5JMggctHBcjVak08+KEVhSIiDzFnT9S9aegmp85S/bReBVTb4QTFaRNptJ9kuYaNhnbNEOkbKb/A== + +v8-compile-cache@2.0.3: + version "2.0.3" + resolved "https://registry.yarnpkg.com/v8-compile-cache/-/v8-compile-cache-2.0.3.tgz#00f7494d2ae2b688cfe2899df6ed2c54bef91dbe" + integrity sha512-CNmdbwQMBjwr9Gsmohvm0pbL954tJrNzf6gWL3K+QMQf00PF7ERGrEiLgjuU3mKreLC2MeGhUsNV9ybTbLgd3w== + +v8-compile-cache@^2.0.3: + version "2.1.0" + resolved "https://registry.yarnpkg.com/v8-compile-cache/-/v8-compile-cache-2.1.0.tgz#e14de37b31a6d194f5690d67efc4e7f6fc6ab30e" + integrity sha512-usZBT3PW+LOjM25wbqIlZwPeJV+3OSz3M1k1Ws8snlW39dZyYL9lOGC5FgPVHfk0jKmjiDV8Z0mIbVQPiwFs7g== + +validate-npm-package-license@^3.0.1: + version "3.0.4" + resolved "https://registry.yarnpkg.com/validate-npm-package-license/-/validate-npm-package-license-3.0.4.tgz#fc91f6b9c7ba15c857f4cb2c5defeec39d4f410a" + integrity sha512-DpKm2Ui/xN7/HQKCtpZxoRWBhZ9Z0kqtygG8XCgNQ8ZlDnxuQmWhj566j8fN4Cu3/JmbhsDo7fcAJq4s9h27Ew== + dependencies: + spdx-correct "^3.0.0" + spdx-expression-parse "^3.0.0" + +vary@~1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/vary/-/vary-1.1.2.tgz#2299f02c6ded30d4a5961b0b9f74524a18f634fc" + integrity sha1-IpnwLG3tMNSllhsLn3RSShj2NPw= + +vendors@^1.0.0: + version "1.0.4" + resolved "https://registry.yarnpkg.com/vendors/-/vendors-1.0.4.tgz#e2b800a53e7a29b93506c3cf41100d16c4c4ad8e" + integrity sha512-/juG65kTL4Cy2su4P8HjtkTxk6VmJDiOPBufWniqQ6wknac6jNiXS9vU+hO3wgusiyqWlzTbVHi0dyJqRONg3w== + +verror@1.10.0: + version "1.10.0" + resolved "https://registry.yarnpkg.com/verror/-/verror-1.10.0.tgz#3a105ca17053af55d6e270c1f8288682e18da400" + integrity sha1-OhBcoXBTr1XW4nDB+CiGguGNpAA= + dependencies: + assert-plus "^1.0.0" + core-util-is "1.0.2" + extsprintf "^1.2.0" + +vm-browserify@^1.0.1: + version "1.1.2" + resolved "https://registry.yarnpkg.com/vm-browserify/-/vm-browserify-1.1.2.tgz#78641c488b8e6ca91a75f511e7a3b32a86e5dda0" + integrity sha512-2ham8XPWTONajOR0ohOKOHXkm3+gaBmGut3SRuu75xLd/RRaY6vqgh8NBYYk7+RW3u5AtzPQZG8F10LHkl0lAQ== + +w3c-hr-time@^1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/w3c-hr-time/-/w3c-hr-time-1.0.2.tgz#0a89cdf5cc15822df9c360543676963e0cc308cd" + integrity sha512-z8P5DvDNjKDoFIHK7q8r8lackT6l+jo/Ye3HOle7l9nICP9lf1Ci25fy9vHd0JOWewkIFzXIEig3TdKT7JQ5fQ== + dependencies: + browser-process-hrtime "^1.0.0" + +w3c-xmlserializer@^1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/w3c-xmlserializer/-/w3c-xmlserializer-1.1.2.tgz#30485ca7d70a6fd052420a3d12fd90e6339ce794" + integrity sha512-p10l/ayESzrBMYWRID6xbuCKh2Fp77+sA0doRuGn4tTIMrrZVeqfpKjXHY+oDh3K4nLdPgNwMTVP6Vp4pvqbNg== + dependencies: + domexception "^1.0.1" + webidl-conversions "^4.0.2" + xml-name-validator "^3.0.0" + +walker@^1.0.7, walker@~1.0.5: + version "1.0.7" + resolved "https://registry.yarnpkg.com/walker/-/walker-1.0.7.tgz#2f7f9b8fd10d677262b18a884e28d19618e028fb" + integrity sha1-L3+bj9ENZ3JisYqITijRlhjgKPs= + dependencies: + makeerror "1.0.x" + +watchpack@^1.6.0: + version "1.6.1" + resolved "https://registry.yarnpkg.com/watchpack/-/watchpack-1.6.1.tgz#280da0a8718592174010c078c7585a74cd8cd0e2" + integrity sha512-+IF9hfUFOrYOOaKyfaI7h7dquUIOgyEMoQMLA7OP5FxegKA2+XdXThAZ9TU2kucfhDH7rfMHs1oPYziVGWRnZA== + dependencies: + chokidar "^2.1.8" + graceful-fs "^4.1.2" + neo-async "^2.5.0" + +wbuf@^1.1.0, wbuf@^1.7.3: + version "1.7.3" + resolved "https://registry.yarnpkg.com/wbuf/-/wbuf-1.7.3.tgz#c1d8d149316d3ea852848895cb6a0bfe887b87df" + integrity sha512-O84QOnr0icsbFGLS0O3bI5FswxzRr8/gHwWkDlQFskhSPryQXvrTMxjxGP4+iWYoauLoBvfDpkrOauZ+0iZpDA== + dependencies: + minimalistic-assert "^1.0.0" + +webidl-conversions@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-4.0.2.tgz#a855980b1f0b6b359ba1d5d9fb39ae941faa63ad" + integrity sha512-YQ+BmxuTgd6UXZW3+ICGfyqRyHXVlD5GtQr5+qjiNW7bF0cqrzX500HVXPBOvgXb5YnzDd+h0zqyv61KUD7+Sg== + +webpack-cli@^3.3.9: + version "3.3.11" + resolved "https://registry.yarnpkg.com/webpack-cli/-/webpack-cli-3.3.11.tgz#3bf21889bf597b5d82c38f215135a411edfdc631" + integrity sha512-dXlfuml7xvAFwYUPsrtQAA9e4DOe58gnzSxhgrO/ZM/gyXTBowrsYeubyN4mqGhYdpXMFNyQ6emjJS9M7OBd4g== + dependencies: + chalk "2.4.2" + cross-spawn "6.0.5" + enhanced-resolve "4.1.0" + findup-sync "3.0.0" + global-modules "2.0.0" + import-local "2.0.0" + interpret "1.2.0" + loader-utils "1.2.3" + supports-color "6.1.0" + v8-compile-cache "2.0.3" + yargs "13.2.4" + +webpack-dev-middleware@^3.7.2: + version "3.7.2" + resolved "https://registry.yarnpkg.com/webpack-dev-middleware/-/webpack-dev-middleware-3.7.2.tgz#0019c3db716e3fa5cecbf64f2ab88a74bab331f3" + integrity sha512-1xC42LxbYoqLNAhV6YzTYacicgMZQTqRd27Sim9wn5hJrX3I5nxYy1SxSd4+gjUFsz1dQFj+yEe6zEVmSkeJjw== + dependencies: + memory-fs "^0.4.1" + mime "^2.4.4" + mkdirp "^0.5.1" + range-parser "^1.2.1" + webpack-log "^2.0.0" + +webpack-dev-server@3.10.3: + version "3.10.3" + resolved "https://registry.yarnpkg.com/webpack-dev-server/-/webpack-dev-server-3.10.3.tgz#f35945036813e57ef582c2420ef7b470e14d3af0" + integrity sha512-e4nWev8YzEVNdOMcNzNeCN947sWJNd43E5XvsJzbAL08kGc2frm1tQ32hTJslRS+H65LCb/AaUCYU7fjHCpDeQ== + dependencies: + ansi-html "0.0.7" + bonjour "^3.5.0" + chokidar "^2.1.8" + compression "^1.7.4" + connect-history-api-fallback "^1.6.0" + debug "^4.1.1" + del "^4.1.1" + express "^4.17.1" + html-entities "^1.2.1" + http-proxy-middleware "0.19.1" + import-local "^2.0.0" + internal-ip "^4.3.0" + ip "^1.1.5" + is-absolute-url "^3.0.3" + killable "^1.0.1" + loglevel "^1.6.6" + opn "^5.5.0" + p-retry "^3.0.1" + portfinder "^1.0.25" + schema-utils "^1.0.0" + selfsigned "^1.10.7" + semver "^6.3.0" + serve-index "^1.9.1" + sockjs "0.3.19" + sockjs-client "1.4.0" + spdy "^4.0.1" + strip-ansi "^3.0.1" + supports-color "^6.1.0" + url "^0.11.0" + webpack-dev-middleware "^3.7.2" + webpack-log "^2.0.0" + ws "^6.2.1" + yargs "12.0.5" + +webpack-log@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/webpack-log/-/webpack-log-2.0.0.tgz#5b7928e0637593f119d32f6227c1e0ac31e1b47f" + integrity sha512-cX8G2vR/85UYG59FgkoMamwHUIkSSlV3bBMRsbxVXVUk2j6NleCKjQ/WE9eYg9WY4w25O9w8wKP4rzNZFmUcUg== + dependencies: + ansi-colors "^3.0.0" + uuid "^3.3.2" + +webpack-manifest-plugin@2.2.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/webpack-manifest-plugin/-/webpack-manifest-plugin-2.2.0.tgz#19ca69b435b0baec7e29fbe90fb4015de2de4f16" + integrity sha512-9S6YyKKKh/Oz/eryM1RyLVDVmy3NSPV0JXMRhZ18fJsq+AwGxUY34X54VNwkzYcEmEkDwNxuEOboCZEebJXBAQ== + dependencies: + fs-extra "^7.0.0" + lodash ">=3.5 <5" + object.entries "^1.1.0" + tapable "^1.0.0" + +webpack-sources@^1.1.0, webpack-sources@^1.4.0, webpack-sources@^1.4.1, webpack-sources@^1.4.3: + version "1.4.3" + resolved "https://registry.yarnpkg.com/webpack-sources/-/webpack-sources-1.4.3.tgz#eedd8ec0b928fbf1cbfe994e22d2d890f330a933" + integrity sha512-lgTS3Xhv1lCOKo7SA5TjKXMjpSM4sBjNV5+q2bqesbSPs5FjGmU6jjtBSkX9b4qW87vDIsCIlUPOEhbZrMdjeQ== + dependencies: + source-list-map "^2.0.0" + source-map "~0.6.1" + +webpack@4.42.0: + version "4.42.0" + resolved "https://registry.yarnpkg.com/webpack/-/webpack-4.42.0.tgz#b901635dd6179391d90740a63c93f76f39883eb8" + integrity sha512-EzJRHvwQyBiYrYqhyjW9AqM90dE4+s1/XtCfn7uWg6cS72zH+2VPFAlsnW0+W0cDi0XRjNKUMoJtpSi50+Ph6w== + dependencies: + "@webassemblyjs/ast" "1.8.5" + "@webassemblyjs/helper-module-context" "1.8.5" + "@webassemblyjs/wasm-edit" "1.8.5" + "@webassemblyjs/wasm-parser" "1.8.5" + acorn "^6.2.1" + ajv "^6.10.2" + ajv-keywords "^3.4.1" + chrome-trace-event "^1.0.2" + enhanced-resolve "^4.1.0" + eslint-scope "^4.0.3" + json-parse-better-errors "^1.0.2" + loader-runner "^2.4.0" + loader-utils "^1.2.3" + memory-fs "^0.4.1" + micromatch "^3.1.10" + mkdirp "^0.5.1" + neo-async "^2.6.1" + node-libs-browser "^2.2.1" + schema-utils "^1.0.0" + tapable "^1.1.3" + terser-webpack-plugin "^1.4.3" + watchpack "^1.6.0" + webpack-sources "^1.4.1" + +websocket-driver@>=0.5.1: + version "0.7.3" + resolved "https://registry.yarnpkg.com/websocket-driver/-/websocket-driver-0.7.3.tgz#a2d4e0d4f4f116f1e6297eba58b05d430100e9f9" + integrity sha512-bpxWlvbbB459Mlipc5GBzzZwhoZgGEZLuqPaR0INBGnPAY1vdBX6hPnoFXiw+3yWxDuHyQjO2oXTMyS8A5haFg== + dependencies: + http-parser-js ">=0.4.0 <0.4.11" + safe-buffer ">=5.1.0" + websocket-extensions ">=0.1.1" + +websocket-extensions@>=0.1.1: + version "0.1.3" + resolved "https://registry.yarnpkg.com/websocket-extensions/-/websocket-extensions-0.1.3.tgz#5d2ff22977003ec687a4b87073dfbbac146ccf29" + integrity sha512-nqHUnMXmBzT0w570r2JpJxfiSD1IzoI+HGVdd3aZ0yNi3ngvQ4jv1dtHt5VGxfI2yj5yqImPhOK4vmIh2xMbGg== + +whatwg-encoding@^1.0.1, whatwg-encoding@^1.0.3, whatwg-encoding@^1.0.5: + version "1.0.5" + resolved "https://registry.yarnpkg.com/whatwg-encoding/-/whatwg-encoding-1.0.5.tgz#5abacf777c32166a51d085d6b4f3e7d27113ddb0" + integrity sha512-b5lim54JOPN9HtzvK9HFXvBma/rnfFeqsic0hSpjtDbVxR3dJKLc+KB4V6GgiGOvl7CY/KNh8rxSo9DKQrnUEw== + dependencies: + iconv-lite "0.4.24" + +whatwg-fetch@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/whatwg-fetch/-/whatwg-fetch-3.0.0.tgz#fc804e458cc460009b1a2b966bc8817d2578aefb" + integrity sha512-9GSJUgz1D4MfyKU7KRqwOjXCXTqWdFNvEr7eUBYchQiVc744mqK/MzXPNR2WsPkmkOa4ywfg8C2n8h+13Bey1Q== + +whatwg-mimetype@^2.1.0, whatwg-mimetype@^2.2.0, whatwg-mimetype@^2.3.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/whatwg-mimetype/-/whatwg-mimetype-2.3.0.tgz#3d4b1e0312d2079879f826aff18dbeeca5960fbf" + integrity sha512-M4yMwr6mAnQz76TbJm914+gPpB/nCwvZbJU28cUD6dR004SAxDLOOSUaB1JDRqLtaOV/vi0IC5lEAGFgrjGv/g== + +whatwg-url@^6.4.1: + version "6.5.0" + resolved "https://registry.yarnpkg.com/whatwg-url/-/whatwg-url-6.5.0.tgz#f2df02bff176fd65070df74ad5ccbb5a199965a8" + integrity sha512-rhRZRqx/TLJQWUpQ6bmrt2UV4f0HCQ463yQuONJqC6fO2VoEb1pTYddbe59SkYq87aoM5A3bdhMZiUiVws+fzQ== + dependencies: + lodash.sortby "^4.7.0" + tr46 "^1.0.1" + webidl-conversions "^4.0.2" + +whatwg-url@^7.0.0: + version "7.1.0" + resolved "https://registry.yarnpkg.com/whatwg-url/-/whatwg-url-7.1.0.tgz#c2c492f1eca612988efd3d2266be1b9fc6170d06" + integrity sha512-WUu7Rg1DroM7oQvGWfOiAK21n74Gg+T4elXEQYkOhtyLeWiJFoOGLXPKI/9gzIie9CtwVLm8wtw6YJdKyxSjeg== + dependencies: + lodash.sortby "^4.7.0" + tr46 "^1.0.1" + webidl-conversions "^4.0.2" + +which-module@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/which-module/-/which-module-2.0.0.tgz#d9ef07dce77b9902b8a3a8fa4b31c3e3f7e6e87a" + integrity sha1-2e8H3Od7mQK4o6j6SzHD4/fm6Ho= + +which@^1.2.14, which@^1.2.9, which@^1.3.0, which@^1.3.1: + version "1.3.1" + resolved "https://registry.yarnpkg.com/which/-/which-1.3.1.tgz#a45043d54f5805316da8d62f9f50918d3da70b0a" + integrity sha512-HxJdYWq1MTIQbJ3nw0cqssHoTNU267KlrDuGZ1WYlxDStUtKUhOaJmh112/TZmHxxUfuJqPXSOm7tDyas0OSIQ== + dependencies: + isexe "^2.0.0" + +which@^2.0.1: + version "2.0.2" + resolved "https://registry.yarnpkg.com/which/-/which-2.0.2.tgz#7c6a8dd0a636a0327e10b59c9286eee93f3f51b1" + integrity sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA== + dependencies: + isexe "^2.0.0" + +word-wrap@~1.2.3: + version "1.2.3" + resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c" + integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ== + +workbox-background-sync@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-background-sync/-/workbox-background-sync-4.3.1.tgz#26821b9bf16e9e37fd1d640289edddc08afd1950" + integrity sha512-1uFkvU8JXi7L7fCHVBEEnc3asPpiAL33kO495UMcD5+arew9IbKW2rV5lpzhoWcm/qhGB89YfO4PmB/0hQwPRg== + dependencies: + workbox-core "^4.3.1" + +workbox-broadcast-update@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-broadcast-update/-/workbox-broadcast-update-4.3.1.tgz#e2c0280b149e3a504983b757606ad041f332c35b" + integrity sha512-MTSfgzIljpKLTBPROo4IpKjESD86pPFlZwlvVG32Kb70hW+aob4Jxpblud8EhNb1/L5m43DUM4q7C+W6eQMMbA== + dependencies: + workbox-core "^4.3.1" + +workbox-build@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-build/-/workbox-build-4.3.1.tgz#414f70fb4d6de47f6538608b80ec52412d233e64" + integrity sha512-UHdwrN3FrDvicM3AqJS/J07X0KXj67R8Cg0waq1MKEOqzo89ap6zh6LmaLnRAjpB+bDIz+7OlPye9iii9KBnxw== + dependencies: + "@babel/runtime" "^7.3.4" + "@hapi/joi" "^15.0.0" + common-tags "^1.8.0" + fs-extra "^4.0.2" + glob "^7.1.3" + lodash.template "^4.4.0" + pretty-bytes "^5.1.0" + stringify-object "^3.3.0" + strip-comments "^1.0.2" + workbox-background-sync "^4.3.1" + workbox-broadcast-update "^4.3.1" + workbox-cacheable-response "^4.3.1" + workbox-core "^4.3.1" + workbox-expiration "^4.3.1" + workbox-google-analytics "^4.3.1" + workbox-navigation-preload "^4.3.1" + workbox-precaching "^4.3.1" + workbox-range-requests "^4.3.1" + workbox-routing "^4.3.1" + workbox-strategies "^4.3.1" + workbox-streams "^4.3.1" + workbox-sw "^4.3.1" + workbox-window "^4.3.1" + +workbox-cacheable-response@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-cacheable-response/-/workbox-cacheable-response-4.3.1.tgz#f53e079179c095a3f19e5313b284975c91428c91" + integrity sha512-Rp5qlzm6z8IOvnQNkCdO9qrDgDpoPNguovs0H8C+wswLuPgSzSp9p2afb5maUt9R1uTIwOXrVQMmPfPypv+npw== + dependencies: + workbox-core "^4.3.1" + +workbox-core@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-core/-/workbox-core-4.3.1.tgz#005d2c6a06a171437afd6ca2904a5727ecd73be6" + integrity sha512-I3C9jlLmMKPxAC1t0ExCq+QoAMd0vAAHULEgRZ7kieCdUd919n53WC0AfvokHNwqRhGn+tIIj7vcb5duCjs2Kg== + +workbox-expiration@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-expiration/-/workbox-expiration-4.3.1.tgz#d790433562029e56837f341d7f553c4a78ebe921" + integrity sha512-vsJLhgQsQouv9m0rpbXubT5jw0jMQdjpkum0uT+d9tTwhXcEZks7qLfQ9dGSaufTD2eimxbUOJfWLbNQpIDMPw== + dependencies: + workbox-core "^4.3.1" + +workbox-google-analytics@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-google-analytics/-/workbox-google-analytics-4.3.1.tgz#9eda0183b103890b5c256e6f4ea15a1f1548519a" + integrity sha512-xzCjAoKuOb55CBSwQrbyWBKqp35yg1vw9ohIlU2wTy06ZrYfJ8rKochb1MSGlnoBfXGWss3UPzxR5QL5guIFdg== + dependencies: + workbox-background-sync "^4.3.1" + workbox-core "^4.3.1" + workbox-routing "^4.3.1" + workbox-strategies "^4.3.1" + +workbox-navigation-preload@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-navigation-preload/-/workbox-navigation-preload-4.3.1.tgz#29c8e4db5843803b34cd96dc155f9ebd9afa453d" + integrity sha512-K076n3oFHYp16/C+F8CwrRqD25GitA6Rkd6+qAmLmMv1QHPI2jfDwYqrytOfKfYq42bYtW8Pr21ejZX7GvALOw== + dependencies: + workbox-core "^4.3.1" + +workbox-precaching@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-precaching/-/workbox-precaching-4.3.1.tgz#9fc45ed122d94bbe1f0ea9584ff5940960771cba" + integrity sha512-piSg/2csPoIi/vPpp48t1q5JLYjMkmg5gsXBQkh/QYapCdVwwmKlU9mHdmy52KsDGIjVaqEUMFvEzn2LRaigqQ== + dependencies: + workbox-core "^4.3.1" + +workbox-range-requests@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-range-requests/-/workbox-range-requests-4.3.1.tgz#f8a470188922145cbf0c09a9a2d5e35645244e74" + integrity sha512-S+HhL9+iTFypJZ/yQSl/x2Bf5pWnbXdd3j57xnb0V60FW1LVn9LRZkPtneODklzYuFZv7qK6riZ5BNyc0R0jZA== + dependencies: + workbox-core "^4.3.1" + +workbox-routing@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-routing/-/workbox-routing-4.3.1.tgz#a675841af623e0bb0c67ce4ed8e724ac0bed0cda" + integrity sha512-FkbtrODA4Imsi0p7TW9u9MXuQ5P4pVs1sWHK4dJMMChVROsbEltuE79fBoIk/BCztvOJ7yUpErMKa4z3uQLX+g== + dependencies: + workbox-core "^4.3.1" + +workbox-strategies@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-strategies/-/workbox-strategies-4.3.1.tgz#d2be03c4ef214c115e1ab29c9c759c9fe3e9e646" + integrity sha512-F/+E57BmVG8dX6dCCopBlkDvvhg/zj6VDs0PigYwSN23L8hseSRwljrceU2WzTvk/+BSYICsWmRq5qHS2UYzhw== + dependencies: + workbox-core "^4.3.1" + +workbox-streams@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-streams/-/workbox-streams-4.3.1.tgz#0b57da70e982572de09c8742dd0cb40a6b7c2cc3" + integrity sha512-4Kisis1f/y0ihf4l3u/+ndMkJkIT4/6UOacU3A4BwZSAC9pQ9vSvJpIi/WFGQRH/uPXvuVjF5c2RfIPQFSS2uA== + dependencies: + workbox-core "^4.3.1" + +workbox-sw@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-sw/-/workbox-sw-4.3.1.tgz#df69e395c479ef4d14499372bcd84c0f5e246164" + integrity sha512-0jXdusCL2uC5gM3yYFT6QMBzKfBr2XTk0g5TPAV4y8IZDyVNDyj1a8uSXy3/XrvkVTmQvLN4O5k3JawGReXr9w== + +workbox-webpack-plugin@4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-webpack-plugin/-/workbox-webpack-plugin-4.3.1.tgz#47ff5ea1cc074b6c40fb5a86108863a24120d4bd" + integrity sha512-gJ9jd8Mb8wHLbRz9ZvGN57IAmknOipD3W4XNE/Lk/4lqs5Htw4WOQgakQy/o/4CoXQlMCYldaqUg+EJ35l9MEQ== + dependencies: + "@babel/runtime" "^7.0.0" + json-stable-stringify "^1.0.1" + workbox-build "^4.3.1" + +workbox-window@^4.3.1: + version "4.3.1" + resolved "https://registry.yarnpkg.com/workbox-window/-/workbox-window-4.3.1.tgz#ee6051bf10f06afa5483c9b8dfa0531994ede0f3" + integrity sha512-C5gWKh6I58w3GeSc0wp2Ne+rqVw8qwcmZnQGpjiek8A2wpbxSJb1FdCoQVO+jDJs35bFgo/WETgl1fqgsxN0Hg== + dependencies: + workbox-core "^4.3.1" + +worker-farm@^1.7.0: + version "1.7.0" + resolved "https://registry.yarnpkg.com/worker-farm/-/worker-farm-1.7.0.tgz#26a94c5391bbca926152002f69b84a4bf772e5a8" + integrity sha512-rvw3QTZc8lAxyVrqcSGVm5yP/IJ2UcB3U0graE3LCFoZ0Yn2x4EoVSqJKdB/T5M+FLcRPjz4TDacRf3OCfNUzw== + dependencies: + errno "~0.1.7" + +worker-rpc@^0.1.0: + version "0.1.1" + resolved "https://registry.yarnpkg.com/worker-rpc/-/worker-rpc-0.1.1.tgz#cb565bd6d7071a8f16660686051e969ad32f54d5" + integrity sha512-P1WjMrUB3qgJNI9jfmpZ/htmBEjFh//6l/5y8SD9hg1Ef5zTTVVoRjTrTEzPrNBQvmhMxkoTsjOXN10GWU7aCg== + dependencies: + microevent.ts "~0.1.1" + +wrap-ansi@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-2.1.0.tgz#d8fc3d284dd05794fe84973caecdd1cf824fdd85" + integrity sha1-2Pw9KE3QV5T+hJc8rs3Rz4JP3YU= + dependencies: + string-width "^1.0.1" + strip-ansi "^3.0.1" + +wrap-ansi@^5.1.0: + version "5.1.0" + resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-5.1.0.tgz#1fd1f67235d5b6d0fee781056001bfb694c03b09" + integrity sha512-QC1/iN/2/RPVJ5jYK8BGttj5z83LmSKmvbvrXPNCLZSEb32KKVDJDl/MOt2N01qU2H/FkzEa9PKto1BqDjtd7Q== + dependencies: + ansi-styles "^3.2.0" + string-width "^3.0.0" + strip-ansi "^5.0.0" + +wrappy@1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/wrappy/-/wrappy-1.0.2.tgz#b5243d8f3ec1aa35f1364605bc0d1036e30ab69f" + integrity sha1-tSQ9jz7BqjXxNkYFvA0QNuMKtp8= + +write-file-atomic@2.4.1: + version "2.4.1" + resolved "https://registry.yarnpkg.com/write-file-atomic/-/write-file-atomic-2.4.1.tgz#d0b05463c188ae804396fd5ab2a370062af87529" + integrity sha512-TGHFeZEZMnv+gBFRfjAcxL5bPHrsGKtnb4qsFAws7/vlh+QfwAaySIw4AXP9ZskTTh5GWu3FLuJhsWVdiJPGvg== + dependencies: + graceful-fs "^4.1.11" + imurmurhash "^0.1.4" + signal-exit "^3.0.2" + +write@1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/write/-/write-1.0.3.tgz#0800e14523b923a387e415123c865616aae0f5c3" + integrity sha512-/lg70HAjtkUgWPVZhZcm+T4hkL8Zbtp1nFNOn3lRrxnlv50SRBv7cR7RqR+GMsd3hUXy9hWBo4CHTbFTcOYwig== + dependencies: + mkdirp "^0.5.1" + +ws@^5.2.0: + version "5.2.2" + resolved "https://registry.yarnpkg.com/ws/-/ws-5.2.2.tgz#dffef14866b8e8dc9133582514d1befaf96e980f" + integrity sha512-jaHFD6PFv6UgoIVda6qZllptQsMlDEJkTQcybzzXDYM1XO9Y8em691FGMPmM46WGyLU4z9KMgQN+qrux/nhlHA== + dependencies: + async-limiter "~1.0.0" + +ws@^6.1.2, ws@^6.2.1: + version "6.2.1" + resolved "https://registry.yarnpkg.com/ws/-/ws-6.2.1.tgz#442fdf0a47ed64f59b6a5d8ff130f4748ed524fb" + integrity sha512-GIyAXC2cB7LjvpgMt9EKS2ldqr0MTrORaleiOno6TweZ6r3TKtoFQWay/2PceJ3RuBasOHzXNn5Lrw1X0bEjqA== + dependencies: + async-limiter "~1.0.0" + +xml-name-validator@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/xml-name-validator/-/xml-name-validator-3.0.0.tgz#6ae73e06de4d8c6e47f9fb181f78d648ad457c6a" + integrity sha512-A5CUptxDsvxKJEU3yO6DuWBSJz/qizqzJKOMIfUJHETbBw/sFaDxgd6fxm1ewUaM0jZ444Fc5vC5ROYurg/4Pw== + +xmlchars@^2.1.1: + version "2.2.0" + resolved "https://registry.yarnpkg.com/xmlchars/-/xmlchars-2.2.0.tgz#060fe1bcb7f9c76fe2a17db86a9bc3ab894210cb" + integrity sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw== + +xregexp@^4.3.0: + version "4.3.0" + resolved "https://registry.yarnpkg.com/xregexp/-/xregexp-4.3.0.tgz#7e92e73d9174a99a59743f67a4ce879a04b5ae50" + integrity sha512-7jXDIFXh5yJ/orPn4SXjuVrWWoi4Cr8jfV1eHv9CixKSbU+jY4mxfrBwAuDvupPNKpMUY+FeIqsVw/JLT9+B8g== + dependencies: + "@babel/runtime-corejs3" "^7.8.3" + +xtend@^4.0.0, xtend@~4.0.1: + version "4.0.2" + resolved "https://registry.yarnpkg.com/xtend/-/xtend-4.0.2.tgz#bb72779f5fa465186b1f438f674fa347fdb5db54" + integrity sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ== + +"y18n@^3.2.1 || ^4.0.0", y18n@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/y18n/-/y18n-4.0.0.tgz#95ef94f85ecc81d007c264e190a120f0a3c8566b" + integrity sha512-r9S/ZyXu/Xu9q1tYlpsLIsa3EeLXXk0VwlxqTcFRfg9EhMW+17kbt9G0NrgCmhGb5vT2hyhJZLfDGx+7+5Uj/w== + +yallist@^3.0.2: + version "3.1.1" + resolved "https://registry.yarnpkg.com/yallist/-/yallist-3.1.1.tgz#dbb7daf9bfd8bac9ab45ebf602b8cbad0d5d08fd" + integrity sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g== + +yallist@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72" + integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A== + +yaml@^1.7.2: + version "1.8.3" + resolved "https://registry.yarnpkg.com/yaml/-/yaml-1.8.3.tgz#2f420fca58b68ce3a332d0ca64be1d191dd3f87a" + integrity sha512-X/v7VDnK+sxbQ2Imq4Jt2PRUsRsP7UcpSl3Llg6+NRRqWLIvxkMFYtH1FmvwNGYRKKPa+EPA4qDBlI9WVG1UKw== + dependencies: + "@babel/runtime" "^7.8.7" + +yargs-parser@^11.1.1: + version "11.1.1" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-11.1.1.tgz#879a0865973bca9f6bab5cbdf3b1c67ec7d3bcf4" + integrity sha512-C6kB/WJDiaxONLJQnF8ccx9SEeoTTLek8RVbaOIsrAUS8VrBEXfmeSnCZxygc+XC2sNMBIwOOnfcxiynjHsVSQ== + dependencies: + camelcase "^5.0.0" + decamelize "^1.2.0" + +yargs-parser@^13.1.0, yargs-parser@^13.1.2: + version "13.1.2" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-13.1.2.tgz#130f09702ebaeef2650d54ce6e3e5706f7a4fb38" + integrity sha512-3lbsNRf/j+A4QuSZfDRA7HRSfWrzO0YjqTJd5kjAq37Zep1CEgaYmrH9Q3GwPiB9cHyd1Y1UwggGhJGoxipbzg== + dependencies: + camelcase "^5.0.0" + decamelize "^1.2.0" + +yargs@12.0.5: + version "12.0.5" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-12.0.5.tgz#05f5997b609647b64f66b81e3b4b10a368e7ad13" + integrity sha512-Lhz8TLaYnxq/2ObqHDql8dX8CJi97oHxrjUcYtzKbbykPtVW9WB+poxI+NM2UIzsMgNCZTIf0AQwsjK5yMAqZw== + dependencies: + cliui "^4.0.0" + decamelize "^1.2.0" + find-up "^3.0.0" + get-caller-file "^1.0.1" + os-locale "^3.0.0" + require-directory "^2.1.1" + require-main-filename "^1.0.1" + set-blocking "^2.0.0" + string-width "^2.0.0" + which-module "^2.0.0" + y18n "^3.2.1 || ^4.0.0" + yargs-parser "^11.1.1" + +yargs@13.2.4: + version "13.2.4" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-13.2.4.tgz#0b562b794016eb9651b98bd37acf364aa5d6dc83" + integrity sha512-HG/DWAJa1PAnHT9JAhNa8AbAv3FPaiLzioSjCcmuXXhP8MlpHO5vwls4g4j6n30Z74GVQj8Xa62dWVx1QCGklg== + dependencies: + cliui "^5.0.0" + find-up "^3.0.0" + get-caller-file "^2.0.1" + os-locale "^3.1.0" + require-directory "^2.1.1" + require-main-filename "^2.0.0" + set-blocking "^2.0.0" + string-width "^3.0.0" + which-module "^2.0.0" + y18n "^4.0.0" + yargs-parser "^13.1.0" + +yargs@^13.3.0: + version "13.3.2" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-13.3.2.tgz#ad7ffefec1aa59565ac915f82dccb38a9c31a2dd" + integrity sha512-AX3Zw5iPruN5ie6xGRIDgqkT+ZhnRlZMLMHAs8tg7nRruy2Nb+i5o9bwghAogtM08q1dpr2LVoS8KSTMYpWXUw== + dependencies: + cliui "^5.0.0" + find-up "^3.0.0" + get-caller-file "^2.0.1" + require-directory "^2.1.1" + require-main-filename "^2.0.0" + set-blocking "^2.0.0" + string-width "^3.0.0" + which-module "^2.0.0" + y18n "^4.0.0" + yargs-parser "^13.1.2" diff --git a/captum/insights/attr_vis/models/cifar_torchvision.pt b/captum/insights/attr_vis/models/cifar_torchvision.pt new file mode 100644 index 0000000000000000000000000000000000000000..2f241ac1cac94b808fcf9c77b8140e071eac9498 --- /dev/null +++ b/captum/insights/attr_vis/models/cifar_torchvision.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:228d8332085282b6d2626d1b1b1b17944c760507d9b0cabe318dc6b108acf6e9 +size 249703 diff --git a/captum/insights/attr_vis/server.py b/captum/insights/attr_vis/server.py new file mode 100644 index 0000000000000000000000000000000000000000..124d152fce26b1ab7672774d2b2c9ea2fa83847f --- /dev/null +++ b/captum/insights/attr_vis/server.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +import logging +import os +import socket +import threading +from time import sleep +from typing import Optional + +from captum.log import log_usage +from flask import Flask, jsonify, render_template, request +from flask_compress import Compress +from torch import Tensor + +app = Flask( + __name__, static_folder="frontend/build/static", template_folder="frontend/build" +) +visualizer = None +port = None +Compress(app) + + +def namedtuple_to_dict(obj): + if isinstance(obj, Tensor): + return obj.item() + if hasattr(obj, "_asdict"): # detect namedtuple + return dict(zip(obj._fields, (namedtuple_to_dict(item) for item in obj))) + elif isinstance(obj, str): # iterables - strings + return obj + elif hasattr(obj, "keys"): # iterables - mapping + return dict( + zip(obj.keys(), (namedtuple_to_dict(item) for item in obj.values())) + ) + elif hasattr(obj, "__iter__"): # iterables - sequence + return type(obj)((namedtuple_to_dict(item) for item in obj)) + else: # non-iterable cannot contain namedtuples + return obj + + +@app.route("/attribute", methods=["POST"]) +def attribute(): + # force=True needed for Colab notebooks, which doesn't use the correct + # Content-Type header when forwarding requests through the Colab proxy + r = request.get_json(force=True) + return jsonify( + namedtuple_to_dict( + visualizer._calculate_attribution_from_cache( + r["inputIndex"], r["modelIndex"], r["labelIndex"] + ) + ) + ) + + +@app.route("/fetch", methods=["POST"]) +def fetch(): + # force=True needed, see comment for "/attribute" route above + visualizer._update_config(request.get_json(force=True)) + visualizer_output = visualizer.visualize() + clean_output = namedtuple_to_dict(visualizer_output) + return jsonify(clean_output) + + +@app.route("/init") +def init(): + return jsonify(visualizer.get_insights_config()) + + +@app.route("/") +def index(id=0): + return render_template("index.html") + + +def get_free_tcp_port(): + tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + tcp.bind(("", 0)) + addr, port = tcp.getsockname() + tcp.close() + return port + + +def run_app(debug: bool = True, bind_all: bool = False): + if bind_all: + app.run(port=port, use_reloader=False, debug=debug, host="0.0.0.0") + else: + app.run(port=port, use_reloader=False, debug=debug) + + +@log_usage() +def start_server( + _viz, + blocking: bool = False, + debug: bool = False, + _port: Optional[int] = None, + bind_all: bool = False, +): + global visualizer + visualizer = _viz + + global port + if port is None: + os.environ["WERKZEUG_RUN_MAIN"] = "true" # hides starting message + if not debug: + log = logging.getLogger("werkzeug") + log.disabled = True + app.logger.disabled = True + + port = _port or get_free_tcp_port() + # Start in a new thread to not block notebook execution + t = threading.Thread( + target=run_app, kwargs={"debug": debug, "bind_all": bind_all} + ) + t.start() + sleep(0.01) # add a short delay to allow server to start up + if blocking: + t.join() + + print(f"\nFetch data and view Captum Insights at http://localhost:{port}/\n") + return port diff --git a/captum/insights/attr_vis/widget/__init__.py b/captum/insights/attr_vis/widget/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..82f0af8d40acba40a629f0030ca1346af22fd364 --- /dev/null +++ b/captum/insights/attr_vis/widget/__init__.py @@ -0,0 +1,13 @@ +from captum.insights.attr_vis.widget._version import __version__, version_info # noqa +from captum.insights.attr_vis.widget.widget import * # noqa + + +def _jupyter_nbextension_paths(): + return [ + { + "section": "notebook", + "src": "static", + "dest": "jupyter-captum-insights", + "require": "jupyter-captum-insights/extension", + } + ] diff --git a/captum/insights/attr_vis/widget/_version.py b/captum/insights/attr_vis/widget/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..adb82fbbe21763144e206a9ad1fe78ff08e1a519 --- /dev/null +++ b/captum/insights/attr_vis/widget/_version.py @@ -0,0 +1,12 @@ +version_info = (0, 1, 0, "alpha", 0) + +_specifier_ = {"alpha": "a", "beta": "b", "candidate": "rc", "final": ""} + +__version__ = "%s.%s.%s%s" % ( + version_info[0], + version_info[1], + version_info[2], + "" + if version_info[3] == "final" + else _specifier_[version_info[3]] + str(version_info[4]), +) diff --git a/captum/insights/attr_vis/widget/widget.py b/captum/insights/attr_vis/widget/widget.py new file mode 100644 index 0000000000000000000000000000000000000000..2f5adbfcedd63b8d386ba287cd94bc9d3034cf65 --- /dev/null +++ b/captum/insights/attr_vis/widget/widget.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +import ipywidgets as widgets +from captum.insights import AttributionVisualizer +from captum.insights.attr_vis.server import namedtuple_to_dict +from traitlets import Dict, Instance, List, observe, Unicode + + +@widgets.register +class CaptumInsights(widgets.DOMWidget): + """A widget for interacting with Captum Insights.""" + + _view_name = Unicode("CaptumInsightsView").tag(sync=True) + _model_name = Unicode("CaptumInsightsModel").tag(sync=True) + _view_module = Unicode("jupyter-captum-insights").tag(sync=True) + _model_module = Unicode("jupyter-captum-insights").tag(sync=True) + _view_module_version = Unicode("^0.1.0").tag(sync=True) + _model_module_version = Unicode("^0.1.0").tag(sync=True) + + visualizer = Instance(klass=AttributionVisualizer) + + insights_config = Dict().tag(sync=True) + label_details = Dict().tag(sync=True) + attribution = Dict().tag(sync=True) + config = Dict().tag(sync=True) + output = List().tag(sync=True) + + def __init__(self, **kwargs) -> None: + super(CaptumInsights, self).__init__(**kwargs) + self.insights_config = self.visualizer.get_insights_config() + self.out = widgets.Output() + with self.out: + print("Captum Insights widget created.") + + @observe("config") + def _fetch_data(self, change): + if not self.config: + return + with self.out: + self.visualizer._update_config(self.config) + self.output = namedtuple_to_dict(self.visualizer.visualize()) + self.config = dict() + + @observe("label_details") + def _fetch_attribution(self, change): + if not self.label_details: + return + with self.out: + self.attribution = namedtuple_to_dict( + self.visualizer._calculate_attribution_from_cache( + self.label_details["inputIndex"], + self.label_details["modelIndex"], + self.label_details["labelIndex"], + ) + ) + self.label_details = dict() diff --git a/captum/insights/example.py b/captum/insights/example.py new file mode 100644 index 0000000000000000000000000000000000000000..a29a685a6d2725c5e6f66e0ffb6595aa64a48cb2 --- /dev/null +++ b/captum/insights/example.py @@ -0,0 +1,11 @@ +# for legacy purposes +import warnings + +from captum.insights.attr_vis.example import * # noqa + +warnings.warn( + "Deprecated. Please import from captum.insights.attr_vis.example instead." +) + + +main() # noqa diff --git a/captum/log/__init__.py b/captum/log/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81d61383d0acbf3dfeff9d5ae0562a864031e6d4 --- /dev/null +++ b/captum/log/__init__.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 + +try: + from captum.log.fb.internal_log import ( + log, + log_usage, + patch_methods, + set_environment, + TimedLog, + ) + + __all__ = ["log", "log_usage", "TimedLog", "set_environment"] + +except ImportError: + from functools import wraps + + def log(*args, **kwargs): + pass + + # bug with mypy: https://github.com/python/mypy/issues/1153 + class TimedLog: # type: ignore + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + return exception_value is not None + + def log_usage(*log_args, **log_kwargs): + def _log_usage(func): + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + return _log_usage + + def set_environment(env): + pass + + def patch_methods(tester, patch_log=True): + pass diff --git a/captum/metrics/__init__.py b/captum/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8e5a3ac3045a8d9ebaabad9540c19f7079652e --- /dev/null +++ b/captum/metrics/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from captum.metrics._core.infidelity import ( # noqa + infidelity, + infidelity_perturb_func_decorator, +) +from captum.metrics._core.sensitivity import sensitivity_max # noqa diff --git a/captum/metrics/_core/__init__.py b/captum/metrics/_core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a0d9b4834ec8f46d6e0d1256c6dcaad2e460fe --- /dev/null +++ b/captum/metrics/_core/__init__.py @@ -0,0 +1 @@ +#!/usr/bin/env python3 diff --git a/captum/metrics/_core/infidelity.py b/captum/metrics/_core/infidelity.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f014b67f38a908bb0c16a179c72ede87c55965 --- /dev/null +++ b/captum/metrics/_core/infidelity.py @@ -0,0 +1,580 @@ +#!/usr/bin/env python3 + +from typing import Any, Callable, cast, Tuple, Union + +import torch +from captum._utils.common import ( + _expand_additional_forward_args, + _expand_target, + _format_additional_forward_args, + _format_baseline, + _format_tensor_into_tuples, + _run_forward, + ExpansionTypes, + safe_div, +) +from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric +from captum.log import log_usage +from captum.metrics._utils.batching import _divide_and_aggregate_metrics +from torch import Tensor + + +def infidelity_perturb_func_decorator(multipy_by_inputs: bool = True) -> Callable: + r"""An auxiliary, decorator function that helps with computing + perturbations given perturbed inputs. It can be useful for cases + when `pertub_func` returns only perturbed inputs and we + internally compute the perturbations as + (input - perturbed_input) / (input - baseline) if + multipy_by_inputs is set to True and + (input - perturbed_input) otherwise. + + If users decorate their `pertub_func` with + `@infidelity_perturb_func_decorator` function then their `pertub_func` + needs to only return perturbed inputs. + + Args: + + multipy_by_inputs (bool): Indicates whether model inputs' + multiplier is factored in the computation of + attribution scores. + + """ + + def sub_infidelity_perturb_func_decorator(pertub_func: Callable) -> Callable: + r""" + Args: + + pertub_func(callable): Input perturbation function that takes inputs + and optionally baselines and returns perturbed inputs + + Returns: + + default_perturb_func(callable): Internal default perturbation + function that computes the perturbations internally and returns + perturbations and perturbed inputs. + + Examples:: + >>> @infidelity_perturb_func_decorator(True) + >>> def perturb_fn(inputs): + >>> noise = torch.tensor(np.random.normal(0, 0.003, + >>> inputs.shape)).float() + >>> return inputs - noise + >>> # Computes infidelity score using `perturb_fn` + >>> infidelity = infidelity(model, perturb_fn, input, ...) + + """ + + def default_perturb_func( + inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None + ): + r""" """ + inputs_perturbed = ( + pertub_func(inputs, baselines) + if baselines is not None + else pertub_func(inputs) + ) + inputs_perturbed = _format_tensor_into_tuples(inputs_perturbed) + inputs = _format_tensor_into_tuples(inputs) + baselines = _format_baseline(baselines, inputs) + if baselines is None: + perturbations = tuple( + safe_div( + input - input_perturbed, + input, + default_denom=1.0, + ) + if multipy_by_inputs + else input - input_perturbed + for input, input_perturbed in zip(inputs, inputs_perturbed) + ) + else: + perturbations = tuple( + safe_div( + input - input_perturbed, + input - baseline, + default_denom=1.0, + ) + if multipy_by_inputs + else input - input_perturbed + for input, input_perturbed, baseline in zip( + inputs, inputs_perturbed, baselines + ) + ) + return perturbations, inputs_perturbed + + return default_perturb_func + + return sub_infidelity_perturb_func_decorator + + +@log_usage() +def infidelity( + forward_func: Callable, + perturb_func: Callable, + inputs: TensorOrTupleOfTensorsGeneric, + attributions: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + additional_forward_args: Any = None, + target: TargetType = None, + n_perturb_samples: int = 10, + max_examples_per_batch: int = None, + normalize: bool = False, +) -> Tensor: + r""" + Explanation infidelity represents the expected mean-squared error + between the explanation multiplied by a meaningful input perturbation + and the differences between the predictor function at its input + and perturbed input. + More details about the measure can be found in the following paper: + https://arxiv.org/pdf/1901.09392.pdf + + It is derived from the completeness property of well-known attribution + algorithms and is a computationally more efficient and generalized + notion of Sensitivy-n. The latter measures correlations between the sum + of the attributions and the differences of the predictor function at + its input and fixed baseline. More details about the Sensitivity-n can + be found here: + https://arxiv.org/pdf/1711.06104.pdfs + + The users can perturb the inputs any desired way by providing any + perturbation function that takes the inputs (and optionally baselines) + and returns perturbed inputs or perturbed inputs and corresponding + perturbations. + + This specific implementation is primarily tested for attribution-based + explanation methods but the idea can be expanded to use for non + attribution-based interpretability methods as well. + + Args: + + forward_func (callable): + The forward function of the model or any modification of it. + + perturb_func (callable): + The perturbation function of model inputs. This function takes + model inputs and optionally baselines as input arguments and returns + either a tuple of perturbations and perturbed inputs or just + perturbed inputs. For example: + + >>> def my_perturb_func(inputs): + >>> + >>> return perturbations, perturbed_inputs + + If we want to only return perturbed inputs and compute + perturbations internally then we can wrap perturb_func with + `infidelity_perturb_func_decorator` decorator such as: + + >>> from captum.metrics import infidelity_perturb_func_decorator + + >>> @infidelity_perturb_func_decorator() + >>> def my_perturb_func(inputs): + >>> + >>> return perturbed_inputs + + In case `multipy_by_inputs` is False we compute perturbations by + `input - perturbed_input` difference and in case `multipy_by_inputs` + flag is True we compute it by dividing + (input - perturbed_input) by (input - baselines). + The user needs to only return perturbed inputs in `perturb_func` + as described above. + + `infidelity_perturb_func_decorator` needs to be used with + `multipy_by_inputs` flag set to False in case infidelity + score is being computed for attribution maps that are local aka + that do not factor in inputs in the final attribution score. + Such attribution algorithms include Saliency, GradCam, Guided Backprop, + or Integrated Gradients and DeepLift attribution scores that are already + computed with `multipy_by_inputs=False` flag. + + If there are more than one inputs passed to infidelity function those + will be passed to `perturb_func` as tuples in the same order as they + are passed to infidelity function. + + If inputs + - is a single tensor, the function needs to return a tuple + of perturbations and perturbed input such as: + perturb, perturbed_input and only perturbed_input in case + `infidelity_perturb_func_decorator` is used. + - is a tuple of tensors, corresponding perturbations and perturbed + inputs must be computed and returned as tuples in the + following format: + + (perturb1, perturb2, ... perturbN), (perturbed_input1, + perturbed_input2, ... perturbed_inputN) + + Similar to previous case here as well we need to return only + perturbed inputs in case `infidelity_perturb_func_decorator` + decorates out `perturb_func`. + It is important to note that for performance reasons `perturb_func` + isn't called for each example individually but on a batch of + input examples that are repeated `max_examples_per_batch / batch_size` + times within the batch. + + inputs (tensor or tuple of tensors): Input for which + attributions are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define reference values which sometimes represent ablated + values and are used to compare with the actual inputs to compute + importance scores in attribution algorithms. They can be represented + as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + + Default: None + + attributions (tensor or tuple of tensors): + Attribution scores computed based on an attribution algorithm. + This attribution scores can be computed using the implementations + provided in the `captum.attr` package. Some of those attribution + approaches are so called global methods, which means that + they factor in model inputs' multiplier, as described in: + https://arxiv.org/pdf/1711.06104.pdf + Many global attribution algorithms can be used in local modes, + meaning that the inputs multiplier isn't factored in the + attribution scores. + This can be done duing the definition of the attribution algorithm + by passing `multipy_by_inputs=False` flag. + For example in case of Integrated Gradients (IG) we can obtain + local attribution scores if we define the constructor of IG as: + ig = IntegratedGradients(multipy_by_inputs=False) + + Some attribution algorithms are inherently local. + Examples of inherently local attribution methods include: + Saliency, Guided GradCam, Guided Backprop and Deconvolution. + + For local attributions we can use real-valued perturbations + whereas for global attributions that perturbation is binary. + https://arxiv.org/pdf/1901.09392.pdf + + If we want to compute the infidelity of global attributions we + can use a binary perturbation matrix that will allow us to select + a subset of features from `inputs` or `inputs - baselines` space. + This will allow us to approximate sensitivity-n for a global + attribution algorithm. + + `infidelity_perturb_func_decorator` function decorator is a helper + function that computes perturbations under the hood if perturbed + inputs are provided. + + For more details about how to use `infidelity_perturb_func_decorator`, + please, read the documentation about `perturb_func` + + Attributions have the same shape and dimensionality as the inputs. + If inputs is a single tensor then the attributions is a single + tensor as well. If inputs is provided as a tuple of tensors + then attributions will be tuples of tensors as well. + + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a tuple + containing multiple additional arguments including tensors + or any arbitrary python types. These arguments are provided to + forward_func in order, following the arguments in inputs. + Note that the perturbations are not computed with respect + to these arguments. This means that these arguments aren't + being passed to `perturb_func` as an input argument. + + Default: None + target (int, tuple, tensor or list, optional): Indices for selecting + predictions from output(for classification cases, + this is usually the target class). + If the network returns a scalar value per example, no target + index is necessary. + For general 2D outputs, targets can be either: + + - A single integer or a tensor containing a single + integer, which is applied to all input examples + + - A list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + n_perturb_samples (int, optional): The number of times input tensors + are perturbed. Each input example in the inputs tensor is expanded + `n_perturb_samples` + times before calling `perturb_func` function. + + Default: 10 + max_examples_per_batch (int, optional): The number of maximum input + examples that are processed together. In case the number of + examples (`input batch size * n_perturb_samples`) exceeds + `max_examples_per_batch`, they will be sliced + into batches of `max_examples_per_batch` examples and processed + in a sequential order. If `max_examples_per_batch` is None, all + examples are processed together. `max_examples_per_batch` should + at least be equal `input batch size` and at most + `input batch size * n_perturb_samples`. + + Default: None + normalize (bool, optional): Normalize the dot product of the input + perturbation and the attribution so the infidelity value is invariant + to constant scaling of the attribution values. The normalization factor + beta is defined as the ratio of two mean values: + + .. math:: + \beta = \frac{ + \mathbb{E}_{I \sim \mu_I} [ I^T \Phi(f, x) (f(x) - f(x - I)) ] + }{ + \mathbb{E}_{I \sim \mu_I} [ (I^T \Phi(f, x))^2 ] + } + + Please refer the original paper for the meaning of the symbols. Same + normalization can be found in the paper's official implementation + https://github.com/chihkuanyeh/saliency_evaluation + + Default: False + Returns: + + infidelities (tensor): A tensor of scalar infidelity scores per + input example. The first dimension is equal to the + number of examples in the input batch and the second + dimension is one. + + Examples:: + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> saliency = Saliency(net) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes saliency maps for class 3. + >>> attribution = saliency.attribute(input, target=3) + >>> # define a perturbation function for the input + >>> def perturb_fn(inputs): + >>> noise = torch.tensor(np.random.normal(0, 0.003, inputs.shape)).float() + >>> return noise, inputs - noise + >>> # Computes infidelity score for saliency maps + >>> infid = infidelity(net, perturb_fn, input, attribution) + """ + + def _generate_perturbations( + current_n_perturb_samples: int, + ) -> Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]: + r""" + The perturbations are generated for each example + `current_n_perturb_samples` times. + + For performance reasons we are not calling `perturb_func` on each example but + on a batch that contains `current_n_perturb_samples` + repeated instances per example. + """ + + def call_perturb_func(): + r""" """ + baselines_pert = None + inputs_pert: Union[Tensor, Tuple[Tensor, ...]] + if len(inputs_expanded) == 1: + inputs_pert = inputs_expanded[0] + if baselines_expanded is not None: + baselines_pert = cast(Tuple, baselines_expanded)[0] + else: + inputs_pert = inputs_expanded + baselines_pert = baselines_expanded + return ( + perturb_func(inputs_pert, baselines_pert) + if baselines_pert is not None + else perturb_func(inputs_pert) + ) + + inputs_expanded = tuple( + torch.repeat_interleave(input, current_n_perturb_samples, dim=0) + for input in inputs + ) + + baselines_expanded = baselines + if baselines is not None: + baselines_expanded = tuple( + baseline.repeat_interleave(current_n_perturb_samples, dim=0) + if isinstance(baseline, torch.Tensor) + and baseline.shape[0] == input.shape[0] + and baseline.shape[0] > 1 + else baseline + for input, baseline in zip(inputs, cast(Tuple, baselines)) + ) + + return call_perturb_func() + + def _validate_inputs_and_perturbations( + inputs: Tuple[Tensor, ...], + inputs_perturbed: Tuple[Tensor, ...], + perturbations: Tuple[Tensor, ...], + ) -> None: + # asserts the sizes of the perturbations and inputs + assert len(perturbations) == len(inputs), ( + """The number of perturbed + inputs and corresponding perturbations must have the same number of + elements. Found number of inputs is: {} and perturbations: + {}""" + ).format(len(perturbations), len(inputs)) + + # asserts the shapes of the perturbations and perturbed inputs + for perturb, input_perturbed in zip(perturbations, inputs_perturbed): + assert perturb[0].shape == input_perturbed[0].shape, ( + """Perturbed input + and corresponding perturbation must have the same shape and + dimensionality. Found perturbation shape is: {} and the input shape + is: {}""" + ).format(perturb[0].shape, input_perturbed[0].shape) + + def _next_infidelity_tensors( + current_n_perturb_samples: int, + ) -> Union[Tuple[Tensor], Tuple[Tensor, Tensor, Tensor]]: + perturbations, inputs_perturbed = _generate_perturbations( + current_n_perturb_samples + ) + + perturbations = _format_tensor_into_tuples(perturbations) + inputs_perturbed = _format_tensor_into_tuples(inputs_perturbed) + + _validate_inputs_and_perturbations( + cast(Tuple[Tensor, ...], inputs), + cast(Tuple[Tensor, ...], inputs_perturbed), + cast(Tuple[Tensor, ...], perturbations), + ) + + targets_expanded = _expand_target( + target, + current_n_perturb_samples, + expansion_type=ExpansionTypes.repeat_interleave, + ) + additional_forward_args_expanded = _expand_additional_forward_args( + additional_forward_args, + current_n_perturb_samples, + expansion_type=ExpansionTypes.repeat_interleave, + ) + + inputs_perturbed_fwd = _run_forward( + forward_func, + inputs_perturbed, + targets_expanded, + additional_forward_args_expanded, + ) + inputs_fwd = _run_forward(forward_func, inputs, target, additional_forward_args) + inputs_fwd = torch.repeat_interleave( + inputs_fwd, current_n_perturb_samples, dim=0 + ) + perturbed_fwd_diffs = inputs_fwd - inputs_perturbed_fwd + attributions_expanded = tuple( + torch.repeat_interleave(attribution, current_n_perturb_samples, dim=0) + for attribution in attributions + ) + + attributions_times_perturb = tuple( + (attribution_expanded * perturbation).view(attribution_expanded.size(0), -1) + for attribution_expanded, perturbation in zip( + attributions_expanded, perturbations + ) + ) + + attr_times_perturb_sums = sum( + torch.sum(attribution_times_perturb, dim=1) + for attribution_times_perturb in attributions_times_perturb + ) + attr_times_perturb_sums = cast(Tensor, attr_times_perturb_sums) + + # reshape as Tensor(bsz, current_n_perturb_samples) + attr_times_perturb_sums = attr_times_perturb_sums.view(bsz, -1) + perturbed_fwd_diffs = perturbed_fwd_diffs.view(bsz, -1) + + if normalize: + # in order to normalize, we have to aggregate the following tensors + # to calculate MSE in its polynomial expansion: + # (a-b)^2 = a^2 - 2ab + b^2 + return ( + attr_times_perturb_sums.pow(2).sum(-1), + (attr_times_perturb_sums * perturbed_fwd_diffs).sum(-1), + perturbed_fwd_diffs.pow(2).sum(-1), + ) + else: + # returns (a-b)^2 if no need to normalize + return ((attr_times_perturb_sums - perturbed_fwd_diffs).pow(2).sum(-1),) + + def _sum_infidelity_tensors(agg_tensors, tensors): + return tuple(agg_t + t for agg_t, t in zip(agg_tensors, tensors)) + + # perform argument formattings + inputs = _format_tensor_into_tuples(inputs) # type: ignore + if baselines is not None: + baselines = _format_baseline(baselines, cast(Tuple[Tensor, ...], inputs)) + additional_forward_args = _format_additional_forward_args(additional_forward_args) + attributions = _format_tensor_into_tuples(attributions) # type: ignore + + # Make sure that inputs and corresponding attributions have matching sizes. + assert len(inputs) == len(attributions), ( + """The number of tensors in the inputs and + attributions must match. Found number of tensors in the inputs is: {} and in the + attributions: {}""" + ).format(len(inputs), len(attributions)) + for inp, attr in zip(inputs, attributions): + assert inp.shape == attr.shape, ( + """Inputs and attributions must have + matching shapes. One of the input tensor's shape is {} and the + attribution tensor's shape is: {}""" + ).format(inp.shape, attr.shape) + + bsz = inputs[0].size(0) + with torch.no_grad(): + # if not normalize, directly return aggrgated MSE ((a-b)^2,) + # else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2) + agg_tensors = _divide_and_aggregate_metrics( + cast(Tuple[Tensor, ...], inputs), + n_perturb_samples, + _next_infidelity_tensors, + agg_func=_sum_infidelity_tensors, + max_examples_per_batch=max_examples_per_batch, + ) + + if normalize: + beta_num = agg_tensors[1] + beta_denorm = agg_tensors[0] + + beta = safe_div(beta_num, beta_denorm) + + infidelity_values = ( + beta ** 2 * agg_tensors[0] - 2 * beta * agg_tensors[1] + agg_tensors[2] + ) + else: + infidelity_values = agg_tensors[0] + + infidelity_values /= n_perturb_samples + + return infidelity_values diff --git a/captum/metrics/_core/sensitivity.py b/captum/metrics/_core/sensitivity.py new file mode 100644 index 0000000000000000000000000000000000000000..77d87e62918b458b7e904e763e28dabc926eb99b --- /dev/null +++ b/captum/metrics/_core/sensitivity.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 + +from copy import deepcopy +from inspect import signature +from typing import Any, Callable, cast, Tuple, Union + +import torch +from captum._utils.common import ( + _expand_and_update_additional_forward_args, + _expand_and_update_baselines, + _expand_and_update_target, + _format_baseline, + _format_tensor_into_tuples, +) +from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from captum.log import log_usage +from captum.metrics._utils.batching import _divide_and_aggregate_metrics +from torch import Tensor + + +def default_perturb_func( + inputs: TensorOrTupleOfTensorsGeneric, perturb_radius: float = 0.02 +) -> Tuple[Tensor, ...]: + r"""A default function for generating perturbations of `inputs` + within perturbation radius of `perturb_radius`. + This function samples uniformly random from the L_Infinity ball + with `perturb_radius` radius. + The users can override this function if they prefer to use a + different perturbation function. + + Args: + + inputs (tensor or a tuple of tensors): The input tensors that we'd + like to perturb by adding a random noise sampled unifromly + random from an L_infinity ball with a radius `perturb_radius`. + + radius (float): A radius used for sampling from + an L_infinity ball. + + Returns: + + perturbed_input (tuple(tensor)): A list of perturbed inputs that + are createed by adding noise sampled uniformly random + from L_infiniy ball with a radius `perturb_radius` to the + original inputs. + + """ + inputs = _format_tensor_into_tuples(inputs) + perturbed_input = tuple( + input + + torch.FloatTensor(input.size()) # type: ignore + .uniform_(-perturb_radius, perturb_radius) + .to(input.device) + for input in inputs + ) + return perturbed_input + + +@log_usage() +def sensitivity_max( + explanation_func: Callable, + inputs: TensorOrTupleOfTensorsGeneric, + perturb_func: Callable = default_perturb_func, + perturb_radius: float = 0.02, + n_perturb_samples: int = 10, + norm_ord: str = "fro", + max_examples_per_batch: int = None, + **kwargs: Any, +) -> Tensor: + r""" + Explanation sensitivity measures the extent of explanation change when + the input is slightly perturbed. It has been shown that the models that + have high explanation sensitivity are prone to adversarial attacks: + `Interpretation of Neural Networks is Fragile` + https://www.aaai.org/ojs/index.php/AAAI/article/view/4252 + + `sensitivity_max` metric measures maximum sensitivity of an explanation + using Monte Carlo sampling-based approximation. By default in order to + do so it samples multiple data points from a sub-space of an L-Infinity + ball that has a `perturb_radius` radius using `default_perturb_func` + default perturbation function. In a general case users can + use any L_p ball or any other custom sampling technique that they + prefer by providing a custom `perturb_func`. + + Note that max sensitivity is similar to Lipschitz Continuity metric + however it is more robust and easier to estimate. + Since the explanation, for instance an attribution function, + may not always be continuous, can lead to unbounded + Lipschitz continuity. Therefore the latter isn't always appropriate. + + More about the Lipschitz Continuity Metric can also be found here + `On the Robustness of Interpretability Methods` + https://arxiv.org/pdf/1806.08049.pdf + and + `Towards Robust Interpretability with Self-Explaining Neural Networks` + https://papers.nips.cc/paper\ + 8003-towards-robust-interpretability- + with-self-explaining-neural-networks.pdf + + More details about sensitivity max can be found here: + `On the (In)fidelity and Sensitivity of Explanations` + https://arxiv.org/pdf/1901.09392.pdf + + Args: + + explanation_func (callable): + This function can be the `attribute` method of an + attribution algorithm or any other explanation method + that returns the explanations. + + inputs (tensor or tuple of tensors): Input for which + explanations are computed. If `explanation_func` takes a + single tensor as input, a single input tensor should + be provided. + If `explanation_func` takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples (aka batch size), and if + multiple input tensors are provided, the examples must + be aligned appropriately. + + perturb_func (callable): + The perturbation function of model inputs. This function takes + model inputs and optionally `perturb_radius` if + the function takes more than one argument and returns + perturbed inputs. + + If there are more than one inputs passed to sensitivity function those + will be passed to `perturb_func` as tuples in the same order as they + are passed to sensitivity function. + + It is important to note that for performance reasons `perturb_func` + isn't called for each example individually but on a batch of + input examples that are repeated `max_examples_per_batch / batch_size` + times within the batch. + + Default: default_perturb_func + perturb_radius (float, optional): The epsilon radius used for sampling. + In the `default_perturb_func` it is used as the radius of + the L-Infinity ball. In a general case it can serve as a radius of + any L_p nom. + This argument is passed to `perturb_func` if it takes more than + one argument. + + Default: 0.02 + n_perturb_samples (int, optional): The number of times input tensors + are perturbed. Each input example in the inputs tensor is + expanded `n_perturb_samples` times before calling + `perturb_func` function. + + Default: 10 + norm_ord (int, float, inf, -inf, 'fro', 'nuc', optional): The type of norm + that is used to compute the + norm of the sensitivity matrix which is defined as the difference + between the explanation function at its input and perturbed input. + + Default: 'fro' + max_examples_per_batch (int, optional): The number of maximum input + examples that are processed together. In case the number of + examples (`input batch size * n_perturb_samples`) exceeds + `max_examples_per_batch`, they will be sliced + into batches of `max_examples_per_batch` examples and processed + in a sequential order. If `max_examples_per_batch` is None, all + examples are processed together. `max_examples_per_batch` should + at least be equal `input batch size` and at most + `input batch size * n_perturb_samples`. + + Default: None + **kwargs (Any, optional): Contains a list of arguments that are passed + to `explanation_func` explanation function which in some cases + could be the `attribute` function of an attribution algorithm. + Any additional arguments that need be passed to the explanation + function should be included here. + For instance, such arguments include: + `additional_forward_args`, `baselines` and `target`. + + Returns: + + sensitivities (tensor): A tensor of scalar sensitivity scores per + input example. The first dimension is equal to the + number of examples in the input batch and the second + dimension is one. Returned sensitivities are normalized by + the magnitudes of the input explanations. + + Examples:: + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> saliency = Saliency(net) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes sensitivity score for saliency maps of class 3 + >>> sens = sensitivity_max(saliency.attribute, input, target = 3) + + """ + + def _generate_perturbations( + current_n_perturb_samples: int, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + The perturbations are generated for each example + `current_n_perturb_samples` times. + + For perfomance reasons we are not calling `perturb_func` on each example but + on a batch that contains `current_n_perturb_samples` repeated instances + per example. + """ + inputs_expanded: Union[Tensor, Tuple[Tensor, ...]] = tuple( + torch.repeat_interleave(input, current_n_perturb_samples, dim=0) + for input in inputs + ) + if len(inputs_expanded) == 1: + inputs_expanded = inputs_expanded[0] + + return ( + perturb_func(inputs_expanded, perturb_radius) + if len(signature(perturb_func).parameters) > 1 + else perturb_func(inputs_expanded) + ) + + def max_values(input_tnsr: Tensor) -> Tensor: + return torch.max(input_tnsr, dim=1).values # type: ignore + + kwarg_expanded_for = None + kwargs_copy: Any = None + + def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor: + inputs_perturbed = _generate_perturbations(current_n_perturb_samples) + + # copy kwargs and update some of the arguments that need to be expanded + nonlocal kwarg_expanded_for + nonlocal kwargs_copy + if ( + kwarg_expanded_for is None + or kwarg_expanded_for != current_n_perturb_samples + ): + kwarg_expanded_for = current_n_perturb_samples + kwargs_copy = deepcopy(kwargs) + _expand_and_update_additional_forward_args( + current_n_perturb_samples, kwargs_copy + ) + _expand_and_update_target(current_n_perturb_samples, kwargs_copy) + if "baselines" in kwargs: + baselines = kwargs["baselines"] + baselines = _format_baseline( + baselines, cast(Tuple[Tensor, ...], inputs) + ) + if ( + isinstance(baselines[0], Tensor) + and baselines[0].shape == inputs[0].shape + ): + _expand_and_update_baselines( + cast(Tuple[Tensor, ...], inputs), + current_n_perturb_samples, + kwargs_copy, + ) + + expl_perturbed_inputs = explanation_func(inputs_perturbed, **kwargs_copy) + + # tuplize `expl_perturbed_inputs` in case it is not + expl_perturbed_inputs = _format_tensor_into_tuples(expl_perturbed_inputs) + + expl_inputs_expanded = tuple( + expl_input.repeat_interleave(current_n_perturb_samples, dim=0) + for expl_input in expl_inputs + ) + + sensitivities = torch.cat( + [ + (expl_input - expl_perturbed).view(expl_perturbed.size(0), -1) + for expl_perturbed, expl_input in zip( + expl_perturbed_inputs, expl_inputs_expanded + ) + ], + dim=1, + ) + # compute the norm of original input explanations + expl_inputs_norm_expanded = torch.norm( + torch.cat( + [expl_input.view(expl_input.size(0), -1) for expl_input in expl_inputs], + dim=1, + ), + p=norm_ord, + dim=1, + keepdim=True, + ).repeat_interleave(current_n_perturb_samples, dim=0) + expl_inputs_norm_expanded = torch.where( + expl_inputs_norm_expanded == 0.0, + torch.tensor( + 1.0, + device=expl_inputs_norm_expanded.device, + dtype=expl_inputs_norm_expanded.dtype, + ), + expl_inputs_norm_expanded, + ) + + # compute the norm for each input noisy example + sensitivities_norm = ( + torch.norm(sensitivities, p=norm_ord, dim=1, keepdim=True) + / expl_inputs_norm_expanded + ) + return max_values(sensitivities_norm.view(bsz, -1)) + + inputs = _format_tensor_into_tuples(inputs) # type: ignore + + bsz = inputs[0].size(0) + + with torch.no_grad(): + expl_inputs = explanation_func(inputs, **kwargs) + metrics_max = _divide_and_aggregate_metrics( + cast(Tuple[Tensor, ...], inputs), + n_perturb_samples, + _next_sensitivity_max, + max_examples_per_batch=max_examples_per_batch, + agg_func=torch.max, + ) + return metrics_max diff --git a/captum/metrics/_utils/__init__.py b/captum/metrics/_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/metrics/_utils/batching.py b/captum/metrics/_utils/batching.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3b38f58e6642ed777761a9df114679b20b167c --- /dev/null +++ b/captum/metrics/_utils/batching.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 + +import warnings +from typing import Callable, Tuple + +import torch +from torch import Tensor + + +def _divide_and_aggregate_metrics( + inputs: Tuple[Tensor, ...], + n_perturb_samples: int, + metric_func: Callable, + agg_func: Callable = torch.add, + max_examples_per_batch: int = None, +) -> Tensor: + r""" + This function is used to slice large number of samples `n_perturb_samples` per + input example into smaller pieces, computing the metrics for each small piece and + aggregating the results across all `n_perturb_samples` per example. The function + returns overall aggregated metric per sample. The size of each slice is determined + by the `max_examples_per_batch` input parameter. + + Args: + + inputs (tuple): The original inputs formatted in a tuple that are passed to + the metrics function and that are used to compute the + attributions for. + n_perturb_samples (int): The number of samples per example that are used for + perturbation purposes for example. + metric_func (callable): This function takes the number of samples per + input batch and returns an overall metric for each example. + agg_func (callable, optional): This function is used to aggregate the + metrics across multiple sub-batches and that are + generated by `metric_func`. + max_examples_per_batch (int, optional): The maximum number of allowed examples + per batch. + + Returns: + + metric (tensor): A metric score estimated by `metric_func` per + input example. + """ + bsz = inputs[0].size(0) + + if max_examples_per_batch is not None and ( + max_examples_per_batch // bsz < 1 + or max_examples_per_batch // bsz > n_perturb_samples + ): + warnings.warn( + ( + "`max_examples_per_batch` must be at least equal to the" + " input batch size and at most to " + "`input batch size` * `n_perturb_samples`." + "`max_examples_per_batch` is: {} and the input batch size is: {}." + "This is necessary because we require that each sub-batch that is used " + "to compute the metrics, contains at least an instance of " + "the original example and doesn't exceed the number of " + "expanded n_perturb_samples." + ).format(max_examples_per_batch, bsz) + ) + + max_inps_per_batch = ( + n_perturb_samples + if max_examples_per_batch is None + else min(max(max_examples_per_batch // bsz, 1), n_perturb_samples) + ) + + current_n_steps = max_inps_per_batch + + metrics_sum = metric_func(max_inps_per_batch) + + while current_n_steps < n_perturb_samples: + current_n_steps += max_inps_per_batch + + metric = metric_func( + max_inps_per_batch + if current_n_steps <= n_perturb_samples + else max_inps_per_batch - (current_n_steps - n_perturb_samples) + ) + + current_n_steps = min(current_n_steps, n_perturb_samples) + + metrics_sum = agg_func(metrics_sum, metric) + return metrics_sum diff --git a/captum/robust/__init__.py b/captum/robust/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..42eb81886002c0eeaa5e1b235b96524b5fbfa3a4 --- /dev/null +++ b/captum/robust/__init__.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 + +from captum.robust._core.fgsm import FGSM # noqa +from captum.robust._core.metrics.attack_comparator import AttackComparator # noqa +from captum.robust._core.metrics.min_param_perturbation import ( # noqa + MinParamPerturbation, +) +from captum.robust._core.perturbation import Perturbation # noqa +from captum.robust._core.pgd import PGD # noqa diff --git a/captum/robust/_core/__init__.py b/captum/robust/_core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/robust/_core/fgsm.py b/captum/robust/_core/fgsm.py new file mode 100644 index 0000000000000000000000000000000000000000..38bae86f07f594583778cbba24eb9bc042e0ae1d --- /dev/null +++ b/captum/robust/_core/fgsm.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +from typing import Any, Callable, Tuple + +import torch +from captum._utils.common import ( + _format_additional_forward_args, + _format_output, + _format_tensor_into_tuples, + _is_tuple, + _select_targets, +) +from captum._utils.gradient import ( + apply_gradient_requirements, + compute_gradients, + undo_gradient_requirements, +) +from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from captum.robust._core.perturbation import Perturbation +from torch import Tensor + + +class FGSM(Perturbation): + r""" + Fast Gradient Sign Method is an one-step method that can generate + adversarial examples. For non-targeted attack, the formulation is + x' = x + epsilon * sign(gradient of L(theta, x, y)). + For targeted attack on t, the formulation is + x' = x - epsilon * sign(gradient of L(theta, x, t)). + L(theta, x, y) is the model's loss function with respect to model + parameters, inputs and labels. + + More details on Fast Gradient Sign Method can be found in the original + paper: + https://arxiv.org/pdf/1412.6572.pdf + """ + + def __init__( + self, + forward_func: Callable, + loss_func: Callable = None, + lower_bound: float = float("-inf"), + upper_bound: float = float("inf"), + ) -> None: + r""" + Args: + forward_func (callable): The pytorch model for which the attack is + computed. + loss_func (callable, optional): Loss function of which the gradient + computed. The loss function should take in outputs of the + model and labels, and return a loss tensor. + The default loss function is negative log. + lower_bound (float, optional): Lower bound of input values. + upper_bound (float, optional): Upper bound of input values. + e.g. image pixels must be in the range 0-255 + + Attributes: + bound (Callable): A function that bounds the input values based on + given lower_bound and upper_bound. Can be overwritten for + custom use cases if necessary. + zero_thresh (float): The threshold below which gradient will be treated + as zero. Can be modified for custom use cases if necessary. + """ + super().__init__() + self.forward_func = forward_func + self.loss_func = loss_func + self.bound = lambda x: torch.clamp(x, min=lower_bound, max=upper_bound) + self.zero_thresh = 10 ** -6 + + def perturb( + self, + inputs: TensorOrTupleOfTensorsGeneric, + epsilon: float, + target: Any, + additional_forward_args: Any = None, + targeted: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + This method computes and returns the perturbed input for each input tensor. + It supports both targeted and non-targeted attacks. + + Args: + + inputs (tensor or tuple of tensors): Input for which adversarial + attack is computed. It can be provided as a single + tensor or a tuple of multiple tensors. If multiple + input tensors are provided, the batch sizes must be + aligned accross all tensors. + epsilon (float): Step size of perturbation. + target (any): True labels of inputs if non-targeted attack is + desired. Target class of inputs if targeted attack + is desired. Target will be passed to the loss function + to compute loss, so the type needs to match the + argument type of the loss function. + + If using the default negative log as loss function, + labels should be of type int, tuple, tensor or list. + For general 2D outputs, labels can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the label for the corresponding example. + + For outputs with > 2 dimensions, labels can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This label index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + label for the corresponding example. + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. These arguments are provided to + forward_func in order following the arguments in inputs. + Default: None. + targeted (bool, optional): If attack should be targeted. + Default: False. + + + Returns: + + - **perturbed inputs** (*tensor* or tuple of *tensors*): + Perturbed input for each + input tensor. The perturbed inputs have the same shape and + dimensionality as the inputs. + If a single tensor is provided as inputs, a single tensor + is returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + """ + is_inputs_tuple = _is_tuple(inputs) + inputs: Tuple[Tensor, ...] = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(inputs) + + def _forward_with_loss() -> Tensor: + additional_inputs = _format_additional_forward_args(additional_forward_args) + outputs = self.forward_func( # type: ignore + *(*inputs, *additional_inputs) # type: ignore + if additional_inputs is not None + else inputs + ) + if self.loss_func is not None: + return self.loss_func(outputs, target) + else: + loss = -torch.log(outputs) + return _select_targets(loss, target) + + grads = compute_gradients(_forward_with_loss, inputs) + undo_gradient_requirements(inputs, gradient_mask) + perturbed_inputs = self._perturb(inputs, grads, epsilon, targeted) + perturbed_inputs = tuple( + self.bound(perturbed_inputs[i]) for i in range(len(perturbed_inputs)) + ) + return _format_output(is_inputs_tuple, perturbed_inputs) + + def _perturb( + self, + inputs: Tuple, + grads: Tuple, + epsilon: float, + targeted: bool, + ) -> Tuple: + r""" + A helper function to calculate the perturbed inputs given original + inputs, gradient of loss function and epsilon. The calculation is + different for targetd v.s. non-targeted as described above. + """ + multiplier = -1 if targeted else 1 + inputs = tuple( + torch.where( + torch.abs(grad) > self.zero_thresh, + inp + multiplier * epsilon * torch.sign(grad), + inp, + ) + for grad, inp in zip(grads, inputs) + ) + return inputs diff --git a/captum/robust/_core/metrics/__init__.py b/captum/robust/_core/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captum/robust/_core/metrics/attack_comparator.py b/captum/robust/_core/metrics/attack_comparator.py new file mode 100644 index 0000000000000000000000000000000000000000..57b03e8f1828cb645283b88bbbfd935ca6acbb39 --- /dev/null +++ b/captum/robust/_core/metrics/attack_comparator.py @@ -0,0 +1,467 @@ +#!/usr/bin/env python3 +import warnings +from collections import namedtuple +from typing import ( + Any, + Callable, + cast, + Dict, + Generic, + List, + NamedTuple, + Optional, + Tuple, + TypeVar, + Union, +) + +from captum._utils.common import ( + _expand_additional_forward_args, + _format_additional_forward_args, + _reduce_list, +) +from captum.attr import Max, Mean, Min, Summarizer +from captum.robust._core.perturbation import Perturbation +from torch import Tensor + +ORIGINAL_KEY = "Original" + +MetricResultType = TypeVar( + "MetricResultType", float, Tensor, Tuple[Union[float, Tensor], ...] +) + + +class AttackInfo(NamedTuple): + attack_fn: Union[Perturbation, Callable] + name: str + num_attempts: int + apply_before_preproc: bool + attack_kwargs: Dict[str, Any] + additional_args: List[str] + + +def agg_metric(inp): + if isinstance(inp, Tensor): + return inp.mean(dim=0) + elif isinstance(inp, tuple): + return tuple(agg_metric(elem) for elem in inp) + return inp + + +class AttackComparator(Generic[MetricResultType]): + r""" + Allows measuring model robustness for a given attack or set of attacks. This class + can be used with any metric(s) as well as any set of attacks, either based on + attacks / perturbations from captum.robust such as FGSM or PGD or external + augmentation methods or perturbations such as torchvision transforms. + """ + + def __init__( + self, + forward_func: Callable, + metric: Callable[..., MetricResultType], + preproc_fn: Callable = None, + ) -> None: + r""" + Args: + forward_func (callable or torch.nn.Module): This can either be an instance + of pytorch model or any modification of a model's forward + function. + + metric (callable): This function is applied to the model output in + order to compute the desired performance metric or metrics. + This function should have the following signature:: + + >>> def model_metric(model_out: Tensor, **kwargs: Any) + >>> -> Union[float, Tensor, Tuple[Union[float, Tensor], ...]: + + All kwargs provided to evaluate are provided to the metric function, + following the model output. A single metric can be returned as + a float or tensor, and multiple metrics should be returned as either + a tuple or named tuple of floats or tensors. For a tensor metric, + the first dimension should match the batch size, corresponding to + metrics for each example. Tensor metrics are averaged over the first + dimension when aggregating multiple batch results. + If tensor metrics represent results for the full batch, the size of the + first dimension should be 1. + + preproc_fn (callable, optional): Optional method applied to inputs. Output + of preproc_fn is then provided as input to model, in addition to + additional_forward_args provided to evaluate. + """ + self.forward_func = forward_func + self.metric: Callable = metric + self.preproc_fn = preproc_fn + self.attacks: Dict[str, AttackInfo] = {} + self.summary_results: Dict[str, Summarizer] = {} + self.metric_aggregator = agg_metric + self.batch_stats = [Mean, Min, Max] + self.aggregate_stats = [Mean] + self.summary_results = {} + self.out_format = None + + def add_attack( + self, + attack: Union[Perturbation, Callable], + name: Optional[str] = None, + num_attempts: int = 1, + apply_before_preproc: bool = True, + attack_kwargs: Optional[Dict[str, Any]] = None, + additional_attack_arg_names: Optional[List[str]] = None, + ) -> None: + r""" + Adds attack to be evaluated when calling evaluate. + + Args: + attack (perturbation or callable): This can either be an instance + of a Captum Perturbation / Attack + or any other perturbation or attack function such + as a torchvision transform. + + name (optional, str): Name or identifier for attack, used as key for + attack results. This defaults to attack.__class__.__name__ + if not provided and must be unique for all added attacks. + + num_attempts (int): Number of attempts that attack should be + repeated. This should only be set to > 1 for non-deterministic + attacks. The minimum, maximum, and average (best, worst, and + average case) are tracked for attack attempts. + + apply_before_preproc (bool): Defines whether attack should be applied + before or after preproc function. + + attack_kwargs (dict): Additional arguments to be provided to given attack. + This should be provided as a dictionary of keyword arguments. + + additional_attack_arg_names (list[str]): Any additional arguments for the + attack which are specific to the particular input example or batch. + An example of this is target, which is necessary for some attacks such + as FGSM or PGD. These arguments are included if provided as a kwarg + to evaluate. + """ + if name is None: + name = attack.__class__.__name__ + + if attack_kwargs is None: + attack_kwargs = {} + + if additional_attack_arg_names is None: + additional_attack_arg_names = [] + + if name in self.attacks: + raise RuntimeError( + "Cannot add attack with same name as existing attack {}".format(name) + ) + + self.attacks[name] = AttackInfo( + attack_fn=attack, + name=name, + num_attempts=num_attempts, + apply_before_preproc=apply_before_preproc, + attack_kwargs=attack_kwargs, + additional_args=additional_attack_arg_names, + ) + + def _format_summary( + self, summary: Union[Dict, List[Dict]] + ) -> Dict[str, MetricResultType]: + r""" + This method reformats a given summary; particularly for tuples, + the Summarizer's summary format is a list of dictionaries, + each containing the summary for the corresponding elements. + We reformat this to return a dictionary with tuples containing + the summary results. + """ + if isinstance(summary, dict): + return summary + else: + summary_dict: Dict[str, Tuple] = {} + for key in summary[0]: + summary_dict[key] = tuple(s[key] for s in summary) + if self.out_format: + summary_dict[key] = self.out_format(*summary_dict[key]) + return summary_dict # type: ignore + + def _update_out_format( + self, out_metric: Union[float, Tensor, Tuple[Union[float, Tensor], ...]] + ) -> None: + if ( + not self.out_format + and isinstance(out_metric, tuple) + and hasattr(out_metric, "_fields") + ): + self.out_format = namedtuple( # type: ignore + type(out_metric).__name__, cast(NamedTuple, out_metric)._fields + ) + + def _evaluate_batch( + self, + input_list: List[Any], + additional_forward_args: Optional[Tuple], + key_list: List[str], + batch_summarizers: Dict[str, Summarizer], + metric_kwargs: Dict[str, Any], + ) -> None: + if additional_forward_args is None: + additional_forward_args = () + if len(input_list) == 1: + model_out = self.forward_func(input_list[0], *additional_forward_args) + out_metric = self.metric(model_out, **metric_kwargs) + self._update_out_format(out_metric) + batch_summarizers[key_list[0]].update(out_metric) + else: + batched_inps = _reduce_list(input_list) + model_out = self.forward_func(batched_inps, *additional_forward_args) + current_count = 0 + for i in range(len(input_list)): + batch_size = ( + input_list[i].shape[0] + if isinstance(input_list[i], Tensor) + else input_list[i][0].shape[0] + ) + out_metric = self.metric( + model_out[current_count : current_count + batch_size], + **metric_kwargs, + ) + self._update_out_format(out_metric) + batch_summarizers[key_list[i]].update(out_metric) + current_count += batch_size + + def evaluate( + self, + inputs: Any, + additional_forward_args: Any = None, + perturbations_per_eval: int = 1, + **kwargs, + ) -> Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]]: + r""" + Evaluate model and attack performance on provided inputs + + Args: + + inputs (any): Input for which attack metrics + are computed. It can be provided as a tensor, tuple of tensors, + or any raw input type (e.g. PIL image or text string). + This input is provided directly as input to preproc function as well + as any attack applied before preprocessing. If no pre-processing + function is provided, this input is provided directly to the main + model and all attacks. + + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the preprocessing + outputs (or inputs if preproc_fn is None), this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. For all other types, + the given argument is used for all forward evaluations. + Default: None + perturbations_per_eval (int, optional): Allows perturbations of multiple + attacks to be grouped and evaluated in one call of forward_fn + Each forward pass will contain a maximum of + perturbations_per_eval * #examples samples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain at most + (perturbations_per_eval * #examples) / num_devices + samples. + In order to apply this functionality, the output of preproc_fn + (or inputs itself if no preproc_fn is provided) must be a tensor + or tuple of tensors. + Default: 1 + kwargs (any, optional): Additional keyword arguments provided to metric function + as well as selected attacks based on chosen additional_args + + Returns: + + - **attack results** Dict: str -> Dict[str, Union[Tensor, Tuple[Tensor, ...]]]: + Dictionary containing attack results for provided batch. + Maps attack name to dictionary, + containing best-case, worst-case and average-case results for attack. + Dictionary contains keys "mean", "max" and "min" when num_attempts > 1 + and only "mean" for num_attempts = 1, which contains the (single) metric + result for the attack attempt. + An additional key of 'Original' is included with metric results + without any perturbations. + + + Examples:: + + >>> def accuracy_metric(model_out: Tensor, targets: Tensor): + >>> return torch.argmax(model_out, dim=1) == targets).float() + + >>> attack_metric = AttackComparator(model=resnet18, + metric=accuracy_metric, + preproc_fn=normalize) + + >>> random_rotation = transforms.RandomRotation() + >>> jitter = transforms.ColorJitter() + + >>> attack_metric.add_attack(random_rotation, "Random Rotation", + >>> num_attempts = 5) + >>> attack_metric.add_attack((jitter, "Jitter", num_attempts = 1) + >>> attack_metric.add_attack(FGSM(resnet18), "FGSM 0.1", num_attempts = 1, + >>> apply_before_preproc=False, + >>> attack_kwargs={epsilon: 0.1}, + >>> additional_args=["targets"]) + + >>> for images, labels in dataloader: + >>> batch_results = attack_metric.evaluate(inputs=images, targets=labels) + """ + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + expanded_additional_args = ( + _expand_additional_forward_args( + additional_forward_args, perturbations_per_eval + ) + if perturbations_per_eval > 1 + else additional_forward_args + ) + + preproc_input = None + if self.preproc_fn is not None: + preproc_input = self.preproc_fn(inputs) + else: + preproc_input = inputs + + input_list = [preproc_input] + key_list = [ORIGINAL_KEY] + + batch_summarizers = {ORIGINAL_KEY: Summarizer([Mean()])} + if ORIGINAL_KEY not in self.summary_results: + self.summary_results[ORIGINAL_KEY] = Summarizer( + [stat() for stat in self.aggregate_stats] + ) + + def _check_and_evaluate(input_list, key_list): + if len(input_list) == perturbations_per_eval: + self._evaluate_batch( + input_list, + expanded_additional_args, + key_list, + batch_summarizers, + kwargs, + ) + return [], [] + return input_list, key_list + + input_list, key_list = _check_and_evaluate(input_list, key_list) + + for attack_key in self.attacks: + attack = self.attacks[attack_key] + if attack.num_attempts > 1: + stats = [stat() for stat in self.batch_stats] + else: + stats = [Mean()] + batch_summarizers[attack.name] = Summarizer(stats) + additional_attack_args = {} + for key in attack.additional_args: + if key not in kwargs: + warnings.warn( + f"Additional sample arg {key} not provided for {attack_key}" + ) + else: + additional_attack_args[key] = kwargs[key] + + for _ in range(attack.num_attempts): + if attack.apply_before_preproc: + attacked_inp = attack.attack_fn( + inputs, **additional_attack_args, **attack.attack_kwargs + ) + preproc_attacked_inp = ( + self.preproc_fn(attacked_inp) + if self.preproc_fn + else attacked_inp + ) + else: + preproc_attacked_inp = attack.attack_fn( + preproc_input, **additional_attack_args, **attack.attack_kwargs + ) + + input_list.append(preproc_attacked_inp) + key_list.append(attack.name) + + input_list, key_list = _check_and_evaluate(input_list, key_list) + + if len(input_list) > 0: + final_add_args = _expand_additional_forward_args( + additional_forward_args, len(input_list) + ) + self._evaluate_batch( + input_list, final_add_args, key_list, batch_summarizers, kwargs + ) + + return self._parse_and_update_results(batch_summarizers) + + def _parse_and_update_results( + self, batch_summarizers: Dict[str, Summarizer] + ) -> Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]]: + results: Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]] = { + ORIGINAL_KEY: self._format_summary( + cast(Union[Dict, List], batch_summarizers[ORIGINAL_KEY].summary) + )["mean"] + } + self.summary_results[ORIGINAL_KEY].update( + self.metric_aggregator(results[ORIGINAL_KEY]) + ) + for attack_key in self.attacks: + attack = self.attacks[attack_key] + attack_results = self._format_summary( + cast(Union[Dict, List], batch_summarizers[attack.name].summary) + ) + results[attack.name] = attack_results + + if len(attack_results) == 1: + key = next(iter(attack_results)) + if attack.name not in self.summary_results: + self.summary_results[attack.name] = Summarizer( + [stat() for stat in self.aggregate_stats] + ) + self.summary_results[attack.name].update( + self.metric_aggregator(attack_results[key]) + ) + else: + for key in attack_results: + summary_key = f"{attack.name} {key.title()} Attempt" + if summary_key not in self.summary_results: + self.summary_results[summary_key] = Summarizer( + [stat() for stat in self.aggregate_stats] + ) + self.summary_results[summary_key].update( + self.metric_aggregator(attack_results[key]) + ) + return results + + def summary(self) -> Dict[str, Dict[str, MetricResultType]]: + r""" + Returns average results over all previous batches evaluated. + + Returns: + + - **summary** Dict: str -> Dict[str, Union[Tensor, Tuple[Tensor, ...]]]: + Dictionary containing summarized average attack results. + Maps attack name (with "Mean Attempt", "Max Attempt" and "Min Attempt" + suffixes if num_attempts > 1) to dictionary containing a key of "mean" + maintaining summarized results, + which is the running mean of results over all batches + since construction or previous reset call. Tensor metrics are averaged + over dimension 0 for each batch, in order to aggregte metrics collected + per batch. + """ + return { + key: self._format_summary( + cast(Union[Dict, List], self.summary_results[key].summary) + ) + for key in self.summary_results + } + + def reset(self) -> None: + r""" + Reset stored average summary results for previous batches + """ + self.summary_results = {} diff --git a/captum/robust/_core/metrics/min_param_perturbation.py b/captum/robust/_core/metrics/min_param_perturbation.py new file mode 100644 index 0000000000000000000000000000000000000000..279179ab648fde5cfeb3622276f0017f48be7e7c --- /dev/null +++ b/captum/robust/_core/metrics/min_param_perturbation.py @@ -0,0 +1,457 @@ +#!/usr/bin/env python3 +import math +from enum import Enum +from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union + +import torch +from captum._utils.common import ( + _expand_additional_forward_args, + _format_additional_forward_args, + _reduce_list, +) +from captum._utils.typing import TargetType +from captum.robust._core.perturbation import Perturbation +from torch import Tensor + + +def drange( + min_val: Union[int, float], max_val: Union[int, float], step_val: Union[int, float] +) -> Generator[Union[int, float], None, None]: + curr = min_val + while curr < max_val: + yield curr + curr += step_val + + +def default_correct_fn(model_out: Tensor, target: TargetType) -> bool: + assert ( + isinstance(model_out, Tensor) and model_out.ndim == 2 + ), "Model output must be a 2D tensor to use default correct function;" + " otherwise custom correct function must be provided" + target_tensor = torch.tensor(target) if not isinstance(target, Tensor) else target + return all(torch.argmax(model_out, dim=1) == target_tensor) + + +class MinParamPerturbationMode(Enum): + LINEAR = 0 + BINARY = 1 + + +class MinParamPerturbation: + def __init__( + self, + forward_func: Callable, + attack: Union[Callable, Perturbation], + arg_name: str, + arg_min: Union[int, float], + arg_max: Union[int, float], + arg_step: Union[int, float], + mode: str = "linear", + num_attempts: int = 1, + preproc_fn: Optional[Callable] = None, + apply_before_preproc: bool = False, + correct_fn: Optional[Callable] = None, + ): + r""" + Identifies minimal perturbation based on target variable which causes + misclassification (or other incorrect prediction) of target input. + + More specifically, given a perturbation parametrized by a single value + (e.g. rotation by angle or mask percentage of top features based on + attribution results), MinParamPerturbation helps identify the minimum value + which leads to misclassification (or other model output change) with the + corresponding perturbed input. + + Args: + forward_func (callable or torch.nn.Module): This can either be an instance + of pytorch model or any modification of a model's forward + function. + + attack (Perturbation or Callable): This can either be an instance + of a Captum Perturbation / Attack + or any other perturbation or attack function such + as a torchvision transform. + Perturb function must take additional argument (var_name) used for + minimal perturbation search. + + arg_name (str): Name of argument / variable paramterizing attack, must be + kwarg of attack. Examples are num_dropout or stdevs + + arg_min (int, float): Minimum value of target variable + + arg_max (int, float): Maximum value of target variable + (not included in range) + + arg_step (int, float): Minimum interval for increase of target variable. + + mode (str, optional): Mode for search of minimum attack value; + either 'linear' for linear search on variable, or 'binary' for + binary search of variable + Default: 'linear' + + num_attempts (int, optional): Number of attempts or trials with + given variable. This should only be set to > 1 for non-deterministic + perturbation / attack functions + Default: 1 + + preproc_fn (callable, optional): Optional method applied to inputs. Output + of preproc_fn is then provided as input to model, in addition to + additional_forward_args provided to evaluate. + Default: None + + apply_before_preproc (bool, optional): Defines whether attack should be + applied before or after preproc function. + Default: False + + correct_fn (Callable, optional): This determines whether the perturbed input + leads to a correct or incorrect prediction. By default, this function + is set to the standard classification test for correctness + (comparing argmax of output with target), which requires model output to + be a 2D tensor, returning True if all batch examples are correct and + false otherwise. Setting this method allows + any custom behavior defining whether the perturbation is successful + at fooling the model. For non-classification use cases, a custom + function must be provided which determines correctness. + + The first argument to this function must be the model out; + any additional arguments should be provided through correct_fn_kwargs. + + This function should have the following signature: + def correct_fn(model_out: Tensor, **kwargs: Any) -> bool + + Method should return a boolean if correct (True) and incorrect (False). + Default: None (applies standard correct_fn for classification) + """ + self.forward_func = forward_func + self.attack = attack + self.arg_name = arg_name + self.arg_min = arg_min + self.arg_max = arg_max + self.arg_step = arg_step + assert self.arg_max > ( + self.arg_min + self.arg_step + ), "Step size cannot be smaller than range between min and max" + + self.num_attempts = num_attempts + self.preproc_fn = preproc_fn + self.apply_before_preproc = apply_before_preproc + self.correct_fn = cast( + Callable, correct_fn if correct_fn is not None else default_correct_fn + ) + + assert ( + mode.upper() in MinParamPerturbationMode.__members__ + ), f"Provided perturb mode {mode} is not valid - must be linear or binary" + self.mode = MinParamPerturbationMode[mode.upper()] + + def _evaluate_batch( + self, + input_list: List, + additional_forward_args: Any, + correct_fn_kwargs: Optional[Dict[str, Any]], + target: TargetType, + ) -> Optional[int]: + if additional_forward_args is None: + additional_forward_args = () + + all_kwargs = {} + if target is not None: + all_kwargs["target"] = target + if correct_fn_kwargs is not None: + all_kwargs.update(correct_fn_kwargs) + + if len(input_list) == 1: + model_out = self.forward_func(input_list[0], *additional_forward_args) + out_metric = self.correct_fn(model_out, **all_kwargs) + return 0 if not out_metric else None + else: + batched_inps = _reduce_list(input_list) + model_out = self.forward_func(batched_inps, *additional_forward_args) + current_count = 0 + for i in range(len(input_list)): + batch_size = ( + input_list[i].shape[0] + if isinstance(input_list[i], Tensor) + else input_list[i][0].shape[0] + ) + out_metric = self.correct_fn( + model_out[current_count : current_count + batch_size], **all_kwargs + ) + if not out_metric: + return i + current_count += batch_size + return None + + def _apply_attack( + self, + inputs: Any, + preproc_input: Any, + attack_kwargs: Optional[Dict[str, Any]], + param: Union[int, float], + ) -> Tuple[Any, Any]: + if attack_kwargs is None: + attack_kwargs = {} + if self.apply_before_preproc: + attacked_inp = self.attack( + inputs, **attack_kwargs, **{self.arg_name: param} + ) + preproc_attacked_inp = ( + self.preproc_fn(attacked_inp) if self.preproc_fn else attacked_inp + ) + else: + attacked_inp = self.attack( + preproc_input, **attack_kwargs, **{self.arg_name: param} + ) + preproc_attacked_inp = attacked_inp + return preproc_attacked_inp, attacked_inp + + def _linear_search( + self, + inputs: Any, + preproc_input: Any, + attack_kwargs: Optional[Dict[str, Any]], + additional_forward_args: Any, + expanded_additional_args: Any, + correct_fn_kwargs: Optional[Dict[str, Any]], + target: TargetType, + perturbations_per_eval: int, + ) -> Tuple[Any, Optional[Union[int, float]]]: + input_list = [] + attack_inp_list = [] + param_list = [] + + for param in drange(self.arg_min, self.arg_max, self.arg_step): + for _ in range(self.num_attempts): + preproc_attacked_inp, attacked_inp = self._apply_attack( + inputs, preproc_input, attack_kwargs, param + ) + + input_list.append(preproc_attacked_inp) + param_list.append(param) + attack_inp_list.append(attacked_inp) + + if len(input_list) == perturbations_per_eval: + successful_ind = self._evaluate_batch( + input_list, + expanded_additional_args, + correct_fn_kwargs, + target, + ) + if successful_ind is not None: + return ( + attack_inp_list[successful_ind], + param_list[successful_ind], + ) + input_list = [] + param_list = [] + attack_inp_list = [] + if len(input_list) > 0: + final_add_args = _expand_additional_forward_args( + additional_forward_args, len(input_list) + ) + successful_ind = self._evaluate_batch( + input_list, + final_add_args, + correct_fn_kwargs, + target, + ) + if successful_ind is not None: + return ( + attack_inp_list[successful_ind], + param_list[successful_ind], + ) + return None, None + + def _binary_search( + self, + inputs: Any, + preproc_input: Any, + attack_kwargs: Optional[Dict[str, Any]], + additional_forward_args: Any, + expanded_additional_args: Any, + correct_fn_kwargs: Optional[Dict[str, Any]], + target: TargetType, + perturbations_per_eval: int, + ) -> Tuple[Any, Optional[Union[int, float]]]: + min_range = self.arg_min + max_range = self.arg_max + min_so_far = None + min_input = None + while max_range > min_range: + mid_step = ((max_range - min_range) // self.arg_step) // 2 + + if mid_step == 0 and min_range + self.arg_step < max_range: + mid_step = 1 + mid = min_range + (mid_step * self.arg_step) + + input_list = [] + param_list = [] + attack_inp_list = [] + attack_success = False + + for i in range(self.num_attempts): + preproc_attacked_inp, attacked_inp = self._apply_attack( + inputs, preproc_input, attack_kwargs, mid + ) + + input_list.append(preproc_attacked_inp) + param_list.append(mid) + attack_inp_list.append(attacked_inp) + + if len(input_list) == perturbations_per_eval or i == ( + self.num_attempts - 1 + ): + additional_args = expanded_additional_args + if len(input_list) != perturbations_per_eval: + additional_args = _expand_additional_forward_args( + additional_forward_args, len(input_list) + ) + + successful_ind = self._evaluate_batch( + input_list, + additional_args, + correct_fn_kwargs, + target, + ) + if successful_ind is not None: + attack_success = True + max_range = mid + if min_so_far is None or min_so_far > mid: + min_so_far = mid + min_input = attack_inp_list[successful_ind] + break + + input_list = [] + param_list = [] + attack_inp_list = [] + + if math.isclose(min_range, mid): + break + + if not attack_success: + min_range = mid + + return min_input, min_so_far + + def evaluate( + self, + inputs: Any, + additional_forward_args: Optional[Tuple] = None, + target: TargetType = None, + perturbations_per_eval: int = 1, + attack_kwargs: Optional[Dict[str, Any]] = None, + correct_fn_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[Any, Optional[Union[int, float]]]: + r""" + This method evaluates the model at each perturbed input and identifies + the minimum perturbation that leads to an incorrect model prediction. + + It is recommended to provide a single input (batch size = 1) when using + this to identify a minimal perturbation for the chosen example. If a + batch of examples is provided, the default correct function identifies + the minimal perturbation for at least 1 example in the batch to be + misclassified. A custom correct_fn can be provided to customize + this behavior and define correctness for the batch. + + Args: + + inputs (Any): Input for which minimal perturbation + is computed. It can be provided as a tensor, tuple of tensors, + or any raw input type (e.g. PIL image or text string). + This input is provided directly as input to preproc function + as well as any attack applied before preprocessing. If no + pre-processing function is provided, + this input is provided directly to the main model and all attacks. + + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the preprocessing + outputs (or inputs if preproc_fn is None), this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. For all other types, + the given argument is used for all forward evaluations. + Default: None + target (TargetType): Target class for classification. This is required if + using the default correct_fn + + perturbations_per_eval (int, optional): Allows perturbations of multiple + attacks to be grouped and evaluated in one call of forward_fn + Each forward pass will contain a maximum of + perturbations_per_eval * #examples samples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain at most + (perturbations_per_eval * #examples) / num_devices + samples. + In order to apply this functionality, the output of preproc_fn + (or inputs itself if no preproc_fn is provided) must be a tensor + or tuple of tensors. + Default: 1 + attack_kwargs (dictionary, optional): Optional dictionary of keyword + arguments provided to attack function + correct_fn_kwargs (dictionary, optional): Optional dictionary of keyword + arguments provided to correct function + + Returns: + + Tuple of (perturbed_inputs, param_val) if successful + else Tuple of (None, None) + + - **perturbed inputs** (Any): + Perturbed input (output of attack) which results in incorrect + prediction. + - param_val (int, float) + Param value leading to perturbed inputs causing misclassification + + Examples:: + + >>> def gaussian_noise(inp: Tensor, std: float) -> Tensor: + >>> return inp + std*torch.randn_like(inp) + + >>> min_pert = MinParamPerturbation(forward_func=resnet18, + attack=gaussian_noise, + arg_name="std", + arg_min=0.0, + arg_max=2.0, + arg_step=0.01, + ) + >>> for images, labels in dataloader: + >>> noised_image, min_std = min_pert.evaluate(inputs=images, target=labels) + + """ + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + expanded_additional_args = ( + _expand_additional_forward_args( + additional_forward_args, perturbations_per_eval + ) + if perturbations_per_eval > 1 + else additional_forward_args + ) + preproc_input = inputs if not self.preproc_fn else self.preproc_fn(inputs) + + if self.mode is MinParamPerturbationMode.LINEAR: + search_fn = self._linear_search + elif self.mode is MinParamPerturbationMode.BINARY: + search_fn = self._binary_search + else: + raise NotImplementedError( + "Chosen MinParamPerturbationMode is not supported!" + ) + + return search_fn( + inputs, + preproc_input, + attack_kwargs, + additional_forward_args, + expanded_additional_args, + correct_fn_kwargs, + target, + perturbations_per_eval, + ) diff --git a/captum/robust/_core/perturbation.py b/captum/robust/_core/perturbation.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb6d534816e9bb70955ac1c4222f92394753c8d --- /dev/null +++ b/captum/robust/_core/perturbation.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +from typing import Callable + + +class Perturbation: + r""" + All perturbation and attack algorithms extend this class. It enforces + its child classes to extend and override core `perturb` method. + """ + + perturb: Callable + r""" + This method computes and returns the perturbed input for each input tensor. + Deriving classes are responsible for implementing its logic accordingly. + + Specific adversarial attack algorithms that extend this class take relevant + arguments. + + Args: + + inputs (tensor or tuple of tensors): Input for which adversarial attack + is computed. It can be provided as a single tensor or + a tuple of multiple tensors. If multiple input tensors + are provided, the batch sizes must be aligned accross all + tensors. + + Returns: + + - **perturbed inputs** (*tensor* or tuple of *tensors*): + Perturbed input for each + input tensor. The perturbed inputs have the same shape and + dimensionality as the inputs. + If a single tensor is provided as inputs, a single tensor + is returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + """ + + def __call__(self, *args, **kwargs): + return self.perturb(*args, **kwargs) diff --git a/captum/robust/_core/pgd.py b/captum/robust/_core/pgd.py new file mode 100644 index 0000000000000000000000000000000000000000..b14239c681cc2585a0e3a1032a6a2c07ec26c824 --- /dev/null +++ b/captum/robust/_core/pgd.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple +from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from captum.robust._core.fgsm import FGSM +from captum.robust._core.perturbation import Perturbation +from torch import Tensor + + +class PGD(Perturbation): + r""" + Projected Gradient Descent is an iterative version of the one-step attack + FGSM that can generate adversarial examples. It takes multiple gradient + steps to search for an adversarial perturbation within the desired + neighbor ball around the original inputs. In a non-targeted attack, the + formulation is:: + + x_0 = x + x_(t+1) = Clip_r(x_t + alpha * sign(gradient of L(theta, x, t))) + + where Clip denotes the function that projects its argument to the r-neighbor + ball around x so that the perturbation will be bounded. Alpha is the step + size. L(theta, x, y) is the model's loss function with respect to model + parameters, inputs and targets. + In a targeted attack, the formulation is similar:: + + x_0 = x + x_(t+1) = Clip_r(x_t - alpha * sign(gradient of L(theta, x, t))) + + More details on Projected Gradient Descent can be found in the original + paper: + https://arxiv.org/pdf/1706.06083.pdf + """ + + def __init__( + self, + forward_func: Callable, + loss_func: Callable = None, + lower_bound: float = float("-inf"), + upper_bound: float = float("inf"), + ) -> None: + r""" + Args: + forward_func (callable): The pytorch model for which the attack is + computed. + loss_func (callable, optional): Loss function of which the gradient + computed. The loss function should take in outputs of the + model and labels, and return the loss for each input tensor. + The default loss function is negative log. + lower_bound (float, optional): Lower bound of input values. + upper_bound (float, optional): Upper bound of input values. + e.g. image pixels must be in the range 0-255 + + Attributes: + bound (Callable): A function that bounds the input values based on + given lower_bound and upper_bound. Can be overwritten for + custom use cases if necessary. + """ + super().__init__() + self.forward_func = forward_func + self.fgsm = FGSM(forward_func, loss_func) + self.bound = lambda x: torch.clamp(x, min=lower_bound, max=upper_bound) + + def perturb( + self, + inputs: TensorOrTupleOfTensorsGeneric, + radius: float, + step_size: float, + step_num: int, + target: Any, + additional_forward_args: Any = None, + targeted: bool = False, + random_start: bool = False, + norm: str = "Linf", + ) -> TensorOrTupleOfTensorsGeneric: + r""" + This method computes and returns the perturbed input for each input tensor. + It supports both targeted and non-targeted attacks. + + Args: + + inputs (tensor or tuple of tensors): Input for which adversarial + attack is computed. It can be provided as a single + tensor or a tuple of multiple tensors. If multiple + input tensors are provided, the batch sizes must be + aligned accross all tensors. + radius (float): Radius of the neighbor ball centered around inputs. + The perturbation should be within this range. + step_size (float): Step size of each gradient step. + step_num (int): Step numbers. It usually guarantees that the perturbation + can reach the border. + target (any): True labels of inputs if non-targeted attack is + desired. Target class of inputs if targeted attack + is desired. Target will be passed to the loss function + to compute loss, so the type needs to match the + argument type of the loss function. + + If using the default negative log as loss function, + labels should be of type int, tuple, tensor or list. + For general 2D outputs, labels can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the label for the corresponding example. + + For outputs with > 2 dimensions, labels can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This label index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + label for the corresponding example. + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. These arguments are provided to + forward_func in order following the arguments in inputs. + Default: None. + targeted (bool, optional): If attack should be targeted. + Default: False. + random_start (bool, optional): If a random initialization is added to + inputs. Default: False. + norm (str, optional): Specifies the norm to calculate distance from + original inputs: 'Linf'|'L2'. + Default: 'Linf'. + + Returns: + + - **perturbed inputs** (*tensor* or tuple of *tensors*): + Perturbed input for each + input tensor. The perturbed inputs have the same shape and + dimensionality as the inputs. + If a single tensor is provided as inputs, a single tensor + is returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + """ + + def _clip(inputs: Tensor, outputs: Tensor) -> Tensor: + diff = outputs - inputs + if norm == "Linf": + return inputs + torch.clamp(diff, -radius, radius) + elif norm == "L2": + return inputs + torch.renorm(diff, 2, 0, radius) + else: + raise AssertionError("Norm constraint must be L2 or Linf.") + + is_inputs_tuple = _is_tuple(inputs) + formatted_inputs = _format_tensor_into_tuples(inputs) + perturbed_inputs = formatted_inputs + if random_start: + perturbed_inputs = tuple( + self.bound(self._random_point(formatted_inputs[i], radius, norm)) + for i in range(len(formatted_inputs)) + ) + for _i in range(step_num): + perturbed_inputs = self.fgsm.perturb( + perturbed_inputs, step_size, target, additional_forward_args, targeted + ) + perturbed_inputs = tuple( + _clip(formatted_inputs[j], perturbed_inputs[j]) + for j in range(len(perturbed_inputs)) + ) + # Detaching inputs to avoid dependency of gradient between steps + perturbed_inputs = tuple( + self.bound(perturbed_inputs[j]).detach() + for j in range(len(perturbed_inputs)) + ) + return _format_output(is_inputs_tuple, perturbed_inputs) + + def _random_point(self, center: Tensor, radius: float, norm: str) -> Tensor: + r""" + A helper function that returns a uniform random point within the ball + with the given center and radius. Norm should be either L2 or Linf. + """ + if norm == "L2": + u = torch.randn_like(center) + unit_u = F.normalize(u.view(u.size(0), -1)).view(u.size()) + d = torch.numel(center[0]) + r = (torch.rand(u.size(0)) ** (1.0 / d)) * radius + r = r[(...,) + (None,) * (r.dim() - 1)] + x = r * unit_u + return center + x + elif norm == "Linf": + x = torch.rand_like(center) * radius * 2 - radius + return center + x + else: + raise AssertionError("Norm constraint must be L2 or Linf.") diff --git a/captum_improve_abinet.py b/captum_improve_abinet.py new file mode 100644 index 0000000000000000000000000000000000000000..34df91d0f5b6580cbae966943ed456ce39808b8e --- /dev/null +++ b/captum_improve_abinet.py @@ -0,0 +1,769 @@ +import os +import time +import string +import argparse +import re +import sys +import random +import pickle +import logging +from fastai.distributed import * +from fastai.vision import * + +import settings +import torch +import torch.backends.cudnn as cudnn +import torch.utils.data +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from skimage.color import gray2rgb +from nltk.metrics.distance import edit_distance +import cv2 +import pickle +import copy + +# from dataset import hierarchical_dataset, AlignCollate +# from model import Model, SuperPixler, CastNumpy, STRScore +# import hiddenlayer as hl +from callbacks import DumpPrediction, IterationCallback, TextAccuracy, TopKTextAccuracy +from dataset_abinet import ImageDataset, CustomImageDataset, TextDataset +from losses import MultiLosses +import matplotlib.pyplot as plt +import random +from utils_abinet import Config, Logger, MyDataParallel, MyConcatDataset, CharsetMapper +from utils import SRNConverter +from model_abinet import STRScore +from lime.wrappers.scikit_image import SegmentationAlgorithm +from captum._utils.models.linear_model import SkLearnLinearModel, SkLearnRidge +from captum_test import acquire_average_auc, saveAttrData, acquire_bestacc_attr, acquireAttribution + +device = torch.device('cpu') + +from captum.attr import ( + GradientShap, + DeepLift, + DeepLiftShap, + IntegratedGradients, + LayerConductance, + NeuronConductance, + NoiseTunnel, + Saliency, + InputXGradient, + GuidedBackprop, + Deconvolution, + GuidedGradCam, + FeatureAblation, + ShapleyValueSampling, + Lime, + KernelShap +) + +from captum.metrics import ( + infidelity, + sensitivity_max +) + +from captum.attr._utils.visualization import visualize_image_attr + +### Acquire pixelwise attributions and replace them with ranked numbers averaged +### across segmentation with the largest contribution having the largest number +### and the smallest set to 1, which is the minimum number. +### attr - original attribution +### segm - image segmentations +def rankedAttributionsBySegm(attr, segm): + aveSegmentations, sortedDict = averageSegmentsOut(attr[0,0], segm) + totalSegm = len(sortedDict.keys()) # total segmentations + sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])] + sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score + currentRank = totalSegm + rankedSegmImg = torch.clone(attr) + for totalSegToHide in range(0, len(sortedKeys)): + currentSegmentToHide = sortedKeys[totalSegToHide] + rankedSegmImg[0,0][segm == currentSegmentToHide] = currentRank + currentRank -= 1 + return rankedSegmImg + +### Returns the mean for each segmentation having shape as the same as the input +### This function can only one attribution image at a time +def averageSegmentsOut(attr, segments): + averagedInput = torch.clone(attr) + sortedDict = {} + for x in np.unique(segments): + segmentMean = torch.mean(attr[segments == x][:]) + sortedDict[x] = float(segmentMean.detach().cpu().numpy()) + averagedInput[segments == x] = segmentMean + return averagedInput, sortedDict + +def acquireSelectivityHit(origImg, attributions, segmentations, model, charset, labels, scoring): + # print("segmentations unique len: ", np.unique(segmentations)) + aveSegmentations, sortedDict = averageSegmentsOut(attributions[0,0], segmentations) + sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])] + sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score + # print("sortedDict: ", sortedDict) # {0: -5.51e-06, 1: -1.469e-05, 2: -3.06e-05,...} + # print("aveSegmentations unique len: ", np.unique(aveSegmentations)) + # print("aveSegmentations device: ", aveSegmentations.device) # cuda:0 + # print("aveSegmentations shape: ", aveSegmentations.shape) # (224,224) + # print("aveSegmentations: ", aveSegmentations) + + n_correct = [] + confidenceList = [] # First index is one feature removed, second index two features removed, and so on... + clonedImg = torch.clone(origImg) + gt = labels + for totalSegToHide in range(0, len(sortedKeys)): + ### Acquire LIME prediction result + currentSegmentToHide = sortedKeys[totalSegToHide] + clonedImg[0,0][segmentations == currentSegmentToHide] = 0.0 + modelOut = model(clonedImg) ### Returns a tuple of dictionaries + confScore = scoring(modelOut).cpu().detach().numpy() + pred, _, __ = postprocess(modelOut[0], charset, config.model_eval) + pred = pred[0] # outputs a list, so query [0] + if pred.lower() == gt.lower(): ### not lowercase gt labels, pred only predicts lowercase + n_correct.append(1) + else: + n_correct.append(0) + confScore = confScore[0][0]*100 + confidenceList.append(confScore) + return n_correct, confidenceList + +def _set_random_seed(seed): + if seed is not None: + random.seed(seed) + torch.manual_seed(seed) + cudnn.deterministic = True + logging.warning('You have chosen to seed training. ' + 'This will slow down your training!') + +def get_model(config): + import importlib + names = config.model_name.split('.') + module_name, class_name = '.'.join(names[:-1]), names[-1] + cls = getattr(importlib.import_module(module_name), class_name) + model = cls(config) + logging.info(model) + model = model.eval() + return model + +def load(model, file, device=None, strict=True): + if device is None: device = 'cpu' + elif isinstance(device, int): device = torch.device('cuda', device) + assert os.path.isfile(file) + state = torch.load(file, map_location=device) + if set(state.keys()) == {'model', 'opt'}: + state = state['model'] + model.load_state_dict(state, strict=strict) + return model + +def _get_dataset(ds_type, paths, is_training, config, **kwargs): + kwargs.update({ + 'img_h': config.dataset_image_height, + 'img_w': config.dataset_image_width, + 'max_length': config.dataset_max_length, + 'case_sensitive': config.dataset_case_sensitive, + 'charset_path': config.dataset_charset_path, + 'data_aug': config.dataset_data_aug, + 'deteriorate_ratio': config.dataset_deteriorate_ratio, + 'is_training': is_training, + 'multiscales': config.dataset_multiscales, + 'one_hot_y': config.dataset_one_hot_y, + }) + datasets = [ds_type(p, **kwargs) for p in paths] + if len(datasets) > 1: return MyConcatDataset(datasets) + else: return datasets[0] + +def _get_databaunch(config): + # An awkward way to reduce loadding data time during test + if config.global_phase == 'test': config.dataset_train_roots = config.dataset_test_roots + train_ds = _get_dataset(ImageDataset, config.dataset_train_roots, True, config) + valid_ds = _get_dataset(ImageDataset, config.dataset_test_roots, False, config) + data = ImageDataBunch.create( + train_ds=train_ds, + valid_ds=valid_ds, + bs=config.dataset_train_batch_size, + val_bs=config.dataset_test_batch_size, + num_workers=config.dataset_num_workers, + pin_memory=config.dataset_pin_memory).normalize(imagenet_stats) + ar_tfm = lambda x: ((x[0], x[1]), x[1]) # auto-regression only for dtd + data.add_tfm(ar_tfm) + + logging.info(f'{len(data.train_ds)} training items found.') + if not data.empty_val: + logging.info(f'{len(data.valid_ds)} valid items found.') + + return data + +def postprocess(output, charset, model_eval): + def _get_output(last_output, model_eval): + if isinstance(last_output, (tuple, list)): + for res in last_output: + if res['name'] == model_eval: return res + return last_output + + def _decode(logit): + """ Greed decode """ + out = F.softmax(logit, dim=2) + pt_text, pt_scores, pt_lengths = [], [], [] + for o in out: + text = charset.get_text(o.argmax(dim=1), padding=False, trim=False) + text = text.split(charset.null_char)[0] # end at end-token + pt_text.append(text) + pt_scores.append(o.max(dim=1)[0]) + pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token + return pt_text, pt_scores, pt_lengths + + output = _get_output(output, model_eval) + # print("output type: ", type(output)) + logits, pt_lengths = output['logits'], output['pt_lengths'] + pt_text, pt_scores, pt_lengths_ = _decode(logits) + + return pt_text, pt_scores, pt_lengths_ + +def main(config): + height = config.imgH + width = config.imgW + # 'IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80' + targetDataset = settings.TARGET_DATASET # Change also the configs/train_abinet.yaml test.roots test folder + segmRootDir = "{}/{}X{}/{}/".format(settings.SEGM_DIR, height, width, targetDataset) + outputSelectivityPkl = "strexp_ave_{}_{}.pkl".format(settings.MODEL, targetDataset) + outputDir = "./attributionImgs/{}/{}/".format(settings.MODEL, targetDataset) + attrOutputDir = "./attributionData/{}/{}/".format(settings.MODEL, targetDataset) + resumePkl = "" # Use to resume when session destroyed. Set to "" to disable + acquireSelectivity = True + acquireInfidelity = False + acquireSensitivity = False + if not os.path.exists(outputDir): + os.makedirs(outputDir) + if not os.path.exists(attrOutputDir): + os.makedirs(attrOutputDir) + charset = CharsetMapper(filename=config.dataset_charset_path, + max_length=config.dataset_max_length + 1) + config.character = "abcdefghijklmnopqrstuvwxyz1234567890$#" # See charset_36.txt + converter = SRNConverter(config.character, 36) + + model = get_model(config).to(device) + model = load(model, config.model_checkpoint, device=device) + + """ evaluation """ + modelCopy = copy.deepcopy(model) + scoring_singlechar = STRScore(config=config, charsetMapper=charset, postprocessFunc=postprocess, device=device, enableSingleCharAttrAve=True) + super_pixel_model_singlechar = torch.nn.Sequential( + modelCopy, + scoring_singlechar + ).to(device) + modelCopy.eval() + scoring_singlechar.eval() + super_pixel_model_singlechar.eval() + + scoring = STRScore(config=config, charsetMapper=charset, postprocessFunc=postprocess, device=device) + ### SuperModel + super_pixel_model = torch.nn.Sequential( + model, + scoring + ).to(device) + model.eval() + scoring.eval() + super_pixel_model.eval() + + selectivity_eval_results = [] + + if config.blackbg: + shapImgLs = np.zeros(shape=(1, 3, 32, 128)).astype(np.float32) + trainList = np.array(shapImgLs) + background = torch.from_numpy(trainList).to(device) + + # define a perturbation function for the input (used for calculating infidelity) + def perturb_fn(modelInputs): + noise = torch.tensor(np.random.normal(0, 0.003, modelInputs.shape)).float() + noise = noise.to(device) + return noise, modelInputs - noise + + strict = ifnone(config.model_strict, True) + ### Dataset not shuffled because it is not a dataloader, just a dataset + valid_ds = _get_dataset(CustomImageDataset, config.dataset_test_roots, False, config) + # print("valid_ds: ", len(valid_ds[0])) + testImgCount = 0 + if resumePkl != "": + with open(resumePkl, 'rb') as filePkl: + selectivity_eval_results = pickle.load(filePkl) + testImgCount = selectivity_eval_results[-1]["testImgCount"] # ResumeCount + try: + for i, (orig_img_tensors, labels, labels_tensor) in enumerate(valid_ds): + if i <= testImgCount: + continue + orig_img_tensors = orig_img_tensors.unsqueeze(0) + # print("orig_img_tensors: ", orig_img_tensors.shape) # (3, 32, 128) + # img_rgb *= 255.0 + # img_rgb = img_rgb.astype('int') + # print("img_rgb max: ", img_rgb.max()) ### 255 + # img_rgb = np.asarray(orig_img_tensors) + # segmentations = segmentation_fn(img_rgb) + # print("segmentations shape: ", segmentations.shape) # (224, 224) + # print("segmentations min: ", segmentations.min()) 0 + # print("Unique: ", len(np.unique(segmentations))) # (70) + results_dict = {} + with open(segmRootDir + "{}.pkl".format(i), 'rb') as f: + pklData = pickle.load(f) + # segmData, labels = segAndLabels[0] + segmDataNP = pklData["segdata"] + labels = labels.lower() # For fair evaluation for all + assert pklData['label'] == labels + segmTensor = torch.from_numpy(segmDataNP).unsqueeze(0).unsqueeze(0) + # print("segmTensor min: ", segmTensor.min()) # 0 starting segmentation + segmTensor = segmTensor.to(device) + # print("segmTensor shape: ", segmTensor.shape) + # img1 = np.asarray(imgPIL.convert('L')) + # sys.exit() + # img1 = img1 / 255.0 + # img1 = torch.from_numpy(img1).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device) + img1 = orig_img_tensors.to(device) + img1.requires_grad = True + bgImg = torch.zeros(img1.shape).to(device) + # preds = model(img1, seqlen=converter.batch_max_length) + input = img1 + origImgNP = torch.clone(orig_img_tensors).detach().cpu().numpy()[0][0] # (1, 1, 224, 224) + origImgNP = gray2rgb(origImgNP) + + ### Integrated Gradients + ig = IntegratedGradients(super_pixel_model) + attributions = ig.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_intgrad.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_intgrad.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["intgrad_acc"] = n_correct + results_dict["intgrad_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["intgrad_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ig.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["intgrad_sens"] = sens + + ### Gradient SHAP using zero-background + gs = GradientShap(super_pixel_model) + # We define a distribution of baselines and draw `n_samples` from that + # distribution in order to estimate the expectations of gradients across all baselines + baseline_dist = torch.zeros((1, 3, height, width)) + baseline_dist = baseline_dist.to(device) + attributions = gs.attribute(input, baselines=baseline_dist, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_gradshap.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_gradshap.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["gradshap_acc"] = n_correct + results_dict["gradshap_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["gradshap_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(gs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["gradshap_sens"] = sens + + ### DeepLift using zero-background + dl = DeepLift(super_pixel_model) + attributions = dl.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_deeplift.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_deeplift.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["deeplift_acc"] = n_correct + results_dict["deeplift_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["deeplift_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(dl.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["deeplift_sens"] = sens + + ### Saliency + saliency = Saliency(super_pixel_model) + attributions = saliency.attribute(input, target=0) ### target=class0 + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_saliency.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_saliency.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["saliency_acc"] = n_correct + results_dict["saliency_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["saliency_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(saliency.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["saliency_sens"] = sens + + ### InputXGradient + input_x_gradient = InputXGradient(super_pixel_model) + attributions = input_x_gradient.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_inpxgrad.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_inpxgrad.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["inpxgrad_acc"] = n_correct + results_dict["inpxgrad_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["inpxgrad_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(input_x_gradient.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["inpxgrad_sens"] = sens + + ### GuidedBackprop + gbp = GuidedBackprop(super_pixel_model) + attributions = gbp.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_guidedbp.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_guidedbp.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["guidedbp_acc"] = n_correct + results_dict["guidedbp_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["guidedbp_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(gbp.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["guidedbp_sens"] = sens + + ### Deconvolution + deconv = Deconvolution(super_pixel_model) + attributions = deconv.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_deconv.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_deconv.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["deconv_acc"] = n_correct + results_dict["deconv_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["deconv_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(deconv.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["deconv_sens"] = sens + + ### Feature ablator + ablator = FeatureAblation(super_pixel_model) + attributions = ablator.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_featablt.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_featablt.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["featablt_acc"] = n_correct + results_dict["featablt_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["featablt_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ablator.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["featablt_sens"] = sens + + ### Shapley Value Sampling + svs = ShapleyValueSampling(super_pixel_model) + # attr = svs.attribute(input, target=0, n_samples=200) ### Individual pixels, too long to calculate + attributions = svs.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_shapley.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_shapley.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["shapley_acc"] = n_correct + results_dict["shapley_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["shapley_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["shapley_sens"] = sens + + ## LIME + interpretable_model = SkLearnRidge(alpha=1, fit_intercept=True) ### This is the default used by LIME + lime = Lime(super_pixel_model, interpretable_model=interpretable_model) + attributions = lime.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_lime.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_lime.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["lime_acc"] = n_correct + results_dict["lime_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["lime_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(lime.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["lime_sens"] = sens + + ### KernelSHAP + ks = KernelShap(super_pixel_model) + attributions = ks.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_kernelshap.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_kernelshap.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["kernelshap_acc"] = n_correct + results_dict["kernelshap_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["kernelshap_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ks.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["kernelshap_sens"] = sens + + # Other data + results_dict["testImgCount"] = testImgCount # 0 to N-1 + selectivity_eval_results.append(results_dict) + + with open(outputSelectivityPkl, 'wb') as f: + pickle.dump(selectivity_eval_results, f) + + testImgCount += 1 + print("testImgCount: ", testImgCount) + except: + print("An exception occurred1") + + del valid_ds + valid_ds = _get_dataset(CustomImageDataset, config.dataset_test_roots, False, config) + bestAttributionKeyStr = acquire_bestacc_attr(config, outputSelectivityPkl) + bestAttrName = bestAttributionKeyStr.split('_')[0] + + testImgCount = 0 + try: + for i, (orig_img_tensors, labels, labels_tensor) in enumerate(valid_ds): + orig_img_tensors = orig_img_tensors.unsqueeze(0) + # print("orig_img_tensors: ", orig_img_tensors.shape) # (3, 32, 128) + # img_rgb *= 255.0 + # img_rgb = img_rgb.astype('int') + # print("img_rgb max: ", img_rgb.max()) ### 255 + # img_rgb = np.asarray(orig_img_tensors) + # segmentations = segmentation_fn(img_rgb) + # print("segmentations shape: ", segmentations.shape) # (224, 224) + # print("segmentations min: ", segmentations.min()) 0 + # print("Unique: ", len(np.unique(segmentations))) # (70) + results_dict = {} + with open(segmRootDir + "{}.pkl".format(i), 'rb') as f: + pklData = pickle.load(f) + # segmData, labels = segAndLabels[0] + segmDataNP = pklData["segdata"] + labels = labels.lower() # For fair evaluation for all + assert pklData['label'] == labels + # labels = "lama0" + target = converter.encode([labels], len(config.character)) + target = target[0] + 1 # Idx predicted by ABINET is 1 to N chars, not 0 to N-1 + target[target > 36] = 0 # Remove EOS predictions, set endpoint chars to 0 + segmTensor = torch.from_numpy(segmDataNP).unsqueeze(0).unsqueeze(0) + # print("segmTensor min: ", segmTensor.min()) # 0 starting segmentation + segmTensor = segmTensor.to(device) + # print("segmTensor shape: ", segmTensor.shape) + # img1 = np.asarray(imgPIL.convert('L')) + # sys.exit() + # img1 = img1 / 255.0 + # img1 = torch.from_numpy(img1).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device) + img1 = orig_img_tensors.to(device) + img1.requires_grad = True + bgImg = torch.zeros(img1.shape).to(device) + # preds = model(img1, seqlen=converter.batch_max_length) + input = img1 + origImgNP = torch.clone(orig_img_tensors).detach().cpu().numpy()[0][0] # (1, 1, 224, 224) + origImgNP = gray2rgb(origImgNP) + + charOffset = 0 + ### Local explanations only + collectedAttributions = [] + for charIdx in range(0, len(labels)): + scoring_singlechar.setSingleCharOutput(charIdx + charOffset) + # print("charIdx + charOffset: ", charIdx + charOffset) + # print("target[0]: ", target[0]) + gtClassNum = target[0][charIdx + charOffset] + + ### Best local + attributions = acquireAttribution(config, super_pixel_model_singlechar, \ + input, segmTensor, gtClassNum, bestAttributionKeyStr, device) + collectedAttributions.append(attributions) + aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0) + rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_{}_l.png'.format(i, bestAttrName)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_{bestAttrName}_l.pkl', aveAttributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, aveAttributions, segmDataNP, modelCopy, charset, labels, scoring_singlechar) + results_dict[f"{bestAttrName}_local_acc"] = n_correct + results_dict[f"{bestAttrName}_local_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model_singlechar, perturb_fn, img1, aveAttributions, normalize=True).detach().cpu().numpy()) + results_dict[f"{bestAttrName}_local_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict[f"{bestAttrName}_local_sens"] = sens + + ### Best global + attributions = acquireAttribution(config, super_pixel_model, \ + input, segmTensor, 0, bestAttributionKeyStr, device) + collectedAttributions.append(attributions) + + ### Global + Local context + aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0) + rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_{}_gl.png'.format(i, bestAttrName)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_{bestAttrName}_gl.pkl', aveAttributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, aveAttributions, segmDataNP, modelCopy, charset, labels, scoring_singlechar) + results_dict[f"{bestAttrName}_global_local_acc"] = n_correct + results_dict[f"{bestAttrName}_global_local_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model_singlechar, perturb_fn, img1, aveAttributions).detach().cpu().numpy()) + results_dict[f"{bestAttrName}_global_local_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict[f"{bestAttrName}_global_local_sens"] = sens + + selectivity_eval_results.append(results_dict) + + with open(outputSelectivityPkl, 'wb') as f: + pickle.dump(selectivity_eval_results, f) + + testImgCount += 1 + print("testImgCount GlobLoc: ", testImgCount) + except: + print("An exception occurred2") + +### Use to check if the model predicted the image or not. Output a pickle file with the image index. +def modelDatasetPredOnly(opt): + # 'IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', + # 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80' + datasetName = "IIIT5k_3000" + outputSelectivityPkl = "metrics_predictonly_eval_results_{}.pkl".format(datasetName) + charset = CharsetMapper(filename=config.dataset_charset_path, + max_length=config.dataset_max_length + 1) + model = get_model(config).to(device) + model = load(model, config.model_checkpoint, device=device) + model.eval() + strict = ifnone(config.model_strict, True) + ### Dataset not shuffled because it is not a dataloader, just a dataset + valid_ds = _get_dataset(CustomImageDataset, config.dataset_test_roots, False, config) + # print("valid_ds: ", len(valid_ds[0])) + testImgCount = 0 + predOutput = [] + for i, (orig_img_tensors, labels, labels_tensor) in enumerate(valid_ds): + orig_img_tensors = orig_img_tensors.unsqueeze(0).to(device) + modelOut = model(orig_img_tensors) ### Returns a tuple of dictionaries + pred, _, __ = postprocess(modelOut[0], charset, config.model_eval) + pred = pred[0] # outputs a list, so query [0] + if pred.lower() == labels.lower(): predOutput.append(1) + else: predOutput.append(0) + with open(outputSelectivityPkl, 'wb') as f: + pickle.dump(predOutput, f) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, required=True, + help='path to config file') + parser.add_argument('--phase', type=str, default=None, choices=['train', 'test']) + parser.add_argument('--name', type=str, default=None) + parser.add_argument('--checkpoint', type=str, default=None) + parser.add_argument('--test_root', type=str, default=None) + parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') + parser.add_argument('--imgW', type=int, default=128, help='the width of the input image') + parser.add_argument('--scorer', type=str, default='mean', help='See STRScore: cumprod | mean') + parser.add_argument('--rgb', action='store_true', help='use rgb input') + parser.add_argument("--local_rank", type=int, default=None) + parser.add_argument('--debug', action='store_true', default=None) + parser.add_argument('--image_only', action='store_true', default=None) + parser.add_argument('--blackbg', action='store_true', default=None) + parser.add_argument('--model_strict', action='store_false', default=None) + parser.add_argument('--model_eval', type=str, default=None, + choices=['alignment', 'vision', 'language']) + args = parser.parse_args() + config = Config(args.config) + if args.name is not None: config.global_name = args.name + if args.phase is not None: config.global_phase = args.phase + if args.test_root is not None: config.dataset_test_roots = [args.test_root] + if args.scorer is not None: config.scorer = args.scorer + if args.blackbg is not None: config.blackbg = args.blackbg + if args.rgb is not None: config.rgb = args.rgb + if args.imgH is not None: config.imgH = args.imgH + if args.imgW is not None: config.imgW = args.imgW + if args.checkpoint is not None: config.model_checkpoint = args.checkpoint + if args.debug is not None: config.global_debug = args.debug + if args.image_only is not None: config.global_image_only = args.image_only + if args.model_eval is not None: config.model_eval = args.model_eval + if args.model_strict is not None: config.model_strict = args.model_strict + + Logger.init(config.global_workdir, config.global_name, config.global_phase) + Logger.enable_file() + _set_random_seed(config.global_seed) + logging.info(config) + + # acquire_average_auc(config) + main(config) + # modelDatasetPredOnly(config) diff --git a/captum_improve_matrn.py b/captum_improve_matrn.py new file mode 100644 index 0000000000000000000000000000000000000000..9a917d5df42b88e10c27367bee30852e7f235001 --- /dev/null +++ b/captum_improve_matrn.py @@ -0,0 +1,772 @@ +import os +import time +import string +import argparse +import re +import sys +import random +import pickle +import logging +from fastai.distributed import * +from fastai.vision import * +import glob + +import settings +import torch +import torch.backends.cudnn as cudnn +import torch.utils.data +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from skimage.color import gray2rgb +from nltk.metrics.distance import edit_distance +import cv2 +import pickle +import copy + +# from dataset import hierarchical_dataset, AlignCollate +# from model import Model, SuperPixler, CastNumpy, STRScore +# import hiddenlayer as hl +from callbacks import DumpPrediction, IterationCallback, TextAccuracy, TopKTextAccuracy +from dataset_matrn import ImageDataset, CustomImageDataset, TextDataset +from losses_matrn import MultiLosses +from lime import lime_image +import matplotlib.pyplot as plt +import random +from transforms import CVColorJitter, CVDeterioration, CVGeometry +from utils_matrn import Config, Logger, CharsetMapper, MyConcatDataset +from utils import SRNConverter +from model_matrn import STRScore +from lime.wrappers.scikit_image import SegmentationAlgorithm +from captum._utils.models.linear_model import SkLearnLinearModel, SkLearnRidge +from captum_test import acquire_average_auc, acquire_bestacc_attr, acquireAttribution, saveAttrData + +# device = torch.device('cpu') +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +from captum.attr import ( + GradientShap, + DeepLift, + DeepLiftShap, + IntegratedGradients, + LayerConductance, + NeuronConductance, + NoiseTunnel, + Saliency, + InputXGradient, + GuidedBackprop, + Deconvolution, + GuidedGradCam, + FeatureAblation, + ShapleyValueSampling, + Lime, + KernelShap +) + +from captum.metrics import ( + infidelity, + sensitivity_max +) + +from captum.attr._utils.visualization import visualize_image_attr + +### Acquire pixelwise attributions and replace them with ranked numbers averaged +### across segmentation with the largest contribution having the largest number +### and the smallest set to 1, which is the minimum number. +### attr - original attribution +### segm - image segmentations +def rankedAttributionsBySegm(attr, segm): + aveSegmentations, sortedDict = averageSegmentsOut(attr[0,0], segm) + totalSegm = len(sortedDict.keys()) # total segmentations + sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])] + sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score + currentRank = totalSegm + rankedSegmImg = torch.clone(attr) + for totalSegToHide in range(0, len(sortedKeys)): + currentSegmentToHide = sortedKeys[totalSegToHide] + rankedSegmImg[0,0][segm == currentSegmentToHide] = currentRank + currentRank -= 1 + return rankedSegmImg + +### Returns the mean for each segmentation having shape as the same as the input +### This function can only one attribution image at a time +def averageSegmentsOut(attr, segments): + averagedInput = torch.clone(attr) + sortedDict = {} + for x in np.unique(segments): + segmentMean = torch.mean(attr[segments == x][:]) + sortedDict[x] = float(segmentMean.detach().cpu().numpy()) + averagedInput[segments == x] = segmentMean + return averagedInput, sortedDict + +def acquireSelectivityHit(origImg, attributions, segmentations, model, charset, labels, scoring): + # print("segmentations unique len: ", np.unique(segmentations)) + aveSegmentations, sortedDict = averageSegmentsOut(attributions[0,0], segmentations) + sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])] + sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score + # print("sortedDict: ", sortedDict) # {0: -5.51e-06, 1: -1.469e-05, 2: -3.06e-05,...} + # print("aveSegmentations unique len: ", np.unique(aveSegmentations)) + # print("aveSegmentations device: ", aveSegmentations.device) # cuda:0 + # print("aveSegmentations shape: ", aveSegmentations.shape) # (224,224) + # print("aveSegmentations: ", aveSegmentations) + + n_correct = [] + confidenceList = [] # First index is one feature removed, second index two features removed, and so on... + clonedImg = torch.clone(origImg) + gt = labels + for totalSegToHide in range(0, len(sortedKeys)): + ### Acquire LIME prediction result + currentSegmentToHide = sortedKeys[totalSegToHide] + clonedImg[0,0][segmentations == currentSegmentToHide] = 0.0 + modelOut = model(clonedImg) ### Returns a tuple of dictionaries + confScore = scoring(modelOut).cpu().detach().numpy() + pred, _, __ = postprocess(modelOut[0], charset, config.model_eval) + pred = pred[0] # outputs a list, so query [0] + if pred.lower() == gt.lower(): ### not lowercase gt labels, pred only predicts lowercase + n_correct.append(1) + else: + n_correct.append(0) + confScore = confScore[0][0]*100 + confidenceList.append(confScore) + return n_correct, confidenceList + +def _get_dataset(ds_type, paths, is_training, config, **kwargs): + kwargs.update({ + 'img_h': config.dataset_image_height, + 'img_w': config.dataset_image_width, + 'max_length': config.dataset_max_length, + 'case_sensitive': config.dataset_case_sensitive, + 'charset_path': config.dataset_charset_path, + 'data_aug': config.dataset_data_aug, + 'deteriorate_ratio': config.dataset_deteriorate_ratio, + 'is_training': is_training, + 'multiscales': config.dataset_multiscales, + 'one_hot_y': config.dataset_one_hot_y, + }) + datasets = [ds_type(p, **kwargs) for p in paths] + if len(datasets) > 1: return MyConcatDataset(datasets) + else: return datasets[0] + +def get_model(config): + import importlib + names = config.model_name.split('.') + module_name, class_name = '.'.join(names[:-1]), names[-1] + cls = getattr(importlib.import_module(module_name), class_name) + model = cls(config) + logging.info(model) + model = model.eval() + return model + +def preprocess(img, width, height): + img = cv2.resize(np.array(img), (width, height)) + img = transforms.ToTensor()(img).unsqueeze(0) + mean = torch.tensor([0.485, 0.456, 0.406]) + std = torch.tensor([0.229, 0.224, 0.225]) + return (img-mean[...,None,None]) / std[...,None,None] + +def postprocess(output, charset, model_eval): + def _get_output(last_output, model_eval): + if isinstance(last_output, (tuple, list)): + for res in last_output: + if res['name'] == model_eval: output = res + else: output = last_output + return output + + def _decode(logit): + """ Greed decode """ + out = F.softmax(logit, dim=2) + pt_text, pt_scores, pt_lengths = [], [], [] + for o in out: + text = charset.get_text(o.argmax(dim=1), padding=False, trim=False) + text = text.split(charset.null_char)[0] # end at end-token + pt_text.append(text) + pt_scores.append(o.max(dim=1)[0]) + pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token + return pt_text, pt_scores, pt_lengths + + output = _get_output(output, model_eval) + logits, pt_lengths = output['logits'], output['pt_lengths'] + pt_text, pt_scores, pt_lengths_ = _decode(logits) + + return pt_text, pt_scores, pt_lengths_ + +def load(model, file, device=None, strict=True): + if device is None: device = 'cpu' + elif isinstance(device, int): device = torch.device('cuda', device) + assert os.path.isfile(file) + state = torch.load(file, map_location=device) + if set(state.keys()) == {'model', 'opt'}: + state = state['model'] + model.load_state_dict(state, strict=strict) + return model + +def main(config): + height = config.imgH + width = config.imgW + # 'IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80' + targetDataset = settings.TARGET_DATASET # Change also the configs/train_matrn.yaml test.roots test folder + segmRootDir = "{}/{}X{}/{}/".format(settings.SEGM_DIR, height, width, targetDataset) + outputSelectivityPkl = "strexp_ave_{}_{}.pkl".format(settings.MODEL, targetDataset) + outputDir = "./attributionImgs/{}/{}/".format(settings.MODEL, targetDataset) + attrOutputDir = "./attributionData/{}/{}/".format(settings.MODEL, targetDataset) + resumePkl = "" # Use to resume when session destroyed. Set to "" to disable + resumePkl2 = "" # To enable global resume 2nd part. Set to "" to disable + acquireSelectivity = True + acquireInfidelity = False + acquireSensitivity = False + if not os.path.exists(outputDir): + os.makedirs(outputDir) + if not os.path.exists(attrOutputDir): + os.makedirs(attrOutputDir) + + config.character = "abcdefghijklmnopqrstuvwxyz1234567890$#" # See charset_36.txt + converter = SRNConverter(config.character, 36) + model = get_model(config).to(device) + model = load(model, config.model_checkpoint, device=device) + charset = CharsetMapper(filename=config.dataset_charset_path, + max_length=config.dataset_max_length + 1) + + # if os.path.isdir(args.input): + # paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)] + # else: + # paths = glob.glob(os.path.expanduser(args.input)) + # assert paths, "The input path(s) was not found" + # paths = sorted(paths) + # for path in tqdm.tqdm(paths): + # img = PIL.Image.open(path).convert('RGB') + # img = preprocess(img, config.dataset_image_width, config.dataset_image_height) + # img = img.to(device) + + """ evaluation """ + modelCopy = copy.deepcopy(model) + scoring_singlechar = STRScore(config=config, charsetMapper=charset, postprocessFunc=postprocess, device=device, enableSingleCharAttrAve=True) + super_pixel_model_singlechar = torch.nn.Sequential( + modelCopy, + scoring_singlechar + ).to(device) + modelCopy.eval() + scoring_singlechar.eval() + super_pixel_model_singlechar.eval() + + scoring = STRScore(config=config, charsetMapper=charset, postprocessFunc=postprocess, device=device) + ### SuperModel + super_pixel_model = torch.nn.Sequential( + model, + scoring + ).to(device) + model.eval() + scoring.eval() + super_pixel_model.eval() + + selectivity_eval_results = [] + + if config.blackbg: + shapImgLs = np.zeros(shape=(1, 3, 32, 128)).astype(np.float32) + trainList = np.array(shapImgLs) + background = torch.from_numpy(trainList).to(device) + + # define a perturbation function for the input (used for calculating infidelity) + # def perturb_fn(modelInputs): + # noise = torch.tensor(np.random.normal(0, 0.003, modelInputs.shape)).float() + # noise = noise.to(device) + # return noise, modelInputs - noise + + strict = ifnone(config.model_strict, True) + ### Dataset not shuffled because it is not a dataloader, just a dataset + valid_ds = _get_dataset(CustomImageDataset, config.dataset_test_roots, False, config) + # print("valid_ds: ", len(valid_ds[0])) + testImgCount = 0 + if resumePkl != "": + with open(resumePkl, 'rb') as filePkl: + selectivity_eval_results = pickle.load(filePkl) + for h in range(1, len(selectivity_eval_results)): + if "testImgCount" in selectivity_eval_results[-h]: + testImgCount = selectivity_eval_results[-h]["testImgCount"] # ResumeCount + break + try: + for i, (orig_img_tensors, labels, labels_tensor) in enumerate(valid_ds): + if i <= testImgCount: + continue + orig_img_tensors = orig_img_tensors.unsqueeze(0) + # print("orig_img_tensors: ", orig_img_tensors.shape) # (3, 32, 128) + # img_rgb *= 255.0 + # img_rgb = img_rgb.astype('int') + # print("img_rgb max: ", img_rgb.max()) ### 255 + # img_rgb = np.asarray(orig_img_tensors) + # segmentations = segmentation_fn(img_rgb) + # print("segmentations shape: ", segmentations.shape) # (224, 224) + # print("segmentations min: ", segmentations.min()) 0 + # print("Unique: ", len(np.unique(segmentations))) # (70) + results_dict = {} + with open(segmRootDir + "{}.pkl".format(i), 'rb') as f: + pklData = pickle.load(f) + # segmData, labels = segAndLabels[0] + segmDataNP = pklData["segdata"] + labels = labels.lower() # For fair evaluation for all + assert pklData['label'] == labels + # labels = "lama0" + segmTensor = torch.from_numpy(segmDataNP).unsqueeze(0).unsqueeze(0) + # print("segmTensor min: ", segmTensor.min()) # 0 starting segmentation + segmTensor = segmTensor.to(device) + # print("segmTensor shape: ", segmTensor.shape) + # img1 = np.asarray(imgPIL.convert('L')) + # sys.exit() + # img1 = img1 / 255.0 + # img1 = torch.from_numpy(img1).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device) + img1 = orig_img_tensors.to(device) + img1.requires_grad = True + bgImg = torch.zeros(img1.shape).to(device) + ## Required preprocessing for MATRN + mean = torch.tensor([0.485, 0.456, 0.406]) + std = torch.tensor([0.229, 0.224, 0.225]) + img1 = (img1-mean[...,None,None]) / std[...,None,None] + # preds = model(img1, seqlen=converter.batch_max_length) + input = img1 + origImgNP = torch.clone(orig_img_tensors).detach().cpu().numpy()[0][0] # (1, 1, 224, 224) + origImgNP = gray2rgb(origImgNP) + + ### Integrated Gradients + ig = IntegratedGradients(super_pixel_model) + attributions = ig.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_intgrad.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_intgrad.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["intgrad_acc"] = n_correct + results_dict["intgrad_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["intgrad_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ig.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["intgrad_sens"] = sens + + ### Gradient SHAP using zero-background + gs = GradientShap(super_pixel_model) + # We define a distribution of baselines and draw `n_samples` from that + # distribution in order to estimate the expectations of gradients across all baselines + baseline_dist = torch.zeros((1, 3, height, width)) + baseline_dist = baseline_dist.to(device) + attributions = gs.attribute(input, baselines=baseline_dist, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_gradshap.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_gradshap.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["gradshap_acc"] = n_correct + results_dict["gradshap_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["gradshap_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(gs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["gradshap_sens"] = sens + + ### DeepLift using zero-background + dl = DeepLift(super_pixel_model) + attributions = dl.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_deeplift.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_deeplift.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["deeplift_acc"] = n_correct + results_dict["deeplift_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["deeplift_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(dl.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["deeplift_sens"] = sens + + ### Saliency + saliency = Saliency(super_pixel_model) + attributions = saliency.attribute(input, target=0) ### target=class0 + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_saliency.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_saliency.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["saliency_acc"] = n_correct + results_dict["saliency_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["saliency_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(saliency.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["saliency_sens"] = sens + + ### InputXGradient + input_x_gradient = InputXGradient(super_pixel_model) + attributions = input_x_gradient.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_inpxgrad.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_inpxgrad.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["inpxgrad_acc"] = n_correct + results_dict["inpxgrad_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["inpxgrad_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(input_x_gradient.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["inpxgrad_sens"] = sens + + ### GuidedBackprop + gbp = GuidedBackprop(super_pixel_model) + attributions = gbp.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_guidedbp.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_guidedbp.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["guidedbp_acc"] = n_correct + results_dict["guidedbp_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["guidedbp_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(gbp.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["guidedbp_sens"] = sens + + ### Deconvolution + deconv = Deconvolution(super_pixel_model) + attributions = deconv.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_deconv.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_deconv.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["deconv_acc"] = n_correct + results_dict["deconv_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["deconv_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(deconv.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["deconv_sens"] = sens + + ### Feature ablator + ablator = FeatureAblation(super_pixel_model) + attributions = ablator.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_featablt.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_featablt.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["featablt_acc"] = n_correct + results_dict["featablt_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["featablt_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ablator.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["featablt_sens"] = sens + + ### Shapley Value Sampling + svs = ShapleyValueSampling(super_pixel_model) + # attr = svs.attribute(input, target=0, n_samples=200) ### Individual pixels, too long to calculate + attributions = svs.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_shapley.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_shapley.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["shapley_acc"] = n_correct + results_dict["shapley_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["shapley_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["shapley_sens"] = sens + + ## LIME + interpretable_model = SkLearnRidge(alpha=1, fit_intercept=True) ### This is the default used by LIME + lime = Lime(super_pixel_model, interpretable_model=interpretable_model) + attributions = lime.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_lime.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_lime.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["lime_acc"] = n_correct + results_dict["lime_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["lime_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(lime.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["lime_sens"] = sens + + ### KernelSHAP + ks = KernelShap(super_pixel_model) + attributions = ks.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_kernelshap.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_kernelshap.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, charset, labels, scoring) + results_dict["kernelshap_acc"] = n_correct + results_dict["kernelshap_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["kernelshap_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ks.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["kernelshap_sens"] = sens + + # Other data + results_dict["testImgCount"] = testImgCount # 0 to N-1 + selectivity_eval_results.append(results_dict) + + with open(outputSelectivityPkl, 'wb') as f: + pickle.dump(selectivity_eval_results, f) + + testImgCount += 1 + print("testImgCount: ", testImgCount) + except: + print("An exception occurred1") + + del valid_ds + valid_ds = _get_dataset(CustomImageDataset, config.dataset_test_roots, False, config) + bestAttributionKeyStr = acquire_bestacc_attr(config, outputSelectivityPkl) + bestAttrName = bestAttributionKeyStr.split('_')[0] + + ### Run another forloop + testImgCount = 0 + if resumePkl2 != "": + with open(resumePkl2, 'rb') as filePkl: + selectivity_eval_results = pickle.load(filePkl) + for h in range(1, len(selectivity_eval_results)): + if "testImgCount2" in selectivity_eval_results[-h]: + testImgCount = selectivity_eval_results[-h]["testImgCount2"] # ResumeCount + break + try: + for i, (orig_img_tensors, labels, labels_tensor) in enumerate(valid_ds): + if i <= testImgCount: + continue + orig_img_tensors = orig_img_tensors.unsqueeze(0) + results_dict = {} + with open(segmRootDir + "{}.pkl".format(i), 'rb') as f: + pklData = pickle.load(f) + # segmData, labels = segAndLabels[0] + segmDataNP = pklData["segdata"] + labels = labels.lower() # For fair evaluation for all + assert pklData['label'] == labels + # labels = "lama0" + target = converter.encode([labels], len(config.character)) + target = target[0] + 1 # Idx predicted by ABINET is 1 to N chars, not 0 to N-1 + target[target > 36] = 0 # Remove EOS predictions, set endpoint chars to 0 + segmTensor = torch.from_numpy(segmDataNP).unsqueeze(0).unsqueeze(0) + # print("segmTensor min: ", segmTensor.min()) # 0 starting segmentation + segmTensor = segmTensor.to(device) + # print("segmTensor shape: ", segmTensor.shape) + # img1 = np.asarray(imgPIL.convert('L')) + # sys.exit() + # img1 = img1 / 255.0 + # img1 = torch.from_numpy(img1).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device) + img1 = orig_img_tensors.to(device) + img1.requires_grad = True + bgImg = torch.zeros(img1.shape).to(device) + ## Required preprocessing for MATRN + mean = torch.tensor([0.485, 0.456, 0.406]) + std = torch.tensor([0.229, 0.224, 0.225]) + img1 = (img1-mean[...,None,None]) / std[...,None,None] + # preds = model(img1, seqlen=converter.batch_max_length) + input = img1 + origImgNP = torch.clone(orig_img_tensors).detach().cpu().numpy()[0][0] # (1, 1, 224, 224) + origImgNP = gray2rgb(origImgNP) + + charOffset = 0 + ### Local explanations only + collectedAttributions = [] + for charIdx in range(0, len(labels)): + scoring_singlechar.setSingleCharOutput(charIdx + charOffset) + # print("charIdx + charOffset: ", charIdx + charOffset) + # print("target[0]: ", target[0]) + gtClassNum = target[0][charIdx + charOffset] + + ### Gradient SHAP using zero-background + # gs = GradientShap(super_pixel_model_singlechar) + # baseline_dist = torch.zeros((1, 3, height, width)) + # baseline_dist = baseline_dist.to(device) + # attributions = gs.attribute(input, baselines=baseline_dist, target=gtClassNum) + attributions = acquireAttribution(config, super_pixel_model_singlechar, \ + input, segmTensor, gtClassNum, bestAttributionKeyStr, device) + collectedAttributions.append(attributions) + aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0) + rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_{}_l.png'.format(i, bestAttrName)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_{bestAttrName}_l.pkl', aveAttributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, aveAttributions, segmDataNP, modelCopy, charset, labels, scoring_singlechar) + results_dict[f"{bestAttrName}_local_acc"] = n_correct + results_dict[f"{bestAttrName}_local_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model_singlechar, perturb_fn, img1, aveAttributions, normalize=True).detach().cpu().numpy()) + results_dict[f"{bestAttrName}_local_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict[f"{bestAttrName}_local_sens"] = sens + + ### Best attribution-based method using zero-background + attributions = acquireAttribution(config, super_pixel_model, \ + input, segmTensor, 0, bestAttributionKeyStr, device) + collectedAttributions.append(attributions) + + ### Global + Local context + aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0) + rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_{}_gl.png'.format(i, bestAttrName)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_{bestAttrName}_gl.pkl', aveAttributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, aveAttributions, segmDataNP, modelCopy, charset, labels, scoring_singlechar) + results_dict[f"{bestAttrName}_global_local_acc"] = n_correct + results_dict[f"{bestAttrName}_global_local_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model_singlechar, perturb_fn, img1, aveAttributions).detach().cpu().numpy()) + results_dict[f"{bestAttrName}_global_local_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict[f"{bestAttrName}_global_local_sens"] = sens + + results_dict["testImgCount2"] = testImgCount # 0 to N-1 + selectivity_eval_results.append(results_dict) + + with open(outputSelectivityPkl, 'wb') as f: + pickle.dump(selectivity_eval_results, f) + + testImgCount += 1 + print("testImgCount GlobLoc: ", testImgCount) + except: + print("An exception occurred2") + +### Use to check if the model predicted the image or not. Output a pickle file with the image index. +def modelDatasetPredOnly(opt): + # 'IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', + # 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80' + datasetName = "IIIT5k_3000" + outputSelectivityPkl = "metrics_predictonly_eval_results_{}.pkl".format(datasetName) + charset = CharsetMapper(filename=config.dataset_charset_path, + max_length=config.dataset_max_length + 1) + model = get_model(config).to(device) + model = load(model, config.model_checkpoint, device=device) + model.eval() + strict = ifnone(config.model_strict, True) + ### Dataset not shuffled because it is not a dataloader, just a dataset + valid_ds = _get_dataset(CustomImageDataset, config.dataset_test_roots, False, config) + # print("valid_ds: ", len(valid_ds[0])) + testImgCount = 0 + predOutput = [] + for i, (orig_img_tensors, labels, labels_tensor) in enumerate(valid_ds): + orig_img_tensors = orig_img_tensors.unsqueeze(0).to(device) + modelOut = model(orig_img_tensors) ### Returns a tuple of dictionaries + pred, _, __ = postprocess(modelOut[0], charset, config.model_eval) + pred = pred[0] # outputs a list, so query [0] + if pred.lower() == labels.lower(): predOutput.append(1) + else: predOutput.append(0) + with open(outputSelectivityPkl, 'wb') as f: + pickle.dump(predOutput, f) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, default='configs/train_matrn.yaml', + help='path to config file') + parser.add_argument('--input', type=str, default='figs/test') + parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') + parser.add_argument('--imgW', type=int, default=128, help='the width of the input image') + parser.add_argument('--scorer', type=str, default='mean', help='See STRScore: cumprod | mean') + parser.add_argument('--blackbg', action='store_true', default=None) + parser.add_argument('--cuda', type=int, default=-1) + parser.add_argument('--rgb', action='store_true', help='use rgb input') + parser.add_argument('--checkpoint', type=str, default='workdir/train-abinet/best-train-abinet.pth') + parser.add_argument('--model_eval', type=str, default='alignment', + choices=['alignment', 'vision', 'language']) + args = parser.parse_args() + config = Config(args.config) + if args.checkpoint is not None: config.model_checkpoint = args.checkpoint + if args.model_eval is not None: config.model_eval = args.model_eval + if args.imgH is not None: config.imgH = args.imgH + if args.imgW is not None: config.imgW = args.imgW + if args.scorer is not None: config.scorer = args.scorer + if args.blackbg is not None: config.blackbg = args.blackbg + if args.rgb is not None: config.rgb = args.rgb + config.global_phase = 'test' + config.model_vision_checkpoint, config.model_language_checkpoint = None, None + device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}' + + Logger.init(config.global_workdir, config.global_name, config.global_phase) + Logger.enable_file() + logging.info(config) + + # acquire_average_auc(config) + main(config) + # modelDatasetPredOnly(config) diff --git a/captum_improve_parseq.py b/captum_improve_parseq.py new file mode 100644 index 0000000000000000000000000000000000000000..9b50900f16735c0ec7cad52c1c043daef083556f --- /dev/null +++ b/captum_improve_parseq.py @@ -0,0 +1,668 @@ +import settings +import captum +import numpy as np +import torch +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +from torchvision import transforms +from utils import get_args +from utils import CTCLabelConverter, AttnLabelConverter, Averager, TokenLabelConverter +import string +import time +import sys +from dataset import hierarchical_dataset, AlignCollate +import validators +from model import Model, STRScore +from PIL import Image +from lime.wrappers.scikit_image import SegmentationAlgorithm +from captum._utils.models.linear_model import SkLearnLinearModel, SkLearnRidge +import random +import os +from skimage.color import gray2rgb +import pickle +from train_shap_corr import getPredAndConf +import re +from captum_test import acquire_average_auc, acquireListOfAveAUC, acquire_bestacc_attr, acquireAttribution, saveAttrData +import copy +from captum_improve_vitstr import rankedAttributionsBySegm +from matplotlib import pyplot as plt +from captum.attr._utils.visualization import visualize_image_attr + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +from captum.attr import ( + GradientShap, + DeepLift, + DeepLiftShap, + IntegratedGradients, + LayerConductance, + NeuronConductance, + NoiseTunnel, + Saliency, + InputXGradient, + GuidedBackprop, + Deconvolution, + GuidedGradCam, + FeatureAblation, + ShapleyValueSampling, + Lime, + KernelShap +) + +from captum.metrics import ( + infidelity, + sensitivity_max +) + +### Returns the mean for each segmentation having shape as the same as the input +### This function can only one attribution image at a time +def averageSegmentsOut(attr, segments): + averagedInput = torch.clone(attr) + sortedDict = {} + for x in np.unique(segments): + segmentMean = torch.mean(attr[segments == x][:]) + sortedDict[x] = float(segmentMean.detach().cpu().numpy()) + averagedInput[segments == x] = segmentMean + return averagedInput, sortedDict + +### Output and save segmentations only for one dataset only +def outputSegmOnly(opt): + ### targetDataset - one dataset only, SVTP-645, CUTE80-288images + targetDataset = "CUTE80" # ['IIIT5k_3000', 'SVT', 'IC03_867', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] + segmRootDir = "/home/uclpc1/Documents/STR/datasets/segmentations/224X224/{}/".format(targetDataset) + + if not os.path.exists(segmRootDir): + os.makedirs(segmRootDir) + + opt.eval = True + ### Only IIIT5k_3000 + if opt.fast_acc: + # # To easily compute the total accuracy of our paper. + eval_data_list = [targetDataset] + else: + # The evaluation datasets, dataset order is same with Table 1 in our paper. + eval_data_list = [targetDataset] + + ### Taken from LIME + segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, + max_dist=200, ratio=0.2, + random_seed=random.randint(0, 1000)) + + for eval_data in eval_data_list: + eval_data_path = os.path.join(opt.eval_data, eval_data) + AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt) + eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt) + evaluation_loader = torch.utils.data.DataLoader( + eval_data, batch_size=1, + shuffle=False, + num_workers=int(opt.workers), + collate_fn=AlignCollate_evaluation, pin_memory=True) + for i, (image_tensors, labels) in enumerate(evaluation_loader): + imgDataDict = {} + img_numpy = image_tensors.cpu().detach().numpy()[0] ### Need to set batch size to 1 only + if img_numpy.shape[0] == 1: + img_numpy = gray2rgb(img_numpy[0]) + # print("img_numpy shape: ", img_numpy.shape) # (224,224,3) + segmOutput = segmentation_fn(img_numpy) + imgDataDict['segdata'] = segmOutput + imgDataDict['label'] = labels[0] + outputPickleFile = segmRootDir + "{}.pkl".format(i) + with open(outputPickleFile, 'wb') as f: + pickle.dump(imgDataDict, f) + +def acquireSelectivityHit(origImg, attributions, segmentations, model, converter, labels, scoring): + # print("segmentations unique len: ", np.unique(segmentations)) + aveSegmentations, sortedDict = averageSegmentsOut(attributions[0,0], segmentations) + sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])] + sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score + # print("sortedDict: ", sortedDict) # {0: -5.51e-06, 1: -1.469e-05, 2: -3.06e-05,...} + # print("aveSegmentations unique len: ", np.unique(aveSegmentations)) + # print("aveSegmentations device: ", aveSegmentations.device) # cuda:0 + # print("aveSegmentations shape: ", aveSegmentations.shape) # (224,224) + # print("aveSegmentations: ", aveSegmentations) + + n_correct = [] + confidenceList = [] # First index is one feature removed, second index two features removed, and so on... + clonedImg = torch.clone(origImg) + gt = str(labels) + for totalSegToHide in range(0, len(sortedKeys)): + ### Acquire LIME prediction result + currentSegmentToHide = sortedKeys[totalSegToHide] + clonedImg[0,0][segmentations == currentSegmentToHide] = 0.0 + pred, confScore = getPredAndConf(opt, model, scoring, clonedImg, converter, np.array([gt])) + # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. + if opt.sensitive and opt.data_filtering_off: + pred = pred.lower() + gt = gt.lower() + alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz' + out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]' + pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred) + gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt) + if pred == gt: + n_correct.append(1) + else: + n_correct.append(0) + confScore = confScore[0][0]*100 + confidenceList.append(confScore) + return n_correct, confidenceList + +### Once you have the selectivity_eval_results.pkl file, +def acquire_selectivity_auc(opt, pkl_filename=None): + if pkl_filename is None: + pkl_filename = "/home/goo/str/str_vit_dataexplain_lambda/metrics_sensitivity_eval_results_CUTE80.pkl" # VITSTR + accKeys = [] + + with open(pkl_filename, 'rb') as f: + selectivity_data = pickle.load(f) + + for resDictIdx, resDict in enumerate(selectivity_data): + keylistAcc = [] + keylistConf = [] + metricsKeys = resDict.keys() + for keyStr in resDict.keys(): + if "_acc" in keyStr: keylistAcc.append(keyStr) + if "_conf" in keyStr: keylistConf.append(keyStr) + # Need to check if network correctly predicted the image + for metrics_accStr in keylistAcc: + if 1 not in resDict[metrics_accStr]: print("resDictIdx") + +### This acquires the attributes of the STR network on individual character levels, +### then averages them. +def acquireSingleCharAttrAve(opt): + ### targetDataset - one dataset only, CUTE80 has 288 samples + # 'IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80' + targetDataset = settings.TARGET_DATASET + segmRootDir = "{}/32X128/{}/".format(settings.SEGM_DIR, targetDataset) + outputSelectivityPkl = "strexp_ave_{}_{}.pkl".format(settings.MODEL, targetDataset) + outputDir = "./attributionImgs/{}/{}/".format(settings.MODEL, targetDataset) + attrOutputDir = "./attributionData/{}/{}/".format(settings.MODEL, targetDataset) + ### Set only one below to True to have enough GPU + acquireSelectivity = True + acquireInfidelity = False + acquireSensitivity = False ### GPU error + if not os.path.exists(outputDir): + os.makedirs(outputDir) + if not os.path.exists(attrOutputDir): + os.makedirs(attrOutputDir) + + model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True) + model = model.to(device) + model_obj = model + converter = TokenLabelConverter(opt) + + modelCopy = copy.deepcopy(model) + + """ evaluation """ + scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True, model=modelCopy) + super_pixel_model_singlechar = torch.nn.Sequential( + # super_pixler, + # numpy2torch_converter, + modelCopy, + scoring_singlechar + ).to(device) + modelCopy.eval() + scoring_singlechar.eval() + super_pixel_model_singlechar.eval() + + # Single Char Attribution Averaging + # enableSingleCharAttrAve - set to True + scoring = STRScore(opt=opt, converter=converter, device=device, model=model) + super_pixel_model = torch.nn.Sequential( + # super_pixler, + # numpy2torch_converter, + model, + scoring + ).to(device) + model.eval() + scoring.eval() + super_pixel_model.eval() + + if opt.blackbg: + shapImgLs = np.zeros(shape=(1, 1, 224, 224)).astype(np.float32) + trainList = np.array(shapImgLs) + background = torch.from_numpy(trainList).to(device) + + opt.eval = True + + ### Only IIIT5k_3000 + if opt.fast_acc: + # # To easily compute the total accuracy of our paper. + eval_data_list = [targetDataset] ### One dataset only + else: + # The evaluation datasets, dataset order is same with Table 1 in our paper. + eval_data_list = [targetDataset] + + if opt.calculate_infer_time: + evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image. + else: + evaluation_batch_size = opt.batch_size + + selectivity_eval_results = [] + + testImgCount = 0 + list_accuracy = [] + total_forward_time = 0 + total_evaluation_data_number = 0 + total_correct_number = 0 + + segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, + max_dist=200, ratio=0.2, + random_seed=random.randint(0, 1000)) + + for eval_data in eval_data_list: + eval_data_path = os.path.join(opt.eval_data, eval_data) + AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt) + eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt, segmRootDir=segmRootDir) + evaluation_loader = torch.utils.data.DataLoader( + eval_data, batch_size=1, + shuffle=False, + num_workers=int(opt.workers), + collate_fn=AlignCollate_evaluation, pin_memory=True) + testImgCount = 0 + + for i, (orig_img_tensors, segAndLabels) in enumerate(evaluation_loader): + results_dict = {} + aveAttr = [] + aveAttr_charContrib = [] + segmData, labels = segAndLabels[0] + target = converter.encode([labels]) + + # labels: RONALDO + segmDataNP = segmData["segdata"] + segmTensor = torch.from_numpy(segmDataNP).unsqueeze(0).unsqueeze(0) + # print("segmTensor min: ", segmTensor.min()) # 0 starting segmentation + segmTensor = segmTensor.to(device) + img1 = orig_img_tensors.to(device) + img1.requires_grad = True + bgImg = torch.zeros(img1.shape).to(device) + + ### Single char averaging + if settings.MODEL == 'vitstr': + charOffset = 1 + elif settings.MODEL == 'parseq': + charOffset = 0 + img1 = transforms.Normalize(0.5, 0.5)(img1) # Between -1 to 1 + + # preds = model(img1, seqlen=converter.batch_max_length) + input = img1 + origImgNP = torch.clone(orig_img_tensors).detach().cpu().numpy()[0][0] # (1, 1, 224, 224) + origImgNP = gray2rgb(origImgNP) + + ### BASELINE Evaluations + + ### Integrated Gradients + ig = IntegratedGradients(super_pixel_model) + attributions = ig.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_intgrad.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_intgrad.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["intgrad_acc"] = n_correct + results_dict["intgrad_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["intgrad_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ig.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["intgrad_sens"] = sens + + ### Gradient SHAP using zero-background + gs = GradientShap(super_pixel_model) + # We define a distribution of baselines and draw `n_samples` from that + # distribution in order to estimate the expectations of gradients across all baselines + baseline_dist = torch.zeros((1, 3, opt.imgH, opt.imgW)) + baseline_dist = baseline_dist.to(device) + attributions = gs.attribute(input, baselines=baseline_dist, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_gradshap.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_gradshap.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["gradshap_acc"] = n_correct + results_dict["gradshap_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["gradshap_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(gs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["gradshap_sens"] = sens + + ### DeepLift using zero-background + dl = DeepLift(super_pixel_model) + attributions = dl.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_deeplift.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_deeplift.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["deeplift_acc"] = n_correct + results_dict["deeplift_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["deeplift_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(dl.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["deeplift_sens"] = sens + + ### Saliency + saliency = Saliency(super_pixel_model) + attributions = saliency.attribute(input, target=0) ### target=class0 + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_saliency.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_saliency.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["saliency_acc"] = n_correct + results_dict["saliency_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["saliency_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(saliency.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["saliency_sens"] = sens + + ### InputXGradient + input_x_gradient = InputXGradient(super_pixel_model) + attributions = input_x_gradient.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_inpxgrad.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_inpxgrad.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["inpxgrad_acc"] = n_correct + results_dict["inpxgrad_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["inpxgrad_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(input_x_gradient.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["inpxgrad_sens"] = sens + + ### GuidedBackprop + gbp = GuidedBackprop(super_pixel_model) + attributions = gbp.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_guidedbp.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_guidedbp.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["guidedbp_acc"] = n_correct + results_dict["guidedbp_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["guidedbp_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(gbp.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["guidedbp_sens"] = sens + + ### Deconvolution + deconv = Deconvolution(super_pixel_model) + attributions = deconv.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_deconv.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_deconv.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["deconv_acc"] = n_correct + results_dict["deconv_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["deconv_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(deconv.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["deconv_sens"] = sens + + ### Feature ablator + ablator = FeatureAblation(super_pixel_model) + attributions = ablator.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_featablt.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_featablt.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["featablt_acc"] = n_correct + results_dict["featablt_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["featablt_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ablator.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["featablt_sens"] = sens + + ### Shapley Value Sampling + svs = ShapleyValueSampling(super_pixel_model) + # attr = svs.attribute(input, target=0, n_samples=200) ### Individual pixels, too long to calculate + attributions = svs.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_shapley.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_shapley.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["shapley_acc"] = n_correct + results_dict["shapley_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["shapley_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["shapley_sens"] = sens + + ## LIME + interpretable_model = SkLearnRidge(alpha=1, fit_intercept=True) ### This is the default used by LIME + lime = Lime(super_pixel_model, interpretable_model=interpretable_model) + attributions = lime.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_lime.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_lime.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["lime_acc"] = n_correct + results_dict["lime_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["lime_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(lime.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["lime_sens"] = sens + + ### KernelSHAP + ks = KernelShap(super_pixel_model) + attributions = ks.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_kernelshap.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_kernelshap.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["kernelshap_acc"] = n_correct + results_dict["kernelshap_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["kernelshap_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ks.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["kernelshap_sens"] = sens + + selectivity_eval_results.append(results_dict) + + with open(outputSelectivityPkl, 'wb') as f: + pickle.dump(selectivity_eval_results, f) + + testImgCount += 1 + print("testImgCount: ", testImgCount) + + bestAttributionKeyStr = acquire_bestacc_attr(opt, outputSelectivityPkl) + bestAttrName = bestAttributionKeyStr.split('_')[0] + + testImgCount = 0 + for i, (orig_img_tensors, segAndLabels) in enumerate(evaluation_loader): + results_dict = {} + aveAttr = [] + aveAttr_charContrib = [] + segmData, labels = segAndLabels[0] + target = converter.encode([labels]) + + # labels: RONALDO + segmDataNP = segmData["segdata"] + segmTensor = torch.from_numpy(segmDataNP).unsqueeze(0).unsqueeze(0) + # print("segmTensor min: ", segmTensor.min()) # 0 starting segmentation + segmTensor = segmTensor.to(device) + # print("segmTensor shape: ", segmTensor.shape) + # img1 = np.asarray(imgPIL.convert('L')) + # sys.exit() + # img1 = img1 / 255.0 + # img1 = torch.from_numpy(img1).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device) + img1 = orig_img_tensors.to(device) + img1.requires_grad = True + bgImg = torch.zeros(img1.shape).to(device) + + ### Single char averaging + if settings.MODEL == 'vitstr': + charOffset = 1 + elif settings.MODEL == 'parseq': + target = target[:, 1:] # First position [GO] not used in parseq too. + # 0 index is [GO] char, not used in parseq, only the [EOS] which is in 1 index + target[target > 0] -= 1 + charOffset = 0 + img1 = transforms.Normalize(0.5, 0.5)(img1) # Between -1 to 1 + + # preds = model(img1, seqlen=converter.batch_max_length) + input = img1 + origImgNP = torch.clone(orig_img_tensors).detach().cpu().numpy()[0][0] # (1, 1, 224, 224) + origImgNP = gray2rgb(origImgNP) + + ### Captum test + collectedAttributions = [] + for charIdx in range(0, len(labels)): + scoring_singlechar.setSingleCharOutput(charIdx + charOffset) + gtClassNum = target[0][charIdx + charOffset] + + # Best + attributions = acquireAttribution(opt, super_pixel_model_singlechar, \ + input, segmTensor, gtClassNum, bestAttributionKeyStr, device) + collectedAttributions.append(attributions) + aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0) + rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_{}_l.png'.format(i, bestAttrName)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_{bestAttrName}_l.pkl', aveAttributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, aveAttributions, segmDataNP, modelCopy, converter, labels, scoring_singlechar) + results_dict[f"{bestAttrName}_local_acc"] = n_correct + results_dict[f"{bestAttrName}_local_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model_singlechar, perturb_fn, img1, aveAttributions).detach().cpu().numpy()) + results_dict[f"{bestAttrName}_local_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict[f"{bestAttrName}_local_sens"] = sens + + ### Best single + attributions = acquireAttribution(opt, super_pixel_model, \ + input, segmTensor, 0, bestAttributionKeyStr, device) + collectedAttributions.append(attributions) + + ### Global + Local context + aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0) + rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_{}_gl.png'.format(i, bestAttrName)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_{bestAttrName}_gl.pkl', aveAttributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, aveAttributions, segmDataNP, modelCopy, converter, labels, scoring_singlechar) + results_dict[f"{bestAttrName}_global_local_acc"] = n_correct + results_dict[f"{bestAttrName}_global_local_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model_singlechar, perturb_fn, img1, aveAttributions).detach().cpu().numpy()) + results_dict[f"{bestAttrName}_global_local_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict[f"{bestAttrName}_global_local_sens"] = sens + + selectivity_eval_results.append(results_dict) + + with open(outputSelectivityPkl, 'wb') as f: + pickle.dump(selectivity_eval_results, f) + + testImgCount += 1 + print("testImgCount GlobLoc: ", testImgCount) + +if __name__ == '__main__': + # deleteInf() + opt = get_args(is_train=False) + + """ vocab / character number configuration """ + if opt.sensitive: + opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). + + cudnn.benchmark = True + cudnn.deterministic = True + opt.num_gpu = torch.cuda.device_count() + + # combineBestDataXAI(opt) + # acquire_average_auc(opt) + # acquireListOfAveAUC(opt) + acquireSingleCharAttrAve(opt) diff --git a/captum_improve_srn.py b/captum_improve_srn.py new file mode 100644 index 0000000000000000000000000000000000000000..fefc98786d1a1a43a6a534ff2ef5a9a5eb21ca41 --- /dev/null +++ b/captum_improve_srn.py @@ -0,0 +1,700 @@ +import settings +import captum +import numpy as np +import argparse +import torch +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +from torchvision import transforms +from utils import get_args +from utils import CTCLabelConverter, AttnLabelConverter, Averager, TokenLabelConverter, SRNConverter +import string +import time +import sys +from dataset import hierarchical_dataset, AlignCollate +import validators +from model_srn import Model, STRScore +from PIL import Image +from lime.wrappers.scikit_image import SegmentationAlgorithm +from captum._utils.models.linear_model import SkLearnLinearModel, SkLearnRidge +import random +import os +from skimage.color import gray2rgb +import pickle +from train_shap_corr import getPredAndConf +import re +from captum_test import acquire_average_auc, acquireListOfAveAUC, saveAttrData +import copy +from model_srn import Model +from captum_improve_vitstr import rankedAttributionsBySegm +from matplotlib import pyplot as plt +from captum.attr._utils.visualization import visualize_image_attr + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +from captum.attr import ( + GradientShap, + DeepLift, + DeepLiftShap, + IntegratedGradients, + LayerConductance, + NeuronConductance, + NoiseTunnel, + Saliency, + InputXGradient, + GuidedBackprop, + Deconvolution, + GuidedGradCam, + FeatureAblation, + ShapleyValueSampling, + Lime, + KernelShap +) + +from captum.metrics import ( + infidelity, + sensitivity_max +) + +### Returns the mean for each segmentation having shape as the same as the input +### This function can only one attribution image at a time +def averageSegmentsOut(attr, segments): + averagedInput = torch.clone(attr) + sortedDict = {} + for x in np.unique(segments): + segmentMean = torch.mean(attr[segments == x][:]) + sortedDict[x] = float(segmentMean.detach().cpu().numpy()) + averagedInput[segments == x] = segmentMean + return averagedInput, sortedDict + +### Output and save segmentations only for one dataset only +def outputSegmOnly(opt): + ### targetDataset - one dataset only, SVTP-645, CUTE80-288images + targetDataset = "CUTE80" # ['IIIT5k_3000', 'SVT', 'IC03_867', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] + segmRootDir = "/home/uclpc1/Documents/STR/datasets/segmentations/224X224/{}/".format(targetDataset) + + if not os.path.exists(segmRootDir): + os.makedirs(segmRootDir) + + opt.eval = True + ### Only IIIT5k_3000 + if opt.fast_acc: + # # To easily compute the total accuracy of our paper. + eval_data_list = [targetDataset] + else: + # The evaluation datasets, dataset order is same with Table 1 in our paper. + eval_data_list = [targetDataset] + + ### Taken from LIME + segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, + max_dist=200, ratio=0.2, + random_seed=random.randint(0, 1000)) + + for eval_data in eval_data_list: + eval_data_path = os.path.join(opt.eval_data, eval_data) + AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt) + eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt) + evaluation_loader = torch.utils.data.DataLoader( + eval_data, batch_size=1, + shuffle=False, + num_workers=int(opt.workers), + collate_fn=AlignCollate_evaluation, pin_memory=True) + for i, (image_tensors, labels) in enumerate(evaluation_loader): + imgDataDict = {} + img_numpy = image_tensors.cpu().detach().numpy()[0] ### Need to set batch size to 1 only + if img_numpy.shape[0] == 1: + img_numpy = gray2rgb(img_numpy[0]) + # print("img_numpy shape: ", img_numpy.shape) # (224,224,3) + segmOutput = segmentation_fn(img_numpy) + imgDataDict['segdata'] = segmOutput + imgDataDict['label'] = labels[0] + outputPickleFile = segmRootDir + "{}.pkl".format(i) + with open(outputPickleFile, 'wb') as f: + pickle.dump(imgDataDict, f) + +def acquireSelectivityHit(origImg, attributions, segmentations, model, converter, labels, scoring): + # print("segmentations unique len: ", np.unique(segmentations)) + aveSegmentations, sortedDict = averageSegmentsOut(attributions[0,0], segmentations) + sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])] + sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score + # print("sortedDict: ", sortedDict) # {0: -5.51e-06, 1: -1.469e-05, 2: -3.06e-05,...} + # print("aveSegmentations unique len: ", np.unique(aveSegmentations)) + # print("aveSegmentations device: ", aveSegmentations.device) # cuda:0 + # print("aveSegmentations shape: ", aveSegmentations.shape) # (224,224) + # print("aveSegmentations: ", aveSegmentations) + + n_correct = [] + confidenceList = [] # First index is one feature removed, second index two features removed, and so on... + clonedImg = torch.clone(origImg) + gt = str(labels) + for totalSegToHide in range(0, len(sortedKeys)): + ### Acquire LIME prediction result + currentSegmentToHide = sortedKeys[totalSegToHide] + clonedImg[0,0][segmentations == currentSegmentToHide] = 0.0 + pred, confScore = getPredAndConf(opt, model, scoring, clonedImg, converter, np.array([gt])) + # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. + if opt.sensitive and opt.data_filtering_off: + pred = pred.lower() + gt = gt.lower() + alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz' + out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]' + pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred) + gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt) + if pred == gt: + n_correct.append(1) + else: + n_correct.append(0) + confScore = confScore[0][0]*100 + confidenceList.append(confScore) + return n_correct, confidenceList + +### Once you have the selectivity_eval_results.pkl file, +def acquire_selectivity_auc(opt, pkl_filename=None): + if pkl_filename is None: + pkl_filename = "/home/goo/str/str_vit_dataexplain_lambda/metrics_sensitivity_eval_results_CUTE80.pkl" # VITSTR + accKeys = [] + + with open(pkl_filename, 'rb') as f: + selectivity_data = pickle.load(f) + + for resDictIdx, resDict in enumerate(selectivity_data): + keylistAcc = [] + keylistConf = [] + metricsKeys = resDict.keys() + for keyStr in resDict.keys(): + if "_acc" in keyStr: keylistAcc.append(keyStr) + if "_conf" in keyStr: keylistConf.append(keyStr) + # Need to check if network correctly predicted the image + for metrics_accStr in keylistAcc: + if 1 not in resDict[metrics_accStr]: print("resDictIdx") + +### This acquires the attributes of the STR network on individual character levels, +### then averages them. +def acquireSingleCharAttrAve(opt): + ### targetDataset - one dataset only, CUTE80 has 288 samples + # 'IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80' + targetDataset = settings.TARGET_DATASET + segmRootDir = "{}/{}X{}/{}/".format(settings.SEGM_DIR, opt.imgH, opt.imgW, targetDataset) + outputSelectivityPkl = "strexp_ave_{}_{}.pkl".format(settings.MODEL, targetDataset) + outputDir = "./attributionImgs/{}/{}/".format(settings.MODEL, targetDataset) + attrOutputDir = "./attributionData/{}/{}/".format(settings.MODEL, targetDataset) + ### Set only one below to True to have enough GPU + acquireSelectivity = True + acquireInfidelity = False + acquireSensitivity = False ### GPU error + if not os.path.exists(outputDir): + os.makedirs(outputDir) + if not os.path.exists(attrOutputDir): + os.makedirs(attrOutputDir) + + converter = SRNConverter(opt.character, opt.SRN_PAD) + opt.num_class = len(converter.character) + length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * opt.batch_size) + + model = Model(opt) + print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, + opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, + opt.SequenceModeling, opt.Prediction) + model = torch.nn.DataParallel(model).cuda() + + # load model + print('loading pretrained model from %s' % opt.saved_model) + model.load_state_dict(torch.load(opt.saved_model)) + model = model.to(device) + model_obj = model + + modelCopy = copy.deepcopy(model) + + """ evaluation """ + scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True, model=modelCopy) + super_pixel_model_singlechar = torch.nn.Sequential( + # super_pixler, + # numpy2torch_converter, + modelCopy, + scoring_singlechar + ).to(device) + modelCopy.eval() + scoring_singlechar.eval() + super_pixel_model_singlechar.eval() + + # Single Char Attribution Averaging + # enableSingleCharAttrAve - set to True + scoring = STRScore(opt=opt, converter=converter, device=device, model=model) + super_pixel_model = torch.nn.Sequential( + # super_pixler, + # numpy2torch_converter, + model, + scoring + ).to(device) + model.eval() + scoring.eval() + super_pixel_model.eval() + + # scoring_charContrib = STRScore(opt=opt, converter=converter, device=device, hasCharContrib=True) + # super_pixel_model_charContrib = torch.nn.Sequential( + # # super_pixler, + # # numpy2torch_converter, + # model, + # scoring_charContrib + # ).to(device) + # model.eval() + # scoring_charContrib.eval() + # super_pixel_model_charContrib.eval() + + shapImgLs = np.zeros(shape=(1, 1, 224, 224)).astype(np.float32) + trainList = np.array(shapImgLs) + background = torch.from_numpy(trainList).to(device) + + opt.eval = True + + # if opt.fast_acc: + # # # To easily compute the total accuracy of our paper. + # eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_867', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] + # else: + # # The evaluation datasets, dataset order is same with Table 1 in our paper. + # eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', + # 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] + + # # To easily compute the total accuracy of our paper. + eval_data_list = [targetDataset] ### One dataset only + evaluation_batch_size = opt.batch_size + + selectivity_eval_results = [] + + testImgCount = 0 + list_accuracy = [] + total_forward_time = 0 + total_evaluation_data_number = 0 + total_correct_number = 0 + # log = open(f'./result/{opt.exp_name}/log_all_evaluation.txt', 'a') + # dashed_line = '-' * 80 + # print(dashed_line) + # log.write(dashed_line + '\n') + + segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, + max_dist=200, ratio=0.2, + random_seed=random.randint(0, 1000)) + + for eval_data in eval_data_list: + eval_data_path = os.path.join(opt.eval_data, eval_data) + AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt) + eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt, segmRootDir=segmRootDir) + evaluation_loader = torch.utils.data.DataLoader( + eval_data, batch_size=1, + shuffle=False, + num_workers=int(opt.workers), + collate_fn=AlignCollate_evaluation, pin_memory=True) + testImgCount = 0 + + for i, (orig_img_tensors, segAndLabels) in enumerate(evaluation_loader): + # img_rgb *= 255.0 + # img_rgb = img_rgb.astype('int') + # print("img_rgb max: ", img_rgb.max()) ### 255 + # img_rgb = np.asarray(orig_img_tensors) + # segmentations = segmentation_fn(img_rgb) + # print("segmentations shape: ", segmentations.shape) # (224, 224) + # print("segmentations min: ", segmentations.min()) 0 + # print("Unique: ", len(np.unique(segmentations))) # (70) + results_dict = {} + aveAttr = [] + aveAttr_charContrib = [] + segmData, labels = segAndLabels[0] + target = converter.encode([labels]) + + # labels: RONALDO + segmDataNP = segmData["segdata"] + segmTensor = torch.from_numpy(segmDataNP).unsqueeze(0).unsqueeze(0) + # print("segmTensor min: ", segmTensor.min()) # 0 starting segmentation + segmTensor = segmTensor.to(device) + # print("segmTensor shape: ", segmTensor.shape) + # img1 = np.asarray(imgPIL.convert('L')) + # sys.exit() + # img1 = img1 / 255.0 + # img1 = torch.from_numpy(img1).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device) + img1 = orig_img_tensors.to(device) + img1.requires_grad = True + bgImg = torch.zeros(img1.shape).to(device) + + ### Single char averaging + if settings.MODEL == 'vitstr': + charOffset = 1 + # img1 = transforms.Normalize(0.5, 0.5)(img1) # Between -1 to 1 + elif settings.MODEL == 'srn': + charOffset = 0 # SRN has no 'GO' token + elif settings.MODEL == 'parseq': + target = target[:, 1:] # First position [GO] not used in parseq too. + # 0 index is [GO] char, not used in parseq, only the [EOS] which is in 1 index + target[target > 0] -= 1 + charOffset = 0 + img1 = transforms.Normalize(0.5, 0.5)(img1) # Between -1 to 1 + + # preds = model(img1, seqlen=converter.batch_max_length) + input = img1 + origImgNP = torch.clone(orig_img_tensors).detach().cpu().numpy()[0][0] # (1, 1, 224, 224) + origImgNP = gray2rgb(origImgNP) + ### Captum test + collectedAttributions = [] + for charIdx in range(0, len(labels)): + scoring_singlechar.setSingleCharOutput(charIdx + charOffset) + # print("charIdx + charOffset: ", charIdx + charOffset) + # print("target[0]: ", target[0]) + gtClassNum = target[0][0][charIdx + charOffset] + + ### Shapley Value Sampling + svs = ShapleyValueSampling(super_pixel_model_singlechar) + # attr = svs.attribute(input, target=0, n_samples=200) ### Individual pixels, too long to calculate + attributions = svs.attribute(input, target=gtClassNum, feature_mask=segmTensor) + collectedAttributions.append(attributions) + aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0) + rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_shapley_l.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_shapley_l.pkl', aveAttributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, aveAttributions, segmDataNP, modelCopy, converter, labels, scoring_singlechar) + results_dict["shapley_local_acc"] = n_correct + results_dict["shapley_local_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model_singlechar, perturb_fn, img1, aveAttributions).detach().cpu().numpy()) + results_dict["shapley_local_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["shapley_local_sens"] = sens + + ### Shapley Value Sampling + svs = ShapleyValueSampling(super_pixel_model) + # attr = svs.attribute(input, target=0, n_samples=200) ### Individual pixels, too long to calculate + attributions = svs.attribute(input, target=0, feature_mask=segmTensor) + collectedAttributions.append(attributions) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_shapley.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_shapley.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["shapley_acc"] = n_correct + results_dict["shapley_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["shapley_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["shapley_sens"] = sens + + ### Global + Local context + aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0) + rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_shapley_gl.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_shapley_gl.pkl', aveAttributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, aveAttributions, segmDataNP, modelCopy, converter, labels, scoring_singlechar) + results_dict["shapley_global_local_acc"] = n_correct + results_dict["shapley_global_local_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model_singlechar, perturb_fn, img1, aveAttributions).detach().cpu().numpy()) + results_dict["shapley_global_local_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["shapley_global_local_sens"] = sens + + + ### BASELINE Evaluations + + ### Integrated Gradients + ig = IntegratedGradients(super_pixel_model) + attributions = ig.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_intgrad.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_intgrad.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["intgrad_acc"] = n_correct + results_dict["intgrad_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["intgrad_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ig.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["intgrad_sens"] = sens + + ### Gradient SHAP using zero-background + gs = GradientShap(super_pixel_model) + # We define a distribution of baselines and draw `n_samples` from that + # distribution in order to estimate the expectations of gradients across all baselines + channelDim = 3 if opt.rgb else 1 + baseline_dist = torch.zeros((1, channelDim, opt.imgH, opt.imgW)) + baseline_dist = baseline_dist.to(device) + attributions = gs.attribute(input, baselines=baseline_dist, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_gradshap.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_gradshap.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["gradshap_acc"] = n_correct + results_dict["gradshap_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["gradshap_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(gs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["gradshap_sens"] = sens + + ### DeepLift using zero-background + dl = DeepLift(super_pixel_model) + attributions = dl.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_deeplift.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_deeplift.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["deeplift_acc"] = n_correct + results_dict["deeplift_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["deeplift_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(dl.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["deeplift_sens"] = sens + + ### Saliency + saliency = Saliency(super_pixel_model) + attributions = saliency.attribute(input, target=0) ### target=class0 + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_saliency.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_saliency.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["saliency_acc"] = n_correct + results_dict["saliency_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["saliency_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(saliency.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["saliency_sens"] = sens + + ### InputXGradient + input_x_gradient = InputXGradient(super_pixel_model) + attributions = input_x_gradient.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_inpxgrad.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_inpxgrad.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["inpxgrad_acc"] = n_correct + results_dict["inpxgrad_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["inpxgrad_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(input_x_gradient.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["inpxgrad_sens"] = sens + + ### GuidedBackprop + gbp = GuidedBackprop(super_pixel_model) + attributions = gbp.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_guidedbp.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_guidedbp.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["guidedbp_acc"] = n_correct + results_dict["guidedbp_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["guidedbp_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(gbp.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["guidedbp_sens"] = sens + + ### Deconvolution + deconv = Deconvolution(super_pixel_model) + attributions = deconv.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_deconv.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_deconv.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["deconv_acc"] = n_correct + results_dict["deconv_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["deconv_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(deconv.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["deconv_sens"] = sens + + ### Feature ablator + ablator = FeatureAblation(super_pixel_model) + attributions = ablator.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_featablt.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_featablt.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["featablt_acc"] = n_correct + results_dict["featablt_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["featablt_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ablator.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["featablt_sens"] = sens + + ## LIME + interpretable_model = SkLearnRidge(alpha=1, fit_intercept=True) ### This is the default used by LIME + lime = Lime(super_pixel_model, interpretable_model=interpretable_model) + attributions = lime.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_lime.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_lime.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["lime_acc"] = n_correct + results_dict["lime_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["lime_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(lime.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["lime_sens"] = sens + + ### KernelSHAP + ks = KernelShap(super_pixel_model) + attributions = ks.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_kernelshap.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_kernelshap.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["kernelshap_acc"] = n_correct + results_dict["kernelshap_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["kernelshap_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ks.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["kernelshap_sens"] = sens + + selectivity_eval_results.append(results_dict) + + with open(outputSelectivityPkl, 'wb') as f: + pickle.dump(selectivity_eval_results, f) + + testImgCount += 1 + print("testImgCount: ", testImgCount) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--eval_data', default='/home/deepblue/deepbluetwo/chenjun/1_OCR/data/data_lmdb_release/evaluation', help='path to evaluation dataset') + parser.add_argument('--benchmark_all_eval', default=True, help='evaluate 10 benchmark evaluation datasets') + parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) + parser.add_argument('--batch_size', type=int, default=64, help='input batch size') + parser.add_argument('--saved_model', default='./saved_models/None-ResNet-SRN-SRN-Seed666/iter_65000.pth', help="path to saved_model to evaluation") + """ Data processing """ + parser.add_argument('--scorer', type=str, default='mean', help='See STRScore: cumprod | mean') + parser.add_argument('--Transformer', action='store_true', help='Use end-to-end transformer') + parser.add_argument('--selective_sample_str', type=str, default='', help='If =='', only sample images with string matching this (see --sensitive for case sensitivity)') + parser.add_argument('--max_selective_list', type=int, default=-1, help='if selective sample list has elements greater than this, autoclear list for batch selection') + parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') + parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') + parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') + parser.add_argument('--rgb', action='store_true', help='use rgb input') + parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz$#', help='character label') + parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') + parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') + parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') + """ Model Architecture """ + parser.add_argument('--confidence_mode', type=int, default=0, help='0-sum of argmax; 1-edit distance') + parser.add_argument('--Transformation', type=str, default='None', help='Transformation stage. None|TPS') + parser.add_argument('--FeatureExtraction', type=str, default='ResNet', help='FeatureExtraction stage. VGG|RCNN|ResNet|AsterRes') + parser.add_argument('--SequenceModeling', type=str, default='SRN', help='SequenceModeling stage. None|BiLSTM|Bert') + parser.add_argument('--Prediction', type=str, default='SRN', help='Prediction stage. CTC|Attn|Bert_pred') + parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') + parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') + parser.add_argument('--output_channel', type=int, default=512, + help='the number of output channel of Feature extractor') + parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') + parser.add_argument('--position_dim', type=int, default=26, help='the length sequence out from cnn encoder,resnet:65;resnetfpn:256') + + parser.add_argument('--SRN_PAD', type=int, default=36, help='the pad character for srn') + parser.add_argument('--batch_max_character', type=int, default=25, help='the max sequence length') + opt = parser.parse_args() + + """ vocab / character number configuration """ + if opt.sensitive: + opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). + + opt.alphabet_size = len(opt.character) # + opt.SRN_PAD = len(opt.character)-1 + cudnn.benchmark = True + cudnn.deterministic = True + opt.num_gpu = torch.cuda.device_count() + + # combineBestDataXAI(opt) + # acquire_average_auc(opt) + # acquireListOfAveAUC(opt) + acquireSingleCharAttrAve(opt) diff --git a/captum_improve_trba.py b/captum_improve_trba.py new file mode 100644 index 0000000000000000000000000000000000000000..750bd205217b1e2d5fa03be1f3ecb7f7ccc66896 --- /dev/null +++ b/captum_improve_trba.py @@ -0,0 +1,941 @@ +import os +import time +import string +import argparse +import re +import sys +import random +import pickle + +import torch +import torch.backends.cudnn as cudnn +import torch.utils.data +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from skimage.color import gray2rgb +from nltk.metrics.distance import edit_distance +import cv2 + +from utils import CTCLabelConverter, AttnLabelConverter, Averager +from dataset_trba import hierarchical_dataset, AlignCollate +from model_trba import Model, SuperPixler, CastNumpy, STRScore +# import hiddenlayer as hl +from lime import lime_image +from lime.wrappers.scikit_image import SegmentationAlgorithm +import matplotlib.pyplot as plt +import random +from captum._utils.models.linear_model import SkLearnLinearModel, SkLearnRidge +import statistics +import settings +import sys +import copy +from captum_test import acquire_average_auc, saveAttrData +from captum_improve_vitstr import rankedAttributionsBySegm +from matplotlib import pyplot as plt +from captum.attr._utils.visualization import visualize_image_attr + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +from captum.attr import ( + GradientShap, + DeepLift, + DeepLiftShap, + IntegratedGradients, + LayerConductance, + NeuronConductance, + NoiseTunnel, + Saliency, + InputXGradient, + GuidedBackprop, + Deconvolution, + GuidedGradCam, + FeatureAblation, + ShapleyValueSampling, + Lime, + KernelShap +) + +from captum.metrics import ( + infidelity, + sensitivity_max +) + +def getPredAndConf(opt, model, scoring, image, converter, labels): + batch_size = image.size(0) + length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) + text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) + text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length) + if 'CTC' in opt.Prediction: + preds = model(image, text_for_pred) + + confScore = scoring(preds) + confScore = confScore.detach().cpu().numpy() + + # Calculate evaluation loss for CTC deocder. + preds_size = torch.IntTensor([preds.size(1)] * batch_size) + + # Select max probabilty (greedy decoding) then decode index to character + if opt.baiduCTC: + _, preds_index = preds.max(2) + preds_index = preds_index.view(-1) + else: + _, preds_index = preds.max(2) + preds_str = converter.decode(preds_index.data, preds_size.data)[0] + else: + preds = model(image, text_for_pred, is_train=False) + + confScore = scoring(preds) + confScore = confScore.detach().cpu().numpy() + + preds = preds[:, :text_for_loss.shape[1] - 1, :] + target = text_for_loss[:, 1:] # without [GO] Symbol + # cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1)) + + # select max probabilty (greedy decoding) then decode index to character + _, preds_index = preds.max(2) + preds_str = converter.decode(preds_index, length_for_pred) + + ### Remove all chars after '[s]' + preds_str = preds_str[0] + preds_str = preds_str[:preds_str.find('[s]')] + # pred = pred[:pred_EOS] + return preds_str, confScore + +### Output and save segmentations only for one dataset only +def outputSegmOnly(opt): + ### targetDataset - one dataset only, SVTP-645, CUTE80-288images + targetDataset = "CUTE80" # ['IIIT5k_3000', 'SVT', 'IC03_867', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] + targetHeight = 32 + targetWidth = 100 + segmRootDir = "/home/uclpc1/Documents/STR/datasets/segmen"\ + "tations/{}X{}/{}/".format(targetHeight, targetWidth, targetDataset) + + if not os.path.exists(segmRootDir): + os.makedirs(segmRootDir) + + opt.eval = True + ### Only IIIT5k_3000 + eval_data_list = [targetDataset] + target_output_orig = opt.outputOrigDir + + ### Taken from LIME + segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, + max_dist=200, ratio=0.2, + random_seed=random.randint(0, 1000)) + for eval_data in eval_data_list: + eval_data_path = os.path.join(opt.eval_data, eval_data) + AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) + eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt, targetDir=target_output_orig) + evaluation_loader = torch.utils.data.DataLoader( + eval_data, batch_size=1, + shuffle=False, + num_workers=int(opt.workers), + collate_fn=AlignCollate_evaluation, pin_memory=True) + for i, (image_tensors, labels) in enumerate(evaluation_loader): + image_tensors = ((image_tensors + 1.0) / 2.0) * 255.0 + imgDataDict = {} + img_numpy = image_tensors.cpu().detach().numpy()[0] ### Need to set batch size to 1 only + if img_numpy.shape[0] == 1: + img_numpy = gray2rgb(img_numpy[0]) + # print("img_numpy shape: ", img_numpy.shape) # (32,100,3) + segmOutput = segmentation_fn(img_numpy) + # print("segmOutput unique: ", len(np.unique(segmOutput))) + imgDataDict['segdata'] = segmOutput + imgDataDict['label'] = labels[0] + outputPickleFile = segmRootDir + "{}.pkl".format(i) + with open(outputPickleFile, 'wb') as f: + pickle.dump(imgDataDict, f) + +### Returns the mean for each segmentation having shape as the same as the input +### This function can only one attribution image at a time +def averageSegmentsOut(attr, segments): + averagedInput = torch.clone(attr) + sortedDict = {} + for x in np.unique(segments): + segmentMean = torch.mean(attr[segments == x][:]) + sortedDict[x] = float(segmentMean.detach().cpu().numpy()) + averagedInput[segments == x] = segmentMean + return averagedInput, sortedDict + +def acquireSelectivityHit(origImg, attributions, segmentations, model, converter, labels, scoring): + # print("segmentations unique len: ", np.unique(segmentations)) + aveSegmentations, sortedDict = averageSegmentsOut(attributions[0,0], segmentations) + sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])] + sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score + # print("sortedDict: ", sortedDict) # {0: -5.51e-06, 1: -1.469e-05, 2: -3.06e-05,...} + # print("aveSegmentations unique len: ", np.unique(aveSegmentations)) + # print("aveSegmentations device: ", aveSegmentations.device) # cuda:0 + # print("aveSegmentations shape: ", aveSegmentations.shape) # (224,224) + # print("aveSegmentations: ", aveSegmentations) + + n_correct = [] + confidenceList = [] # First index is one feature removed, second index two features removed, and so on... + clonedImg = torch.clone(origImg) + gt = str(labels[0]) + for totalSegToHide in range(0, len(sortedKeys)): + ### Acquire LIME prediction result + currentSegmentToHide = sortedKeys[totalSegToHide] + clonedImg[0,0][segmentations == currentSegmentToHide] = 0.0 + pred, confScore = getPredAndConf(opt, model, scoring, clonedImg, converter, np.array([gt])) + # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. + if opt.sensitive and opt.data_filtering_off: + pred = pred.lower() + gt = gt.lower() + alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz' + out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]' + pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred) + gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt) + if pred == gt: + n_correct.append(1) + else: + n_correct.append(0) + confScore = confScore[0][0]*100 + confidenceList.append(confScore) + return n_correct, confidenceList + +def main(opt): + # 'IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80' + datasetName = settings.TARGET_DATASET + custom_segm_dataroot = "{}/{}X{}/{}/".format(settings.SEGM_DIR, opt.imgH, opt.imgW, datasetName) + outputSelectivityPkl = "strexp_ave_{}_{}.pkl".format(settings.MODEL, datasetName) + outputDir = "./attributionImgs/{}/{}/".format(settings.MODEL, datasetName) + attrOutputDir = "./attributionData/{}/{}/".format(settings.MODEL, datasetName) + acquireSelectivity = True + acquireInfidelity = False + acquireSensitivity = False ### GPU error + imgHeight = 32 + imgWidth = 100 + if not os.path.exists(outputDir): + os.makedirs(outputDir) + if not os.path.exists(attrOutputDir): + os.makedirs(attrOutputDir) + + """ model configuration """ + if 'CTC' in opt.Prediction: + converter = CTCLabelConverter(opt.character) + else: + converter = AttnLabelConverter(opt.character) + opt.num_class = len(converter.character) + + if opt.rgb: + opt.input_channel = 3 + model_obj = Model(opt, device) + print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, + opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, + opt.SequenceModeling, opt.Prediction) + model = torch.nn.DataParallel(model_obj).to(device) + + # load model + print('loading pretrained model from %s' % opt.saved_model) + model.load_state_dict(torch.load(opt.saved_model, map_location=device)) + opt.exp_name = '_'.join(opt.saved_model.split('/')[1:]) + + modelCopy = copy.deepcopy(model) + scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True) + super_pixel_model_singlechar = torch.nn.Sequential( + # super_pixler, + # numpy2torch_converter, + modelCopy, + scoring_singlechar + ).to(device) + modelCopy.train() + scoring_singlechar.train() + super_pixel_model_singlechar.train() + + scoring = STRScore(opt=opt, converter=converter, device=device) + super_pixel_model = torch.nn.Sequential( + model, + scoring + ) + model.train() + scoring.train() + super_pixel_model.train() + + """ keep evaluation model and result logs """ + os.makedirs(f'./result/{opt.exp_name}', exist_ok=True) + os.system(f'cp {opt.saved_model} ./result/{opt.exp_name}/') + + """ setup loss """ + if 'CTC' in opt.Prediction: + criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) + else: + criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 + + """Output shap values""" + """ evaluation with 10 benchmark evaluation datasets """ + # The evaluation datasets, dataset order is same with Table 1 in our paper. + # eval_data_list = ['IIIT5k_3000', 'IC03_860', 'IC03_867', 'IC15_1811'] + target_output_orig = opt.outputOrigDir + # eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', + # 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] + # eval_data_list = ['IIIT5k_3000'] + eval_data_list = [datasetName] + # # To easily compute the total accuracy of our paper. + # eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_867', + # 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] + + list_accuracy = [] + total_forward_time = 0 + total_evaluation_data_number = 0 + total_correct_number = 0 + log = open(f'./result/{opt.exp_name}/log_all_evaluation.txt', 'a') + dashed_line = '-' * 80 + print(dashed_line) + log.write(dashed_line + '\n') + + selectivity_eval_results = [] + imageData = [] + targetText = "all" + middleMaskThreshold = 5 + testImgCount = 0 + imgResultDir = str(opt.Transformation) + "-" + str(opt.FeatureExtraction) + "-" + str(opt.SequenceModeling) + "-" + str(opt.Prediction) + "-" + str(opt.scorer) + + # define a perturbation function for the input (used for calculating infidelity) + def perturb_fn(modelInputs): + noise = torch.tensor(np.random.normal(0, 0.003, modelInputs.shape)).float() + noise = noise.to(device) + return noise, modelInputs - noise + + if opt.blackbg: + shapImgLs = np.zeros(shape=(1, 1, 32, 100)).astype(np.float32) + trainList = np.array(shapImgLs) + background = torch.from_numpy(trainList).to(device) + if imgResultDir != "": + if not os.path.exists(imgResultDir): + os.makedirs(imgResultDir) + for eval_data in eval_data_list: + eval_data_path = os.path.join(opt.eval_data, eval_data) + AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) + eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt, targetDir=target_output_orig) + evaluation_loader = torch.utils.data.DataLoader( + eval_data, batch_size=1, + shuffle=False, + num_workers=int(opt.workers), + collate_fn=AlignCollate_evaluation, pin_memory=True) + # image_tensors, labels = next(iter(evaluation_loader)) ### Iterate one batch only + for i, (orig_img_tensors, labels) in enumerate(evaluation_loader): + # img_rgb *= 255.0 + # img_rgb = img_rgb.astype('int') + # print("img_rgb max: ", img_rgb.max()) ### 255 + # img_rgb = np.asarray(orig_img_tensors) + # segmentations = segmentation_fn(img_rgb) + # print("segmentations shape: ", segmentations.shape) # (224, 224) + # print("segmentations min: ", segmentations.min()) 0 + # print("Unique: ", len(np.unique(segmentations))) # (70) + # print("target: ", target) tensor([[ 0, 29, 26, 25, 12 + results_dict = {} + pklFilename = custom_segm_dataroot + "{}.pkl".format(i) + with open(pklFilename, 'rb') as f: + pklData = pickle.load(f) + segmDataNP = pklData["segdata"] + # print("segmDataNP unique: ", len(np.unique(segmDataNP))) + assert pklData["label"] == labels[0] + segmTensor = torch.from_numpy(segmDataNP).unsqueeze(0).unsqueeze(0) + # print("segmTensor min: ", segmTensor.min()) # 0 starting segmentation + segmTensor = segmTensor.to(device) + # print("segmTensor shape: ", segmTensor.shape) + # img1 = np.asarray(imgPIL.convert('L')) + # sys.exit() + # img1 = img1 / 255.0 + # img1 = torch.from_numpy(img1).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device) + img1 = orig_img_tensors.to(device) + img1.requires_grad = True + bgImg = torch.zeros(img1.shape).to(device) + # preds = model(img1, seqlen=converter.batch_max_length) + target = converter.encode(labels) + target = target[0][:, 1:] + charOffset = 0 + input = img1 + origImgNP = torch.clone(orig_img_tensors).detach().cpu().numpy()[0][0] # (1, 1, 224, 224) + origImgNP = gray2rgb(origImgNP) + + # preds = model(input) + # preds_prob = F.softmax(preds, dim=2) + # preds_max_prob, preds_max_idx = preds_prob.max(dim=2) + # print("preds_max_idx: ", preds_max_idx) tensor([[14, 26, 25, 12 + + ### Captum test + collectedAttributions = [] + for charIdx in range(0, len(labels)): + scoring_singlechar.setSingleCharOutput(charIdx + charOffset) + gtClassNum = target[0][charIdx + charOffset] + + ### Shapley Value Sampling + svs = ShapleyValueSampling(super_pixel_model_singlechar) + # attr = svs.attribute(input, target=0, n_samples=200) ### Individual pixels, too long to calculate + attributions = svs.attribute(input, target=gtClassNum, feature_mask=segmTensor) + collectedAttributions.append(attributions) + aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0) + rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_shapley_l.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_shapley_l.pkl', aveAttributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, aveAttributions, segmDataNP, modelCopy, converter, labels, scoring_singlechar) + results_dict["shapley_local_acc"] = n_correct + results_dict["shapley_local_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model_singlechar, perturb_fn, img1, aveAttributions).detach().cpu().numpy()) + results_dict["shapley_local_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["shapley_local_sens"] = sens + + ### Shapley Value Sampling + svs = ShapleyValueSampling(super_pixel_model) + # attr = svs.attribute(input, target=0, n_samples=200) ### Individual pixels, too long to calculate + attributions = svs.attribute(input, target=0, feature_mask=segmTensor) + collectedAttributions.append(attributions) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_shapley.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_shapley.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["shapley_acc"] = n_correct + results_dict["shapley_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["shapley_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["shapley_sens"] = sens + + ### Global + Local context + aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0) + rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_shapley_gl.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_shapley_gl.pkl', aveAttributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, aveAttributions, segmDataNP, modelCopy, converter, labels, scoring_singlechar) + results_dict["shapley_global_local_acc"] = n_correct + results_dict["shapley_global_local_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model_singlechar, perturb_fn, img1, aveAttributions).detach().cpu().numpy()) + results_dict["shapley_global_local_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["shapley_global_local_sens"] = sens + + + # Baselines + ### Integrated Gradients + ig = IntegratedGradients(super_pixel_model) + attributions = ig.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_intgrad.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_intgrad.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["intgrad_acc"] = n_correct + results_dict["intgrad_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["intgrad_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ig.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["intgrad_sens"] = sens + + ### Gradient SHAP using zero-background + gs = GradientShap(super_pixel_model) + # We define a distribution of baselines and draw `n_samples` from that + # distribution in order to estimate the expectations of gradients across all baselines + baseline_dist = torch.zeros((1, 1, imgHeight, imgWidth)) + baseline_dist = baseline_dist.to(device) + attributions = gs.attribute(input, baselines=baseline_dist, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_gradshap.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_gradshap.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["gradshap_acc"] = n_correct + results_dict["gradshap_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["gradshap_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(gs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["gradshap_sens"] = sens + + ### DeepLift using zero-background + dl = DeepLift(super_pixel_model) + attributions = dl.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_deeplift.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_deeplift.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["deeplift_acc"] = n_correct + results_dict["deeplift_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["deeplift_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(dl.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["deeplift_sens"] = sens + + ### Saliency + saliency = Saliency(super_pixel_model) + attributions = saliency.attribute(input, target=0) ### target=class0 + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_saliency.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_saliency.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["saliency_acc"] = n_correct + results_dict["saliency_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["saliency_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(saliency.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["saliency_sens"] = sens + + ### InputXGradient + input_x_gradient = InputXGradient(super_pixel_model) + attributions = input_x_gradient.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_inpxgrad.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_inpxgrad.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["inpxgrad_acc"] = n_correct + results_dict["inpxgrad_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["inpxgrad_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(input_x_gradient.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["inpxgrad_sens"] = sens + + ## GuidedBackprop + gbp = GuidedBackprop(super_pixel_model) + attributions = gbp.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_guidedbp.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_guidedbp.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["guidedbp_acc"] = n_correct + results_dict["guidedbp_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["guidedbp_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(gbp.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["guidedbp_sens"] = sens + # + # ## Deconvolution + deconv = Deconvolution(super_pixel_model) + attributions = deconv.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_deconv.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_deconv.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["deconv_acc"] = n_correct + results_dict["deconv_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["deconv_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(deconv.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["deconv_sens"] = sens + + ### Feature ablator + ablator = FeatureAblation(super_pixel_model) + attributions = ablator.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_featablt.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_featablt.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["featablt_acc"] = n_correct + results_dict["featablt_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["featablt_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ablator.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["featablt_sens"] = sens + + ## LIME + interpretable_model = SkLearnRidge(alpha=1, fit_intercept=True) ### This is the default used by LIME + lime = Lime(super_pixel_model, interpretable_model=interpretable_model) + attributions = lime.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_lime.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_lime.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["lime_acc"] = n_correct + results_dict["lime_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["lime_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(lime.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["lime_sens"] = sens + + ### KernelSHAP + ks = KernelShap(super_pixel_model) + attributions = ks.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_kernelshap.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_kernelshap.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["kernelshap_acc"] = n_correct + results_dict["kernelshap_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions, normalize=True).detach().cpu().numpy()) + results_dict["kernelshap_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ks.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["kernelshap_sens"] = sens + + selectivity_eval_results.append(results_dict) + + with open(outputSelectivityPkl, 'wb') as f: + pickle.dump(selectivity_eval_results, f) + + testImgCount += 1 + print("testImgCount: ", testImgCount) + +def outputOrigImagesOnly(opt): + datasetName = "CUTE80" # ['IIIT5k_3000', 'SVT', 'IC03_867', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] + opt.outputOrigDir = "./datasetOrigImgs/{}/".format(datasetName) + opt.output_orig = True + opt.corruption_num = 0 + opt.apply_corruptions = False + opt.min_imgnum = 0 + opt.max_imgnum = 1000 + + target_output_orig = opt.outputOrigDir + if not os.path.exists(target_output_orig): + os.makedirs(target_output_orig) + + """ model configuration """ + if 'CTC' in opt.Prediction: + converter = CTCLabelConverter(opt.character) + else: + converter = AttnLabelConverter(opt.character) + opt.num_class = len(converter.character) + + if opt.rgb: + opt.input_channel = 3 + model_obj = Model(opt, device) + print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, + opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, + opt.SequenceModeling, opt.Prediction) + model = torch.nn.DataParallel(model_obj).to(device) + + # load model + print('loading pretrained model from %s' % opt.saved_model) + model.load_state_dict(torch.load(opt.saved_model, map_location=device)) + opt.exp_name = '_'.join(opt.saved_model.split('/')[1:]) + scoring = STRScore(opt=opt, converter=converter, device=device) + ### + + super_pixel_model = torch.nn.Sequential( + model, + scoring + ) + model.train() + scoring.train() + super_pixel_model.train() + # print(model) + + """ keep evaluation model and result logs """ + os.makedirs(f'./result/{opt.exp_name}', exist_ok=True) + os.system(f'cp {opt.saved_model} ./result/{opt.exp_name}/') + + """ setup loss """ + if 'CTC' in opt.Prediction: + criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) + else: + criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 + + """Output shap values""" + """ evaluation with 10 benchmark evaluation datasets """ + # The evaluation datasets, dataset order is same with Table 1 in our paper. + # eval_data_list = ['IIIT5k_3000', 'IC03_860', 'IC03_867', 'IC15_1811'] + # eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', + # 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] + # eval_data_list = ['IIIT5k_3000'] + eval_data_list = [datasetName] + # # To easily compute the total accuracy of our paper. + # eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_867', + # 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] + + list_accuracy = [] + total_forward_time = 0 + total_evaluation_data_number = 0 + total_correct_number = 0 + log = open(f'./result/{opt.exp_name}/log_all_evaluation.txt', 'a') + dashed_line = '-' * 80 + print(dashed_line) + log.write(dashed_line + '\n') + + selectivity_eval_results = [] + imageData = [] + targetText = "all" + middleMaskThreshold = 5 + testImgCount = 0 + imgResultDir = str(opt.Transformation) + "-" + str(opt.FeatureExtraction) + "-" + str(opt.SequenceModeling) + "-" + str(opt.Prediction) + "-" + str(opt.scorer) + + if opt.blackbg: + shapImgLs = np.zeros(shape=(1, 1, 32, 100)).astype(np.float32) + trainList = np.array(shapImgLs) + background = torch.from_numpy(trainList).to(device) + if imgResultDir != "": + if not os.path.exists(imgResultDir): + os.makedirs(imgResultDir) + for eval_data in eval_data_list: + eval_data_path = os.path.join(opt.eval_data, eval_data) + AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) + eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt, targetDir=target_output_orig) + evaluation_loader = torch.utils.data.DataLoader( + eval_data, batch_size=1, + shuffle=False, + num_workers=int(opt.workers), + collate_fn=AlignCollate_evaluation, pin_memory=True) + # image_tensors, labels = next(iter(evaluation_loader)) ### Iterate one batch only + for i, (orig_img_tensors, labels) in enumerate(evaluation_loader): + testImgCount += 1 + print("testImgCount: ", testImgCount) + +### Use to check if the model predicted the image or not. Output a pickle file with the image index. +def modelDatasetPredOnly(opt): + ### targetDataset - one dataset only, CUTE80 has 288 samples + targetDataset = "CUTE80" # ['IIIT5k_3000', 'SVT', 'IC03_867', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] + outputSelectivityPkl = "metrics_predictonly_results_{}.pkl".format(targetDataset) + start_time = time.time() + + """ model configuration """ + if 'CTC' in opt.Prediction: + converter = CTCLabelConverter(opt.character) + else: + converter = AttnLabelConverter(opt.character) + opt.num_class = len(converter.character) + + if opt.rgb: + opt.input_channel = 3 + model_obj = Model(opt, device) + print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, + opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, + opt.SequenceModeling, opt.Prediction) + model = torch.nn.DataParallel(model_obj).to(device) + + # load model + print('loading pretrained model from %s' % opt.saved_model) + model.load_state_dict(torch.load(opt.saved_model, map_location=device)) + opt.exp_name = '_'.join(opt.saved_model.split('/')[1:]) + scoring = STRScore(opt=opt, converter=converter, device=device) + ### + + super_pixel_model = torch.nn.Sequential( + model, + scoring + ) + model.train() + scoring.train() + super_pixel_model.train() + + if opt.blackbg: + shapImgLs = np.zeros(shape=(1, 1, 224, 224)).astype(np.float32) + trainList = np.array(shapImgLs) + background = torch.from_numpy(trainList).to(device) + + opt.eval = True + eval_data_list = [targetDataset] + + testImgCount = 0 + list_accuracy = [] + total_forward_time = 0 + total_evaluation_data_number = 0 + total_correct_number = 0 + log = open(f'./result/{opt.exp_name}/log_all_evaluation.txt', 'a') + dashed_line = '-' * 80 + print(dashed_line) + log.write(dashed_line + '\n') + target_output_orig = opt.outputOrigDir + predOutput = [] + for eval_data in eval_data_list: + eval_data_path = os.path.join(opt.eval_data, eval_data) + AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) + eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt, targetDir=target_output_orig) + evaluation_loader = torch.utils.data.DataLoader( + eval_data, batch_size=1, + shuffle=False, + num_workers=int(opt.workers), + collate_fn=AlignCollate_evaluation, pin_memory=True) + testImgCount = 0 + for i, (orig_img_tensors, labels) in enumerate(evaluation_loader): + image = orig_img_tensors.to(device) + batch_size = 1 + length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) + text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) + text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length) + if 'CTC' in opt.Prediction: + preds = model(image, text_for_pred) + + confScore = scoring(preds) + confScore = confScore.detach().cpu().numpy() + + # Calculate evaluation loss for CTC deocder. + preds_size = torch.IntTensor([preds.size(1)] * batch_size) + + # Select max probabilty (greedy decoding) then decode index to character + if opt.baiduCTC: + _, preds_index = preds.max(2) + preds_index = preds_index.view(-1) + else: + _, preds_index = preds.max(2) + preds_str = converter.decode(preds_index.data, preds_size.data)[0] + else: + preds = model(image, text_for_pred, is_train=False) + + confScore = scoring(preds) + confScore = confScore.detach().cpu().numpy() + + preds = preds[:, :text_for_loss.shape[1] - 1, :] + target = text_for_loss[:, 1:] # without [GO] Symbol + # cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1)) + + # select max probabilty (greedy decoding) then decode index to character + _, preds_index = preds.max(2) + preds_str = converter.decode(preds_index, length_for_pred) + + ### Remove all chars after '[s]' + preds_str = preds_str[0] + preds_str = preds_str[:preds_str.find('[s]')] + # print("preds_str: ", preds_str) # lowercased prediction + # print("labels: ", labels[0]) # gt already in lowercased + if preds_str==labels[0]: predOutput.append(1) + else: predOutput.append(0) + + with open(outputSelectivityPkl, 'wb') as f: + pickle.dump(predOutput, f) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--eval_data', required=True, help='path to evaluation dataset') + parser.add_argument('--benchmark_all_eval', action='store_true', help='evaluate 10 benchmark evaluation datasets') + parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) + parser.add_argument('--batch_size', type=int, default=192, help='input batch size') + parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation") + """ Data processing """ + parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') + parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') + parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') + parser.add_argument('--superHeight', type=int, default=5, help='the height of the superpixel') + parser.add_argument('--superWidth', type=int, default=2, help='the width of the superpixel') + parser.add_argument('--min_imgnum', type=int, default=0, help='set this to skip for loop index of specific image number') + parser.add_argument('--max_imgnum', type=int, default=2, help='set this to skip for loop index of specific image number') + parser.add_argument('--severity', type=int, default=1, help='severity level if apply corruptions') + parser.add_argument('--scorer', type=str, default='cumprod', help='See STRScore: cumprod | mean') + parser.add_argument('--corruption_num', type=int, default=0, help='corruption to apply') + parser.add_argument('--confidence_mode', type=int, default=0, help='0-sum of argmax; 1-edit distance') + parser.add_argument('--outputOrigDir', type=str, default="output_orig/", help='output directory to save original \ + images. This will be automatically created. Needs --output_orig too.') + parser.add_argument('--output_orig', action='store_true', help='if true, output first original rgb image of each batch') + parser.add_argument('--compare_corrupt', action='store_true', help='set to true to output results across corruptions') + parser.add_argument('--is_shap', action='store_true', help='no need to call in command line') + parser.add_argument('--blackbg', action='store_true', help='if True, background color for covering features will be black(0)') + parser.add_argument('--rgb', action='store_true', help='use rgb input') + parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') + parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') + parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') + parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') + parser.add_argument('--apply_corruptions', action='store_true', help='apply corruptions to images') + parser.add_argument('--output_feat_maps', action='store_true', help='toggle this to output images of featmaps') + parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode') + """ Model Architecture """ + parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') + parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet') + parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') + parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') + parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') + parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') + parser.add_argument('--output_channel', type=int, default=512, + help='the number of output channel of Feature extractor') + parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') + + opt = parser.parse_args() + + """ vocab / character number configuration """ + if opt.sensitive: + opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). + + cudnn.benchmark = True + cudnn.deterministic = True + opt.num_gpu = torch.cuda.device_count() + + # acquire_average_auc(opt) + main(opt) + # outputOrigImagesOnly(opt) diff --git a/captum_improve_vitstr.py b/captum_improve_vitstr.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1c30338e980d30b09e02cf064d1769f1f2ec45 --- /dev/null +++ b/captum_improve_vitstr.py @@ -0,0 +1,671 @@ +import settings +import captum +import numpy as np +import torch +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +from utils import get_args +from utils import CTCLabelConverter, AttnLabelConverter, Averager, TokenLabelConverter +import string +import time +import sys +from dataset import hierarchical_dataset, AlignCollate +import validators +from model import Model, STRScore +from PIL import Image +from lime.wrappers.scikit_image import SegmentationAlgorithm +from captum._utils.models.linear_model import SkLearnLinearModel, SkLearnRidge +import random +import os +from skimage.color import gray2rgb +import pickle +from train_shap_corr import getPredAndConf +import re +from captum_test import acquire_average_auc, saveAttrData +import copy +from skimage.color import gray2rgb +from matplotlib import pyplot as plt + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +from captum.attr import ( + GradientShap, + DeepLift, + DeepLiftShap, + IntegratedGradients, + LayerConductance, + NeuronConductance, + NoiseTunnel, + Saliency, + InputXGradient, + GuidedBackprop, + Deconvolution, + GuidedGradCam, + FeatureAblation, + ShapleyValueSampling, + Lime, + KernelShap +) + +from captum.metrics import ( + infidelity, + sensitivity_max +) + +from captum.attr._utils.visualization import visualize_image_attr + +### Acquire pixelwise attributions and replace them with ranked numbers averaged +### across segmentation with the largest contribution having the largest number +### and the smallest set to 1, which is the minimum number. +### attr - original attribution +### segm - image segmentations +def rankedAttributionsBySegm(attr, segm): + aveSegmentations, sortedDict = averageSegmentsOut(attr[0,0], segm) + totalSegm = len(sortedDict.keys()) # total segmentations + sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])] + sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score + currentRank = totalSegm + rankedSegmImg = torch.clone(attr) + for totalSegToHide in range(0, len(sortedKeys)): + currentSegmentToHide = sortedKeys[totalSegToHide] + rankedSegmImg[0,0][segm == currentSegmentToHide] = currentRank + currentRank -= 1 + return rankedSegmImg + +### Returns the mean for each segmentation having shape as the same as the input +### This function can only one attribution image at a time +def averageSegmentsOut(attr, segments): + averagedInput = torch.clone(attr) + sortedDict = {} + for x in np.unique(segments): + segmentMean = torch.mean(attr[segments == x][:]) + sortedDict[x] = float(segmentMean.detach().cpu().numpy()) + averagedInput[segments == x] = segmentMean + return averagedInput, sortedDict + +### Output and save segmentations only for one dataset only +def outputSegmOnly(opt): + ### targetDataset - one dataset only, SVTP-645, CUTE80-288images + targetDataset = "CUTE80" # ['IIIT5k_3000', 'SVT', 'IC03_867', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] + segmRootDir = "/home/uclpc1/Documents/STR/datasets/segmentations/224X224/{}/".format(targetDataset) + + if not os.path.exists(segmRootDir): + os.makedirs(segmRootDir) + + opt.eval = True + ### Only IIIT5k_3000 + if opt.fast_acc: + # # To easily compute the total accuracy of our paper. + eval_data_list = [targetDataset] + else: + # The evaluation datasets, dataset order is same with Table 1 in our paper. + eval_data_list = [targetDataset] + + ### Taken from LIME + segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, + max_dist=200, ratio=0.2, + random_seed=random.randint(0, 1000)) + + for eval_data in eval_data_list: + eval_data_path = os.path.join(opt.eval_data, eval_data) + AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt) + eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt) + evaluation_loader = torch.utils.data.DataLoader( + eval_data, batch_size=1, + shuffle=False, + num_workers=int(opt.workers), + collate_fn=AlignCollate_evaluation, pin_memory=True) + for i, (image_tensors, labels) in enumerate(evaluation_loader): + imgDataDict = {} + img_numpy = image_tensors.cpu().detach().numpy()[0] ### Need to set batch size to 1 only + if img_numpy.shape[0] == 1: + img_numpy = gray2rgb(img_numpy[0]) + # print("img_numpy shape: ", img_numpy.shape) # (224,224,3) + segmOutput = segmentation_fn(img_numpy) + imgDataDict['segdata'] = segmOutput + imgDataDict['label'] = labels[0] + outputPickleFile = segmRootDir + "{}.pkl".format(i) + with open(outputPickleFile, 'wb') as f: + pickle.dump(imgDataDict, f) + +def acquireSelectivityHit(origImg, attributions, segmentations, model, converter, labels, scoring): + # print("segmentations unique len: ", np.unique(segmentations)) + aveSegmentations, sortedDict = averageSegmentsOut(attributions[0,0], segmentations) + sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])] + sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score + # print("sortedDict: ", sortedDict) # {0: -5.51e-06, 1: -1.469e-05, 2: -3.06e-05,...} + # print("aveSegmentations unique len: ", np.unique(aveSegmentations)) + # print("aveSegmentations device: ", aveSegmentations.device) # cuda:0 + # print("aveSegmentations shape: ", aveSegmentations.shape) # (224,224) + # print("aveSegmentations: ", aveSegmentations) + + n_correct = [] + confidenceList = [] # First index is one feature removed, second index two features removed, and so on... + clonedImg = torch.clone(origImg) + gt = str(labels) + for totalSegToHide in range(0, len(sortedKeys)): + ### Acquire LIME prediction result + currentSegmentToHide = sortedKeys[totalSegToHide] + clonedImg[0,0][segmentations == currentSegmentToHide] = 0.0 + pred, confScore = getPredAndConf(opt, model, scoring, clonedImg, converter, np.array([gt])) + # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. + if opt.sensitive and opt.data_filtering_off: + pred = pred.lower() + gt = gt.lower() + alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz' + out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]' + pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred) + gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt) + if pred == gt: + n_correct.append(1) + else: + n_correct.append(0) + confScore = confScore[0][0]*100 + confidenceList.append(confScore) + return n_correct, confidenceList + +### Once you have the selectivity_eval_results.pkl file, +def acquire_selectivity_auc(opt, pkl_filename=None): + if pkl_filename is None: + pkl_filename = "/home/goo/str/str_vit_dataexplain_lambda/metrics_sensitivity_eval_results_CUTE80.pkl" # VITSTR + accKeys = [] + + with open(pkl_filename, 'rb') as f: + selectivity_data = pickle.load(f) + + for resDictIdx, resDict in enumerate(selectivity_data): + keylistAcc = [] + keylistConf = [] + metricsKeys = resDict.keys() + for keyStr in resDict.keys(): + if "_acc" in keyStr: keylistAcc.append(keyStr) + if "_conf" in keyStr: keylistConf.append(keyStr) + # Need to check if network correctly predicted the image + for metrics_accStr in keylistAcc: + if 1 not in resDict[metrics_accStr]: print("resDictIdx") + +### This acquires the attributes of the STR network on individual character levels, +### then averages them. +def acquireSingleCharAttrAve(opt): + ### targetDataset - one dataset only, CUTE80 has 288 samples + # 'IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80 + targetDataset = settings.TARGET_DATASET + segmRootDir = "{}/224X224/{}/".format(settings.SEGM_DIR, targetDataset) + outputSelectivityPkl = "strexp_ave_{}_{}.pkl".format(settings.MODEL, targetDataset) + outputDir = "./attributionImgs/{}/{}/".format(settings.MODEL, targetDataset) + attrOutputDir = "./attributionData/{}/{}/".format(settings.MODEL, targetDataset) + ### Set only one below to True to have enough GPU + acquireSelectivity = True + acquireInfidelity = False + acquireSensitivity = False ### GPU error + + if not os.path.exists(outputDir): + os.makedirs(outputDir) + if not os.path.exists(attrOutputDir): + os.makedirs(attrOutputDir) + + start_time = time.time() + + """ model configuration """ + if opt.Transformer: + converter = TokenLabelConverter(opt) + elif 'CTC' in opt.Prediction: + converter = CTCLabelConverter(opt.character) + else: + converter = AttnLabelConverter(opt.character) + opt.num_class = len(converter.character) + + if opt.rgb: + opt.input_channel = 3 + model_obj = Model(opt) + + print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, + opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, + opt.SequenceModeling, opt.Prediction) + model = torch.nn.DataParallel(model_obj).to(device) + + # load model + print('loading pretrained model from %s' % opt.saved_model) + + if validators.url(opt.saved_model): + print("opt.saved_model: ", opt.saved_model) + model.load_state_dict(torch.hub.load_state_dict_from_url(opt.saved_model, progress=True, map_location=device)) + else: + model.load_state_dict(torch.load(opt.saved_model, map_location=device)) + opt.exp_name = '_'.join(opt.saved_model.split('/')[1:]) + # print(model) + + """ keep evaluation model and result logs """ + os.makedirs(f'./result/{opt.exp_name}', exist_ok=True) + os.system(f'cp {opt.saved_model} ./result/{opt.exp_name}/') + + """ setup loss """ + if 'CTC' in opt.Prediction: + criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) + else: + criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 + + modelCopy = copy.deepcopy(model) + + """ evaluation """ + scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True) + super_pixel_model_singlechar = torch.nn.Sequential( + # super_pixler, + # numpy2torch_converter, + modelCopy, + scoring_singlechar + ).to(device) + modelCopy.eval() + scoring_singlechar.eval() + super_pixel_model_singlechar.eval() + + # Single Char Attribution Averaging + # enableSingleCharAttrAve - set to True + scoring = STRScore(opt=opt, converter=converter, device=device) + super_pixel_model = torch.nn.Sequential( + # super_pixler, + # numpy2torch_converter, + model, + scoring + ).to(device) + model.eval() + scoring.eval() + super_pixel_model.eval() + + if opt.blackbg: + shapImgLs = np.zeros(shape=(1, 1, 224, 224)).astype(np.float32) + trainList = np.array(shapImgLs) + background = torch.from_numpy(trainList).to(device) + + opt.eval = True + + ### Only IIIT5k_3000 + if opt.fast_acc: + # # To easily compute the total accuracy of our paper. + eval_data_list = [targetDataset] ### One dataset only + else: + # The evaluation datasets, dataset order is same with Table 1 in our paper. + eval_data_list = [targetDataset] + + if opt.calculate_infer_time: + evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image. + else: + evaluation_batch_size = opt.batch_size + + selectivity_eval_results = [] + + testImgCount = 0 + list_accuracy = [] + total_forward_time = 0 + total_evaluation_data_number = 0 + total_correct_number = 0 + log = open(f'./result/{opt.exp_name}/log_all_evaluation.txt', 'a') + dashed_line = '-' * 80 + print(dashed_line) + log.write(dashed_line + '\n') + + segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, + max_dist=200, ratio=0.2, + random_seed=random.randint(0, 1000)) + + for eval_data in eval_data_list: + eval_data_path = os.path.join(opt.eval_data, eval_data) + AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt) + eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt, segmRootDir=segmRootDir) + evaluation_loader = torch.utils.data.DataLoader( + eval_data, batch_size=1, + shuffle=False, + num_workers=int(opt.workers), + collate_fn=AlignCollate_evaluation, pin_memory=True) + testImgCount = 0 + + for i, (orig_img_tensors, segAndLabels) in enumerate(evaluation_loader): + results_dict = {} + aveAttr = [] + aveAttr_charContrib = [] + segmData, labels = segAndLabels[0] + target = converter.encode([labels]) + + # labels: RONALDO + segmDataNP = segmData["segdata"] + segmTensor = torch.from_numpy(segmDataNP).unsqueeze(0).unsqueeze(0) + # print("segmTensor min: ", segmTensor.min()) # 0 starting segmentation + segmTensor = segmTensor.to(device) + img1 = orig_img_tensors.to(device) + img1.requires_grad = True + bgImg = torch.zeros(img1.shape).to(device) + + ### Single char averaging + charOffset = 1 + + # preds = model(img1, seqlen=converter.batch_max_length) + input = img1 + origImgNP = torch.clone(orig_img_tensors).detach().cpu().numpy()[0][0] # (1, 1, 224, 224) + origImgNP = gray2rgb(origImgNP) + + ### Local explanations only + collectedAttributions = [] + for charIdx in range(0, len(labels)): + scoring_singlechar.setSingleCharOutput(charIdx + charOffset) + gtClassNum = target[0][charIdx + charOffset] + + ### Shapley Value Sampling + svs = ShapleyValueSampling(super_pixel_model_singlechar) + # attr = svs.attribute(input, target=0, n_samples=200) ### Individual pixels, too long to calculate + attributions = svs.attribute(input, target=gtClassNum, feature_mask=segmTensor) + collectedAttributions.append(attributions) + aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0) + rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_shapley_l.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_shapley_l.pkl', aveAttributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, aveAttributions, segmDataNP, modelCopy, converter, labels, scoring_singlechar) + results_dict["shapley_local_acc"] = n_correct + results_dict["shapley_local_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model_singlechar, perturb_fn, img1, aveAttributions).detach().cpu().numpy()) + results_dict["shapley_local_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["shapley_local_sens"] = sens + + ### Shapley Value Sampling + svs = ShapleyValueSampling(super_pixel_model) + # attr = svs.attribute(input, target=0, n_samples=200) ### Individual pixels, too long to calculate + attributions = svs.attribute(input, target=0, feature_mask=segmTensor) + collectedAttributions.append(attributions) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_shapley.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_shapley.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["shapley_acc"] = n_correct + results_dict["shapley_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["shapley_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["shapley_sens"] = sens + + ### Global + Local context + aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0) + rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_shapley_gl.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_shapley_gl.pkl', aveAttributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, aveAttributions, segmDataNP, modelCopy, converter, labels, scoring_singlechar) + results_dict["shapley_global_local_acc"] = n_correct + results_dict["shapley_global_local_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model_singlechar, perturb_fn, img1, aveAttributions).detach().cpu().numpy()) + results_dict["shapley_global_local_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(svs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["shapley_global_local_sens"] = sens + + ### BASELINE Evaluations + + ### Integrated Gradients + ig = IntegratedGradients(super_pixel_model) + attributions = ig.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_intgrad.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_intgrad.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["intgrad_acc"] = n_correct + results_dict["intgrad_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["intgrad_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ig.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["intgrad_sens"] = sens + + ### Gradient SHAP using zero-background + gs = GradientShap(super_pixel_model) + # We define a distribution of baselines and draw `n_samples` from that + # distribution in order to estimate the expectations of gradients across all baselines + baseline_dist = torch.zeros((1, 1, 224, 224)) + baseline_dist = baseline_dist.to(device) + attributions = gs.attribute(input, baselines=baseline_dist, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_gradshap.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_gradshap.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["gradshap_acc"] = n_correct + results_dict["gradshap_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["gradshap_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(gs.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["gradshap_sens"] = sens + + ### DeepLift using zero-background + dl = DeepLift(super_pixel_model) + attributions = dl.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_deeplift.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_deeplift.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["deeplift_acc"] = n_correct + results_dict["deeplift_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["deeplift_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(dl.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["deeplift_sens"] = sens + + ### Saliency + saliency = Saliency(super_pixel_model) + attributions = saliency.attribute(input, target=0) ### target=class0 + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_saliency.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_saliency.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["saliency_acc"] = n_correct + results_dict["saliency_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["saliency_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(saliency.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["saliency_sens"] = sens + + ### InputXGradient + input_x_gradient = InputXGradient(super_pixel_model) + attributions = input_x_gradient.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_inpxgrad.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_inpxgrad.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["inpxgrad_acc"] = n_correct + results_dict["inpxgrad_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["inpxgrad_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(input_x_gradient.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["inpxgrad_sens"] = sens + + ### GuidedBackprop + gbp = GuidedBackprop(super_pixel_model) + attributions = gbp.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_guidedbp.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_guidedbp.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["guidedbp_acc"] = n_correct + results_dict["guidedbp_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["guidedbp_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(gbp.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["guidedbp_sens"] = sens + + ### Deconvolution + deconv = Deconvolution(super_pixel_model) + attributions = deconv.attribute(input, target=0) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_deconv.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_deconv.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["deconv_acc"] = n_correct + results_dict["deconv_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["deconv_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(deconv.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["deconv_sens"] = sens + + ### Feature ablator + ablator = FeatureAblation(super_pixel_model) + attributions = ablator.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_featablt.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_featablt.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["featablt_acc"] = n_correct + results_dict["featablt_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["featablt_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ablator.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["featablt_sens"] = sens + + ## LIME + interpretable_model = SkLearnRidge(alpha=1, fit_intercept=True) ### This is the default used by LIME + lime = Lime(super_pixel_model, interpretable_model=interpretable_model) + attributions = lime.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_lime.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_lime.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["lime_acc"] = n_correct + results_dict["lime_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["lime_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(lime.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["lime_sens"] = sens + + ### KernelSHAP + ks = KernelShap(super_pixel_model) + attributions = ks.attribute(input, target=0, feature_mask=segmTensor) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map') + mplotfig.savefig(outputDir + '{}_kernelshap.png'.format(i)) + mplotfig.clear() + plt.close(mplotfig) + saveAttrData(attrOutputDir + f'{i}_kernelshap.pkl', attributions, segmDataNP, origImgNP) + if acquireSelectivity: + n_correct, confidenceList = acquireSelectivityHit(img1, attributions, segmDataNP, model, converter, labels, scoring) + results_dict["kernelshap_acc"] = n_correct + results_dict["kernelshap_conf"] = confidenceList + if acquireInfidelity: + infid = float(infidelity(super_pixel_model, perturb_fn, img1, attributions).detach().cpu().numpy()) + results_dict["kernelshap_infid"] = infid + if acquireSensitivity: + sens = float(sensitivity_max(ks.attribute, img1, target=0).detach().cpu().numpy()) + results_dict["kernelshap_sens"] = sens + + selectivity_eval_results.append(results_dict) + + with open(outputSelectivityPkl, 'wb') as f: + pickle.dump(selectivity_eval_results, f) + + testImgCount += 1 + print("testImgCount: ", testImgCount) + +if __name__ == '__main__': + # deleteInf() + opt = get_args(is_train=False) + + """ vocab / character number configuration """ + if opt.sensitive: + opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). + + cudnn.benchmark = True + cudnn.deterministic = True + opt.num_gpu = torch.cuda.device_count() + + # combineBestDataXAI(opt) + # acquire_average_auc(opt) + acquireSingleCharAttrAve(opt) diff --git a/captum_test.py b/captum_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8fedf8a70c3032bd1424f119e7714dd4d3143caa --- /dev/null +++ b/captum_test.py @@ -0,0 +1,349 @@ +import settings +import captum +import numpy as np +import torch +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +from utils import get_args +from utils import CTCLabelConverter, AttnLabelConverter, Averager, TokenLabelConverter +import string +import time +import sys +from dataset import hierarchical_dataset, AlignCollate +import validators +from model import Model, STRScore +from PIL import Image +from lime.wrappers.scikit_image import SegmentationAlgorithm +from captum._utils.models.linear_model import SkLearnLinearModel, SkLearnRidge +import random +import os +from skimage.color import gray2rgb +import pickle +from train_shap_corr import getPredAndConf +import re +import copy +import statistics + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +from captum.attr import ( + GradientShap, + DeepLift, + DeepLiftShap, + IntegratedGradients, + LayerConductance, + NeuronConductance, + NoiseTunnel, + Saliency, + InputXGradient, + GuidedBackprop, + Deconvolution, + GuidedGradCam, + FeatureAblation, + ShapleyValueSampling, + Lime, + KernelShap +) + +from captum.metrics import ( + infidelity, + sensitivity_max +) + +### Returns the mean for each segmentation having shape as the same as the input +### This function can only one attribution image at a time +def averageSegmentsOut(attr, segments): + averagedInput = torch.clone(attr) + sortedDict = {} + for x in np.unique(segments): + segmentMean = torch.mean(attr[segments == x][:]) + sortedDict[x] = float(segmentMean.detach().cpu().numpy()) + averagedInput[segments == x] = segmentMean + return averagedInput, sortedDict + +### Output and save segmentations only for one dataset only +def outputSegmOnly(opt): + ### targetDataset - one dataset only, SVTP-645, CUTE80-288images + targetDataset = "CUTE80" # ['IIIT5k_3000', 'SVT', 'IC03_867', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] + segmRootDir = "/home/uclpc1/Documents/STR/datasets/segmentations/224X224/{}/".format(targetDataset) + + if not os.path.exists(segmRootDir): + os.makedirs(segmRootDir) + + opt.eval = True + ### Only IIIT5k_3000 + if opt.fast_acc: + # # To easily compute the total accuracy of our paper. + eval_data_list = [targetDataset] + else: + # The evaluation datasets, dataset order is same with Table 1 in our paper. + eval_data_list = [targetDataset] + + ### Taken from LIME + segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, + max_dist=200, ratio=0.2, + random_seed=random.randint(0, 1000)) + + for eval_data in eval_data_list: + eval_data_path = os.path.join(opt.eval_data, eval_data) + AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt) + eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt) + evaluation_loader = torch.utils.data.DataLoader( + eval_data, batch_size=1, + shuffle=False, + num_workers=int(opt.workers), + collate_fn=AlignCollate_evaluation, pin_memory=True) + for i, (image_tensors, labels) in enumerate(evaluation_loader): + imgDataDict = {} + img_numpy = image_tensors.cpu().detach().numpy()[0] ### Need to set batch size to 1 only + if img_numpy.shape[0] == 1: + img_numpy = gray2rgb(img_numpy[0]) + # print("img_numpy shape: ", img_numpy.shape) # (224,224,3) + segmOutput = segmentation_fn(img_numpy) + imgDataDict['segdata'] = segmOutput + imgDataDict['label'] = labels[0] + outputPickleFile = segmRootDir + "{}.pkl".format(i) + with open(outputPickleFile, 'wb') as f: + pickle.dump(imgDataDict, f) + +def acquireSelectivityHit(origImg, attributions, segmentations, model, converter, labels, scoring): + # print("segmentations unique len: ", np.unique(segmentations)) + aveSegmentations, sortedDict = averageSegmentsOut(attributions[0,0], segmentations) + sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])] + sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score + # print("sortedDict: ", sortedDict) # {0: -5.51e-06, 1: -1.469e-05, 2: -3.06e-05,...} + # print("aveSegmentations unique len: ", np.unique(aveSegmentations)) + # print("aveSegmentations device: ", aveSegmentations.device) # cuda:0 + # print("aveSegmentations shape: ", aveSegmentations.shape) # (224,224) + # print("aveSegmentations: ", aveSegmentations) + + n_correct = [] + confidenceList = [] # First index is one feature removed, second index two features removed, and so on... + clonedImg = torch.clone(origImg) + gt = str(labels) + for totalSegToHide in range(0, len(sortedKeys)): + ### Acquire LIME prediction result + currentSegmentToHide = sortedKeys[totalSegToHide] + clonedImg[0,0][segmentations == currentSegmentToHide] = 0.0 + pred, confScore = getPredAndConf(opt, model, scoring, clonedImg, converter, np.array([gt])) + # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. + if opt.sensitive and opt.data_filtering_off: + pred = pred.lower() + gt = gt.lower() + alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz' + out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]' + pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred) + gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt) + if pred == gt: + n_correct.append(1) + else: + n_correct.append(0) + confScore = confScore[0][0]*100 + confidenceList.append(confScore) + return n_correct, confidenceList + +### Once you have the selectivity_eval_results.pkl file, +def acquire_selectivity_auc(opt, pkl_filename=None): + if pkl_filename is None: + pkl_filename = "/home/goo/str/str_vit_dataexplain_lambda/metrics_sensitivity_eval_results_CUTE80.pkl" # VITSTR + accKeys = [] + + with open(pkl_filename, 'rb') as f: + selectivity_data = pickle.load(f) + + for resDictIdx, resDict in enumerate(selectivity_data): + keylistAcc = [] + keylistConf = [] + metricsKeys = resDict.keys() + for keyStr in resDict.keys(): + if "_acc" in keyStr: keylistAcc.append(keyStr) + if "_conf" in keyStr: keylistConf.append(keyStr) + # Need to check if network correctly predicted the image + for metrics_accStr in keylistAcc: + if 1 not in resDict[metrics_accStr]: print("resDictIdx") + +## gtClassNum - set to gtClassNum=0 for standard implemention, or specific class idx for local explanation +def acquireAttribution(opt, super_model, input, segmTensor, gtClassNum, lowestAccKey, device): + channels = 1 + if opt.rgb: + channels = 3 + + ### Perform attribution + if "intgrad_" in lowestAccKey: + ig = IntegratedGradients(super_model) + attributions = ig.attribute(input, target=gtClassNum) + elif "gradshap_" in lowestAccKey: + gs = GradientShap(super_model) + baseline_dist = torch.zeros((1, channels, opt.imgH, opt.imgW)) + baseline_dist = baseline_dist.to(device) + attributions = gs.attribute(input, baselines=baseline_dist, target=gtClassNum) + elif "deeplift_" in lowestAccKey: + dl = DeepLift(super_model) + attributions = dl.attribute(input, target=gtClassNum) + elif "saliency_" in lowestAccKey: + saliency = Saliency(super_model) + attributions = saliency.attribute(input, target=gtClassNum) + elif "inpxgrad_" in lowestAccKey: + input_x_gradient = InputXGradient(super_model) + attributions = input_x_gradient.attribute(input, target=gtClassNum) + elif "guidedbp_" in lowestAccKey: + gbp = GuidedBackprop(super_model) + attributions = gbp.attribute(input, target=gtClassNum) + elif "deconv_" in lowestAccKey: + deconv = Deconvolution(super_model) + attributions = deconv.attribute(input, target=gtClassNum) + elif "featablt_" in lowestAccKey: + ablator = FeatureAblation(super_model) + attributions = ablator.attribute(input, target=gtClassNum, feature_mask=segmTensor) + elif "shapley_" in lowestAccKey: + svs = ShapleyValueSampling(super_model) + attributions = svs.attribute(input, target=gtClassNum, feature_mask=segmTensor) + elif "lime_" in lowestAccKey: + interpretable_model = SkLearnRidge(alpha=1, fit_intercept=True) ### This is the default used by LIME + lime = Lime(super_model, interpretable_model=interpretable_model) + attributions = lime.attribute(input, target=gtClassNum, feature_mask=segmTensor) + elif "kernelshap_" in lowestAccKey: + ks = KernelShap(super_model) + attributions = ks.attribute(input, target=gtClassNum, feature_mask=segmTensor) + else: + assert False + return attributions + +### In addition to acquire_average_auc(), this function also returns the best selectivity_acc attr-based method +### pklFile - you need to pass pkl file here +def acquire_bestacc_attr(opt, pickleFile): + # pickleFile = "metrics_sensitivity_eval_results_IIIT5k_3000.pkl" + # pickleFile = "/home/goo/str/str_vit_dataexplain_lambda/shapley_singlechar_ave_matrn_SVT.pkl" + acquireSelectivity = True # If True, set to + acquireInfidelity = False + acquireSensitivity = False + + with open(pickleFile, 'rb') as f: + data = pickle.load(f) + metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens" + selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle + for imgData in data: + if acquireSelectivity: + for keyStr in imgData.keys(): + if ("_acc" in keyStr or "_conf" in keyStr) and not ("_local_" in keyStr or "_global_local_" in keyStr): # Accept only selectivity + if keyStr not in metricDict: + metricDict[keyStr] = [] + dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0] + dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0 + denom = [1] * len(dataList) # Denominator to normalize AUC + auc_norm = np.trapz(dataList) / np.trapz(denom) + metricDict[keyStr].append(auc_norm) + elif acquireInfidelity: + pass # TODO + elif acquireSensitivity: + pass # TODO + + lowestAccKey = "" + lowestAcc = 10000000 + for metricKey in metricDict: + if "_acc" in metricKey: # Used for selectivity accuracy only + statisticVal = statistics.mean(metricDict[metricKey]) + if statisticVal < lowestAcc: + lowestAcc = statisticVal + lowestAccKey = metricKey + # print("{}: {}".format(metricKey, statisticVal)) + + assert lowestAccKey!="" + return lowestAccKey + +def saveAttrData(filename, attribution, segmData, origImg): + pklData = {} + pklData['attribution'] = torch.clone(attribution).detach().cpu().numpy() + pklData['segmData'] = segmData + pklData['origImg'] = origImg + with open(filename, 'wb') as f: + pickle.dump(pklData, f) + +### New code (8/3/2022) to acquire average selectivity, infidelity, etc. after running captum test +def acquire_average_auc(opt): + # pickleFile = "metrics_sensitivity_eval_results_IIIT5k_3000.pkl" + pickleFile = "/home/goo/str/str_vit_dataexplain_lambda/shapley_singlechar_ave_vitstr_IC03_860.pkl" + acquireSelectivity = True # If True, set to + acquireInfidelity = False + acquireSensitivity = False + + with open(pickleFile, 'rb') as f: + data = pickle.load(f) + metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens" + selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle + for imgData in data: + if acquireSelectivity: + for keyStr in imgData.keys(): + if "_acc" in keyStr or "_conf" in keyStr: # Accept only selectivity + if keyStr not in metricDict: + metricDict[keyStr] = [] + dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0] + dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0 + denom = [1] * len(dataList) # Denominator to normalize AUC + auc_norm = np.trapz(dataList) / np.trapz(denom) + metricDict[keyStr].append(auc_norm) + elif acquireInfidelity: + pass # TODO + elif acquireSensitivity: + pass # TODO + + for metricKey in metricDict: + print("{}: {}".format(metricKey, statistics.mean(metricDict[metricKey]))) + +### Use this acquire list +def acquireListOfAveAUC(opt): + acquireSelectivity = True + acquireInfidelity = False + acquireSensitivity = False + totalChars = 10 + collectedMetricDict = {} + for charNum in range(0, totalChars): + pickleFile = f"/home/goo/str/str_vit_dataexplain_lambda/singlechar{charNum}_results_{totalChars}chardataset.pkl" + with open(pickleFile, 'rb') as f: + data = pickle.load(f) + metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens" + selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle + for imgData in data: + if acquireSelectivity: + for keyStr in imgData.keys(): + if "_acc" in keyStr or "_conf" in keyStr: # Accept only selectivity + if keyStr not in metricDict: + metricDict[keyStr] = [] + dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0] + dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0 + denom = [1] * len(dataList) # Denominator to normalize AUC + auc_norm = np.trapz(dataList) / np.trapz(denom) + metricDict[keyStr].append(auc_norm) + for metricKey in metricDict: + selec_auc_normalize = statistics.mean(metricDict[metricKey]) + if metricKey not in collectedMetricDict: + collectedMetricDict[metricKey] = [] + collectedMetricDict[metricKey].append(selec_auc_normalize) + for collectedMetricDictKey in collectedMetricDict: + print("{}: {}".format(collectedMetricDictKey, collectedMetricDict[collectedMetricDictKey])) + for charNum in range(0, totalChars): + selectivityAcrossCharsLs = [] + for collectedMetricDictKey in collectedMetricDict: + if "_acc" in collectedMetricDictKey: + selectivityAcrossCharsLs.append(collectedMetricDict[collectedMetricDictKey][charNum]) + print("accuracy -- {}: {}".format(charNum, statistics.mean(selectivityAcrossCharsLs))) + for charNum in range(0, totalChars): + selectivityAcrossCharsLs = [] + for collectedMetricDictKey in collectedMetricDict: + if "_conf" in collectedMetricDictKey: + selectivityAcrossCharsLs.append(collectedMetricDict[collectedMetricDictKey][charNum]) + print("confidence -- {}: {}".format(charNum, statistics.mean(selectivityAcrossCharsLs))) + +if __name__ == '__main__': + # deleteInf() + opt = get_args(is_train=False) + + """ vocab / character number configuration """ + if opt.sensitive: + opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). + + cudnn.benchmark = True + cudnn.deterministic = True + opt.num_gpu = torch.cuda.device_count() + + main(opt) diff --git a/captum_test_eval.py b/captum_test_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..80d4091469caf931549f1afb956e87d81fc0b4a1 --- /dev/null +++ b/captum_test_eval.py @@ -0,0 +1,126 @@ +import pickle +import copy +import numpy as np +import statistics +import sys +import os +from captum.attr._utils.visualization import visualize_image_attr +import matplotlib.pyplot as plt + +### New code (8/3/2022) to acquire average selectivity, infidelity, etc. after running captum test +def acquire_average_auc(): + # pickleFile = "metrics_sensitivity_eval_results_IIIT5k_3000.pkl" + pickleFile = "shapley_singlechar_ave_vitstr_IC15_1811.pkl" + acquireSelectivity = True # If True, set to + acquireInfidelity = False + acquireSensitivity = False + + with open(pickleFile, 'rb') as f: + data = pickle.load(f) + metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens" + selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle + for imgData in data: + if acquireSelectivity: + for keyStr in imgData.keys(): + if "_acc" in keyStr or "_conf" in keyStr: # Accept only selectivity + if keyStr not in metricDict: + metricDict[keyStr] = [] + dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0] + dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0 + denom = [1] * len(dataList) # Denominator to normalize AUC + auc_norm = np.trapz(dataList) / np.trapz(denom) + if not np.isnan(auc_norm).any(): + metricDict[keyStr].append(auc_norm) + elif acquireInfidelity: + pass # TODO + elif acquireSensitivity: + pass # TODO + + for metricKey in metricDict: + print("{}: {}".format(metricKey, statistics.mean(metricDict[metricKey]))) + +### +def sumOfAllAttributions(): + modelName = "trba" + datasetName = "IC15_1811" # IIIT5k_3000, IC03_867, IC13_857, IC15_1811 + mainRootDir = "/data/goo/strattr/" + rootDir = f"{mainRootDir}attributionData/{modelName}/{datasetName}/" + numpyOutputDir = mainRootDir + + if modelName=="vitstr": + shape = [1, 1, 224, 224] + elif modelName =="parseq": + shape = [1, 3, 32, 128] + elif modelName =="trba": + shape = [1, 1, 32, 100] + # pickleFile = f"shapley_singlechar_ave_{modelName}_{datasetName}.pkl" + # acquireSelectivity = True + # with open(pickleFile, 'rb') as f: + # data = pickle.load(f) + # metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens" + # + # for imgData in data: + # if acquireSelectivity: + # for keyStr in imgData.keys(): + # print("keyStr: ", keyStr) + # if "_acc" in keyStr or "_conf" in keyStr: # Accept only selectivity + # if keyStr not in metricDict: + # metricDict[keyStr] = [] + # dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0] + # dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0 + # denom = [1] * len(dataList) # Denominator to normalize AUC + # auc_norm = np.trapz(dataList) / np.trapz(denom) + + totalImgCount = 0 + # From a folder containing saved attribution pickle files, convert them into attribution images + for path, subdirs, files in os.walk(rootDir): + for name in files: + fullfilename = os.path.join(rootDir, name) # Value + # fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl + if "_gl." not in fullfilename.split('/')[-1]: # Accept only global+local + continue + totalImgCount += 1 + shape[0] = totalImgCount + main_np = np.memmap(numpyOutputDir+f"aveattr_{modelName}_{datasetName}.dat", dtype='float32', mode='w+', shape=tuple(shape)) + + attrIdx = 0 + # From a folder containing saved attribution pickle files, convert them into attribution images + leftGreaterRightAcc = 0.0 + for path, subdirs, files in os.walk(rootDir): + for name in files: + fullfilename = os.path.join(rootDir, name) # Value + # fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl + if "_gl." not in fullfilename.split('/')[-1]: # Accept only global+local + continue + print("fullfilename: ", fullfilename) + # imgNum = int(partfilename.split('_')[0]) + # attrImgName = partfilename.replace('.pkl', '.png') + # minNumber = min(minNumber, imgNum) + # maxNumber = max(maxNumber, imgNum) + with open(fullfilename, 'rb') as f: + pklData = pickle.load(f) + attributions = pklData['attribution'] + segmDataNP = pklData['segmData'] + origImgNP = pklData['origImg'] + if np.isnan(attributions).any(): + continue + # attributions[0] = (attributions[0] - attributions[0].min()) / (attributions[0].max() - attributions[0].min()) + main_np[attrIdx] = attributions[0] + sumLeft = np.sum(attributions[0,:,:,0:attributions.shape[3]//2]) + sumRight = np.sum(attributions[0,:,:,attributions.shape[3]//2:]) + if sumLeft > sumRight: + leftGreaterRightAcc += 1.0 + attrIdx += 1 + print("leftGreaterRightAcc: ", leftGreaterRightAcc/attrIdx) + main_np.flush() + meanAveAttr = np.transpose(np.mean(main_np, axis=0), (1,2,0)) + print("meanAveAttr shape: ", meanAveAttr.shape) # (1, 3, 32, 128) + meanAveAttr = 2*((meanAveAttr - meanAveAttr.min()) / (meanAveAttr.max() - meanAveAttr.min())) - 1.0 + mplotfig, _ = visualize_image_attr(meanAveAttr, cmap='RdYlGn') # input should be in (H,W,C) + mplotfig.savefig(numpyOutputDir+f"aveattr_{modelName}_{datasetName}.png") + mplotfig.clear() + plt.close(mplotfig) + +if __name__ == '__main__': + # acquire_average_auc() + sumOfAllAttributions() diff --git a/configs/template.yaml b/configs/template.yaml new file mode 100644 index 0000000000000000000000000000000000000000..837ca5c2218cb39b9d98a4e64681a136dc22ee10 --- /dev/null +++ b/configs/template.yaml @@ -0,0 +1,67 @@ +global: + name: exp + phase: train + stage: pretrain-vision + workdir: /tmp/workdir + seed: ~ + +dataset: + train: { + roots: ['data/training/MJ/MJ_train/', + 'data/training/MJ/MJ_test/', + 'data/training/MJ/MJ_valid/', + 'data/training/ST'], + batch_size: 128 + } + test: { + roots: ['data/evaluation/IIIT5k_3000', + 'data/evaluation/SVT', + 'data/evaluation/SVTP', + 'data/evaluation/IC13_857', + 'data/evaluation/IC15_1811', + 'data/evaluation/CUTE80'], + batch_size: 128 + } + charset_path: data/charset_36.txt + num_workers: 4 + max_length: 25 # 30 + image_height: 32 + image_width: 128 + case_sensitive: False + eval_case_sensitive: False + data_aug: True + multiscales: False + pin_memory: False + smooth_label: False + smooth_factor: 0.1 + one_hot_y: True + use_sm: False + +training: + epochs: 6 + show_iters: 50 + eval_iters: 3000 + save_iters: 20000 + start_iters: 0 + stats_iters: 100000 + +optimizer: + type: Adadelta # Adadelta, Adam + true_wd: False + wd: 0. # 0.001 + bn_wd: False + args: { + # betas: !!python/tuple [0.9, 0.99], # betas=(0.9,0.99) for AdamW + # betas: !!python/tuple [0.9, 0.999], # for default Adam + } + clip_grad: 20 + lr: [1.0, 1.0, 1.0] # lr: [0.005, 0.005, 0.005] + scheduler: { + periods: [3, 2, 1], + gamma: 0.1, + } + +model: + name: 'modules_abinet.model_abinet.ABINetModel' + checkpoint: ~ + strict: True diff --git a/configs/train_abinet.yaml b/configs/train_abinet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1828058ab7d62e756e6bc73f2eb72a179205926c --- /dev/null +++ b/configs/train_abinet.yaml @@ -0,0 +1,66 @@ +global: + name: train-abinet + phase: train + stage: train-super + workdir: workdir + seed: ~ + +dataset: + train: { + roots: ['datasets/data_lmdb_release/training/MJ/MJ_train/', + 'datasets/data_lmdb_release/training/MJ/MJ_test/', + 'datasets/data_lmdb_release/training/MJ/MJ_valid/', + 'datasets/data_lmdb_release/training/ST'], + batch_size: 384 + } + test: { + roots: ['datasets/data_lmdb_release/evaluation/SVTP'], + batch_size: 1 + } + data_aug: False + multiscales: False + num_workers: 0 + +training: + epochs: 10 + show_iters: 50 + eval_iters: 3000 + save_iters: 3000 + +optimizer: + type: Adam + true_wd: False + wd: 0.0 + bn_wd: False + clip_grad: 20 + lr: 0.0001 + args: { + betas: !!python/tuple [0.9, 0.999], # for default Adam + } + scheduler: { + periods: [6, 4], + gamma: 0.1, + } + +model: + name: 'modules_abinet.model_abinet_iter.ABINetIterModel' + iter_size: 3 + ensemble: '' + use_vision: False + vision: { + checkpoint: pretrained/abinet_vision.pth, + loss_weight: 1., + attention: 'position', + backbone: 'transformer', + backbone_ln: 3, + } + language: { + checkpoint: pretrained/abinet_language_model.pth, + num_layers: 4, + loss_weight: 1., + detach: True, + use_self_attn: False + } + alignment: { + loss_weight: 1., + } diff --git a/configs/train_matrn.yaml b/configs/train_matrn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8d6c1a4137f6e7fd045cde1c92b7ba0ef0f6ab03 --- /dev/null +++ b/configs/train_matrn.yaml @@ -0,0 +1,73 @@ +global: + name: train-matrn + phase: train + stage: train-super + workdir: results + seed: ~ + +dataset: + train: { + roots: ['data/training/MJ/MJ_train/', + 'data/training/MJ/MJ_test/', + 'data/training/MJ/MJ_valid/', + 'data/training/ST'], + batch_size: 384 + } + test: { + roots: ['datasets/data_lmdb_release/evaluation/SVTP'], + batch_size: 1 + } + valid: { + roots: ['data/validation'], + batch_size: 384 + } + data_aug: True + multiscales: False + num_workers: 0 + +training: + epochs: 10 + show_iters: 50 + eval_iters: 3000 + save_iters: 3000 + +optimizer: + type: Adam + true_wd: False + wd: 0.0 + bn_wd: False + clip_grad: 20 + lr: 0.0001 + args: { + betas: !!python/tuple [0.9, 0.999], # for default Adam + } + scheduler: { + periods: [6, 4], + gamma: 0.1, + } + +model: + name: 'modules_matrn.model_matrn_iter.MATRN' + iter_size: 3 + ensemble: '' + use_vision: False + vision: { + checkpoint: 'pretrained/best-pretrain-vision-model.pth', + loss_weight: 1., + attention: 'position', + backbone: 'transformer', + backbone_ln: 3, + } + language: { + checkpoint: 'pretrained/pretrain-language-model.pth', + num_layers: 4, + loss_weight: 1., + detach: True, + use_self_attn: False + } + alignment: { + checkpoint: ~, + num_layers: 2, + loss_weight: 1., + use_self_attn: False, + } diff --git a/data/charset_36.txt b/data/charset_36.txt new file mode 100644 index 0000000000000000000000000000000000000000..212c9c59c7efd6e6a2be73812a33cb26e8e7bc42 --- /dev/null +++ b/data/charset_36.txt @@ -0,0 +1,36 @@ +0 a +1 b +2 c +3 d +4 e +5 f +6 g +7 h +8 i +9 j +10 k +11 l +12 m +13 n +14 o +15 p +16 q +17 r +18 s +19 t +20 u +21 v +22 w +23 x +24 y +25 z +26 1 +27 2 +28 3 +29 4 +30 5 +31 6 +32 7 +33 8 +34 9 +35 0 \ No newline at end of file diff --git a/data/charset_62.txt b/data/charset_62.txt new file mode 100644 index 0000000000000000000000000000000000000000..0eab7a0b57a3a4f6c2ad08d3a6172423895909eb --- /dev/null +++ b/data/charset_62.txt @@ -0,0 +1,62 @@ +0 0 +1 1 +2 2 +3 3 +4 4 +5 5 +6 6 +7 7 +8 8 +9 9 +10 A +11 B +12 C +13 D +14 E +15 F +16 G +17 H +18 I +19 J +20 K +21 L +22 M +23 N +24 O +25 P +26 Q +27 R +28 S +29 T +30 U +31 V +32 W +33 X +34 Y +35 Z +36 a +37 b +38 c +39 d +40 e +41 f +42 g +43 h +44 i +45 j +46 k +47 l +48 m +49 n +50 o +51 p +52 q +53 r +54 s +55 t +56 u +57 v +58 w +59 x +60 y +61 z \ No newline at end of file diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6e94b8436797e8f1bddd851ba375a4f26a6e6355 --- /dev/null +++ b/dataset.py @@ -0,0 +1,921 @@ +import os +import sys +import re +import six +import math +import lmdb +import torch +import copy +import random +import pickle + +from augmentation.weather import Fog, Snow, Frost +from augmentation.warp import Curve, Distort, Stretch +from augmentation.geometry import Rotate, Perspective, Shrink, TranslateX, TranslateY +from augmentation.pattern import VGrid, HGrid, Grid, RectGrid, EllipseGrid +from augmentation.noise import GaussianNoise, ShotNoise, ImpulseNoise, SpeckleNoise +from augmentation.blur import GaussianBlur, DefocusBlur, MotionBlur, GlassBlur, ZoomBlur +from augmentation.camera import Contrast, Brightness, JpegCompression, Pixelate +from augmentation.weather import Fog, Snow, Frost, Rain, Shadow +from augmentation.process import Posterize, Solarize, Invert, Equalize, AutoContrast, Sharpness, Color + +from natsort import natsorted +from PIL import Image +import PIL.ImageOps +import numpy as np +from torch.utils.data import Dataset, ConcatDataset, Subset +from torch._utils import _accumulate +import torchvision.transforms as transforms +import torchvision.transforms.functional as TF +import random + + +class Batch_Balanced_Dataset(object): + + def __init__(self, opt): + """ + Modulate the data ratio in the batch. + For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5", + the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST. + """ + if not os.path.exists(f'./saved_models/{opt.exp_name}/'): + os.makedirs(f'./saved_models/{opt.exp_name}/') + log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') + dashed_line = '-' * 80 + print(dashed_line) + log.write(dashed_line + '\n') + print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}') + log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n') + assert len(opt.select_data) == len(opt.batch_ratio) + + _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt) + self.data_loader_list = [] + self.dataloader_iter_list = [] + batch_size_list = [] + Total_batch_size = 0 + notSelectiveVal = True + if opt.selective_sample_str != '': + notSelectiveVal = False + for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio): + _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) + print(dashed_line) + log.write(dashed_line + '\n') + _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, notSelective=notSelectiveVal, select_data=[selected_d]) + total_number_dataset = len(_dataset) + log.write(_dataset_log) + + """ + The total number of data can be modified with opt.total_data_usage_ratio. + ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage. + See 4.2 section in our paper. + """ + number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio)) + dataset_split = [number_dataset, total_number_dataset - number_dataset] + indices = range(total_number_dataset) + _dataset, _ = [Subset(_dataset, indices[offset - length:offset]) + for offset, length in zip(_accumulate(dataset_split), dataset_split)] + selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n' + selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}' + print(selected_d_log) + log.write(selected_d_log + '\n') + batch_size_list.append(str(_batch_size)) + Total_batch_size += _batch_size + + _data_loader = torch.utils.data.DataLoader( + _dataset, batch_size=_batch_size, + shuffle=True, + num_workers=int(opt.workers), + collate_fn=_AlignCollate, pin_memory=True) + self.data_loader_list.append(_data_loader) + self.dataloader_iter_list.append(iter(_data_loader)) + + Total_batch_size_log = f'{dashed_line}\n' + batch_size_sum = '+'.join(batch_size_list) + Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n' + Total_batch_size_log += f'{dashed_line}' + opt.batch_size = Total_batch_size + + print(Total_batch_size_log) + log.write(Total_batch_size_log + '\n') + log.close() + + def get_batch(self): + balanced_batch_images = [] + balanced_batch_texts = [] + + for i, data_loader_iter in enumerate(self.dataloader_iter_list): + try: + image, text = data_loader_iter.next() + balanced_batch_images.append(image) + balanced_batch_texts += text + except StopIteration: + self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) + image, text = self.dataloader_iter_list[i].next() + balanced_batch_images.append(image) + balanced_batch_texts += text + except ValueError: + pass + + balanced_batch_images = torch.cat(balanced_batch_images, 0) + + return balanced_batch_images, balanced_batch_texts + +### notSelective - when False, LMDB dataset loader goes to the routine of randomly +### sampling indices to match --selective_sample_str, else it will no execute the code in the while loop +### and just do the normal VITSTR code +def hierarchical_dataset(root, opt, notSelective=True, select_data='/', segmRootDir=None, maxImages=None): + """ select_data='/' contains all sub-directory of root directory """ + dataset_list = [] + dataset_log = f'dataset_root: {root}\t dataset: {select_data[0]}' + print(dataset_log) + dataset_log += '\n' + for dirpath, dirnames, filenames in os.walk(root+'/'): + if not dirnames: + select_flag = False + for selected_d in select_data: + if selected_d in dirpath: + select_flag = True + break + + if select_flag: + if segmRootDir is None: + dataset = LmdbDataset(dirpath, opt, notSelective, maxImages=maxImages) + else: + dataset = LMDBSegmentationDataset(dirpath, opt, notSelective, segmRootDir=segmRootDir, maxImages=maxImages) + sub_dataset_log = f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}' + print(sub_dataset_log) + dataset_log += f'{sub_dataset_log}\n' + dataset_list.append(dataset) + + concatenated_dataset = ConcatDataset(dataset_list) + + return concatenated_dataset, dataset_log + +class ValidDataset(Dataset): + ### validPklData - pickle containing mapping of validIdx to original train/test idx + ### knnDataRoot - root dir to open pickle file for knn, with forward slash + ### knnCount - max number of knn from 0-knnCount, not necessarily the same number as + ### inside the pickle knns + ### typeSet - if 'train' or 'test' + ### offsetStartIdx - start index of dataset to sample (0 to N-1), where N is size of valid test set + ### offsetEndIdx - end index of dataset to sample (0 to N-1), where N is size of valid test set + ### actual size of this dataset will be offsetStartIdx - offsetEndIdx + def __init__(self, validPklData, lmdbDataset, typeSet, knnDataRoot, knnCount=None, offsetStartIdx=None, offsetEndIdx=None): + self.validPklData = validPklData + self.lmdbDataset = lmdbDataset + self.typeSet = typeSet + self.knnCount = knnCount + self.totalValidImgs = len(validPklData) + self.knnDataRoot = knnDataRoot + ### this function is only for the test dataloader, remember to set batch size to one + self.currentIdx = None + self.knnPklData = None + self.offsetStartIdx = None + if offsetStartIdx is not None: + self.totalValidImgs = offsetEndIdx - offsetStartIdx + self.offsetStartIdx = offsetStartIdx + ### this function is purposely created for the trainset dataloader + ### call this function to load new pickle file for knn for training set + ### be sure to call this function before looping over the dataloader again + ### This function also applies offsetting for the test index num i + def setCurrentTestNumKNN(self, testValidIdx): + knnPklFile = self.knnDataRoot + "test" + str(testValidIdx + self.offsetStartIdx) + "knn.pkl" + with open(knnPklFile, 'rb') as f: + ### this data is a list of indices with index 0 nearest to the textValidIdx + ### according to FAISS KNN + self.knnPklData = pickle.load(f) + self.totalValidImgs = self.knnCount + ### index should be the same number thrown by __getitem__ function + ### this function will only work properly if the batch size of testdataloader is equal to one + def getValidPklIdx(self): + return self.currentIdx + def __len__(self): + return self.totalValidImgs + def __getitem__(self, index): + if self.typeSet == 'train': + data, label = self.lmdbDataset[self.validPklData[self.knnPklData[index]]] + elif self.typeSet == 'test': + if self.offsetStartIdx is not None: + index = index + self.offsetStartIdx + self.currentIdx = index + data, label = self.lmdbDataset[self.validPklData[index]] + else: + assert(False) + return data, label +class NShotDataset(Dataset): + ### infPKLFile - the influence file containing the validTrainIdx list + def __init__(self, infPKLData, validTrainPklData, lmdbDataset): + self.infPKLData = infPKLData + self.totalDataImg = len(infPKLData) + self.validTrainPklData = validTrainPklData + self.lmdbDataset = lmdbDataset + def __len__(self): + return self.totalDataImg + def __getitem__(self, index): + data, label = self.lmdbDataset[self.validTrainPklData[self.infPKLData[index]]] + return data, label +class LmdbDataset(Dataset): + + def __init__(self, root, opt, notSelective, maxImages=None): + + self.root = root + self.opt = opt + if self.opt.eval == False: + self.currentInfluenceLS = copy.deepcopy(self.opt.influence_idx) + random.shuffle(self.currentInfluenceLS) + self.notSelective = notSelective + self.selective_sample_ls = set([]) + self.env = lmdb.open(root, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) + if not self.env: + print('cannot create lmdb from %s' % (root)) + sys.exit(0) + + with self.env.begin(write=False) as txn: + nSamples = int(txn.get('num-samples'.encode())) + if maxImages is not None: + nSamples = min(nSamples, maxImages) + self.nSamples = nSamples + + if self.opt.data_filtering_off: + # for fast check or benchmark evaluation with no filtering + self.filtered_index_list = [index + 1 for index in range(self.nSamples)] + else: + """ Filtering part + If you want to evaluate IC15-2077 & CUTE datasets which have special character labels, + use --data_filtering_off and only evaluate on alphabets and digits. + see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L190-L192 + + And if you want to evaluate them with the model trained with --sensitive option, + use --sensitive and --data_filtering_off, + see https://github.com/clovaai/deep-text-recognition-benchmark/blob/dff844874dbe9e0ec8c5a52a7bd08c7f20afe704/test.py#L137-L144 + """ + self.filtered_index_list = [] + for index in range(self.nSamples): + index += 1 # lmdb starts with 1 + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key).decode('utf-8') + + if len(label) > self.opt.batch_max_length: + # print(f'The length of the label is longer than max_length: length + # {len(label)}, {label} in dataset {self.root}') + continue + + # By default, images containing characters which are not in opt.character are filtered. + # You can add [UNK] token to `opt.character` in utils.py instead of this filtering. + out_of_char = f'[^{self.opt.character}]' + if re.search(out_of_char, label.lower()): + continue + + self.filtered_index_list.append(index) + + self.nSamples = len(self.filtered_index_list) + + def __len__(self): + return self.nSamples + + def __getitem__(self, index): + assert index <= len(self), 'index range error' + + ### Used for influence function training + if self.opt.eval == False: + index = self.currentInfluenceLS.pop(len(self.currentInfluenceLS)-1) + if len(self.currentInfluenceLS) <= 0: + self.currentInfluenceLS = copy.deepcopy(self.opt.influence_idx) + random.shuffle(self.currentInfluenceLS) + + while True: + index = self.filtered_index_list[index] + + if self.opt.max_selective_list != -1: + if len(self.selective_sample_ls) >= self.opt.max_selective_list: + self.selective_sample_ls.clear() + + with self.env.begin(write=False) as txn: + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key).decode('utf-8') ### label - raw utf8 string output + if self.opt.selective_sample_str != '' and not self.notSelective: + if self.opt.ignore_case_sensitivity: + if label.lower() != self.opt.selective_sample_str.lower(): + ### Reloop + self.selective_sample_ls.add(index) + while True: + index = random.randint(0, len(self)-1) + if index not in self.selective_sample_ls: break + continue + else: + if label != self.opt.selective_sample_str: + ### Reloop + self.selective_sample_ls.add(index) + while True: + index = random.randint(0, len(self)-1) + if index not in self.selective_sample_ls: break + continue + img_key = 'image-%09d'.encode() % index + imgbuf = txn.get(img_key) + + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + try: + if self.opt.rgb: + img = Image.open(buf).convert('RGB') # for color image + else: + img = Image.open(buf).convert('L') + + except IOError: + print(f'Corrupted image for {index}') + # make dummy image and dummy label for corrupted image. + if self.opt.rgb: + img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) + else: + img = Image.new('L', (self.opt.imgW, self.opt.imgH)) + label = '[dummy_label]' + + if not self.opt.sensitive: + label = label.lower() + + # We only train and evaluate on alphanumerics (or pre-defined character set in train.py) + out_of_char = f'[^{self.opt.character}]' + label = re.sub(out_of_char, '', label) + break + return (img, label) + + +class RawDataset(Dataset): + + def __init__(self, root, opt): + self.opt = opt + self.image_path_list = [] + for dirpath, dirnames, filenames in os.walk(root): + for name in filenames: + _, ext = os.path.splitext(name) + ext = ext.lower() + if ext == '.jpg' or ext == '.jpeg' or ext == '.png': + self.image_path_list.append(os.path.join(dirpath, name)) + + self.image_path_list = natsorted(self.image_path_list) + self.nSamples = len(self.image_path_list) + + def __len__(self): + return self.nSamples + + def __getitem__(self, index): + + try: + if self.opt.rgb: + img = Image.open(self.image_path_list[index]).convert('RGB') # for color image + else: + img = Image.open(self.image_path_list[index]).convert('L') + + except IOError: + print(f'Corrupted image for {index}') + # make dummy image and dummy label for corrupted image. + if self.opt.rgb: + img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) + else: + img = Image.new('L', (self.opt.imgW, self.opt.imgH)) + + return (img, self.image_path_list[index]) + + +def isless(prob=0.5): + return np.random.uniform(0,1) < prob + +class DataAugment(object): + ''' + Supports with and without data augmentation + ''' + def __init__(self, opt): + self.opt = opt + + if not opt.eval: + self.process = [Posterize(), Solarize(), Invert(), Equalize(), AutoContrast(), Sharpness(), Color()] + self.camera = [Contrast(), Brightness(), JpegCompression(), Pixelate()] + + self.pattern = [VGrid(), HGrid(), Grid(), RectGrid(), EllipseGrid()] + + self.noise = [GaussianNoise(), ShotNoise(), ImpulseNoise(), SpeckleNoise()] + self.blur = [GaussianBlur(), DefocusBlur(), MotionBlur(), GlassBlur(), ZoomBlur()] + self.weather = [Fog(), Snow(), Frost(), Rain(), Shadow()] + + self.noises = [self.blur, self.noise, self.weather] + self.processes = [self.camera, self.process] + + self.warp = [Curve(), Distort(), Stretch()] + self.geometry = [Rotate(), Perspective(), Shrink()] + + self.isbaseline_aug = False + # rand augment + if self.opt.isrand_aug: + self.augs = [self.process, self.camera, self.noise, self.blur, self.weather, self.pattern, self.warp, self.geometry] + # semantic augment + elif self.opt.issemantic_aug: + self.geometry = [Rotate(), Perspective(), Shrink()] + self.noise = [GaussianNoise()] + self.blur = [MotionBlur()] + self.augs = [self.noise, self.blur, self.geometry] + self.isbaseline_aug = True + # pp-ocr augment + elif self.opt.islearning_aug: + self.geometry = [Rotate(), Perspective()] + self.noise = [GaussianNoise()] + self.blur = [MotionBlur()] + self.warp = [Distort()] + self.augs = [self.warp, self.noise, self.blur, self.geometry] + self.isbaseline_aug = True + # scatter augment + elif self.opt.isscatter_aug: + self.geometry = [Shrink()] + self.warp = [Distort()] + self.augs = [self.warp, self.geometry] + self.baseline_aug = True + # rotation augment + elif self.opt.isrotation_aug: + self.geometry = [Rotate()] + self.augs = [self.geometry] + self.isbaseline_aug = True + + self.scale = False if opt.Transformer else True + + def __call__(self, img): + ''' + Must call img.copy() if pattern, Rain or Shadow is used + ''' + img = img.resize((self.opt.imgW, self.opt.imgH), Image.BICUBIC) + + if self.opt.eval or isless(self.opt.intact_prob): + pass + elif self.opt.isshap_aug: + img = self.shap_aug(img) + elif self.opt.isrand_aug or self.isbaseline_aug: + img = self.rand_aug(img) + # individual augment can also be selected + elif self.opt.issel_aug: + img = self.sel_aug(img) + + img = transforms.ToTensor()(img) + if self.scale: + img.sub_(0.5).div_(0.5) + return img + + + def rand_aug(self, img): + augs = np.random.choice(self.augs, self.opt.augs_num, replace=False) + for aug in augs: + index = np.random.randint(0, len(aug)) + op = aug[index] + mag = np.random.randint(0, 3) if self.opt.augs_mag is None else self.opt.augs_mag + if type(op).__name__ == "Rain" or type(op).__name__ == "Grid": + img = op(img.copy(), mag=mag) + else: + img = op(img, mag=mag) + + return img + + def shap_aug(self, img): + weatherProb = 0.094624746 + warpProb = 0.204524008 + geometryProb = 0.332274202 + noiseProb = 0.477033377 + cameraProb = 0.57329097 + patternProb = 0.743824929 + processProb = 0.845809948 + blurProb = 0.946237465 + noCorruptProb = 1 + + prob = 1. + iscurve = False + + corrProb = random.uniform(0, 1) + if corrProb >= 0 and corrProb < weatherProb: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.weather)) + op = self.weather[index] + if type(op).__name__ == "Rain": #or "Grid" in type(op).__name__ : + img = op(img.copy(), mag=mag, prob=prob) + else: + img = op(img, mag=mag, prob=prob) + elif corrProb >= weatherProb and corrProb < warpProb: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.warp)) + op = self.warp[index] + if type(op).__name__ == "Curve": + iscurve = True + img = op(img, mag=mag, prob=prob) + elif corrProb >= warpProb and corrProb < geometryProb: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.geometry)) + op = self.geometry[index] + if type(op).__name__ == "Rotate": + img = op(img, iscurve=iscurve, mag=mag, prob=prob) + else: + img = op(img, mag=mag, prob=prob) + elif corrProb >= geometryProb and corrProb < noiseProb: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.noise)) + op = self.noise[index] + img = op(img, mag=mag, prob=prob) + elif corrProb >= noiseProb and corrProb < cameraProb: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.camera)) + op = self.camera[index] + img = op(img, mag=mag, prob=prob) + elif corrProb >= cameraProb and corrProb < patternProb: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.pattern)) + op = self.pattern[index] + img = op(img.copy(), mag=mag, prob=prob) + elif corrProb >= patternProb and corrProb < processProb: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.process)) + op = self.process[index] + img = op(img, mag=mag, prob=prob) + elif corrProb >= processProb and corrProb < blurProb: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.blur)) + op = self.blur[index] + img = op(img, mag=mag, prob=prob) + elif corrProb >= blurProb and corrProb <= noCorruptProb: + pass + + return img + + def sel_aug(self, img): + + prob = 1. + + if self.opt.process: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.process)) + op = self.process[index] + img = op(img, mag=mag, prob=prob) + + if self.opt.noise: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.noise)) + op = self.noise[index] + img = op(img, mag=mag, prob=prob) + + if self.opt.blur: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.blur)) + op = self.blur[index] + img = op(img, mag=mag, prob=prob) + + if self.opt.weather: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.weather)) + op = self.weather[index] + if type(op).__name__ == "Rain": #or "Grid" in type(op).__name__ : + img = op(img.copy(), mag=mag, prob=prob) + else: + img = op(img, mag=mag, prob=prob) + + if self.opt.camera: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.camera)) + op = self.camera[index] + img = op(img, mag=mag, prob=prob) + + if self.opt.pattern: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.pattern)) + op = self.pattern[index] + img = op(img.copy(), mag=mag, prob=prob) + + iscurve = False + if self.opt.warp: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.warp)) + op = self.warp[index] + if type(op).__name__ == "Curve": + iscurve = True + img = op(img, mag=mag, prob=prob) + + if self.opt.geometry: + mag = np.random.randint(self.opt.min_rand, self.opt.max_rand) + index = np.random.randint(0, len(self.geometry)) + op = self.geometry[index] + if type(op).__name__ == "Rotate": + img = op(img, iscurve=iscurve, mag=mag, prob=prob) + else: + img = op(img, mag=mag, prob=prob) + + return img + + +class ResizeNormalize(object): + + def __init__(self, size, interpolation=Image.BICUBIC): + self.size = size + self.interpolation = interpolation + self.toTensor = transforms.ToTensor() + + def __call__(self, img): + img = img.resize(self.size, self.interpolation) + img = self.toTensor(img) + img.sub_(0.5).div_(0.5) + return img + + +class NormalizePAD(object): + + def __init__(self, max_size, PAD_type='right'): + self.toTensor = transforms.ToTensor() + self.max_size = max_size + self.max_width_half = math.floor(max_size[2] / 2) + self.PAD_type = PAD_type + + def __call__(self, img): + img = self.toTensor(img) + img.sub_(0.5).div_(0.5) + c, h, w = img.size() + Pad_img = torch.FloatTensor(*self.max_size).fill_(0) + Pad_img[:, :, :w] = img # right pad + if self.max_size[2] != w: # add border Pad + Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) + + return Pad_img + + +class AlignCollate(object): + + def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, opt=None): + self.imgH = imgH + self.imgW = imgW + self.keep_ratio_with_pad = keep_ratio_with_pad + self.opt = opt + + def __call__(self, batch): + # print("type batch: ", type(batch)) + # print("type batch[0]: ", type(batch[0])) + batch = filter(lambda x: x is not None, batch) + images, labels = zip(*batch) + + if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper + resized_max_w = self.imgW + input_channel = 3 if images[0].mode == 'RGB' else 1 + transform = NormalizePAD((input_channel, self.imgH, resized_max_w)) + + resized_images = [] + for image in images: + w, h = image.size + ratio = w / float(h) + if math.ceil(self.imgH * ratio) > self.imgW: + resized_w = self.imgW + else: + resized_w = math.ceil(self.imgH * ratio) + + resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC) + resized_images.append(transform(resized_image)) + # resized_image.save('./image_test/%d_test.jpg' % w) + + image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0) + + else: + transform = DataAugment(self.opt) + #i = 0 + #for image in images: + # transform(image) + # if i == 1: + # exit(0) + # else: + # i = i + 1 + image_tensors = [transform(image) for image in images] + image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) + + #else: + # transform = ResizeNormalize((self.imgW, self.imgH)) + # image_tensors = [transform(image) for image in images] + # image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) + + return image_tensors, labels + +class STRCharSegmDataset(Dataset): + ### imgRoot - above the ./images folder + ### minCharNum - set to 0 to deactivate. If greater than 0, this dataset will only output + ### images >= minCharNum + def __init__(self, annotFile, imgRoot, transforms, minCharNum=0,\ + charNum=-1, charToQuery=None): + self.transforms = transforms + self.minCharNum = minCharNum + with open(annotFile) as file: + self.lines = file.readlines() + self.filteredLines = [] + for lineStr in self.lines: + splitStr = lineStr.split() + gtLabel = splitStr[-1] + if self.minCharNum > 0 and len(gtLabel) >= self.minCharNum: + if charNum != -1 and gtLabel[charNum] == charToQuery: + self.filteredLines.append(lineStr) + self.totalItems = len(self.filteredLines) + self.imgRoot = imgRoot + + def __len__(self): + return self.totalItems + + def __getitem__(self, index): + lineStr = self.filteredLines[index] + splitStr = lineStr.split() + imgFilename = splitStr[0] + gtLabel = splitStr[-1] + imgPIL = Image.open(os.path.join(self.imgRoot, imgFilename)).convert('L') + imgPIL = self.transforms(imgPIL) + return imgPIL, gtLabel + +### Class simplifying the LMDB reader +class MyLMDBReader(Dataset): + ### indexMap - pass here the file created that maps indices from + ### limitedCharIdx ---> fullLMDBIdx + ### Should be of format = "char1_N" assumed to be getting only labels + ### where the first char is capital N. char1 is the first char. + ### maxImages - set this to a number to reduce dataset size + def __init__(self, root, opt, indexMap=None, charIdx=None, maxImages=None): + self.root = root + self.opt = opt + self.env = lmdb.open(root, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) + self.indexMapList = None + if indexMap is not None: + with open(indexMap, 'rb') as f: + self.indexMapList = pickle.load(f)[charIdx] ### type list + lesserSize = min(len(self.indexMapList), maxImages) + self.indexMapList = self.indexMapList[:lesserSize] + if not self.env: + print('cannot create lmdb from %s' % (root)) + sys.exit(0) + + with self.env.begin(write=False) as txn: + self.nSamples = int(txn.get('num-samples'.encode())) + + if self.opt.data_filtering_off: + # for fast check or benchmark evaluation with no filtering + self.filtered_index_list = [index + 1 for index in range(self.nSamples)] + else: + """ Filtering part + If you want to evaluate IC15-2077 & CUTE datasets which have special character labels, + use --data_filtering_off and only evaluate on alphabets and digits. + see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L190-L192 + + And if you want to evaluate them with the model trained with --sensitive option, + use --sensitive and --data_filtering_off, + see https://github.com/clovaai/deep-text-recognition-benchmark/blob/dff844874dbe9e0ec8c5a52a7bd08c7f20afe704/test.py#L137-L144 + """ + self.filtered_index_list = [] + for index in range(self.nSamples): + index += 1 # lmdb starts with 1 + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key).decode('utf-8') + + if len(label) > self.opt.batch_max_length: + # print(f'The length of the label is longer than max_length: length + # {len(label)}, {label} in dataset {self.root}') + continue + + # By default, images containing characters which are not in opt.character are filtered. + # You can add [UNK] token to `opt.character` in utils.py instead of this filtering. + out_of_char = f'[^{self.opt.character}]' + if re.search(out_of_char, label.lower()): + continue + + self.filtered_index_list.append(index) + + self.nSamples = len(self.filtered_index_list) + + if self.indexMapList is not None: + self.nSamples = len(self.indexMapList) + + def __len__(self): + return self.nSamples + + def __getitem__(self, index): + ### Acquire mapped index of filtered char only dataset + if self.indexMapList is not None: + index = self.indexMapList[index] + # assert index <= len(self), 'index range error' + + while True: + index = self.filtered_index_list[index] + + with self.env.begin(write=False) as txn: + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key).decode('utf-8') ### label - raw utf8 string output + img_key = 'image-%09d'.encode() % index + imgbuf = txn.get(img_key) + + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + try: + if self.opt.rgb: + img = Image.open(buf).convert('RGB') # for color image + else: + img = Image.open(buf).convert('L') + + except IOError: + print(f'Corrupted image for {index}') + # make dummy image and dummy label for corrupted image. + if self.opt.rgb: + img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) + else: + img = Image.new('L', (self.opt.imgW, self.opt.imgH)) + label = '[dummy_label]' + + if not self.opt.sensitive: + label = label.lower() + + # We only train and evaluate on alphanumerics (or pre-defined character set in train.py) + out_of_char = f'[^{self.opt.character}]' + label = re.sub(out_of_char, '', label) + break + return (img, label) + +class LMDBSegmentationDataset(LmdbDataset): + ### segmRootDir - if not None, + def __init__(self, root, opt, notSelective, segmRootDir, maxImages=None): + super().__init__(root, opt, notSelective, maxImages=maxImages) + self.segmRootDir = segmRootDir + + def __getitem__(self, index): + originalIdx = index + assert index <= len(self), 'index range error' + + ### Used for influence function training + if self.opt.eval == False: + index = self.currentInfluenceLS.pop(len(self.currentInfluenceLS)-1) + if len(self.currentInfluenceLS) <= 0: + self.currentInfluenceLS = copy.deepcopy(self.opt.influence_idx) + random.shuffle(self.currentInfluenceLS) + + while True: + index = self.filtered_index_list[index] + + if self.opt.max_selective_list != -1: + if len(self.selective_sample_ls) >= self.opt.max_selective_list: + self.selective_sample_ls.clear() + + with self.env.begin(write=False) as txn: + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key).decode('utf-8') ### label - raw utf8 string output + if self.opt.selective_sample_str != '' and not self.notSelective: + if self.opt.ignore_case_sensitivity: + if label.lower() != self.opt.selective_sample_str.lower(): + ### Reloop + self.selective_sample_ls.add(index) + while True: + index = random.randint(0, len(self)-1) + if index not in self.selective_sample_ls: break + continue + else: + if label != self.opt.selective_sample_str: + ### Reloop + self.selective_sample_ls.add(index) + while True: + index = random.randint(0, len(self)-1) + if index not in self.selective_sample_ls: break + continue + img_key = 'image-%09d'.encode() % index + imgbuf = txn.get(img_key) + + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + try: + if self.opt.rgb: + img = Image.open(buf).convert('RGB') # for color image + else: + img = Image.open(buf).convert('L') + + except IOError: + print(f'Corrupted image for {index}') + # make dummy image and dummy label for corrupted image. + if self.opt.rgb: + img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) + else: + img = Image.new('L', (self.opt.imgW, self.opt.imgH)) + label = '[dummy_label]' + + if not self.opt.sensitive: + label = label.lower() + + # We only train and evaluate on alphanumerics (or pre-defined character set in train.py) + out_of_char = f'[^{self.opt.character}]' + label = re.sub(out_of_char, '', label) + break + + ### Acquire segmentations + with open(self.segmRootDir + "{}.pkl".format(originalIdx), 'rb') as f: + segmData = pickle.load(f) + label = (segmData, label) + return (img, label) + +def tensor2im(image_tensor, imtype=np.uint8): + image_numpy = image_tensor.cpu().float().numpy() + if image_numpy.shape[0] == 1: + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + return image_numpy.astype(imtype) + + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) diff --git a/dataset_abinet.py b/dataset_abinet.py new file mode 100644 index 0000000000000000000000000000000000000000..7d8b95aac412836c72c9c1bf93857276ef8af710 --- /dev/null +++ b/dataset_abinet.py @@ -0,0 +1,296 @@ +import logging +import re + +import cv2 +import lmdb +import six +from fastai.vision import * +from torchvision import transforms + +from transforms import CVColorJitter, CVDeterioration, CVGeometry +from utils_abinet import CharsetMapper, onehot + + +class ImageDataset(Dataset): + "`ImageDataset` read data from LMDB database." + + def __init__(self, + path:PathOrStr, + is_training:bool=True, + img_h:int=32, + img_w:int=100, + max_length:int=25, + check_length:bool=True, + case_sensitive:bool=False, + charset_path:str='data/charset_36.txt', + convert_mode:str='RGB', + data_aug:bool=True, + deteriorate_ratio:float=0., + multiscales:bool=True, + one_hot_y:bool=True, + return_idx:bool=False, + return_raw:bool=False, + **kwargs): + self.path, self.name = Path(path), Path(path).name + assert self.path.is_dir() and self.path.exists(), f"{path} is not a valid directory." + self.convert_mode, self.check_length = convert_mode, check_length + self.img_h, self.img_w = img_h, img_w + self.max_length, self.one_hot_y = max_length, one_hot_y + self.return_idx, self.return_raw = return_idx, return_raw + self.case_sensitive, self.is_training = case_sensitive, is_training + self.data_aug, self.multiscales = data_aug, multiscales + self.charset = CharsetMapper(charset_path, max_length=max_length+1) + self.c = self.charset.num_classes + + self.env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False) + assert self.env, f'Cannot open LMDB dataset from {path}.' + with self.env.begin(write=False) as txn: + self.length = int(txn.get('num-samples'.encode())) + + if self.is_training and self.data_aug: + self.augment_tfs = transforms.Compose([ + CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5), + CVDeterioration(var=20, degrees=6, factor=4, p=0.25), + CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25) + ]) + self.totensor = transforms.ToTensor() + + def __len__(self): return self.length + + def _next_image(self, index): + next_index = random.randint(0, len(self) - 1) + return self.get(next_index) + + def _check_image(self, x, pixels=6): + if x.size[0] <= pixels or x.size[1] <= pixels: return False + else: return True + + def resize_multiscales(self, img, borderType=cv2.BORDER_CONSTANT): + def _resize_ratio(img, ratio, fix_h=True): + if ratio * self.img_w < self.img_h: + if fix_h: trg_h = self.img_h + else: trg_h = int(ratio * self.img_w) + trg_w = self.img_w + else: trg_h, trg_w = self.img_h, int(self.img_h / ratio) + img = cv2.resize(img, (trg_w, trg_h)) + pad_h, pad_w = (self.img_h - trg_h) / 2, (self.img_w - trg_w) / 2 + top, bottom = math.ceil(pad_h), math.floor(pad_h) + left, right = math.ceil(pad_w), math.floor(pad_w) + img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType) + return img + + if self.is_training: + if random.random() < 0.5: + base, maxh, maxw = self.img_h, self.img_h, self.img_w + h, w = random.randint(base, maxh), random.randint(base, maxw) + return _resize_ratio(img, h/w) + else: return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio + else: return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio + + def resize(self, img): + if self.multiscales: return self.resize_multiscales(img, cv2.BORDER_REPLICATE) + else: return cv2.resize(img, (self.img_w, self.img_h)) + + def get(self, idx): + with self.env.begin(write=False) as txn: + image_key, label_key = f'image-{idx+1:09d}', f'label-{idx+1:09d}' + try: + label = str(txn.get(label_key.encode()), 'utf-8') # label + label = re.sub('[^0-9a-zA-Z]+', '', label) + if self.check_length and self.max_length > 0: + if len(label) > self.max_length or len(label) <= 0: + #logging.info(f'Long or short text image is found: {self.name}, {idx}, {label}, {len(label)}') + return self._next_image(idx) + label = label[:self.max_length] + + imgbuf = txn.get(image_key.encode()) # image + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin + image = PIL.Image.open(buf).convert(self.convert_mode) + if self.is_training and not self._check_image(image): + #logging.info(f'Invalid image is found: {self.name}, {idx}, {label}, {len(label)}') + return self._next_image(idx) + except: + import traceback + traceback.print_exc() + # logging.info(f'Corrupted image is found: {self.name}, {idx}, {label}, {len(label)}') + return self._next_image(idx) + return image, label, idx + + def _process_training(self, image): + if self.data_aug: image = self.augment_tfs(image) + image = self.resize(np.array(image)) + return image + + def _process_test(self, image): + return self.resize(np.array(image)) # TODO:move is_training to here + + def __getitem__(self, idx): + image, text, idx_new = self.get(idx) + if not self.is_training: assert idx == idx_new, f'idx {idx} != idx_new {idx_new} during testing.' + + if self.is_training: image = self._process_training(image) + else: image = self._process_test(image) + if self.return_raw: return image, text + image = self.totensor(image) + + length = tensor(len(text) + 1).to(dtype=torch.long) # one for end token + label = self.charset.get_labels(text, case_sensitive=self.case_sensitive) + label = tensor(label).to(dtype=torch.long) + if self.one_hot_y: label = onehot(label, self.charset.num_classes) + + if self.return_idx: y = [label, length, idx_new] + else: y = [label, length] + return image, y +class CustomImageDataset(ImageDataset): + def __getitem__(self, idx): + image, text, idx_new = self.get(idx) + if not self.is_training: assert idx == idx_new, f'idx {idx} != idx_new {idx_new} during testing.' + + if self.is_training: image = self._process_training(image) + else: image = self._process_test(image) + if self.return_raw: return image, text + image = self.totensor(image) + + length = tensor(len(text) + 1).to(dtype=torch.long) # one for end token + label = self.charset.get_labels(text, case_sensitive=self.case_sensitive) + label = tensor(label).to(dtype=torch.long) + if self.one_hot_y: label = onehot(label, self.charset.num_classes) + + if self.return_idx: y = [label, length, idx_new] + else: y = [label, length] + ### Major difference here is that 3 args are returned, with text in the middle + ### y - tensorized (text) + return image, text, y +class TextDataset(Dataset): + def __init__(self, + path:PathOrStr, + delimiter:str='\t', + max_length:int=25, + charset_path:str='data/charset_36.txt', + case_sensitive=False, + one_hot_x=True, + one_hot_y=True, + is_training=True, + smooth_label=False, + smooth_factor=0.2, + use_sm=False, + **kwargs): + self.path = Path(path) + self.case_sensitive, self.use_sm = case_sensitive, use_sm + self.smooth_factor, self.smooth_label = smooth_factor, smooth_label + self.charset = CharsetMapper(charset_path, max_length=max_length+1) + self.one_hot_x, self.one_hot_y, self.is_training = one_hot_x, one_hot_y, is_training + if self.is_training and self.use_sm: self.sm = SpellingMutation(charset=self.charset) + + dtype = {'inp': str, 'gt': str} + self.df = pd.read_csv(self.path, dtype=dtype, delimiter=delimiter, na_filter=False) + self.inp_col, self.gt_col = 0, 1 + + def __len__(self): return len(self.df) + + def __getitem__(self, idx): + text_x = self.df.iloc[idx, self.inp_col] + text_x = re.sub('[^0-9a-zA-Z]+', '', text_x) + if not self.case_sensitive: text_x = text_x.lower() + if self.is_training and self.use_sm: text_x = self.sm(text_x) + + length_x = tensor(len(text_x) + 1).to(dtype=torch.long) # one for end token + label_x = self.charset.get_labels(text_x, case_sensitive=self.case_sensitive) + label_x = tensor(label_x) + if self.one_hot_x: + label_x = onehot(label_x, self.charset.num_classes) + if self.is_training and self.smooth_label: + label_x = torch.stack([self.prob_smooth_label(l) for l in label_x]) + x = [label_x, length_x] + + text_y = self.df.iloc[idx, self.gt_col] + text_y = re.sub('[^0-9a-zA-Z]+', '', text_y) + if not self.case_sensitive: text_y = text_y.lower() + length_y = tensor(len(text_y) + 1).to(dtype=torch.long) # one for end token + label_y = self.charset.get_labels(text_y, case_sensitive=self.case_sensitive) + label_y = tensor(label_y) + if self.one_hot_y: label_y = onehot(label_y, self.charset.num_classes) + y = [label_y, length_y] + + return x, y + + def prob_smooth_label(self, one_hot): + one_hot = one_hot.float() + delta = torch.rand([]) * self.smooth_factor + num_classes = len(one_hot) + noise = torch.rand(num_classes) + noise = noise / noise.sum() * delta + one_hot = one_hot * (1 - delta) + noise + return one_hot + + +class SpellingMutation(object): + def __init__(self, pn0=0.7, pn1=0.85, pn2=0.95, pt0=0.7, pt1=0.85, charset=None): + """ + Args: + pn0: the prob of not modifying characters is (pn0) + pn1: the prob of modifying one characters is (pn1 - pn0) + pn2: the prob of modifying two characters is (pn2 - pn1), + and three (1 - pn2) + pt0: the prob of replacing operation is pt0. + pt1: the prob of inserting operation is (pt1 - pt0), + and deleting operation is (1 - pt1) + """ + super().__init__() + self.pn0, self.pn1, self.pn2 = pn0, pn1, pn2 + self.pt0, self.pt1 = pt0, pt1 + self.charset = charset + logging.info(f'the probs: pn0={self.pn0}, pn1={self.pn1} ' + + f'pn2={self.pn2}, pt0={self.pt0}, pt1={self.pt1}') + + def is_digit(self, text, ratio=0.5): + length = max(len(text), 1) + digit_num = sum([t in self.charset.digits for t in text]) + if digit_num / length < ratio: return False + return True + + def is_unk_char(self, char): + # return char == self.charset.unk_char + return (char not in self.charset.digits) and (char not in self.charset.alphabets) + + def get_num_to_modify(self, length): + prob = random.random() + if prob < self.pn0: num_to_modify = 0 + elif prob < self.pn1: num_to_modify = 1 + elif prob < self.pn2: num_to_modify = 2 + else: num_to_modify = 3 + + if length <= 1: num_to_modify = 0 + elif length >= 2 and length <= 4: num_to_modify = min(num_to_modify, 1) + else: num_to_modify = min(num_to_modify, length // 2) # smaller than length // 2 + return num_to_modify + + def __call__(self, text, debug=False): + if self.is_digit(text): return text + length = len(text) + num_to_modify = self.get_num_to_modify(length) + if num_to_modify <= 0: return text + + chars = [] + index = np.arange(0, length) + random.shuffle(index) + index = index[: num_to_modify] + if debug: self.index = index + for i, t in enumerate(text): + if i not in index: chars.append(t) + elif self.is_unk_char(t): chars.append(t) + else: + prob = random.random() + if prob < self.pt0: # replace + chars.append(random.choice(self.charset.alphabets)) + elif prob < self.pt1: # insert + chars.append(random.choice(self.charset.alphabets)) + chars.append(t) + else: # delete + continue + new_text = ''.join(chars[: self.charset.max_length-1]) + return new_text if len(new_text) >= 1 else text diff --git a/dataset_matrn.py b/dataset_matrn.py new file mode 100644 index 0000000000000000000000000000000000000000..7534250449232e6166b2e9e3311ce3cd4ffcd378 --- /dev/null +++ b/dataset_matrn.py @@ -0,0 +1,299 @@ +import logging +import re + +import cv2 +import lmdb +import six +from fastai.vision import * +from torchvision import transforms + +from transforms import CVColorJitter, CVDeterioration, CVGeometry +from utils_matrn import CharsetMapper, onehot + + +class ImageDataset(Dataset): + "`ImageDataset` read data from LMDB database." + + def __init__(self, + path:PathOrStr, + is_training:bool=True, + img_h:int=32, + img_w:int=100, + max_length:int=25, + check_length:bool=True, + case_sensitive:bool=False, + charset_path:str='data/charset_36.txt', + convert_mode:str='RGB', + data_aug:bool=True, + deteriorate_ratio:float=0., + multiscales:bool=True, + one_hot_y:bool=True, + return_idx:bool=False, + return_raw:bool=False, + **kwargs): + self.path, self.name = Path(path), Path(path).name + assert self.path.is_dir() and self.path.exists(), f"{path} is not a valid directory." + self.convert_mode, self.check_length = convert_mode, check_length + self.img_h, self.img_w = img_h, img_w + self.max_length, self.one_hot_y = max_length, one_hot_y + self.return_idx, self.return_raw = return_idx, return_raw + self.case_sensitive, self.is_training = case_sensitive, is_training + self.data_aug, self.multiscales = data_aug, multiscales + self.charset = CharsetMapper(charset_path, max_length=max_length+1) + self.c = self.charset.num_classes + + self.env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False) + assert self.env, f'Cannot open LMDB dataset from {path}.' + with self.env.begin(write=False) as txn: + self.length = int(txn.get('num-samples'.encode())) + + if self.is_training and self.data_aug: + self.augment_tfs = transforms.Compose([ + CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5), + CVDeterioration(var=20, degrees=6, factor=4, p=0.25), + CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25) + ]) + self.totensor = transforms.ToTensor() + + def __len__(self): return self.length + + def _next_image(self, index): + next_index = random.randint(0, len(self) - 1) + return self.get(next_index) + + def _check_image(self, x, pixels=6): + if x.size[0] <= pixels or x.size[1] <= pixels: return False + else: return True + + def resize_multiscales(self, img, borderType=cv2.BORDER_CONSTANT): + def _resize_ratio(img, ratio, fix_h=True): + if ratio * self.img_w < self.img_h: + if fix_h: trg_h = self.img_h + else: trg_h = int(ratio * self.img_w) + trg_w = self.img_w + else: trg_h, trg_w = self.img_h, int(self.img_h / ratio) + img = cv2.resize(img, (trg_w, trg_h)) + pad_h, pad_w = (self.img_h - trg_h) / 2, (self.img_w - trg_w) / 2 + top, bottom = math.ceil(pad_h), math.floor(pad_h) + left, right = math.ceil(pad_w), math.floor(pad_w) + img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType) + return img + + if self.is_training: + if random.random() < 0.5: + base, maxh, maxw = self.img_h, self.img_h, self.img_w + h, w = random.randint(base, maxh), random.randint(base, maxw) + return _resize_ratio(img, h/w) + else: return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio + else: return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio + + def resize(self, img): + if self.multiscales: return self.resize_multiscales(img, cv2.BORDER_REPLICATE) + else: return cv2.resize(img, (self.img_w, self.img_h)) + + def get(self, idx): + with self.env.begin(write=False) as txn: + image_key, label_key = f'image-{idx+1:09d}', f'label-{idx+1:09d}' + try: + label = str(txn.get(label_key.encode()), 'utf-8') # label + label = re.sub('[^0-9a-zA-Z]+', '', label) + if self.check_length and self.max_length > 0: + if len(label) > self.max_length or len(label) <= 0: + #logging.info(f'Long or short text image is found: {self.name}, {idx}, {label}, {len(label)}') + return self._next_image(idx) + label = label[:self.max_length] + + imgbuf = txn.get(image_key.encode()) # image + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin + image = PIL.Image.open(buf).convert(self.convert_mode) + if self.is_training and not self._check_image(image): + #logging.info(f'Invalid image is found: {self.name}, {idx}, {label}, {len(label)}') + return self._next_image(idx) + except: + import traceback + traceback.print_exc() + logging.info(f'Corrupted image is found: {self.name}, {idx}, {label}, {len(label)}') + return self._next_image(idx) + return image, label, idx + + def _process_training(self, image): + if self.data_aug: image = self.augment_tfs(image) + image = self.resize(np.array(image)) + return image + + def _process_test(self, image): + return self.resize(np.array(image)) # TODO:move is_training to here + + def __getitem__(self, idx): + image, text, idx_new = self.get(idx) + if not self.is_training: assert idx == idx_new, f'idx {idx} != idx_new {idx_new} during testing.' + + if self.is_training: image = self._process_training(image) + else: image = self._process_test(image) + if self.return_raw: return image, text + image = self.totensor(image) + + length = tensor(len(text) + 1).to(dtype=torch.long) # one for end token + label = self.charset.get_labels(text, case_sensitive=self.case_sensitive) + label = tensor(label).to(dtype=torch.long) + if self.one_hot_y: label = onehot(label, self.charset.num_classes) + + if self.return_idx: y = [label, length, idx_new] + else: y = [label, length] + return image, y + + +class TextDataset(Dataset): + def __init__(self, + path:PathOrStr, + delimiter:str='\t', + max_length:int=25, + charset_path:str='data/charset_36.txt', + case_sensitive=False, + one_hot_x=True, + one_hot_y=True, + is_training=True, + smooth_label=False, + smooth_factor=0.2, + use_sm=False, + **kwargs): + self.path = Path(path) + self.case_sensitive, self.use_sm = case_sensitive, use_sm + self.smooth_factor, self.smooth_label = smooth_factor, smooth_label + self.charset = CharsetMapper(charset_path, max_length=max_length+1) + self.one_hot_x, self.one_hot_y, self.is_training = one_hot_x, one_hot_y, is_training + if self.is_training and self.use_sm: self.sm = SpellingMutation(charset=self.charset) + + dtype = {'inp': str, 'gt': str} + self.df = pd.read_csv(self.path, dtype=dtype, delimiter=delimiter, na_filter=False) + self.inp_col, self.gt_col = 0, 1 + + def __len__(self): return len(self.df) + + def __getitem__(self, idx): + text_x = self.df.iloc[idx, self.inp_col] + text_x = re.sub('[^0-9a-zA-Z]+', '', text_x) + if not self.case_sensitive: text_x = text_x.lower() + if self.is_training and self.use_sm: text_x = self.sm(text_x) + + length_x = tensor(len(text_x) + 1).to(dtype=torch.long) # one for end token + label_x = self.charset.get_labels(text_x, case_sensitive=self.case_sensitive) + label_x = tensor(label_x) + if self.one_hot_x: + label_x = onehot(label_x, self.charset.num_classes) + if self.is_training and self.smooth_label: + label_x = torch.stack([self.prob_smooth_label(l) for l in label_x]) + x = [label_x, length_x] + + text_y = self.df.iloc[idx, self.gt_col] + text_y = re.sub('[^0-9a-zA-Z]+', '', text_y) + if not self.case_sensitive: text_y = text_y.lower() + length_y = tensor(len(text_y) + 1).to(dtype=torch.long) # one for end token + label_y = self.charset.get_labels(text_y, case_sensitive=self.case_sensitive) + label_y = tensor(label_y) + if self.one_hot_y: label_y = onehot(label_y, self.charset.num_classes) + y = [label_y, length_y] + + return x, y + + def prob_smooth_label(self, one_hot): + one_hot = one_hot.float() + delta = torch.rand([]) * self.smooth_factor + num_classes = len(one_hot) + noise = torch.rand(num_classes) + noise = noise / noise.sum() * delta + one_hot = one_hot * (1 - delta) + noise + return one_hot + + +class SpellingMutation(object): + def __init__(self, pn0=0.7, pn1=0.85, pn2=0.95, pt0=0.7, pt1=0.85, charset=None): + """ + Args: + pn0: the prob of not modifying characters is (pn0) + pn1: the prob of modifying one characters is (pn1 - pn0) + pn2: the prob of modifying two characters is (pn2 - pn1), + and three (1 - pn2) + pt0: the prob of replacing operation is pt0. + pt1: the prob of inserting operation is (pt1 - pt0), + and deleting operation is (1 - pt1) + """ + super().__init__() + self.pn0, self.pn1, self.pn2 = pn0, pn1, pn2 + self.pt0, self.pt1 = pt0, pt1 + self.charset = charset + logging.info(f'the probs: pn0={self.pn0}, pn1={self.pn1} ' + + f'pn2={self.pn2}, pt0={self.pt0}, pt1={self.pt1}') + + def is_digit(self, text, ratio=0.5): + length = max(len(text), 1) + digit_num = sum([t in self.charset.digits for t in text]) + if digit_num / length < ratio: return False + return True + + def is_unk_char(self, char): + # return char == self.charset.unk_char + return (char not in self.charset.digits) and (char not in self.charset.alphabets) + + def get_num_to_modify(self, length): + prob = random.random() + if prob < self.pn0: num_to_modify = 0 + elif prob < self.pn1: num_to_modify = 1 + elif prob < self.pn2: num_to_modify = 2 + else: num_to_modify = 3 + + if length <= 1: num_to_modify = 0 + elif length >= 2 and length <= 4: num_to_modify = min(num_to_modify, 1) + else: num_to_modify = min(num_to_modify, length // 2) # smaller than length // 2 + return num_to_modify + + def __call__(self, text, debug=False): + if self.is_digit(text): return text + length = len(text) + num_to_modify = self.get_num_to_modify(length) + if num_to_modify <= 0: return text + + chars = [] + index = np.arange(0, length) + random.shuffle(index) + index = index[: num_to_modify] + if debug: self.index = index + for i, t in enumerate(text): + if i not in index: chars.append(t) + elif self.is_unk_char(t): chars.append(t) + else: + prob = random.random() + if prob < self.pt0: # replace + chars.append(random.choice(self.charset.alphabets)) + elif prob < self.pt1: # insert + chars.append(random.choice(self.charset.alphabets)) + chars.append(t) + else: # delete + continue + new_text = ''.join(chars[: self.charset.max_length-1]) + return new_text if len(new_text) >= 1 else text + +class CustomImageDataset(ImageDataset): + def __getitem__(self, idx): + image, text, idx_new = self.get(idx) + if not self.is_training: assert idx == idx_new, f'idx {idx} != idx_new {idx_new} during testing.' + + if self.is_training: image = self._process_training(image) + else: image = self._process_test(image) + if self.return_raw: return image, text + image = self.totensor(image) + + length = tensor(len(text) + 1).to(dtype=torch.long) # one for end token + label = self.charset.get_labels(text, case_sensitive=self.case_sensitive) + label = tensor(label).to(dtype=torch.long) + if self.one_hot_y: label = onehot(label, self.charset.num_classes) + + if self.return_idx: y = [label, length, idx_new] + else: y = [label, length] + ### Major difference here is that 3 args are returned, with text in the middle + ### y - tensorized (text) + return image, text, y diff --git a/dataset_trba.py b/dataset_trba.py new file mode 100644 index 0000000000000000000000000000000000000000..7d1438c523ccdf67e02836cfa5bad74fd73b4b48 --- /dev/null +++ b/dataset_trba.py @@ -0,0 +1,422 @@ +import os +import sys +import re +import six +import math +import lmdb +import torch +import cv2 + +from natsort import natsorted +from PIL import Image +import numpy as np +from torch.utils.data import Dataset, ConcatDataset, Subset +from torch._utils import _accumulate +import torchvision.transforms as transforms +from imagenet_c import corrupt + +class Batch_Balanced_Dataset(object): + + def __init__(self, opt): + """ + Modulate the data ratio in the batch. + For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5", + the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST. + """ + log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') + dashed_line = '-' * 80 + print(dashed_line) + log.write(dashed_line + '\n') + print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}') + log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n') + assert len(opt.select_data) == len(opt.batch_ratio) + + _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) + self.data_loader_list = [] + self.dataloader_iter_list = [] + batch_size_list = [] + Total_batch_size = 0 + for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio): + _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) + print(dashed_line) + log.write(dashed_line + '\n') + _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d]) + total_number_dataset = len(_dataset) + log.write(_dataset_log) + + """ + The total number of data can be modified with opt.total_data_usage_ratio. + ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage. + See 4.2 section in our paper. + """ + number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio)) + dataset_split = [number_dataset, total_number_dataset - number_dataset] + indices = range(total_number_dataset) + _dataset, _ = [Subset(_dataset, indices[offset - length:offset]) + for offset, length in zip(_accumulate(dataset_split), dataset_split)] + selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n' + selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}' + print(selected_d_log) + log.write(selected_d_log + '\n') + batch_size_list.append(str(_batch_size)) + Total_batch_size += _batch_size + + _data_loader = torch.utils.data.DataLoader( + _dataset, batch_size=_batch_size, + shuffle=True, + num_workers=int(opt.workers), + collate_fn=_AlignCollate, pin_memory=True) + self.data_loader_list.append(_data_loader) + self.dataloader_iter_list.append(iter(_data_loader)) + + Total_batch_size_log = f'{dashed_line}\n' + batch_size_sum = '+'.join(batch_size_list) + Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n' + Total_batch_size_log += f'{dashed_line}' + opt.batch_size = Total_batch_size + + print(Total_batch_size_log) + log.write(Total_batch_size_log + '\n') + log.close() + + def get_batch(self): + balanced_batch_images = [] + balanced_batch_texts = [] + + for i, data_loader_iter in enumerate(self.dataloader_iter_list): + try: + image, text = data_loader_iter.next() + balanced_batch_images.append(image) + balanced_batch_texts += text + except StopIteration: + self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) + image, text = self.dataloader_iter_list[i].next() + balanced_batch_images.append(image) + balanced_batch_texts += text + except ValueError: + pass + + balanced_batch_images = torch.cat(balanced_batch_images, 0) + + return balanced_batch_images, balanced_batch_texts + +### targetDir - used for outputting original image +def hierarchical_dataset(root, opt, targetDir, select_data='/'): + """ select_data='/' contains all sub-directory of root directory """ + dataset_list = [] + dataset_log = f'dataset_root: {root}\t dataset: {select_data[0]}' + print(dataset_log) + dataset_log += '\n' + for dirpath, dirnames, filenames in os.walk(root+'/'): + if not dirnames: + select_flag = False + for selected_d in select_data: + if selected_d in dirpath: + select_flag = True + break + + if select_flag: + dataset = LmdbDataset(root=dirpath, opt=opt, targetDir=targetDir) + sub_dataset_log = f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}' + print(sub_dataset_log) + dataset_log += f'{sub_dataset_log}\n' + dataset_list.append(dataset) + + concatenated_dataset = ConcatDataset(dataset_list) + + return concatenated_dataset, dataset_log + + +class LmdbDataset(Dataset): + + def __init__(self, root, opt, targetDir, apply_corruptions=False, severity=1, corruptNum=0): + + self.root = root + self.opt = opt + self.targetDir = targetDir + self.apply_corruptions = opt.apply_corruptions + self.severity = opt.severity + self.corruptNum = opt.corruption_num + self.env = lmdb.open(root, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) + if not self.env: + print('cannot create lmdb from %s' % (root)) + sys.exit(0) + + with self.env.begin(write=False) as txn: + nSamples = int(txn.get('num-samples'.encode())) + self.nSamples = nSamples + + if self.opt.data_filtering_off: + # for fast check or benchmark evaluation with no filtering + self.filtered_index_list = [index + 1 for index in range(self.nSamples)] + else: + """ Filtering part + If you want to evaluate IC15-2077 & CUTE datasets which have special character labels, + use --data_filtering_off and only evaluate on alphabets and digits. + see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L190-L192 + + And if you want to evaluate them with the model trained with --sensitive option, + use --sensitive and --data_filtering_off, + see https://github.com/clovaai/deep-text-recognition-benchmark/blob/dff844874dbe9e0ec8c5a52a7bd08c7f20afe704/test.py#L137-L144 + """ + self.filtered_index_list = [] + for index in range(self.nSamples): + index += 1 # lmdb starts with 1 + if not self.opt.compare_corrupt: + label_key = 'label-%09d'.encode() % index + else: + label_key = 'label-%09d-%04d'.encode() % (index, self.corruptNum) + label = txn.get(label_key).decode('utf-8') + + if len(label) > self.opt.batch_max_length: + # print(f'The length of the label is longer than max_length: length + # {len(label)}, {label} in dataset {self.root}') + continue + + # By default, images containing characters which are not in opt.character are filtered. + # You can add [UNK] token to `opt.character` in utils.py instead of this filtering. + out_of_char = f'[^{self.opt.character}]' + if re.search(out_of_char, label.lower()): + continue + + self.filtered_index_list.append(index) + + self.nSamples = len(self.filtered_index_list) + + def __len__(self): + return self.nSamples + + def __getitem__(self, index): + origIdxInt = index + assert index <= len(self), 'index range error' + index = self.filtered_index_list[index] + if not self.opt.compare_corrupt: + with self.env.begin(write=False) as txn: + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key).decode('utf-8') + img_key = 'image-%09d'.encode() % index + imgbuf = txn.get(img_key) + + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + + if self.opt.output_orig: + if index >= self.opt.min_imgnum+1 and index <= self.opt.max_imgnum+1: + orig_img = Image.open(buf).convert('RGB') + # totalDir = self.targetDir + "img{}/".format(index) + # if not os.path.exists(totalDir): + # os.makedirs(totalDir) + # orig_img.save(totalDir + "corr{}.png".format(self.corruptNum)) + origImgFileName = self.targetDir + "img_{}.png".format(origIdxInt) + orig_img.save(origImgFileName) + if not self.apply_corruptions: + try: + if self.opt.rgb: + img = Image.open(buf).convert('RGB') # for color image + else: + img = Image.open(buf).convert('L') + + except IOError: + print(f'Corrupted image for {index}') + # make dummy image and dummy label for corrupted image. + if self.opt.rgb: + img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) + else: + img = Image.new('L', (self.opt.imgW, self.opt.imgH)) + label = '[dummy_label]' + else: + ### Apply corruption + img = Image.open(buf).convert('RGB') ### 0-255 pillow image value + img = np.array(img).astype(np.uint8) + maxHeight = int(max(224, img.shape[0])) + maxWidth = int(max(224, img.shape[1])) + largerWidth = img.shape[1] > img.shape[0] ### True if width of img is larger than height + if largerWidth: + if img.shape[1] > 224: + imgResized = cv2.resize(img, (224, int(224.*img.shape[0]/img.shape[1]))) + else: + imgResized = img + else: + if img.shape[0] > 224: + imgResized = cv2.resize(img, (int(224.*img.shape[1]/img.shape[0]), 224)) + else: + imgResized = img + tempNP = np.zeros((224, 224, 3)).astype(np.uint8) + tempNP[0:imgResized.shape[0], 0:imgResized.shape[1]] = imgResized + corruptedImg = corrupt(tempNP, severity=self.severity, corruption_number=self.corruptNum) + cropOriginal = corruptedImg[0:imgResized.shape[0], 1:imgResized.shape[1]] + if self.opt.rgb: + img = Image.fromarray(np.uint8(cropOriginal)).convert('RGB') + else: + img = Image.fromarray(np.uint8(cropOriginal)).convert('L') + if not self.opt.sensitive: + label = label.lower() + # We only train and evaluate on alphanumerics (or pre-defined character set in train.py) + out_of_char = f'[^{self.opt.character}]' + label = re.sub(out_of_char, '', label) + return (img, label) + else: + with self.env.begin(write=False) as txn: + label_key = 'label-%09d-%04d'.encode() % (index, self.corruptNum) + label = txn.get(label_key).decode('utf-8') + img_key = 'image-%09d-%04d'.encode() % (index, self.corruptNum) + imgbuf = txn.get(img_key) + + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + + if self.opt.output_orig: + if index >= self.opt.min_imgnum+1 and index <= self.opt.max_imgnum+1: + orig_img = Image.open(buf).convert('RGB') + totalDir = self.targetDir + "img{}/".format(index) + if not os.path.exists(totalDir): + os.makedirs(totalDir) + orig_img.save(totalDir + "corr{}.png".format(self.corruptNum)) + + try: + if self.opt.rgb: + img = Image.open(buf).convert('RGB') # for color image + else: + img = Image.open(buf).convert('L') + + except IOError: + print(f'Corrupted image for {index}') + # make dummy image and dummy label for corrupted image. + if self.opt.rgb: + img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) + else: + img = Image.new('L', (self.opt.imgW, self.opt.imgH)) + label = '[dummy_label]' + + if not self.opt.sensitive: + label = label.lower() + # We only train and evaluate on alphanumerics (or pre-defined character set in train.py) + out_of_char = f'[^{self.opt.character}]' + label = re.sub(out_of_char, '', label) + return (img, label) + +class RawDataset(Dataset): + + def __init__(self, root, opt): + self.opt = opt + self.image_path_list = [] + for dirpath, dirnames, filenames in os.walk(root): + for name in filenames: + _, ext = os.path.splitext(name) + ext = ext.lower() + if ext == '.jpg' or ext == '.jpeg' or ext == '.png': + self.image_path_list.append(os.path.join(dirpath, name)) + + self.image_path_list = natsorted(self.image_path_list) + self.nSamples = len(self.image_path_list) + + def __len__(self): + return self.nSamples + + def __getitem__(self, index): + + try: + if self.opt.rgb: + img = Image.open(self.image_path_list[index]).convert('RGB') # for color image + else: + img = Image.open(self.image_path_list[index]).convert('L') + + except IOError: + print(f'Corrupted image for {index}') + # make dummy image and dummy label for corrupted image. + if self.opt.rgb: + img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) + else: + ### "Passed here" + img = Image.new('L', (self.opt.imgW, self.opt.imgH)) + + return (img, self.image_path_list[index]) + + +class ResizeNormalize(object): + + def __init__(self, size, interpolation=Image.BICUBIC): + self.size = size + self.interpolation = interpolation + self.toTensor = transforms.ToTensor() + + def __call__(self, img): + img = img.resize(self.size, self.interpolation) + img = self.toTensor(img) + img.sub_(0.5).div_(0.5) + return img + + +class NormalizePAD(object): + + def __init__(self, max_size, PAD_type='right'): + self.toTensor = transforms.ToTensor() + self.max_size = max_size + self.max_width_half = math.floor(max_size[2] / 2) + self.PAD_type = PAD_type + + def __call__(self, img): + img = self.toTensor(img) + img.sub_(0.5).div_(0.5) + c, h, w = img.size() + Pad_img = torch.FloatTensor(*self.max_size).fill_(0) + Pad_img[:, :, :w] = img # right pad + if self.max_size[2] != w: # add border Pad + Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) + + return Pad_img + + +class AlignCollate(object): + + def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False): + self.imgH = imgH + self.imgW = imgW + self.keep_ratio_with_pad = keep_ratio_with_pad + + def __call__(self, batch): + ### batch is a list of tuple. The __getitem__() function returns a pillow Image, so you have to use Image.size to see shape + ### print("collate batch[0] shape: ", batch[0][0].size) + batch = filter(lambda x: x is not None, batch) + images, labels = zip(*batch) + + if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper + resized_max_w = self.imgW + input_channel = 3 if images[0].mode == 'RGB' else 1 + transform = NormalizePAD((input_channel, self.imgH, resized_max_w)) + + resized_images = [] + for image in images: + w, h = image.size + ratio = w / float(h) + if math.ceil(self.imgH * ratio) > self.imgW: + resized_w = self.imgW + else: + resized_w = math.ceil(self.imgH * ratio) + + resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC) + resized_images.append(transform(resized_image)) + # resized_image.save('./image_test/%d_test.jpg' % w) + + image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0) + + else: + transform = ResizeNormalize((self.imgW, self.imgH)) + image_tensors = [transform(image) for image in images] + image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) + return image_tensors, labels + + +def tensor2im(image_tensor, imtype=np.uint8): + image_numpy = image_tensor.cpu().float().numpy() + if image_numpy.shape[0] == 1: + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + return image_numpy.astype(imtype) + + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) diff --git a/demo_image/demo_10.jpg b/demo_image/demo_10.jpg new file mode 100644 index 0000000000000000000000000000000000000000..597d3d2303a21c699fda94665e40c6532dc9a106 Binary files /dev/null and b/demo_image/demo_10.jpg differ diff --git a/demo_image/demo_2.jpg b/demo_image/demo_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..01d3f6bd39484d540374dc558da6794f4846ff80 Binary files /dev/null and b/demo_image/demo_2.jpg differ diff --git a/demo_image/demo_8.jpg b/demo_image/demo_8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..de7f6c9d1166940486be2248e8b67f642903abe4 Binary files /dev/null and b/demo_image/demo_8.jpg differ diff --git a/demo_image/demo_9.jpg b/demo_image/demo_9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b02945f52ca32ca90a950769c5e46157bc676823 Binary files /dev/null and b/demo_image/demo_9.jpg differ diff --git a/imagenet_c/__init__.py b/imagenet_c/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23c0ed4799361da33c059727de04136f76eba009 --- /dev/null +++ b/imagenet_c/__init__.py @@ -0,0 +1,35 @@ +import numpy as np +from PIL import Image +from .corruptions import * + +corruption_tuple = (gaussian_noise, shot_noise, impulse_noise, defocus_blur, + glass_blur, motion_blur, zoom_blur, snow, frost, fog, + brightness, contrast, elastic_transform, pixelate, jpeg_compression, + speckle_noise, gaussian_blur, spatter, saturate) + +corruption_dict = {corr_func.__name__: corr_func for corr_func in corruption_tuple} + + +def corrupt(x, severity=1, corruption_name=None, corruption_number=-1): + """ + :param x: image to corrupt; a 224x224x3 numpy array in [0, 255] + :param severity: strength with which to corrupt x; an integer in (0, 5] + :param corruption_name: specifies which corruption function to call; + must be one of 'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', + 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', + 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression', + 'speckle_noise', 'gaussian_blur', 'spatter', 'saturate'; + the last four are validation functions + :param corruption_number: the position of the corruption_name in the above list; + an integer in [0, 18]; useful for easy looping; 15, 16, 17, 18 are validation corruption numbers + :return: the image x corrupted by a corruption function at the given severity; same shape as input + """ + + if corruption_name: + x_corrupted = corruption_dict[corruption_name](Image.fromarray(x), severity) + elif corruption_number != -1: + x_corrupted = corruption_tuple[corruption_number](Image.fromarray(x), severity) + else: + raise ValueError("Either corruption_name or corruption_number must be passed") + + return np.uint8(x_corrupted) diff --git a/imagenet_c/corruptions.py b/imagenet_c/corruptions.py new file mode 100644 index 0000000000000000000000000000000000000000..f0bf2d495dc9514e8e6f800d1d2a753f781f9dd8 --- /dev/null +++ b/imagenet_c/corruptions.py @@ -0,0 +1,427 @@ +# -*- coding: utf-8 -*- + +import numpy as np +from PIL import Image + +# /////////////// Corruption Helpers /////////////// + +import skimage as sk +from skimage.filters import gaussian +from io import BytesIO +from wand.image import Image as WandImage +from wand.api import library as wandlibrary +import wand.color as WandColor +import ctypes +from PIL import Image as PILImage +import cv2 +from scipy.ndimage import zoom as scizoom +from scipy.ndimage.interpolation import map_coordinates +import warnings +import os +from pkg_resources import resource_filename + +warnings.simplefilter("ignore", UserWarning) + + +def disk(radius, alias_blur=0.1, dtype=np.float32): + if radius <= 8: + L = np.arange(-8, 8 + 1) + ksize = (3, 3) + else: + L = np.arange(-radius, radius + 1) + ksize = (5, 5) + X, Y = np.meshgrid(L, L) + aliased_disk = np.array((X ** 2 + Y ** 2) <= radius ** 2, dtype=dtype) + aliased_disk /= np.sum(aliased_disk) + + # supersample disk to antialias + return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur) + + +# Tell Python about the C method +wandlibrary.MagickMotionBlurImage.argtypes = (ctypes.c_void_p, # wand + ctypes.c_double, # radius + ctypes.c_double, # sigma + ctypes.c_double) # angle + + +# Extend wand.image.Image class to include method signature +class MotionImage(WandImage): + def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0): + wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle) + + +# modification of https://github.com/FLHerne/mapgen/blob/master/diamondsquare.py +def plasma_fractal(mapsize=256, wibbledecay=3): + """ + Generate a heightmap using diamond-square algorithm. + Return square 2d array, side length 'mapsize', of floats in range 0-255. + 'mapsize' must be a power of two. + """ + assert (mapsize & (mapsize - 1) == 0) + maparray = np.empty((mapsize, mapsize), dtype=np.float_) + maparray[0, 0] = 0 + stepsize = mapsize + wibble = 100 + + def wibbledmean(array): + return array / 4 + wibble * np.random.uniform(-wibble, wibble, array.shape) + + def fillsquares(): + """For each square of points stepsize apart, + calculate middle value as mean of points + wibble""" + cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize] + squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0) + squareaccum += np.roll(squareaccum, shift=-1, axis=1) + maparray[stepsize // 2:mapsize:stepsize, + stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum) + + def filldiamonds(): + """For each diamond of points stepsize apart, + calculate middle value as mean of points + wibble""" + mapsize = maparray.shape[0] + drgrid = maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize] + ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize] + ldrsum = drgrid + np.roll(drgrid, 1, axis=0) + lulsum = ulgrid + np.roll(ulgrid, -1, axis=1) + ltsum = ldrsum + lulsum + maparray[0:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum) + tdrsum = drgrid + np.roll(drgrid, 1, axis=1) + tulsum = ulgrid + np.roll(ulgrid, -1, axis=0) + ttsum = tdrsum + tulsum + maparray[stepsize // 2:mapsize:stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum) + + while stepsize >= 2: + fillsquares() + filldiamonds() + stepsize //= 2 + wibble /= wibbledecay + + maparray -= maparray.min() + return maparray / maparray.max() + + +def clipped_zoom(img, zoom_factor): + h = img.shape[0] + # ceil crop height(= crop width) + ch = int(np.ceil(h / float(zoom_factor))) + + top = (h - ch) // 2 + img = scizoom(img[top:top + ch, top:top + ch], (zoom_factor, zoom_factor, 1), order=1) + # trim off any extra pixels + trim_top = (img.shape[0] - h) // 2 + + return img[trim_top:trim_top + h, trim_top:trim_top + h] + + +# /////////////// End Corruption Helpers /////////////// + + +# /////////////// Corruptions /////////////// + +def gaussian_noise(x, severity=1): + c = [.08, .12, 0.18, 0.26, 0.38][severity - 1] + + x = np.array(x) / 255. + return np.clip(x + np.random.normal(size=x.shape, scale=c), 0, 1) * 255 + + +def shot_noise(x, severity=1): + c = [60, 25, 12, 5, 3][severity - 1] + + x = np.array(x) / 255. + return np.clip(np.random.poisson(x * c) / float(c), 0, 1) * 255 + + +def impulse_noise(x, severity=1): + c = [.03, .06, .09, 0.17, 0.27][severity - 1] + + x = sk.util.random_noise(np.array(x) / 255., mode='s&p', amount=c) + return np.clip(x, 0, 1) * 255 + + +def speckle_noise(x, severity=1): + c = [.15, .2, 0.35, 0.45, 0.6][severity - 1] + + x = np.array(x) / 255. + return np.clip(x + x * np.random.normal(size=x.shape, scale=c), 0, 1) * 255 + + +def fgsm(x, source_net, severity=1): + c = [8, 16, 32, 64, 128][severity - 1] + + x = V(x, requires_grad=True) + logits = source_net(x) + source_net.zero_grad() + loss = F.cross_entropy(logits, V(logits.data.max(1)[1].squeeze_()), size_average=False) + loss.backward() + + return standardize(torch.clamp(unstandardize(x.data) + c / 255. * unstandardize(torch.sign(x.grad.data)), 0, 1)) + + +def gaussian_blur(x, severity=1): + c = [1, 2, 3, 4, 6][severity - 1] + + x = gaussian(np.array(x) / 255., sigma=c, multichannel=True) + return np.clip(x, 0, 1) * 255 + + +def glass_blur(x, severity=1): + # sigma, max_delta, iterations + c = [(0.7, 1, 2), (0.9, 2, 1), (1, 2, 3), (1.1, 3, 2), (1.5, 4, 2)][severity - 1] + + x = np.uint8(gaussian(np.array(x) / 255., sigma=c[0], multichannel=True) * 255) + + # locally shuffle pixels + for i in range(c[2]): + for h in range(224 - c[1], c[1], -1): + for w in range(224 - c[1], c[1], -1): + dx, dy = np.random.randint(-c[1], c[1], size=(2,)) + h_prime, w_prime = h + dy, w + dx + # swap + x[h, w], x[h_prime, w_prime] = x[h_prime, w_prime], x[h, w] + + return np.clip(gaussian(x / 255., sigma=c[0], multichannel=True), 0, 1) * 255 + + +def defocus_blur(x, severity=1): + c = [(3, 0.1), (4, 0.5), (6, 0.5), (8, 0.5), (10, 0.5)][severity - 1] + + x = np.array(x) / 255. + kernel = disk(radius=c[0], alias_blur=c[1]) + + channels = [] + for d in range(3): + channels.append(cv2.filter2D(x[:, :, d], -1, kernel)) + channels = np.array(channels).transpose((1, 2, 0)) # 3x224x224 -> 224x224x3 + + return np.clip(channels, 0, 1) * 255 + + +def motion_blur(x, severity=1): + c = [(10, 3), (15, 5), (15, 8), (15, 12), (20, 15)][severity - 1] + + output = BytesIO() + x.save(output, format='PNG') + x = MotionImage(blob=output.getvalue()) + + x.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45)) + + x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8), + cv2.IMREAD_UNCHANGED) + + if x.shape != (224, 224): + return np.clip(x[..., [2, 1, 0]], 0, 255) # BGR to RGB + else: # greyscale to RGB + return np.clip(np.array([x, x, x]).transpose((1, 2, 0)), 0, 255) + + +def zoom_blur(x, severity=1): + c = [np.arange(1, 1.11, 0.01), + np.arange(1, 1.16, 0.01), + np.arange(1, 1.21, 0.02), + np.arange(1, 1.26, 0.02), + np.arange(1, 1.31, 0.03)][severity - 1] + + x = (np.array(x) / 255.).astype(np.float32) + out = np.zeros_like(x) + for zoom_factor in c: + out += clipped_zoom(x, zoom_factor) + + x = (x + out) / (len(c) + 1) + return np.clip(x, 0, 1) * 255 + + +def fog(x, severity=1): + c = [(1.5, 2), (2., 2), (2.5, 1.7), (2.5, 1.5), (3., 1.4)][severity - 1] + + x = np.array(x) / 255. + max_val = x.max() + x += c[0] * plasma_fractal(wibbledecay=c[1])[:224, :224][..., np.newaxis] + return np.clip(x * max_val / (max_val + c[0]), 0, 1) * 255 + + +def frost(x, severity=1): + c = [(1, 0.4), + (0.8, 0.6), + (0.7, 0.7), + (0.65, 0.7), + (0.6, 0.75)][severity - 1] + idx = np.random.randint(5) + filename = [resource_filename(__name__, 'frost/frost1.png'), + resource_filename(__name__, 'frost/frost2.png'), + resource_filename(__name__, 'frost/frost3.png'), + resource_filename(__name__, 'frost/frost4.jpg'), + resource_filename(__name__, 'frost/frost5.jpg'), + resource_filename(__name__, 'frost/frost6.jpg')][idx] + frost = cv2.imread(filename) + # randomly crop and convert to rgb + x_start, y_start = np.random.randint(0, frost.shape[0] - 224), np.random.randint(0, frost.shape[1] - 224) + frost = frost[x_start:x_start + 224, y_start:y_start + 224][..., [2, 1, 0]] + + return np.clip(c[0] * np.array(x) + c[1] * frost, 0, 255) + + +def snow(x, severity=1): + c = [(0.1, 0.3, 3, 0.5, 10, 4, 0.8), + (0.2, 0.3, 2, 0.5, 12, 4, 0.7), + (0.55, 0.3, 4, 0.9, 12, 8, 0.7), + (0.55, 0.3, 4.5, 0.85, 12, 8, 0.65), + (0.55, 0.3, 2.5, 0.85, 12, 12, 0.55)][severity - 1] + + x = np.array(x, dtype=np.float32) / 255. + snow_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1]) # [:2] for monochrome + + snow_layer = clipped_zoom(snow_layer[..., np.newaxis], c[2]) + snow_layer[snow_layer < c[3]] = 0 + + snow_layer = PILImage.fromarray((np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode='L') + output = BytesIO() + snow_layer.save(output, format='PNG') + snow_layer = MotionImage(blob=output.getvalue()) + + snow_layer.motion_blur(radius=c[4], sigma=c[5], angle=np.random.uniform(-135, -45)) + + snow_layer = cv2.imdecode(np.fromstring(snow_layer.make_blob(), np.uint8), + cv2.IMREAD_UNCHANGED) / 255. + snow_layer = snow_layer[..., np.newaxis] + + x = c[6] * x + (1 - c[6]) * np.maximum(x, cv2.cvtColor(x, cv2.COLOR_RGB2GRAY).reshape(224, 224, 1) * 1.5 + 0.5) + return np.clip(x + snow_layer + np.rot90(snow_layer, k=2), 0, 1) * 255 + + +def spatter(x, severity=1): + c = [(0.65, 0.3, 4, 0.69, 0.6, 0), + (0.65, 0.3, 3, 0.68, 0.6, 0), + (0.65, 0.3, 2, 0.68, 0.5, 0), + (0.65, 0.3, 1, 0.65, 1.5, 1), + (0.67, 0.4, 1, 0.65, 1.5, 1)][severity - 1] + x = np.array(x, dtype=np.float32) / 255. + + liquid_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1]) + + liquid_layer = gaussian(liquid_layer, sigma=c[2]) + liquid_layer[liquid_layer < c[3]] = 0 + if c[5] == 0: + liquid_layer = (liquid_layer * 255).astype(np.uint8) + dist = 255 - cv2.Canny(liquid_layer, 50, 150) + dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5) + _, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC) + dist = cv2.blur(dist, (3, 3)).astype(np.uint8) + dist = cv2.equalizeHist(dist) + ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]]) + dist = cv2.filter2D(dist, cv2.CV_8U, ker) + dist = cv2.blur(dist, (3, 3)).astype(np.float32) + + m = cv2.cvtColor(liquid_layer * dist, cv2.COLOR_GRAY2BGRA) + m /= np.max(m, axis=(0, 1)) + m *= c[4] + + # water is pale turqouise + color = np.concatenate((175 / 255. * np.ones_like(m[..., :1]), + 238 / 255. * np.ones_like(m[..., :1]), + 238 / 255. * np.ones_like(m[..., :1])), axis=2) + + color = cv2.cvtColor(color, cv2.COLOR_BGR2BGRA) + x = cv2.cvtColor(x, cv2.COLOR_BGR2BGRA) + + return cv2.cvtColor(np.clip(x + m * color, 0, 1), cv2.COLOR_BGRA2BGR) * 255 + else: + m = np.where(liquid_layer > c[3], 1, 0) + m = gaussian(m.astype(np.float32), sigma=c[4]) + m[m < 0.8] = 0 + + # mud brown + color = np.concatenate((63 / 255. * np.ones_like(x[..., :1]), + 42 / 255. * np.ones_like(x[..., :1]), + 20 / 255. * np.ones_like(x[..., :1])), axis=2) + + color *= m[..., np.newaxis] + x *= (1 - m[..., np.newaxis]) + + return np.clip(x + color, 0, 1) * 255 + + +def contrast(x, severity=1): + c = [0.4, .3, .2, .1, .05][severity - 1] + + x = np.array(x) / 255. + means = np.mean(x, axis=(0, 1), keepdims=True) + return np.clip((x - means) * c + means, 0, 1) * 255 + + +def brightness(x, severity=1): + c = [.1, .2, .3, .4, .5][severity - 1] + + x = np.array(x) / 255. + x = sk.color.rgb2hsv(x) + x[:, :, 2] = np.clip(x[:, :, 2] + c, 0, 1) + x = sk.color.hsv2rgb(x) + + return np.clip(x, 0, 1) * 255 + + +def saturate(x, severity=1): + c = [(0.3, 0), (0.1, 0), (2, 0), (5, 0.1), (20, 0.2)][severity - 1] + + x = np.array(x) / 255. + x = sk.color.rgb2hsv(x) + x[:, :, 1] = np.clip(x[:, :, 1] * c[0] + c[1], 0, 1) + x = sk.color.hsv2rgb(x) + + return np.clip(x, 0, 1) * 255 + + +def jpeg_compression(x, severity=1): + c = [25, 18, 15, 10, 7][severity - 1] + + output = BytesIO() + x.save(output, 'JPEG', quality=c) + x = PILImage.open(output) + + return x + + +def pixelate(x, severity=1): + c = [0.6, 0.5, 0.4, 0.3, 0.25][severity - 1] + + x = x.resize((int(224 * c), int(224 * c)), PILImage.BOX) + x = x.resize((224, 224), PILImage.BOX) + + return x + + +# mod of https://gist.github.com/erniejunior/601cdf56d2b424757de5 +def elastic_transform(image, severity=1): + c = [(244 * 2, 244 * 0.7, 244 * 0.1), # 244 should have been 224, but ultimately nothing is incorrect + (244 * 2, 244 * 0.08, 244 * 0.2), + (244 * 0.05, 244 * 0.01, 244 * 0.02), + (244 * 0.07, 244 * 0.01, 244 * 0.02), + (244 * 0.12, 244 * 0.01, 244 * 0.02)][severity - 1] + + image = np.array(image, dtype=np.float32) / 255. + shape = image.shape + shape_size = shape[:2] + + # random affine + center_square = np.float32(shape_size) // 2 + square_size = min(shape_size) // 3 + pts1 = np.float32([center_square + square_size, + [center_square[0] + square_size, center_square[1] - square_size], + center_square - square_size]) + pts2 = pts1 + np.random.uniform(-c[2], c[2], size=pts1.shape).astype(np.float32) + M = cv2.getAffineTransform(pts1, pts2) + image = cv2.warpAffine(image, M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101) + + dx = (gaussian(np.random.uniform(-1, 1, size=shape[:2]), + c[1], mode='reflect', truncate=3) * c[0]).astype(np.float32) + dy = (gaussian(np.random.uniform(-1, 1, size=shape[:2]), + c[1], mode='reflect', truncate=3) * c[0]).astype(np.float32) + dx, dy = dx[..., np.newaxis], dy[..., np.newaxis] + + x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2])) + indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1)) + return np.clip(map_coordinates(image, indices, order=1, mode='reflect').reshape(shape), 0, 1) * 255 + + +# /////////////// End Corruptions /////////////// diff --git a/lime/__init__.py b/lime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lime/bundle.js b/lime/bundle.js new file mode 100644 index 0000000000000000000000000000000000000000..87bf784a039d9f9b7f066ffbd33587c4b5fcd2f4 --- /dev/null +++ b/lime/bundle.js @@ -0,0 +1,37089 @@ +var lime = +/******/ (function(modules) { // webpackBootstrap +/******/ // The module cache +/******/ var installedModules = {}; +/******/ +/******/ // The require function +/******/ function __webpack_require__(moduleId) { +/******/ +/******/ // Check if module is in cache +/******/ if(installedModules[moduleId]) +/******/ return installedModules[moduleId].exports; +/******/ +/******/ // Create a new module (and put it into the cache) +/******/ var module = installedModules[moduleId] = { +/******/ exports: {}, +/******/ id: moduleId, +/******/ loaded: false +/******/ }; +/******/ +/******/ // Execute the module function +/******/ modules[moduleId].call(module.exports, module, module.exports, __webpack_require__); +/******/ +/******/ // Flag the module as loaded +/******/ module.loaded = true; +/******/ +/******/ // Return the exports of the module +/******/ return module.exports; +/******/ } +/******/ +/******/ +/******/ // expose the modules object (__webpack_modules__) +/******/ __webpack_require__.m = modules; +/******/ +/******/ // expose the module cache +/******/ __webpack_require__.c = installedModules; +/******/ +/******/ // __webpack_public_path__ +/******/ __webpack_require__.p = ""; +/******/ +/******/ // Load entry module and return exports +/******/ return __webpack_require__(0); +/******/ }) +/************************************************************************/ +/******/ ([ +/* 0 */ +/***/ (function(module, exports, __webpack_require__) { + + /* WEBPACK VAR INJECTION */(function(global) {'use strict'; + + Object.defineProperty(exports, "__esModule", { + value: true + }); + exports.PredictedValue = exports.PredictProba = exports.Barchart = exports.Explanation = undefined; + + var _explanation = __webpack_require__(1); + + var _explanation2 = _interopRequireDefault(_explanation); + + var _bar_chart = __webpack_require__(3); + + var _bar_chart2 = _interopRequireDefault(_bar_chart); + + var _predict_proba = __webpack_require__(6); + + var _predict_proba2 = _interopRequireDefault(_predict_proba); + + var _predicted_value = __webpack_require__(7); + + var _predicted_value2 = _interopRequireDefault(_predicted_value); + + function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } + + if (!global._babelPolyfill) { + __webpack_require__(8); + } + + __webpack_require__(339); + + exports.Explanation = _explanation2.default; + exports.Barchart = _bar_chart2.default; + exports.PredictProba = _predict_proba2.default; + exports.PredictedValue = _predicted_value2.default; + //require('style-loader'); + /* WEBPACK VAR INJECTION */}.call(exports, (function() { return this; }()))) + +/***/ }), +/* 1 */ +/***/ (function(module, exports, __webpack_require__) { + + 'use strict'; + + Object.defineProperty(exports, "__esModule", { + value: true + }); + + var _slicedToArray = function () { function sliceIterator(arr, i) { var _arr = []; var _n = true; var _d = false; var _e = undefined; try { for (var _i = arr[Symbol.iterator](), _s; !(_n = (_s = _i.next()).done); _n = true) { _arr.push(_s.value); if (i && _arr.length === i) break; } } catch (err) { _d = true; _e = err; } finally { try { if (!_n && _i["return"]) _i["return"](); } finally { if (_d) throw _e; } } return _arr; } return function (arr, i) { if (Array.isArray(arr)) { return arr; } else if (Symbol.iterator in Object(arr)) { return sliceIterator(arr, i); } else { throw new TypeError("Invalid attempt to destructure non-iterable instance"); } }; }(); + + var _d2 = __webpack_require__(2); + + var _d3 = _interopRequireDefault(_d2); + + var _bar_chart = __webpack_require__(3); + + var _bar_chart2 = _interopRequireDefault(_bar_chart); + + var _lodash = __webpack_require__(4); + + function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } + + function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } + + var Explanation = function () { + function Explanation(class_names) { + _classCallCheck(this, Explanation); + + this.names = class_names; + if (class_names.length < 10) { + this.colors = _d3.default.scale.category10().domain(this.names); + this.colors_i = _d3.default.scale.category10().domain((0, _lodash.range)(this.names.length)); + } else { + this.colors = _d3.default.scale.category20().domain(this.names); + this.colors_i = _d3.default.scale.category20().domain((0, _lodash.range)(this.names.length)); + } + } + // exp: [(feature-name, weight), ...] + // label: int + // div: d3 selection + + + Explanation.prototype.show = function show(exp, label, div) { + var svg = div.append('svg').style('width', '100%'); + var colors = ['#5F9EA0', this.colors_i(label)]; + var names = ['NOT ' + this.names[label], this.names[label]]; + if (this.names.length == 2) { + colors = [this.colors_i(0), this.colors_i(1)]; + names = this.names; + } + var plot = new _bar_chart2.default(svg, exp, true, names, colors, true, 10); + svg.style('height', plot.svg_height + 'px'); + }; + // exp has all ocurrences of words, with start index and weight: + // exp = [('word', 132, -0.13), ('word3', 111, 1.3) + + + Explanation.prototype.show_raw_text = function show_raw_text(exp, label, raw, div) { + var opacity = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : true; + + //let colors=['#5F9EA0', this.colors(this.exp['class'])]; + var colors = ['#5F9EA0', this.colors_i(label)]; + if (this.names.length == 2) { + colors = [this.colors_i(0), this.colors_i(1)]; + } + var word_lists = [[], []]; + var max_weight = -1; + var _iteratorNormalCompletion = true; + var _didIteratorError = false; + var _iteratorError = undefined; + + try { + for (var _iterator = exp[Symbol.iterator](), _step; !(_iteratorNormalCompletion = (_step = _iterator.next()).done); _iteratorNormalCompletion = true) { + var _step$value = _slicedToArray(_step.value, 3), + word = _step$value[0], + start = _step$value[1], + weight = _step$value[2]; + + if (weight > 0) { + word_lists[1].push([start, start + word.length, weight]); + } else { + word_lists[0].push([start, start + word.length, -weight]); + } + max_weight = Math.max(max_weight, Math.abs(weight)); + } + } catch (err) { + _didIteratorError = true; + _iteratorError = err; + } finally { + try { + if (!_iteratorNormalCompletion && _iterator.return) { + _iterator.return(); + } + } finally { + if (_didIteratorError) { + throw _iteratorError; + } + } + } + + if (!opacity) { + max_weight = 0; + } + this.display_raw_text(div, raw, word_lists, colors, max_weight, true); + }; + // exp is list of (feature_name, value, weight) + + + Explanation.prototype.show_raw_tabular = function show_raw_tabular(exp, label, div) { + div.classed('lime', true).classed('table_div', true); + var colors = ['#5F9EA0', this.colors_i(label)]; + if (this.names.length == 2) { + colors = [this.colors_i(0), this.colors_i(1)]; + } + var table = div.append('table'); + var thead = table.append('tr'); + thead.append('td').text('Feature'); + thead.append('td').text('Value'); + thead.style('color', 'black').style('font-size', '20px'); + var _iteratorNormalCompletion2 = true; + var _didIteratorError2 = false; + var _iteratorError2 = undefined; + + try { + for (var _iterator2 = exp[Symbol.iterator](), _step2; !(_iteratorNormalCompletion2 = (_step2 = _iterator2.next()).done); _iteratorNormalCompletion2 = true) { + var _step2$value = _slicedToArray(_step2.value, 3), + fname = _step2$value[0], + value = _step2$value[1], + weight = _step2$value[2]; + + var tr = table.append('tr'); + tr.style('border-style', 'hidden'); + tr.append('td').text(fname); + tr.append('td').text(value); + if (weight > 0) { + tr.style('background-color', colors[1]); + } else if (weight < 0) { + tr.style('background-color', colors[0]); + } else { + tr.style('color', 'black'); + } + } + } catch (err) { + _didIteratorError2 = true; + _iteratorError2 = err; + } finally { + try { + if (!_iteratorNormalCompletion2 && _iterator2.return) { + _iterator2.return(); + } + } finally { + if (_didIteratorError2) { + throw _iteratorError2; + } + } + } + }; + + Explanation.prototype.hexToRgb = function hexToRgb(hex) { + var result = /^#?([a-f\d]{2})([a-f\d]{2})([a-f\d]{2})$/i.exec(hex); + return result ? { + r: parseInt(result[1], 16), + g: parseInt(result[2], 16), + b: parseInt(result[3], 16) + } : null; + }; + + Explanation.prototype.applyAlpha = function applyAlpha(hex, alpha) { + var components = this.hexToRgb(hex); + return 'rgba(' + components.r + "," + components.g + "," + components.b + "," + alpha.toFixed(3) + ")"; + }; + // sord_lists is an array of arrays, of length (colors). if with_positions is true, + // word_lists is an array of [start,end] positions instead + + + Explanation.prototype.display_raw_text = function display_raw_text(div, raw_text) { + var word_lists = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : []; + var colors = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : []; + var max_weight = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : 1; + var positions = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : false; + + div.classed('lime', true).classed('text_div', true); + div.append('h3').text('Text with highlighted words'); + var highlight_tag = 'span'; + var text_span = div.append('span').style('white-space', 'pre-wrap').text(raw_text); + var position_lists = word_lists; + if (!positions) { + position_lists = this.wordlists_to_positions(word_lists, raw_text); + } + var objects = []; + var _iteratorNormalCompletion3 = true; + var _didIteratorError3 = false; + var _iteratorError3 = undefined; + + try { + var _loop = function _loop() { + var i = _step3.value; + + position_lists[i].map(function (x) { + return objects.push({ 'label': i, 'start': x[0], 'end': x[1], 'alpha': max_weight === 0 ? 1 : x[2] / max_weight }); + }); + }; + + for (var _iterator3 = (0, _lodash.range)(position_lists.length)[Symbol.iterator](), _step3; !(_iteratorNormalCompletion3 = (_step3 = _iterator3.next()).done); _iteratorNormalCompletion3 = true) { + _loop(); + } + } catch (err) { + _didIteratorError3 = true; + _iteratorError3 = err; + } finally { + try { + if (!_iteratorNormalCompletion3 && _iterator3.return) { + _iterator3.return(); + } + } finally { + if (_didIteratorError3) { + throw _iteratorError3; + } + } + } + + objects = (0, _lodash.sortBy)(objects, function (x) { + return x['start']; + }); + var node = text_span.node().childNodes[0]; + var subtract = 0; + var _iteratorNormalCompletion4 = true; + var _didIteratorError4 = false; + var _iteratorError4 = undefined; + + try { + for (var _iterator4 = objects[Symbol.iterator](), _step4; !(_iteratorNormalCompletion4 = (_step4 = _iterator4.next()).done); _iteratorNormalCompletion4 = true) { + var obj = _step4.value; + + var word = raw_text.slice(obj.start, obj.end); + var start = obj.start - subtract; + var end = obj.end - subtract; + var match = document.createElement(highlight_tag); + match.appendChild(document.createTextNode(word)); + match.style.backgroundColor = this.applyAlpha(colors[obj.label], obj.alpha); + var after = node.splitText(start); + after.nodeValue = after.nodeValue.substring(word.length); + node.parentNode.insertBefore(match, after); + subtract += end; + node = after; + } + } catch (err) { + _didIteratorError4 = true; + _iteratorError4 = err; + } finally { + try { + if (!_iteratorNormalCompletion4 && _iterator4.return) { + _iterator4.return(); + } + } finally { + if (_didIteratorError4) { + throw _iteratorError4; + } + } + } + }; + + Explanation.prototype.wordlists_to_positions = function wordlists_to_positions(word_lists, raw_text) { + var ret = []; + var _iteratorNormalCompletion5 = true; + var _didIteratorError5 = false; + var _iteratorError5 = undefined; + + try { + for (var _iterator5 = word_lists[Symbol.iterator](), _step5; !(_iteratorNormalCompletion5 = (_step5 = _iterator5.next()).done); _iteratorNormalCompletion5 = true) { + var words = _step5.value; + + if (words.length === 0) { + ret.push([]); + continue; + } + var re = new RegExp("\\b(" + words.join('|') + ")\\b", 'gm'); + var temp = void 0; + var list = []; + while ((temp = re.exec(raw_text)) !== null) { + list.push([temp.index, temp.index + temp[0].length]); + } + ret.push(list); + } + } catch (err) { + _didIteratorError5 = true; + _iteratorError5 = err; + } finally { + try { + if (!_iteratorNormalCompletion5 && _iterator5.return) { + _iterator5.return(); + } + } finally { + if (_didIteratorError5) { + throw _iteratorError5; + } + } + } + + return ret; + }; + + return Explanation; + }(); + + exports.default = Explanation; + +/***/ }), +/* 2 */ +/***/ (function(module, exports, __webpack_require__) { + + var __WEBPACK_AMD_DEFINE_FACTORY__, __WEBPACK_AMD_DEFINE_RESULT__;!function() { + var d3 = { + version: "3.5.17" + }; + var d3_arraySlice = [].slice, d3_array = function(list) { + return d3_arraySlice.call(list); + }; + var d3_document = this.document; + function d3_documentElement(node) { + return node && (node.ownerDocument || node.document || node).documentElement; + } + function d3_window(node) { + return node && (node.ownerDocument && node.ownerDocument.defaultView || node.document && node || node.defaultView); + } + if (d3_document) { + try { + d3_array(d3_document.documentElement.childNodes)[0].nodeType; + } catch (e) { + d3_array = function(list) { + var i = list.length, array = new Array(i); + while (i--) array[i] = list[i]; + return array; + }; + } + } + if (!Date.now) Date.now = function() { + return +new Date(); + }; + if (d3_document) { + try { + d3_document.createElement("DIV").style.setProperty("opacity", 0, ""); + } catch (error) { + var d3_element_prototype = this.Element.prototype, d3_element_setAttribute = d3_element_prototype.setAttribute, d3_element_setAttributeNS = d3_element_prototype.setAttributeNS, d3_style_prototype = this.CSSStyleDeclaration.prototype, d3_style_setProperty = d3_style_prototype.setProperty; + d3_element_prototype.setAttribute = function(name, value) { + d3_element_setAttribute.call(this, name, value + ""); + }; + d3_element_prototype.setAttributeNS = function(space, local, value) { + d3_element_setAttributeNS.call(this, space, local, value + ""); + }; + d3_style_prototype.setProperty = function(name, value, priority) { + d3_style_setProperty.call(this, name, value + "", priority); + }; + } + } + d3.ascending = d3_ascending; + function d3_ascending(a, b) { + return a < b ? -1 : a > b ? 1 : a >= b ? 0 : NaN; + } + d3.descending = function(a, b) { + return b < a ? -1 : b > a ? 1 : b >= a ? 0 : NaN; + }; + d3.min = function(array, f) { + var i = -1, n = array.length, a, b; + if (arguments.length === 1) { + while (++i < n) if ((b = array[i]) != null && b >= b) { + a = b; + break; + } + while (++i < n) if ((b = array[i]) != null && a > b) a = b; + } else { + while (++i < n) if ((b = f.call(array, array[i], i)) != null && b >= b) { + a = b; + break; + } + while (++i < n) if ((b = f.call(array, array[i], i)) != null && a > b) a = b; + } + return a; + }; + d3.max = function(array, f) { + var i = -1, n = array.length, a, b; + if (arguments.length === 1) { + while (++i < n) if ((b = array[i]) != null && b >= b) { + a = b; + break; + } + while (++i < n) if ((b = array[i]) != null && b > a) a = b; + } else { + while (++i < n) if ((b = f.call(array, array[i], i)) != null && b >= b) { + a = b; + break; + } + while (++i < n) if ((b = f.call(array, array[i], i)) != null && b > a) a = b; + } + return a; + }; + d3.extent = function(array, f) { + var i = -1, n = array.length, a, b, c; + if (arguments.length === 1) { + while (++i < n) if ((b = array[i]) != null && b >= b) { + a = c = b; + break; + } + while (++i < n) if ((b = array[i]) != null) { + if (a > b) a = b; + if (c < b) c = b; + } + } else { + while (++i < n) if ((b = f.call(array, array[i], i)) != null && b >= b) { + a = c = b; + break; + } + while (++i < n) if ((b = f.call(array, array[i], i)) != null) { + if (a > b) a = b; + if (c < b) c = b; + } + } + return [ a, c ]; + }; + function d3_number(x) { + return x === null ? NaN : +x; + } + function d3_numeric(x) { + return !isNaN(x); + } + d3.sum = function(array, f) { + var s = 0, n = array.length, a, i = -1; + if (arguments.length === 1) { + while (++i < n) if (d3_numeric(a = +array[i])) s += a; + } else { + while (++i < n) if (d3_numeric(a = +f.call(array, array[i], i))) s += a; + } + return s; + }; + d3.mean = function(array, f) { + var s = 0, n = array.length, a, i = -1, j = n; + if (arguments.length === 1) { + while (++i < n) if (d3_numeric(a = d3_number(array[i]))) s += a; else --j; + } else { + while (++i < n) if (d3_numeric(a = d3_number(f.call(array, array[i], i)))) s += a; else --j; + } + if (j) return s / j; + }; + d3.quantile = function(values, p) { + var H = (values.length - 1) * p + 1, h = Math.floor(H), v = +values[h - 1], e = H - h; + return e ? v + e * (values[h] - v) : v; + }; + d3.median = function(array, f) { + var numbers = [], n = array.length, a, i = -1; + if (arguments.length === 1) { + while (++i < n) if (d3_numeric(a = d3_number(array[i]))) numbers.push(a); + } else { + while (++i < n) if (d3_numeric(a = d3_number(f.call(array, array[i], i)))) numbers.push(a); + } + if (numbers.length) return d3.quantile(numbers.sort(d3_ascending), .5); + }; + d3.variance = function(array, f) { + var n = array.length, m = 0, a, d, s = 0, i = -1, j = 0; + if (arguments.length === 1) { + while (++i < n) { + if (d3_numeric(a = d3_number(array[i]))) { + d = a - m; + m += d / ++j; + s += d * (a - m); + } + } + } else { + while (++i < n) { + if (d3_numeric(a = d3_number(f.call(array, array[i], i)))) { + d = a - m; + m += d / ++j; + s += d * (a - m); + } + } + } + if (j > 1) return s / (j - 1); + }; + d3.deviation = function() { + var v = d3.variance.apply(this, arguments); + return v ? Math.sqrt(v) : v; + }; + function d3_bisector(compare) { + return { + left: function(a, x, lo, hi) { + if (arguments.length < 3) lo = 0; + if (arguments.length < 4) hi = a.length; + while (lo < hi) { + var mid = lo + hi >>> 1; + if (compare(a[mid], x) < 0) lo = mid + 1; else hi = mid; + } + return lo; + }, + right: function(a, x, lo, hi) { + if (arguments.length < 3) lo = 0; + if (arguments.length < 4) hi = a.length; + while (lo < hi) { + var mid = lo + hi >>> 1; + if (compare(a[mid], x) > 0) hi = mid; else lo = mid + 1; + } + return lo; + } + }; + } + var d3_bisect = d3_bisector(d3_ascending); + d3.bisectLeft = d3_bisect.left; + d3.bisect = d3.bisectRight = d3_bisect.right; + d3.bisector = function(f) { + return d3_bisector(f.length === 1 ? function(d, x) { + return d3_ascending(f(d), x); + } : f); + }; + d3.shuffle = function(array, i0, i1) { + if ((m = arguments.length) < 3) { + i1 = array.length; + if (m < 2) i0 = 0; + } + var m = i1 - i0, t, i; + while (m) { + i = Math.random() * m-- | 0; + t = array[m + i0], array[m + i0] = array[i + i0], array[i + i0] = t; + } + return array; + }; + d3.permute = function(array, indexes) { + var i = indexes.length, permutes = new Array(i); + while (i--) permutes[i] = array[indexes[i]]; + return permutes; + }; + d3.pairs = function(array) { + var i = 0, n = array.length - 1, p0, p1 = array[0], pairs = new Array(n < 0 ? 0 : n); + while (i < n) pairs[i] = [ p0 = p1, p1 = array[++i] ]; + return pairs; + }; + d3.transpose = function(matrix) { + if (!(n = matrix.length)) return []; + for (var i = -1, m = d3.min(matrix, d3_transposeLength), transpose = new Array(m); ++i < m; ) { + for (var j = -1, n, row = transpose[i] = new Array(n); ++j < n; ) { + row[j] = matrix[j][i]; + } + } + return transpose; + }; + function d3_transposeLength(d) { + return d.length; + } + d3.zip = function() { + return d3.transpose(arguments); + }; + d3.keys = function(map) { + var keys = []; + for (var key in map) keys.push(key); + return keys; + }; + d3.values = function(map) { + var values = []; + for (var key in map) values.push(map[key]); + return values; + }; + d3.entries = function(map) { + var entries = []; + for (var key in map) entries.push({ + key: key, + value: map[key] + }); + return entries; + }; + d3.merge = function(arrays) { + var n = arrays.length, m, i = -1, j = 0, merged, array; + while (++i < n) j += arrays[i].length; + merged = new Array(j); + while (--n >= 0) { + array = arrays[n]; + m = array.length; + while (--m >= 0) { + merged[--j] = array[m]; + } + } + return merged; + }; + var abs = Math.abs; + d3.range = function(start, stop, step) { + if (arguments.length < 3) { + step = 1; + if (arguments.length < 2) { + stop = start; + start = 0; + } + } + if ((stop - start) / step === Infinity) throw new Error("infinite range"); + var range = [], k = d3_range_integerScale(abs(step)), i = -1, j; + start *= k, stop *= k, step *= k; + if (step < 0) while ((j = start + step * ++i) > stop) range.push(j / k); else while ((j = start + step * ++i) < stop) range.push(j / k); + return range; + }; + function d3_range_integerScale(x) { + var k = 1; + while (x * k % 1) k *= 10; + return k; + } + function d3_class(ctor, properties) { + for (var key in properties) { + Object.defineProperty(ctor.prototype, key, { + value: properties[key], + enumerable: false + }); + } + } + d3.map = function(object, f) { + var map = new d3_Map(); + if (object instanceof d3_Map) { + object.forEach(function(key, value) { + map.set(key, value); + }); + } else if (Array.isArray(object)) { + var i = -1, n = object.length, o; + if (arguments.length === 1) while (++i < n) map.set(i, object[i]); else while (++i < n) map.set(f.call(object, o = object[i], i), o); + } else { + for (var key in object) map.set(key, object[key]); + } + return map; + }; + function d3_Map() { + this._ = Object.create(null); + } + var d3_map_proto = "__proto__", d3_map_zero = "\x00"; + d3_class(d3_Map, { + has: d3_map_has, + get: function(key) { + return this._[d3_map_escape(key)]; + }, + set: function(key, value) { + return this._[d3_map_escape(key)] = value; + }, + remove: d3_map_remove, + keys: d3_map_keys, + values: function() { + var values = []; + for (var key in this._) values.push(this._[key]); + return values; + }, + entries: function() { + var entries = []; + for (var key in this._) entries.push({ + key: d3_map_unescape(key), + value: this._[key] + }); + return entries; + }, + size: d3_map_size, + empty: d3_map_empty, + forEach: function(f) { + for (var key in this._) f.call(this, d3_map_unescape(key), this._[key]); + } + }); + function d3_map_escape(key) { + return (key += "") === d3_map_proto || key[0] === d3_map_zero ? d3_map_zero + key : key; + } + function d3_map_unescape(key) { + return (key += "")[0] === d3_map_zero ? key.slice(1) : key; + } + function d3_map_has(key) { + return d3_map_escape(key) in this._; + } + function d3_map_remove(key) { + return (key = d3_map_escape(key)) in this._ && delete this._[key]; + } + function d3_map_keys() { + var keys = []; + for (var key in this._) keys.push(d3_map_unescape(key)); + return keys; + } + function d3_map_size() { + var size = 0; + for (var key in this._) ++size; + return size; + } + function d3_map_empty() { + for (var key in this._) return false; + return true; + } + d3.nest = function() { + var nest = {}, keys = [], sortKeys = [], sortValues, rollup; + function map(mapType, array, depth) { + if (depth >= keys.length) return rollup ? rollup.call(nest, array) : sortValues ? array.sort(sortValues) : array; + var i = -1, n = array.length, key = keys[depth++], keyValue, object, setter, valuesByKey = new d3_Map(), values; + while (++i < n) { + if (values = valuesByKey.get(keyValue = key(object = array[i]))) { + values.push(object); + } else { + valuesByKey.set(keyValue, [ object ]); + } + } + if (mapType) { + object = mapType(); + setter = function(keyValue, values) { + object.set(keyValue, map(mapType, values, depth)); + }; + } else { + object = {}; + setter = function(keyValue, values) { + object[keyValue] = map(mapType, values, depth); + }; + } + valuesByKey.forEach(setter); + return object; + } + function entries(map, depth) { + if (depth >= keys.length) return map; + var array = [], sortKey = sortKeys[depth++]; + map.forEach(function(key, keyMap) { + array.push({ + key: key, + values: entries(keyMap, depth) + }); + }); + return sortKey ? array.sort(function(a, b) { + return sortKey(a.key, b.key); + }) : array; + } + nest.map = function(array, mapType) { + return map(mapType, array, 0); + }; + nest.entries = function(array) { + return entries(map(d3.map, array, 0), 0); + }; + nest.key = function(d) { + keys.push(d); + return nest; + }; + nest.sortKeys = function(order) { + sortKeys[keys.length - 1] = order; + return nest; + }; + nest.sortValues = function(order) { + sortValues = order; + return nest; + }; + nest.rollup = function(f) { + rollup = f; + return nest; + }; + return nest; + }; + d3.set = function(array) { + var set = new d3_Set(); + if (array) for (var i = 0, n = array.length; i < n; ++i) set.add(array[i]); + return set; + }; + function d3_Set() { + this._ = Object.create(null); + } + d3_class(d3_Set, { + has: d3_map_has, + add: function(key) { + this._[d3_map_escape(key += "")] = true; + return key; + }, + remove: d3_map_remove, + values: d3_map_keys, + size: d3_map_size, + empty: d3_map_empty, + forEach: function(f) { + for (var key in this._) f.call(this, d3_map_unescape(key)); + } + }); + d3.behavior = {}; + function d3_identity(d) { + return d; + } + d3.rebind = function(target, source) { + var i = 1, n = arguments.length, method; + while (++i < n) target[method = arguments[i]] = d3_rebind(target, source, source[method]); + return target; + }; + function d3_rebind(target, source, method) { + return function() { + var value = method.apply(source, arguments); + return value === source ? target : value; + }; + } + function d3_vendorSymbol(object, name) { + if (name in object) return name; + name = name.charAt(0).toUpperCase() + name.slice(1); + for (var i = 0, n = d3_vendorPrefixes.length; i < n; ++i) { + var prefixName = d3_vendorPrefixes[i] + name; + if (prefixName in object) return prefixName; + } + } + var d3_vendorPrefixes = [ "webkit", "ms", "moz", "Moz", "o", "O" ]; + function d3_noop() {} + d3.dispatch = function() { + var dispatch = new d3_dispatch(), i = -1, n = arguments.length; + while (++i < n) dispatch[arguments[i]] = d3_dispatch_event(dispatch); + return dispatch; + }; + function d3_dispatch() {} + d3_dispatch.prototype.on = function(type, listener) { + var i = type.indexOf("."), name = ""; + if (i >= 0) { + name = type.slice(i + 1); + type = type.slice(0, i); + } + if (type) return arguments.length < 2 ? this[type].on(name) : this[type].on(name, listener); + if (arguments.length === 2) { + if (listener == null) for (type in this) { + if (this.hasOwnProperty(type)) this[type].on(name, null); + } + return this; + } + }; + function d3_dispatch_event(dispatch) { + var listeners = [], listenerByName = new d3_Map(); + function event() { + var z = listeners, i = -1, n = z.length, l; + while (++i < n) if (l = z[i].on) l.apply(this, arguments); + return dispatch; + } + event.on = function(name, listener) { + var l = listenerByName.get(name), i; + if (arguments.length < 2) return l && l.on; + if (l) { + l.on = null; + listeners = listeners.slice(0, i = listeners.indexOf(l)).concat(listeners.slice(i + 1)); + listenerByName.remove(name); + } + if (listener) listeners.push(listenerByName.set(name, { + on: listener + })); + return dispatch; + }; + return event; + } + d3.event = null; + function d3_eventPreventDefault() { + d3.event.preventDefault(); + } + function d3_eventSource() { + var e = d3.event, s; + while (s = e.sourceEvent) e = s; + return e; + } + function d3_eventDispatch(target) { + var dispatch = new d3_dispatch(), i = 0, n = arguments.length; + while (++i < n) dispatch[arguments[i]] = d3_dispatch_event(dispatch); + dispatch.of = function(thiz, argumentz) { + return function(e1) { + try { + var e0 = e1.sourceEvent = d3.event; + e1.target = target; + d3.event = e1; + dispatch[e1.type].apply(thiz, argumentz); + } finally { + d3.event = e0; + } + }; + }; + return dispatch; + } + d3.requote = function(s) { + return s.replace(d3_requote_re, "\\$&"); + }; + var d3_requote_re = /[\\\^\$\*\+\?\|\[\]\(\)\.\{\}]/g; + var d3_subclass = {}.__proto__ ? function(object, prototype) { + object.__proto__ = prototype; + } : function(object, prototype) { + for (var property in prototype) object[property] = prototype[property]; + }; + function d3_selection(groups) { + d3_subclass(groups, d3_selectionPrototype); + return groups; + } + var d3_select = function(s, n) { + return n.querySelector(s); + }, d3_selectAll = function(s, n) { + return n.querySelectorAll(s); + }, d3_selectMatches = function(n, s) { + var d3_selectMatcher = n.matches || n[d3_vendorSymbol(n, "matchesSelector")]; + d3_selectMatches = function(n, s) { + return d3_selectMatcher.call(n, s); + }; + return d3_selectMatches(n, s); + }; + if (typeof Sizzle === "function") { + d3_select = function(s, n) { + return Sizzle(s, n)[0] || null; + }; + d3_selectAll = Sizzle; + d3_selectMatches = Sizzle.matchesSelector; + } + d3.selection = function() { + return d3.select(d3_document.documentElement); + }; + var d3_selectionPrototype = d3.selection.prototype = []; + d3_selectionPrototype.select = function(selector) { + var subgroups = [], subgroup, subnode, group, node; + selector = d3_selection_selector(selector); + for (var j = -1, m = this.length; ++j < m; ) { + subgroups.push(subgroup = []); + subgroup.parentNode = (group = this[j]).parentNode; + for (var i = -1, n = group.length; ++i < n; ) { + if (node = group[i]) { + subgroup.push(subnode = selector.call(node, node.__data__, i, j)); + if (subnode && "__data__" in node) subnode.__data__ = node.__data__; + } else { + subgroup.push(null); + } + } + } + return d3_selection(subgroups); + }; + function d3_selection_selector(selector) { + return typeof selector === "function" ? selector : function() { + return d3_select(selector, this); + }; + } + d3_selectionPrototype.selectAll = function(selector) { + var subgroups = [], subgroup, node; + selector = d3_selection_selectorAll(selector); + for (var j = -1, m = this.length; ++j < m; ) { + for (var group = this[j], i = -1, n = group.length; ++i < n; ) { + if (node = group[i]) { + subgroups.push(subgroup = d3_array(selector.call(node, node.__data__, i, j))); + subgroup.parentNode = node; + } + } + } + return d3_selection(subgroups); + }; + function d3_selection_selectorAll(selector) { + return typeof selector === "function" ? selector : function() { + return d3_selectAll(selector, this); + }; + } + var d3_nsXhtml = "http://www.w3.org/1999/xhtml"; + var d3_nsPrefix = { + svg: "http://www.w3.org/2000/svg", + xhtml: d3_nsXhtml, + xlink: "http://www.w3.org/1999/xlink", + xml: "http://www.w3.org/XML/1998/namespace", + xmlns: "http://www.w3.org/2000/xmlns/" + }; + d3.ns = { + prefix: d3_nsPrefix, + qualify: function(name) { + var i = name.indexOf(":"), prefix = name; + if (i >= 0 && (prefix = name.slice(0, i)) !== "xmlns") name = name.slice(i + 1); + return d3_nsPrefix.hasOwnProperty(prefix) ? { + space: d3_nsPrefix[prefix], + local: name + } : name; + } + }; + d3_selectionPrototype.attr = function(name, value) { + if (arguments.length < 2) { + if (typeof name === "string") { + var node = this.node(); + name = d3.ns.qualify(name); + return name.local ? node.getAttributeNS(name.space, name.local) : node.getAttribute(name); + } + for (value in name) this.each(d3_selection_attr(value, name[value])); + return this; + } + return this.each(d3_selection_attr(name, value)); + }; + function d3_selection_attr(name, value) { + name = d3.ns.qualify(name); + function attrNull() { + this.removeAttribute(name); + } + function attrNullNS() { + this.removeAttributeNS(name.space, name.local); + } + function attrConstant() { + this.setAttribute(name, value); + } + function attrConstantNS() { + this.setAttributeNS(name.space, name.local, value); + } + function attrFunction() { + var x = value.apply(this, arguments); + if (x == null) this.removeAttribute(name); else this.setAttribute(name, x); + } + function attrFunctionNS() { + var x = value.apply(this, arguments); + if (x == null) this.removeAttributeNS(name.space, name.local); else this.setAttributeNS(name.space, name.local, x); + } + return value == null ? name.local ? attrNullNS : attrNull : typeof value === "function" ? name.local ? attrFunctionNS : attrFunction : name.local ? attrConstantNS : attrConstant; + } + function d3_collapse(s) { + return s.trim().replace(/\s+/g, " "); + } + d3_selectionPrototype.classed = function(name, value) { + if (arguments.length < 2) { + if (typeof name === "string") { + var node = this.node(), n = (name = d3_selection_classes(name)).length, i = -1; + if (value = node.classList) { + while (++i < n) if (!value.contains(name[i])) return false; + } else { + value = node.getAttribute("class"); + while (++i < n) if (!d3_selection_classedRe(name[i]).test(value)) return false; + } + return true; + } + for (value in name) this.each(d3_selection_classed(value, name[value])); + return this; + } + return this.each(d3_selection_classed(name, value)); + }; + function d3_selection_classedRe(name) { + return new RegExp("(?:^|\\s+)" + d3.requote(name) + "(?:\\s+|$)", "g"); + } + function d3_selection_classes(name) { + return (name + "").trim().split(/^|\s+/); + } + function d3_selection_classed(name, value) { + name = d3_selection_classes(name).map(d3_selection_classedName); + var n = name.length; + function classedConstant() { + var i = -1; + while (++i < n) name[i](this, value); + } + function classedFunction() { + var i = -1, x = value.apply(this, arguments); + while (++i < n) name[i](this, x); + } + return typeof value === "function" ? classedFunction : classedConstant; + } + function d3_selection_classedName(name) { + var re = d3_selection_classedRe(name); + return function(node, value) { + if (c = node.classList) return value ? c.add(name) : c.remove(name); + var c = node.getAttribute("class") || ""; + if (value) { + re.lastIndex = 0; + if (!re.test(c)) node.setAttribute("class", d3_collapse(c + " " + name)); + } else { + node.setAttribute("class", d3_collapse(c.replace(re, " "))); + } + }; + } + d3_selectionPrototype.style = function(name, value, priority) { + var n = arguments.length; + if (n < 3) { + if (typeof name !== "string") { + if (n < 2) value = ""; + for (priority in name) this.each(d3_selection_style(priority, name[priority], value)); + return this; + } + if (n < 2) { + var node = this.node(); + return d3_window(node).getComputedStyle(node, null).getPropertyValue(name); + } + priority = ""; + } + return this.each(d3_selection_style(name, value, priority)); + }; + function d3_selection_style(name, value, priority) { + function styleNull() { + this.style.removeProperty(name); + } + function styleConstant() { + this.style.setProperty(name, value, priority); + } + function styleFunction() { + var x = value.apply(this, arguments); + if (x == null) this.style.removeProperty(name); else this.style.setProperty(name, x, priority); + } + return value == null ? styleNull : typeof value === "function" ? styleFunction : styleConstant; + } + d3_selectionPrototype.property = function(name, value) { + if (arguments.length < 2) { + if (typeof name === "string") return this.node()[name]; + for (value in name) this.each(d3_selection_property(value, name[value])); + return this; + } + return this.each(d3_selection_property(name, value)); + }; + function d3_selection_property(name, value) { + function propertyNull() { + delete this[name]; + } + function propertyConstant() { + this[name] = value; + } + function propertyFunction() { + var x = value.apply(this, arguments); + if (x == null) delete this[name]; else this[name] = x; + } + return value == null ? propertyNull : typeof value === "function" ? propertyFunction : propertyConstant; + } + d3_selectionPrototype.text = function(value) { + return arguments.length ? this.each(typeof value === "function" ? function() { + var v = value.apply(this, arguments); + this.textContent = v == null ? "" : v; + } : value == null ? function() { + this.textContent = ""; + } : function() { + this.textContent = value; + }) : this.node().textContent; + }; + d3_selectionPrototype.html = function(value) { + return arguments.length ? this.each(typeof value === "function" ? function() { + var v = value.apply(this, arguments); + this.innerHTML = v == null ? "" : v; + } : value == null ? function() { + this.innerHTML = ""; + } : function() { + this.innerHTML = value; + }) : this.node().innerHTML; + }; + d3_selectionPrototype.append = function(name) { + name = d3_selection_creator(name); + return this.select(function() { + return this.appendChild(name.apply(this, arguments)); + }); + }; + function d3_selection_creator(name) { + function create() { + var document = this.ownerDocument, namespace = this.namespaceURI; + return namespace === d3_nsXhtml && document.documentElement.namespaceURI === d3_nsXhtml ? document.createElement(name) : document.createElementNS(namespace, name); + } + function createNS() { + return this.ownerDocument.createElementNS(name.space, name.local); + } + return typeof name === "function" ? name : (name = d3.ns.qualify(name)).local ? createNS : create; + } + d3_selectionPrototype.insert = function(name, before) { + name = d3_selection_creator(name); + before = d3_selection_selector(before); + return this.select(function() { + return this.insertBefore(name.apply(this, arguments), before.apply(this, arguments) || null); + }); + }; + d3_selectionPrototype.remove = function() { + return this.each(d3_selectionRemove); + }; + function d3_selectionRemove() { + var parent = this.parentNode; + if (parent) parent.removeChild(this); + } + d3_selectionPrototype.data = function(value, key) { + var i = -1, n = this.length, group, node; + if (!arguments.length) { + value = new Array(n = (group = this[0]).length); + while (++i < n) { + if (node = group[i]) { + value[i] = node.__data__; + } + } + return value; + } + function bind(group, groupData) { + var i, n = group.length, m = groupData.length, n0 = Math.min(n, m), updateNodes = new Array(m), enterNodes = new Array(m), exitNodes = new Array(n), node, nodeData; + if (key) { + var nodeByKeyValue = new d3_Map(), keyValues = new Array(n), keyValue; + for (i = -1; ++i < n; ) { + if (node = group[i]) { + if (nodeByKeyValue.has(keyValue = key.call(node, node.__data__, i))) { + exitNodes[i] = node; + } else { + nodeByKeyValue.set(keyValue, node); + } + keyValues[i] = keyValue; + } + } + for (i = -1; ++i < m; ) { + if (!(node = nodeByKeyValue.get(keyValue = key.call(groupData, nodeData = groupData[i], i)))) { + enterNodes[i] = d3_selection_dataNode(nodeData); + } else if (node !== true) { + updateNodes[i] = node; + node.__data__ = nodeData; + } + nodeByKeyValue.set(keyValue, true); + } + for (i = -1; ++i < n; ) { + if (i in keyValues && nodeByKeyValue.get(keyValues[i]) !== true) { + exitNodes[i] = group[i]; + } + } + } else { + for (i = -1; ++i < n0; ) { + node = group[i]; + nodeData = groupData[i]; + if (node) { + node.__data__ = nodeData; + updateNodes[i] = node; + } else { + enterNodes[i] = d3_selection_dataNode(nodeData); + } + } + for (;i < m; ++i) { + enterNodes[i] = d3_selection_dataNode(groupData[i]); + } + for (;i < n; ++i) { + exitNodes[i] = group[i]; + } + } + enterNodes.update = updateNodes; + enterNodes.parentNode = updateNodes.parentNode = exitNodes.parentNode = group.parentNode; + enter.push(enterNodes); + update.push(updateNodes); + exit.push(exitNodes); + } + var enter = d3_selection_enter([]), update = d3_selection([]), exit = d3_selection([]); + if (typeof value === "function") { + while (++i < n) { + bind(group = this[i], value.call(group, group.parentNode.__data__, i)); + } + } else { + while (++i < n) { + bind(group = this[i], value); + } + } + update.enter = function() { + return enter; + }; + update.exit = function() { + return exit; + }; + return update; + }; + function d3_selection_dataNode(data) { + return { + __data__: data + }; + } + d3_selectionPrototype.datum = function(value) { + return arguments.length ? this.property("__data__", value) : this.property("__data__"); + }; + d3_selectionPrototype.filter = function(filter) { + var subgroups = [], subgroup, group, node; + if (typeof filter !== "function") filter = d3_selection_filter(filter); + for (var j = 0, m = this.length; j < m; j++) { + subgroups.push(subgroup = []); + subgroup.parentNode = (group = this[j]).parentNode; + for (var i = 0, n = group.length; i < n; i++) { + if ((node = group[i]) && filter.call(node, node.__data__, i, j)) { + subgroup.push(node); + } + } + } + return d3_selection(subgroups); + }; + function d3_selection_filter(selector) { + return function() { + return d3_selectMatches(this, selector); + }; + } + d3_selectionPrototype.order = function() { + for (var j = -1, m = this.length; ++j < m; ) { + for (var group = this[j], i = group.length - 1, next = group[i], node; --i >= 0; ) { + if (node = group[i]) { + if (next && next !== node.nextSibling) next.parentNode.insertBefore(node, next); + next = node; + } + } + } + return this; + }; + d3_selectionPrototype.sort = function(comparator) { + comparator = d3_selection_sortComparator.apply(this, arguments); + for (var j = -1, m = this.length; ++j < m; ) this[j].sort(comparator); + return this.order(); + }; + function d3_selection_sortComparator(comparator) { + if (!arguments.length) comparator = d3_ascending; + return function(a, b) { + return a && b ? comparator(a.__data__, b.__data__) : !a - !b; + }; + } + d3_selectionPrototype.each = function(callback) { + return d3_selection_each(this, function(node, i, j) { + callback.call(node, node.__data__, i, j); + }); + }; + function d3_selection_each(groups, callback) { + for (var j = 0, m = groups.length; j < m; j++) { + for (var group = groups[j], i = 0, n = group.length, node; i < n; i++) { + if (node = group[i]) callback(node, i, j); + } + } + return groups; + } + d3_selectionPrototype.call = function(callback) { + var args = d3_array(arguments); + callback.apply(args[0] = this, args); + return this; + }; + d3_selectionPrototype.empty = function() { + return !this.node(); + }; + d3_selectionPrototype.node = function() { + for (var j = 0, m = this.length; j < m; j++) { + for (var group = this[j], i = 0, n = group.length; i < n; i++) { + var node = group[i]; + if (node) return node; + } + } + return null; + }; + d3_selectionPrototype.size = function() { + var n = 0; + d3_selection_each(this, function() { + ++n; + }); + return n; + }; + function d3_selection_enter(selection) { + d3_subclass(selection, d3_selection_enterPrototype); + return selection; + } + var d3_selection_enterPrototype = []; + d3.selection.enter = d3_selection_enter; + d3.selection.enter.prototype = d3_selection_enterPrototype; + d3_selection_enterPrototype.append = d3_selectionPrototype.append; + d3_selection_enterPrototype.empty = d3_selectionPrototype.empty; + d3_selection_enterPrototype.node = d3_selectionPrototype.node; + d3_selection_enterPrototype.call = d3_selectionPrototype.call; + d3_selection_enterPrototype.size = d3_selectionPrototype.size; + d3_selection_enterPrototype.select = function(selector) { + var subgroups = [], subgroup, subnode, upgroup, group, node; + for (var j = -1, m = this.length; ++j < m; ) { + upgroup = (group = this[j]).update; + subgroups.push(subgroup = []); + subgroup.parentNode = group.parentNode; + for (var i = -1, n = group.length; ++i < n; ) { + if (node = group[i]) { + subgroup.push(upgroup[i] = subnode = selector.call(group.parentNode, node.__data__, i, j)); + subnode.__data__ = node.__data__; + } else { + subgroup.push(null); + } + } + } + return d3_selection(subgroups); + }; + d3_selection_enterPrototype.insert = function(name, before) { + if (arguments.length < 2) before = d3_selection_enterInsertBefore(this); + return d3_selectionPrototype.insert.call(this, name, before); + }; + function d3_selection_enterInsertBefore(enter) { + var i0, j0; + return function(d, i, j) { + var group = enter[j].update, n = group.length, node; + if (j != j0) j0 = j, i0 = 0; + if (i >= i0) i0 = i + 1; + while (!(node = group[i0]) && ++i0 < n) ; + return node; + }; + } + d3.select = function(node) { + var group; + if (typeof node === "string") { + group = [ d3_select(node, d3_document) ]; + group.parentNode = d3_document.documentElement; + } else { + group = [ node ]; + group.parentNode = d3_documentElement(node); + } + return d3_selection([ group ]); + }; + d3.selectAll = function(nodes) { + var group; + if (typeof nodes === "string") { + group = d3_array(d3_selectAll(nodes, d3_document)); + group.parentNode = d3_document.documentElement; + } else { + group = d3_array(nodes); + group.parentNode = null; + } + return d3_selection([ group ]); + }; + d3_selectionPrototype.on = function(type, listener, capture) { + var n = arguments.length; + if (n < 3) { + if (typeof type !== "string") { + if (n < 2) listener = false; + for (capture in type) this.each(d3_selection_on(capture, type[capture], listener)); + return this; + } + if (n < 2) return (n = this.node()["__on" + type]) && n._; + capture = false; + } + return this.each(d3_selection_on(type, listener, capture)); + }; + function d3_selection_on(type, listener, capture) { + var name = "__on" + type, i = type.indexOf("."), wrap = d3_selection_onListener; + if (i > 0) type = type.slice(0, i); + var filter = d3_selection_onFilters.get(type); + if (filter) type = filter, wrap = d3_selection_onFilter; + function onRemove() { + var l = this[name]; + if (l) { + this.removeEventListener(type, l, l.$); + delete this[name]; + } + } + function onAdd() { + var l = wrap(listener, d3_array(arguments)); + onRemove.call(this); + this.addEventListener(type, this[name] = l, l.$ = capture); + l._ = listener; + } + function removeAll() { + var re = new RegExp("^__on([^.]+)" + d3.requote(type) + "$"), match; + for (var name in this) { + if (match = name.match(re)) { + var l = this[name]; + this.removeEventListener(match[1], l, l.$); + delete this[name]; + } + } + } + return i ? listener ? onAdd : onRemove : listener ? d3_noop : removeAll; + } + var d3_selection_onFilters = d3.map({ + mouseenter: "mouseover", + mouseleave: "mouseout" + }); + if (d3_document) { + d3_selection_onFilters.forEach(function(k) { + if ("on" + k in d3_document) d3_selection_onFilters.remove(k); + }); + } + function d3_selection_onListener(listener, argumentz) { + return function(e) { + var o = d3.event; + d3.event = e; + argumentz[0] = this.__data__; + try { + listener.apply(this, argumentz); + } finally { + d3.event = o; + } + }; + } + function d3_selection_onFilter(listener, argumentz) { + var l = d3_selection_onListener(listener, argumentz); + return function(e) { + var target = this, related = e.relatedTarget; + if (!related || related !== target && !(related.compareDocumentPosition(target) & 8)) { + l.call(target, e); + } + }; + } + var d3_event_dragSelect, d3_event_dragId = 0; + function d3_event_dragSuppress(node) { + var name = ".dragsuppress-" + ++d3_event_dragId, click = "click" + name, w = d3.select(d3_window(node)).on("touchmove" + name, d3_eventPreventDefault).on("dragstart" + name, d3_eventPreventDefault).on("selectstart" + name, d3_eventPreventDefault); + if (d3_event_dragSelect == null) { + d3_event_dragSelect = "onselectstart" in node ? false : d3_vendorSymbol(node.style, "userSelect"); + } + if (d3_event_dragSelect) { + var style = d3_documentElement(node).style, select = style[d3_event_dragSelect]; + style[d3_event_dragSelect] = "none"; + } + return function(suppressClick) { + w.on(name, null); + if (d3_event_dragSelect) style[d3_event_dragSelect] = select; + if (suppressClick) { + var off = function() { + w.on(click, null); + }; + w.on(click, function() { + d3_eventPreventDefault(); + off(); + }, true); + setTimeout(off, 0); + } + }; + } + d3.mouse = function(container) { + return d3_mousePoint(container, d3_eventSource()); + }; + var d3_mouse_bug44083 = this.navigator && /WebKit/.test(this.navigator.userAgent) ? -1 : 0; + function d3_mousePoint(container, e) { + if (e.changedTouches) e = e.changedTouches[0]; + var svg = container.ownerSVGElement || container; + if (svg.createSVGPoint) { + var point = svg.createSVGPoint(); + if (d3_mouse_bug44083 < 0) { + var window = d3_window(container); + if (window.scrollX || window.scrollY) { + svg = d3.select("body").append("svg").style({ + position: "absolute", + top: 0, + left: 0, + margin: 0, + padding: 0, + border: "none" + }, "important"); + var ctm = svg[0][0].getScreenCTM(); + d3_mouse_bug44083 = !(ctm.f || ctm.e); + svg.remove(); + } + } + if (d3_mouse_bug44083) point.x = e.pageX, point.y = e.pageY; else point.x = e.clientX, + point.y = e.clientY; + point = point.matrixTransform(container.getScreenCTM().inverse()); + return [ point.x, point.y ]; + } + var rect = container.getBoundingClientRect(); + return [ e.clientX - rect.left - container.clientLeft, e.clientY - rect.top - container.clientTop ]; + } + d3.touch = function(container, touches, identifier) { + if (arguments.length < 3) identifier = touches, touches = d3_eventSource().changedTouches; + if (touches) for (var i = 0, n = touches.length, touch; i < n; ++i) { + if ((touch = touches[i]).identifier === identifier) { + return d3_mousePoint(container, touch); + } + } + }; + d3.behavior.drag = function() { + var event = d3_eventDispatch(drag, "drag", "dragstart", "dragend"), origin = null, mousedown = dragstart(d3_noop, d3.mouse, d3_window, "mousemove", "mouseup"), touchstart = dragstart(d3_behavior_dragTouchId, d3.touch, d3_identity, "touchmove", "touchend"); + function drag() { + this.on("mousedown.drag", mousedown).on("touchstart.drag", touchstart); + } + function dragstart(id, position, subject, move, end) { + return function() { + var that = this, target = d3.event.target.correspondingElement || d3.event.target, parent = that.parentNode, dispatch = event.of(that, arguments), dragged = 0, dragId = id(), dragName = ".drag" + (dragId == null ? "" : "-" + dragId), dragOffset, dragSubject = d3.select(subject(target)).on(move + dragName, moved).on(end + dragName, ended), dragRestore = d3_event_dragSuppress(target), position0 = position(parent, dragId); + if (origin) { + dragOffset = origin.apply(that, arguments); + dragOffset = [ dragOffset.x - position0[0], dragOffset.y - position0[1] ]; + } else { + dragOffset = [ 0, 0 ]; + } + dispatch({ + type: "dragstart" + }); + function moved() { + var position1 = position(parent, dragId), dx, dy; + if (!position1) return; + dx = position1[0] - position0[0]; + dy = position1[1] - position0[1]; + dragged |= dx | dy; + position0 = position1; + dispatch({ + type: "drag", + x: position1[0] + dragOffset[0], + y: position1[1] + dragOffset[1], + dx: dx, + dy: dy + }); + } + function ended() { + if (!position(parent, dragId)) return; + dragSubject.on(move + dragName, null).on(end + dragName, null); + dragRestore(dragged); + dispatch({ + type: "dragend" + }); + } + }; + } + drag.origin = function(x) { + if (!arguments.length) return origin; + origin = x; + return drag; + }; + return d3.rebind(drag, event, "on"); + }; + function d3_behavior_dragTouchId() { + return d3.event.changedTouches[0].identifier; + } + d3.touches = function(container, touches) { + if (arguments.length < 2) touches = d3_eventSource().touches; + return touches ? d3_array(touches).map(function(touch) { + var point = d3_mousePoint(container, touch); + point.identifier = touch.identifier; + return point; + }) : []; + }; + var ε = 1e-6, ε2 = ε * ε, π = Math.PI, τ = 2 * π, τε = τ - ε, halfπ = π / 2, d3_radians = π / 180, d3_degrees = 180 / π; + function d3_sgn(x) { + return x > 0 ? 1 : x < 0 ? -1 : 0; + } + function d3_cross2d(a, b, c) { + return (b[0] - a[0]) * (c[1] - a[1]) - (b[1] - a[1]) * (c[0] - a[0]); + } + function d3_acos(x) { + return x > 1 ? 0 : x < -1 ? π : Math.acos(x); + } + function d3_asin(x) { + return x > 1 ? halfπ : x < -1 ? -halfπ : Math.asin(x); + } + function d3_sinh(x) { + return ((x = Math.exp(x)) - 1 / x) / 2; + } + function d3_cosh(x) { + return ((x = Math.exp(x)) + 1 / x) / 2; + } + function d3_tanh(x) { + return ((x = Math.exp(2 * x)) - 1) / (x + 1); + } + function d3_haversin(x) { + return (x = Math.sin(x / 2)) * x; + } + var ρ = Math.SQRT2, ρ2 = 2, ρ4 = 4; + d3.interpolateZoom = function(p0, p1) { + var ux0 = p0[0], uy0 = p0[1], w0 = p0[2], ux1 = p1[0], uy1 = p1[1], w1 = p1[2], dx = ux1 - ux0, dy = uy1 - uy0, d2 = dx * dx + dy * dy, i, S; + if (d2 < ε2) { + S = Math.log(w1 / w0) / ρ; + i = function(t) { + return [ ux0 + t * dx, uy0 + t * dy, w0 * Math.exp(ρ * t * S) ]; + }; + } else { + var d1 = Math.sqrt(d2), b0 = (w1 * w1 - w0 * w0 + ρ4 * d2) / (2 * w0 * ρ2 * d1), b1 = (w1 * w1 - w0 * w0 - ρ4 * d2) / (2 * w1 * ρ2 * d1), r0 = Math.log(Math.sqrt(b0 * b0 + 1) - b0), r1 = Math.log(Math.sqrt(b1 * b1 + 1) - b1); + S = (r1 - r0) / ρ; + i = function(t) { + var s = t * S, coshr0 = d3_cosh(r0), u = w0 / (ρ2 * d1) * (coshr0 * d3_tanh(ρ * s + r0) - d3_sinh(r0)); + return [ ux0 + u * dx, uy0 + u * dy, w0 * coshr0 / d3_cosh(ρ * s + r0) ]; + }; + } + i.duration = S * 1e3; + return i; + }; + d3.behavior.zoom = function() { + var view = { + x: 0, + y: 0, + k: 1 + }, translate0, center0, center, size = [ 960, 500 ], scaleExtent = d3_behavior_zoomInfinity, duration = 250, zooming = 0, mousedown = "mousedown.zoom", mousemove = "mousemove.zoom", mouseup = "mouseup.zoom", mousewheelTimer, touchstart = "touchstart.zoom", touchtime, event = d3_eventDispatch(zoom, "zoomstart", "zoom", "zoomend"), x0, x1, y0, y1; + if (!d3_behavior_zoomWheel) { + d3_behavior_zoomWheel = "onwheel" in d3_document ? (d3_behavior_zoomDelta = function() { + return -d3.event.deltaY * (d3.event.deltaMode ? 120 : 1); + }, "wheel") : "onmousewheel" in d3_document ? (d3_behavior_zoomDelta = function() { + return d3.event.wheelDelta; + }, "mousewheel") : (d3_behavior_zoomDelta = function() { + return -d3.event.detail; + }, "MozMousePixelScroll"); + } + function zoom(g) { + g.on(mousedown, mousedowned).on(d3_behavior_zoomWheel + ".zoom", mousewheeled).on("dblclick.zoom", dblclicked).on(touchstart, touchstarted); + } + zoom.event = function(g) { + g.each(function() { + var dispatch = event.of(this, arguments), view1 = view; + if (d3_transitionInheritId) { + d3.select(this).transition().each("start.zoom", function() { + view = this.__chart__ || { + x: 0, + y: 0, + k: 1 + }; + zoomstarted(dispatch); + }).tween("zoom:zoom", function() { + var dx = size[0], dy = size[1], cx = center0 ? center0[0] : dx / 2, cy = center0 ? center0[1] : dy / 2, i = d3.interpolateZoom([ (cx - view.x) / view.k, (cy - view.y) / view.k, dx / view.k ], [ (cx - view1.x) / view1.k, (cy - view1.y) / view1.k, dx / view1.k ]); + return function(t) { + var l = i(t), k = dx / l[2]; + this.__chart__ = view = { + x: cx - l[0] * k, + y: cy - l[1] * k, + k: k + }; + zoomed(dispatch); + }; + }).each("interrupt.zoom", function() { + zoomended(dispatch); + }).each("end.zoom", function() { + zoomended(dispatch); + }); + } else { + this.__chart__ = view; + zoomstarted(dispatch); + zoomed(dispatch); + zoomended(dispatch); + } + }); + }; + zoom.translate = function(_) { + if (!arguments.length) return [ view.x, view.y ]; + view = { + x: +_[0], + y: +_[1], + k: view.k + }; + rescale(); + return zoom; + }; + zoom.scale = function(_) { + if (!arguments.length) return view.k; + view = { + x: view.x, + y: view.y, + k: null + }; + scaleTo(+_); + rescale(); + return zoom; + }; + zoom.scaleExtent = function(_) { + if (!arguments.length) return scaleExtent; + scaleExtent = _ == null ? d3_behavior_zoomInfinity : [ +_[0], +_[1] ]; + return zoom; + }; + zoom.center = function(_) { + if (!arguments.length) return center; + center = _ && [ +_[0], +_[1] ]; + return zoom; + }; + zoom.size = function(_) { + if (!arguments.length) return size; + size = _ && [ +_[0], +_[1] ]; + return zoom; + }; + zoom.duration = function(_) { + if (!arguments.length) return duration; + duration = +_; + return zoom; + }; + zoom.x = function(z) { + if (!arguments.length) return x1; + x1 = z; + x0 = z.copy(); + view = { + x: 0, + y: 0, + k: 1 + }; + return zoom; + }; + zoom.y = function(z) { + if (!arguments.length) return y1; + y1 = z; + y0 = z.copy(); + view = { + x: 0, + y: 0, + k: 1 + }; + return zoom; + }; + function location(p) { + return [ (p[0] - view.x) / view.k, (p[1] - view.y) / view.k ]; + } + function point(l) { + return [ l[0] * view.k + view.x, l[1] * view.k + view.y ]; + } + function scaleTo(s) { + view.k = Math.max(scaleExtent[0], Math.min(scaleExtent[1], s)); + } + function translateTo(p, l) { + l = point(l); + view.x += p[0] - l[0]; + view.y += p[1] - l[1]; + } + function zoomTo(that, p, l, k) { + that.__chart__ = { + x: view.x, + y: view.y, + k: view.k + }; + scaleTo(Math.pow(2, k)); + translateTo(center0 = p, l); + that = d3.select(that); + if (duration > 0) that = that.transition().duration(duration); + that.call(zoom.event); + } + function rescale() { + if (x1) x1.domain(x0.range().map(function(x) { + return (x - view.x) / view.k; + }).map(x0.invert)); + if (y1) y1.domain(y0.range().map(function(y) { + return (y - view.y) / view.k; + }).map(y0.invert)); + } + function zoomstarted(dispatch) { + if (!zooming++) dispatch({ + type: "zoomstart" + }); + } + function zoomed(dispatch) { + rescale(); + dispatch({ + type: "zoom", + scale: view.k, + translate: [ view.x, view.y ] + }); + } + function zoomended(dispatch) { + if (!--zooming) dispatch({ + type: "zoomend" + }), center0 = null; + } + function mousedowned() { + var that = this, dispatch = event.of(that, arguments), dragged = 0, subject = d3.select(d3_window(that)).on(mousemove, moved).on(mouseup, ended), location0 = location(d3.mouse(that)), dragRestore = d3_event_dragSuppress(that); + d3_selection_interrupt.call(that); + zoomstarted(dispatch); + function moved() { + dragged = 1; + translateTo(d3.mouse(that), location0); + zoomed(dispatch); + } + function ended() { + subject.on(mousemove, null).on(mouseup, null); + dragRestore(dragged); + zoomended(dispatch); + } + } + function touchstarted() { + var that = this, dispatch = event.of(that, arguments), locations0 = {}, distance0 = 0, scale0, zoomName = ".zoom-" + d3.event.changedTouches[0].identifier, touchmove = "touchmove" + zoomName, touchend = "touchend" + zoomName, targets = [], subject = d3.select(that), dragRestore = d3_event_dragSuppress(that); + started(); + zoomstarted(dispatch); + subject.on(mousedown, null).on(touchstart, started); + function relocate() { + var touches = d3.touches(that); + scale0 = view.k; + touches.forEach(function(t) { + if (t.identifier in locations0) locations0[t.identifier] = location(t); + }); + return touches; + } + function started() { + var target = d3.event.target; + d3.select(target).on(touchmove, moved).on(touchend, ended); + targets.push(target); + var changed = d3.event.changedTouches; + for (var i = 0, n = changed.length; i < n; ++i) { + locations0[changed[i].identifier] = null; + } + var touches = relocate(), now = Date.now(); + if (touches.length === 1) { + if (now - touchtime < 500) { + var p = touches[0]; + zoomTo(that, p, locations0[p.identifier], Math.floor(Math.log(view.k) / Math.LN2) + 1); + d3_eventPreventDefault(); + } + touchtime = now; + } else if (touches.length > 1) { + var p = touches[0], q = touches[1], dx = p[0] - q[0], dy = p[1] - q[1]; + distance0 = dx * dx + dy * dy; + } + } + function moved() { + var touches = d3.touches(that), p0, l0, p1, l1; + d3_selection_interrupt.call(that); + for (var i = 0, n = touches.length; i < n; ++i, l1 = null) { + p1 = touches[i]; + if (l1 = locations0[p1.identifier]) { + if (l0) break; + p0 = p1, l0 = l1; + } + } + if (l1) { + var distance1 = (distance1 = p1[0] - p0[0]) * distance1 + (distance1 = p1[1] - p0[1]) * distance1, scale1 = distance0 && Math.sqrt(distance1 / distance0); + p0 = [ (p0[0] + p1[0]) / 2, (p0[1] + p1[1]) / 2 ]; + l0 = [ (l0[0] + l1[0]) / 2, (l0[1] + l1[1]) / 2 ]; + scaleTo(scale1 * scale0); + } + touchtime = null; + translateTo(p0, l0); + zoomed(dispatch); + } + function ended() { + if (d3.event.touches.length) { + var changed = d3.event.changedTouches; + for (var i = 0, n = changed.length; i < n; ++i) { + delete locations0[changed[i].identifier]; + } + for (var identifier in locations0) { + return void relocate(); + } + } + d3.selectAll(targets).on(zoomName, null); + subject.on(mousedown, mousedowned).on(touchstart, touchstarted); + dragRestore(); + zoomended(dispatch); + } + } + function mousewheeled() { + var dispatch = event.of(this, arguments); + if (mousewheelTimer) clearTimeout(mousewheelTimer); else d3_selection_interrupt.call(this), + translate0 = location(center0 = center || d3.mouse(this)), zoomstarted(dispatch); + mousewheelTimer = setTimeout(function() { + mousewheelTimer = null; + zoomended(dispatch); + }, 50); + d3_eventPreventDefault(); + scaleTo(Math.pow(2, d3_behavior_zoomDelta() * .002) * view.k); + translateTo(center0, translate0); + zoomed(dispatch); + } + function dblclicked() { + var p = d3.mouse(this), k = Math.log(view.k) / Math.LN2; + zoomTo(this, p, location(p), d3.event.shiftKey ? Math.ceil(k) - 1 : Math.floor(k) + 1); + } + return d3.rebind(zoom, event, "on"); + }; + var d3_behavior_zoomInfinity = [ 0, Infinity ], d3_behavior_zoomDelta, d3_behavior_zoomWheel; + d3.color = d3_color; + function d3_color() {} + d3_color.prototype.toString = function() { + return this.rgb() + ""; + }; + d3.hsl = d3_hsl; + function d3_hsl(h, s, l) { + return this instanceof d3_hsl ? void (this.h = +h, this.s = +s, this.l = +l) : arguments.length < 2 ? h instanceof d3_hsl ? new d3_hsl(h.h, h.s, h.l) : d3_rgb_parse("" + h, d3_rgb_hsl, d3_hsl) : new d3_hsl(h, s, l); + } + var d3_hslPrototype = d3_hsl.prototype = new d3_color(); + d3_hslPrototype.brighter = function(k) { + k = Math.pow(.7, arguments.length ? k : 1); + return new d3_hsl(this.h, this.s, this.l / k); + }; + d3_hslPrototype.darker = function(k) { + k = Math.pow(.7, arguments.length ? k : 1); + return new d3_hsl(this.h, this.s, k * this.l); + }; + d3_hslPrototype.rgb = function() { + return d3_hsl_rgb(this.h, this.s, this.l); + }; + function d3_hsl_rgb(h, s, l) { + var m1, m2; + h = isNaN(h) ? 0 : (h %= 360) < 0 ? h + 360 : h; + s = isNaN(s) ? 0 : s < 0 ? 0 : s > 1 ? 1 : s; + l = l < 0 ? 0 : l > 1 ? 1 : l; + m2 = l <= .5 ? l * (1 + s) : l + s - l * s; + m1 = 2 * l - m2; + function v(h) { + if (h > 360) h -= 360; else if (h < 0) h += 360; + if (h < 60) return m1 + (m2 - m1) * h / 60; + if (h < 180) return m2; + if (h < 240) return m1 + (m2 - m1) * (240 - h) / 60; + return m1; + } + function vv(h) { + return Math.round(v(h) * 255); + } + return new d3_rgb(vv(h + 120), vv(h), vv(h - 120)); + } + d3.hcl = d3_hcl; + function d3_hcl(h, c, l) { + return this instanceof d3_hcl ? void (this.h = +h, this.c = +c, this.l = +l) : arguments.length < 2 ? h instanceof d3_hcl ? new d3_hcl(h.h, h.c, h.l) : h instanceof d3_lab ? d3_lab_hcl(h.l, h.a, h.b) : d3_lab_hcl((h = d3_rgb_lab((h = d3.rgb(h)).r, h.g, h.b)).l, h.a, h.b) : new d3_hcl(h, c, l); + } + var d3_hclPrototype = d3_hcl.prototype = new d3_color(); + d3_hclPrototype.brighter = function(k) { + return new d3_hcl(this.h, this.c, Math.min(100, this.l + d3_lab_K * (arguments.length ? k : 1))); + }; + d3_hclPrototype.darker = function(k) { + return new d3_hcl(this.h, this.c, Math.max(0, this.l - d3_lab_K * (arguments.length ? k : 1))); + }; + d3_hclPrototype.rgb = function() { + return d3_hcl_lab(this.h, this.c, this.l).rgb(); + }; + function d3_hcl_lab(h, c, l) { + if (isNaN(h)) h = 0; + if (isNaN(c)) c = 0; + return new d3_lab(l, Math.cos(h *= d3_radians) * c, Math.sin(h) * c); + } + d3.lab = d3_lab; + function d3_lab(l, a, b) { + return this instanceof d3_lab ? void (this.l = +l, this.a = +a, this.b = +b) : arguments.length < 2 ? l instanceof d3_lab ? new d3_lab(l.l, l.a, l.b) : l instanceof d3_hcl ? d3_hcl_lab(l.h, l.c, l.l) : d3_rgb_lab((l = d3_rgb(l)).r, l.g, l.b) : new d3_lab(l, a, b); + } + var d3_lab_K = 18; + var d3_lab_X = .95047, d3_lab_Y = 1, d3_lab_Z = 1.08883; + var d3_labPrototype = d3_lab.prototype = new d3_color(); + d3_labPrototype.brighter = function(k) { + return new d3_lab(Math.min(100, this.l + d3_lab_K * (arguments.length ? k : 1)), this.a, this.b); + }; + d3_labPrototype.darker = function(k) { + return new d3_lab(Math.max(0, this.l - d3_lab_K * (arguments.length ? k : 1)), this.a, this.b); + }; + d3_labPrototype.rgb = function() { + return d3_lab_rgb(this.l, this.a, this.b); + }; + function d3_lab_rgb(l, a, b) { + var y = (l + 16) / 116, x = y + a / 500, z = y - b / 200; + x = d3_lab_xyz(x) * d3_lab_X; + y = d3_lab_xyz(y) * d3_lab_Y; + z = d3_lab_xyz(z) * d3_lab_Z; + return new d3_rgb(d3_xyz_rgb(3.2404542 * x - 1.5371385 * y - .4985314 * z), d3_xyz_rgb(-.969266 * x + 1.8760108 * y + .041556 * z), d3_xyz_rgb(.0556434 * x - .2040259 * y + 1.0572252 * z)); + } + function d3_lab_hcl(l, a, b) { + return l > 0 ? new d3_hcl(Math.atan2(b, a) * d3_degrees, Math.sqrt(a * a + b * b), l) : new d3_hcl(NaN, NaN, l); + } + function d3_lab_xyz(x) { + return x > .206893034 ? x * x * x : (x - 4 / 29) / 7.787037; + } + function d3_xyz_lab(x) { + return x > .008856 ? Math.pow(x, 1 / 3) : 7.787037 * x + 4 / 29; + } + function d3_xyz_rgb(r) { + return Math.round(255 * (r <= .00304 ? 12.92 * r : 1.055 * Math.pow(r, 1 / 2.4) - .055)); + } + d3.rgb = d3_rgb; + function d3_rgb(r, g, b) { + return this instanceof d3_rgb ? void (this.r = ~~r, this.g = ~~g, this.b = ~~b) : arguments.length < 2 ? r instanceof d3_rgb ? new d3_rgb(r.r, r.g, r.b) : d3_rgb_parse("" + r, d3_rgb, d3_hsl_rgb) : new d3_rgb(r, g, b); + } + function d3_rgbNumber(value) { + return new d3_rgb(value >> 16, value >> 8 & 255, value & 255); + } + function d3_rgbString(value) { + return d3_rgbNumber(value) + ""; + } + var d3_rgbPrototype = d3_rgb.prototype = new d3_color(); + d3_rgbPrototype.brighter = function(k) { + k = Math.pow(.7, arguments.length ? k : 1); + var r = this.r, g = this.g, b = this.b, i = 30; + if (!r && !g && !b) return new d3_rgb(i, i, i); + if (r && r < i) r = i; + if (g && g < i) g = i; + if (b && b < i) b = i; + return new d3_rgb(Math.min(255, r / k), Math.min(255, g / k), Math.min(255, b / k)); + }; + d3_rgbPrototype.darker = function(k) { + k = Math.pow(.7, arguments.length ? k : 1); + return new d3_rgb(k * this.r, k * this.g, k * this.b); + }; + d3_rgbPrototype.hsl = function() { + return d3_rgb_hsl(this.r, this.g, this.b); + }; + d3_rgbPrototype.toString = function() { + return "#" + d3_rgb_hex(this.r) + d3_rgb_hex(this.g) + d3_rgb_hex(this.b); + }; + function d3_rgb_hex(v) { + return v < 16 ? "0" + Math.max(0, v).toString(16) : Math.min(255, v).toString(16); + } + function d3_rgb_parse(format, rgb, hsl) { + var r = 0, g = 0, b = 0, m1, m2, color; + m1 = /([a-z]+)\((.*)\)/.exec(format = format.toLowerCase()); + if (m1) { + m2 = m1[2].split(","); + switch (m1[1]) { + case "hsl": + { + return hsl(parseFloat(m2[0]), parseFloat(m2[1]) / 100, parseFloat(m2[2]) / 100); + } + + case "rgb": + { + return rgb(d3_rgb_parseNumber(m2[0]), d3_rgb_parseNumber(m2[1]), d3_rgb_parseNumber(m2[2])); + } + } + } + if (color = d3_rgb_names.get(format)) { + return rgb(color.r, color.g, color.b); + } + if (format != null && format.charAt(0) === "#" && !isNaN(color = parseInt(format.slice(1), 16))) { + if (format.length === 4) { + r = (color & 3840) >> 4; + r = r >> 4 | r; + g = color & 240; + g = g >> 4 | g; + b = color & 15; + b = b << 4 | b; + } else if (format.length === 7) { + r = (color & 16711680) >> 16; + g = (color & 65280) >> 8; + b = color & 255; + } + } + return rgb(r, g, b); + } + function d3_rgb_hsl(r, g, b) { + var min = Math.min(r /= 255, g /= 255, b /= 255), max = Math.max(r, g, b), d = max - min, h, s, l = (max + min) / 2; + if (d) { + s = l < .5 ? d / (max + min) : d / (2 - max - min); + if (r == max) h = (g - b) / d + (g < b ? 6 : 0); else if (g == max) h = (b - r) / d + 2; else h = (r - g) / d + 4; + h *= 60; + } else { + h = NaN; + s = l > 0 && l < 1 ? 0 : h; + } + return new d3_hsl(h, s, l); + } + function d3_rgb_lab(r, g, b) { + r = d3_rgb_xyz(r); + g = d3_rgb_xyz(g); + b = d3_rgb_xyz(b); + var x = d3_xyz_lab((.4124564 * r + .3575761 * g + .1804375 * b) / d3_lab_X), y = d3_xyz_lab((.2126729 * r + .7151522 * g + .072175 * b) / d3_lab_Y), z = d3_xyz_lab((.0193339 * r + .119192 * g + .9503041 * b) / d3_lab_Z); + return d3_lab(116 * y - 16, 500 * (x - y), 200 * (y - z)); + } + function d3_rgb_xyz(r) { + return (r /= 255) <= .04045 ? r / 12.92 : Math.pow((r + .055) / 1.055, 2.4); + } + function d3_rgb_parseNumber(c) { + var f = parseFloat(c); + return c.charAt(c.length - 1) === "%" ? Math.round(f * 2.55) : f; + } + var d3_rgb_names = d3.map({ + aliceblue: 15792383, + antiquewhite: 16444375, + aqua: 65535, + aquamarine: 8388564, + azure: 15794175, + beige: 16119260, + bisque: 16770244, + black: 0, + blanchedalmond: 16772045, + blue: 255, + blueviolet: 9055202, + brown: 10824234, + burlywood: 14596231, + cadetblue: 6266528, + chartreuse: 8388352, + chocolate: 13789470, + coral: 16744272, + cornflowerblue: 6591981, + cornsilk: 16775388, + crimson: 14423100, + cyan: 65535, + darkblue: 139, + darkcyan: 35723, + darkgoldenrod: 12092939, + darkgray: 11119017, + darkgreen: 25600, + darkgrey: 11119017, + darkkhaki: 12433259, + darkmagenta: 9109643, + darkolivegreen: 5597999, + darkorange: 16747520, + darkorchid: 10040012, + darkred: 9109504, + darksalmon: 15308410, + darkseagreen: 9419919, + darkslateblue: 4734347, + darkslategray: 3100495, + darkslategrey: 3100495, + darkturquoise: 52945, + darkviolet: 9699539, + deeppink: 16716947, + deepskyblue: 49151, + dimgray: 6908265, + dimgrey: 6908265, + dodgerblue: 2003199, + firebrick: 11674146, + floralwhite: 16775920, + forestgreen: 2263842, + fuchsia: 16711935, + gainsboro: 14474460, + ghostwhite: 16316671, + gold: 16766720, + goldenrod: 14329120, + gray: 8421504, + green: 32768, + greenyellow: 11403055, + grey: 8421504, + honeydew: 15794160, + hotpink: 16738740, + indianred: 13458524, + indigo: 4915330, + ivory: 16777200, + khaki: 15787660, + lavender: 15132410, + lavenderblush: 16773365, + lawngreen: 8190976, + lemonchiffon: 16775885, + lightblue: 11393254, + lightcoral: 15761536, + lightcyan: 14745599, + lightgoldenrodyellow: 16448210, + lightgray: 13882323, + lightgreen: 9498256, + lightgrey: 13882323, + lightpink: 16758465, + lightsalmon: 16752762, + lightseagreen: 2142890, + lightskyblue: 8900346, + lightslategray: 7833753, + lightslategrey: 7833753, + lightsteelblue: 11584734, + lightyellow: 16777184, + lime: 65280, + limegreen: 3329330, + linen: 16445670, + magenta: 16711935, + maroon: 8388608, + mediumaquamarine: 6737322, + mediumblue: 205, + mediumorchid: 12211667, + mediumpurple: 9662683, + mediumseagreen: 3978097, + mediumslateblue: 8087790, + mediumspringgreen: 64154, + mediumturquoise: 4772300, + mediumvioletred: 13047173, + midnightblue: 1644912, + mintcream: 16121850, + mistyrose: 16770273, + moccasin: 16770229, + navajowhite: 16768685, + navy: 128, + oldlace: 16643558, + olive: 8421376, + olivedrab: 7048739, + orange: 16753920, + orangered: 16729344, + orchid: 14315734, + palegoldenrod: 15657130, + palegreen: 10025880, + paleturquoise: 11529966, + palevioletred: 14381203, + papayawhip: 16773077, + peachpuff: 16767673, + peru: 13468991, + pink: 16761035, + plum: 14524637, + powderblue: 11591910, + purple: 8388736, + rebeccapurple: 6697881, + red: 16711680, + rosybrown: 12357519, + royalblue: 4286945, + saddlebrown: 9127187, + salmon: 16416882, + sandybrown: 16032864, + seagreen: 3050327, + seashell: 16774638, + sienna: 10506797, + silver: 12632256, + skyblue: 8900331, + slateblue: 6970061, + slategray: 7372944, + slategrey: 7372944, + snow: 16775930, + springgreen: 65407, + steelblue: 4620980, + tan: 13808780, + teal: 32896, + thistle: 14204888, + tomato: 16737095, + turquoise: 4251856, + violet: 15631086, + wheat: 16113331, + white: 16777215, + whitesmoke: 16119285, + yellow: 16776960, + yellowgreen: 10145074 + }); + d3_rgb_names.forEach(function(key, value) { + d3_rgb_names.set(key, d3_rgbNumber(value)); + }); + function d3_functor(v) { + return typeof v === "function" ? v : function() { + return v; + }; + } + d3.functor = d3_functor; + d3.xhr = d3_xhrType(d3_identity); + function d3_xhrType(response) { + return function(url, mimeType, callback) { + if (arguments.length === 2 && typeof mimeType === "function") callback = mimeType, + mimeType = null; + return d3_xhr(url, mimeType, response, callback); + }; + } + function d3_xhr(url, mimeType, response, callback) { + var xhr = {}, dispatch = d3.dispatch("beforesend", "progress", "load", "error"), headers = {}, request = new XMLHttpRequest(), responseType = null; + if (this.XDomainRequest && !("withCredentials" in request) && /^(http(s)?:)?\/\//.test(url)) request = new XDomainRequest(); + "onload" in request ? request.onload = request.onerror = respond : request.onreadystatechange = function() { + request.readyState > 3 && respond(); + }; + function respond() { + var status = request.status, result; + if (!status && d3_xhrHasResponse(request) || status >= 200 && status < 300 || status === 304) { + try { + result = response.call(xhr, request); + } catch (e) { + dispatch.error.call(xhr, e); + return; + } + dispatch.load.call(xhr, result); + } else { + dispatch.error.call(xhr, request); + } + } + request.onprogress = function(event) { + var o = d3.event; + d3.event = event; + try { + dispatch.progress.call(xhr, request); + } finally { + d3.event = o; + } + }; + xhr.header = function(name, value) { + name = (name + "").toLowerCase(); + if (arguments.length < 2) return headers[name]; + if (value == null) delete headers[name]; else headers[name] = value + ""; + return xhr; + }; + xhr.mimeType = function(value) { + if (!arguments.length) return mimeType; + mimeType = value == null ? null : value + ""; + return xhr; + }; + xhr.responseType = function(value) { + if (!arguments.length) return responseType; + responseType = value; + return xhr; + }; + xhr.response = function(value) { + response = value; + return xhr; + }; + [ "get", "post" ].forEach(function(method) { + xhr[method] = function() { + return xhr.send.apply(xhr, [ method ].concat(d3_array(arguments))); + }; + }); + xhr.send = function(method, data, callback) { + if (arguments.length === 2 && typeof data === "function") callback = data, data = null; + request.open(method, url, true); + if (mimeType != null && !("accept" in headers)) headers["accept"] = mimeType + ",*/*"; + if (request.setRequestHeader) for (var name in headers) request.setRequestHeader(name, headers[name]); + if (mimeType != null && request.overrideMimeType) request.overrideMimeType(mimeType); + if (responseType != null) request.responseType = responseType; + if (callback != null) xhr.on("error", callback).on("load", function(request) { + callback(null, request); + }); + dispatch.beforesend.call(xhr, request); + request.send(data == null ? null : data); + return xhr; + }; + xhr.abort = function() { + request.abort(); + return xhr; + }; + d3.rebind(xhr, dispatch, "on"); + return callback == null ? xhr : xhr.get(d3_xhr_fixCallback(callback)); + } + function d3_xhr_fixCallback(callback) { + return callback.length === 1 ? function(error, request) { + callback(error == null ? request : null); + } : callback; + } + function d3_xhrHasResponse(request) { + var type = request.responseType; + return type && type !== "text" ? request.response : request.responseText; + } + d3.dsv = function(delimiter, mimeType) { + var reFormat = new RegExp('["' + delimiter + "\n]"), delimiterCode = delimiter.charCodeAt(0); + function dsv(url, row, callback) { + if (arguments.length < 3) callback = row, row = null; + var xhr = d3_xhr(url, mimeType, row == null ? response : typedResponse(row), callback); + xhr.row = function(_) { + return arguments.length ? xhr.response((row = _) == null ? response : typedResponse(_)) : row; + }; + return xhr; + } + function response(request) { + return dsv.parse(request.responseText); + } + function typedResponse(f) { + return function(request) { + return dsv.parse(request.responseText, f); + }; + } + dsv.parse = function(text, f) { + var o; + return dsv.parseRows(text, function(row, i) { + if (o) return o(row, i - 1); + var a = new Function("d", "return {" + row.map(function(name, i) { + return JSON.stringify(name) + ": d[" + i + "]"; + }).join(",") + "}"); + o = f ? function(row, i) { + return f(a(row), i); + } : a; + }); + }; + dsv.parseRows = function(text, f) { + var EOL = {}, EOF = {}, rows = [], N = text.length, I = 0, n = 0, t, eol; + function token() { + if (I >= N) return EOF; + if (eol) return eol = false, EOL; + var j = I; + if (text.charCodeAt(j) === 34) { + var i = j; + while (i++ < N) { + if (text.charCodeAt(i) === 34) { + if (text.charCodeAt(i + 1) !== 34) break; + ++i; + } + } + I = i + 2; + var c = text.charCodeAt(i + 1); + if (c === 13) { + eol = true; + if (text.charCodeAt(i + 2) === 10) ++I; + } else if (c === 10) { + eol = true; + } + return text.slice(j + 1, i).replace(/""/g, '"'); + } + while (I < N) { + var c = text.charCodeAt(I++), k = 1; + if (c === 10) eol = true; else if (c === 13) { + eol = true; + if (text.charCodeAt(I) === 10) ++I, ++k; + } else if (c !== delimiterCode) continue; + return text.slice(j, I - k); + } + return text.slice(j); + } + while ((t = token()) !== EOF) { + var a = []; + while (t !== EOL && t !== EOF) { + a.push(t); + t = token(); + } + if (f && (a = f(a, n++)) == null) continue; + rows.push(a); + } + return rows; + }; + dsv.format = function(rows) { + if (Array.isArray(rows[0])) return dsv.formatRows(rows); + var fieldSet = new d3_Set(), fields = []; + rows.forEach(function(row) { + for (var field in row) { + if (!fieldSet.has(field)) { + fields.push(fieldSet.add(field)); + } + } + }); + return [ fields.map(formatValue).join(delimiter) ].concat(rows.map(function(row) { + return fields.map(function(field) { + return formatValue(row[field]); + }).join(delimiter); + })).join("\n"); + }; + dsv.formatRows = function(rows) { + return rows.map(formatRow).join("\n"); + }; + function formatRow(row) { + return row.map(formatValue).join(delimiter); + } + function formatValue(text) { + return reFormat.test(text) ? '"' + text.replace(/\"/g, '""') + '"' : text; + } + return dsv; + }; + d3.csv = d3.dsv(",", "text/csv"); + d3.tsv = d3.dsv(" ", "text/tab-separated-values"); + var d3_timer_queueHead, d3_timer_queueTail, d3_timer_interval, d3_timer_timeout, d3_timer_frame = this[d3_vendorSymbol(this, "requestAnimationFrame")] || function(callback) { + setTimeout(callback, 17); + }; + d3.timer = function() { + d3_timer.apply(this, arguments); + }; + function d3_timer(callback, delay, then) { + var n = arguments.length; + if (n < 2) delay = 0; + if (n < 3) then = Date.now(); + var time = then + delay, timer = { + c: callback, + t: time, + n: null + }; + if (d3_timer_queueTail) d3_timer_queueTail.n = timer; else d3_timer_queueHead = timer; + d3_timer_queueTail = timer; + if (!d3_timer_interval) { + d3_timer_timeout = clearTimeout(d3_timer_timeout); + d3_timer_interval = 1; + d3_timer_frame(d3_timer_step); + } + return timer; + } + function d3_timer_step() { + var now = d3_timer_mark(), delay = d3_timer_sweep() - now; + if (delay > 24) { + if (isFinite(delay)) { + clearTimeout(d3_timer_timeout); + d3_timer_timeout = setTimeout(d3_timer_step, delay); + } + d3_timer_interval = 0; + } else { + d3_timer_interval = 1; + d3_timer_frame(d3_timer_step); + } + } + d3.timer.flush = function() { + d3_timer_mark(); + d3_timer_sweep(); + }; + function d3_timer_mark() { + var now = Date.now(), timer = d3_timer_queueHead; + while (timer) { + if (now >= timer.t && timer.c(now - timer.t)) timer.c = null; + timer = timer.n; + } + return now; + } + function d3_timer_sweep() { + var t0, t1 = d3_timer_queueHead, time = Infinity; + while (t1) { + if (t1.c) { + if (t1.t < time) time = t1.t; + t1 = (t0 = t1).n; + } else { + t1 = t0 ? t0.n = t1.n : d3_timer_queueHead = t1.n; + } + } + d3_timer_queueTail = t0; + return time; + } + function d3_format_precision(x, p) { + return p - (x ? Math.ceil(Math.log(x) / Math.LN10) : 1); + } + d3.round = function(x, n) { + return n ? Math.round(x * (n = Math.pow(10, n))) / n : Math.round(x); + }; + var d3_formatPrefixes = [ "y", "z", "a", "f", "p", "n", "µ", "m", "", "k", "M", "G", "T", "P", "E", "Z", "Y" ].map(d3_formatPrefix); + d3.formatPrefix = function(value, precision) { + var i = 0; + if (value = +value) { + if (value < 0) value *= -1; + if (precision) value = d3.round(value, d3_format_precision(value, precision)); + i = 1 + Math.floor(1e-12 + Math.log(value) / Math.LN10); + i = Math.max(-24, Math.min(24, Math.floor((i - 1) / 3) * 3)); + } + return d3_formatPrefixes[8 + i / 3]; + }; + function d3_formatPrefix(d, i) { + var k = Math.pow(10, abs(8 - i) * 3); + return { + scale: i > 8 ? function(d) { + return d / k; + } : function(d) { + return d * k; + }, + symbol: d + }; + } + function d3_locale_numberFormat(locale) { + var locale_decimal = locale.decimal, locale_thousands = locale.thousands, locale_grouping = locale.grouping, locale_currency = locale.currency, formatGroup = locale_grouping && locale_thousands ? function(value, width) { + var i = value.length, t = [], j = 0, g = locale_grouping[0], length = 0; + while (i > 0 && g > 0) { + if (length + g + 1 > width) g = Math.max(1, width - length); + t.push(value.substring(i -= g, i + g)); + if ((length += g + 1) > width) break; + g = locale_grouping[j = (j + 1) % locale_grouping.length]; + } + return t.reverse().join(locale_thousands); + } : d3_identity; + return function(specifier) { + var match = d3_format_re.exec(specifier), fill = match[1] || " ", align = match[2] || ">", sign = match[3] || "-", symbol = match[4] || "", zfill = match[5], width = +match[6], comma = match[7], precision = match[8], type = match[9], scale = 1, prefix = "", suffix = "", integer = false, exponent = true; + if (precision) precision = +precision.substring(1); + if (zfill || fill === "0" && align === "=") { + zfill = fill = "0"; + align = "="; + } + switch (type) { + case "n": + comma = true; + type = "g"; + break; + + case "%": + scale = 100; + suffix = "%"; + type = "f"; + break; + + case "p": + scale = 100; + suffix = "%"; + type = "r"; + break; + + case "b": + case "o": + case "x": + case "X": + if (symbol === "#") prefix = "0" + type.toLowerCase(); + + case "c": + exponent = false; + + case "d": + integer = true; + precision = 0; + break; + + case "s": + scale = -1; + type = "r"; + break; + } + if (symbol === "$") prefix = locale_currency[0], suffix = locale_currency[1]; + if (type == "r" && !precision) type = "g"; + if (precision != null) { + if (type == "g") precision = Math.max(1, Math.min(21, precision)); else if (type == "e" || type == "f") precision = Math.max(0, Math.min(20, precision)); + } + type = d3_format_types.get(type) || d3_format_typeDefault; + var zcomma = zfill && comma; + return function(value) { + var fullSuffix = suffix; + if (integer && value % 1) return ""; + var negative = value < 0 || value === 0 && 1 / value < 0 ? (value = -value, "-") : sign === "-" ? "" : sign; + if (scale < 0) { + var unit = d3.formatPrefix(value, precision); + value = unit.scale(value); + fullSuffix = unit.symbol + suffix; + } else { + value *= scale; + } + value = type(value, precision); + var i = value.lastIndexOf("."), before, after; + if (i < 0) { + var j = exponent ? value.lastIndexOf("e") : -1; + if (j < 0) before = value, after = ""; else before = value.substring(0, j), after = value.substring(j); + } else { + before = value.substring(0, i); + after = locale_decimal + value.substring(i + 1); + } + if (!zfill && comma) before = formatGroup(before, Infinity); + var length = prefix.length + before.length + after.length + (zcomma ? 0 : negative.length), padding = length < width ? new Array(length = width - length + 1).join(fill) : ""; + if (zcomma) before = formatGroup(padding + before, padding.length ? width - after.length : Infinity); + negative += prefix; + value = before + after; + return (align === "<" ? negative + value + padding : align === ">" ? padding + negative + value : align === "^" ? padding.substring(0, length >>= 1) + negative + value + padding.substring(length) : negative + (zcomma ? value : padding + value)) + fullSuffix; + }; + }; + } + var d3_format_re = /(?:([^{])?([<>=^]))?([+\- ])?([$#])?(0)?(\d+)?(,)?(\.-?\d+)?([a-z%])?/i; + var d3_format_types = d3.map({ + b: function(x) { + return x.toString(2); + }, + c: function(x) { + return String.fromCharCode(x); + }, + o: function(x) { + return x.toString(8); + }, + x: function(x) { + return x.toString(16); + }, + X: function(x) { + return x.toString(16).toUpperCase(); + }, + g: function(x, p) { + return x.toPrecision(p); + }, + e: function(x, p) { + return x.toExponential(p); + }, + f: function(x, p) { + return x.toFixed(p); + }, + r: function(x, p) { + return (x = d3.round(x, d3_format_precision(x, p))).toFixed(Math.max(0, Math.min(20, d3_format_precision(x * (1 + 1e-15), p)))); + } + }); + function d3_format_typeDefault(x) { + return x + ""; + } + var d3_time = d3.time = {}, d3_date = Date; + function d3_date_utc() { + this._ = new Date(arguments.length > 1 ? Date.UTC.apply(this, arguments) : arguments[0]); + } + d3_date_utc.prototype = { + getDate: function() { + return this._.getUTCDate(); + }, + getDay: function() { + return this._.getUTCDay(); + }, + getFullYear: function() { + return this._.getUTCFullYear(); + }, + getHours: function() { + return this._.getUTCHours(); + }, + getMilliseconds: function() { + return this._.getUTCMilliseconds(); + }, + getMinutes: function() { + return this._.getUTCMinutes(); + }, + getMonth: function() { + return this._.getUTCMonth(); + }, + getSeconds: function() { + return this._.getUTCSeconds(); + }, + getTime: function() { + return this._.getTime(); + }, + getTimezoneOffset: function() { + return 0; + }, + valueOf: function() { + return this._.valueOf(); + }, + setDate: function() { + d3_time_prototype.setUTCDate.apply(this._, arguments); + }, + setDay: function() { + d3_time_prototype.setUTCDay.apply(this._, arguments); + }, + setFullYear: function() { + d3_time_prototype.setUTCFullYear.apply(this._, arguments); + }, + setHours: function() { + d3_time_prototype.setUTCHours.apply(this._, arguments); + }, + setMilliseconds: function() { + d3_time_prototype.setUTCMilliseconds.apply(this._, arguments); + }, + setMinutes: function() { + d3_time_prototype.setUTCMinutes.apply(this._, arguments); + }, + setMonth: function() { + d3_time_prototype.setUTCMonth.apply(this._, arguments); + }, + setSeconds: function() { + d3_time_prototype.setUTCSeconds.apply(this._, arguments); + }, + setTime: function() { + d3_time_prototype.setTime.apply(this._, arguments); + } + }; + var d3_time_prototype = Date.prototype; + function d3_time_interval(local, step, number) { + function round(date) { + var d0 = local(date), d1 = offset(d0, 1); + return date - d0 < d1 - date ? d0 : d1; + } + function ceil(date) { + step(date = local(new d3_date(date - 1)), 1); + return date; + } + function offset(date, k) { + step(date = new d3_date(+date), k); + return date; + } + function range(t0, t1, dt) { + var time = ceil(t0), times = []; + if (dt > 1) { + while (time < t1) { + if (!(number(time) % dt)) times.push(new Date(+time)); + step(time, 1); + } + } else { + while (time < t1) times.push(new Date(+time)), step(time, 1); + } + return times; + } + function range_utc(t0, t1, dt) { + try { + d3_date = d3_date_utc; + var utc = new d3_date_utc(); + utc._ = t0; + return range(utc, t1, dt); + } finally { + d3_date = Date; + } + } + local.floor = local; + local.round = round; + local.ceil = ceil; + local.offset = offset; + local.range = range; + var utc = local.utc = d3_time_interval_utc(local); + utc.floor = utc; + utc.round = d3_time_interval_utc(round); + utc.ceil = d3_time_interval_utc(ceil); + utc.offset = d3_time_interval_utc(offset); + utc.range = range_utc; + return local; + } + function d3_time_interval_utc(method) { + return function(date, k) { + try { + d3_date = d3_date_utc; + var utc = new d3_date_utc(); + utc._ = date; + return method(utc, k)._; + } finally { + d3_date = Date; + } + }; + } + d3_time.year = d3_time_interval(function(date) { + date = d3_time.day(date); + date.setMonth(0, 1); + return date; + }, function(date, offset) { + date.setFullYear(date.getFullYear() + offset); + }, function(date) { + return date.getFullYear(); + }); + d3_time.years = d3_time.year.range; + d3_time.years.utc = d3_time.year.utc.range; + d3_time.day = d3_time_interval(function(date) { + var day = new d3_date(2e3, 0); + day.setFullYear(date.getFullYear(), date.getMonth(), date.getDate()); + return day; + }, function(date, offset) { + date.setDate(date.getDate() + offset); + }, function(date) { + return date.getDate() - 1; + }); + d3_time.days = d3_time.day.range; + d3_time.days.utc = d3_time.day.utc.range; + d3_time.dayOfYear = function(date) { + var year = d3_time.year(date); + return Math.floor((date - year - (date.getTimezoneOffset() - year.getTimezoneOffset()) * 6e4) / 864e5); + }; + [ "sunday", "monday", "tuesday", "wednesday", "thursday", "friday", "saturday" ].forEach(function(day, i) { + i = 7 - i; + var interval = d3_time[day] = d3_time_interval(function(date) { + (date = d3_time.day(date)).setDate(date.getDate() - (date.getDay() + i) % 7); + return date; + }, function(date, offset) { + date.setDate(date.getDate() + Math.floor(offset) * 7); + }, function(date) { + var day = d3_time.year(date).getDay(); + return Math.floor((d3_time.dayOfYear(date) + (day + i) % 7) / 7) - (day !== i); + }); + d3_time[day + "s"] = interval.range; + d3_time[day + "s"].utc = interval.utc.range; + d3_time[day + "OfYear"] = function(date) { + var day = d3_time.year(date).getDay(); + return Math.floor((d3_time.dayOfYear(date) + (day + i) % 7) / 7); + }; + }); + d3_time.week = d3_time.sunday; + d3_time.weeks = d3_time.sunday.range; + d3_time.weeks.utc = d3_time.sunday.utc.range; + d3_time.weekOfYear = d3_time.sundayOfYear; + function d3_locale_timeFormat(locale) { + var locale_dateTime = locale.dateTime, locale_date = locale.date, locale_time = locale.time, locale_periods = locale.periods, locale_days = locale.days, locale_shortDays = locale.shortDays, locale_months = locale.months, locale_shortMonths = locale.shortMonths; + function d3_time_format(template) { + var n = template.length; + function format(date) { + var string = [], i = -1, j = 0, c, p, f; + while (++i < n) { + if (template.charCodeAt(i) === 37) { + string.push(template.slice(j, i)); + if ((p = d3_time_formatPads[c = template.charAt(++i)]) != null) c = template.charAt(++i); + if (f = d3_time_formats[c]) c = f(date, p == null ? c === "e" ? " " : "0" : p); + string.push(c); + j = i + 1; + } + } + string.push(template.slice(j, i)); + return string.join(""); + } + format.parse = function(string) { + var d = { + y: 1900, + m: 0, + d: 1, + H: 0, + M: 0, + S: 0, + L: 0, + Z: null + }, i = d3_time_parse(d, template, string, 0); + if (i != string.length) return null; + if ("p" in d) d.H = d.H % 12 + d.p * 12; + var localZ = d.Z != null && d3_date !== d3_date_utc, date = new (localZ ? d3_date_utc : d3_date)(); + if ("j" in d) date.setFullYear(d.y, 0, d.j); else if ("W" in d || "U" in d) { + if (!("w" in d)) d.w = "W" in d ? 1 : 0; + date.setFullYear(d.y, 0, 1); + date.setFullYear(d.y, 0, "W" in d ? (d.w + 6) % 7 + d.W * 7 - (date.getDay() + 5) % 7 : d.w + d.U * 7 - (date.getDay() + 6) % 7); + } else date.setFullYear(d.y, d.m, d.d); + date.setHours(d.H + (d.Z / 100 | 0), d.M + d.Z % 100, d.S, d.L); + return localZ ? date._ : date; + }; + format.toString = function() { + return template; + }; + return format; + } + function d3_time_parse(date, template, string, j) { + var c, p, t, i = 0, n = template.length, m = string.length; + while (i < n) { + if (j >= m) return -1; + c = template.charCodeAt(i++); + if (c === 37) { + t = template.charAt(i++); + p = d3_time_parsers[t in d3_time_formatPads ? template.charAt(i++) : t]; + if (!p || (j = p(date, string, j)) < 0) return -1; + } else if (c != string.charCodeAt(j++)) { + return -1; + } + } + return j; + } + d3_time_format.utc = function(template) { + var local = d3_time_format(template); + function format(date) { + try { + d3_date = d3_date_utc; + var utc = new d3_date(); + utc._ = date; + return local(utc); + } finally { + d3_date = Date; + } + } + format.parse = function(string) { + try { + d3_date = d3_date_utc; + var date = local.parse(string); + return date && date._; + } finally { + d3_date = Date; + } + }; + format.toString = local.toString; + return format; + }; + d3_time_format.multi = d3_time_format.utc.multi = d3_time_formatMulti; + var d3_time_periodLookup = d3.map(), d3_time_dayRe = d3_time_formatRe(locale_days), d3_time_dayLookup = d3_time_formatLookup(locale_days), d3_time_dayAbbrevRe = d3_time_formatRe(locale_shortDays), d3_time_dayAbbrevLookup = d3_time_formatLookup(locale_shortDays), d3_time_monthRe = d3_time_formatRe(locale_months), d3_time_monthLookup = d3_time_formatLookup(locale_months), d3_time_monthAbbrevRe = d3_time_formatRe(locale_shortMonths), d3_time_monthAbbrevLookup = d3_time_formatLookup(locale_shortMonths); + locale_periods.forEach(function(p, i) { + d3_time_periodLookup.set(p.toLowerCase(), i); + }); + var d3_time_formats = { + a: function(d) { + return locale_shortDays[d.getDay()]; + }, + A: function(d) { + return locale_days[d.getDay()]; + }, + b: function(d) { + return locale_shortMonths[d.getMonth()]; + }, + B: function(d) { + return locale_months[d.getMonth()]; + }, + c: d3_time_format(locale_dateTime), + d: function(d, p) { + return d3_time_formatPad(d.getDate(), p, 2); + }, + e: function(d, p) { + return d3_time_formatPad(d.getDate(), p, 2); + }, + H: function(d, p) { + return d3_time_formatPad(d.getHours(), p, 2); + }, + I: function(d, p) { + return d3_time_formatPad(d.getHours() % 12 || 12, p, 2); + }, + j: function(d, p) { + return d3_time_formatPad(1 + d3_time.dayOfYear(d), p, 3); + }, + L: function(d, p) { + return d3_time_formatPad(d.getMilliseconds(), p, 3); + }, + m: function(d, p) { + return d3_time_formatPad(d.getMonth() + 1, p, 2); + }, + M: function(d, p) { + return d3_time_formatPad(d.getMinutes(), p, 2); + }, + p: function(d) { + return locale_periods[+(d.getHours() >= 12)]; + }, + S: function(d, p) { + return d3_time_formatPad(d.getSeconds(), p, 2); + }, + U: function(d, p) { + return d3_time_formatPad(d3_time.sundayOfYear(d), p, 2); + }, + w: function(d) { + return d.getDay(); + }, + W: function(d, p) { + return d3_time_formatPad(d3_time.mondayOfYear(d), p, 2); + }, + x: d3_time_format(locale_date), + X: d3_time_format(locale_time), + y: function(d, p) { + return d3_time_formatPad(d.getFullYear() % 100, p, 2); + }, + Y: function(d, p) { + return d3_time_formatPad(d.getFullYear() % 1e4, p, 4); + }, + Z: d3_time_zone, + "%": function() { + return "%"; + } + }; + var d3_time_parsers = { + a: d3_time_parseWeekdayAbbrev, + A: d3_time_parseWeekday, + b: d3_time_parseMonthAbbrev, + B: d3_time_parseMonth, + c: d3_time_parseLocaleFull, + d: d3_time_parseDay, + e: d3_time_parseDay, + H: d3_time_parseHour24, + I: d3_time_parseHour24, + j: d3_time_parseDayOfYear, + L: d3_time_parseMilliseconds, + m: d3_time_parseMonthNumber, + M: d3_time_parseMinutes, + p: d3_time_parseAmPm, + S: d3_time_parseSeconds, + U: d3_time_parseWeekNumberSunday, + w: d3_time_parseWeekdayNumber, + W: d3_time_parseWeekNumberMonday, + x: d3_time_parseLocaleDate, + X: d3_time_parseLocaleTime, + y: d3_time_parseYear, + Y: d3_time_parseFullYear, + Z: d3_time_parseZone, + "%": d3_time_parseLiteralPercent + }; + function d3_time_parseWeekdayAbbrev(date, string, i) { + d3_time_dayAbbrevRe.lastIndex = 0; + var n = d3_time_dayAbbrevRe.exec(string.slice(i)); + return n ? (date.w = d3_time_dayAbbrevLookup.get(n[0].toLowerCase()), i + n[0].length) : -1; + } + function d3_time_parseWeekday(date, string, i) { + d3_time_dayRe.lastIndex = 0; + var n = d3_time_dayRe.exec(string.slice(i)); + return n ? (date.w = d3_time_dayLookup.get(n[0].toLowerCase()), i + n[0].length) : -1; + } + function d3_time_parseMonthAbbrev(date, string, i) { + d3_time_monthAbbrevRe.lastIndex = 0; + var n = d3_time_monthAbbrevRe.exec(string.slice(i)); + return n ? (date.m = d3_time_monthAbbrevLookup.get(n[0].toLowerCase()), i + n[0].length) : -1; + } + function d3_time_parseMonth(date, string, i) { + d3_time_monthRe.lastIndex = 0; + var n = d3_time_monthRe.exec(string.slice(i)); + return n ? (date.m = d3_time_monthLookup.get(n[0].toLowerCase()), i + n[0].length) : -1; + } + function d3_time_parseLocaleFull(date, string, i) { + return d3_time_parse(date, d3_time_formats.c.toString(), string, i); + } + function d3_time_parseLocaleDate(date, string, i) { + return d3_time_parse(date, d3_time_formats.x.toString(), string, i); + } + function d3_time_parseLocaleTime(date, string, i) { + return d3_time_parse(date, d3_time_formats.X.toString(), string, i); + } + function d3_time_parseAmPm(date, string, i) { + var n = d3_time_periodLookup.get(string.slice(i, i += 2).toLowerCase()); + return n == null ? -1 : (date.p = n, i); + } + return d3_time_format; + } + var d3_time_formatPads = { + "-": "", + _: " ", + "0": "0" + }, d3_time_numberRe = /^\s*\d+/, d3_time_percentRe = /^%/; + function d3_time_formatPad(value, fill, width) { + var sign = value < 0 ? "-" : "", string = (sign ? -value : value) + "", length = string.length; + return sign + (length < width ? new Array(width - length + 1).join(fill) + string : string); + } + function d3_time_formatRe(names) { + return new RegExp("^(?:" + names.map(d3.requote).join("|") + ")", "i"); + } + function d3_time_formatLookup(names) { + var map = new d3_Map(), i = -1, n = names.length; + while (++i < n) map.set(names[i].toLowerCase(), i); + return map; + } + function d3_time_parseWeekdayNumber(date, string, i) { + d3_time_numberRe.lastIndex = 0; + var n = d3_time_numberRe.exec(string.slice(i, i + 1)); + return n ? (date.w = +n[0], i + n[0].length) : -1; + } + function d3_time_parseWeekNumberSunday(date, string, i) { + d3_time_numberRe.lastIndex = 0; + var n = d3_time_numberRe.exec(string.slice(i)); + return n ? (date.U = +n[0], i + n[0].length) : -1; + } + function d3_time_parseWeekNumberMonday(date, string, i) { + d3_time_numberRe.lastIndex = 0; + var n = d3_time_numberRe.exec(string.slice(i)); + return n ? (date.W = +n[0], i + n[0].length) : -1; + } + function d3_time_parseFullYear(date, string, i) { + d3_time_numberRe.lastIndex = 0; + var n = d3_time_numberRe.exec(string.slice(i, i + 4)); + return n ? (date.y = +n[0], i + n[0].length) : -1; + } + function d3_time_parseYear(date, string, i) { + d3_time_numberRe.lastIndex = 0; + var n = d3_time_numberRe.exec(string.slice(i, i + 2)); + return n ? (date.y = d3_time_expandYear(+n[0]), i + n[0].length) : -1; + } + function d3_time_parseZone(date, string, i) { + return /^[+-]\d{4}$/.test(string = string.slice(i, i + 5)) ? (date.Z = -string, + i + 5) : -1; + } + function d3_time_expandYear(d) { + return d + (d > 68 ? 1900 : 2e3); + } + function d3_time_parseMonthNumber(date, string, i) { + d3_time_numberRe.lastIndex = 0; + var n = d3_time_numberRe.exec(string.slice(i, i + 2)); + return n ? (date.m = n[0] - 1, i + n[0].length) : -1; + } + function d3_time_parseDay(date, string, i) { + d3_time_numberRe.lastIndex = 0; + var n = d3_time_numberRe.exec(string.slice(i, i + 2)); + return n ? (date.d = +n[0], i + n[0].length) : -1; + } + function d3_time_parseDayOfYear(date, string, i) { + d3_time_numberRe.lastIndex = 0; + var n = d3_time_numberRe.exec(string.slice(i, i + 3)); + return n ? (date.j = +n[0], i + n[0].length) : -1; + } + function d3_time_parseHour24(date, string, i) { + d3_time_numberRe.lastIndex = 0; + var n = d3_time_numberRe.exec(string.slice(i, i + 2)); + return n ? (date.H = +n[0], i + n[0].length) : -1; + } + function d3_time_parseMinutes(date, string, i) { + d3_time_numberRe.lastIndex = 0; + var n = d3_time_numberRe.exec(string.slice(i, i + 2)); + return n ? (date.M = +n[0], i + n[0].length) : -1; + } + function d3_time_parseSeconds(date, string, i) { + d3_time_numberRe.lastIndex = 0; + var n = d3_time_numberRe.exec(string.slice(i, i + 2)); + return n ? (date.S = +n[0], i + n[0].length) : -1; + } + function d3_time_parseMilliseconds(date, string, i) { + d3_time_numberRe.lastIndex = 0; + var n = d3_time_numberRe.exec(string.slice(i, i + 3)); + return n ? (date.L = +n[0], i + n[0].length) : -1; + } + function d3_time_zone(d) { + var z = d.getTimezoneOffset(), zs = z > 0 ? "-" : "+", zh = abs(z) / 60 | 0, zm = abs(z) % 60; + return zs + d3_time_formatPad(zh, "0", 2) + d3_time_formatPad(zm, "0", 2); + } + function d3_time_parseLiteralPercent(date, string, i) { + d3_time_percentRe.lastIndex = 0; + var n = d3_time_percentRe.exec(string.slice(i, i + 1)); + return n ? i + n[0].length : -1; + } + function d3_time_formatMulti(formats) { + var n = formats.length, i = -1; + while (++i < n) formats[i][0] = this(formats[i][0]); + return function(date) { + var i = 0, f = formats[i]; + while (!f[1](date)) f = formats[++i]; + return f[0](date); + }; + } + d3.locale = function(locale) { + return { + numberFormat: d3_locale_numberFormat(locale), + timeFormat: d3_locale_timeFormat(locale) + }; + }; + var d3_locale_enUS = d3.locale({ + decimal: ".", + thousands: ",", + grouping: [ 3 ], + currency: [ "$", "" ], + dateTime: "%a %b %e %X %Y", + date: "%m/%d/%Y", + time: "%H:%M:%S", + periods: [ "AM", "PM" ], + days: [ "Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday" ], + shortDays: [ "Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat" ], + months: [ "January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December" ], + shortMonths: [ "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec" ] + }); + d3.format = d3_locale_enUS.numberFormat; + d3.geo = {}; + function d3_adder() {} + d3_adder.prototype = { + s: 0, + t: 0, + add: function(y) { + d3_adderSum(y, this.t, d3_adderTemp); + d3_adderSum(d3_adderTemp.s, this.s, this); + if (this.s) this.t += d3_adderTemp.t; else this.s = d3_adderTemp.t; + }, + reset: function() { + this.s = this.t = 0; + }, + valueOf: function() { + return this.s; + } + }; + var d3_adderTemp = new d3_adder(); + function d3_adderSum(a, b, o) { + var x = o.s = a + b, bv = x - a, av = x - bv; + o.t = a - av + (b - bv); + } + d3.geo.stream = function(object, listener) { + if (object && d3_geo_streamObjectType.hasOwnProperty(object.type)) { + d3_geo_streamObjectType[object.type](object, listener); + } else { + d3_geo_streamGeometry(object, listener); + } + }; + function d3_geo_streamGeometry(geometry, listener) { + if (geometry && d3_geo_streamGeometryType.hasOwnProperty(geometry.type)) { + d3_geo_streamGeometryType[geometry.type](geometry, listener); + } + } + var d3_geo_streamObjectType = { + Feature: function(feature, listener) { + d3_geo_streamGeometry(feature.geometry, listener); + }, + FeatureCollection: function(object, listener) { + var features = object.features, i = -1, n = features.length; + while (++i < n) d3_geo_streamGeometry(features[i].geometry, listener); + } + }; + var d3_geo_streamGeometryType = { + Sphere: function(object, listener) { + listener.sphere(); + }, + Point: function(object, listener) { + object = object.coordinates; + listener.point(object[0], object[1], object[2]); + }, + MultiPoint: function(object, listener) { + var coordinates = object.coordinates, i = -1, n = coordinates.length; + while (++i < n) object = coordinates[i], listener.point(object[0], object[1], object[2]); + }, + LineString: function(object, listener) { + d3_geo_streamLine(object.coordinates, listener, 0); + }, + MultiLineString: function(object, listener) { + var coordinates = object.coordinates, i = -1, n = coordinates.length; + while (++i < n) d3_geo_streamLine(coordinates[i], listener, 0); + }, + Polygon: function(object, listener) { + d3_geo_streamPolygon(object.coordinates, listener); + }, + MultiPolygon: function(object, listener) { + var coordinates = object.coordinates, i = -1, n = coordinates.length; + while (++i < n) d3_geo_streamPolygon(coordinates[i], listener); + }, + GeometryCollection: function(object, listener) { + var geometries = object.geometries, i = -1, n = geometries.length; + while (++i < n) d3_geo_streamGeometry(geometries[i], listener); + } + }; + function d3_geo_streamLine(coordinates, listener, closed) { + var i = -1, n = coordinates.length - closed, coordinate; + listener.lineStart(); + while (++i < n) coordinate = coordinates[i], listener.point(coordinate[0], coordinate[1], coordinate[2]); + listener.lineEnd(); + } + function d3_geo_streamPolygon(coordinates, listener) { + var i = -1, n = coordinates.length; + listener.polygonStart(); + while (++i < n) d3_geo_streamLine(coordinates[i], listener, 1); + listener.polygonEnd(); + } + d3.geo.area = function(object) { + d3_geo_areaSum = 0; + d3.geo.stream(object, d3_geo_area); + return d3_geo_areaSum; + }; + var d3_geo_areaSum, d3_geo_areaRingSum = new d3_adder(); + var d3_geo_area = { + sphere: function() { + d3_geo_areaSum += 4 * π; + }, + point: d3_noop, + lineStart: d3_noop, + lineEnd: d3_noop, + polygonStart: function() { + d3_geo_areaRingSum.reset(); + d3_geo_area.lineStart = d3_geo_areaRingStart; + }, + polygonEnd: function() { + var area = 2 * d3_geo_areaRingSum; + d3_geo_areaSum += area < 0 ? 4 * π + area : area; + d3_geo_area.lineStart = d3_geo_area.lineEnd = d3_geo_area.point = d3_noop; + } + }; + function d3_geo_areaRingStart() { + var λ00, φ00, λ0, cosφ0, sinφ0; + d3_geo_area.point = function(λ, φ) { + d3_geo_area.point = nextPoint; + λ0 = (λ00 = λ) * d3_radians, cosφ0 = Math.cos(φ = (φ00 = φ) * d3_radians / 2 + π / 4), + sinφ0 = Math.sin(φ); + }; + function nextPoint(λ, φ) { + λ *= d3_radians; + φ = φ * d3_radians / 2 + π / 4; + var dλ = λ - λ0, sdλ = dλ >= 0 ? 1 : -1, adλ = sdλ * dλ, cosφ = Math.cos(φ), sinφ = Math.sin(φ), k = sinφ0 * sinφ, u = cosφ0 * cosφ + k * Math.cos(adλ), v = k * sdλ * Math.sin(adλ); + d3_geo_areaRingSum.add(Math.atan2(v, u)); + λ0 = λ, cosφ0 = cosφ, sinφ0 = sinφ; + } + d3_geo_area.lineEnd = function() { + nextPoint(λ00, φ00); + }; + } + function d3_geo_cartesian(spherical) { + var λ = spherical[0], φ = spherical[1], cosφ = Math.cos(φ); + return [ cosφ * Math.cos(λ), cosφ * Math.sin(λ), Math.sin(φ) ]; + } + function d3_geo_cartesianDot(a, b) { + return a[0] * b[0] + a[1] * b[1] + a[2] * b[2]; + } + function d3_geo_cartesianCross(a, b) { + return [ a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0] ]; + } + function d3_geo_cartesianAdd(a, b) { + a[0] += b[0]; + a[1] += b[1]; + a[2] += b[2]; + } + function d3_geo_cartesianScale(vector, k) { + return [ vector[0] * k, vector[1] * k, vector[2] * k ]; + } + function d3_geo_cartesianNormalize(d) { + var l = Math.sqrt(d[0] * d[0] + d[1] * d[1] + d[2] * d[2]); + d[0] /= l; + d[1] /= l; + d[2] /= l; + } + function d3_geo_spherical(cartesian) { + return [ Math.atan2(cartesian[1], cartesian[0]), d3_asin(cartesian[2]) ]; + } + function d3_geo_sphericalEqual(a, b) { + return abs(a[0] - b[0]) < ε && abs(a[1] - b[1]) < ε; + } + d3.geo.bounds = function() { + var λ0, φ0, λ1, φ1, λ_, λ__, φ__, p0, dλSum, ranges, range; + var bound = { + point: point, + lineStart: lineStart, + lineEnd: lineEnd, + polygonStart: function() { + bound.point = ringPoint; + bound.lineStart = ringStart; + bound.lineEnd = ringEnd; + dλSum = 0; + d3_geo_area.polygonStart(); + }, + polygonEnd: function() { + d3_geo_area.polygonEnd(); + bound.point = point; + bound.lineStart = lineStart; + bound.lineEnd = lineEnd; + if (d3_geo_areaRingSum < 0) λ0 = -(λ1 = 180), φ0 = -(φ1 = 90); else if (dλSum > ε) φ1 = 90; else if (dλSum < -ε) φ0 = -90; + range[0] = λ0, range[1] = λ1; + } + }; + function point(λ, φ) { + ranges.push(range = [ λ0 = λ, λ1 = λ ]); + if (φ < φ0) φ0 = φ; + if (φ > φ1) φ1 = φ; + } + function linePoint(λ, φ) { + var p = d3_geo_cartesian([ λ * d3_radians, φ * d3_radians ]); + if (p0) { + var normal = d3_geo_cartesianCross(p0, p), equatorial = [ normal[1], -normal[0], 0 ], inflection = d3_geo_cartesianCross(equatorial, normal); + d3_geo_cartesianNormalize(inflection); + inflection = d3_geo_spherical(inflection); + var dλ = λ - λ_, s = dλ > 0 ? 1 : -1, λi = inflection[0] * d3_degrees * s, antimeridian = abs(dλ) > 180; + if (antimeridian ^ (s * λ_ < λi && λi < s * λ)) { + var φi = inflection[1] * d3_degrees; + if (φi > φ1) φ1 = φi; + } else if (λi = (λi + 360) % 360 - 180, antimeridian ^ (s * λ_ < λi && λi < s * λ)) { + var φi = -inflection[1] * d3_degrees; + if (φi < φ0) φ0 = φi; + } else { + if (φ < φ0) φ0 = φ; + if (φ > φ1) φ1 = φ; + } + if (antimeridian) { + if (λ < λ_) { + if (angle(λ0, λ) > angle(λ0, λ1)) λ1 = λ; + } else { + if (angle(λ, λ1) > angle(λ0, λ1)) λ0 = λ; + } + } else { + if (λ1 >= λ0) { + if (λ < λ0) λ0 = λ; + if (λ > λ1) λ1 = λ; + } else { + if (λ > λ_) { + if (angle(λ0, λ) > angle(λ0, λ1)) λ1 = λ; + } else { + if (angle(λ, λ1) > angle(λ0, λ1)) λ0 = λ; + } + } + } + } else { + point(λ, φ); + } + p0 = p, λ_ = λ; + } + function lineStart() { + bound.point = linePoint; + } + function lineEnd() { + range[0] = λ0, range[1] = λ1; + bound.point = point; + p0 = null; + } + function ringPoint(λ, φ) { + if (p0) { + var dλ = λ - λ_; + dλSum += abs(dλ) > 180 ? dλ + (dλ > 0 ? 360 : -360) : dλ; + } else λ__ = λ, φ__ = φ; + d3_geo_area.point(λ, φ); + linePoint(λ, φ); + } + function ringStart() { + d3_geo_area.lineStart(); + } + function ringEnd() { + ringPoint(λ__, φ__); + d3_geo_area.lineEnd(); + if (abs(dλSum) > ε) λ0 = -(λ1 = 180); + range[0] = λ0, range[1] = λ1; + p0 = null; + } + function angle(λ0, λ1) { + return (λ1 -= λ0) < 0 ? λ1 + 360 : λ1; + } + function compareRanges(a, b) { + return a[0] - b[0]; + } + function withinRange(x, range) { + return range[0] <= range[1] ? range[0] <= x && x <= range[1] : x < range[0] || range[1] < x; + } + return function(feature) { + φ1 = λ1 = -(λ0 = φ0 = Infinity); + ranges = []; + d3.geo.stream(feature, bound); + var n = ranges.length; + if (n) { + ranges.sort(compareRanges); + for (var i = 1, a = ranges[0], b, merged = [ a ]; i < n; ++i) { + b = ranges[i]; + if (withinRange(b[0], a) || withinRange(b[1], a)) { + if (angle(a[0], b[1]) > angle(a[0], a[1])) a[1] = b[1]; + if (angle(b[0], a[1]) > angle(a[0], a[1])) a[0] = b[0]; + } else { + merged.push(a = b); + } + } + var best = -Infinity, dλ; + for (var n = merged.length - 1, i = 0, a = merged[n], b; i <= n; a = b, ++i) { + b = merged[i]; + if ((dλ = angle(a[1], b[0])) > best) best = dλ, λ0 = b[0], λ1 = a[1]; + } + } + ranges = range = null; + return λ0 === Infinity || φ0 === Infinity ? [ [ NaN, NaN ], [ NaN, NaN ] ] : [ [ λ0, φ0 ], [ λ1, φ1 ] ]; + }; + }(); + d3.geo.centroid = function(object) { + d3_geo_centroidW0 = d3_geo_centroidW1 = d3_geo_centroidX0 = d3_geo_centroidY0 = d3_geo_centroidZ0 = d3_geo_centroidX1 = d3_geo_centroidY1 = d3_geo_centroidZ1 = d3_geo_centroidX2 = d3_geo_centroidY2 = d3_geo_centroidZ2 = 0; + d3.geo.stream(object, d3_geo_centroid); + var x = d3_geo_centroidX2, y = d3_geo_centroidY2, z = d3_geo_centroidZ2, m = x * x + y * y + z * z; + if (m < ε2) { + x = d3_geo_centroidX1, y = d3_geo_centroidY1, z = d3_geo_centroidZ1; + if (d3_geo_centroidW1 < ε) x = d3_geo_centroidX0, y = d3_geo_centroidY0, z = d3_geo_centroidZ0; + m = x * x + y * y + z * z; + if (m < ε2) return [ NaN, NaN ]; + } + return [ Math.atan2(y, x) * d3_degrees, d3_asin(z / Math.sqrt(m)) * d3_degrees ]; + }; + var d3_geo_centroidW0, d3_geo_centroidW1, d3_geo_centroidX0, d3_geo_centroidY0, d3_geo_centroidZ0, d3_geo_centroidX1, d3_geo_centroidY1, d3_geo_centroidZ1, d3_geo_centroidX2, d3_geo_centroidY2, d3_geo_centroidZ2; + var d3_geo_centroid = { + sphere: d3_noop, + point: d3_geo_centroidPoint, + lineStart: d3_geo_centroidLineStart, + lineEnd: d3_geo_centroidLineEnd, + polygonStart: function() { + d3_geo_centroid.lineStart = d3_geo_centroidRingStart; + }, + polygonEnd: function() { + d3_geo_centroid.lineStart = d3_geo_centroidLineStart; + } + }; + function d3_geo_centroidPoint(λ, φ) { + λ *= d3_radians; + var cosφ = Math.cos(φ *= d3_radians); + d3_geo_centroidPointXYZ(cosφ * Math.cos(λ), cosφ * Math.sin(λ), Math.sin(φ)); + } + function d3_geo_centroidPointXYZ(x, y, z) { + ++d3_geo_centroidW0; + d3_geo_centroidX0 += (x - d3_geo_centroidX0) / d3_geo_centroidW0; + d3_geo_centroidY0 += (y - d3_geo_centroidY0) / d3_geo_centroidW0; + d3_geo_centroidZ0 += (z - d3_geo_centroidZ0) / d3_geo_centroidW0; + } + function d3_geo_centroidLineStart() { + var x0, y0, z0; + d3_geo_centroid.point = function(λ, φ) { + λ *= d3_radians; + var cosφ = Math.cos(φ *= d3_radians); + x0 = cosφ * Math.cos(λ); + y0 = cosφ * Math.sin(λ); + z0 = Math.sin(φ); + d3_geo_centroid.point = nextPoint; + d3_geo_centroidPointXYZ(x0, y0, z0); + }; + function nextPoint(λ, φ) { + λ *= d3_radians; + var cosφ = Math.cos(φ *= d3_radians), x = cosφ * Math.cos(λ), y = cosφ * Math.sin(λ), z = Math.sin(φ), w = Math.atan2(Math.sqrt((w = y0 * z - z0 * y) * w + (w = z0 * x - x0 * z) * w + (w = x0 * y - y0 * x) * w), x0 * x + y0 * y + z0 * z); + d3_geo_centroidW1 += w; + d3_geo_centroidX1 += w * (x0 + (x0 = x)); + d3_geo_centroidY1 += w * (y0 + (y0 = y)); + d3_geo_centroidZ1 += w * (z0 + (z0 = z)); + d3_geo_centroidPointXYZ(x0, y0, z0); + } + } + function d3_geo_centroidLineEnd() { + d3_geo_centroid.point = d3_geo_centroidPoint; + } + function d3_geo_centroidRingStart() { + var λ00, φ00, x0, y0, z0; + d3_geo_centroid.point = function(λ, φ) { + λ00 = λ, φ00 = φ; + d3_geo_centroid.point = nextPoint; + λ *= d3_radians; + var cosφ = Math.cos(φ *= d3_radians); + x0 = cosφ * Math.cos(λ); + y0 = cosφ * Math.sin(λ); + z0 = Math.sin(φ); + d3_geo_centroidPointXYZ(x0, y0, z0); + }; + d3_geo_centroid.lineEnd = function() { + nextPoint(λ00, φ00); + d3_geo_centroid.lineEnd = d3_geo_centroidLineEnd; + d3_geo_centroid.point = d3_geo_centroidPoint; + }; + function nextPoint(λ, φ) { + λ *= d3_radians; + var cosφ = Math.cos(φ *= d3_radians), x = cosφ * Math.cos(λ), y = cosφ * Math.sin(λ), z = Math.sin(φ), cx = y0 * z - z0 * y, cy = z0 * x - x0 * z, cz = x0 * y - y0 * x, m = Math.sqrt(cx * cx + cy * cy + cz * cz), u = x0 * x + y0 * y + z0 * z, v = m && -d3_acos(u) / m, w = Math.atan2(m, u); + d3_geo_centroidX2 += v * cx; + d3_geo_centroidY2 += v * cy; + d3_geo_centroidZ2 += v * cz; + d3_geo_centroidW1 += w; + d3_geo_centroidX1 += w * (x0 + (x0 = x)); + d3_geo_centroidY1 += w * (y0 + (y0 = y)); + d3_geo_centroidZ1 += w * (z0 + (z0 = z)); + d3_geo_centroidPointXYZ(x0, y0, z0); + } + } + function d3_geo_compose(a, b) { + function compose(x, y) { + return x = a(x, y), b(x[0], x[1]); + } + if (a.invert && b.invert) compose.invert = function(x, y) { + return x = b.invert(x, y), x && a.invert(x[0], x[1]); + }; + return compose; + } + function d3_true() { + return true; + } + function d3_geo_clipPolygon(segments, compare, clipStartInside, interpolate, listener) { + var subject = [], clip = []; + segments.forEach(function(segment) { + if ((n = segment.length - 1) <= 0) return; + var n, p0 = segment[0], p1 = segment[n]; + if (d3_geo_sphericalEqual(p0, p1)) { + listener.lineStart(); + for (var i = 0; i < n; ++i) listener.point((p0 = segment[i])[0], p0[1]); + listener.lineEnd(); + return; + } + var a = new d3_geo_clipPolygonIntersection(p0, segment, null, true), b = new d3_geo_clipPolygonIntersection(p0, null, a, false); + a.o = b; + subject.push(a); + clip.push(b); + a = new d3_geo_clipPolygonIntersection(p1, segment, null, false); + b = new d3_geo_clipPolygonIntersection(p1, null, a, true); + a.o = b; + subject.push(a); + clip.push(b); + }); + clip.sort(compare); + d3_geo_clipPolygonLinkCircular(subject); + d3_geo_clipPolygonLinkCircular(clip); + if (!subject.length) return; + for (var i = 0, entry = clipStartInside, n = clip.length; i < n; ++i) { + clip[i].e = entry = !entry; + } + var start = subject[0], points, point; + while (1) { + var current = start, isSubject = true; + while (current.v) if ((current = current.n) === start) return; + points = current.z; + listener.lineStart(); + do { + current.v = current.o.v = true; + if (current.e) { + if (isSubject) { + for (var i = 0, n = points.length; i < n; ++i) listener.point((point = points[i])[0], point[1]); + } else { + interpolate(current.x, current.n.x, 1, listener); + } + current = current.n; + } else { + if (isSubject) { + points = current.p.z; + for (var i = points.length - 1; i >= 0; --i) listener.point((point = points[i])[0], point[1]); + } else { + interpolate(current.x, current.p.x, -1, listener); + } + current = current.p; + } + current = current.o; + points = current.z; + isSubject = !isSubject; + } while (!current.v); + listener.lineEnd(); + } + } + function d3_geo_clipPolygonLinkCircular(array) { + if (!(n = array.length)) return; + var n, i = 0, a = array[0], b; + while (++i < n) { + a.n = b = array[i]; + b.p = a; + a = b; + } + a.n = b = array[0]; + b.p = a; + } + function d3_geo_clipPolygonIntersection(point, points, other, entry) { + this.x = point; + this.z = points; + this.o = other; + this.e = entry; + this.v = false; + this.n = this.p = null; + } + function d3_geo_clip(pointVisible, clipLine, interpolate, clipStart) { + return function(rotate, listener) { + var line = clipLine(listener), rotatedClipStart = rotate.invert(clipStart[0], clipStart[1]); + var clip = { + point: point, + lineStart: lineStart, + lineEnd: lineEnd, + polygonStart: function() { + clip.point = pointRing; + clip.lineStart = ringStart; + clip.lineEnd = ringEnd; + segments = []; + polygon = []; + }, + polygonEnd: function() { + clip.point = point; + clip.lineStart = lineStart; + clip.lineEnd = lineEnd; + segments = d3.merge(segments); + var clipStartInside = d3_geo_pointInPolygon(rotatedClipStart, polygon); + if (segments.length) { + if (!polygonStarted) listener.polygonStart(), polygonStarted = true; + d3_geo_clipPolygon(segments, d3_geo_clipSort, clipStartInside, interpolate, listener); + } else if (clipStartInside) { + if (!polygonStarted) listener.polygonStart(), polygonStarted = true; + listener.lineStart(); + interpolate(null, null, 1, listener); + listener.lineEnd(); + } + if (polygonStarted) listener.polygonEnd(), polygonStarted = false; + segments = polygon = null; + }, + sphere: function() { + listener.polygonStart(); + listener.lineStart(); + interpolate(null, null, 1, listener); + listener.lineEnd(); + listener.polygonEnd(); + } + }; + function point(λ, φ) { + var point = rotate(λ, φ); + if (pointVisible(λ = point[0], φ = point[1])) listener.point(λ, φ); + } + function pointLine(λ, φ) { + var point = rotate(λ, φ); + line.point(point[0], point[1]); + } + function lineStart() { + clip.point = pointLine; + line.lineStart(); + } + function lineEnd() { + clip.point = point; + line.lineEnd(); + } + var segments; + var buffer = d3_geo_clipBufferListener(), ringListener = clipLine(buffer), polygonStarted = false, polygon, ring; + function pointRing(λ, φ) { + ring.push([ λ, φ ]); + var point = rotate(λ, φ); + ringListener.point(point[0], point[1]); + } + function ringStart() { + ringListener.lineStart(); + ring = []; + } + function ringEnd() { + pointRing(ring[0][0], ring[0][1]); + ringListener.lineEnd(); + var clean = ringListener.clean(), ringSegments = buffer.buffer(), segment, n = ringSegments.length; + ring.pop(); + polygon.push(ring); + ring = null; + if (!n) return; + if (clean & 1) { + segment = ringSegments[0]; + var n = segment.length - 1, i = -1, point; + if (n > 0) { + if (!polygonStarted) listener.polygonStart(), polygonStarted = true; + listener.lineStart(); + while (++i < n) listener.point((point = segment[i])[0], point[1]); + listener.lineEnd(); + } + return; + } + if (n > 1 && clean & 2) ringSegments.push(ringSegments.pop().concat(ringSegments.shift())); + segments.push(ringSegments.filter(d3_geo_clipSegmentLength1)); + } + return clip; + }; + } + function d3_geo_clipSegmentLength1(segment) { + return segment.length > 1; + } + function d3_geo_clipBufferListener() { + var lines = [], line; + return { + lineStart: function() { + lines.push(line = []); + }, + point: function(λ, φ) { + line.push([ λ, φ ]); + }, + lineEnd: d3_noop, + buffer: function() { + var buffer = lines; + lines = []; + line = null; + return buffer; + }, + rejoin: function() { + if (lines.length > 1) lines.push(lines.pop().concat(lines.shift())); + } + }; + } + function d3_geo_clipSort(a, b) { + return ((a = a.x)[0] < 0 ? a[1] - halfπ - ε : halfπ - a[1]) - ((b = b.x)[0] < 0 ? b[1] - halfπ - ε : halfπ - b[1]); + } + var d3_geo_clipAntimeridian = d3_geo_clip(d3_true, d3_geo_clipAntimeridianLine, d3_geo_clipAntimeridianInterpolate, [ -π, -π / 2 ]); + function d3_geo_clipAntimeridianLine(listener) { + var λ0 = NaN, φ0 = NaN, sλ0 = NaN, clean; + return { + lineStart: function() { + listener.lineStart(); + clean = 1; + }, + point: function(λ1, φ1) { + var sλ1 = λ1 > 0 ? π : -π, dλ = abs(λ1 - λ0); + if (abs(dλ - π) < ε) { + listener.point(λ0, φ0 = (φ0 + φ1) / 2 > 0 ? halfπ : -halfπ); + listener.point(sλ0, φ0); + listener.lineEnd(); + listener.lineStart(); + listener.point(sλ1, φ0); + listener.point(λ1, φ0); + clean = 0; + } else if (sλ0 !== sλ1 && dλ >= π) { + if (abs(λ0 - sλ0) < ε) λ0 -= sλ0 * ε; + if (abs(λ1 - sλ1) < ε) λ1 -= sλ1 * ε; + φ0 = d3_geo_clipAntimeridianIntersect(λ0, φ0, λ1, φ1); + listener.point(sλ0, φ0); + listener.lineEnd(); + listener.lineStart(); + listener.point(sλ1, φ0); + clean = 0; + } + listener.point(λ0 = λ1, φ0 = φ1); + sλ0 = sλ1; + }, + lineEnd: function() { + listener.lineEnd(); + λ0 = φ0 = NaN; + }, + clean: function() { + return 2 - clean; + } + }; + } + function d3_geo_clipAntimeridianIntersect(λ0, φ0, λ1, φ1) { + var cosφ0, cosφ1, sinλ0_λ1 = Math.sin(λ0 - λ1); + return abs(sinλ0_λ1) > ε ? Math.atan((Math.sin(φ0) * (cosφ1 = Math.cos(φ1)) * Math.sin(λ1) - Math.sin(φ1) * (cosφ0 = Math.cos(φ0)) * Math.sin(λ0)) / (cosφ0 * cosφ1 * sinλ0_λ1)) : (φ0 + φ1) / 2; + } + function d3_geo_clipAntimeridianInterpolate(from, to, direction, listener) { + var φ; + if (from == null) { + φ = direction * halfπ; + listener.point(-π, φ); + listener.point(0, φ); + listener.point(π, φ); + listener.point(π, 0); + listener.point(π, -φ); + listener.point(0, -φ); + listener.point(-π, -φ); + listener.point(-π, 0); + listener.point(-π, φ); + } else if (abs(from[0] - to[0]) > ε) { + var s = from[0] < to[0] ? π : -π; + φ = direction * s / 2; + listener.point(-s, φ); + listener.point(0, φ); + listener.point(s, φ); + } else { + listener.point(to[0], to[1]); + } + } + function d3_geo_pointInPolygon(point, polygon) { + var meridian = point[0], parallel = point[1], meridianNormal = [ Math.sin(meridian), -Math.cos(meridian), 0 ], polarAngle = 0, winding = 0; + d3_geo_areaRingSum.reset(); + for (var i = 0, n = polygon.length; i < n; ++i) { + var ring = polygon[i], m = ring.length; + if (!m) continue; + var point0 = ring[0], λ0 = point0[0], φ0 = point0[1] / 2 + π / 4, sinφ0 = Math.sin(φ0), cosφ0 = Math.cos(φ0), j = 1; + while (true) { + if (j === m) j = 0; + point = ring[j]; + var λ = point[0], φ = point[1] / 2 + π / 4, sinφ = Math.sin(φ), cosφ = Math.cos(φ), dλ = λ - λ0, sdλ = dλ >= 0 ? 1 : -1, adλ = sdλ * dλ, antimeridian = adλ > π, k = sinφ0 * sinφ; + d3_geo_areaRingSum.add(Math.atan2(k * sdλ * Math.sin(adλ), cosφ0 * cosφ + k * Math.cos(adλ))); + polarAngle += antimeridian ? dλ + sdλ * τ : dλ; + if (antimeridian ^ λ0 >= meridian ^ λ >= meridian) { + var arc = d3_geo_cartesianCross(d3_geo_cartesian(point0), d3_geo_cartesian(point)); + d3_geo_cartesianNormalize(arc); + var intersection = d3_geo_cartesianCross(meridianNormal, arc); + d3_geo_cartesianNormalize(intersection); + var φarc = (antimeridian ^ dλ >= 0 ? -1 : 1) * d3_asin(intersection[2]); + if (parallel > φarc || parallel === φarc && (arc[0] || arc[1])) { + winding += antimeridian ^ dλ >= 0 ? 1 : -1; + } + } + if (!j++) break; + λ0 = λ, sinφ0 = sinφ, cosφ0 = cosφ, point0 = point; + } + } + return (polarAngle < -ε || polarAngle < ε && d3_geo_areaRingSum < -ε) ^ winding & 1; + } + function d3_geo_clipCircle(radius) { + var cr = Math.cos(radius), smallRadius = cr > 0, notHemisphere = abs(cr) > ε, interpolate = d3_geo_circleInterpolate(radius, 6 * d3_radians); + return d3_geo_clip(visible, clipLine, interpolate, smallRadius ? [ 0, -radius ] : [ -π, radius - π ]); + function visible(λ, φ) { + return Math.cos(λ) * Math.cos(φ) > cr; + } + function clipLine(listener) { + var point0, c0, v0, v00, clean; + return { + lineStart: function() { + v00 = v0 = false; + clean = 1; + }, + point: function(λ, φ) { + var point1 = [ λ, φ ], point2, v = visible(λ, φ), c = smallRadius ? v ? 0 : code(λ, φ) : v ? code(λ + (λ < 0 ? π : -π), φ) : 0; + if (!point0 && (v00 = v0 = v)) listener.lineStart(); + if (v !== v0) { + point2 = intersect(point0, point1); + if (d3_geo_sphericalEqual(point0, point2) || d3_geo_sphericalEqual(point1, point2)) { + point1[0] += ε; + point1[1] += ε; + v = visible(point1[0], point1[1]); + } + } + if (v !== v0) { + clean = 0; + if (v) { + listener.lineStart(); + point2 = intersect(point1, point0); + listener.point(point2[0], point2[1]); + } else { + point2 = intersect(point0, point1); + listener.point(point2[0], point2[1]); + listener.lineEnd(); + } + point0 = point2; + } else if (notHemisphere && point0 && smallRadius ^ v) { + var t; + if (!(c & c0) && (t = intersect(point1, point0, true))) { + clean = 0; + if (smallRadius) { + listener.lineStart(); + listener.point(t[0][0], t[0][1]); + listener.point(t[1][0], t[1][1]); + listener.lineEnd(); + } else { + listener.point(t[1][0], t[1][1]); + listener.lineEnd(); + listener.lineStart(); + listener.point(t[0][0], t[0][1]); + } + } + } + if (v && (!point0 || !d3_geo_sphericalEqual(point0, point1))) { + listener.point(point1[0], point1[1]); + } + point0 = point1, v0 = v, c0 = c; + }, + lineEnd: function() { + if (v0) listener.lineEnd(); + point0 = null; + }, + clean: function() { + return clean | (v00 && v0) << 1; + } + }; + } + function intersect(a, b, two) { + var pa = d3_geo_cartesian(a), pb = d3_geo_cartesian(b); + var n1 = [ 1, 0, 0 ], n2 = d3_geo_cartesianCross(pa, pb), n2n2 = d3_geo_cartesianDot(n2, n2), n1n2 = n2[0], determinant = n2n2 - n1n2 * n1n2; + if (!determinant) return !two && a; + var c1 = cr * n2n2 / determinant, c2 = -cr * n1n2 / determinant, n1xn2 = d3_geo_cartesianCross(n1, n2), A = d3_geo_cartesianScale(n1, c1), B = d3_geo_cartesianScale(n2, c2); + d3_geo_cartesianAdd(A, B); + var u = n1xn2, w = d3_geo_cartesianDot(A, u), uu = d3_geo_cartesianDot(u, u), t2 = w * w - uu * (d3_geo_cartesianDot(A, A) - 1); + if (t2 < 0) return; + var t = Math.sqrt(t2), q = d3_geo_cartesianScale(u, (-w - t) / uu); + d3_geo_cartesianAdd(q, A); + q = d3_geo_spherical(q); + if (!two) return q; + var λ0 = a[0], λ1 = b[0], φ0 = a[1], φ1 = b[1], z; + if (λ1 < λ0) z = λ0, λ0 = λ1, λ1 = z; + var δλ = λ1 - λ0, polar = abs(δλ - π) < ε, meridian = polar || δλ < ε; + if (!polar && φ1 < φ0) z = φ0, φ0 = φ1, φ1 = z; + if (meridian ? polar ? φ0 + φ1 > 0 ^ q[1] < (abs(q[0] - λ0) < ε ? φ0 : φ1) : φ0 <= q[1] && q[1] <= φ1 : δλ > π ^ (λ0 <= q[0] && q[0] <= λ1)) { + var q1 = d3_geo_cartesianScale(u, (-w + t) / uu); + d3_geo_cartesianAdd(q1, A); + return [ q, d3_geo_spherical(q1) ]; + } + } + function code(λ, φ) { + var r = smallRadius ? radius : π - radius, code = 0; + if (λ < -r) code |= 1; else if (λ > r) code |= 2; + if (φ < -r) code |= 4; else if (φ > r) code |= 8; + return code; + } + } + function d3_geom_clipLine(x0, y0, x1, y1) { + return function(line) { + var a = line.a, b = line.b, ax = a.x, ay = a.y, bx = b.x, by = b.y, t0 = 0, t1 = 1, dx = bx - ax, dy = by - ay, r; + r = x0 - ax; + if (!dx && r > 0) return; + r /= dx; + if (dx < 0) { + if (r < t0) return; + if (r < t1) t1 = r; + } else if (dx > 0) { + if (r > t1) return; + if (r > t0) t0 = r; + } + r = x1 - ax; + if (!dx && r < 0) return; + r /= dx; + if (dx < 0) { + if (r > t1) return; + if (r > t0) t0 = r; + } else if (dx > 0) { + if (r < t0) return; + if (r < t1) t1 = r; + } + r = y0 - ay; + if (!dy && r > 0) return; + r /= dy; + if (dy < 0) { + if (r < t0) return; + if (r < t1) t1 = r; + } else if (dy > 0) { + if (r > t1) return; + if (r > t0) t0 = r; + } + r = y1 - ay; + if (!dy && r < 0) return; + r /= dy; + if (dy < 0) { + if (r > t1) return; + if (r > t0) t0 = r; + } else if (dy > 0) { + if (r < t0) return; + if (r < t1) t1 = r; + } + if (t0 > 0) line.a = { + x: ax + t0 * dx, + y: ay + t0 * dy + }; + if (t1 < 1) line.b = { + x: ax + t1 * dx, + y: ay + t1 * dy + }; + return line; + }; + } + var d3_geo_clipExtentMAX = 1e9; + d3.geo.clipExtent = function() { + var x0, y0, x1, y1, stream, clip, clipExtent = { + stream: function(output) { + if (stream) stream.valid = false; + stream = clip(output); + stream.valid = true; + return stream; + }, + extent: function(_) { + if (!arguments.length) return [ [ x0, y0 ], [ x1, y1 ] ]; + clip = d3_geo_clipExtent(x0 = +_[0][0], y0 = +_[0][1], x1 = +_[1][0], y1 = +_[1][1]); + if (stream) stream.valid = false, stream = null; + return clipExtent; + } + }; + return clipExtent.extent([ [ 0, 0 ], [ 960, 500 ] ]); + }; + function d3_geo_clipExtent(x0, y0, x1, y1) { + return function(listener) { + var listener_ = listener, bufferListener = d3_geo_clipBufferListener(), clipLine = d3_geom_clipLine(x0, y0, x1, y1), segments, polygon, ring; + var clip = { + point: point, + lineStart: lineStart, + lineEnd: lineEnd, + polygonStart: function() { + listener = bufferListener; + segments = []; + polygon = []; + clean = true; + }, + polygonEnd: function() { + listener = listener_; + segments = d3.merge(segments); + var clipStartInside = insidePolygon([ x0, y1 ]), inside = clean && clipStartInside, visible = segments.length; + if (inside || visible) { + listener.polygonStart(); + if (inside) { + listener.lineStart(); + interpolate(null, null, 1, listener); + listener.lineEnd(); + } + if (visible) { + d3_geo_clipPolygon(segments, compare, clipStartInside, interpolate, listener); + } + listener.polygonEnd(); + } + segments = polygon = ring = null; + } + }; + function insidePolygon(p) { + var wn = 0, n = polygon.length, y = p[1]; + for (var i = 0; i < n; ++i) { + for (var j = 1, v = polygon[i], m = v.length, a = v[0], b; j < m; ++j) { + b = v[j]; + if (a[1] <= y) { + if (b[1] > y && d3_cross2d(a, b, p) > 0) ++wn; + } else { + if (b[1] <= y && d3_cross2d(a, b, p) < 0) --wn; + } + a = b; + } + } + return wn !== 0; + } + function interpolate(from, to, direction, listener) { + var a = 0, a1 = 0; + if (from == null || (a = corner(from, direction)) !== (a1 = corner(to, direction)) || comparePoints(from, to) < 0 ^ direction > 0) { + do { + listener.point(a === 0 || a === 3 ? x0 : x1, a > 1 ? y1 : y0); + } while ((a = (a + direction + 4) % 4) !== a1); + } else { + listener.point(to[0], to[1]); + } + } + function pointVisible(x, y) { + return x0 <= x && x <= x1 && y0 <= y && y <= y1; + } + function point(x, y) { + if (pointVisible(x, y)) listener.point(x, y); + } + var x__, y__, v__, x_, y_, v_, first, clean; + function lineStart() { + clip.point = linePoint; + if (polygon) polygon.push(ring = []); + first = true; + v_ = false; + x_ = y_ = NaN; + } + function lineEnd() { + if (segments) { + linePoint(x__, y__); + if (v__ && v_) bufferListener.rejoin(); + segments.push(bufferListener.buffer()); + } + clip.point = point; + if (v_) listener.lineEnd(); + } + function linePoint(x, y) { + x = Math.max(-d3_geo_clipExtentMAX, Math.min(d3_geo_clipExtentMAX, x)); + y = Math.max(-d3_geo_clipExtentMAX, Math.min(d3_geo_clipExtentMAX, y)); + var v = pointVisible(x, y); + if (polygon) ring.push([ x, y ]); + if (first) { + x__ = x, y__ = y, v__ = v; + first = false; + if (v) { + listener.lineStart(); + listener.point(x, y); + } + } else { + if (v && v_) listener.point(x, y); else { + var l = { + a: { + x: x_, + y: y_ + }, + b: { + x: x, + y: y + } + }; + if (clipLine(l)) { + if (!v_) { + listener.lineStart(); + listener.point(l.a.x, l.a.y); + } + listener.point(l.b.x, l.b.y); + if (!v) listener.lineEnd(); + clean = false; + } else if (v) { + listener.lineStart(); + listener.point(x, y); + clean = false; + } + } + } + x_ = x, y_ = y, v_ = v; + } + return clip; + }; + function corner(p, direction) { + return abs(p[0] - x0) < ε ? direction > 0 ? 0 : 3 : abs(p[0] - x1) < ε ? direction > 0 ? 2 : 1 : abs(p[1] - y0) < ε ? direction > 0 ? 1 : 0 : direction > 0 ? 3 : 2; + } + function compare(a, b) { + return comparePoints(a.x, b.x); + } + function comparePoints(a, b) { + var ca = corner(a, 1), cb = corner(b, 1); + return ca !== cb ? ca - cb : ca === 0 ? b[1] - a[1] : ca === 1 ? a[0] - b[0] : ca === 2 ? a[1] - b[1] : b[0] - a[0]; + } + } + function d3_geo_conic(projectAt) { + var φ0 = 0, φ1 = π / 3, m = d3_geo_projectionMutator(projectAt), p = m(φ0, φ1); + p.parallels = function(_) { + if (!arguments.length) return [ φ0 / π * 180, φ1 / π * 180 ]; + return m(φ0 = _[0] * π / 180, φ1 = _[1] * π / 180); + }; + return p; + } + function d3_geo_conicEqualArea(φ0, φ1) { + var sinφ0 = Math.sin(φ0), n = (sinφ0 + Math.sin(φ1)) / 2, C = 1 + sinφ0 * (2 * n - sinφ0), ρ0 = Math.sqrt(C) / n; + function forward(λ, φ) { + var ρ = Math.sqrt(C - 2 * n * Math.sin(φ)) / n; + return [ ρ * Math.sin(λ *= n), ρ0 - ρ * Math.cos(λ) ]; + } + forward.invert = function(x, y) { + var ρ0_y = ρ0 - y; + return [ Math.atan2(x, ρ0_y) / n, d3_asin((C - (x * x + ρ0_y * ρ0_y) * n * n) / (2 * n)) ]; + }; + return forward; + } + (d3.geo.conicEqualArea = function() { + return d3_geo_conic(d3_geo_conicEqualArea); + }).raw = d3_geo_conicEqualArea; + d3.geo.albers = function() { + return d3.geo.conicEqualArea().rotate([ 96, 0 ]).center([ -.6, 38.7 ]).parallels([ 29.5, 45.5 ]).scale(1070); + }; + d3.geo.albersUsa = function() { + var lower48 = d3.geo.albers(); + var alaska = d3.geo.conicEqualArea().rotate([ 154, 0 ]).center([ -2, 58.5 ]).parallels([ 55, 65 ]); + var hawaii = d3.geo.conicEqualArea().rotate([ 157, 0 ]).center([ -3, 19.9 ]).parallels([ 8, 18 ]); + var point, pointStream = { + point: function(x, y) { + point = [ x, y ]; + } + }, lower48Point, alaskaPoint, hawaiiPoint; + function albersUsa(coordinates) { + var x = coordinates[0], y = coordinates[1]; + point = null; + (lower48Point(x, y), point) || (alaskaPoint(x, y), point) || hawaiiPoint(x, y); + return point; + } + albersUsa.invert = function(coordinates) { + var k = lower48.scale(), t = lower48.translate(), x = (coordinates[0] - t[0]) / k, y = (coordinates[1] - t[1]) / k; + return (y >= .12 && y < .234 && x >= -.425 && x < -.214 ? alaska : y >= .166 && y < .234 && x >= -.214 && x < -.115 ? hawaii : lower48).invert(coordinates); + }; + albersUsa.stream = function(stream) { + var lower48Stream = lower48.stream(stream), alaskaStream = alaska.stream(stream), hawaiiStream = hawaii.stream(stream); + return { + point: function(x, y) { + lower48Stream.point(x, y); + alaskaStream.point(x, y); + hawaiiStream.point(x, y); + }, + sphere: function() { + lower48Stream.sphere(); + alaskaStream.sphere(); + hawaiiStream.sphere(); + }, + lineStart: function() { + lower48Stream.lineStart(); + alaskaStream.lineStart(); + hawaiiStream.lineStart(); + }, + lineEnd: function() { + lower48Stream.lineEnd(); + alaskaStream.lineEnd(); + hawaiiStream.lineEnd(); + }, + polygonStart: function() { + lower48Stream.polygonStart(); + alaskaStream.polygonStart(); + hawaiiStream.polygonStart(); + }, + polygonEnd: function() { + lower48Stream.polygonEnd(); + alaskaStream.polygonEnd(); + hawaiiStream.polygonEnd(); + } + }; + }; + albersUsa.precision = function(_) { + if (!arguments.length) return lower48.precision(); + lower48.precision(_); + alaska.precision(_); + hawaii.precision(_); + return albersUsa; + }; + albersUsa.scale = function(_) { + if (!arguments.length) return lower48.scale(); + lower48.scale(_); + alaska.scale(_ * .35); + hawaii.scale(_); + return albersUsa.translate(lower48.translate()); + }; + albersUsa.translate = function(_) { + if (!arguments.length) return lower48.translate(); + var k = lower48.scale(), x = +_[0], y = +_[1]; + lower48Point = lower48.translate(_).clipExtent([ [ x - .455 * k, y - .238 * k ], [ x + .455 * k, y + .238 * k ] ]).stream(pointStream).point; + alaskaPoint = alaska.translate([ x - .307 * k, y + .201 * k ]).clipExtent([ [ x - .425 * k + ε, y + .12 * k + ε ], [ x - .214 * k - ε, y + .234 * k - ε ] ]).stream(pointStream).point; + hawaiiPoint = hawaii.translate([ x - .205 * k, y + .212 * k ]).clipExtent([ [ x - .214 * k + ε, y + .166 * k + ε ], [ x - .115 * k - ε, y + .234 * k - ε ] ]).stream(pointStream).point; + return albersUsa; + }; + return albersUsa.scale(1070); + }; + var d3_geo_pathAreaSum, d3_geo_pathAreaPolygon, d3_geo_pathArea = { + point: d3_noop, + lineStart: d3_noop, + lineEnd: d3_noop, + polygonStart: function() { + d3_geo_pathAreaPolygon = 0; + d3_geo_pathArea.lineStart = d3_geo_pathAreaRingStart; + }, + polygonEnd: function() { + d3_geo_pathArea.lineStart = d3_geo_pathArea.lineEnd = d3_geo_pathArea.point = d3_noop; + d3_geo_pathAreaSum += abs(d3_geo_pathAreaPolygon / 2); + } + }; + function d3_geo_pathAreaRingStart() { + var x00, y00, x0, y0; + d3_geo_pathArea.point = function(x, y) { + d3_geo_pathArea.point = nextPoint; + x00 = x0 = x, y00 = y0 = y; + }; + function nextPoint(x, y) { + d3_geo_pathAreaPolygon += y0 * x - x0 * y; + x0 = x, y0 = y; + } + d3_geo_pathArea.lineEnd = function() { + nextPoint(x00, y00); + }; + } + var d3_geo_pathBoundsX0, d3_geo_pathBoundsY0, d3_geo_pathBoundsX1, d3_geo_pathBoundsY1; + var d3_geo_pathBounds = { + point: d3_geo_pathBoundsPoint, + lineStart: d3_noop, + lineEnd: d3_noop, + polygonStart: d3_noop, + polygonEnd: d3_noop + }; + function d3_geo_pathBoundsPoint(x, y) { + if (x < d3_geo_pathBoundsX0) d3_geo_pathBoundsX0 = x; + if (x > d3_geo_pathBoundsX1) d3_geo_pathBoundsX1 = x; + if (y < d3_geo_pathBoundsY0) d3_geo_pathBoundsY0 = y; + if (y > d3_geo_pathBoundsY1) d3_geo_pathBoundsY1 = y; + } + function d3_geo_pathBuffer() { + var pointCircle = d3_geo_pathBufferCircle(4.5), buffer = []; + var stream = { + point: point, + lineStart: function() { + stream.point = pointLineStart; + }, + lineEnd: lineEnd, + polygonStart: function() { + stream.lineEnd = lineEndPolygon; + }, + polygonEnd: function() { + stream.lineEnd = lineEnd; + stream.point = point; + }, + pointRadius: function(_) { + pointCircle = d3_geo_pathBufferCircle(_); + return stream; + }, + result: function() { + if (buffer.length) { + var result = buffer.join(""); + buffer = []; + return result; + } + } + }; + function point(x, y) { + buffer.push("M", x, ",", y, pointCircle); + } + function pointLineStart(x, y) { + buffer.push("M", x, ",", y); + stream.point = pointLine; + } + function pointLine(x, y) { + buffer.push("L", x, ",", y); + } + function lineEnd() { + stream.point = point; + } + function lineEndPolygon() { + buffer.push("Z"); + } + return stream; + } + function d3_geo_pathBufferCircle(radius) { + return "m0," + radius + "a" + radius + "," + radius + " 0 1,1 0," + -2 * radius + "a" + radius + "," + radius + " 0 1,1 0," + 2 * radius + "z"; + } + var d3_geo_pathCentroid = { + point: d3_geo_pathCentroidPoint, + lineStart: d3_geo_pathCentroidLineStart, + lineEnd: d3_geo_pathCentroidLineEnd, + polygonStart: function() { + d3_geo_pathCentroid.lineStart = d3_geo_pathCentroidRingStart; + }, + polygonEnd: function() { + d3_geo_pathCentroid.point = d3_geo_pathCentroidPoint; + d3_geo_pathCentroid.lineStart = d3_geo_pathCentroidLineStart; + d3_geo_pathCentroid.lineEnd = d3_geo_pathCentroidLineEnd; + } + }; + function d3_geo_pathCentroidPoint(x, y) { + d3_geo_centroidX0 += x; + d3_geo_centroidY0 += y; + ++d3_geo_centroidZ0; + } + function d3_geo_pathCentroidLineStart() { + var x0, y0; + d3_geo_pathCentroid.point = function(x, y) { + d3_geo_pathCentroid.point = nextPoint; + d3_geo_pathCentroidPoint(x0 = x, y0 = y); + }; + function nextPoint(x, y) { + var dx = x - x0, dy = y - y0, z = Math.sqrt(dx * dx + dy * dy); + d3_geo_centroidX1 += z * (x0 + x) / 2; + d3_geo_centroidY1 += z * (y0 + y) / 2; + d3_geo_centroidZ1 += z; + d3_geo_pathCentroidPoint(x0 = x, y0 = y); + } + } + function d3_geo_pathCentroidLineEnd() { + d3_geo_pathCentroid.point = d3_geo_pathCentroidPoint; + } + function d3_geo_pathCentroidRingStart() { + var x00, y00, x0, y0; + d3_geo_pathCentroid.point = function(x, y) { + d3_geo_pathCentroid.point = nextPoint; + d3_geo_pathCentroidPoint(x00 = x0 = x, y00 = y0 = y); + }; + function nextPoint(x, y) { + var dx = x - x0, dy = y - y0, z = Math.sqrt(dx * dx + dy * dy); + d3_geo_centroidX1 += z * (x0 + x) / 2; + d3_geo_centroidY1 += z * (y0 + y) / 2; + d3_geo_centroidZ1 += z; + z = y0 * x - x0 * y; + d3_geo_centroidX2 += z * (x0 + x); + d3_geo_centroidY2 += z * (y0 + y); + d3_geo_centroidZ2 += z * 3; + d3_geo_pathCentroidPoint(x0 = x, y0 = y); + } + d3_geo_pathCentroid.lineEnd = function() { + nextPoint(x00, y00); + }; + } + function d3_geo_pathContext(context) { + var pointRadius = 4.5; + var stream = { + point: point, + lineStart: function() { + stream.point = pointLineStart; + }, + lineEnd: lineEnd, + polygonStart: function() { + stream.lineEnd = lineEndPolygon; + }, + polygonEnd: function() { + stream.lineEnd = lineEnd; + stream.point = point; + }, + pointRadius: function(_) { + pointRadius = _; + return stream; + }, + result: d3_noop + }; + function point(x, y) { + context.moveTo(x + pointRadius, y); + context.arc(x, y, pointRadius, 0, τ); + } + function pointLineStart(x, y) { + context.moveTo(x, y); + stream.point = pointLine; + } + function pointLine(x, y) { + context.lineTo(x, y); + } + function lineEnd() { + stream.point = point; + } + function lineEndPolygon() { + context.closePath(); + } + return stream; + } + function d3_geo_resample(project) { + var δ2 = .5, cosMinDistance = Math.cos(30 * d3_radians), maxDepth = 16; + function resample(stream) { + return (maxDepth ? resampleRecursive : resampleNone)(stream); + } + function resampleNone(stream) { + return d3_geo_transformPoint(stream, function(x, y) { + x = project(x, y); + stream.point(x[0], x[1]); + }); + } + function resampleRecursive(stream) { + var λ00, φ00, x00, y00, a00, b00, c00, λ0, x0, y0, a0, b0, c0; + var resample = { + point: point, + lineStart: lineStart, + lineEnd: lineEnd, + polygonStart: function() { + stream.polygonStart(); + resample.lineStart = ringStart; + }, + polygonEnd: function() { + stream.polygonEnd(); + resample.lineStart = lineStart; + } + }; + function point(x, y) { + x = project(x, y); + stream.point(x[0], x[1]); + } + function lineStart() { + x0 = NaN; + resample.point = linePoint; + stream.lineStart(); + } + function linePoint(λ, φ) { + var c = d3_geo_cartesian([ λ, φ ]), p = project(λ, φ); + resampleLineTo(x0, y0, λ0, a0, b0, c0, x0 = p[0], y0 = p[1], λ0 = λ, a0 = c[0], b0 = c[1], c0 = c[2], maxDepth, stream); + stream.point(x0, y0); + } + function lineEnd() { + resample.point = point; + stream.lineEnd(); + } + function ringStart() { + lineStart(); + resample.point = ringPoint; + resample.lineEnd = ringEnd; + } + function ringPoint(λ, φ) { + linePoint(λ00 = λ, φ00 = φ), x00 = x0, y00 = y0, a00 = a0, b00 = b0, c00 = c0; + resample.point = linePoint; + } + function ringEnd() { + resampleLineTo(x0, y0, λ0, a0, b0, c0, x00, y00, λ00, a00, b00, c00, maxDepth, stream); + resample.lineEnd = lineEnd; + lineEnd(); + } + return resample; + } + function resampleLineTo(x0, y0, λ0, a0, b0, c0, x1, y1, λ1, a1, b1, c1, depth, stream) { + var dx = x1 - x0, dy = y1 - y0, d2 = dx * dx + dy * dy; + if (d2 > 4 * δ2 && depth--) { + var a = a0 + a1, b = b0 + b1, c = c0 + c1, m = Math.sqrt(a * a + b * b + c * c), φ2 = Math.asin(c /= m), λ2 = abs(abs(c) - 1) < ε || abs(λ0 - λ1) < ε ? (λ0 + λ1) / 2 : Math.atan2(b, a), p = project(λ2, φ2), x2 = p[0], y2 = p[1], dx2 = x2 - x0, dy2 = y2 - y0, dz = dy * dx2 - dx * dy2; + if (dz * dz / d2 > δ2 || abs((dx * dx2 + dy * dy2) / d2 - .5) > .3 || a0 * a1 + b0 * b1 + c0 * c1 < cosMinDistance) { + resampleLineTo(x0, y0, λ0, a0, b0, c0, x2, y2, λ2, a /= m, b /= m, c, depth, stream); + stream.point(x2, y2); + resampleLineTo(x2, y2, λ2, a, b, c, x1, y1, λ1, a1, b1, c1, depth, stream); + } + } + } + resample.precision = function(_) { + if (!arguments.length) return Math.sqrt(δ2); + maxDepth = (δ2 = _ * _) > 0 && 16; + return resample; + }; + return resample; + } + d3.geo.path = function() { + var pointRadius = 4.5, projection, context, projectStream, contextStream, cacheStream; + function path(object) { + if (object) { + if (typeof pointRadius === "function") contextStream.pointRadius(+pointRadius.apply(this, arguments)); + if (!cacheStream || !cacheStream.valid) cacheStream = projectStream(contextStream); + d3.geo.stream(object, cacheStream); + } + return contextStream.result(); + } + path.area = function(object) { + d3_geo_pathAreaSum = 0; + d3.geo.stream(object, projectStream(d3_geo_pathArea)); + return d3_geo_pathAreaSum; + }; + path.centroid = function(object) { + d3_geo_centroidX0 = d3_geo_centroidY0 = d3_geo_centroidZ0 = d3_geo_centroidX1 = d3_geo_centroidY1 = d3_geo_centroidZ1 = d3_geo_centroidX2 = d3_geo_centroidY2 = d3_geo_centroidZ2 = 0; + d3.geo.stream(object, projectStream(d3_geo_pathCentroid)); + return d3_geo_centroidZ2 ? [ d3_geo_centroidX2 / d3_geo_centroidZ2, d3_geo_centroidY2 / d3_geo_centroidZ2 ] : d3_geo_centroidZ1 ? [ d3_geo_centroidX1 / d3_geo_centroidZ1, d3_geo_centroidY1 / d3_geo_centroidZ1 ] : d3_geo_centroidZ0 ? [ d3_geo_centroidX0 / d3_geo_centroidZ0, d3_geo_centroidY0 / d3_geo_centroidZ0 ] : [ NaN, NaN ]; + }; + path.bounds = function(object) { + d3_geo_pathBoundsX1 = d3_geo_pathBoundsY1 = -(d3_geo_pathBoundsX0 = d3_geo_pathBoundsY0 = Infinity); + d3.geo.stream(object, projectStream(d3_geo_pathBounds)); + return [ [ d3_geo_pathBoundsX0, d3_geo_pathBoundsY0 ], [ d3_geo_pathBoundsX1, d3_geo_pathBoundsY1 ] ]; + }; + path.projection = function(_) { + if (!arguments.length) return projection; + projectStream = (projection = _) ? _.stream || d3_geo_pathProjectStream(_) : d3_identity; + return reset(); + }; + path.context = function(_) { + if (!arguments.length) return context; + contextStream = (context = _) == null ? new d3_geo_pathBuffer() : new d3_geo_pathContext(_); + if (typeof pointRadius !== "function") contextStream.pointRadius(pointRadius); + return reset(); + }; + path.pointRadius = function(_) { + if (!arguments.length) return pointRadius; + pointRadius = typeof _ === "function" ? _ : (contextStream.pointRadius(+_), +_); + return path; + }; + function reset() { + cacheStream = null; + return path; + } + return path.projection(d3.geo.albersUsa()).context(null); + }; + function d3_geo_pathProjectStream(project) { + var resample = d3_geo_resample(function(x, y) { + return project([ x * d3_degrees, y * d3_degrees ]); + }); + return function(stream) { + return d3_geo_projectionRadians(resample(stream)); + }; + } + d3.geo.transform = function(methods) { + return { + stream: function(stream) { + var transform = new d3_geo_transform(stream); + for (var k in methods) transform[k] = methods[k]; + return transform; + } + }; + }; + function d3_geo_transform(stream) { + this.stream = stream; + } + d3_geo_transform.prototype = { + point: function(x, y) { + this.stream.point(x, y); + }, + sphere: function() { + this.stream.sphere(); + }, + lineStart: function() { + this.stream.lineStart(); + }, + lineEnd: function() { + this.stream.lineEnd(); + }, + polygonStart: function() { + this.stream.polygonStart(); + }, + polygonEnd: function() { + this.stream.polygonEnd(); + } + }; + function d3_geo_transformPoint(stream, point) { + return { + point: point, + sphere: function() { + stream.sphere(); + }, + lineStart: function() { + stream.lineStart(); + }, + lineEnd: function() { + stream.lineEnd(); + }, + polygonStart: function() { + stream.polygonStart(); + }, + polygonEnd: function() { + stream.polygonEnd(); + } + }; + } + d3.geo.projection = d3_geo_projection; + d3.geo.projectionMutator = d3_geo_projectionMutator; + function d3_geo_projection(project) { + return d3_geo_projectionMutator(function() { + return project; + })(); + } + function d3_geo_projectionMutator(projectAt) { + var project, rotate, projectRotate, projectResample = d3_geo_resample(function(x, y) { + x = project(x, y); + return [ x[0] * k + δx, δy - x[1] * k ]; + }), k = 150, x = 480, y = 250, λ = 0, φ = 0, δλ = 0, δφ = 0, δγ = 0, δx, δy, preclip = d3_geo_clipAntimeridian, postclip = d3_identity, clipAngle = null, clipExtent = null, stream; + function projection(point) { + point = projectRotate(point[0] * d3_radians, point[1] * d3_radians); + return [ point[0] * k + δx, δy - point[1] * k ]; + } + function invert(point) { + point = projectRotate.invert((point[0] - δx) / k, (δy - point[1]) / k); + return point && [ point[0] * d3_degrees, point[1] * d3_degrees ]; + } + projection.stream = function(output) { + if (stream) stream.valid = false; + stream = d3_geo_projectionRadians(preclip(rotate, projectResample(postclip(output)))); + stream.valid = true; + return stream; + }; + projection.clipAngle = function(_) { + if (!arguments.length) return clipAngle; + preclip = _ == null ? (clipAngle = _, d3_geo_clipAntimeridian) : d3_geo_clipCircle((clipAngle = +_) * d3_radians); + return invalidate(); + }; + projection.clipExtent = function(_) { + if (!arguments.length) return clipExtent; + clipExtent = _; + postclip = _ ? d3_geo_clipExtent(_[0][0], _[0][1], _[1][0], _[1][1]) : d3_identity; + return invalidate(); + }; + projection.scale = function(_) { + if (!arguments.length) return k; + k = +_; + return reset(); + }; + projection.translate = function(_) { + if (!arguments.length) return [ x, y ]; + x = +_[0]; + y = +_[1]; + return reset(); + }; + projection.center = function(_) { + if (!arguments.length) return [ λ * d3_degrees, φ * d3_degrees ]; + λ = _[0] % 360 * d3_radians; + φ = _[1] % 360 * d3_radians; + return reset(); + }; + projection.rotate = function(_) { + if (!arguments.length) return [ δλ * d3_degrees, δφ * d3_degrees, δγ * d3_degrees ]; + δλ = _[0] % 360 * d3_radians; + δφ = _[1] % 360 * d3_radians; + δγ = _.length > 2 ? _[2] % 360 * d3_radians : 0; + return reset(); + }; + d3.rebind(projection, projectResample, "precision"); + function reset() { + projectRotate = d3_geo_compose(rotate = d3_geo_rotation(δλ, δφ, δγ), project); + var center = project(λ, φ); + δx = x - center[0] * k; + δy = y + center[1] * k; + return invalidate(); + } + function invalidate() { + if (stream) stream.valid = false, stream = null; + return projection; + } + return function() { + project = projectAt.apply(this, arguments); + projection.invert = project.invert && invert; + return reset(); + }; + } + function d3_geo_projectionRadians(stream) { + return d3_geo_transformPoint(stream, function(x, y) { + stream.point(x * d3_radians, y * d3_radians); + }); + } + function d3_geo_equirectangular(λ, φ) { + return [ λ, φ ]; + } + (d3.geo.equirectangular = function() { + return d3_geo_projection(d3_geo_equirectangular); + }).raw = d3_geo_equirectangular.invert = d3_geo_equirectangular; + d3.geo.rotation = function(rotate) { + rotate = d3_geo_rotation(rotate[0] % 360 * d3_radians, rotate[1] * d3_radians, rotate.length > 2 ? rotate[2] * d3_radians : 0); + function forward(coordinates) { + coordinates = rotate(coordinates[0] * d3_radians, coordinates[1] * d3_radians); + return coordinates[0] *= d3_degrees, coordinates[1] *= d3_degrees, coordinates; + } + forward.invert = function(coordinates) { + coordinates = rotate.invert(coordinates[0] * d3_radians, coordinates[1] * d3_radians); + return coordinates[0] *= d3_degrees, coordinates[1] *= d3_degrees, coordinates; + }; + return forward; + }; + function d3_geo_identityRotation(λ, φ) { + return [ λ > π ? λ - τ : λ < -π ? λ + τ : λ, φ ]; + } + d3_geo_identityRotation.invert = d3_geo_equirectangular; + function d3_geo_rotation(δλ, δφ, δγ) { + return δλ ? δφ || δγ ? d3_geo_compose(d3_geo_rotationλ(δλ), d3_geo_rotationφγ(δφ, δγ)) : d3_geo_rotationλ(δλ) : δφ || δγ ? d3_geo_rotationφγ(δφ, δγ) : d3_geo_identityRotation; + } + function d3_geo_forwardRotationλ(δλ) { + return function(λ, φ) { + return λ += δλ, [ λ > π ? λ - τ : λ < -π ? λ + τ : λ, φ ]; + }; + } + function d3_geo_rotationλ(δλ) { + var rotation = d3_geo_forwardRotationλ(δλ); + rotation.invert = d3_geo_forwardRotationλ(-δλ); + return rotation; + } + function d3_geo_rotationφγ(δφ, δγ) { + var cosδφ = Math.cos(δφ), sinδφ = Math.sin(δφ), cosδγ = Math.cos(δγ), sinδγ = Math.sin(δγ); + function rotation(λ, φ) { + var cosφ = Math.cos(φ), x = Math.cos(λ) * cosφ, y = Math.sin(λ) * cosφ, z = Math.sin(φ), k = z * cosδφ + x * sinδφ; + return [ Math.atan2(y * cosδγ - k * sinδγ, x * cosδφ - z * sinδφ), d3_asin(k * cosδγ + y * sinδγ) ]; + } + rotation.invert = function(λ, φ) { + var cosφ = Math.cos(φ), x = Math.cos(λ) * cosφ, y = Math.sin(λ) * cosφ, z = Math.sin(φ), k = z * cosδγ - y * sinδγ; + return [ Math.atan2(y * cosδγ + z * sinδγ, x * cosδφ + k * sinδφ), d3_asin(k * cosδφ - x * sinδφ) ]; + }; + return rotation; + } + d3.geo.circle = function() { + var origin = [ 0, 0 ], angle, precision = 6, interpolate; + function circle() { + var center = typeof origin === "function" ? origin.apply(this, arguments) : origin, rotate = d3_geo_rotation(-center[0] * d3_radians, -center[1] * d3_radians, 0).invert, ring = []; + interpolate(null, null, 1, { + point: function(x, y) { + ring.push(x = rotate(x, y)); + x[0] *= d3_degrees, x[1] *= d3_degrees; + } + }); + return { + type: "Polygon", + coordinates: [ ring ] + }; + } + circle.origin = function(x) { + if (!arguments.length) return origin; + origin = x; + return circle; + }; + circle.angle = function(x) { + if (!arguments.length) return angle; + interpolate = d3_geo_circleInterpolate((angle = +x) * d3_radians, precision * d3_radians); + return circle; + }; + circle.precision = function(_) { + if (!arguments.length) return precision; + interpolate = d3_geo_circleInterpolate(angle * d3_radians, (precision = +_) * d3_radians); + return circle; + }; + return circle.angle(90); + }; + function d3_geo_circleInterpolate(radius, precision) { + var cr = Math.cos(radius), sr = Math.sin(radius); + return function(from, to, direction, listener) { + var step = direction * precision; + if (from != null) { + from = d3_geo_circleAngle(cr, from); + to = d3_geo_circleAngle(cr, to); + if (direction > 0 ? from < to : from > to) from += direction * τ; + } else { + from = radius + direction * τ; + to = radius - .5 * step; + } + for (var point, t = from; direction > 0 ? t > to : t < to; t -= step) { + listener.point((point = d3_geo_spherical([ cr, -sr * Math.cos(t), -sr * Math.sin(t) ]))[0], point[1]); + } + }; + } + function d3_geo_circleAngle(cr, point) { + var a = d3_geo_cartesian(point); + a[0] -= cr; + d3_geo_cartesianNormalize(a); + var angle = d3_acos(-a[1]); + return ((-a[2] < 0 ? -angle : angle) + 2 * Math.PI - ε) % (2 * Math.PI); + } + d3.geo.distance = function(a, b) { + var Δλ = (b[0] - a[0]) * d3_radians, φ0 = a[1] * d3_radians, φ1 = b[1] * d3_radians, sinΔλ = Math.sin(Δλ), cosΔλ = Math.cos(Δλ), sinφ0 = Math.sin(φ0), cosφ0 = Math.cos(φ0), sinφ1 = Math.sin(φ1), cosφ1 = Math.cos(φ1), t; + return Math.atan2(Math.sqrt((t = cosφ1 * sinΔλ) * t + (t = cosφ0 * sinφ1 - sinφ0 * cosφ1 * cosΔλ) * t), sinφ0 * sinφ1 + cosφ0 * cosφ1 * cosΔλ); + }; + d3.geo.graticule = function() { + var x1, x0, X1, X0, y1, y0, Y1, Y0, dx = 10, dy = dx, DX = 90, DY = 360, x, y, X, Y, precision = 2.5; + function graticule() { + return { + type: "MultiLineString", + coordinates: lines() + }; + } + function lines() { + return d3.range(Math.ceil(X0 / DX) * DX, X1, DX).map(X).concat(d3.range(Math.ceil(Y0 / DY) * DY, Y1, DY).map(Y)).concat(d3.range(Math.ceil(x0 / dx) * dx, x1, dx).filter(function(x) { + return abs(x % DX) > ε; + }).map(x)).concat(d3.range(Math.ceil(y0 / dy) * dy, y1, dy).filter(function(y) { + return abs(y % DY) > ε; + }).map(y)); + } + graticule.lines = function() { + return lines().map(function(coordinates) { + return { + type: "LineString", + coordinates: coordinates + }; + }); + }; + graticule.outline = function() { + return { + type: "Polygon", + coordinates: [ X(X0).concat(Y(Y1).slice(1), X(X1).reverse().slice(1), Y(Y0).reverse().slice(1)) ] + }; + }; + graticule.extent = function(_) { + if (!arguments.length) return graticule.minorExtent(); + return graticule.majorExtent(_).minorExtent(_); + }; + graticule.majorExtent = function(_) { + if (!arguments.length) return [ [ X0, Y0 ], [ X1, Y1 ] ]; + X0 = +_[0][0], X1 = +_[1][0]; + Y0 = +_[0][1], Y1 = +_[1][1]; + if (X0 > X1) _ = X0, X0 = X1, X1 = _; + if (Y0 > Y1) _ = Y0, Y0 = Y1, Y1 = _; + return graticule.precision(precision); + }; + graticule.minorExtent = function(_) { + if (!arguments.length) return [ [ x0, y0 ], [ x1, y1 ] ]; + x0 = +_[0][0], x1 = +_[1][0]; + y0 = +_[0][1], y1 = +_[1][1]; + if (x0 > x1) _ = x0, x0 = x1, x1 = _; + if (y0 > y1) _ = y0, y0 = y1, y1 = _; + return graticule.precision(precision); + }; + graticule.step = function(_) { + if (!arguments.length) return graticule.minorStep(); + return graticule.majorStep(_).minorStep(_); + }; + graticule.majorStep = function(_) { + if (!arguments.length) return [ DX, DY ]; + DX = +_[0], DY = +_[1]; + return graticule; + }; + graticule.minorStep = function(_) { + if (!arguments.length) return [ dx, dy ]; + dx = +_[0], dy = +_[1]; + return graticule; + }; + graticule.precision = function(_) { + if (!arguments.length) return precision; + precision = +_; + x = d3_geo_graticuleX(y0, y1, 90); + y = d3_geo_graticuleY(x0, x1, precision); + X = d3_geo_graticuleX(Y0, Y1, 90); + Y = d3_geo_graticuleY(X0, X1, precision); + return graticule; + }; + return graticule.majorExtent([ [ -180, -90 + ε ], [ 180, 90 - ε ] ]).minorExtent([ [ -180, -80 - ε ], [ 180, 80 + ε ] ]); + }; + function d3_geo_graticuleX(y0, y1, dy) { + var y = d3.range(y0, y1 - ε, dy).concat(y1); + return function(x) { + return y.map(function(y) { + return [ x, y ]; + }); + }; + } + function d3_geo_graticuleY(x0, x1, dx) { + var x = d3.range(x0, x1 - ε, dx).concat(x1); + return function(y) { + return x.map(function(x) { + return [ x, y ]; + }); + }; + } + function d3_source(d) { + return d.source; + } + function d3_target(d) { + return d.target; + } + d3.geo.greatArc = function() { + var source = d3_source, source_, target = d3_target, target_; + function greatArc() { + return { + type: "LineString", + coordinates: [ source_ || source.apply(this, arguments), target_ || target.apply(this, arguments) ] + }; + } + greatArc.distance = function() { + return d3.geo.distance(source_ || source.apply(this, arguments), target_ || target.apply(this, arguments)); + }; + greatArc.source = function(_) { + if (!arguments.length) return source; + source = _, source_ = typeof _ === "function" ? null : _; + return greatArc; + }; + greatArc.target = function(_) { + if (!arguments.length) return target; + target = _, target_ = typeof _ === "function" ? null : _; + return greatArc; + }; + greatArc.precision = function() { + return arguments.length ? greatArc : 0; + }; + return greatArc; + }; + d3.geo.interpolate = function(source, target) { + return d3_geo_interpolate(source[0] * d3_radians, source[1] * d3_radians, target[0] * d3_radians, target[1] * d3_radians); + }; + function d3_geo_interpolate(x0, y0, x1, y1) { + var cy0 = Math.cos(y0), sy0 = Math.sin(y0), cy1 = Math.cos(y1), sy1 = Math.sin(y1), kx0 = cy0 * Math.cos(x0), ky0 = cy0 * Math.sin(x0), kx1 = cy1 * Math.cos(x1), ky1 = cy1 * Math.sin(x1), d = 2 * Math.asin(Math.sqrt(d3_haversin(y1 - y0) + cy0 * cy1 * d3_haversin(x1 - x0))), k = 1 / Math.sin(d); + var interpolate = d ? function(t) { + var B = Math.sin(t *= d) * k, A = Math.sin(d - t) * k, x = A * kx0 + B * kx1, y = A * ky0 + B * ky1, z = A * sy0 + B * sy1; + return [ Math.atan2(y, x) * d3_degrees, Math.atan2(z, Math.sqrt(x * x + y * y)) * d3_degrees ]; + } : function() { + return [ x0 * d3_degrees, y0 * d3_degrees ]; + }; + interpolate.distance = d; + return interpolate; + } + d3.geo.length = function(object) { + d3_geo_lengthSum = 0; + d3.geo.stream(object, d3_geo_length); + return d3_geo_lengthSum; + }; + var d3_geo_lengthSum; + var d3_geo_length = { + sphere: d3_noop, + point: d3_noop, + lineStart: d3_geo_lengthLineStart, + lineEnd: d3_noop, + polygonStart: d3_noop, + polygonEnd: d3_noop + }; + function d3_geo_lengthLineStart() { + var λ0, sinφ0, cosφ0; + d3_geo_length.point = function(λ, φ) { + λ0 = λ * d3_radians, sinφ0 = Math.sin(φ *= d3_radians), cosφ0 = Math.cos(φ); + d3_geo_length.point = nextPoint; + }; + d3_geo_length.lineEnd = function() { + d3_geo_length.point = d3_geo_length.lineEnd = d3_noop; + }; + function nextPoint(λ, φ) { + var sinφ = Math.sin(φ *= d3_radians), cosφ = Math.cos(φ), t = abs((λ *= d3_radians) - λ0), cosΔλ = Math.cos(t); + d3_geo_lengthSum += Math.atan2(Math.sqrt((t = cosφ * Math.sin(t)) * t + (t = cosφ0 * sinφ - sinφ0 * cosφ * cosΔλ) * t), sinφ0 * sinφ + cosφ0 * cosφ * cosΔλ); + λ0 = λ, sinφ0 = sinφ, cosφ0 = cosφ; + } + } + function d3_geo_azimuthal(scale, angle) { + function azimuthal(λ, φ) { + var cosλ = Math.cos(λ), cosφ = Math.cos(φ), k = scale(cosλ * cosφ); + return [ k * cosφ * Math.sin(λ), k * Math.sin(φ) ]; + } + azimuthal.invert = function(x, y) { + var ρ = Math.sqrt(x * x + y * y), c = angle(ρ), sinc = Math.sin(c), cosc = Math.cos(c); + return [ Math.atan2(x * sinc, ρ * cosc), Math.asin(ρ && y * sinc / ρ) ]; + }; + return azimuthal; + } + var d3_geo_azimuthalEqualArea = d3_geo_azimuthal(function(cosλcosφ) { + return Math.sqrt(2 / (1 + cosλcosφ)); + }, function(ρ) { + return 2 * Math.asin(ρ / 2); + }); + (d3.geo.azimuthalEqualArea = function() { + return d3_geo_projection(d3_geo_azimuthalEqualArea); + }).raw = d3_geo_azimuthalEqualArea; + var d3_geo_azimuthalEquidistant = d3_geo_azimuthal(function(cosλcosφ) { + var c = Math.acos(cosλcosφ); + return c && c / Math.sin(c); + }, d3_identity); + (d3.geo.azimuthalEquidistant = function() { + return d3_geo_projection(d3_geo_azimuthalEquidistant); + }).raw = d3_geo_azimuthalEquidistant; + function d3_geo_conicConformal(φ0, φ1) { + var cosφ0 = Math.cos(φ0), t = function(φ) { + return Math.tan(π / 4 + φ / 2); + }, n = φ0 === φ1 ? Math.sin(φ0) : Math.log(cosφ0 / Math.cos(φ1)) / Math.log(t(φ1) / t(φ0)), F = cosφ0 * Math.pow(t(φ0), n) / n; + if (!n) return d3_geo_mercator; + function forward(λ, φ) { + if (F > 0) { + if (φ < -halfπ + ε) φ = -halfπ + ε; + } else { + if (φ > halfπ - ε) φ = halfπ - ε; + } + var ρ = F / Math.pow(t(φ), n); + return [ ρ * Math.sin(n * λ), F - ρ * Math.cos(n * λ) ]; + } + forward.invert = function(x, y) { + var ρ0_y = F - y, ρ = d3_sgn(n) * Math.sqrt(x * x + ρ0_y * ρ0_y); + return [ Math.atan2(x, ρ0_y) / n, 2 * Math.atan(Math.pow(F / ρ, 1 / n)) - halfπ ]; + }; + return forward; + } + (d3.geo.conicConformal = function() { + return d3_geo_conic(d3_geo_conicConformal); + }).raw = d3_geo_conicConformal; + function d3_geo_conicEquidistant(φ0, φ1) { + var cosφ0 = Math.cos(φ0), n = φ0 === φ1 ? Math.sin(φ0) : (cosφ0 - Math.cos(φ1)) / (φ1 - φ0), G = cosφ0 / n + φ0; + if (abs(n) < ε) return d3_geo_equirectangular; + function forward(λ, φ) { + var ρ = G - φ; + return [ ρ * Math.sin(n * λ), G - ρ * Math.cos(n * λ) ]; + } + forward.invert = function(x, y) { + var ρ0_y = G - y; + return [ Math.atan2(x, ρ0_y) / n, G - d3_sgn(n) * Math.sqrt(x * x + ρ0_y * ρ0_y) ]; + }; + return forward; + } + (d3.geo.conicEquidistant = function() { + return d3_geo_conic(d3_geo_conicEquidistant); + }).raw = d3_geo_conicEquidistant; + var d3_geo_gnomonic = d3_geo_azimuthal(function(cosλcosφ) { + return 1 / cosλcosφ; + }, Math.atan); + (d3.geo.gnomonic = function() { + return d3_geo_projection(d3_geo_gnomonic); + }).raw = d3_geo_gnomonic; + function d3_geo_mercator(λ, φ) { + return [ λ, Math.log(Math.tan(π / 4 + φ / 2)) ]; + } + d3_geo_mercator.invert = function(x, y) { + return [ x, 2 * Math.atan(Math.exp(y)) - halfπ ]; + }; + function d3_geo_mercatorProjection(project) { + var m = d3_geo_projection(project), scale = m.scale, translate = m.translate, clipExtent = m.clipExtent, clipAuto; + m.scale = function() { + var v = scale.apply(m, arguments); + return v === m ? clipAuto ? m.clipExtent(null) : m : v; + }; + m.translate = function() { + var v = translate.apply(m, arguments); + return v === m ? clipAuto ? m.clipExtent(null) : m : v; + }; + m.clipExtent = function(_) { + var v = clipExtent.apply(m, arguments); + if (v === m) { + if (clipAuto = _ == null) { + var k = π * scale(), t = translate(); + clipExtent([ [ t[0] - k, t[1] - k ], [ t[0] + k, t[1] + k ] ]); + } + } else if (clipAuto) { + v = null; + } + return v; + }; + return m.clipExtent(null); + } + (d3.geo.mercator = function() { + return d3_geo_mercatorProjection(d3_geo_mercator); + }).raw = d3_geo_mercator; + var d3_geo_orthographic = d3_geo_azimuthal(function() { + return 1; + }, Math.asin); + (d3.geo.orthographic = function() { + return d3_geo_projection(d3_geo_orthographic); + }).raw = d3_geo_orthographic; + var d3_geo_stereographic = d3_geo_azimuthal(function(cosλcosφ) { + return 1 / (1 + cosλcosφ); + }, function(ρ) { + return 2 * Math.atan(ρ); + }); + (d3.geo.stereographic = function() { + return d3_geo_projection(d3_geo_stereographic); + }).raw = d3_geo_stereographic; + function d3_geo_transverseMercator(λ, φ) { + return [ Math.log(Math.tan(π / 4 + φ / 2)), -λ ]; + } + d3_geo_transverseMercator.invert = function(x, y) { + return [ -y, 2 * Math.atan(Math.exp(x)) - halfπ ]; + }; + (d3.geo.transverseMercator = function() { + var projection = d3_geo_mercatorProjection(d3_geo_transverseMercator), center = projection.center, rotate = projection.rotate; + projection.center = function(_) { + return _ ? center([ -_[1], _[0] ]) : (_ = center(), [ _[1], -_[0] ]); + }; + projection.rotate = function(_) { + return _ ? rotate([ _[0], _[1], _.length > 2 ? _[2] + 90 : 90 ]) : (_ = rotate(), + [ _[0], _[1], _[2] - 90 ]); + }; + return rotate([ 0, 0, 90 ]); + }).raw = d3_geo_transverseMercator; + d3.geom = {}; + function d3_geom_pointX(d) { + return d[0]; + } + function d3_geom_pointY(d) { + return d[1]; + } + d3.geom.hull = function(vertices) { + var x = d3_geom_pointX, y = d3_geom_pointY; + if (arguments.length) return hull(vertices); + function hull(data) { + if (data.length < 3) return []; + var fx = d3_functor(x), fy = d3_functor(y), i, n = data.length, points = [], flippedPoints = []; + for (i = 0; i < n; i++) { + points.push([ +fx.call(this, data[i], i), +fy.call(this, data[i], i), i ]); + } + points.sort(d3_geom_hullOrder); + for (i = 0; i < n; i++) flippedPoints.push([ points[i][0], -points[i][1] ]); + var upper = d3_geom_hullUpper(points), lower = d3_geom_hullUpper(flippedPoints); + var skipLeft = lower[0] === upper[0], skipRight = lower[lower.length - 1] === upper[upper.length - 1], polygon = []; + for (i = upper.length - 1; i >= 0; --i) polygon.push(data[points[upper[i]][2]]); + for (i = +skipLeft; i < lower.length - skipRight; ++i) polygon.push(data[points[lower[i]][2]]); + return polygon; + } + hull.x = function(_) { + return arguments.length ? (x = _, hull) : x; + }; + hull.y = function(_) { + return arguments.length ? (y = _, hull) : y; + }; + return hull; + }; + function d3_geom_hullUpper(points) { + var n = points.length, hull = [ 0, 1 ], hs = 2; + for (var i = 2; i < n; i++) { + while (hs > 1 && d3_cross2d(points[hull[hs - 2]], points[hull[hs - 1]], points[i]) <= 0) --hs; + hull[hs++] = i; + } + return hull.slice(0, hs); + } + function d3_geom_hullOrder(a, b) { + return a[0] - b[0] || a[1] - b[1]; + } + d3.geom.polygon = function(coordinates) { + d3_subclass(coordinates, d3_geom_polygonPrototype); + return coordinates; + }; + var d3_geom_polygonPrototype = d3.geom.polygon.prototype = []; + d3_geom_polygonPrototype.area = function() { + var i = -1, n = this.length, a, b = this[n - 1], area = 0; + while (++i < n) { + a = b; + b = this[i]; + area += a[1] * b[0] - a[0] * b[1]; + } + return area * .5; + }; + d3_geom_polygonPrototype.centroid = function(k) { + var i = -1, n = this.length, x = 0, y = 0, a, b = this[n - 1], c; + if (!arguments.length) k = -1 / (6 * this.area()); + while (++i < n) { + a = b; + b = this[i]; + c = a[0] * b[1] - b[0] * a[1]; + x += (a[0] + b[0]) * c; + y += (a[1] + b[1]) * c; + } + return [ x * k, y * k ]; + }; + d3_geom_polygonPrototype.clip = function(subject) { + var input, closed = d3_geom_polygonClosed(subject), i = -1, n = this.length - d3_geom_polygonClosed(this), j, m, a = this[n - 1], b, c, d; + while (++i < n) { + input = subject.slice(); + subject.length = 0; + b = this[i]; + c = input[(m = input.length - closed) - 1]; + j = -1; + while (++j < m) { + d = input[j]; + if (d3_geom_polygonInside(d, a, b)) { + if (!d3_geom_polygonInside(c, a, b)) { + subject.push(d3_geom_polygonIntersect(c, d, a, b)); + } + subject.push(d); + } else if (d3_geom_polygonInside(c, a, b)) { + subject.push(d3_geom_polygonIntersect(c, d, a, b)); + } + c = d; + } + if (closed) subject.push(subject[0]); + a = b; + } + return subject; + }; + function d3_geom_polygonInside(p, a, b) { + return (b[0] - a[0]) * (p[1] - a[1]) < (b[1] - a[1]) * (p[0] - a[0]); + } + function d3_geom_polygonIntersect(c, d, a, b) { + var x1 = c[0], x3 = a[0], x21 = d[0] - x1, x43 = b[0] - x3, y1 = c[1], y3 = a[1], y21 = d[1] - y1, y43 = b[1] - y3, ua = (x43 * (y1 - y3) - y43 * (x1 - x3)) / (y43 * x21 - x43 * y21); + return [ x1 + ua * x21, y1 + ua * y21 ]; + } + function d3_geom_polygonClosed(coordinates) { + var a = coordinates[0], b = coordinates[coordinates.length - 1]; + return !(a[0] - b[0] || a[1] - b[1]); + } + var d3_geom_voronoiEdges, d3_geom_voronoiCells, d3_geom_voronoiBeaches, d3_geom_voronoiBeachPool = [], d3_geom_voronoiFirstCircle, d3_geom_voronoiCircles, d3_geom_voronoiCirclePool = []; + function d3_geom_voronoiBeach() { + d3_geom_voronoiRedBlackNode(this); + this.edge = this.site = this.circle = null; + } + function d3_geom_voronoiCreateBeach(site) { + var beach = d3_geom_voronoiBeachPool.pop() || new d3_geom_voronoiBeach(); + beach.site = site; + return beach; + } + function d3_geom_voronoiDetachBeach(beach) { + d3_geom_voronoiDetachCircle(beach); + d3_geom_voronoiBeaches.remove(beach); + d3_geom_voronoiBeachPool.push(beach); + d3_geom_voronoiRedBlackNode(beach); + } + function d3_geom_voronoiRemoveBeach(beach) { + var circle = beach.circle, x = circle.x, y = circle.cy, vertex = { + x: x, + y: y + }, previous = beach.P, next = beach.N, disappearing = [ beach ]; + d3_geom_voronoiDetachBeach(beach); + var lArc = previous; + while (lArc.circle && abs(x - lArc.circle.x) < ε && abs(y - lArc.circle.cy) < ε) { + previous = lArc.P; + disappearing.unshift(lArc); + d3_geom_voronoiDetachBeach(lArc); + lArc = previous; + } + disappearing.unshift(lArc); + d3_geom_voronoiDetachCircle(lArc); + var rArc = next; + while (rArc.circle && abs(x - rArc.circle.x) < ε && abs(y - rArc.circle.cy) < ε) { + next = rArc.N; + disappearing.push(rArc); + d3_geom_voronoiDetachBeach(rArc); + rArc = next; + } + disappearing.push(rArc); + d3_geom_voronoiDetachCircle(rArc); + var nArcs = disappearing.length, iArc; + for (iArc = 1; iArc < nArcs; ++iArc) { + rArc = disappearing[iArc]; + lArc = disappearing[iArc - 1]; + d3_geom_voronoiSetEdgeEnd(rArc.edge, lArc.site, rArc.site, vertex); + } + lArc = disappearing[0]; + rArc = disappearing[nArcs - 1]; + rArc.edge = d3_geom_voronoiCreateEdge(lArc.site, rArc.site, null, vertex); + d3_geom_voronoiAttachCircle(lArc); + d3_geom_voronoiAttachCircle(rArc); + } + function d3_geom_voronoiAddBeach(site) { + var x = site.x, directrix = site.y, lArc, rArc, dxl, dxr, node = d3_geom_voronoiBeaches._; + while (node) { + dxl = d3_geom_voronoiLeftBreakPoint(node, directrix) - x; + if (dxl > ε) node = node.L; else { + dxr = x - d3_geom_voronoiRightBreakPoint(node, directrix); + if (dxr > ε) { + if (!node.R) { + lArc = node; + break; + } + node = node.R; + } else { + if (dxl > -ε) { + lArc = node.P; + rArc = node; + } else if (dxr > -ε) { + lArc = node; + rArc = node.N; + } else { + lArc = rArc = node; + } + break; + } + } + } + var newArc = d3_geom_voronoiCreateBeach(site); + d3_geom_voronoiBeaches.insert(lArc, newArc); + if (!lArc && !rArc) return; + if (lArc === rArc) { + d3_geom_voronoiDetachCircle(lArc); + rArc = d3_geom_voronoiCreateBeach(lArc.site); + d3_geom_voronoiBeaches.insert(newArc, rArc); + newArc.edge = rArc.edge = d3_geom_voronoiCreateEdge(lArc.site, newArc.site); + d3_geom_voronoiAttachCircle(lArc); + d3_geom_voronoiAttachCircle(rArc); + return; + } + if (!rArc) { + newArc.edge = d3_geom_voronoiCreateEdge(lArc.site, newArc.site); + return; + } + d3_geom_voronoiDetachCircle(lArc); + d3_geom_voronoiDetachCircle(rArc); + var lSite = lArc.site, ax = lSite.x, ay = lSite.y, bx = site.x - ax, by = site.y - ay, rSite = rArc.site, cx = rSite.x - ax, cy = rSite.y - ay, d = 2 * (bx * cy - by * cx), hb = bx * bx + by * by, hc = cx * cx + cy * cy, vertex = { + x: (cy * hb - by * hc) / d + ax, + y: (bx * hc - cx * hb) / d + ay + }; + d3_geom_voronoiSetEdgeEnd(rArc.edge, lSite, rSite, vertex); + newArc.edge = d3_geom_voronoiCreateEdge(lSite, site, null, vertex); + rArc.edge = d3_geom_voronoiCreateEdge(site, rSite, null, vertex); + d3_geom_voronoiAttachCircle(lArc); + d3_geom_voronoiAttachCircle(rArc); + } + function d3_geom_voronoiLeftBreakPoint(arc, directrix) { + var site = arc.site, rfocx = site.x, rfocy = site.y, pby2 = rfocy - directrix; + if (!pby2) return rfocx; + var lArc = arc.P; + if (!lArc) return -Infinity; + site = lArc.site; + var lfocx = site.x, lfocy = site.y, plby2 = lfocy - directrix; + if (!plby2) return lfocx; + var hl = lfocx - rfocx, aby2 = 1 / pby2 - 1 / plby2, b = hl / plby2; + if (aby2) return (-b + Math.sqrt(b * b - 2 * aby2 * (hl * hl / (-2 * plby2) - lfocy + plby2 / 2 + rfocy - pby2 / 2))) / aby2 + rfocx; + return (rfocx + lfocx) / 2; + } + function d3_geom_voronoiRightBreakPoint(arc, directrix) { + var rArc = arc.N; + if (rArc) return d3_geom_voronoiLeftBreakPoint(rArc, directrix); + var site = arc.site; + return site.y === directrix ? site.x : Infinity; + } + function d3_geom_voronoiCell(site) { + this.site = site; + this.edges = []; + } + d3_geom_voronoiCell.prototype.prepare = function() { + var halfEdges = this.edges, iHalfEdge = halfEdges.length, edge; + while (iHalfEdge--) { + edge = halfEdges[iHalfEdge].edge; + if (!edge.b || !edge.a) halfEdges.splice(iHalfEdge, 1); + } + halfEdges.sort(d3_geom_voronoiHalfEdgeOrder); + return halfEdges.length; + }; + function d3_geom_voronoiCloseCells(extent) { + var x0 = extent[0][0], x1 = extent[1][0], y0 = extent[0][1], y1 = extent[1][1], x2, y2, x3, y3, cells = d3_geom_voronoiCells, iCell = cells.length, cell, iHalfEdge, halfEdges, nHalfEdges, start, end; + while (iCell--) { + cell = cells[iCell]; + if (!cell || !cell.prepare()) continue; + halfEdges = cell.edges; + nHalfEdges = halfEdges.length; + iHalfEdge = 0; + while (iHalfEdge < nHalfEdges) { + end = halfEdges[iHalfEdge].end(), x3 = end.x, y3 = end.y; + start = halfEdges[++iHalfEdge % nHalfEdges].start(), x2 = start.x, y2 = start.y; + if (abs(x3 - x2) > ε || abs(y3 - y2) > ε) { + halfEdges.splice(iHalfEdge, 0, new d3_geom_voronoiHalfEdge(d3_geom_voronoiCreateBorderEdge(cell.site, end, abs(x3 - x0) < ε && y1 - y3 > ε ? { + x: x0, + y: abs(x2 - x0) < ε ? y2 : y1 + } : abs(y3 - y1) < ε && x1 - x3 > ε ? { + x: abs(y2 - y1) < ε ? x2 : x1, + y: y1 + } : abs(x3 - x1) < ε && y3 - y0 > ε ? { + x: x1, + y: abs(x2 - x1) < ε ? y2 : y0 + } : abs(y3 - y0) < ε && x3 - x0 > ε ? { + x: abs(y2 - y0) < ε ? x2 : x0, + y: y0 + } : null), cell.site, null)); + ++nHalfEdges; + } + } + } + } + function d3_geom_voronoiHalfEdgeOrder(a, b) { + return b.angle - a.angle; + } + function d3_geom_voronoiCircle() { + d3_geom_voronoiRedBlackNode(this); + this.x = this.y = this.arc = this.site = this.cy = null; + } + function d3_geom_voronoiAttachCircle(arc) { + var lArc = arc.P, rArc = arc.N; + if (!lArc || !rArc) return; + var lSite = lArc.site, cSite = arc.site, rSite = rArc.site; + if (lSite === rSite) return; + var bx = cSite.x, by = cSite.y, ax = lSite.x - bx, ay = lSite.y - by, cx = rSite.x - bx, cy = rSite.y - by; + var d = 2 * (ax * cy - ay * cx); + if (d >= -ε2) return; + var ha = ax * ax + ay * ay, hc = cx * cx + cy * cy, x = (cy * ha - ay * hc) / d, y = (ax * hc - cx * ha) / d, cy = y + by; + var circle = d3_geom_voronoiCirclePool.pop() || new d3_geom_voronoiCircle(); + circle.arc = arc; + circle.site = cSite; + circle.x = x + bx; + circle.y = cy + Math.sqrt(x * x + y * y); + circle.cy = cy; + arc.circle = circle; + var before = null, node = d3_geom_voronoiCircles._; + while (node) { + if (circle.y < node.y || circle.y === node.y && circle.x <= node.x) { + if (node.L) node = node.L; else { + before = node.P; + break; + } + } else { + if (node.R) node = node.R; else { + before = node; + break; + } + } + } + d3_geom_voronoiCircles.insert(before, circle); + if (!before) d3_geom_voronoiFirstCircle = circle; + } + function d3_geom_voronoiDetachCircle(arc) { + var circle = arc.circle; + if (circle) { + if (!circle.P) d3_geom_voronoiFirstCircle = circle.N; + d3_geom_voronoiCircles.remove(circle); + d3_geom_voronoiCirclePool.push(circle); + d3_geom_voronoiRedBlackNode(circle); + arc.circle = null; + } + } + function d3_geom_voronoiClipEdges(extent) { + var edges = d3_geom_voronoiEdges, clip = d3_geom_clipLine(extent[0][0], extent[0][1], extent[1][0], extent[1][1]), i = edges.length, e; + while (i--) { + e = edges[i]; + if (!d3_geom_voronoiConnectEdge(e, extent) || !clip(e) || abs(e.a.x - e.b.x) < ε && abs(e.a.y - e.b.y) < ε) { + e.a = e.b = null; + edges.splice(i, 1); + } + } + } + function d3_geom_voronoiConnectEdge(edge, extent) { + var vb = edge.b; + if (vb) return true; + var va = edge.a, x0 = extent[0][0], x1 = extent[1][0], y0 = extent[0][1], y1 = extent[1][1], lSite = edge.l, rSite = edge.r, lx = lSite.x, ly = lSite.y, rx = rSite.x, ry = rSite.y, fx = (lx + rx) / 2, fy = (ly + ry) / 2, fm, fb; + if (ry === ly) { + if (fx < x0 || fx >= x1) return; + if (lx > rx) { + if (!va) va = { + x: fx, + y: y0 + }; else if (va.y >= y1) return; + vb = { + x: fx, + y: y1 + }; + } else { + if (!va) va = { + x: fx, + y: y1 + }; else if (va.y < y0) return; + vb = { + x: fx, + y: y0 + }; + } + } else { + fm = (lx - rx) / (ry - ly); + fb = fy - fm * fx; + if (fm < -1 || fm > 1) { + if (lx > rx) { + if (!va) va = { + x: (y0 - fb) / fm, + y: y0 + }; else if (va.y >= y1) return; + vb = { + x: (y1 - fb) / fm, + y: y1 + }; + } else { + if (!va) va = { + x: (y1 - fb) / fm, + y: y1 + }; else if (va.y < y0) return; + vb = { + x: (y0 - fb) / fm, + y: y0 + }; + } + } else { + if (ly < ry) { + if (!va) va = { + x: x0, + y: fm * x0 + fb + }; else if (va.x >= x1) return; + vb = { + x: x1, + y: fm * x1 + fb + }; + } else { + if (!va) va = { + x: x1, + y: fm * x1 + fb + }; else if (va.x < x0) return; + vb = { + x: x0, + y: fm * x0 + fb + }; + } + } + } + edge.a = va; + edge.b = vb; + return true; + } + function d3_geom_voronoiEdge(lSite, rSite) { + this.l = lSite; + this.r = rSite; + this.a = this.b = null; + } + function d3_geom_voronoiCreateEdge(lSite, rSite, va, vb) { + var edge = new d3_geom_voronoiEdge(lSite, rSite); + d3_geom_voronoiEdges.push(edge); + if (va) d3_geom_voronoiSetEdgeEnd(edge, lSite, rSite, va); + if (vb) d3_geom_voronoiSetEdgeEnd(edge, rSite, lSite, vb); + d3_geom_voronoiCells[lSite.i].edges.push(new d3_geom_voronoiHalfEdge(edge, lSite, rSite)); + d3_geom_voronoiCells[rSite.i].edges.push(new d3_geom_voronoiHalfEdge(edge, rSite, lSite)); + return edge; + } + function d3_geom_voronoiCreateBorderEdge(lSite, va, vb) { + var edge = new d3_geom_voronoiEdge(lSite, null); + edge.a = va; + edge.b = vb; + d3_geom_voronoiEdges.push(edge); + return edge; + } + function d3_geom_voronoiSetEdgeEnd(edge, lSite, rSite, vertex) { + if (!edge.a && !edge.b) { + edge.a = vertex; + edge.l = lSite; + edge.r = rSite; + } else if (edge.l === rSite) { + edge.b = vertex; + } else { + edge.a = vertex; + } + } + function d3_geom_voronoiHalfEdge(edge, lSite, rSite) { + var va = edge.a, vb = edge.b; + this.edge = edge; + this.site = lSite; + this.angle = rSite ? Math.atan2(rSite.y - lSite.y, rSite.x - lSite.x) : edge.l === lSite ? Math.atan2(vb.x - va.x, va.y - vb.y) : Math.atan2(va.x - vb.x, vb.y - va.y); + } + d3_geom_voronoiHalfEdge.prototype = { + start: function() { + return this.edge.l === this.site ? this.edge.a : this.edge.b; + }, + end: function() { + return this.edge.l === this.site ? this.edge.b : this.edge.a; + } + }; + function d3_geom_voronoiRedBlackTree() { + this._ = null; + } + function d3_geom_voronoiRedBlackNode(node) { + node.U = node.C = node.L = node.R = node.P = node.N = null; + } + d3_geom_voronoiRedBlackTree.prototype = { + insert: function(after, node) { + var parent, grandpa, uncle; + if (after) { + node.P = after; + node.N = after.N; + if (after.N) after.N.P = node; + after.N = node; + if (after.R) { + after = after.R; + while (after.L) after = after.L; + after.L = node; + } else { + after.R = node; + } + parent = after; + } else if (this._) { + after = d3_geom_voronoiRedBlackFirst(this._); + node.P = null; + node.N = after; + after.P = after.L = node; + parent = after; + } else { + node.P = node.N = null; + this._ = node; + parent = null; + } + node.L = node.R = null; + node.U = parent; + node.C = true; + after = node; + while (parent && parent.C) { + grandpa = parent.U; + if (parent === grandpa.L) { + uncle = grandpa.R; + if (uncle && uncle.C) { + parent.C = uncle.C = false; + grandpa.C = true; + after = grandpa; + } else { + if (after === parent.R) { + d3_geom_voronoiRedBlackRotateLeft(this, parent); + after = parent; + parent = after.U; + } + parent.C = false; + grandpa.C = true; + d3_geom_voronoiRedBlackRotateRight(this, grandpa); + } + } else { + uncle = grandpa.L; + if (uncle && uncle.C) { + parent.C = uncle.C = false; + grandpa.C = true; + after = grandpa; + } else { + if (after === parent.L) { + d3_geom_voronoiRedBlackRotateRight(this, parent); + after = parent; + parent = after.U; + } + parent.C = false; + grandpa.C = true; + d3_geom_voronoiRedBlackRotateLeft(this, grandpa); + } + } + parent = after.U; + } + this._.C = false; + }, + remove: function(node) { + if (node.N) node.N.P = node.P; + if (node.P) node.P.N = node.N; + node.N = node.P = null; + var parent = node.U, sibling, left = node.L, right = node.R, next, red; + if (!left) next = right; else if (!right) next = left; else next = d3_geom_voronoiRedBlackFirst(right); + if (parent) { + if (parent.L === node) parent.L = next; else parent.R = next; + } else { + this._ = next; + } + if (left && right) { + red = next.C; + next.C = node.C; + next.L = left; + left.U = next; + if (next !== right) { + parent = next.U; + next.U = node.U; + node = next.R; + parent.L = node; + next.R = right; + right.U = next; + } else { + next.U = parent; + parent = next; + node = next.R; + } + } else { + red = node.C; + node = next; + } + if (node) node.U = parent; + if (red) return; + if (node && node.C) { + node.C = false; + return; + } + do { + if (node === this._) break; + if (node === parent.L) { + sibling = parent.R; + if (sibling.C) { + sibling.C = false; + parent.C = true; + d3_geom_voronoiRedBlackRotateLeft(this, parent); + sibling = parent.R; + } + if (sibling.L && sibling.L.C || sibling.R && sibling.R.C) { + if (!sibling.R || !sibling.R.C) { + sibling.L.C = false; + sibling.C = true; + d3_geom_voronoiRedBlackRotateRight(this, sibling); + sibling = parent.R; + } + sibling.C = parent.C; + parent.C = sibling.R.C = false; + d3_geom_voronoiRedBlackRotateLeft(this, parent); + node = this._; + break; + } + } else { + sibling = parent.L; + if (sibling.C) { + sibling.C = false; + parent.C = true; + d3_geom_voronoiRedBlackRotateRight(this, parent); + sibling = parent.L; + } + if (sibling.L && sibling.L.C || sibling.R && sibling.R.C) { + if (!sibling.L || !sibling.L.C) { + sibling.R.C = false; + sibling.C = true; + d3_geom_voronoiRedBlackRotateLeft(this, sibling); + sibling = parent.L; + } + sibling.C = parent.C; + parent.C = sibling.L.C = false; + d3_geom_voronoiRedBlackRotateRight(this, parent); + node = this._; + break; + } + } + sibling.C = true; + node = parent; + parent = parent.U; + } while (!node.C); + if (node) node.C = false; + } + }; + function d3_geom_voronoiRedBlackRotateLeft(tree, node) { + var p = node, q = node.R, parent = p.U; + if (parent) { + if (parent.L === p) parent.L = q; else parent.R = q; + } else { + tree._ = q; + } + q.U = parent; + p.U = q; + p.R = q.L; + if (p.R) p.R.U = p; + q.L = p; + } + function d3_geom_voronoiRedBlackRotateRight(tree, node) { + var p = node, q = node.L, parent = p.U; + if (parent) { + if (parent.L === p) parent.L = q; else parent.R = q; + } else { + tree._ = q; + } + q.U = parent; + p.U = q; + p.L = q.R; + if (p.L) p.L.U = p; + q.R = p; + } + function d3_geom_voronoiRedBlackFirst(node) { + while (node.L) node = node.L; + return node; + } + function d3_geom_voronoi(sites, bbox) { + var site = sites.sort(d3_geom_voronoiVertexOrder).pop(), x0, y0, circle; + d3_geom_voronoiEdges = []; + d3_geom_voronoiCells = new Array(sites.length); + d3_geom_voronoiBeaches = new d3_geom_voronoiRedBlackTree(); + d3_geom_voronoiCircles = new d3_geom_voronoiRedBlackTree(); + while (true) { + circle = d3_geom_voronoiFirstCircle; + if (site && (!circle || site.y < circle.y || site.y === circle.y && site.x < circle.x)) { + if (site.x !== x0 || site.y !== y0) { + d3_geom_voronoiCells[site.i] = new d3_geom_voronoiCell(site); + d3_geom_voronoiAddBeach(site); + x0 = site.x, y0 = site.y; + } + site = sites.pop(); + } else if (circle) { + d3_geom_voronoiRemoveBeach(circle.arc); + } else { + break; + } + } + if (bbox) d3_geom_voronoiClipEdges(bbox), d3_geom_voronoiCloseCells(bbox); + var diagram = { + cells: d3_geom_voronoiCells, + edges: d3_geom_voronoiEdges + }; + d3_geom_voronoiBeaches = d3_geom_voronoiCircles = d3_geom_voronoiEdges = d3_geom_voronoiCells = null; + return diagram; + } + function d3_geom_voronoiVertexOrder(a, b) { + return b.y - a.y || b.x - a.x; + } + d3.geom.voronoi = function(points) { + var x = d3_geom_pointX, y = d3_geom_pointY, fx = x, fy = y, clipExtent = d3_geom_voronoiClipExtent; + if (points) return voronoi(points); + function voronoi(data) { + var polygons = new Array(data.length), x0 = clipExtent[0][0], y0 = clipExtent[0][1], x1 = clipExtent[1][0], y1 = clipExtent[1][1]; + d3_geom_voronoi(sites(data), clipExtent).cells.forEach(function(cell, i) { + var edges = cell.edges, site = cell.site, polygon = polygons[i] = edges.length ? edges.map(function(e) { + var s = e.start(); + return [ s.x, s.y ]; + }) : site.x >= x0 && site.x <= x1 && site.y >= y0 && site.y <= y1 ? [ [ x0, y1 ], [ x1, y1 ], [ x1, y0 ], [ x0, y0 ] ] : []; + polygon.point = data[i]; + }); + return polygons; + } + function sites(data) { + return data.map(function(d, i) { + return { + x: Math.round(fx(d, i) / ε) * ε, + y: Math.round(fy(d, i) / ε) * ε, + i: i + }; + }); + } + voronoi.links = function(data) { + return d3_geom_voronoi(sites(data)).edges.filter(function(edge) { + return edge.l && edge.r; + }).map(function(edge) { + return { + source: data[edge.l.i], + target: data[edge.r.i] + }; + }); + }; + voronoi.triangles = function(data) { + var triangles = []; + d3_geom_voronoi(sites(data)).cells.forEach(function(cell, i) { + var site = cell.site, edges = cell.edges.sort(d3_geom_voronoiHalfEdgeOrder), j = -1, m = edges.length, e0, s0, e1 = edges[m - 1].edge, s1 = e1.l === site ? e1.r : e1.l; + while (++j < m) { + e0 = e1; + s0 = s1; + e1 = edges[j].edge; + s1 = e1.l === site ? e1.r : e1.l; + if (i < s0.i && i < s1.i && d3_geom_voronoiTriangleArea(site, s0, s1) < 0) { + triangles.push([ data[i], data[s0.i], data[s1.i] ]); + } + } + }); + return triangles; + }; + voronoi.x = function(_) { + return arguments.length ? (fx = d3_functor(x = _), voronoi) : x; + }; + voronoi.y = function(_) { + return arguments.length ? (fy = d3_functor(y = _), voronoi) : y; + }; + voronoi.clipExtent = function(_) { + if (!arguments.length) return clipExtent === d3_geom_voronoiClipExtent ? null : clipExtent; + clipExtent = _ == null ? d3_geom_voronoiClipExtent : _; + return voronoi; + }; + voronoi.size = function(_) { + if (!arguments.length) return clipExtent === d3_geom_voronoiClipExtent ? null : clipExtent && clipExtent[1]; + return voronoi.clipExtent(_ && [ [ 0, 0 ], _ ]); + }; + return voronoi; + }; + var d3_geom_voronoiClipExtent = [ [ -1e6, -1e6 ], [ 1e6, 1e6 ] ]; + function d3_geom_voronoiTriangleArea(a, b, c) { + return (a.x - c.x) * (b.y - a.y) - (a.x - b.x) * (c.y - a.y); + } + d3.geom.delaunay = function(vertices) { + return d3.geom.voronoi().triangles(vertices); + }; + d3.geom.quadtree = function(points, x1, y1, x2, y2) { + var x = d3_geom_pointX, y = d3_geom_pointY, compat; + if (compat = arguments.length) { + x = d3_geom_quadtreeCompatX; + y = d3_geom_quadtreeCompatY; + if (compat === 3) { + y2 = y1; + x2 = x1; + y1 = x1 = 0; + } + return quadtree(points); + } + function quadtree(data) { + var d, fx = d3_functor(x), fy = d3_functor(y), xs, ys, i, n, x1_, y1_, x2_, y2_; + if (x1 != null) { + x1_ = x1, y1_ = y1, x2_ = x2, y2_ = y2; + } else { + x2_ = y2_ = -(x1_ = y1_ = Infinity); + xs = [], ys = []; + n = data.length; + if (compat) for (i = 0; i < n; ++i) { + d = data[i]; + if (d.x < x1_) x1_ = d.x; + if (d.y < y1_) y1_ = d.y; + if (d.x > x2_) x2_ = d.x; + if (d.y > y2_) y2_ = d.y; + xs.push(d.x); + ys.push(d.y); + } else for (i = 0; i < n; ++i) { + var x_ = +fx(d = data[i], i), y_ = +fy(d, i); + if (x_ < x1_) x1_ = x_; + if (y_ < y1_) y1_ = y_; + if (x_ > x2_) x2_ = x_; + if (y_ > y2_) y2_ = y_; + xs.push(x_); + ys.push(y_); + } + } + var dx = x2_ - x1_, dy = y2_ - y1_; + if (dx > dy) y2_ = y1_ + dx; else x2_ = x1_ + dy; + function insert(n, d, x, y, x1, y1, x2, y2) { + if (isNaN(x) || isNaN(y)) return; + if (n.leaf) { + var nx = n.x, ny = n.y; + if (nx != null) { + if (abs(nx - x) + abs(ny - y) < .01) { + insertChild(n, d, x, y, x1, y1, x2, y2); + } else { + var nPoint = n.point; + n.x = n.y = n.point = null; + insertChild(n, nPoint, nx, ny, x1, y1, x2, y2); + insertChild(n, d, x, y, x1, y1, x2, y2); + } + } else { + n.x = x, n.y = y, n.point = d; + } + } else { + insertChild(n, d, x, y, x1, y1, x2, y2); + } + } + function insertChild(n, d, x, y, x1, y1, x2, y2) { + var xm = (x1 + x2) * .5, ym = (y1 + y2) * .5, right = x >= xm, below = y >= ym, i = below << 1 | right; + n.leaf = false; + n = n.nodes[i] || (n.nodes[i] = d3_geom_quadtreeNode()); + if (right) x1 = xm; else x2 = xm; + if (below) y1 = ym; else y2 = ym; + insert(n, d, x, y, x1, y1, x2, y2); + } + var root = d3_geom_quadtreeNode(); + root.add = function(d) { + insert(root, d, +fx(d, ++i), +fy(d, i), x1_, y1_, x2_, y2_); + }; + root.visit = function(f) { + d3_geom_quadtreeVisit(f, root, x1_, y1_, x2_, y2_); + }; + root.find = function(point) { + return d3_geom_quadtreeFind(root, point[0], point[1], x1_, y1_, x2_, y2_); + }; + i = -1; + if (x1 == null) { + while (++i < n) { + insert(root, data[i], xs[i], ys[i], x1_, y1_, x2_, y2_); + } + --i; + } else data.forEach(root.add); + xs = ys = data = d = null; + return root; + } + quadtree.x = function(_) { + return arguments.length ? (x = _, quadtree) : x; + }; + quadtree.y = function(_) { + return arguments.length ? (y = _, quadtree) : y; + }; + quadtree.extent = function(_) { + if (!arguments.length) return x1 == null ? null : [ [ x1, y1 ], [ x2, y2 ] ]; + if (_ == null) x1 = y1 = x2 = y2 = null; else x1 = +_[0][0], y1 = +_[0][1], x2 = +_[1][0], + y2 = +_[1][1]; + return quadtree; + }; + quadtree.size = function(_) { + if (!arguments.length) return x1 == null ? null : [ x2 - x1, y2 - y1 ]; + if (_ == null) x1 = y1 = x2 = y2 = null; else x1 = y1 = 0, x2 = +_[0], y2 = +_[1]; + return quadtree; + }; + return quadtree; + }; + function d3_geom_quadtreeCompatX(d) { + return d.x; + } + function d3_geom_quadtreeCompatY(d) { + return d.y; + } + function d3_geom_quadtreeNode() { + return { + leaf: true, + nodes: [], + point: null, + x: null, + y: null + }; + } + function d3_geom_quadtreeVisit(f, node, x1, y1, x2, y2) { + if (!f(node, x1, y1, x2, y2)) { + var sx = (x1 + x2) * .5, sy = (y1 + y2) * .5, children = node.nodes; + if (children[0]) d3_geom_quadtreeVisit(f, children[0], x1, y1, sx, sy); + if (children[1]) d3_geom_quadtreeVisit(f, children[1], sx, y1, x2, sy); + if (children[2]) d3_geom_quadtreeVisit(f, children[2], x1, sy, sx, y2); + if (children[3]) d3_geom_quadtreeVisit(f, children[3], sx, sy, x2, y2); + } + } + function d3_geom_quadtreeFind(root, x, y, x0, y0, x3, y3) { + var minDistance2 = Infinity, closestPoint; + (function find(node, x1, y1, x2, y2) { + if (x1 > x3 || y1 > y3 || x2 < x0 || y2 < y0) return; + if (point = node.point) { + var point, dx = x - node.x, dy = y - node.y, distance2 = dx * dx + dy * dy; + if (distance2 < minDistance2) { + var distance = Math.sqrt(minDistance2 = distance2); + x0 = x - distance, y0 = y - distance; + x3 = x + distance, y3 = y + distance; + closestPoint = point; + } + } + var children = node.nodes, xm = (x1 + x2) * .5, ym = (y1 + y2) * .5, right = x >= xm, below = y >= ym; + for (var i = below << 1 | right, j = i + 4; i < j; ++i) { + if (node = children[i & 3]) switch (i & 3) { + case 0: + find(node, x1, y1, xm, ym); + break; + + case 1: + find(node, xm, y1, x2, ym); + break; + + case 2: + find(node, x1, ym, xm, y2); + break; + + case 3: + find(node, xm, ym, x2, y2); + break; + } + } + })(root, x0, y0, x3, y3); + return closestPoint; + } + d3.interpolateRgb = d3_interpolateRgb; + function d3_interpolateRgb(a, b) { + a = d3.rgb(a); + b = d3.rgb(b); + var ar = a.r, ag = a.g, ab = a.b, br = b.r - ar, bg = b.g - ag, bb = b.b - ab; + return function(t) { + return "#" + d3_rgb_hex(Math.round(ar + br * t)) + d3_rgb_hex(Math.round(ag + bg * t)) + d3_rgb_hex(Math.round(ab + bb * t)); + }; + } + d3.interpolateObject = d3_interpolateObject; + function d3_interpolateObject(a, b) { + var i = {}, c = {}, k; + for (k in a) { + if (k in b) { + i[k] = d3_interpolate(a[k], b[k]); + } else { + c[k] = a[k]; + } + } + for (k in b) { + if (!(k in a)) { + c[k] = b[k]; + } + } + return function(t) { + for (k in i) c[k] = i[k](t); + return c; + }; + } + d3.interpolateNumber = d3_interpolateNumber; + function d3_interpolateNumber(a, b) { + a = +a, b = +b; + return function(t) { + return a * (1 - t) + b * t; + }; + } + d3.interpolateString = d3_interpolateString; + function d3_interpolateString(a, b) { + var bi = d3_interpolate_numberA.lastIndex = d3_interpolate_numberB.lastIndex = 0, am, bm, bs, i = -1, s = [], q = []; + a = a + "", b = b + ""; + while ((am = d3_interpolate_numberA.exec(a)) && (bm = d3_interpolate_numberB.exec(b))) { + if ((bs = bm.index) > bi) { + bs = b.slice(bi, bs); + if (s[i]) s[i] += bs; else s[++i] = bs; + } + if ((am = am[0]) === (bm = bm[0])) { + if (s[i]) s[i] += bm; else s[++i] = bm; + } else { + s[++i] = null; + q.push({ + i: i, + x: d3_interpolateNumber(am, bm) + }); + } + bi = d3_interpolate_numberB.lastIndex; + } + if (bi < b.length) { + bs = b.slice(bi); + if (s[i]) s[i] += bs; else s[++i] = bs; + } + return s.length < 2 ? q[0] ? (b = q[0].x, function(t) { + return b(t) + ""; + }) : function() { + return b; + } : (b = q.length, function(t) { + for (var i = 0, o; i < b; ++i) s[(o = q[i]).i] = o.x(t); + return s.join(""); + }); + } + var d3_interpolate_numberA = /[-+]?(?:\d+\.?\d*|\.?\d+)(?:[eE][-+]?\d+)?/g, d3_interpolate_numberB = new RegExp(d3_interpolate_numberA.source, "g"); + d3.interpolate = d3_interpolate; + function d3_interpolate(a, b) { + var i = d3.interpolators.length, f; + while (--i >= 0 && !(f = d3.interpolators[i](a, b))) ; + return f; + } + d3.interpolators = [ function(a, b) { + var t = typeof b; + return (t === "string" ? d3_rgb_names.has(b.toLowerCase()) || /^(#|rgb\(|hsl\()/i.test(b) ? d3_interpolateRgb : d3_interpolateString : b instanceof d3_color ? d3_interpolateRgb : Array.isArray(b) ? d3_interpolateArray : t === "object" && isNaN(b) ? d3_interpolateObject : d3_interpolateNumber)(a, b); + } ]; + d3.interpolateArray = d3_interpolateArray; + function d3_interpolateArray(a, b) { + var x = [], c = [], na = a.length, nb = b.length, n0 = Math.min(a.length, b.length), i; + for (i = 0; i < n0; ++i) x.push(d3_interpolate(a[i], b[i])); + for (;i < na; ++i) c[i] = a[i]; + for (;i < nb; ++i) c[i] = b[i]; + return function(t) { + for (i = 0; i < n0; ++i) c[i] = x[i](t); + return c; + }; + } + var d3_ease_default = function() { + return d3_identity; + }; + var d3_ease = d3.map({ + linear: d3_ease_default, + poly: d3_ease_poly, + quad: function() { + return d3_ease_quad; + }, + cubic: function() { + return d3_ease_cubic; + }, + sin: function() { + return d3_ease_sin; + }, + exp: function() { + return d3_ease_exp; + }, + circle: function() { + return d3_ease_circle; + }, + elastic: d3_ease_elastic, + back: d3_ease_back, + bounce: function() { + return d3_ease_bounce; + } + }); + var d3_ease_mode = d3.map({ + "in": d3_identity, + out: d3_ease_reverse, + "in-out": d3_ease_reflect, + "out-in": function(f) { + return d3_ease_reflect(d3_ease_reverse(f)); + } + }); + d3.ease = function(name) { + var i = name.indexOf("-"), t = i >= 0 ? name.slice(0, i) : name, m = i >= 0 ? name.slice(i + 1) : "in"; + t = d3_ease.get(t) || d3_ease_default; + m = d3_ease_mode.get(m) || d3_identity; + return d3_ease_clamp(m(t.apply(null, d3_arraySlice.call(arguments, 1)))); + }; + function d3_ease_clamp(f) { + return function(t) { + return t <= 0 ? 0 : t >= 1 ? 1 : f(t); + }; + } + function d3_ease_reverse(f) { + return function(t) { + return 1 - f(1 - t); + }; + } + function d3_ease_reflect(f) { + return function(t) { + return .5 * (t < .5 ? f(2 * t) : 2 - f(2 - 2 * t)); + }; + } + function d3_ease_quad(t) { + return t * t; + } + function d3_ease_cubic(t) { + return t * t * t; + } + function d3_ease_cubicInOut(t) { + if (t <= 0) return 0; + if (t >= 1) return 1; + var t2 = t * t, t3 = t2 * t; + return 4 * (t < .5 ? t3 : 3 * (t - t2) + t3 - .75); + } + function d3_ease_poly(e) { + return function(t) { + return Math.pow(t, e); + }; + } + function d3_ease_sin(t) { + return 1 - Math.cos(t * halfπ); + } + function d3_ease_exp(t) { + return Math.pow(2, 10 * (t - 1)); + } + function d3_ease_circle(t) { + return 1 - Math.sqrt(1 - t * t); + } + function d3_ease_elastic(a, p) { + var s; + if (arguments.length < 2) p = .45; + if (arguments.length) s = p / τ * Math.asin(1 / a); else a = 1, s = p / 4; + return function(t) { + return 1 + a * Math.pow(2, -10 * t) * Math.sin((t - s) * τ / p); + }; + } + function d3_ease_back(s) { + if (!s) s = 1.70158; + return function(t) { + return t * t * ((s + 1) * t - s); + }; + } + function d3_ease_bounce(t) { + return t < 1 / 2.75 ? 7.5625 * t * t : t < 2 / 2.75 ? 7.5625 * (t -= 1.5 / 2.75) * t + .75 : t < 2.5 / 2.75 ? 7.5625 * (t -= 2.25 / 2.75) * t + .9375 : 7.5625 * (t -= 2.625 / 2.75) * t + .984375; + } + d3.interpolateHcl = d3_interpolateHcl; + function d3_interpolateHcl(a, b) { + a = d3.hcl(a); + b = d3.hcl(b); + var ah = a.h, ac = a.c, al = a.l, bh = b.h - ah, bc = b.c - ac, bl = b.l - al; + if (isNaN(bc)) bc = 0, ac = isNaN(ac) ? b.c : ac; + if (isNaN(bh)) bh = 0, ah = isNaN(ah) ? b.h : ah; else if (bh > 180) bh -= 360; else if (bh < -180) bh += 360; + return function(t) { + return d3_hcl_lab(ah + bh * t, ac + bc * t, al + bl * t) + ""; + }; + } + d3.interpolateHsl = d3_interpolateHsl; + function d3_interpolateHsl(a, b) { + a = d3.hsl(a); + b = d3.hsl(b); + var ah = a.h, as = a.s, al = a.l, bh = b.h - ah, bs = b.s - as, bl = b.l - al; + if (isNaN(bs)) bs = 0, as = isNaN(as) ? b.s : as; + if (isNaN(bh)) bh = 0, ah = isNaN(ah) ? b.h : ah; else if (bh > 180) bh -= 360; else if (bh < -180) bh += 360; + return function(t) { + return d3_hsl_rgb(ah + bh * t, as + bs * t, al + bl * t) + ""; + }; + } + d3.interpolateLab = d3_interpolateLab; + function d3_interpolateLab(a, b) { + a = d3.lab(a); + b = d3.lab(b); + var al = a.l, aa = a.a, ab = a.b, bl = b.l - al, ba = b.a - aa, bb = b.b - ab; + return function(t) { + return d3_lab_rgb(al + bl * t, aa + ba * t, ab + bb * t) + ""; + }; + } + d3.interpolateRound = d3_interpolateRound; + function d3_interpolateRound(a, b) { + b -= a; + return function(t) { + return Math.round(a + b * t); + }; + } + d3.transform = function(string) { + var g = d3_document.createElementNS(d3.ns.prefix.svg, "g"); + return (d3.transform = function(string) { + if (string != null) { + g.setAttribute("transform", string); + var t = g.transform.baseVal.consolidate(); + } + return new d3_transform(t ? t.matrix : d3_transformIdentity); + })(string); + }; + function d3_transform(m) { + var r0 = [ m.a, m.b ], r1 = [ m.c, m.d ], kx = d3_transformNormalize(r0), kz = d3_transformDot(r0, r1), ky = d3_transformNormalize(d3_transformCombine(r1, r0, -kz)) || 0; + if (r0[0] * r1[1] < r1[0] * r0[1]) { + r0[0] *= -1; + r0[1] *= -1; + kx *= -1; + kz *= -1; + } + this.rotate = (kx ? Math.atan2(r0[1], r0[0]) : Math.atan2(-r1[0], r1[1])) * d3_degrees; + this.translate = [ m.e, m.f ]; + this.scale = [ kx, ky ]; + this.skew = ky ? Math.atan2(kz, ky) * d3_degrees : 0; + } + d3_transform.prototype.toString = function() { + return "translate(" + this.translate + ")rotate(" + this.rotate + ")skewX(" + this.skew + ")scale(" + this.scale + ")"; + }; + function d3_transformDot(a, b) { + return a[0] * b[0] + a[1] * b[1]; + } + function d3_transformNormalize(a) { + var k = Math.sqrt(d3_transformDot(a, a)); + if (k) { + a[0] /= k; + a[1] /= k; + } + return k; + } + function d3_transformCombine(a, b, k) { + a[0] += k * b[0]; + a[1] += k * b[1]; + return a; + } + var d3_transformIdentity = { + a: 1, + b: 0, + c: 0, + d: 1, + e: 0, + f: 0 + }; + d3.interpolateTransform = d3_interpolateTransform; + function d3_interpolateTransformPop(s) { + return s.length ? s.pop() + "," : ""; + } + function d3_interpolateTranslate(ta, tb, s, q) { + if (ta[0] !== tb[0] || ta[1] !== tb[1]) { + var i = s.push("translate(", null, ",", null, ")"); + q.push({ + i: i - 4, + x: d3_interpolateNumber(ta[0], tb[0]) + }, { + i: i - 2, + x: d3_interpolateNumber(ta[1], tb[1]) + }); + } else if (tb[0] || tb[1]) { + s.push("translate(" + tb + ")"); + } + } + function d3_interpolateRotate(ra, rb, s, q) { + if (ra !== rb) { + if (ra - rb > 180) rb += 360; else if (rb - ra > 180) ra += 360; + q.push({ + i: s.push(d3_interpolateTransformPop(s) + "rotate(", null, ")") - 2, + x: d3_interpolateNumber(ra, rb) + }); + } else if (rb) { + s.push(d3_interpolateTransformPop(s) + "rotate(" + rb + ")"); + } + } + function d3_interpolateSkew(wa, wb, s, q) { + if (wa !== wb) { + q.push({ + i: s.push(d3_interpolateTransformPop(s) + "skewX(", null, ")") - 2, + x: d3_interpolateNumber(wa, wb) + }); + } else if (wb) { + s.push(d3_interpolateTransformPop(s) + "skewX(" + wb + ")"); + } + } + function d3_interpolateScale(ka, kb, s, q) { + if (ka[0] !== kb[0] || ka[1] !== kb[1]) { + var i = s.push(d3_interpolateTransformPop(s) + "scale(", null, ",", null, ")"); + q.push({ + i: i - 4, + x: d3_interpolateNumber(ka[0], kb[0]) + }, { + i: i - 2, + x: d3_interpolateNumber(ka[1], kb[1]) + }); + } else if (kb[0] !== 1 || kb[1] !== 1) { + s.push(d3_interpolateTransformPop(s) + "scale(" + kb + ")"); + } + } + function d3_interpolateTransform(a, b) { + var s = [], q = []; + a = d3.transform(a), b = d3.transform(b); + d3_interpolateTranslate(a.translate, b.translate, s, q); + d3_interpolateRotate(a.rotate, b.rotate, s, q); + d3_interpolateSkew(a.skew, b.skew, s, q); + d3_interpolateScale(a.scale, b.scale, s, q); + a = b = null; + return function(t) { + var i = -1, n = q.length, o; + while (++i < n) s[(o = q[i]).i] = o.x(t); + return s.join(""); + }; + } + function d3_uninterpolateNumber(a, b) { + b = (b -= a = +a) || 1 / b; + return function(x) { + return (x - a) / b; + }; + } + function d3_uninterpolateClamp(a, b) { + b = (b -= a = +a) || 1 / b; + return function(x) { + return Math.max(0, Math.min(1, (x - a) / b)); + }; + } + d3.layout = {}; + d3.layout.bundle = function() { + return function(links) { + var paths = [], i = -1, n = links.length; + while (++i < n) paths.push(d3_layout_bundlePath(links[i])); + return paths; + }; + }; + function d3_layout_bundlePath(link) { + var start = link.source, end = link.target, lca = d3_layout_bundleLeastCommonAncestor(start, end), points = [ start ]; + while (start !== lca) { + start = start.parent; + points.push(start); + } + var k = points.length; + while (end !== lca) { + points.splice(k, 0, end); + end = end.parent; + } + return points; + } + function d3_layout_bundleAncestors(node) { + var ancestors = [], parent = node.parent; + while (parent != null) { + ancestors.push(node); + node = parent; + parent = parent.parent; + } + ancestors.push(node); + return ancestors; + } + function d3_layout_bundleLeastCommonAncestor(a, b) { + if (a === b) return a; + var aNodes = d3_layout_bundleAncestors(a), bNodes = d3_layout_bundleAncestors(b), aNode = aNodes.pop(), bNode = bNodes.pop(), sharedNode = null; + while (aNode === bNode) { + sharedNode = aNode; + aNode = aNodes.pop(); + bNode = bNodes.pop(); + } + return sharedNode; + } + d3.layout.chord = function() { + var chord = {}, chords, groups, matrix, n, padding = 0, sortGroups, sortSubgroups, sortChords; + function relayout() { + var subgroups = {}, groupSums = [], groupIndex = d3.range(n), subgroupIndex = [], k, x, x0, i, j; + chords = []; + groups = []; + k = 0, i = -1; + while (++i < n) { + x = 0, j = -1; + while (++j < n) { + x += matrix[i][j]; + } + groupSums.push(x); + subgroupIndex.push(d3.range(n)); + k += x; + } + if (sortGroups) { + groupIndex.sort(function(a, b) { + return sortGroups(groupSums[a], groupSums[b]); + }); + } + if (sortSubgroups) { + subgroupIndex.forEach(function(d, i) { + d.sort(function(a, b) { + return sortSubgroups(matrix[i][a], matrix[i][b]); + }); + }); + } + k = (τ - padding * n) / k; + x = 0, i = -1; + while (++i < n) { + x0 = x, j = -1; + while (++j < n) { + var di = groupIndex[i], dj = subgroupIndex[di][j], v = matrix[di][dj], a0 = x, a1 = x += v * k; + subgroups[di + "-" + dj] = { + index: di, + subindex: dj, + startAngle: a0, + endAngle: a1, + value: v + }; + } + groups[di] = { + index: di, + startAngle: x0, + endAngle: x, + value: groupSums[di] + }; + x += padding; + } + i = -1; + while (++i < n) { + j = i - 1; + while (++j < n) { + var source = subgroups[i + "-" + j], target = subgroups[j + "-" + i]; + if (source.value || target.value) { + chords.push(source.value < target.value ? { + source: target, + target: source + } : { + source: source, + target: target + }); + } + } + } + if (sortChords) resort(); + } + function resort() { + chords.sort(function(a, b) { + return sortChords((a.source.value + a.target.value) / 2, (b.source.value + b.target.value) / 2); + }); + } + chord.matrix = function(x) { + if (!arguments.length) return matrix; + n = (matrix = x) && matrix.length; + chords = groups = null; + return chord; + }; + chord.padding = function(x) { + if (!arguments.length) return padding; + padding = x; + chords = groups = null; + return chord; + }; + chord.sortGroups = function(x) { + if (!arguments.length) return sortGroups; + sortGroups = x; + chords = groups = null; + return chord; + }; + chord.sortSubgroups = function(x) { + if (!arguments.length) return sortSubgroups; + sortSubgroups = x; + chords = null; + return chord; + }; + chord.sortChords = function(x) { + if (!arguments.length) return sortChords; + sortChords = x; + if (chords) resort(); + return chord; + }; + chord.chords = function() { + if (!chords) relayout(); + return chords; + }; + chord.groups = function() { + if (!groups) relayout(); + return groups; + }; + return chord; + }; + d3.layout.force = function() { + var force = {}, event = d3.dispatch("start", "tick", "end"), timer, size = [ 1, 1 ], drag, alpha, friction = .9, linkDistance = d3_layout_forceLinkDistance, linkStrength = d3_layout_forceLinkStrength, charge = -30, chargeDistance2 = d3_layout_forceChargeDistance2, gravity = .1, theta2 = .64, nodes = [], links = [], distances, strengths, charges; + function repulse(node) { + return function(quad, x1, _, x2) { + if (quad.point !== node) { + var dx = quad.cx - node.x, dy = quad.cy - node.y, dw = x2 - x1, dn = dx * dx + dy * dy; + if (dw * dw / theta2 < dn) { + if (dn < chargeDistance2) { + var k = quad.charge / dn; + node.px -= dx * k; + node.py -= dy * k; + } + return true; + } + if (quad.point && dn && dn < chargeDistance2) { + var k = quad.pointCharge / dn; + node.px -= dx * k; + node.py -= dy * k; + } + } + return !quad.charge; + }; + } + force.tick = function() { + if ((alpha *= .99) < .005) { + timer = null; + event.end({ + type: "end", + alpha: alpha = 0 + }); + return true; + } + var n = nodes.length, m = links.length, q, i, o, s, t, l, k, x, y; + for (i = 0; i < m; ++i) { + o = links[i]; + s = o.source; + t = o.target; + x = t.x - s.x; + y = t.y - s.y; + if (l = x * x + y * y) { + l = alpha * strengths[i] * ((l = Math.sqrt(l)) - distances[i]) / l; + x *= l; + y *= l; + t.x -= x * (k = s.weight + t.weight ? s.weight / (s.weight + t.weight) : .5); + t.y -= y * k; + s.x += x * (k = 1 - k); + s.y += y * k; + } + } + if (k = alpha * gravity) { + x = size[0] / 2; + y = size[1] / 2; + i = -1; + if (k) while (++i < n) { + o = nodes[i]; + o.x += (x - o.x) * k; + o.y += (y - o.y) * k; + } + } + if (charge) { + d3_layout_forceAccumulate(q = d3.geom.quadtree(nodes), alpha, charges); + i = -1; + while (++i < n) { + if (!(o = nodes[i]).fixed) { + q.visit(repulse(o)); + } + } + } + i = -1; + while (++i < n) { + o = nodes[i]; + if (o.fixed) { + o.x = o.px; + o.y = o.py; + } else { + o.x -= (o.px - (o.px = o.x)) * friction; + o.y -= (o.py - (o.py = o.y)) * friction; + } + } + event.tick({ + type: "tick", + alpha: alpha + }); + }; + force.nodes = function(x) { + if (!arguments.length) return nodes; + nodes = x; + return force; + }; + force.links = function(x) { + if (!arguments.length) return links; + links = x; + return force; + }; + force.size = function(x) { + if (!arguments.length) return size; + size = x; + return force; + }; + force.linkDistance = function(x) { + if (!arguments.length) return linkDistance; + linkDistance = typeof x === "function" ? x : +x; + return force; + }; + force.distance = force.linkDistance; + force.linkStrength = function(x) { + if (!arguments.length) return linkStrength; + linkStrength = typeof x === "function" ? x : +x; + return force; + }; + force.friction = function(x) { + if (!arguments.length) return friction; + friction = +x; + return force; + }; + force.charge = function(x) { + if (!arguments.length) return charge; + charge = typeof x === "function" ? x : +x; + return force; + }; + force.chargeDistance = function(x) { + if (!arguments.length) return Math.sqrt(chargeDistance2); + chargeDistance2 = x * x; + return force; + }; + force.gravity = function(x) { + if (!arguments.length) return gravity; + gravity = +x; + return force; + }; + force.theta = function(x) { + if (!arguments.length) return Math.sqrt(theta2); + theta2 = x * x; + return force; + }; + force.alpha = function(x) { + if (!arguments.length) return alpha; + x = +x; + if (alpha) { + if (x > 0) { + alpha = x; + } else { + timer.c = null, timer.t = NaN, timer = null; + event.end({ + type: "end", + alpha: alpha = 0 + }); + } + } else if (x > 0) { + event.start({ + type: "start", + alpha: alpha = x + }); + timer = d3_timer(force.tick); + } + return force; + }; + force.start = function() { + var i, n = nodes.length, m = links.length, w = size[0], h = size[1], neighbors, o; + for (i = 0; i < n; ++i) { + (o = nodes[i]).index = i; + o.weight = 0; + } + for (i = 0; i < m; ++i) { + o = links[i]; + if (typeof o.source == "number") o.source = nodes[o.source]; + if (typeof o.target == "number") o.target = nodes[o.target]; + ++o.source.weight; + ++o.target.weight; + } + for (i = 0; i < n; ++i) { + o = nodes[i]; + if (isNaN(o.x)) o.x = position("x", w); + if (isNaN(o.y)) o.y = position("y", h); + if (isNaN(o.px)) o.px = o.x; + if (isNaN(o.py)) o.py = o.y; + } + distances = []; + if (typeof linkDistance === "function") for (i = 0; i < m; ++i) distances[i] = +linkDistance.call(this, links[i], i); else for (i = 0; i < m; ++i) distances[i] = linkDistance; + strengths = []; + if (typeof linkStrength === "function") for (i = 0; i < m; ++i) strengths[i] = +linkStrength.call(this, links[i], i); else for (i = 0; i < m; ++i) strengths[i] = linkStrength; + charges = []; + if (typeof charge === "function") for (i = 0; i < n; ++i) charges[i] = +charge.call(this, nodes[i], i); else for (i = 0; i < n; ++i) charges[i] = charge; + function position(dimension, size) { + if (!neighbors) { + neighbors = new Array(n); + for (j = 0; j < n; ++j) { + neighbors[j] = []; + } + for (j = 0; j < m; ++j) { + var o = links[j]; + neighbors[o.source.index].push(o.target); + neighbors[o.target.index].push(o.source); + } + } + var candidates = neighbors[i], j = -1, l = candidates.length, x; + while (++j < l) if (!isNaN(x = candidates[j][dimension])) return x; + return Math.random() * size; + } + return force.resume(); + }; + force.resume = function() { + return force.alpha(.1); + }; + force.stop = function() { + return force.alpha(0); + }; + force.drag = function() { + if (!drag) drag = d3.behavior.drag().origin(d3_identity).on("dragstart.force", d3_layout_forceDragstart).on("drag.force", dragmove).on("dragend.force", d3_layout_forceDragend); + if (!arguments.length) return drag; + this.on("mouseover.force", d3_layout_forceMouseover).on("mouseout.force", d3_layout_forceMouseout).call(drag); + }; + function dragmove(d) { + d.px = d3.event.x, d.py = d3.event.y; + force.resume(); + } + return d3.rebind(force, event, "on"); + }; + function d3_layout_forceDragstart(d) { + d.fixed |= 2; + } + function d3_layout_forceDragend(d) { + d.fixed &= ~6; + } + function d3_layout_forceMouseover(d) { + d.fixed |= 4; + d.px = d.x, d.py = d.y; + } + function d3_layout_forceMouseout(d) { + d.fixed &= ~4; + } + function d3_layout_forceAccumulate(quad, alpha, charges) { + var cx = 0, cy = 0; + quad.charge = 0; + if (!quad.leaf) { + var nodes = quad.nodes, n = nodes.length, i = -1, c; + while (++i < n) { + c = nodes[i]; + if (c == null) continue; + d3_layout_forceAccumulate(c, alpha, charges); + quad.charge += c.charge; + cx += c.charge * c.cx; + cy += c.charge * c.cy; + } + } + if (quad.point) { + if (!quad.leaf) { + quad.point.x += Math.random() - .5; + quad.point.y += Math.random() - .5; + } + var k = alpha * charges[quad.point.index]; + quad.charge += quad.pointCharge = k; + cx += k * quad.point.x; + cy += k * quad.point.y; + } + quad.cx = cx / quad.charge; + quad.cy = cy / quad.charge; + } + var d3_layout_forceLinkDistance = 20, d3_layout_forceLinkStrength = 1, d3_layout_forceChargeDistance2 = Infinity; + d3.layout.hierarchy = function() { + var sort = d3_layout_hierarchySort, children = d3_layout_hierarchyChildren, value = d3_layout_hierarchyValue; + function hierarchy(root) { + var stack = [ root ], nodes = [], node; + root.depth = 0; + while ((node = stack.pop()) != null) { + nodes.push(node); + if ((childs = children.call(hierarchy, node, node.depth)) && (n = childs.length)) { + var n, childs, child; + while (--n >= 0) { + stack.push(child = childs[n]); + child.parent = node; + child.depth = node.depth + 1; + } + if (value) node.value = 0; + node.children = childs; + } else { + if (value) node.value = +value.call(hierarchy, node, node.depth) || 0; + delete node.children; + } + } + d3_layout_hierarchyVisitAfter(root, function(node) { + var childs, parent; + if (sort && (childs = node.children)) childs.sort(sort); + if (value && (parent = node.parent)) parent.value += node.value; + }); + return nodes; + } + hierarchy.sort = function(x) { + if (!arguments.length) return sort; + sort = x; + return hierarchy; + }; + hierarchy.children = function(x) { + if (!arguments.length) return children; + children = x; + return hierarchy; + }; + hierarchy.value = function(x) { + if (!arguments.length) return value; + value = x; + return hierarchy; + }; + hierarchy.revalue = function(root) { + if (value) { + d3_layout_hierarchyVisitBefore(root, function(node) { + if (node.children) node.value = 0; + }); + d3_layout_hierarchyVisitAfter(root, function(node) { + var parent; + if (!node.children) node.value = +value.call(hierarchy, node, node.depth) || 0; + if (parent = node.parent) parent.value += node.value; + }); + } + return root; + }; + return hierarchy; + }; + function d3_layout_hierarchyRebind(object, hierarchy) { + d3.rebind(object, hierarchy, "sort", "children", "value"); + object.nodes = object; + object.links = d3_layout_hierarchyLinks; + return object; + } + function d3_layout_hierarchyVisitBefore(node, callback) { + var nodes = [ node ]; + while ((node = nodes.pop()) != null) { + callback(node); + if ((children = node.children) && (n = children.length)) { + var n, children; + while (--n >= 0) nodes.push(children[n]); + } + } + } + function d3_layout_hierarchyVisitAfter(node, callback) { + var nodes = [ node ], nodes2 = []; + while ((node = nodes.pop()) != null) { + nodes2.push(node); + if ((children = node.children) && (n = children.length)) { + var i = -1, n, children; + while (++i < n) nodes.push(children[i]); + } + } + while ((node = nodes2.pop()) != null) { + callback(node); + } + } + function d3_layout_hierarchyChildren(d) { + return d.children; + } + function d3_layout_hierarchyValue(d) { + return d.value; + } + function d3_layout_hierarchySort(a, b) { + return b.value - a.value; + } + function d3_layout_hierarchyLinks(nodes) { + return d3.merge(nodes.map(function(parent) { + return (parent.children || []).map(function(child) { + return { + source: parent, + target: child + }; + }); + })); + } + d3.layout.partition = function() { + var hierarchy = d3.layout.hierarchy(), size = [ 1, 1 ]; + function position(node, x, dx, dy) { + var children = node.children; + node.x = x; + node.y = node.depth * dy; + node.dx = dx; + node.dy = dy; + if (children && (n = children.length)) { + var i = -1, n, c, d; + dx = node.value ? dx / node.value : 0; + while (++i < n) { + position(c = children[i], x, d = c.value * dx, dy); + x += d; + } + } + } + function depth(node) { + var children = node.children, d = 0; + if (children && (n = children.length)) { + var i = -1, n; + while (++i < n) d = Math.max(d, depth(children[i])); + } + return 1 + d; + } + function partition(d, i) { + var nodes = hierarchy.call(this, d, i); + position(nodes[0], 0, size[0], size[1] / depth(nodes[0])); + return nodes; + } + partition.size = function(x) { + if (!arguments.length) return size; + size = x; + return partition; + }; + return d3_layout_hierarchyRebind(partition, hierarchy); + }; + d3.layout.pie = function() { + var value = Number, sort = d3_layout_pieSortByValue, startAngle = 0, endAngle = τ, padAngle = 0; + function pie(data) { + var n = data.length, values = data.map(function(d, i) { + return +value.call(pie, d, i); + }), a = +(typeof startAngle === "function" ? startAngle.apply(this, arguments) : startAngle), da = (typeof endAngle === "function" ? endAngle.apply(this, arguments) : endAngle) - a, p = Math.min(Math.abs(da) / n, +(typeof padAngle === "function" ? padAngle.apply(this, arguments) : padAngle)), pa = p * (da < 0 ? -1 : 1), sum = d3.sum(values), k = sum ? (da - n * pa) / sum : 0, index = d3.range(n), arcs = [], v; + if (sort != null) index.sort(sort === d3_layout_pieSortByValue ? function(i, j) { + return values[j] - values[i]; + } : function(i, j) { + return sort(data[i], data[j]); + }); + index.forEach(function(i) { + arcs[i] = { + data: data[i], + value: v = values[i], + startAngle: a, + endAngle: a += v * k + pa, + padAngle: p + }; + }); + return arcs; + } + pie.value = function(_) { + if (!arguments.length) return value; + value = _; + return pie; + }; + pie.sort = function(_) { + if (!arguments.length) return sort; + sort = _; + return pie; + }; + pie.startAngle = function(_) { + if (!arguments.length) return startAngle; + startAngle = _; + return pie; + }; + pie.endAngle = function(_) { + if (!arguments.length) return endAngle; + endAngle = _; + return pie; + }; + pie.padAngle = function(_) { + if (!arguments.length) return padAngle; + padAngle = _; + return pie; + }; + return pie; + }; + var d3_layout_pieSortByValue = {}; + d3.layout.stack = function() { + var values = d3_identity, order = d3_layout_stackOrderDefault, offset = d3_layout_stackOffsetZero, out = d3_layout_stackOut, x = d3_layout_stackX, y = d3_layout_stackY; + function stack(data, index) { + if (!(n = data.length)) return data; + var series = data.map(function(d, i) { + return values.call(stack, d, i); + }); + var points = series.map(function(d) { + return d.map(function(v, i) { + return [ x.call(stack, v, i), y.call(stack, v, i) ]; + }); + }); + var orders = order.call(stack, points, index); + series = d3.permute(series, orders); + points = d3.permute(points, orders); + var offsets = offset.call(stack, points, index); + var m = series[0].length, n, i, j, o; + for (j = 0; j < m; ++j) { + out.call(stack, series[0][j], o = offsets[j], points[0][j][1]); + for (i = 1; i < n; ++i) { + out.call(stack, series[i][j], o += points[i - 1][j][1], points[i][j][1]); + } + } + return data; + } + stack.values = function(x) { + if (!arguments.length) return values; + values = x; + return stack; + }; + stack.order = function(x) { + if (!arguments.length) return order; + order = typeof x === "function" ? x : d3_layout_stackOrders.get(x) || d3_layout_stackOrderDefault; + return stack; + }; + stack.offset = function(x) { + if (!arguments.length) return offset; + offset = typeof x === "function" ? x : d3_layout_stackOffsets.get(x) || d3_layout_stackOffsetZero; + return stack; + }; + stack.x = function(z) { + if (!arguments.length) return x; + x = z; + return stack; + }; + stack.y = function(z) { + if (!arguments.length) return y; + y = z; + return stack; + }; + stack.out = function(z) { + if (!arguments.length) return out; + out = z; + return stack; + }; + return stack; + }; + function d3_layout_stackX(d) { + return d.x; + } + function d3_layout_stackY(d) { + return d.y; + } + function d3_layout_stackOut(d, y0, y) { + d.y0 = y0; + d.y = y; + } + var d3_layout_stackOrders = d3.map({ + "inside-out": function(data) { + var n = data.length, i, j, max = data.map(d3_layout_stackMaxIndex), sums = data.map(d3_layout_stackReduceSum), index = d3.range(n).sort(function(a, b) { + return max[a] - max[b]; + }), top = 0, bottom = 0, tops = [], bottoms = []; + for (i = 0; i < n; ++i) { + j = index[i]; + if (top < bottom) { + top += sums[j]; + tops.push(j); + } else { + bottom += sums[j]; + bottoms.push(j); + } + } + return bottoms.reverse().concat(tops); + }, + reverse: function(data) { + return d3.range(data.length).reverse(); + }, + "default": d3_layout_stackOrderDefault + }); + var d3_layout_stackOffsets = d3.map({ + silhouette: function(data) { + var n = data.length, m = data[0].length, sums = [], max = 0, i, j, o, y0 = []; + for (j = 0; j < m; ++j) { + for (i = 0, o = 0; i < n; i++) o += data[i][j][1]; + if (o > max) max = o; + sums.push(o); + } + for (j = 0; j < m; ++j) { + y0[j] = (max - sums[j]) / 2; + } + return y0; + }, + wiggle: function(data) { + var n = data.length, x = data[0], m = x.length, i, j, k, s1, s2, s3, dx, o, o0, y0 = []; + y0[0] = o = o0 = 0; + for (j = 1; j < m; ++j) { + for (i = 0, s1 = 0; i < n; ++i) s1 += data[i][j][1]; + for (i = 0, s2 = 0, dx = x[j][0] - x[j - 1][0]; i < n; ++i) { + for (k = 0, s3 = (data[i][j][1] - data[i][j - 1][1]) / (2 * dx); k < i; ++k) { + s3 += (data[k][j][1] - data[k][j - 1][1]) / dx; + } + s2 += s3 * data[i][j][1]; + } + y0[j] = o -= s1 ? s2 / s1 * dx : 0; + if (o < o0) o0 = o; + } + for (j = 0; j < m; ++j) y0[j] -= o0; + return y0; + }, + expand: function(data) { + var n = data.length, m = data[0].length, k = 1 / n, i, j, o, y0 = []; + for (j = 0; j < m; ++j) { + for (i = 0, o = 0; i < n; i++) o += data[i][j][1]; + if (o) for (i = 0; i < n; i++) data[i][j][1] /= o; else for (i = 0; i < n; i++) data[i][j][1] = k; + } + for (j = 0; j < m; ++j) y0[j] = 0; + return y0; + }, + zero: d3_layout_stackOffsetZero + }); + function d3_layout_stackOrderDefault(data) { + return d3.range(data.length); + } + function d3_layout_stackOffsetZero(data) { + var j = -1, m = data[0].length, y0 = []; + while (++j < m) y0[j] = 0; + return y0; + } + function d3_layout_stackMaxIndex(array) { + var i = 1, j = 0, v = array[0][1], k, n = array.length; + for (;i < n; ++i) { + if ((k = array[i][1]) > v) { + j = i; + v = k; + } + } + return j; + } + function d3_layout_stackReduceSum(d) { + return d.reduce(d3_layout_stackSum, 0); + } + function d3_layout_stackSum(p, d) { + return p + d[1]; + } + d3.layout.histogram = function() { + var frequency = true, valuer = Number, ranger = d3_layout_histogramRange, binner = d3_layout_histogramBinSturges; + function histogram(data, i) { + var bins = [], values = data.map(valuer, this), range = ranger.call(this, values, i), thresholds = binner.call(this, range, values, i), bin, i = -1, n = values.length, m = thresholds.length - 1, k = frequency ? 1 : 1 / n, x; + while (++i < m) { + bin = bins[i] = []; + bin.dx = thresholds[i + 1] - (bin.x = thresholds[i]); + bin.y = 0; + } + if (m > 0) { + i = -1; + while (++i < n) { + x = values[i]; + if (x >= range[0] && x <= range[1]) { + bin = bins[d3.bisect(thresholds, x, 1, m) - 1]; + bin.y += k; + bin.push(data[i]); + } + } + } + return bins; + } + histogram.value = function(x) { + if (!arguments.length) return valuer; + valuer = x; + return histogram; + }; + histogram.range = function(x) { + if (!arguments.length) return ranger; + ranger = d3_functor(x); + return histogram; + }; + histogram.bins = function(x) { + if (!arguments.length) return binner; + binner = typeof x === "number" ? function(range) { + return d3_layout_histogramBinFixed(range, x); + } : d3_functor(x); + return histogram; + }; + histogram.frequency = function(x) { + if (!arguments.length) return frequency; + frequency = !!x; + return histogram; + }; + return histogram; + }; + function d3_layout_histogramBinSturges(range, values) { + return d3_layout_histogramBinFixed(range, Math.ceil(Math.log(values.length) / Math.LN2 + 1)); + } + function d3_layout_histogramBinFixed(range, n) { + var x = -1, b = +range[0], m = (range[1] - b) / n, f = []; + while (++x <= n) f[x] = m * x + b; + return f; + } + function d3_layout_histogramRange(values) { + return [ d3.min(values), d3.max(values) ]; + } + d3.layout.pack = function() { + var hierarchy = d3.layout.hierarchy().sort(d3_layout_packSort), padding = 0, size = [ 1, 1 ], radius; + function pack(d, i) { + var nodes = hierarchy.call(this, d, i), root = nodes[0], w = size[0], h = size[1], r = radius == null ? Math.sqrt : typeof radius === "function" ? radius : function() { + return radius; + }; + root.x = root.y = 0; + d3_layout_hierarchyVisitAfter(root, function(d) { + d.r = +r(d.value); + }); + d3_layout_hierarchyVisitAfter(root, d3_layout_packSiblings); + if (padding) { + var dr = padding * (radius ? 1 : Math.max(2 * root.r / w, 2 * root.r / h)) / 2; + d3_layout_hierarchyVisitAfter(root, function(d) { + d.r += dr; + }); + d3_layout_hierarchyVisitAfter(root, d3_layout_packSiblings); + d3_layout_hierarchyVisitAfter(root, function(d) { + d.r -= dr; + }); + } + d3_layout_packTransform(root, w / 2, h / 2, radius ? 1 : 1 / Math.max(2 * root.r / w, 2 * root.r / h)); + return nodes; + } + pack.size = function(_) { + if (!arguments.length) return size; + size = _; + return pack; + }; + pack.radius = function(_) { + if (!arguments.length) return radius; + radius = _ == null || typeof _ === "function" ? _ : +_; + return pack; + }; + pack.padding = function(_) { + if (!arguments.length) return padding; + padding = +_; + return pack; + }; + return d3_layout_hierarchyRebind(pack, hierarchy); + }; + function d3_layout_packSort(a, b) { + return a.value - b.value; + } + function d3_layout_packInsert(a, b) { + var c = a._pack_next; + a._pack_next = b; + b._pack_prev = a; + b._pack_next = c; + c._pack_prev = b; + } + function d3_layout_packSplice(a, b) { + a._pack_next = b; + b._pack_prev = a; + } + function d3_layout_packIntersects(a, b) { + var dx = b.x - a.x, dy = b.y - a.y, dr = a.r + b.r; + return .999 * dr * dr > dx * dx + dy * dy; + } + function d3_layout_packSiblings(node) { + if (!(nodes = node.children) || !(n = nodes.length)) return; + var nodes, xMin = Infinity, xMax = -Infinity, yMin = Infinity, yMax = -Infinity, a, b, c, i, j, k, n; + function bound(node) { + xMin = Math.min(node.x - node.r, xMin); + xMax = Math.max(node.x + node.r, xMax); + yMin = Math.min(node.y - node.r, yMin); + yMax = Math.max(node.y + node.r, yMax); + } + nodes.forEach(d3_layout_packLink); + a = nodes[0]; + a.x = -a.r; + a.y = 0; + bound(a); + if (n > 1) { + b = nodes[1]; + b.x = b.r; + b.y = 0; + bound(b); + if (n > 2) { + c = nodes[2]; + d3_layout_packPlace(a, b, c); + bound(c); + d3_layout_packInsert(a, c); + a._pack_prev = c; + d3_layout_packInsert(c, b); + b = a._pack_next; + for (i = 3; i < n; i++) { + d3_layout_packPlace(a, b, c = nodes[i]); + var isect = 0, s1 = 1, s2 = 1; + for (j = b._pack_next; j !== b; j = j._pack_next, s1++) { + if (d3_layout_packIntersects(j, c)) { + isect = 1; + break; + } + } + if (isect == 1) { + for (k = a._pack_prev; k !== j._pack_prev; k = k._pack_prev, s2++) { + if (d3_layout_packIntersects(k, c)) { + break; + } + } + } + if (isect) { + if (s1 < s2 || s1 == s2 && b.r < a.r) d3_layout_packSplice(a, b = j); else d3_layout_packSplice(a = k, b); + i--; + } else { + d3_layout_packInsert(a, c); + b = c; + bound(c); + } + } + } + } + var cx = (xMin + xMax) / 2, cy = (yMin + yMax) / 2, cr = 0; + for (i = 0; i < n; i++) { + c = nodes[i]; + c.x -= cx; + c.y -= cy; + cr = Math.max(cr, c.r + Math.sqrt(c.x * c.x + c.y * c.y)); + } + node.r = cr; + nodes.forEach(d3_layout_packUnlink); + } + function d3_layout_packLink(node) { + node._pack_next = node._pack_prev = node; + } + function d3_layout_packUnlink(node) { + delete node._pack_next; + delete node._pack_prev; + } + function d3_layout_packTransform(node, x, y, k) { + var children = node.children; + node.x = x += k * node.x; + node.y = y += k * node.y; + node.r *= k; + if (children) { + var i = -1, n = children.length; + while (++i < n) d3_layout_packTransform(children[i], x, y, k); + } + } + function d3_layout_packPlace(a, b, c) { + var db = a.r + c.r, dx = b.x - a.x, dy = b.y - a.y; + if (db && (dx || dy)) { + var da = b.r + c.r, dc = dx * dx + dy * dy; + da *= da; + db *= db; + var x = .5 + (db - da) / (2 * dc), y = Math.sqrt(Math.max(0, 2 * da * (db + dc) - (db -= dc) * db - da * da)) / (2 * dc); + c.x = a.x + x * dx + y * dy; + c.y = a.y + x * dy - y * dx; + } else { + c.x = a.x + db; + c.y = a.y; + } + } + d3.layout.tree = function() { + var hierarchy = d3.layout.hierarchy().sort(null).value(null), separation = d3_layout_treeSeparation, size = [ 1, 1 ], nodeSize = null; + function tree(d, i) { + var nodes = hierarchy.call(this, d, i), root0 = nodes[0], root1 = wrapTree(root0); + d3_layout_hierarchyVisitAfter(root1, firstWalk), root1.parent.m = -root1.z; + d3_layout_hierarchyVisitBefore(root1, secondWalk); + if (nodeSize) d3_layout_hierarchyVisitBefore(root0, sizeNode); else { + var left = root0, right = root0, bottom = root0; + d3_layout_hierarchyVisitBefore(root0, function(node) { + if (node.x < left.x) left = node; + if (node.x > right.x) right = node; + if (node.depth > bottom.depth) bottom = node; + }); + var tx = separation(left, right) / 2 - left.x, kx = size[0] / (right.x + separation(right, left) / 2 + tx), ky = size[1] / (bottom.depth || 1); + d3_layout_hierarchyVisitBefore(root0, function(node) { + node.x = (node.x + tx) * kx; + node.y = node.depth * ky; + }); + } + return nodes; + } + function wrapTree(root0) { + var root1 = { + A: null, + children: [ root0 ] + }, queue = [ root1 ], node1; + while ((node1 = queue.pop()) != null) { + for (var children = node1.children, child, i = 0, n = children.length; i < n; ++i) { + queue.push((children[i] = child = { + _: children[i], + parent: node1, + children: (child = children[i].children) && child.slice() || [], + A: null, + a: null, + z: 0, + m: 0, + c: 0, + s: 0, + t: null, + i: i + }).a = child); + } + } + return root1.children[0]; + } + function firstWalk(v) { + var children = v.children, siblings = v.parent.children, w = v.i ? siblings[v.i - 1] : null; + if (children.length) { + d3_layout_treeShift(v); + var midpoint = (children[0].z + children[children.length - 1].z) / 2; + if (w) { + v.z = w.z + separation(v._, w._); + v.m = v.z - midpoint; + } else { + v.z = midpoint; + } + } else if (w) { + v.z = w.z + separation(v._, w._); + } + v.parent.A = apportion(v, w, v.parent.A || siblings[0]); + } + function secondWalk(v) { + v._.x = v.z + v.parent.m; + v.m += v.parent.m; + } + function apportion(v, w, ancestor) { + if (w) { + var vip = v, vop = v, vim = w, vom = vip.parent.children[0], sip = vip.m, sop = vop.m, sim = vim.m, som = vom.m, shift; + while (vim = d3_layout_treeRight(vim), vip = d3_layout_treeLeft(vip), vim && vip) { + vom = d3_layout_treeLeft(vom); + vop = d3_layout_treeRight(vop); + vop.a = v; + shift = vim.z + sim - vip.z - sip + separation(vim._, vip._); + if (shift > 0) { + d3_layout_treeMove(d3_layout_treeAncestor(vim, v, ancestor), v, shift); + sip += shift; + sop += shift; + } + sim += vim.m; + sip += vip.m; + som += vom.m; + sop += vop.m; + } + if (vim && !d3_layout_treeRight(vop)) { + vop.t = vim; + vop.m += sim - sop; + } + if (vip && !d3_layout_treeLeft(vom)) { + vom.t = vip; + vom.m += sip - som; + ancestor = v; + } + } + return ancestor; + } + function sizeNode(node) { + node.x *= size[0]; + node.y = node.depth * size[1]; + } + tree.separation = function(x) { + if (!arguments.length) return separation; + separation = x; + return tree; + }; + tree.size = function(x) { + if (!arguments.length) return nodeSize ? null : size; + nodeSize = (size = x) == null ? sizeNode : null; + return tree; + }; + tree.nodeSize = function(x) { + if (!arguments.length) return nodeSize ? size : null; + nodeSize = (size = x) == null ? null : sizeNode; + return tree; + }; + return d3_layout_hierarchyRebind(tree, hierarchy); + }; + function d3_layout_treeSeparation(a, b) { + return a.parent == b.parent ? 1 : 2; + } + function d3_layout_treeLeft(v) { + var children = v.children; + return children.length ? children[0] : v.t; + } + function d3_layout_treeRight(v) { + var children = v.children, n; + return (n = children.length) ? children[n - 1] : v.t; + } + function d3_layout_treeMove(wm, wp, shift) { + var change = shift / (wp.i - wm.i); + wp.c -= change; + wp.s += shift; + wm.c += change; + wp.z += shift; + wp.m += shift; + } + function d3_layout_treeShift(v) { + var shift = 0, change = 0, children = v.children, i = children.length, w; + while (--i >= 0) { + w = children[i]; + w.z += shift; + w.m += shift; + shift += w.s + (change += w.c); + } + } + function d3_layout_treeAncestor(vim, v, ancestor) { + return vim.a.parent === v.parent ? vim.a : ancestor; + } + d3.layout.cluster = function() { + var hierarchy = d3.layout.hierarchy().sort(null).value(null), separation = d3_layout_treeSeparation, size = [ 1, 1 ], nodeSize = false; + function cluster(d, i) { + var nodes = hierarchy.call(this, d, i), root = nodes[0], previousNode, x = 0; + d3_layout_hierarchyVisitAfter(root, function(node) { + var children = node.children; + if (children && children.length) { + node.x = d3_layout_clusterX(children); + node.y = d3_layout_clusterY(children); + } else { + node.x = previousNode ? x += separation(node, previousNode) : 0; + node.y = 0; + previousNode = node; + } + }); + var left = d3_layout_clusterLeft(root), right = d3_layout_clusterRight(root), x0 = left.x - separation(left, right) / 2, x1 = right.x + separation(right, left) / 2; + d3_layout_hierarchyVisitAfter(root, nodeSize ? function(node) { + node.x = (node.x - root.x) * size[0]; + node.y = (root.y - node.y) * size[1]; + } : function(node) { + node.x = (node.x - x0) / (x1 - x0) * size[0]; + node.y = (1 - (root.y ? node.y / root.y : 1)) * size[1]; + }); + return nodes; + } + cluster.separation = function(x) { + if (!arguments.length) return separation; + separation = x; + return cluster; + }; + cluster.size = function(x) { + if (!arguments.length) return nodeSize ? null : size; + nodeSize = (size = x) == null; + return cluster; + }; + cluster.nodeSize = function(x) { + if (!arguments.length) return nodeSize ? size : null; + nodeSize = (size = x) != null; + return cluster; + }; + return d3_layout_hierarchyRebind(cluster, hierarchy); + }; + function d3_layout_clusterY(children) { + return 1 + d3.max(children, function(child) { + return child.y; + }); + } + function d3_layout_clusterX(children) { + return children.reduce(function(x, child) { + return x + child.x; + }, 0) / children.length; + } + function d3_layout_clusterLeft(node) { + var children = node.children; + return children && children.length ? d3_layout_clusterLeft(children[0]) : node; + } + function d3_layout_clusterRight(node) { + var children = node.children, n; + return children && (n = children.length) ? d3_layout_clusterRight(children[n - 1]) : node; + } + d3.layout.treemap = function() { + var hierarchy = d3.layout.hierarchy(), round = Math.round, size = [ 1, 1 ], padding = null, pad = d3_layout_treemapPadNull, sticky = false, stickies, mode = "squarify", ratio = .5 * (1 + Math.sqrt(5)); + function scale(children, k) { + var i = -1, n = children.length, child, area; + while (++i < n) { + area = (child = children[i]).value * (k < 0 ? 0 : k); + child.area = isNaN(area) || area <= 0 ? 0 : area; + } + } + function squarify(node) { + var children = node.children; + if (children && children.length) { + var rect = pad(node), row = [], remaining = children.slice(), child, best = Infinity, score, u = mode === "slice" ? rect.dx : mode === "dice" ? rect.dy : mode === "slice-dice" ? node.depth & 1 ? rect.dy : rect.dx : Math.min(rect.dx, rect.dy), n; + scale(remaining, rect.dx * rect.dy / node.value); + row.area = 0; + while ((n = remaining.length) > 0) { + row.push(child = remaining[n - 1]); + row.area += child.area; + if (mode !== "squarify" || (score = worst(row, u)) <= best) { + remaining.pop(); + best = score; + } else { + row.area -= row.pop().area; + position(row, u, rect, false); + u = Math.min(rect.dx, rect.dy); + row.length = row.area = 0; + best = Infinity; + } + } + if (row.length) { + position(row, u, rect, true); + row.length = row.area = 0; + } + children.forEach(squarify); + } + } + function stickify(node) { + var children = node.children; + if (children && children.length) { + var rect = pad(node), remaining = children.slice(), child, row = []; + scale(remaining, rect.dx * rect.dy / node.value); + row.area = 0; + while (child = remaining.pop()) { + row.push(child); + row.area += child.area; + if (child.z != null) { + position(row, child.z ? rect.dx : rect.dy, rect, !remaining.length); + row.length = row.area = 0; + } + } + children.forEach(stickify); + } + } + function worst(row, u) { + var s = row.area, r, rmax = 0, rmin = Infinity, i = -1, n = row.length; + while (++i < n) { + if (!(r = row[i].area)) continue; + if (r < rmin) rmin = r; + if (r > rmax) rmax = r; + } + s *= s; + u *= u; + return s ? Math.max(u * rmax * ratio / s, s / (u * rmin * ratio)) : Infinity; + } + function position(row, u, rect, flush) { + var i = -1, n = row.length, x = rect.x, y = rect.y, v = u ? round(row.area / u) : 0, o; + if (u == rect.dx) { + if (flush || v > rect.dy) v = rect.dy; + while (++i < n) { + o = row[i]; + o.x = x; + o.y = y; + o.dy = v; + x += o.dx = Math.min(rect.x + rect.dx - x, v ? round(o.area / v) : 0); + } + o.z = true; + o.dx += rect.x + rect.dx - x; + rect.y += v; + rect.dy -= v; + } else { + if (flush || v > rect.dx) v = rect.dx; + while (++i < n) { + o = row[i]; + o.x = x; + o.y = y; + o.dx = v; + y += o.dy = Math.min(rect.y + rect.dy - y, v ? round(o.area / v) : 0); + } + o.z = false; + o.dy += rect.y + rect.dy - y; + rect.x += v; + rect.dx -= v; + } + } + function treemap(d) { + var nodes = stickies || hierarchy(d), root = nodes[0]; + root.x = root.y = 0; + if (root.value) root.dx = size[0], root.dy = size[1]; else root.dx = root.dy = 0; + if (stickies) hierarchy.revalue(root); + scale([ root ], root.dx * root.dy / root.value); + (stickies ? stickify : squarify)(root); + if (sticky) stickies = nodes; + return nodes; + } + treemap.size = function(x) { + if (!arguments.length) return size; + size = x; + return treemap; + }; + treemap.padding = function(x) { + if (!arguments.length) return padding; + function padFunction(node) { + var p = x.call(treemap, node, node.depth); + return p == null ? d3_layout_treemapPadNull(node) : d3_layout_treemapPad(node, typeof p === "number" ? [ p, p, p, p ] : p); + } + function padConstant(node) { + return d3_layout_treemapPad(node, x); + } + var type; + pad = (padding = x) == null ? d3_layout_treemapPadNull : (type = typeof x) === "function" ? padFunction : type === "number" ? (x = [ x, x, x, x ], + padConstant) : padConstant; + return treemap; + }; + treemap.round = function(x) { + if (!arguments.length) return round != Number; + round = x ? Math.round : Number; + return treemap; + }; + treemap.sticky = function(x) { + if (!arguments.length) return sticky; + sticky = x; + stickies = null; + return treemap; + }; + treemap.ratio = function(x) { + if (!arguments.length) return ratio; + ratio = x; + return treemap; + }; + treemap.mode = function(x) { + if (!arguments.length) return mode; + mode = x + ""; + return treemap; + }; + return d3_layout_hierarchyRebind(treemap, hierarchy); + }; + function d3_layout_treemapPadNull(node) { + return { + x: node.x, + y: node.y, + dx: node.dx, + dy: node.dy + }; + } + function d3_layout_treemapPad(node, padding) { + var x = node.x + padding[3], y = node.y + padding[0], dx = node.dx - padding[1] - padding[3], dy = node.dy - padding[0] - padding[2]; + if (dx < 0) { + x += dx / 2; + dx = 0; + } + if (dy < 0) { + y += dy / 2; + dy = 0; + } + return { + x: x, + y: y, + dx: dx, + dy: dy + }; + } + d3.random = { + normal: function(µ, σ) { + var n = arguments.length; + if (n < 2) σ = 1; + if (n < 1) µ = 0; + return function() { + var x, y, r; + do { + x = Math.random() * 2 - 1; + y = Math.random() * 2 - 1; + r = x * x + y * y; + } while (!r || r > 1); + return µ + σ * x * Math.sqrt(-2 * Math.log(r) / r); + }; + }, + logNormal: function() { + var random = d3.random.normal.apply(d3, arguments); + return function() { + return Math.exp(random()); + }; + }, + bates: function(m) { + var random = d3.random.irwinHall(m); + return function() { + return random() / m; + }; + }, + irwinHall: function(m) { + return function() { + for (var s = 0, j = 0; j < m; j++) s += Math.random(); + return s; + }; + } + }; + d3.scale = {}; + function d3_scaleExtent(domain) { + var start = domain[0], stop = domain[domain.length - 1]; + return start < stop ? [ start, stop ] : [ stop, start ]; + } + function d3_scaleRange(scale) { + return scale.rangeExtent ? scale.rangeExtent() : d3_scaleExtent(scale.range()); + } + function d3_scale_bilinear(domain, range, uninterpolate, interpolate) { + var u = uninterpolate(domain[0], domain[1]), i = interpolate(range[0], range[1]); + return function(x) { + return i(u(x)); + }; + } + function d3_scale_nice(domain, nice) { + var i0 = 0, i1 = domain.length - 1, x0 = domain[i0], x1 = domain[i1], dx; + if (x1 < x0) { + dx = i0, i0 = i1, i1 = dx; + dx = x0, x0 = x1, x1 = dx; + } + domain[i0] = nice.floor(x0); + domain[i1] = nice.ceil(x1); + return domain; + } + function d3_scale_niceStep(step) { + return step ? { + floor: function(x) { + return Math.floor(x / step) * step; + }, + ceil: function(x) { + return Math.ceil(x / step) * step; + } + } : d3_scale_niceIdentity; + } + var d3_scale_niceIdentity = { + floor: d3_identity, + ceil: d3_identity + }; + function d3_scale_polylinear(domain, range, uninterpolate, interpolate) { + var u = [], i = [], j = 0, k = Math.min(domain.length, range.length) - 1; + if (domain[k] < domain[0]) { + domain = domain.slice().reverse(); + range = range.slice().reverse(); + } + while (++j <= k) { + u.push(uninterpolate(domain[j - 1], domain[j])); + i.push(interpolate(range[j - 1], range[j])); + } + return function(x) { + var j = d3.bisect(domain, x, 1, k) - 1; + return i[j](u[j](x)); + }; + } + d3.scale.linear = function() { + return d3_scale_linear([ 0, 1 ], [ 0, 1 ], d3_interpolate, false); + }; + function d3_scale_linear(domain, range, interpolate, clamp) { + var output, input; + function rescale() { + var linear = Math.min(domain.length, range.length) > 2 ? d3_scale_polylinear : d3_scale_bilinear, uninterpolate = clamp ? d3_uninterpolateClamp : d3_uninterpolateNumber; + output = linear(domain, range, uninterpolate, interpolate); + input = linear(range, domain, uninterpolate, d3_interpolate); + return scale; + } + function scale(x) { + return output(x); + } + scale.invert = function(y) { + return input(y); + }; + scale.domain = function(x) { + if (!arguments.length) return domain; + domain = x.map(Number); + return rescale(); + }; + scale.range = function(x) { + if (!arguments.length) return range; + range = x; + return rescale(); + }; + scale.rangeRound = function(x) { + return scale.range(x).interpolate(d3_interpolateRound); + }; + scale.clamp = function(x) { + if (!arguments.length) return clamp; + clamp = x; + return rescale(); + }; + scale.interpolate = function(x) { + if (!arguments.length) return interpolate; + interpolate = x; + return rescale(); + }; + scale.ticks = function(m) { + return d3_scale_linearTicks(domain, m); + }; + scale.tickFormat = function(m, format) { + return d3_scale_linearTickFormat(domain, m, format); + }; + scale.nice = function(m) { + d3_scale_linearNice(domain, m); + return rescale(); + }; + scale.copy = function() { + return d3_scale_linear(domain, range, interpolate, clamp); + }; + return rescale(); + } + function d3_scale_linearRebind(scale, linear) { + return d3.rebind(scale, linear, "range", "rangeRound", "interpolate", "clamp"); + } + function d3_scale_linearNice(domain, m) { + d3_scale_nice(domain, d3_scale_niceStep(d3_scale_linearTickRange(domain, m)[2])); + d3_scale_nice(domain, d3_scale_niceStep(d3_scale_linearTickRange(domain, m)[2])); + return domain; + } + function d3_scale_linearTickRange(domain, m) { + if (m == null) m = 10; + var extent = d3_scaleExtent(domain), span = extent[1] - extent[0], step = Math.pow(10, Math.floor(Math.log(span / m) / Math.LN10)), err = m / span * step; + if (err <= .15) step *= 10; else if (err <= .35) step *= 5; else if (err <= .75) step *= 2; + extent[0] = Math.ceil(extent[0] / step) * step; + extent[1] = Math.floor(extent[1] / step) * step + step * .5; + extent[2] = step; + return extent; + } + function d3_scale_linearTicks(domain, m) { + return d3.range.apply(d3, d3_scale_linearTickRange(domain, m)); + } + function d3_scale_linearTickFormat(domain, m, format) { + var range = d3_scale_linearTickRange(domain, m); + if (format) { + var match = d3_format_re.exec(format); + match.shift(); + if (match[8] === "s") { + var prefix = d3.formatPrefix(Math.max(abs(range[0]), abs(range[1]))); + if (!match[7]) match[7] = "." + d3_scale_linearPrecision(prefix.scale(range[2])); + match[8] = "f"; + format = d3.format(match.join("")); + return function(d) { + return format(prefix.scale(d)) + prefix.symbol; + }; + } + if (!match[7]) match[7] = "." + d3_scale_linearFormatPrecision(match[8], range); + format = match.join(""); + } else { + format = ",." + d3_scale_linearPrecision(range[2]) + "f"; + } + return d3.format(format); + } + var d3_scale_linearFormatSignificant = { + s: 1, + g: 1, + p: 1, + r: 1, + e: 1 + }; + function d3_scale_linearPrecision(value) { + return -Math.floor(Math.log(value) / Math.LN10 + .01); + } + function d3_scale_linearFormatPrecision(type, range) { + var p = d3_scale_linearPrecision(range[2]); + return type in d3_scale_linearFormatSignificant ? Math.abs(p - d3_scale_linearPrecision(Math.max(abs(range[0]), abs(range[1])))) + +(type !== "e") : p - (type === "%") * 2; + } + d3.scale.log = function() { + return d3_scale_log(d3.scale.linear().domain([ 0, 1 ]), 10, true, [ 1, 10 ]); + }; + function d3_scale_log(linear, base, positive, domain) { + function log(x) { + return (positive ? Math.log(x < 0 ? 0 : x) : -Math.log(x > 0 ? 0 : -x)) / Math.log(base); + } + function pow(x) { + return positive ? Math.pow(base, x) : -Math.pow(base, -x); + } + function scale(x) { + return linear(log(x)); + } + scale.invert = function(x) { + return pow(linear.invert(x)); + }; + scale.domain = function(x) { + if (!arguments.length) return domain; + positive = x[0] >= 0; + linear.domain((domain = x.map(Number)).map(log)); + return scale; + }; + scale.base = function(_) { + if (!arguments.length) return base; + base = +_; + linear.domain(domain.map(log)); + return scale; + }; + scale.nice = function() { + var niced = d3_scale_nice(domain.map(log), positive ? Math : d3_scale_logNiceNegative); + linear.domain(niced); + domain = niced.map(pow); + return scale; + }; + scale.ticks = function() { + var extent = d3_scaleExtent(domain), ticks = [], u = extent[0], v = extent[1], i = Math.floor(log(u)), j = Math.ceil(log(v)), n = base % 1 ? 2 : base; + if (isFinite(j - i)) { + if (positive) { + for (;i < j; i++) for (var k = 1; k < n; k++) ticks.push(pow(i) * k); + ticks.push(pow(i)); + } else { + ticks.push(pow(i)); + for (;i++ < j; ) for (var k = n - 1; k > 0; k--) ticks.push(pow(i) * k); + } + for (i = 0; ticks[i] < u; i++) {} + for (j = ticks.length; ticks[j - 1] > v; j--) {} + ticks = ticks.slice(i, j); + } + return ticks; + }; + scale.tickFormat = function(n, format) { + if (!arguments.length) return d3_scale_logFormat; + if (arguments.length < 2) format = d3_scale_logFormat; else if (typeof format !== "function") format = d3.format(format); + var k = Math.max(1, base * n / scale.ticks().length); + return function(d) { + var i = d / pow(Math.round(log(d))); + if (i * base < base - .5) i *= base; + return i <= k ? format(d) : ""; + }; + }; + scale.copy = function() { + return d3_scale_log(linear.copy(), base, positive, domain); + }; + return d3_scale_linearRebind(scale, linear); + } + var d3_scale_logFormat = d3.format(".0e"), d3_scale_logNiceNegative = { + floor: function(x) { + return -Math.ceil(-x); + }, + ceil: function(x) { + return -Math.floor(-x); + } + }; + d3.scale.pow = function() { + return d3_scale_pow(d3.scale.linear(), 1, [ 0, 1 ]); + }; + function d3_scale_pow(linear, exponent, domain) { + var powp = d3_scale_powPow(exponent), powb = d3_scale_powPow(1 / exponent); + function scale(x) { + return linear(powp(x)); + } + scale.invert = function(x) { + return powb(linear.invert(x)); + }; + scale.domain = function(x) { + if (!arguments.length) return domain; + linear.domain((domain = x.map(Number)).map(powp)); + return scale; + }; + scale.ticks = function(m) { + return d3_scale_linearTicks(domain, m); + }; + scale.tickFormat = function(m, format) { + return d3_scale_linearTickFormat(domain, m, format); + }; + scale.nice = function(m) { + return scale.domain(d3_scale_linearNice(domain, m)); + }; + scale.exponent = function(x) { + if (!arguments.length) return exponent; + powp = d3_scale_powPow(exponent = x); + powb = d3_scale_powPow(1 / exponent); + linear.domain(domain.map(powp)); + return scale; + }; + scale.copy = function() { + return d3_scale_pow(linear.copy(), exponent, domain); + }; + return d3_scale_linearRebind(scale, linear); + } + function d3_scale_powPow(e) { + return function(x) { + return x < 0 ? -Math.pow(-x, e) : Math.pow(x, e); + }; + } + d3.scale.sqrt = function() { + return d3.scale.pow().exponent(.5); + }; + d3.scale.ordinal = function() { + return d3_scale_ordinal([], { + t: "range", + a: [ [] ] + }); + }; + function d3_scale_ordinal(domain, ranger) { + var index, range, rangeBand; + function scale(x) { + return range[((index.get(x) || (ranger.t === "range" ? index.set(x, domain.push(x)) : NaN)) - 1) % range.length]; + } + function steps(start, step) { + return d3.range(domain.length).map(function(i) { + return start + step * i; + }); + } + scale.domain = function(x) { + if (!arguments.length) return domain; + domain = []; + index = new d3_Map(); + var i = -1, n = x.length, xi; + while (++i < n) if (!index.has(xi = x[i])) index.set(xi, domain.push(xi)); + return scale[ranger.t].apply(scale, ranger.a); + }; + scale.range = function(x) { + if (!arguments.length) return range; + range = x; + rangeBand = 0; + ranger = { + t: "range", + a: arguments + }; + return scale; + }; + scale.rangePoints = function(x, padding) { + if (arguments.length < 2) padding = 0; + var start = x[0], stop = x[1], step = domain.length < 2 ? (start = (start + stop) / 2, + 0) : (stop - start) / (domain.length - 1 + padding); + range = steps(start + step * padding / 2, step); + rangeBand = 0; + ranger = { + t: "rangePoints", + a: arguments + }; + return scale; + }; + scale.rangeRoundPoints = function(x, padding) { + if (arguments.length < 2) padding = 0; + var start = x[0], stop = x[1], step = domain.length < 2 ? (start = stop = Math.round((start + stop) / 2), + 0) : (stop - start) / (domain.length - 1 + padding) | 0; + range = steps(start + Math.round(step * padding / 2 + (stop - start - (domain.length - 1 + padding) * step) / 2), step); + rangeBand = 0; + ranger = { + t: "rangeRoundPoints", + a: arguments + }; + return scale; + }; + scale.rangeBands = function(x, padding, outerPadding) { + if (arguments.length < 2) padding = 0; + if (arguments.length < 3) outerPadding = padding; + var reverse = x[1] < x[0], start = x[reverse - 0], stop = x[1 - reverse], step = (stop - start) / (domain.length - padding + 2 * outerPadding); + range = steps(start + step * outerPadding, step); + if (reverse) range.reverse(); + rangeBand = step * (1 - padding); + ranger = { + t: "rangeBands", + a: arguments + }; + return scale; + }; + scale.rangeRoundBands = function(x, padding, outerPadding) { + if (arguments.length < 2) padding = 0; + if (arguments.length < 3) outerPadding = padding; + var reverse = x[1] < x[0], start = x[reverse - 0], stop = x[1 - reverse], step = Math.floor((stop - start) / (domain.length - padding + 2 * outerPadding)); + range = steps(start + Math.round((stop - start - (domain.length - padding) * step) / 2), step); + if (reverse) range.reverse(); + rangeBand = Math.round(step * (1 - padding)); + ranger = { + t: "rangeRoundBands", + a: arguments + }; + return scale; + }; + scale.rangeBand = function() { + return rangeBand; + }; + scale.rangeExtent = function() { + return d3_scaleExtent(ranger.a[0]); + }; + scale.copy = function() { + return d3_scale_ordinal(domain, ranger); + }; + return scale.domain(domain); + } + d3.scale.category10 = function() { + return d3.scale.ordinal().range(d3_category10); + }; + d3.scale.category20 = function() { + return d3.scale.ordinal().range(d3_category20); + }; + d3.scale.category20b = function() { + return d3.scale.ordinal().range(d3_category20b); + }; + d3.scale.category20c = function() { + return d3.scale.ordinal().range(d3_category20c); + }; + var d3_category10 = [ 2062260, 16744206, 2924588, 14034728, 9725885, 9197131, 14907330, 8355711, 12369186, 1556175 ].map(d3_rgbString); + var d3_category20 = [ 2062260, 11454440, 16744206, 16759672, 2924588, 10018698, 14034728, 16750742, 9725885, 12955861, 9197131, 12885140, 14907330, 16234194, 8355711, 13092807, 12369186, 14408589, 1556175, 10410725 ].map(d3_rgbString); + var d3_category20b = [ 3750777, 5395619, 7040719, 10264286, 6519097, 9216594, 11915115, 13556636, 9202993, 12426809, 15186514, 15190932, 8666169, 11356490, 14049643, 15177372, 8077683, 10834324, 13528509, 14589654 ].map(d3_rgbString); + var d3_category20c = [ 3244733, 7057110, 10406625, 13032431, 15095053, 16616764, 16625259, 16634018, 3253076, 7652470, 10607003, 13101504, 7695281, 10394312, 12369372, 14342891, 6513507, 9868950, 12434877, 14277081 ].map(d3_rgbString); + d3.scale.quantile = function() { + return d3_scale_quantile([], []); + }; + function d3_scale_quantile(domain, range) { + var thresholds; + function rescale() { + var k = 0, q = range.length; + thresholds = []; + while (++k < q) thresholds[k - 1] = d3.quantile(domain, k / q); + return scale; + } + function scale(x) { + if (!isNaN(x = +x)) return range[d3.bisect(thresholds, x)]; + } + scale.domain = function(x) { + if (!arguments.length) return domain; + domain = x.map(d3_number).filter(d3_numeric).sort(d3_ascending); + return rescale(); + }; + scale.range = function(x) { + if (!arguments.length) return range; + range = x; + return rescale(); + }; + scale.quantiles = function() { + return thresholds; + }; + scale.invertExtent = function(y) { + y = range.indexOf(y); + return y < 0 ? [ NaN, NaN ] : [ y > 0 ? thresholds[y - 1] : domain[0], y < thresholds.length ? thresholds[y] : domain[domain.length - 1] ]; + }; + scale.copy = function() { + return d3_scale_quantile(domain, range); + }; + return rescale(); + } + d3.scale.quantize = function() { + return d3_scale_quantize(0, 1, [ 0, 1 ]); + }; + function d3_scale_quantize(x0, x1, range) { + var kx, i; + function scale(x) { + return range[Math.max(0, Math.min(i, Math.floor(kx * (x - x0))))]; + } + function rescale() { + kx = range.length / (x1 - x0); + i = range.length - 1; + return scale; + } + scale.domain = function(x) { + if (!arguments.length) return [ x0, x1 ]; + x0 = +x[0]; + x1 = +x[x.length - 1]; + return rescale(); + }; + scale.range = function(x) { + if (!arguments.length) return range; + range = x; + return rescale(); + }; + scale.invertExtent = function(y) { + y = range.indexOf(y); + y = y < 0 ? NaN : y / kx + x0; + return [ y, y + 1 / kx ]; + }; + scale.copy = function() { + return d3_scale_quantize(x0, x1, range); + }; + return rescale(); + } + d3.scale.threshold = function() { + return d3_scale_threshold([ .5 ], [ 0, 1 ]); + }; + function d3_scale_threshold(domain, range) { + function scale(x) { + if (x <= x) return range[d3.bisect(domain, x)]; + } + scale.domain = function(_) { + if (!arguments.length) return domain; + domain = _; + return scale; + }; + scale.range = function(_) { + if (!arguments.length) return range; + range = _; + return scale; + }; + scale.invertExtent = function(y) { + y = range.indexOf(y); + return [ domain[y - 1], domain[y] ]; + }; + scale.copy = function() { + return d3_scale_threshold(domain, range); + }; + return scale; + } + d3.scale.identity = function() { + return d3_scale_identity([ 0, 1 ]); + }; + function d3_scale_identity(domain) { + function identity(x) { + return +x; + } + identity.invert = identity; + identity.domain = identity.range = function(x) { + if (!arguments.length) return domain; + domain = x.map(identity); + return identity; + }; + identity.ticks = function(m) { + return d3_scale_linearTicks(domain, m); + }; + identity.tickFormat = function(m, format) { + return d3_scale_linearTickFormat(domain, m, format); + }; + identity.copy = function() { + return d3_scale_identity(domain); + }; + return identity; + } + d3.svg = {}; + function d3_zero() { + return 0; + } + d3.svg.arc = function() { + var innerRadius = d3_svg_arcInnerRadius, outerRadius = d3_svg_arcOuterRadius, cornerRadius = d3_zero, padRadius = d3_svg_arcAuto, startAngle = d3_svg_arcStartAngle, endAngle = d3_svg_arcEndAngle, padAngle = d3_svg_arcPadAngle; + function arc() { + var r0 = Math.max(0, +innerRadius.apply(this, arguments)), r1 = Math.max(0, +outerRadius.apply(this, arguments)), a0 = startAngle.apply(this, arguments) - halfπ, a1 = endAngle.apply(this, arguments) - halfπ, da = Math.abs(a1 - a0), cw = a0 > a1 ? 0 : 1; + if (r1 < r0) rc = r1, r1 = r0, r0 = rc; + if (da >= τε) return circleSegment(r1, cw) + (r0 ? circleSegment(r0, 1 - cw) : "") + "Z"; + var rc, cr, rp, ap, p0 = 0, p1 = 0, x0, y0, x1, y1, x2, y2, x3, y3, path = []; + if (ap = (+padAngle.apply(this, arguments) || 0) / 2) { + rp = padRadius === d3_svg_arcAuto ? Math.sqrt(r0 * r0 + r1 * r1) : +padRadius.apply(this, arguments); + if (!cw) p1 *= -1; + if (r1) p1 = d3_asin(rp / r1 * Math.sin(ap)); + if (r0) p0 = d3_asin(rp / r0 * Math.sin(ap)); + } + if (r1) { + x0 = r1 * Math.cos(a0 + p1); + y0 = r1 * Math.sin(a0 + p1); + x1 = r1 * Math.cos(a1 - p1); + y1 = r1 * Math.sin(a1 - p1); + var l1 = Math.abs(a1 - a0 - 2 * p1) <= π ? 0 : 1; + if (p1 && d3_svg_arcSweep(x0, y0, x1, y1) === cw ^ l1) { + var h1 = (a0 + a1) / 2; + x0 = r1 * Math.cos(h1); + y0 = r1 * Math.sin(h1); + x1 = y1 = null; + } + } else { + x0 = y0 = 0; + } + if (r0) { + x2 = r0 * Math.cos(a1 - p0); + y2 = r0 * Math.sin(a1 - p0); + x3 = r0 * Math.cos(a0 + p0); + y3 = r0 * Math.sin(a0 + p0); + var l0 = Math.abs(a0 - a1 + 2 * p0) <= π ? 0 : 1; + if (p0 && d3_svg_arcSweep(x2, y2, x3, y3) === 1 - cw ^ l0) { + var h0 = (a0 + a1) / 2; + x2 = r0 * Math.cos(h0); + y2 = r0 * Math.sin(h0); + x3 = y3 = null; + } + } else { + x2 = y2 = 0; + } + if (da > ε && (rc = Math.min(Math.abs(r1 - r0) / 2, +cornerRadius.apply(this, arguments))) > .001) { + cr = r0 < r1 ^ cw ? 0 : 1; + var rc1 = rc, rc0 = rc; + if (da < π) { + var oc = x3 == null ? [ x2, y2 ] : x1 == null ? [ x0, y0 ] : d3_geom_polygonIntersect([ x0, y0 ], [ x3, y3 ], [ x1, y1 ], [ x2, y2 ]), ax = x0 - oc[0], ay = y0 - oc[1], bx = x1 - oc[0], by = y1 - oc[1], kc = 1 / Math.sin(Math.acos((ax * bx + ay * by) / (Math.sqrt(ax * ax + ay * ay) * Math.sqrt(bx * bx + by * by))) / 2), lc = Math.sqrt(oc[0] * oc[0] + oc[1] * oc[1]); + rc0 = Math.min(rc, (r0 - lc) / (kc - 1)); + rc1 = Math.min(rc, (r1 - lc) / (kc + 1)); + } + if (x1 != null) { + var t30 = d3_svg_arcCornerTangents(x3 == null ? [ x2, y2 ] : [ x3, y3 ], [ x0, y0 ], r1, rc1, cw), t12 = d3_svg_arcCornerTangents([ x1, y1 ], [ x2, y2 ], r1, rc1, cw); + if (rc === rc1) { + path.push("M", t30[0], "A", rc1, ",", rc1, " 0 0,", cr, " ", t30[1], "A", r1, ",", r1, " 0 ", 1 - cw ^ d3_svg_arcSweep(t30[1][0], t30[1][1], t12[1][0], t12[1][1]), ",", cw, " ", t12[1], "A", rc1, ",", rc1, " 0 0,", cr, " ", t12[0]); + } else { + path.push("M", t30[0], "A", rc1, ",", rc1, " 0 1,", cr, " ", t12[0]); + } + } else { + path.push("M", x0, ",", y0); + } + if (x3 != null) { + var t03 = d3_svg_arcCornerTangents([ x0, y0 ], [ x3, y3 ], r0, -rc0, cw), t21 = d3_svg_arcCornerTangents([ x2, y2 ], x1 == null ? [ x0, y0 ] : [ x1, y1 ], r0, -rc0, cw); + if (rc === rc0) { + path.push("L", t21[0], "A", rc0, ",", rc0, " 0 0,", cr, " ", t21[1], "A", r0, ",", r0, " 0 ", cw ^ d3_svg_arcSweep(t21[1][0], t21[1][1], t03[1][0], t03[1][1]), ",", 1 - cw, " ", t03[1], "A", rc0, ",", rc0, " 0 0,", cr, " ", t03[0]); + } else { + path.push("L", t21[0], "A", rc0, ",", rc0, " 0 0,", cr, " ", t03[0]); + } + } else { + path.push("L", x2, ",", y2); + } + } else { + path.push("M", x0, ",", y0); + if (x1 != null) path.push("A", r1, ",", r1, " 0 ", l1, ",", cw, " ", x1, ",", y1); + path.push("L", x2, ",", y2); + if (x3 != null) path.push("A", r0, ",", r0, " 0 ", l0, ",", 1 - cw, " ", x3, ",", y3); + } + path.push("Z"); + return path.join(""); + } + function circleSegment(r1, cw) { + return "M0," + r1 + "A" + r1 + "," + r1 + " 0 1," + cw + " 0," + -r1 + "A" + r1 + "," + r1 + " 0 1," + cw + " 0," + r1; + } + arc.innerRadius = function(v) { + if (!arguments.length) return innerRadius; + innerRadius = d3_functor(v); + return arc; + }; + arc.outerRadius = function(v) { + if (!arguments.length) return outerRadius; + outerRadius = d3_functor(v); + return arc; + }; + arc.cornerRadius = function(v) { + if (!arguments.length) return cornerRadius; + cornerRadius = d3_functor(v); + return arc; + }; + arc.padRadius = function(v) { + if (!arguments.length) return padRadius; + padRadius = v == d3_svg_arcAuto ? d3_svg_arcAuto : d3_functor(v); + return arc; + }; + arc.startAngle = function(v) { + if (!arguments.length) return startAngle; + startAngle = d3_functor(v); + return arc; + }; + arc.endAngle = function(v) { + if (!arguments.length) return endAngle; + endAngle = d3_functor(v); + return arc; + }; + arc.padAngle = function(v) { + if (!arguments.length) return padAngle; + padAngle = d3_functor(v); + return arc; + }; + arc.centroid = function() { + var r = (+innerRadius.apply(this, arguments) + +outerRadius.apply(this, arguments)) / 2, a = (+startAngle.apply(this, arguments) + +endAngle.apply(this, arguments)) / 2 - halfπ; + return [ Math.cos(a) * r, Math.sin(a) * r ]; + }; + return arc; + }; + var d3_svg_arcAuto = "auto"; + function d3_svg_arcInnerRadius(d) { + return d.innerRadius; + } + function d3_svg_arcOuterRadius(d) { + return d.outerRadius; + } + function d3_svg_arcStartAngle(d) { + return d.startAngle; + } + function d3_svg_arcEndAngle(d) { + return d.endAngle; + } + function d3_svg_arcPadAngle(d) { + return d && d.padAngle; + } + function d3_svg_arcSweep(x0, y0, x1, y1) { + return (x0 - x1) * y0 - (y0 - y1) * x0 > 0 ? 0 : 1; + } + function d3_svg_arcCornerTangents(p0, p1, r1, rc, cw) { + var x01 = p0[0] - p1[0], y01 = p0[1] - p1[1], lo = (cw ? rc : -rc) / Math.sqrt(x01 * x01 + y01 * y01), ox = lo * y01, oy = -lo * x01, x1 = p0[0] + ox, y1 = p0[1] + oy, x2 = p1[0] + ox, y2 = p1[1] + oy, x3 = (x1 + x2) / 2, y3 = (y1 + y2) / 2, dx = x2 - x1, dy = y2 - y1, d2 = dx * dx + dy * dy, r = r1 - rc, D = x1 * y2 - x2 * y1, d = (dy < 0 ? -1 : 1) * Math.sqrt(Math.max(0, r * r * d2 - D * D)), cx0 = (D * dy - dx * d) / d2, cy0 = (-D * dx - dy * d) / d2, cx1 = (D * dy + dx * d) / d2, cy1 = (-D * dx + dy * d) / d2, dx0 = cx0 - x3, dy0 = cy0 - y3, dx1 = cx1 - x3, dy1 = cy1 - y3; + if (dx0 * dx0 + dy0 * dy0 > dx1 * dx1 + dy1 * dy1) cx0 = cx1, cy0 = cy1; + return [ [ cx0 - ox, cy0 - oy ], [ cx0 * r1 / r, cy0 * r1 / r ] ]; + } + function d3_svg_line(projection) { + var x = d3_geom_pointX, y = d3_geom_pointY, defined = d3_true, interpolate = d3_svg_lineLinear, interpolateKey = interpolate.key, tension = .7; + function line(data) { + var segments = [], points = [], i = -1, n = data.length, d, fx = d3_functor(x), fy = d3_functor(y); + function segment() { + segments.push("M", interpolate(projection(points), tension)); + } + while (++i < n) { + if (defined.call(this, d = data[i], i)) { + points.push([ +fx.call(this, d, i), +fy.call(this, d, i) ]); + } else if (points.length) { + segment(); + points = []; + } + } + if (points.length) segment(); + return segments.length ? segments.join("") : null; + } + line.x = function(_) { + if (!arguments.length) return x; + x = _; + return line; + }; + line.y = function(_) { + if (!arguments.length) return y; + y = _; + return line; + }; + line.defined = function(_) { + if (!arguments.length) return defined; + defined = _; + return line; + }; + line.interpolate = function(_) { + if (!arguments.length) return interpolateKey; + if (typeof _ === "function") interpolateKey = interpolate = _; else interpolateKey = (interpolate = d3_svg_lineInterpolators.get(_) || d3_svg_lineLinear).key; + return line; + }; + line.tension = function(_) { + if (!arguments.length) return tension; + tension = _; + return line; + }; + return line; + } + d3.svg.line = function() { + return d3_svg_line(d3_identity); + }; + var d3_svg_lineInterpolators = d3.map({ + linear: d3_svg_lineLinear, + "linear-closed": d3_svg_lineLinearClosed, + step: d3_svg_lineStep, + "step-before": d3_svg_lineStepBefore, + "step-after": d3_svg_lineStepAfter, + basis: d3_svg_lineBasis, + "basis-open": d3_svg_lineBasisOpen, + "basis-closed": d3_svg_lineBasisClosed, + bundle: d3_svg_lineBundle, + cardinal: d3_svg_lineCardinal, + "cardinal-open": d3_svg_lineCardinalOpen, + "cardinal-closed": d3_svg_lineCardinalClosed, + monotone: d3_svg_lineMonotone + }); + d3_svg_lineInterpolators.forEach(function(key, value) { + value.key = key; + value.closed = /-closed$/.test(key); + }); + function d3_svg_lineLinear(points) { + return points.length > 1 ? points.join("L") : points + "Z"; + } + function d3_svg_lineLinearClosed(points) { + return points.join("L") + "Z"; + } + function d3_svg_lineStep(points) { + var i = 0, n = points.length, p = points[0], path = [ p[0], ",", p[1] ]; + while (++i < n) path.push("H", (p[0] + (p = points[i])[0]) / 2, "V", p[1]); + if (n > 1) path.push("H", p[0]); + return path.join(""); + } + function d3_svg_lineStepBefore(points) { + var i = 0, n = points.length, p = points[0], path = [ p[0], ",", p[1] ]; + while (++i < n) path.push("V", (p = points[i])[1], "H", p[0]); + return path.join(""); + } + function d3_svg_lineStepAfter(points) { + var i = 0, n = points.length, p = points[0], path = [ p[0], ",", p[1] ]; + while (++i < n) path.push("H", (p = points[i])[0], "V", p[1]); + return path.join(""); + } + function d3_svg_lineCardinalOpen(points, tension) { + return points.length < 4 ? d3_svg_lineLinear(points) : points[1] + d3_svg_lineHermite(points.slice(1, -1), d3_svg_lineCardinalTangents(points, tension)); + } + function d3_svg_lineCardinalClosed(points, tension) { + return points.length < 3 ? d3_svg_lineLinearClosed(points) : points[0] + d3_svg_lineHermite((points.push(points[0]), + points), d3_svg_lineCardinalTangents([ points[points.length - 2] ].concat(points, [ points[1] ]), tension)); + } + function d3_svg_lineCardinal(points, tension) { + return points.length < 3 ? d3_svg_lineLinear(points) : points[0] + d3_svg_lineHermite(points, d3_svg_lineCardinalTangents(points, tension)); + } + function d3_svg_lineHermite(points, tangents) { + if (tangents.length < 1 || points.length != tangents.length && points.length != tangents.length + 2) { + return d3_svg_lineLinear(points); + } + var quad = points.length != tangents.length, path = "", p0 = points[0], p = points[1], t0 = tangents[0], t = t0, pi = 1; + if (quad) { + path += "Q" + (p[0] - t0[0] * 2 / 3) + "," + (p[1] - t0[1] * 2 / 3) + "," + p[0] + "," + p[1]; + p0 = points[1]; + pi = 2; + } + if (tangents.length > 1) { + t = tangents[1]; + p = points[pi]; + pi++; + path += "C" + (p0[0] + t0[0]) + "," + (p0[1] + t0[1]) + "," + (p[0] - t[0]) + "," + (p[1] - t[1]) + "," + p[0] + "," + p[1]; + for (var i = 2; i < tangents.length; i++, pi++) { + p = points[pi]; + t = tangents[i]; + path += "S" + (p[0] - t[0]) + "," + (p[1] - t[1]) + "," + p[0] + "," + p[1]; + } + } + if (quad) { + var lp = points[pi]; + path += "Q" + (p[0] + t[0] * 2 / 3) + "," + (p[1] + t[1] * 2 / 3) + "," + lp[0] + "," + lp[1]; + } + return path; + } + function d3_svg_lineCardinalTangents(points, tension) { + var tangents = [], a = (1 - tension) / 2, p0, p1 = points[0], p2 = points[1], i = 1, n = points.length; + while (++i < n) { + p0 = p1; + p1 = p2; + p2 = points[i]; + tangents.push([ a * (p2[0] - p0[0]), a * (p2[1] - p0[1]) ]); + } + return tangents; + } + function d3_svg_lineBasis(points) { + if (points.length < 3) return d3_svg_lineLinear(points); + var i = 1, n = points.length, pi = points[0], x0 = pi[0], y0 = pi[1], px = [ x0, x0, x0, (pi = points[1])[0] ], py = [ y0, y0, y0, pi[1] ], path = [ x0, ",", y0, "L", d3_svg_lineDot4(d3_svg_lineBasisBezier3, px), ",", d3_svg_lineDot4(d3_svg_lineBasisBezier3, py) ]; + points.push(points[n - 1]); + while (++i <= n) { + pi = points[i]; + px.shift(); + px.push(pi[0]); + py.shift(); + py.push(pi[1]); + d3_svg_lineBasisBezier(path, px, py); + } + points.pop(); + path.push("L", pi); + return path.join(""); + } + function d3_svg_lineBasisOpen(points) { + if (points.length < 4) return d3_svg_lineLinear(points); + var path = [], i = -1, n = points.length, pi, px = [ 0 ], py = [ 0 ]; + while (++i < 3) { + pi = points[i]; + px.push(pi[0]); + py.push(pi[1]); + } + path.push(d3_svg_lineDot4(d3_svg_lineBasisBezier3, px) + "," + d3_svg_lineDot4(d3_svg_lineBasisBezier3, py)); + --i; + while (++i < n) { + pi = points[i]; + px.shift(); + px.push(pi[0]); + py.shift(); + py.push(pi[1]); + d3_svg_lineBasisBezier(path, px, py); + } + return path.join(""); + } + function d3_svg_lineBasisClosed(points) { + var path, i = -1, n = points.length, m = n + 4, pi, px = [], py = []; + while (++i < 4) { + pi = points[i % n]; + px.push(pi[0]); + py.push(pi[1]); + } + path = [ d3_svg_lineDot4(d3_svg_lineBasisBezier3, px), ",", d3_svg_lineDot4(d3_svg_lineBasisBezier3, py) ]; + --i; + while (++i < m) { + pi = points[i % n]; + px.shift(); + px.push(pi[0]); + py.shift(); + py.push(pi[1]); + d3_svg_lineBasisBezier(path, px, py); + } + return path.join(""); + } + function d3_svg_lineBundle(points, tension) { + var n = points.length - 1; + if (n) { + var x0 = points[0][0], y0 = points[0][1], dx = points[n][0] - x0, dy = points[n][1] - y0, i = -1, p, t; + while (++i <= n) { + p = points[i]; + t = i / n; + p[0] = tension * p[0] + (1 - tension) * (x0 + t * dx); + p[1] = tension * p[1] + (1 - tension) * (y0 + t * dy); + } + } + return d3_svg_lineBasis(points); + } + function d3_svg_lineDot4(a, b) { + return a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3]; + } + var d3_svg_lineBasisBezier1 = [ 0, 2 / 3, 1 / 3, 0 ], d3_svg_lineBasisBezier2 = [ 0, 1 / 3, 2 / 3, 0 ], d3_svg_lineBasisBezier3 = [ 0, 1 / 6, 2 / 3, 1 / 6 ]; + function d3_svg_lineBasisBezier(path, x, y) { + path.push("C", d3_svg_lineDot4(d3_svg_lineBasisBezier1, x), ",", d3_svg_lineDot4(d3_svg_lineBasisBezier1, y), ",", d3_svg_lineDot4(d3_svg_lineBasisBezier2, x), ",", d3_svg_lineDot4(d3_svg_lineBasisBezier2, y), ",", d3_svg_lineDot4(d3_svg_lineBasisBezier3, x), ",", d3_svg_lineDot4(d3_svg_lineBasisBezier3, y)); + } + function d3_svg_lineSlope(p0, p1) { + return (p1[1] - p0[1]) / (p1[0] - p0[0]); + } + function d3_svg_lineFiniteDifferences(points) { + var i = 0, j = points.length - 1, m = [], p0 = points[0], p1 = points[1], d = m[0] = d3_svg_lineSlope(p0, p1); + while (++i < j) { + m[i] = (d + (d = d3_svg_lineSlope(p0 = p1, p1 = points[i + 1]))) / 2; + } + m[i] = d; + return m; + } + function d3_svg_lineMonotoneTangents(points) { + var tangents = [], d, a, b, s, m = d3_svg_lineFiniteDifferences(points), i = -1, j = points.length - 1; + while (++i < j) { + d = d3_svg_lineSlope(points[i], points[i + 1]); + if (abs(d) < ε) { + m[i] = m[i + 1] = 0; + } else { + a = m[i] / d; + b = m[i + 1] / d; + s = a * a + b * b; + if (s > 9) { + s = d * 3 / Math.sqrt(s); + m[i] = s * a; + m[i + 1] = s * b; + } + } + } + i = -1; + while (++i <= j) { + s = (points[Math.min(j, i + 1)][0] - points[Math.max(0, i - 1)][0]) / (6 * (1 + m[i] * m[i])); + tangents.push([ s || 0, m[i] * s || 0 ]); + } + return tangents; + } + function d3_svg_lineMonotone(points) { + return points.length < 3 ? d3_svg_lineLinear(points) : points[0] + d3_svg_lineHermite(points, d3_svg_lineMonotoneTangents(points)); + } + d3.svg.line.radial = function() { + var line = d3_svg_line(d3_svg_lineRadial); + line.radius = line.x, delete line.x; + line.angle = line.y, delete line.y; + return line; + }; + function d3_svg_lineRadial(points) { + var point, i = -1, n = points.length, r, a; + while (++i < n) { + point = points[i]; + r = point[0]; + a = point[1] - halfπ; + point[0] = r * Math.cos(a); + point[1] = r * Math.sin(a); + } + return points; + } + function d3_svg_area(projection) { + var x0 = d3_geom_pointX, x1 = d3_geom_pointX, y0 = 0, y1 = d3_geom_pointY, defined = d3_true, interpolate = d3_svg_lineLinear, interpolateKey = interpolate.key, interpolateReverse = interpolate, L = "L", tension = .7; + function area(data) { + var segments = [], points0 = [], points1 = [], i = -1, n = data.length, d, fx0 = d3_functor(x0), fy0 = d3_functor(y0), fx1 = x0 === x1 ? function() { + return x; + } : d3_functor(x1), fy1 = y0 === y1 ? function() { + return y; + } : d3_functor(y1), x, y; + function segment() { + segments.push("M", interpolate(projection(points1), tension), L, interpolateReverse(projection(points0.reverse()), tension), "Z"); + } + while (++i < n) { + if (defined.call(this, d = data[i], i)) { + points0.push([ x = +fx0.call(this, d, i), y = +fy0.call(this, d, i) ]); + points1.push([ +fx1.call(this, d, i), +fy1.call(this, d, i) ]); + } else if (points0.length) { + segment(); + points0 = []; + points1 = []; + } + } + if (points0.length) segment(); + return segments.length ? segments.join("") : null; + } + area.x = function(_) { + if (!arguments.length) return x1; + x0 = x1 = _; + return area; + }; + area.x0 = function(_) { + if (!arguments.length) return x0; + x0 = _; + return area; + }; + area.x1 = function(_) { + if (!arguments.length) return x1; + x1 = _; + return area; + }; + area.y = function(_) { + if (!arguments.length) return y1; + y0 = y1 = _; + return area; + }; + area.y0 = function(_) { + if (!arguments.length) return y0; + y0 = _; + return area; + }; + area.y1 = function(_) { + if (!arguments.length) return y1; + y1 = _; + return area; + }; + area.defined = function(_) { + if (!arguments.length) return defined; + defined = _; + return area; + }; + area.interpolate = function(_) { + if (!arguments.length) return interpolateKey; + if (typeof _ === "function") interpolateKey = interpolate = _; else interpolateKey = (interpolate = d3_svg_lineInterpolators.get(_) || d3_svg_lineLinear).key; + interpolateReverse = interpolate.reverse || interpolate; + L = interpolate.closed ? "M" : "L"; + return area; + }; + area.tension = function(_) { + if (!arguments.length) return tension; + tension = _; + return area; + }; + return area; + } + d3_svg_lineStepBefore.reverse = d3_svg_lineStepAfter; + d3_svg_lineStepAfter.reverse = d3_svg_lineStepBefore; + d3.svg.area = function() { + return d3_svg_area(d3_identity); + }; + d3.svg.area.radial = function() { + var area = d3_svg_area(d3_svg_lineRadial); + area.radius = area.x, delete area.x; + area.innerRadius = area.x0, delete area.x0; + area.outerRadius = area.x1, delete area.x1; + area.angle = area.y, delete area.y; + area.startAngle = area.y0, delete area.y0; + area.endAngle = area.y1, delete area.y1; + return area; + }; + d3.svg.chord = function() { + var source = d3_source, target = d3_target, radius = d3_svg_chordRadius, startAngle = d3_svg_arcStartAngle, endAngle = d3_svg_arcEndAngle; + function chord(d, i) { + var s = subgroup(this, source, d, i), t = subgroup(this, target, d, i); + return "M" + s.p0 + arc(s.r, s.p1, s.a1 - s.a0) + (equals(s, t) ? curve(s.r, s.p1, s.r, s.p0) : curve(s.r, s.p1, t.r, t.p0) + arc(t.r, t.p1, t.a1 - t.a0) + curve(t.r, t.p1, s.r, s.p0)) + "Z"; + } + function subgroup(self, f, d, i) { + var subgroup = f.call(self, d, i), r = radius.call(self, subgroup, i), a0 = startAngle.call(self, subgroup, i) - halfπ, a1 = endAngle.call(self, subgroup, i) - halfπ; + return { + r: r, + a0: a0, + a1: a1, + p0: [ r * Math.cos(a0), r * Math.sin(a0) ], + p1: [ r * Math.cos(a1), r * Math.sin(a1) ] + }; + } + function equals(a, b) { + return a.a0 == b.a0 && a.a1 == b.a1; + } + function arc(r, p, a) { + return "A" + r + "," + r + " 0 " + +(a > π) + ",1 " + p; + } + function curve(r0, p0, r1, p1) { + return "Q 0,0 " + p1; + } + chord.radius = function(v) { + if (!arguments.length) return radius; + radius = d3_functor(v); + return chord; + }; + chord.source = function(v) { + if (!arguments.length) return source; + source = d3_functor(v); + return chord; + }; + chord.target = function(v) { + if (!arguments.length) return target; + target = d3_functor(v); + return chord; + }; + chord.startAngle = function(v) { + if (!arguments.length) return startAngle; + startAngle = d3_functor(v); + return chord; + }; + chord.endAngle = function(v) { + if (!arguments.length) return endAngle; + endAngle = d3_functor(v); + return chord; + }; + return chord; + }; + function d3_svg_chordRadius(d) { + return d.radius; + } + d3.svg.diagonal = function() { + var source = d3_source, target = d3_target, projection = d3_svg_diagonalProjection; + function diagonal(d, i) { + var p0 = source.call(this, d, i), p3 = target.call(this, d, i), m = (p0.y + p3.y) / 2, p = [ p0, { + x: p0.x, + y: m + }, { + x: p3.x, + y: m + }, p3 ]; + p = p.map(projection); + return "M" + p[0] + "C" + p[1] + " " + p[2] + " " + p[3]; + } + diagonal.source = function(x) { + if (!arguments.length) return source; + source = d3_functor(x); + return diagonal; + }; + diagonal.target = function(x) { + if (!arguments.length) return target; + target = d3_functor(x); + return diagonal; + }; + diagonal.projection = function(x) { + if (!arguments.length) return projection; + projection = x; + return diagonal; + }; + return diagonal; + }; + function d3_svg_diagonalProjection(d) { + return [ d.x, d.y ]; + } + d3.svg.diagonal.radial = function() { + var diagonal = d3.svg.diagonal(), projection = d3_svg_diagonalProjection, projection_ = diagonal.projection; + diagonal.projection = function(x) { + return arguments.length ? projection_(d3_svg_diagonalRadialProjection(projection = x)) : projection; + }; + return diagonal; + }; + function d3_svg_diagonalRadialProjection(projection) { + return function() { + var d = projection.apply(this, arguments), r = d[0], a = d[1] - halfπ; + return [ r * Math.cos(a), r * Math.sin(a) ]; + }; + } + d3.svg.symbol = function() { + var type = d3_svg_symbolType, size = d3_svg_symbolSize; + function symbol(d, i) { + return (d3_svg_symbols.get(type.call(this, d, i)) || d3_svg_symbolCircle)(size.call(this, d, i)); + } + symbol.type = function(x) { + if (!arguments.length) return type; + type = d3_functor(x); + return symbol; + }; + symbol.size = function(x) { + if (!arguments.length) return size; + size = d3_functor(x); + return symbol; + }; + return symbol; + }; + function d3_svg_symbolSize() { + return 64; + } + function d3_svg_symbolType() { + return "circle"; + } + function d3_svg_symbolCircle(size) { + var r = Math.sqrt(size / π); + return "M0," + r + "A" + r + "," + r + " 0 1,1 0," + -r + "A" + r + "," + r + " 0 1,1 0," + r + "Z"; + } + var d3_svg_symbols = d3.map({ + circle: d3_svg_symbolCircle, + cross: function(size) { + var r = Math.sqrt(size / 5) / 2; + return "M" + -3 * r + "," + -r + "H" + -r + "V" + -3 * r + "H" + r + "V" + -r + "H" + 3 * r + "V" + r + "H" + r + "V" + 3 * r + "H" + -r + "V" + r + "H" + -3 * r + "Z"; + }, + diamond: function(size) { + var ry = Math.sqrt(size / (2 * d3_svg_symbolTan30)), rx = ry * d3_svg_symbolTan30; + return "M0," + -ry + "L" + rx + ",0" + " 0," + ry + " " + -rx + ",0" + "Z"; + }, + square: function(size) { + var r = Math.sqrt(size) / 2; + return "M" + -r + "," + -r + "L" + r + "," + -r + " " + r + "," + r + " " + -r + "," + r + "Z"; + }, + "triangle-down": function(size) { + var rx = Math.sqrt(size / d3_svg_symbolSqrt3), ry = rx * d3_svg_symbolSqrt3 / 2; + return "M0," + ry + "L" + rx + "," + -ry + " " + -rx + "," + -ry + "Z"; + }, + "triangle-up": function(size) { + var rx = Math.sqrt(size / d3_svg_symbolSqrt3), ry = rx * d3_svg_symbolSqrt3 / 2; + return "M0," + -ry + "L" + rx + "," + ry + " " + -rx + "," + ry + "Z"; + } + }); + d3.svg.symbolTypes = d3_svg_symbols.keys(); + var d3_svg_symbolSqrt3 = Math.sqrt(3), d3_svg_symbolTan30 = Math.tan(30 * d3_radians); + d3_selectionPrototype.transition = function(name) { + var id = d3_transitionInheritId || ++d3_transitionId, ns = d3_transitionNamespace(name), subgroups = [], subgroup, node, transition = d3_transitionInherit || { + time: Date.now(), + ease: d3_ease_cubicInOut, + delay: 0, + duration: 250 + }; + for (var j = -1, m = this.length; ++j < m; ) { + subgroups.push(subgroup = []); + for (var group = this[j], i = -1, n = group.length; ++i < n; ) { + if (node = group[i]) d3_transitionNode(node, i, ns, id, transition); + subgroup.push(node); + } + } + return d3_transition(subgroups, ns, id); + }; + d3_selectionPrototype.interrupt = function(name) { + return this.each(name == null ? d3_selection_interrupt : d3_selection_interruptNS(d3_transitionNamespace(name))); + }; + var d3_selection_interrupt = d3_selection_interruptNS(d3_transitionNamespace()); + function d3_selection_interruptNS(ns) { + return function() { + var lock, activeId, active; + if ((lock = this[ns]) && (active = lock[activeId = lock.active])) { + active.timer.c = null; + active.timer.t = NaN; + if (--lock.count) delete lock[activeId]; else delete this[ns]; + lock.active += .5; + active.event && active.event.interrupt.call(this, this.__data__, active.index); + } + }; + } + function d3_transition(groups, ns, id) { + d3_subclass(groups, d3_transitionPrototype); + groups.namespace = ns; + groups.id = id; + return groups; + } + var d3_transitionPrototype = [], d3_transitionId = 0, d3_transitionInheritId, d3_transitionInherit; + d3_transitionPrototype.call = d3_selectionPrototype.call; + d3_transitionPrototype.empty = d3_selectionPrototype.empty; + d3_transitionPrototype.node = d3_selectionPrototype.node; + d3_transitionPrototype.size = d3_selectionPrototype.size; + d3.transition = function(selection, name) { + return selection && selection.transition ? d3_transitionInheritId ? selection.transition(name) : selection : d3.selection().transition(selection); + }; + d3.transition.prototype = d3_transitionPrototype; + d3_transitionPrototype.select = function(selector) { + var id = this.id, ns = this.namespace, subgroups = [], subgroup, subnode, node; + selector = d3_selection_selector(selector); + for (var j = -1, m = this.length; ++j < m; ) { + subgroups.push(subgroup = []); + for (var group = this[j], i = -1, n = group.length; ++i < n; ) { + if ((node = group[i]) && (subnode = selector.call(node, node.__data__, i, j))) { + if ("__data__" in node) subnode.__data__ = node.__data__; + d3_transitionNode(subnode, i, ns, id, node[ns][id]); + subgroup.push(subnode); + } else { + subgroup.push(null); + } + } + } + return d3_transition(subgroups, ns, id); + }; + d3_transitionPrototype.selectAll = function(selector) { + var id = this.id, ns = this.namespace, subgroups = [], subgroup, subnodes, node, subnode, transition; + selector = d3_selection_selectorAll(selector); + for (var j = -1, m = this.length; ++j < m; ) { + for (var group = this[j], i = -1, n = group.length; ++i < n; ) { + if (node = group[i]) { + transition = node[ns][id]; + subnodes = selector.call(node, node.__data__, i, j); + subgroups.push(subgroup = []); + for (var k = -1, o = subnodes.length; ++k < o; ) { + if (subnode = subnodes[k]) d3_transitionNode(subnode, k, ns, id, transition); + subgroup.push(subnode); + } + } + } + } + return d3_transition(subgroups, ns, id); + }; + d3_transitionPrototype.filter = function(filter) { + var subgroups = [], subgroup, group, node; + if (typeof filter !== "function") filter = d3_selection_filter(filter); + for (var j = 0, m = this.length; j < m; j++) { + subgroups.push(subgroup = []); + for (var group = this[j], i = 0, n = group.length; i < n; i++) { + if ((node = group[i]) && filter.call(node, node.__data__, i, j)) { + subgroup.push(node); + } + } + } + return d3_transition(subgroups, this.namespace, this.id); + }; + d3_transitionPrototype.tween = function(name, tween) { + var id = this.id, ns = this.namespace; + if (arguments.length < 2) return this.node()[ns][id].tween.get(name); + return d3_selection_each(this, tween == null ? function(node) { + node[ns][id].tween.remove(name); + } : function(node) { + node[ns][id].tween.set(name, tween); + }); + }; + function d3_transition_tween(groups, name, value, tween) { + var id = groups.id, ns = groups.namespace; + return d3_selection_each(groups, typeof value === "function" ? function(node, i, j) { + node[ns][id].tween.set(name, tween(value.call(node, node.__data__, i, j))); + } : (value = tween(value), function(node) { + node[ns][id].tween.set(name, value); + })); + } + d3_transitionPrototype.attr = function(nameNS, value) { + if (arguments.length < 2) { + for (value in nameNS) this.attr(value, nameNS[value]); + return this; + } + var interpolate = nameNS == "transform" ? d3_interpolateTransform : d3_interpolate, name = d3.ns.qualify(nameNS); + function attrNull() { + this.removeAttribute(name); + } + function attrNullNS() { + this.removeAttributeNS(name.space, name.local); + } + function attrTween(b) { + return b == null ? attrNull : (b += "", function() { + var a = this.getAttribute(name), i; + return a !== b && (i = interpolate(a, b), function(t) { + this.setAttribute(name, i(t)); + }); + }); + } + function attrTweenNS(b) { + return b == null ? attrNullNS : (b += "", function() { + var a = this.getAttributeNS(name.space, name.local), i; + return a !== b && (i = interpolate(a, b), function(t) { + this.setAttributeNS(name.space, name.local, i(t)); + }); + }); + } + return d3_transition_tween(this, "attr." + nameNS, value, name.local ? attrTweenNS : attrTween); + }; + d3_transitionPrototype.attrTween = function(nameNS, tween) { + var name = d3.ns.qualify(nameNS); + function attrTween(d, i) { + var f = tween.call(this, d, i, this.getAttribute(name)); + return f && function(t) { + this.setAttribute(name, f(t)); + }; + } + function attrTweenNS(d, i) { + var f = tween.call(this, d, i, this.getAttributeNS(name.space, name.local)); + return f && function(t) { + this.setAttributeNS(name.space, name.local, f(t)); + }; + } + return this.tween("attr." + nameNS, name.local ? attrTweenNS : attrTween); + }; + d3_transitionPrototype.style = function(name, value, priority) { + var n = arguments.length; + if (n < 3) { + if (typeof name !== "string") { + if (n < 2) value = ""; + for (priority in name) this.style(priority, name[priority], value); + return this; + } + priority = ""; + } + function styleNull() { + this.style.removeProperty(name); + } + function styleString(b) { + return b == null ? styleNull : (b += "", function() { + var a = d3_window(this).getComputedStyle(this, null).getPropertyValue(name), i; + return a !== b && (i = d3_interpolate(a, b), function(t) { + this.style.setProperty(name, i(t), priority); + }); + }); + } + return d3_transition_tween(this, "style." + name, value, styleString); + }; + d3_transitionPrototype.styleTween = function(name, tween, priority) { + if (arguments.length < 3) priority = ""; + function styleTween(d, i) { + var f = tween.call(this, d, i, d3_window(this).getComputedStyle(this, null).getPropertyValue(name)); + return f && function(t) { + this.style.setProperty(name, f(t), priority); + }; + } + return this.tween("style." + name, styleTween); + }; + d3_transitionPrototype.text = function(value) { + return d3_transition_tween(this, "text", value, d3_transition_text); + }; + function d3_transition_text(b) { + if (b == null) b = ""; + return function() { + this.textContent = b; + }; + } + d3_transitionPrototype.remove = function() { + var ns = this.namespace; + return this.each("end.transition", function() { + var p; + if (this[ns].count < 2 && (p = this.parentNode)) p.removeChild(this); + }); + }; + d3_transitionPrototype.ease = function(value) { + var id = this.id, ns = this.namespace; + if (arguments.length < 1) return this.node()[ns][id].ease; + if (typeof value !== "function") value = d3.ease.apply(d3, arguments); + return d3_selection_each(this, function(node) { + node[ns][id].ease = value; + }); + }; + d3_transitionPrototype.delay = function(value) { + var id = this.id, ns = this.namespace; + if (arguments.length < 1) return this.node()[ns][id].delay; + return d3_selection_each(this, typeof value === "function" ? function(node, i, j) { + node[ns][id].delay = +value.call(node, node.__data__, i, j); + } : (value = +value, function(node) { + node[ns][id].delay = value; + })); + }; + d3_transitionPrototype.duration = function(value) { + var id = this.id, ns = this.namespace; + if (arguments.length < 1) return this.node()[ns][id].duration; + return d3_selection_each(this, typeof value === "function" ? function(node, i, j) { + node[ns][id].duration = Math.max(1, value.call(node, node.__data__, i, j)); + } : (value = Math.max(1, value), function(node) { + node[ns][id].duration = value; + })); + }; + d3_transitionPrototype.each = function(type, listener) { + var id = this.id, ns = this.namespace; + if (arguments.length < 2) { + var inherit = d3_transitionInherit, inheritId = d3_transitionInheritId; + try { + d3_transitionInheritId = id; + d3_selection_each(this, function(node, i, j) { + d3_transitionInherit = node[ns][id]; + type.call(node, node.__data__, i, j); + }); + } finally { + d3_transitionInherit = inherit; + d3_transitionInheritId = inheritId; + } + } else { + d3_selection_each(this, function(node) { + var transition = node[ns][id]; + (transition.event || (transition.event = d3.dispatch("start", "end", "interrupt"))).on(type, listener); + }); + } + return this; + }; + d3_transitionPrototype.transition = function() { + var id0 = this.id, id1 = ++d3_transitionId, ns = this.namespace, subgroups = [], subgroup, group, node, transition; + for (var j = 0, m = this.length; j < m; j++) { + subgroups.push(subgroup = []); + for (var group = this[j], i = 0, n = group.length; i < n; i++) { + if (node = group[i]) { + transition = node[ns][id0]; + d3_transitionNode(node, i, ns, id1, { + time: transition.time, + ease: transition.ease, + delay: transition.delay + transition.duration, + duration: transition.duration + }); + } + subgroup.push(node); + } + } + return d3_transition(subgroups, ns, id1); + }; + function d3_transitionNamespace(name) { + return name == null ? "__transition__" : "__transition_" + name + "__"; + } + function d3_transitionNode(node, i, ns, id, inherit) { + var lock = node[ns] || (node[ns] = { + active: 0, + count: 0 + }), transition = lock[id], time, timer, duration, ease, tweens; + function schedule(elapsed) { + var delay = transition.delay; + timer.t = delay + time; + if (delay <= elapsed) return start(elapsed - delay); + timer.c = start; + } + function start(elapsed) { + var activeId = lock.active, active = lock[activeId]; + if (active) { + active.timer.c = null; + active.timer.t = NaN; + --lock.count; + delete lock[activeId]; + active.event && active.event.interrupt.call(node, node.__data__, active.index); + } + for (var cancelId in lock) { + if (+cancelId < id) { + var cancel = lock[cancelId]; + cancel.timer.c = null; + cancel.timer.t = NaN; + --lock.count; + delete lock[cancelId]; + } + } + timer.c = tick; + d3_timer(function() { + if (timer.c && tick(elapsed || 1)) { + timer.c = null; + timer.t = NaN; + } + return 1; + }, 0, time); + lock.active = id; + transition.event && transition.event.start.call(node, node.__data__, i); + tweens = []; + transition.tween.forEach(function(key, value) { + if (value = value.call(node, node.__data__, i)) { + tweens.push(value); + } + }); + ease = transition.ease; + duration = transition.duration; + } + function tick(elapsed) { + var t = elapsed / duration, e = ease(t), n = tweens.length; + while (n > 0) { + tweens[--n].call(node, e); + } + if (t >= 1) { + transition.event && transition.event.end.call(node, node.__data__, i); + if (--lock.count) delete lock[id]; else delete node[ns]; + return 1; + } + } + if (!transition) { + time = inherit.time; + timer = d3_timer(schedule, 0, time); + transition = lock[id] = { + tween: new d3_Map(), + time: time, + timer: timer, + delay: inherit.delay, + duration: inherit.duration, + ease: inherit.ease, + index: i + }; + inherit = null; + ++lock.count; + } + } + d3.svg.axis = function() { + var scale = d3.scale.linear(), orient = d3_svg_axisDefaultOrient, innerTickSize = 6, outerTickSize = 6, tickPadding = 3, tickArguments_ = [ 10 ], tickValues = null, tickFormat_; + function axis(g) { + g.each(function() { + var g = d3.select(this); + var scale0 = this.__chart__ || scale, scale1 = this.__chart__ = scale.copy(); + var ticks = tickValues == null ? scale1.ticks ? scale1.ticks.apply(scale1, tickArguments_) : scale1.domain() : tickValues, tickFormat = tickFormat_ == null ? scale1.tickFormat ? scale1.tickFormat.apply(scale1, tickArguments_) : d3_identity : tickFormat_, tick = g.selectAll(".tick").data(ticks, scale1), tickEnter = tick.enter().insert("g", ".domain").attr("class", "tick").style("opacity", ε), tickExit = d3.transition(tick.exit()).style("opacity", ε).remove(), tickUpdate = d3.transition(tick.order()).style("opacity", 1), tickSpacing = Math.max(innerTickSize, 0) + tickPadding, tickTransform; + var range = d3_scaleRange(scale1), path = g.selectAll(".domain").data([ 0 ]), pathUpdate = (path.enter().append("path").attr("class", "domain"), + d3.transition(path)); + tickEnter.append("line"); + tickEnter.append("text"); + var lineEnter = tickEnter.select("line"), lineUpdate = tickUpdate.select("line"), text = tick.select("text").text(tickFormat), textEnter = tickEnter.select("text"), textUpdate = tickUpdate.select("text"), sign = orient === "top" || orient === "left" ? -1 : 1, x1, x2, y1, y2; + if (orient === "bottom" || orient === "top") { + tickTransform = d3_svg_axisX, x1 = "x", y1 = "y", x2 = "x2", y2 = "y2"; + text.attr("dy", sign < 0 ? "0em" : ".71em").style("text-anchor", "middle"); + pathUpdate.attr("d", "M" + range[0] + "," + sign * outerTickSize + "V0H" + range[1] + "V" + sign * outerTickSize); + } else { + tickTransform = d3_svg_axisY, x1 = "y", y1 = "x", x2 = "y2", y2 = "x2"; + text.attr("dy", ".32em").style("text-anchor", sign < 0 ? "end" : "start"); + pathUpdate.attr("d", "M" + sign * outerTickSize + "," + range[0] + "H0V" + range[1] + "H" + sign * outerTickSize); + } + lineEnter.attr(y2, sign * innerTickSize); + textEnter.attr(y1, sign * tickSpacing); + lineUpdate.attr(x2, 0).attr(y2, sign * innerTickSize); + textUpdate.attr(x1, 0).attr(y1, sign * tickSpacing); + if (scale1.rangeBand) { + var x = scale1, dx = x.rangeBand() / 2; + scale0 = scale1 = function(d) { + return x(d) + dx; + }; + } else if (scale0.rangeBand) { + scale0 = scale1; + } else { + tickExit.call(tickTransform, scale1, scale0); + } + tickEnter.call(tickTransform, scale0, scale1); + tickUpdate.call(tickTransform, scale1, scale1); + }); + } + axis.scale = function(x) { + if (!arguments.length) return scale; + scale = x; + return axis; + }; + axis.orient = function(x) { + if (!arguments.length) return orient; + orient = x in d3_svg_axisOrients ? x + "" : d3_svg_axisDefaultOrient; + return axis; + }; + axis.ticks = function() { + if (!arguments.length) return tickArguments_; + tickArguments_ = d3_array(arguments); + return axis; + }; + axis.tickValues = function(x) { + if (!arguments.length) return tickValues; + tickValues = x; + return axis; + }; + axis.tickFormat = function(x) { + if (!arguments.length) return tickFormat_; + tickFormat_ = x; + return axis; + }; + axis.tickSize = function(x) { + var n = arguments.length; + if (!n) return innerTickSize; + innerTickSize = +x; + outerTickSize = +arguments[n - 1]; + return axis; + }; + axis.innerTickSize = function(x) { + if (!arguments.length) return innerTickSize; + innerTickSize = +x; + return axis; + }; + axis.outerTickSize = function(x) { + if (!arguments.length) return outerTickSize; + outerTickSize = +x; + return axis; + }; + axis.tickPadding = function(x) { + if (!arguments.length) return tickPadding; + tickPadding = +x; + return axis; + }; + axis.tickSubdivide = function() { + return arguments.length && axis; + }; + return axis; + }; + var d3_svg_axisDefaultOrient = "bottom", d3_svg_axisOrients = { + top: 1, + right: 1, + bottom: 1, + left: 1 + }; + function d3_svg_axisX(selection, x0, x1) { + selection.attr("transform", function(d) { + var v0 = x0(d); + return "translate(" + (isFinite(v0) ? v0 : x1(d)) + ",0)"; + }); + } + function d3_svg_axisY(selection, y0, y1) { + selection.attr("transform", function(d) { + var v0 = y0(d); + return "translate(0," + (isFinite(v0) ? v0 : y1(d)) + ")"; + }); + } + d3.svg.brush = function() { + var event = d3_eventDispatch(brush, "brushstart", "brush", "brushend"), x = null, y = null, xExtent = [ 0, 0 ], yExtent = [ 0, 0 ], xExtentDomain, yExtentDomain, xClamp = true, yClamp = true, resizes = d3_svg_brushResizes[0]; + function brush(g) { + g.each(function() { + var g = d3.select(this).style("pointer-events", "all").style("-webkit-tap-highlight-color", "rgba(0,0,0,0)").on("mousedown.brush", brushstart).on("touchstart.brush", brushstart); + var background = g.selectAll(".background").data([ 0 ]); + background.enter().append("rect").attr("class", "background").style("visibility", "hidden").style("cursor", "crosshair"); + g.selectAll(".extent").data([ 0 ]).enter().append("rect").attr("class", "extent").style("cursor", "move"); + var resize = g.selectAll(".resize").data(resizes, d3_identity); + resize.exit().remove(); + resize.enter().append("g").attr("class", function(d) { + return "resize " + d; + }).style("cursor", function(d) { + return d3_svg_brushCursor[d]; + }).append("rect").attr("x", function(d) { + return /[ew]$/.test(d) ? -3 : null; + }).attr("y", function(d) { + return /^[ns]/.test(d) ? -3 : null; + }).attr("width", 6).attr("height", 6).style("visibility", "hidden"); + resize.style("display", brush.empty() ? "none" : null); + var gUpdate = d3.transition(g), backgroundUpdate = d3.transition(background), range; + if (x) { + range = d3_scaleRange(x); + backgroundUpdate.attr("x", range[0]).attr("width", range[1] - range[0]); + redrawX(gUpdate); + } + if (y) { + range = d3_scaleRange(y); + backgroundUpdate.attr("y", range[0]).attr("height", range[1] - range[0]); + redrawY(gUpdate); + } + redraw(gUpdate); + }); + } + brush.event = function(g) { + g.each(function() { + var event_ = event.of(this, arguments), extent1 = { + x: xExtent, + y: yExtent, + i: xExtentDomain, + j: yExtentDomain + }, extent0 = this.__chart__ || extent1; + this.__chart__ = extent1; + if (d3_transitionInheritId) { + d3.select(this).transition().each("start.brush", function() { + xExtentDomain = extent0.i; + yExtentDomain = extent0.j; + xExtent = extent0.x; + yExtent = extent0.y; + event_({ + type: "brushstart" + }); + }).tween("brush:brush", function() { + var xi = d3_interpolateArray(xExtent, extent1.x), yi = d3_interpolateArray(yExtent, extent1.y); + xExtentDomain = yExtentDomain = null; + return function(t) { + xExtent = extent1.x = xi(t); + yExtent = extent1.y = yi(t); + event_({ + type: "brush", + mode: "resize" + }); + }; + }).each("end.brush", function() { + xExtentDomain = extent1.i; + yExtentDomain = extent1.j; + event_({ + type: "brush", + mode: "resize" + }); + event_({ + type: "brushend" + }); + }); + } else { + event_({ + type: "brushstart" + }); + event_({ + type: "brush", + mode: "resize" + }); + event_({ + type: "brushend" + }); + } + }); + }; + function redraw(g) { + g.selectAll(".resize").attr("transform", function(d) { + return "translate(" + xExtent[+/e$/.test(d)] + "," + yExtent[+/^s/.test(d)] + ")"; + }); + } + function redrawX(g) { + g.select(".extent").attr("x", xExtent[0]); + g.selectAll(".extent,.n>rect,.s>rect").attr("width", xExtent[1] - xExtent[0]); + } + function redrawY(g) { + g.select(".extent").attr("y", yExtent[0]); + g.selectAll(".extent,.e>rect,.w>rect").attr("height", yExtent[1] - yExtent[0]); + } + function brushstart() { + var target = this, eventTarget = d3.select(d3.event.target), event_ = event.of(target, arguments), g = d3.select(target), resizing = eventTarget.datum(), resizingX = !/^(n|s)$/.test(resizing) && x, resizingY = !/^(e|w)$/.test(resizing) && y, dragging = eventTarget.classed("extent"), dragRestore = d3_event_dragSuppress(target), center, origin = d3.mouse(target), offset; + var w = d3.select(d3_window(target)).on("keydown.brush", keydown).on("keyup.brush", keyup); + if (d3.event.changedTouches) { + w.on("touchmove.brush", brushmove).on("touchend.brush", brushend); + } else { + w.on("mousemove.brush", brushmove).on("mouseup.brush", brushend); + } + g.interrupt().selectAll("*").interrupt(); + if (dragging) { + origin[0] = xExtent[0] - origin[0]; + origin[1] = yExtent[0] - origin[1]; + } else if (resizing) { + var ex = +/w$/.test(resizing), ey = +/^n/.test(resizing); + offset = [ xExtent[1 - ex] - origin[0], yExtent[1 - ey] - origin[1] ]; + origin[0] = xExtent[ex]; + origin[1] = yExtent[ey]; + } else if (d3.event.altKey) center = origin.slice(); + g.style("pointer-events", "none").selectAll(".resize").style("display", null); + d3.select("body").style("cursor", eventTarget.style("cursor")); + event_({ + type: "brushstart" + }); + brushmove(); + function keydown() { + if (d3.event.keyCode == 32) { + if (!dragging) { + center = null; + origin[0] -= xExtent[1]; + origin[1] -= yExtent[1]; + dragging = 2; + } + d3_eventPreventDefault(); + } + } + function keyup() { + if (d3.event.keyCode == 32 && dragging == 2) { + origin[0] += xExtent[1]; + origin[1] += yExtent[1]; + dragging = 0; + d3_eventPreventDefault(); + } + } + function brushmove() { + var point = d3.mouse(target), moved = false; + if (offset) { + point[0] += offset[0]; + point[1] += offset[1]; + } + if (!dragging) { + if (d3.event.altKey) { + if (!center) center = [ (xExtent[0] + xExtent[1]) / 2, (yExtent[0] + yExtent[1]) / 2 ]; + origin[0] = xExtent[+(point[0] < center[0])]; + origin[1] = yExtent[+(point[1] < center[1])]; + } else center = null; + } + if (resizingX && move1(point, x, 0)) { + redrawX(g); + moved = true; + } + if (resizingY && move1(point, y, 1)) { + redrawY(g); + moved = true; + } + if (moved) { + redraw(g); + event_({ + type: "brush", + mode: dragging ? "move" : "resize" + }); + } + } + function move1(point, scale, i) { + var range = d3_scaleRange(scale), r0 = range[0], r1 = range[1], position = origin[i], extent = i ? yExtent : xExtent, size = extent[1] - extent[0], min, max; + if (dragging) { + r0 -= position; + r1 -= size + position; + } + min = (i ? yClamp : xClamp) ? Math.max(r0, Math.min(r1, point[i])) : point[i]; + if (dragging) { + max = (min += position) + size; + } else { + if (center) position = Math.max(r0, Math.min(r1, 2 * center[i] - min)); + if (position < min) { + max = min; + min = position; + } else { + max = position; + } + } + if (extent[0] != min || extent[1] != max) { + if (i) yExtentDomain = null; else xExtentDomain = null; + extent[0] = min; + extent[1] = max; + return true; + } + } + function brushend() { + brushmove(); + g.style("pointer-events", "all").selectAll(".resize").style("display", brush.empty() ? "none" : null); + d3.select("body").style("cursor", null); + w.on("mousemove.brush", null).on("mouseup.brush", null).on("touchmove.brush", null).on("touchend.brush", null).on("keydown.brush", null).on("keyup.brush", null); + dragRestore(); + event_({ + type: "brushend" + }); + } + } + brush.x = function(z) { + if (!arguments.length) return x; + x = z; + resizes = d3_svg_brushResizes[!x << 1 | !y]; + return brush; + }; + brush.y = function(z) { + if (!arguments.length) return y; + y = z; + resizes = d3_svg_brushResizes[!x << 1 | !y]; + return brush; + }; + brush.clamp = function(z) { + if (!arguments.length) return x && y ? [ xClamp, yClamp ] : x ? xClamp : y ? yClamp : null; + if (x && y) xClamp = !!z[0], yClamp = !!z[1]; else if (x) xClamp = !!z; else if (y) yClamp = !!z; + return brush; + }; + brush.extent = function(z) { + var x0, x1, y0, y1, t; + if (!arguments.length) { + if (x) { + if (xExtentDomain) { + x0 = xExtentDomain[0], x1 = xExtentDomain[1]; + } else { + x0 = xExtent[0], x1 = xExtent[1]; + if (x.invert) x0 = x.invert(x0), x1 = x.invert(x1); + if (x1 < x0) t = x0, x0 = x1, x1 = t; + } + } + if (y) { + if (yExtentDomain) { + y0 = yExtentDomain[0], y1 = yExtentDomain[1]; + } else { + y0 = yExtent[0], y1 = yExtent[1]; + if (y.invert) y0 = y.invert(y0), y1 = y.invert(y1); + if (y1 < y0) t = y0, y0 = y1, y1 = t; + } + } + return x && y ? [ [ x0, y0 ], [ x1, y1 ] ] : x ? [ x0, x1 ] : y && [ y0, y1 ]; + } + if (x) { + x0 = z[0], x1 = z[1]; + if (y) x0 = x0[0], x1 = x1[0]; + xExtentDomain = [ x0, x1 ]; + if (x.invert) x0 = x(x0), x1 = x(x1); + if (x1 < x0) t = x0, x0 = x1, x1 = t; + if (x0 != xExtent[0] || x1 != xExtent[1]) xExtent = [ x0, x1 ]; + } + if (y) { + y0 = z[0], y1 = z[1]; + if (x) y0 = y0[1], y1 = y1[1]; + yExtentDomain = [ y0, y1 ]; + if (y.invert) y0 = y(y0), y1 = y(y1); + if (y1 < y0) t = y0, y0 = y1, y1 = t; + if (y0 != yExtent[0] || y1 != yExtent[1]) yExtent = [ y0, y1 ]; + } + return brush; + }; + brush.clear = function() { + if (!brush.empty()) { + xExtent = [ 0, 0 ], yExtent = [ 0, 0 ]; + xExtentDomain = yExtentDomain = null; + } + return brush; + }; + brush.empty = function() { + return !!x && xExtent[0] == xExtent[1] || !!y && yExtent[0] == yExtent[1]; + }; + return d3.rebind(brush, event, "on"); + }; + var d3_svg_brushCursor = { + n: "ns-resize", + e: "ew-resize", + s: "ns-resize", + w: "ew-resize", + nw: "nwse-resize", + ne: "nesw-resize", + se: "nwse-resize", + sw: "nesw-resize" + }; + var d3_svg_brushResizes = [ [ "n", "e", "s", "w", "nw", "ne", "se", "sw" ], [ "e", "w" ], [ "n", "s" ], [] ]; + var d3_time_format = d3_time.format = d3_locale_enUS.timeFormat; + var d3_time_formatUtc = d3_time_format.utc; + var d3_time_formatIso = d3_time_formatUtc("%Y-%m-%dT%H:%M:%S.%LZ"); + d3_time_format.iso = Date.prototype.toISOString && +new Date("2000-01-01T00:00:00.000Z") ? d3_time_formatIsoNative : d3_time_formatIso; + function d3_time_formatIsoNative(date) { + return date.toISOString(); + } + d3_time_formatIsoNative.parse = function(string) { + var date = new Date(string); + return isNaN(date) ? null : date; + }; + d3_time_formatIsoNative.toString = d3_time_formatIso.toString; + d3_time.second = d3_time_interval(function(date) { + return new d3_date(Math.floor(date / 1e3) * 1e3); + }, function(date, offset) { + date.setTime(date.getTime() + Math.floor(offset) * 1e3); + }, function(date) { + return date.getSeconds(); + }); + d3_time.seconds = d3_time.second.range; + d3_time.seconds.utc = d3_time.second.utc.range; + d3_time.minute = d3_time_interval(function(date) { + return new d3_date(Math.floor(date / 6e4) * 6e4); + }, function(date, offset) { + date.setTime(date.getTime() + Math.floor(offset) * 6e4); + }, function(date) { + return date.getMinutes(); + }); + d3_time.minutes = d3_time.minute.range; + d3_time.minutes.utc = d3_time.minute.utc.range; + d3_time.hour = d3_time_interval(function(date) { + var timezone = date.getTimezoneOffset() / 60; + return new d3_date((Math.floor(date / 36e5 - timezone) + timezone) * 36e5); + }, function(date, offset) { + date.setTime(date.getTime() + Math.floor(offset) * 36e5); + }, function(date) { + return date.getHours(); + }); + d3_time.hours = d3_time.hour.range; + d3_time.hours.utc = d3_time.hour.utc.range; + d3_time.month = d3_time_interval(function(date) { + date = d3_time.day(date); + date.setDate(1); + return date; + }, function(date, offset) { + date.setMonth(date.getMonth() + offset); + }, function(date) { + return date.getMonth(); + }); + d3_time.months = d3_time.month.range; + d3_time.months.utc = d3_time.month.utc.range; + function d3_time_scale(linear, methods, format) { + function scale(x) { + return linear(x); + } + scale.invert = function(x) { + return d3_time_scaleDate(linear.invert(x)); + }; + scale.domain = function(x) { + if (!arguments.length) return linear.domain().map(d3_time_scaleDate); + linear.domain(x); + return scale; + }; + function tickMethod(extent, count) { + var span = extent[1] - extent[0], target = span / count, i = d3.bisect(d3_time_scaleSteps, target); + return i == d3_time_scaleSteps.length ? [ methods.year, d3_scale_linearTickRange(extent.map(function(d) { + return d / 31536e6; + }), count)[2] ] : !i ? [ d3_time_scaleMilliseconds, d3_scale_linearTickRange(extent, count)[2] ] : methods[target / d3_time_scaleSteps[i - 1] < d3_time_scaleSteps[i] / target ? i - 1 : i]; + } + scale.nice = function(interval, skip) { + var domain = scale.domain(), extent = d3_scaleExtent(domain), method = interval == null ? tickMethod(extent, 10) : typeof interval === "number" && tickMethod(extent, interval); + if (method) interval = method[0], skip = method[1]; + function skipped(date) { + return !isNaN(date) && !interval.range(date, d3_time_scaleDate(+date + 1), skip).length; + } + return scale.domain(d3_scale_nice(domain, skip > 1 ? { + floor: function(date) { + while (skipped(date = interval.floor(date))) date = d3_time_scaleDate(date - 1); + return date; + }, + ceil: function(date) { + while (skipped(date = interval.ceil(date))) date = d3_time_scaleDate(+date + 1); + return date; + } + } : interval)); + }; + scale.ticks = function(interval, skip) { + var extent = d3_scaleExtent(scale.domain()), method = interval == null ? tickMethod(extent, 10) : typeof interval === "number" ? tickMethod(extent, interval) : !interval.range && [ { + range: interval + }, skip ]; + if (method) interval = method[0], skip = method[1]; + return interval.range(extent[0], d3_time_scaleDate(+extent[1] + 1), skip < 1 ? 1 : skip); + }; + scale.tickFormat = function() { + return format; + }; + scale.copy = function() { + return d3_time_scale(linear.copy(), methods, format); + }; + return d3_scale_linearRebind(scale, linear); + } + function d3_time_scaleDate(t) { + return new Date(t); + } + var d3_time_scaleSteps = [ 1e3, 5e3, 15e3, 3e4, 6e4, 3e5, 9e5, 18e5, 36e5, 108e5, 216e5, 432e5, 864e5, 1728e5, 6048e5, 2592e6, 7776e6, 31536e6 ]; + var d3_time_scaleLocalMethods = [ [ d3_time.second, 1 ], [ d3_time.second, 5 ], [ d3_time.second, 15 ], [ d3_time.second, 30 ], [ d3_time.minute, 1 ], [ d3_time.minute, 5 ], [ d3_time.minute, 15 ], [ d3_time.minute, 30 ], [ d3_time.hour, 1 ], [ d3_time.hour, 3 ], [ d3_time.hour, 6 ], [ d3_time.hour, 12 ], [ d3_time.day, 1 ], [ d3_time.day, 2 ], [ d3_time.week, 1 ], [ d3_time.month, 1 ], [ d3_time.month, 3 ], [ d3_time.year, 1 ] ]; + var d3_time_scaleLocalFormat = d3_time_format.multi([ [ ".%L", function(d) { + return d.getMilliseconds(); + } ], [ ":%S", function(d) { + return d.getSeconds(); + } ], [ "%I:%M", function(d) { + return d.getMinutes(); + } ], [ "%I %p", function(d) { + return d.getHours(); + } ], [ "%a %d", function(d) { + return d.getDay() && d.getDate() != 1; + } ], [ "%b %d", function(d) { + return d.getDate() != 1; + } ], [ "%B", function(d) { + return d.getMonth(); + } ], [ "%Y", d3_true ] ]); + var d3_time_scaleMilliseconds = { + range: function(start, stop, step) { + return d3.range(Math.ceil(start / step) * step, +stop, step).map(d3_time_scaleDate); + }, + floor: d3_identity, + ceil: d3_identity + }; + d3_time_scaleLocalMethods.year = d3_time.year; + d3_time.scale = function() { + return d3_time_scale(d3.scale.linear(), d3_time_scaleLocalMethods, d3_time_scaleLocalFormat); + }; + var d3_time_scaleUtcMethods = d3_time_scaleLocalMethods.map(function(m) { + return [ m[0].utc, m[1] ]; + }); + var d3_time_scaleUtcFormat = d3_time_formatUtc.multi([ [ ".%L", function(d) { + return d.getUTCMilliseconds(); + } ], [ ":%S", function(d) { + return d.getUTCSeconds(); + } ], [ "%I:%M", function(d) { + return d.getUTCMinutes(); + } ], [ "%I %p", function(d) { + return d.getUTCHours(); + } ], [ "%a %d", function(d) { + return d.getUTCDay() && d.getUTCDate() != 1; + } ], [ "%b %d", function(d) { + return d.getUTCDate() != 1; + } ], [ "%B", function(d) { + return d.getUTCMonth(); + } ], [ "%Y", d3_true ] ]); + d3_time_scaleUtcMethods.year = d3_time.year.utc; + d3_time.scale.utc = function() { + return d3_time_scale(d3.scale.linear(), d3_time_scaleUtcMethods, d3_time_scaleUtcFormat); + }; + d3.text = d3_xhrType(function(request) { + return request.responseText; + }); + d3.json = function(url, callback) { + return d3_xhr(url, "application/json", d3_json, callback); + }; + function d3_json(request) { + return JSON.parse(request.responseText); + } + d3.html = function(url, callback) { + return d3_xhr(url, "text/html", d3_html, callback); + }; + function d3_html(request) { + var range = d3_document.createRange(); + range.selectNode(d3_document.body); + return range.createContextualFragment(request.responseText); + } + d3.xml = d3_xhrType(function(request) { + return request.responseXML; + }); + if (true) this.d3 = d3, !(__WEBPACK_AMD_DEFINE_FACTORY__ = (d3), __WEBPACK_AMD_DEFINE_RESULT__ = (typeof __WEBPACK_AMD_DEFINE_FACTORY__ === 'function' ? (__WEBPACK_AMD_DEFINE_FACTORY__.call(exports, __webpack_require__, exports, module)) : __WEBPACK_AMD_DEFINE_FACTORY__), __WEBPACK_AMD_DEFINE_RESULT__ !== undefined && (module.exports = __WEBPACK_AMD_DEFINE_RESULT__)); else if (typeof module === "object" && module.exports) module.exports = d3; else this.d3 = d3; + }(); + +/***/ }), +/* 3 */ +/***/ (function(module, exports, __webpack_require__) { + + 'use strict'; + + Object.defineProperty(exports, "__esModule", { + value: true + }); + + var _d = __webpack_require__(2); + + var _d2 = _interopRequireDefault(_d); + + function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } + + function _toConsumableArray(arr) { if (Array.isArray(arr)) { for (var i = 0, arr2 = Array(arr.length); i < arr.length; i++) { arr2[i] = arr[i]; } return arr2; } else { return Array.from(arr); } } + + function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } + + var Barchart = + // svg: d3 object with the svg in question + // exp_array: list of (feature_name, weight) + function Barchart(svg, exp_array) { + var two_sided = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : true; + var titles = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : undefined; + var colors = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : ['red', 'green']; + var show_numbers = arguments.length > 5 && arguments[5] !== undefined ? arguments[5] : false; + var bar_height = arguments.length > 6 && arguments[6] !== undefined ? arguments[6] : 5; + + _classCallCheck(this, Barchart); + + var svg_width = Math.min(600, parseInt(svg.style('width'))); + var bar_width = two_sided ? svg_width / 2 : svg_width; + if (titles === undefined) { + titles = two_sided ? ['Cons', 'Pros'] : 'Pros'; + } + if (show_numbers) { + bar_width = bar_width - 30; + } + var x_offset = two_sided ? svg_width / 2 : 10; + // 13.1 is +- the width of W, the widest letter. + if (two_sided && titles.length == 2) { + svg.append('text').attr('x', svg_width / 4).attr('y', 15).attr('font-size', '20').attr('text-anchor', 'middle').style('fill', colors[0]).text(titles[0]); + + svg.append('text').attr('x', svg_width / 4 * 3).attr('y', 15).attr('font-size', '20').attr('text-anchor', 'middle').style('fill', colors[1]).text(titles[1]); + } else { + var pos = two_sided ? svg_width / 2 : x_offset; + var anchor = two_sided ? 'middle' : 'begin'; + svg.append('text').attr('x', pos).attr('y', 15).attr('font-size', '20').attr('text-anchor', anchor).text(titles); + } + var yshift = 20; + var space_between_bars = 0; + var text_height = 16; + var space_between_bar_and_text = 3; + var total_bar_height = text_height + space_between_bar_and_text + bar_height + space_between_bars; + var total_height = total_bar_height * exp_array.length; + this.svg_height = total_height + yshift; + var yscale = _d2.default.scale.linear().domain([0, exp_array.length]).range([yshift, yshift + total_height]); + var names = exp_array.map(function (v) { + return v[0]; + }); + var weights = exp_array.map(function (v) { + return v[1]; + }); + var max_weight = Math.max.apply(Math, _toConsumableArray(weights.map(function (v) { + return Math.abs(v); + }))); + var xscale = _d2.default.scale.linear().domain([0, Math.max(1, max_weight)]).range([0, bar_width]); + + for (var i = 0; i < exp_array.length; ++i) { + var name = names[i]; + var weight = weights[i]; + var size = xscale(Math.abs(weight)); + var to_the_right = weight > 0 || !two_sided; + var text = svg.append('text').attr('x', to_the_right ? x_offset + 2 : x_offset - 2).attr('y', yscale(i) + text_height).attr('text-anchor', to_the_right ? 'begin' : 'end').attr('font-size', '14').text(name); + while (text.node().getBBox()['width'] + 1 > bar_width) { + var cur_text = text.text().slice(0, text.text().length - 5); + text.text(cur_text + '...'); + if (text === '...') { + break; + } + } + var bar = svg.append('rect').attr('height', bar_height).attr('x', to_the_right ? x_offset : x_offset - size).attr('y', text_height + yscale(i) + space_between_bar_and_text) // + bar_height) + .attr('width', size).style('fill', weight > 0 ? colors[1] : colors[0]); + if (show_numbers) { + var bartext = svg.append('text').attr('x', to_the_right ? x_offset + size + 1 : x_offset - size - 1).attr('text-anchor', weight > 0 || !two_sided ? 'begin' : 'end').attr('y', bar_height + yscale(i) + text_height + space_between_bar_and_text).attr('font-size', '10').text(Math.abs(weight).toFixed(2)); + } + } + var line = svg.append("line").attr("x1", x_offset).attr("x2", x_offset).attr("y1", bar_height + yshift).attr("y2", Math.max(bar_height, yscale(exp_array.length))).style("stroke-width", 2).style("stroke", "black"); + }; + + exports.default = Barchart; + +/***/ }), +/* 4 */ +/***/ (function(module, exports, __webpack_require__) { + + var __WEBPACK_AMD_DEFINE_RESULT__;/* WEBPACK VAR INJECTION */(function(global, module) {/** + * @license + * Lodash + * Copyright JS Foundation and other contributors + * Released under MIT license + * Based on Underscore.js 1.8.3 + * Copyright Jeremy Ashkenas, DocumentCloud and Investigative Reporters & Editors + */ + ;(function() { + + /** Used as a safe reference for `undefined` in pre-ES5 environments. */ + var undefined; + + /** Used as the semantic version number. */ + var VERSION = '4.17.11'; + + /** Used as the size to enable large array optimizations. */ + var LARGE_ARRAY_SIZE = 200; + + /** Error message constants. */ + var CORE_ERROR_TEXT = 'Unsupported core-js use. Try https://npms.io/search?q=ponyfill.', + FUNC_ERROR_TEXT = 'Expected a function'; + + /** Used to stand-in for `undefined` hash values. */ + var HASH_UNDEFINED = '__lodash_hash_undefined__'; + + /** Used as the maximum memoize cache size. */ + var MAX_MEMOIZE_SIZE = 500; + + /** Used as the internal argument placeholder. */ + var PLACEHOLDER = '__lodash_placeholder__'; + + /** Used to compose bitmasks for cloning. */ + var CLONE_DEEP_FLAG = 1, + CLONE_FLAT_FLAG = 2, + CLONE_SYMBOLS_FLAG = 4; + + /** Used to compose bitmasks for value comparisons. */ + var COMPARE_PARTIAL_FLAG = 1, + COMPARE_UNORDERED_FLAG = 2; + + /** Used to compose bitmasks for function metadata. */ + var WRAP_BIND_FLAG = 1, + WRAP_BIND_KEY_FLAG = 2, + WRAP_CURRY_BOUND_FLAG = 4, + WRAP_CURRY_FLAG = 8, + WRAP_CURRY_RIGHT_FLAG = 16, + WRAP_PARTIAL_FLAG = 32, + WRAP_PARTIAL_RIGHT_FLAG = 64, + WRAP_ARY_FLAG = 128, + WRAP_REARG_FLAG = 256, + WRAP_FLIP_FLAG = 512; + + /** Used as default options for `_.truncate`. */ + var DEFAULT_TRUNC_LENGTH = 30, + DEFAULT_TRUNC_OMISSION = '...'; + + /** Used to detect hot functions by number of calls within a span of milliseconds. */ + var HOT_COUNT = 800, + HOT_SPAN = 16; + + /** Used to indicate the type of lazy iteratees. */ + var LAZY_FILTER_FLAG = 1, + LAZY_MAP_FLAG = 2, + LAZY_WHILE_FLAG = 3; + + /** Used as references for various `Number` constants. */ + var INFINITY = 1 / 0, + MAX_SAFE_INTEGER = 9007199254740991, + MAX_INTEGER = 1.7976931348623157e+308, + NAN = 0 / 0; + + /** Used as references for the maximum length and index of an array. */ + var MAX_ARRAY_LENGTH = 4294967295, + MAX_ARRAY_INDEX = MAX_ARRAY_LENGTH - 1, + HALF_MAX_ARRAY_LENGTH = MAX_ARRAY_LENGTH >>> 1; + + /** Used to associate wrap methods with their bit flags. */ + var wrapFlags = [ + ['ary', WRAP_ARY_FLAG], + ['bind', WRAP_BIND_FLAG], + ['bindKey', WRAP_BIND_KEY_FLAG], + ['curry', WRAP_CURRY_FLAG], + ['curryRight', WRAP_CURRY_RIGHT_FLAG], + ['flip', WRAP_FLIP_FLAG], + ['partial', WRAP_PARTIAL_FLAG], + ['partialRight', WRAP_PARTIAL_RIGHT_FLAG], + ['rearg', WRAP_REARG_FLAG] + ]; + + /** `Object#toString` result references. */ + var argsTag = '[object Arguments]', + arrayTag = '[object Array]', + asyncTag = '[object AsyncFunction]', + boolTag = '[object Boolean]', + dateTag = '[object Date]', + domExcTag = '[object DOMException]', + errorTag = '[object Error]', + funcTag = '[object Function]', + genTag = '[object GeneratorFunction]', + mapTag = '[object Map]', + numberTag = '[object Number]', + nullTag = '[object Null]', + objectTag = '[object Object]', + promiseTag = '[object Promise]', + proxyTag = '[object Proxy]', + regexpTag = '[object RegExp]', + setTag = '[object Set]', + stringTag = '[object String]', + symbolTag = '[object Symbol]', + undefinedTag = '[object Undefined]', + weakMapTag = '[object WeakMap]', + weakSetTag = '[object WeakSet]'; + + var arrayBufferTag = '[object ArrayBuffer]', + dataViewTag = '[object DataView]', + float32Tag = '[object Float32Array]', + float64Tag = '[object Float64Array]', + int8Tag = '[object Int8Array]', + int16Tag = '[object Int16Array]', + int32Tag = '[object Int32Array]', + uint8Tag = '[object Uint8Array]', + uint8ClampedTag = '[object Uint8ClampedArray]', + uint16Tag = '[object Uint16Array]', + uint32Tag = '[object Uint32Array]'; + + /** Used to match empty string literals in compiled template source. */ + var reEmptyStringLeading = /\b__p \+= '';/g, + reEmptyStringMiddle = /\b(__p \+=) '' \+/g, + reEmptyStringTrailing = /(__e\(.*?\)|\b__t\)) \+\n'';/g; + + /** Used to match HTML entities and HTML characters. */ + var reEscapedHtml = /&(?:amp|lt|gt|quot|#39);/g, + reUnescapedHtml = /[&<>"']/g, + reHasEscapedHtml = RegExp(reEscapedHtml.source), + reHasUnescapedHtml = RegExp(reUnescapedHtml.source); + + /** Used to match template delimiters. */ + var reEscape = /<%-([\s\S]+?)%>/g, + reEvaluate = /<%([\s\S]+?)%>/g, + reInterpolate = /<%=([\s\S]+?)%>/g; + + /** Used to match property names within property paths. */ + var reIsDeepProp = /\.|\[(?:[^[\]]*|(["'])(?:(?!\1)[^\\]|\\.)*?\1)\]/, + reIsPlainProp = /^\w*$/, + rePropName = /[^.[\]]+|\[(?:(-?\d+(?:\.\d+)?)|(["'])((?:(?!\2)[^\\]|\\.)*?)\2)\]|(?=(?:\.|\[\])(?:\.|\[\]|$))/g; + + /** + * Used to match `RegExp` + * [syntax characters](http://ecma-international.org/ecma-262/7.0/#sec-patterns). + */ + var reRegExpChar = /[\\^$.*+?()[\]{}|]/g, + reHasRegExpChar = RegExp(reRegExpChar.source); + + /** Used to match leading and trailing whitespace. */ + var reTrim = /^\s+|\s+$/g, + reTrimStart = /^\s+/, + reTrimEnd = /\s+$/; + + /** Used to match wrap detail comments. */ + var reWrapComment = /\{(?:\n\/\* \[wrapped with .+\] \*\/)?\n?/, + reWrapDetails = /\{\n\/\* \[wrapped with (.+)\] \*/, + reSplitDetails = /,? & /; + + /** Used to match words composed of alphanumeric characters. */ + var reAsciiWord = /[^\x00-\x2f\x3a-\x40\x5b-\x60\x7b-\x7f]+/g; + + /** Used to match backslashes in property paths. */ + var reEscapeChar = /\\(\\)?/g; + + /** + * Used to match + * [ES template delimiters](http://ecma-international.org/ecma-262/7.0/#sec-template-literal-lexical-components). + */ + var reEsTemplate = /\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g; + + /** Used to match `RegExp` flags from their coerced string values. */ + var reFlags = /\w*$/; + + /** Used to detect bad signed hexadecimal string values. */ + var reIsBadHex = /^[-+]0x[0-9a-f]+$/i; + + /** Used to detect binary string values. */ + var reIsBinary = /^0b[01]+$/i; + + /** Used to detect host constructors (Safari). */ + var reIsHostCtor = /^\[object .+?Constructor\]$/; + + /** Used to detect octal string values. */ + var reIsOctal = /^0o[0-7]+$/i; + + /** Used to detect unsigned integer values. */ + var reIsUint = /^(?:0|[1-9]\d*)$/; + + /** Used to match Latin Unicode letters (excluding mathematical operators). */ + var reLatin = /[\xc0-\xd6\xd8-\xf6\xf8-\xff\u0100-\u017f]/g; + + /** Used to ensure capturing order of template delimiters. */ + var reNoMatch = /($^)/; + + /** Used to match unescaped characters in compiled string literals. */ + var reUnescapedString = /['\n\r\u2028\u2029\\]/g; + + /** Used to compose unicode character classes. */ + var rsAstralRange = '\\ud800-\\udfff', + rsComboMarksRange = '\\u0300-\\u036f', + reComboHalfMarksRange = '\\ufe20-\\ufe2f', + rsComboSymbolsRange = '\\u20d0-\\u20ff', + rsComboRange = rsComboMarksRange + reComboHalfMarksRange + rsComboSymbolsRange, + rsDingbatRange = '\\u2700-\\u27bf', + rsLowerRange = 'a-z\\xdf-\\xf6\\xf8-\\xff', + rsMathOpRange = '\\xac\\xb1\\xd7\\xf7', + rsNonCharRange = '\\x00-\\x2f\\x3a-\\x40\\x5b-\\x60\\x7b-\\xbf', + rsPunctuationRange = '\\u2000-\\u206f', + rsSpaceRange = ' \\t\\x0b\\f\\xa0\\ufeff\\n\\r\\u2028\\u2029\\u1680\\u180e\\u2000\\u2001\\u2002\\u2003\\u2004\\u2005\\u2006\\u2007\\u2008\\u2009\\u200a\\u202f\\u205f\\u3000', + rsUpperRange = 'A-Z\\xc0-\\xd6\\xd8-\\xde', + rsVarRange = '\\ufe0e\\ufe0f', + rsBreakRange = rsMathOpRange + rsNonCharRange + rsPunctuationRange + rsSpaceRange; + + /** Used to compose unicode capture groups. */ + var rsApos = "['\u2019]", + rsAstral = '[' + rsAstralRange + ']', + rsBreak = '[' + rsBreakRange + ']', + rsCombo = '[' + rsComboRange + ']', + rsDigits = '\\d+', + rsDingbat = '[' + rsDingbatRange + ']', + rsLower = '[' + rsLowerRange + ']', + rsMisc = '[^' + rsAstralRange + rsBreakRange + rsDigits + rsDingbatRange + rsLowerRange + rsUpperRange + ']', + rsFitz = '\\ud83c[\\udffb-\\udfff]', + rsModifier = '(?:' + rsCombo + '|' + rsFitz + ')', + rsNonAstral = '[^' + rsAstralRange + ']', + rsRegional = '(?:\\ud83c[\\udde6-\\uddff]){2}', + rsSurrPair = '[\\ud800-\\udbff][\\udc00-\\udfff]', + rsUpper = '[' + rsUpperRange + ']', + rsZWJ = '\\u200d'; + + /** Used to compose unicode regexes. */ + var rsMiscLower = '(?:' + rsLower + '|' + rsMisc + ')', + rsMiscUpper = '(?:' + rsUpper + '|' + rsMisc + ')', + rsOptContrLower = '(?:' + rsApos + '(?:d|ll|m|re|s|t|ve))?', + rsOptContrUpper = '(?:' + rsApos + '(?:D|LL|M|RE|S|T|VE))?', + reOptMod = rsModifier + '?', + rsOptVar = '[' + rsVarRange + ']?', + rsOptJoin = '(?:' + rsZWJ + '(?:' + [rsNonAstral, rsRegional, rsSurrPair].join('|') + ')' + rsOptVar + reOptMod + ')*', + rsOrdLower = '\\d*(?:1st|2nd|3rd|(?![123])\\dth)(?=\\b|[A-Z_])', + rsOrdUpper = '\\d*(?:1ST|2ND|3RD|(?![123])\\dTH)(?=\\b|[a-z_])', + rsSeq = rsOptVar + reOptMod + rsOptJoin, + rsEmoji = '(?:' + [rsDingbat, rsRegional, rsSurrPair].join('|') + ')' + rsSeq, + rsSymbol = '(?:' + [rsNonAstral + rsCombo + '?', rsCombo, rsRegional, rsSurrPair, rsAstral].join('|') + ')'; + + /** Used to match apostrophes. */ + var reApos = RegExp(rsApos, 'g'); + + /** + * Used to match [combining diacritical marks](https://en.wikipedia.org/wiki/Combining_Diacritical_Marks) and + * [combining diacritical marks for symbols](https://en.wikipedia.org/wiki/Combining_Diacritical_Marks_for_Symbols). + */ + var reComboMark = RegExp(rsCombo, 'g'); + + /** Used to match [string symbols](https://mathiasbynens.be/notes/javascript-unicode). */ + var reUnicode = RegExp(rsFitz + '(?=' + rsFitz + ')|' + rsSymbol + rsSeq, 'g'); + + /** Used to match complex or compound words. */ + var reUnicodeWord = RegExp([ + rsUpper + '?' + rsLower + '+' + rsOptContrLower + '(?=' + [rsBreak, rsUpper, '$'].join('|') + ')', + rsMiscUpper + '+' + rsOptContrUpper + '(?=' + [rsBreak, rsUpper + rsMiscLower, '$'].join('|') + ')', + rsUpper + '?' + rsMiscLower + '+' + rsOptContrLower, + rsUpper + '+' + rsOptContrUpper, + rsOrdUpper, + rsOrdLower, + rsDigits, + rsEmoji + ].join('|'), 'g'); + + /** Used to detect strings with [zero-width joiners or code points from the astral planes](http://eev.ee/blog/2015/09/12/dark-corners-of-unicode/). */ + var reHasUnicode = RegExp('[' + rsZWJ + rsAstralRange + rsComboRange + rsVarRange + ']'); + + /** Used to detect strings that need a more robust regexp to match words. */ + var reHasUnicodeWord = /[a-z][A-Z]|[A-Z]{2}[a-z]|[0-9][a-zA-Z]|[a-zA-Z][0-9]|[^a-zA-Z0-9 ]/; + + /** Used to assign default `context` object properties. */ + var contextProps = [ + 'Array', 'Buffer', 'DataView', 'Date', 'Error', 'Float32Array', 'Float64Array', + 'Function', 'Int8Array', 'Int16Array', 'Int32Array', 'Map', 'Math', 'Object', + 'Promise', 'RegExp', 'Set', 'String', 'Symbol', 'TypeError', 'Uint8Array', + 'Uint8ClampedArray', 'Uint16Array', 'Uint32Array', 'WeakMap', + '_', 'clearTimeout', 'isFinite', 'parseInt', 'setTimeout' + ]; + + /** Used to make template sourceURLs easier to identify. */ + var templateCounter = -1; + + /** Used to identify `toStringTag` values of typed arrays. */ + var typedArrayTags = {}; + typedArrayTags[float32Tag] = typedArrayTags[float64Tag] = + typedArrayTags[int8Tag] = typedArrayTags[int16Tag] = + typedArrayTags[int32Tag] = typedArrayTags[uint8Tag] = + typedArrayTags[uint8ClampedTag] = typedArrayTags[uint16Tag] = + typedArrayTags[uint32Tag] = true; + typedArrayTags[argsTag] = typedArrayTags[arrayTag] = + typedArrayTags[arrayBufferTag] = typedArrayTags[boolTag] = + typedArrayTags[dataViewTag] = typedArrayTags[dateTag] = + typedArrayTags[errorTag] = typedArrayTags[funcTag] = + typedArrayTags[mapTag] = typedArrayTags[numberTag] = + typedArrayTags[objectTag] = typedArrayTags[regexpTag] = + typedArrayTags[setTag] = typedArrayTags[stringTag] = + typedArrayTags[weakMapTag] = false; + + /** Used to identify `toStringTag` values supported by `_.clone`. */ + var cloneableTags = {}; + cloneableTags[argsTag] = cloneableTags[arrayTag] = + cloneableTags[arrayBufferTag] = cloneableTags[dataViewTag] = + cloneableTags[boolTag] = cloneableTags[dateTag] = + cloneableTags[float32Tag] = cloneableTags[float64Tag] = + cloneableTags[int8Tag] = cloneableTags[int16Tag] = + cloneableTags[int32Tag] = cloneableTags[mapTag] = + cloneableTags[numberTag] = cloneableTags[objectTag] = + cloneableTags[regexpTag] = cloneableTags[setTag] = + cloneableTags[stringTag] = cloneableTags[symbolTag] = + cloneableTags[uint8Tag] = cloneableTags[uint8ClampedTag] = + cloneableTags[uint16Tag] = cloneableTags[uint32Tag] = true; + cloneableTags[errorTag] = cloneableTags[funcTag] = + cloneableTags[weakMapTag] = false; + + /** Used to map Latin Unicode letters to basic Latin letters. */ + var deburredLetters = { + // Latin-1 Supplement block. + '\xc0': 'A', '\xc1': 'A', '\xc2': 'A', '\xc3': 'A', '\xc4': 'A', '\xc5': 'A', + '\xe0': 'a', '\xe1': 'a', '\xe2': 'a', '\xe3': 'a', '\xe4': 'a', '\xe5': 'a', + '\xc7': 'C', '\xe7': 'c', + '\xd0': 'D', '\xf0': 'd', + '\xc8': 'E', '\xc9': 'E', '\xca': 'E', '\xcb': 'E', + '\xe8': 'e', '\xe9': 'e', '\xea': 'e', '\xeb': 'e', + '\xcc': 'I', '\xcd': 'I', '\xce': 'I', '\xcf': 'I', + '\xec': 'i', '\xed': 'i', '\xee': 'i', '\xef': 'i', + '\xd1': 'N', '\xf1': 'n', + '\xd2': 'O', '\xd3': 'O', '\xd4': 'O', '\xd5': 'O', '\xd6': 'O', '\xd8': 'O', + '\xf2': 'o', '\xf3': 'o', '\xf4': 'o', '\xf5': 'o', '\xf6': 'o', '\xf8': 'o', + '\xd9': 'U', '\xda': 'U', '\xdb': 'U', '\xdc': 'U', + '\xf9': 'u', '\xfa': 'u', '\xfb': 'u', '\xfc': 'u', + '\xdd': 'Y', '\xfd': 'y', '\xff': 'y', + '\xc6': 'Ae', '\xe6': 'ae', + '\xde': 'Th', '\xfe': 'th', + '\xdf': 'ss', + // Latin Extended-A block. + '\u0100': 'A', '\u0102': 'A', '\u0104': 'A', + '\u0101': 'a', '\u0103': 'a', '\u0105': 'a', + '\u0106': 'C', '\u0108': 'C', '\u010a': 'C', '\u010c': 'C', + '\u0107': 'c', '\u0109': 'c', '\u010b': 'c', '\u010d': 'c', + '\u010e': 'D', '\u0110': 'D', '\u010f': 'd', '\u0111': 'd', + '\u0112': 'E', '\u0114': 'E', '\u0116': 'E', '\u0118': 'E', '\u011a': 'E', + '\u0113': 'e', '\u0115': 'e', '\u0117': 'e', '\u0119': 'e', '\u011b': 'e', + '\u011c': 'G', '\u011e': 'G', '\u0120': 'G', '\u0122': 'G', + '\u011d': 'g', '\u011f': 'g', '\u0121': 'g', '\u0123': 'g', + '\u0124': 'H', '\u0126': 'H', '\u0125': 'h', '\u0127': 'h', + '\u0128': 'I', '\u012a': 'I', '\u012c': 'I', '\u012e': 'I', '\u0130': 'I', + '\u0129': 'i', '\u012b': 'i', '\u012d': 'i', '\u012f': 'i', '\u0131': 'i', + '\u0134': 'J', '\u0135': 'j', + '\u0136': 'K', '\u0137': 'k', '\u0138': 'k', + '\u0139': 'L', '\u013b': 'L', '\u013d': 'L', '\u013f': 'L', '\u0141': 'L', + '\u013a': 'l', '\u013c': 'l', '\u013e': 'l', '\u0140': 'l', '\u0142': 'l', + '\u0143': 'N', '\u0145': 'N', '\u0147': 'N', '\u014a': 'N', + '\u0144': 'n', '\u0146': 'n', '\u0148': 'n', '\u014b': 'n', + '\u014c': 'O', '\u014e': 'O', '\u0150': 'O', + '\u014d': 'o', '\u014f': 'o', '\u0151': 'o', + '\u0154': 'R', '\u0156': 'R', '\u0158': 'R', + '\u0155': 'r', '\u0157': 'r', '\u0159': 'r', + '\u015a': 'S', '\u015c': 'S', '\u015e': 'S', '\u0160': 'S', + '\u015b': 's', '\u015d': 's', '\u015f': 's', '\u0161': 's', + '\u0162': 'T', '\u0164': 'T', '\u0166': 'T', + '\u0163': 't', '\u0165': 't', '\u0167': 't', + '\u0168': 'U', '\u016a': 'U', '\u016c': 'U', '\u016e': 'U', '\u0170': 'U', '\u0172': 'U', + '\u0169': 'u', '\u016b': 'u', '\u016d': 'u', '\u016f': 'u', '\u0171': 'u', '\u0173': 'u', + '\u0174': 'W', '\u0175': 'w', + '\u0176': 'Y', '\u0177': 'y', '\u0178': 'Y', + '\u0179': 'Z', '\u017b': 'Z', '\u017d': 'Z', + '\u017a': 'z', '\u017c': 'z', '\u017e': 'z', + '\u0132': 'IJ', '\u0133': 'ij', + '\u0152': 'Oe', '\u0153': 'oe', + '\u0149': "'n", '\u017f': 's' + }; + + /** Used to map characters to HTML entities. */ + var htmlEscapes = { + '&': '&', + '<': '<', + '>': '>', + '"': '"', + "'": ''' + }; + + /** Used to map HTML entities to characters. */ + var htmlUnescapes = { + '&': '&', + '<': '<', + '>': '>', + '"': '"', + ''': "'" + }; + + /** Used to escape characters for inclusion in compiled string literals. */ + var stringEscapes = { + '\\': '\\', + "'": "'", + '\n': 'n', + '\r': 'r', + '\u2028': 'u2028', + '\u2029': 'u2029' + }; + + /** Built-in method references without a dependency on `root`. */ + var freeParseFloat = parseFloat, + freeParseInt = parseInt; + + /** Detect free variable `global` from Node.js. */ + var freeGlobal = typeof global == 'object' && global && global.Object === Object && global; + + /** Detect free variable `self`. */ + var freeSelf = typeof self == 'object' && self && self.Object === Object && self; + + /** Used as a reference to the global object. */ + var root = freeGlobal || freeSelf || Function('return this')(); + + /** Detect free variable `exports`. */ + var freeExports = typeof exports == 'object' && exports && !exports.nodeType && exports; + + /** Detect free variable `module`. */ + var freeModule = freeExports && typeof module == 'object' && module && !module.nodeType && module; + + /** Detect the popular CommonJS extension `module.exports`. */ + var moduleExports = freeModule && freeModule.exports === freeExports; + + /** Detect free variable `process` from Node.js. */ + var freeProcess = moduleExports && freeGlobal.process; + + /** Used to access faster Node.js helpers. */ + var nodeUtil = (function() { + try { + // Use `util.types` for Node.js 10+. + var types = freeModule && freeModule.require && freeModule.require('util').types; + + if (types) { + return types; + } + + // Legacy `process.binding('util')` for Node.js < 10. + return freeProcess && freeProcess.binding && freeProcess.binding('util'); + } catch (e) {} + }()); + + /* Node.js helper references. */ + var nodeIsArrayBuffer = nodeUtil && nodeUtil.isArrayBuffer, + nodeIsDate = nodeUtil && nodeUtil.isDate, + nodeIsMap = nodeUtil && nodeUtil.isMap, + nodeIsRegExp = nodeUtil && nodeUtil.isRegExp, + nodeIsSet = nodeUtil && nodeUtil.isSet, + nodeIsTypedArray = nodeUtil && nodeUtil.isTypedArray; + + /*--------------------------------------------------------------------------*/ + + /** + * A faster alternative to `Function#apply`, this function invokes `func` + * with the `this` binding of `thisArg` and the arguments of `args`. + * + * @private + * @param {Function} func The function to invoke. + * @param {*} thisArg The `this` binding of `func`. + * @param {Array} args The arguments to invoke `func` with. + * @returns {*} Returns the result of `func`. + */ + function apply(func, thisArg, args) { + switch (args.length) { + case 0: return func.call(thisArg); + case 1: return func.call(thisArg, args[0]); + case 2: return func.call(thisArg, args[0], args[1]); + case 3: return func.call(thisArg, args[0], args[1], args[2]); + } + return func.apply(thisArg, args); + } + + /** + * A specialized version of `baseAggregator` for arrays. + * + * @private + * @param {Array} [array] The array to iterate over. + * @param {Function} setter The function to set `accumulator` values. + * @param {Function} iteratee The iteratee to transform keys. + * @param {Object} accumulator The initial aggregated object. + * @returns {Function} Returns `accumulator`. + */ + function arrayAggregator(array, setter, iteratee, accumulator) { + var index = -1, + length = array == null ? 0 : array.length; + + while (++index < length) { + var value = array[index]; + setter(accumulator, value, iteratee(value), array); + } + return accumulator; + } + + /** + * A specialized version of `_.forEach` for arrays without support for + * iteratee shorthands. + * + * @private + * @param {Array} [array] The array to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @returns {Array} Returns `array`. + */ + function arrayEach(array, iteratee) { + var index = -1, + length = array == null ? 0 : array.length; + + while (++index < length) { + if (iteratee(array[index], index, array) === false) { + break; + } + } + return array; + } + + /** + * A specialized version of `_.forEachRight` for arrays without support for + * iteratee shorthands. + * + * @private + * @param {Array} [array] The array to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @returns {Array} Returns `array`. + */ + function arrayEachRight(array, iteratee) { + var length = array == null ? 0 : array.length; + + while (length--) { + if (iteratee(array[length], length, array) === false) { + break; + } + } + return array; + } + + /** + * A specialized version of `_.every` for arrays without support for + * iteratee shorthands. + * + * @private + * @param {Array} [array] The array to iterate over. + * @param {Function} predicate The function invoked per iteration. + * @returns {boolean} Returns `true` if all elements pass the predicate check, + * else `false`. + */ + function arrayEvery(array, predicate) { + var index = -1, + length = array == null ? 0 : array.length; + + while (++index < length) { + if (!predicate(array[index], index, array)) { + return false; + } + } + return true; + } + + /** + * A specialized version of `_.filter` for arrays without support for + * iteratee shorthands. + * + * @private + * @param {Array} [array] The array to iterate over. + * @param {Function} predicate The function invoked per iteration. + * @returns {Array} Returns the new filtered array. + */ + function arrayFilter(array, predicate) { + var index = -1, + length = array == null ? 0 : array.length, + resIndex = 0, + result = []; + + while (++index < length) { + var value = array[index]; + if (predicate(value, index, array)) { + result[resIndex++] = value; + } + } + return result; + } + + /** + * A specialized version of `_.includes` for arrays without support for + * specifying an index to search from. + * + * @private + * @param {Array} [array] The array to inspect. + * @param {*} target The value to search for. + * @returns {boolean} Returns `true` if `target` is found, else `false`. + */ + function arrayIncludes(array, value) { + var length = array == null ? 0 : array.length; + return !!length && baseIndexOf(array, value, 0) > -1; + } + + /** + * This function is like `arrayIncludes` except that it accepts a comparator. + * + * @private + * @param {Array} [array] The array to inspect. + * @param {*} target The value to search for. + * @param {Function} comparator The comparator invoked per element. + * @returns {boolean} Returns `true` if `target` is found, else `false`. + */ + function arrayIncludesWith(array, value, comparator) { + var index = -1, + length = array == null ? 0 : array.length; + + while (++index < length) { + if (comparator(value, array[index])) { + return true; + } + } + return false; + } + + /** + * A specialized version of `_.map` for arrays without support for iteratee + * shorthands. + * + * @private + * @param {Array} [array] The array to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @returns {Array} Returns the new mapped array. + */ + function arrayMap(array, iteratee) { + var index = -1, + length = array == null ? 0 : array.length, + result = Array(length); + + while (++index < length) { + result[index] = iteratee(array[index], index, array); + } + return result; + } + + /** + * Appends the elements of `values` to `array`. + * + * @private + * @param {Array} array The array to modify. + * @param {Array} values The values to append. + * @returns {Array} Returns `array`. + */ + function arrayPush(array, values) { + var index = -1, + length = values.length, + offset = array.length; + + while (++index < length) { + array[offset + index] = values[index]; + } + return array; + } + + /** + * A specialized version of `_.reduce` for arrays without support for + * iteratee shorthands. + * + * @private + * @param {Array} [array] The array to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @param {*} [accumulator] The initial value. + * @param {boolean} [initAccum] Specify using the first element of `array` as + * the initial value. + * @returns {*} Returns the accumulated value. + */ + function arrayReduce(array, iteratee, accumulator, initAccum) { + var index = -1, + length = array == null ? 0 : array.length; + + if (initAccum && length) { + accumulator = array[++index]; + } + while (++index < length) { + accumulator = iteratee(accumulator, array[index], index, array); + } + return accumulator; + } + + /** + * A specialized version of `_.reduceRight` for arrays without support for + * iteratee shorthands. + * + * @private + * @param {Array} [array] The array to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @param {*} [accumulator] The initial value. + * @param {boolean} [initAccum] Specify using the last element of `array` as + * the initial value. + * @returns {*} Returns the accumulated value. + */ + function arrayReduceRight(array, iteratee, accumulator, initAccum) { + var length = array == null ? 0 : array.length; + if (initAccum && length) { + accumulator = array[--length]; + } + while (length--) { + accumulator = iteratee(accumulator, array[length], length, array); + } + return accumulator; + } + + /** + * A specialized version of `_.some` for arrays without support for iteratee + * shorthands. + * + * @private + * @param {Array} [array] The array to iterate over. + * @param {Function} predicate The function invoked per iteration. + * @returns {boolean} Returns `true` if any element passes the predicate check, + * else `false`. + */ + function arraySome(array, predicate) { + var index = -1, + length = array == null ? 0 : array.length; + + while (++index < length) { + if (predicate(array[index], index, array)) { + return true; + } + } + return false; + } + + /** + * Gets the size of an ASCII `string`. + * + * @private + * @param {string} string The string inspect. + * @returns {number} Returns the string size. + */ + var asciiSize = baseProperty('length'); + + /** + * Converts an ASCII `string` to an array. + * + * @private + * @param {string} string The string to convert. + * @returns {Array} Returns the converted array. + */ + function asciiToArray(string) { + return string.split(''); + } + + /** + * Splits an ASCII `string` into an array of its words. + * + * @private + * @param {string} The string to inspect. + * @returns {Array} Returns the words of `string`. + */ + function asciiWords(string) { + return string.match(reAsciiWord) || []; + } + + /** + * The base implementation of methods like `_.findKey` and `_.findLastKey`, + * without support for iteratee shorthands, which iterates over `collection` + * using `eachFunc`. + * + * @private + * @param {Array|Object} collection The collection to inspect. + * @param {Function} predicate The function invoked per iteration. + * @param {Function} eachFunc The function to iterate over `collection`. + * @returns {*} Returns the found element or its key, else `undefined`. + */ + function baseFindKey(collection, predicate, eachFunc) { + var result; + eachFunc(collection, function(value, key, collection) { + if (predicate(value, key, collection)) { + result = key; + return false; + } + }); + return result; + } + + /** + * The base implementation of `_.findIndex` and `_.findLastIndex` without + * support for iteratee shorthands. + * + * @private + * @param {Array} array The array to inspect. + * @param {Function} predicate The function invoked per iteration. + * @param {number} fromIndex The index to search from. + * @param {boolean} [fromRight] Specify iterating from right to left. + * @returns {number} Returns the index of the matched value, else `-1`. + */ + function baseFindIndex(array, predicate, fromIndex, fromRight) { + var length = array.length, + index = fromIndex + (fromRight ? 1 : -1); + + while ((fromRight ? index-- : ++index < length)) { + if (predicate(array[index], index, array)) { + return index; + } + } + return -1; + } + + /** + * The base implementation of `_.indexOf` without `fromIndex` bounds checks. + * + * @private + * @param {Array} array The array to inspect. + * @param {*} value The value to search for. + * @param {number} fromIndex The index to search from. + * @returns {number} Returns the index of the matched value, else `-1`. + */ + function baseIndexOf(array, value, fromIndex) { + return value === value + ? strictIndexOf(array, value, fromIndex) + : baseFindIndex(array, baseIsNaN, fromIndex); + } + + /** + * This function is like `baseIndexOf` except that it accepts a comparator. + * + * @private + * @param {Array} array The array to inspect. + * @param {*} value The value to search for. + * @param {number} fromIndex The index to search from. + * @param {Function} comparator The comparator invoked per element. + * @returns {number} Returns the index of the matched value, else `-1`. + */ + function baseIndexOfWith(array, value, fromIndex, comparator) { + var index = fromIndex - 1, + length = array.length; + + while (++index < length) { + if (comparator(array[index], value)) { + return index; + } + } + return -1; + } + + /** + * The base implementation of `_.isNaN` without support for number objects. + * + * @private + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is `NaN`, else `false`. + */ + function baseIsNaN(value) { + return value !== value; + } + + /** + * The base implementation of `_.mean` and `_.meanBy` without support for + * iteratee shorthands. + * + * @private + * @param {Array} array The array to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @returns {number} Returns the mean. + */ + function baseMean(array, iteratee) { + var length = array == null ? 0 : array.length; + return length ? (baseSum(array, iteratee) / length) : NAN; + } + + /** + * The base implementation of `_.property` without support for deep paths. + * + * @private + * @param {string} key The key of the property to get. + * @returns {Function} Returns the new accessor function. + */ + function baseProperty(key) { + return function(object) { + return object == null ? undefined : object[key]; + }; + } + + /** + * The base implementation of `_.propertyOf` without support for deep paths. + * + * @private + * @param {Object} object The object to query. + * @returns {Function} Returns the new accessor function. + */ + function basePropertyOf(object) { + return function(key) { + return object == null ? undefined : object[key]; + }; + } + + /** + * The base implementation of `_.reduce` and `_.reduceRight`, without support + * for iteratee shorthands, which iterates over `collection` using `eachFunc`. + * + * @private + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @param {*} accumulator The initial value. + * @param {boolean} initAccum Specify using the first or last element of + * `collection` as the initial value. + * @param {Function} eachFunc The function to iterate over `collection`. + * @returns {*} Returns the accumulated value. + */ + function baseReduce(collection, iteratee, accumulator, initAccum, eachFunc) { + eachFunc(collection, function(value, index, collection) { + accumulator = initAccum + ? (initAccum = false, value) + : iteratee(accumulator, value, index, collection); + }); + return accumulator; + } + + /** + * The base implementation of `_.sortBy` which uses `comparer` to define the + * sort order of `array` and replaces criteria objects with their corresponding + * values. + * + * @private + * @param {Array} array The array to sort. + * @param {Function} comparer The function to define sort order. + * @returns {Array} Returns `array`. + */ + function baseSortBy(array, comparer) { + var length = array.length; + + array.sort(comparer); + while (length--) { + array[length] = array[length].value; + } + return array; + } + + /** + * The base implementation of `_.sum` and `_.sumBy` without support for + * iteratee shorthands. + * + * @private + * @param {Array} array The array to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @returns {number} Returns the sum. + */ + function baseSum(array, iteratee) { + var result, + index = -1, + length = array.length; + + while (++index < length) { + var current = iteratee(array[index]); + if (current !== undefined) { + result = result === undefined ? current : (result + current); + } + } + return result; + } + + /** + * The base implementation of `_.times` without support for iteratee shorthands + * or max array length checks. + * + * @private + * @param {number} n The number of times to invoke `iteratee`. + * @param {Function} iteratee The function invoked per iteration. + * @returns {Array} Returns the array of results. + */ + function baseTimes(n, iteratee) { + var index = -1, + result = Array(n); + + while (++index < n) { + result[index] = iteratee(index); + } + return result; + } + + /** + * The base implementation of `_.toPairs` and `_.toPairsIn` which creates an array + * of key-value pairs for `object` corresponding to the property names of `props`. + * + * @private + * @param {Object} object The object to query. + * @param {Array} props The property names to get values for. + * @returns {Object} Returns the key-value pairs. + */ + function baseToPairs(object, props) { + return arrayMap(props, function(key) { + return [key, object[key]]; + }); + } + + /** + * The base implementation of `_.unary` without support for storing metadata. + * + * @private + * @param {Function} func The function to cap arguments for. + * @returns {Function} Returns the new capped function. + */ + function baseUnary(func) { + return function(value) { + return func(value); + }; + } + + /** + * The base implementation of `_.values` and `_.valuesIn` which creates an + * array of `object` property values corresponding to the property names + * of `props`. + * + * @private + * @param {Object} object The object to query. + * @param {Array} props The property names to get values for. + * @returns {Object} Returns the array of property values. + */ + function baseValues(object, props) { + return arrayMap(props, function(key) { + return object[key]; + }); + } + + /** + * Checks if a `cache` value for `key` exists. + * + * @private + * @param {Object} cache The cache to query. + * @param {string} key The key of the entry to check. + * @returns {boolean} Returns `true` if an entry for `key` exists, else `false`. + */ + function cacheHas(cache, key) { + return cache.has(key); + } + + /** + * Used by `_.trim` and `_.trimStart` to get the index of the first string symbol + * that is not found in the character symbols. + * + * @private + * @param {Array} strSymbols The string symbols to inspect. + * @param {Array} chrSymbols The character symbols to find. + * @returns {number} Returns the index of the first unmatched string symbol. + */ + function charsStartIndex(strSymbols, chrSymbols) { + var index = -1, + length = strSymbols.length; + + while (++index < length && baseIndexOf(chrSymbols, strSymbols[index], 0) > -1) {} + return index; + } + + /** + * Used by `_.trim` and `_.trimEnd` to get the index of the last string symbol + * that is not found in the character symbols. + * + * @private + * @param {Array} strSymbols The string symbols to inspect. + * @param {Array} chrSymbols The character symbols to find. + * @returns {number} Returns the index of the last unmatched string symbol. + */ + function charsEndIndex(strSymbols, chrSymbols) { + var index = strSymbols.length; + + while (index-- && baseIndexOf(chrSymbols, strSymbols[index], 0) > -1) {} + return index; + } + + /** + * Gets the number of `placeholder` occurrences in `array`. + * + * @private + * @param {Array} array The array to inspect. + * @param {*} placeholder The placeholder to search for. + * @returns {number} Returns the placeholder count. + */ + function countHolders(array, placeholder) { + var length = array.length, + result = 0; + + while (length--) { + if (array[length] === placeholder) { + ++result; + } + } + return result; + } + + /** + * Used by `_.deburr` to convert Latin-1 Supplement and Latin Extended-A + * letters to basic Latin letters. + * + * @private + * @param {string} letter The matched letter to deburr. + * @returns {string} Returns the deburred letter. + */ + var deburrLetter = basePropertyOf(deburredLetters); + + /** + * Used by `_.escape` to convert characters to HTML entities. + * + * @private + * @param {string} chr The matched character to escape. + * @returns {string} Returns the escaped character. + */ + var escapeHtmlChar = basePropertyOf(htmlEscapes); + + /** + * Used by `_.template` to escape characters for inclusion in compiled string literals. + * + * @private + * @param {string} chr The matched character to escape. + * @returns {string} Returns the escaped character. + */ + function escapeStringChar(chr) { + return '\\' + stringEscapes[chr]; + } + + /** + * Gets the value at `key` of `object`. + * + * @private + * @param {Object} [object] The object to query. + * @param {string} key The key of the property to get. + * @returns {*} Returns the property value. + */ + function getValue(object, key) { + return object == null ? undefined : object[key]; + } + + /** + * Checks if `string` contains Unicode symbols. + * + * @private + * @param {string} string The string to inspect. + * @returns {boolean} Returns `true` if a symbol is found, else `false`. + */ + function hasUnicode(string) { + return reHasUnicode.test(string); + } + + /** + * Checks if `string` contains a word composed of Unicode symbols. + * + * @private + * @param {string} string The string to inspect. + * @returns {boolean} Returns `true` if a word is found, else `false`. + */ + function hasUnicodeWord(string) { + return reHasUnicodeWord.test(string); + } + + /** + * Converts `iterator` to an array. + * + * @private + * @param {Object} iterator The iterator to convert. + * @returns {Array} Returns the converted array. + */ + function iteratorToArray(iterator) { + var data, + result = []; + + while (!(data = iterator.next()).done) { + result.push(data.value); + } + return result; + } + + /** + * Converts `map` to its key-value pairs. + * + * @private + * @param {Object} map The map to convert. + * @returns {Array} Returns the key-value pairs. + */ + function mapToArray(map) { + var index = -1, + result = Array(map.size); + + map.forEach(function(value, key) { + result[++index] = [key, value]; + }); + return result; + } + + /** + * Creates a unary function that invokes `func` with its argument transformed. + * + * @private + * @param {Function} func The function to wrap. + * @param {Function} transform The argument transform. + * @returns {Function} Returns the new function. + */ + function overArg(func, transform) { + return function(arg) { + return func(transform(arg)); + }; + } + + /** + * Replaces all `placeholder` elements in `array` with an internal placeholder + * and returns an array of their indexes. + * + * @private + * @param {Array} array The array to modify. + * @param {*} placeholder The placeholder to replace. + * @returns {Array} Returns the new array of placeholder indexes. + */ + function replaceHolders(array, placeholder) { + var index = -1, + length = array.length, + resIndex = 0, + result = []; + + while (++index < length) { + var value = array[index]; + if (value === placeholder || value === PLACEHOLDER) { + array[index] = PLACEHOLDER; + result[resIndex++] = index; + } + } + return result; + } + + /** + * Converts `set` to an array of its values. + * + * @private + * @param {Object} set The set to convert. + * @returns {Array} Returns the values. + */ + function setToArray(set) { + var index = -1, + result = Array(set.size); + + set.forEach(function(value) { + result[++index] = value; + }); + return result; + } + + /** + * Converts `set` to its value-value pairs. + * + * @private + * @param {Object} set The set to convert. + * @returns {Array} Returns the value-value pairs. + */ + function setToPairs(set) { + var index = -1, + result = Array(set.size); + + set.forEach(function(value) { + result[++index] = [value, value]; + }); + return result; + } + + /** + * A specialized version of `_.indexOf` which performs strict equality + * comparisons of values, i.e. `===`. + * + * @private + * @param {Array} array The array to inspect. + * @param {*} value The value to search for. + * @param {number} fromIndex The index to search from. + * @returns {number} Returns the index of the matched value, else `-1`. + */ + function strictIndexOf(array, value, fromIndex) { + var index = fromIndex - 1, + length = array.length; + + while (++index < length) { + if (array[index] === value) { + return index; + } + } + return -1; + } + + /** + * A specialized version of `_.lastIndexOf` which performs strict equality + * comparisons of values, i.e. `===`. + * + * @private + * @param {Array} array The array to inspect. + * @param {*} value The value to search for. + * @param {number} fromIndex The index to search from. + * @returns {number} Returns the index of the matched value, else `-1`. + */ + function strictLastIndexOf(array, value, fromIndex) { + var index = fromIndex + 1; + while (index--) { + if (array[index] === value) { + return index; + } + } + return index; + } + + /** + * Gets the number of symbols in `string`. + * + * @private + * @param {string} string The string to inspect. + * @returns {number} Returns the string size. + */ + function stringSize(string) { + return hasUnicode(string) + ? unicodeSize(string) + : asciiSize(string); + } + + /** + * Converts `string` to an array. + * + * @private + * @param {string} string The string to convert. + * @returns {Array} Returns the converted array. + */ + function stringToArray(string) { + return hasUnicode(string) + ? unicodeToArray(string) + : asciiToArray(string); + } + + /** + * Used by `_.unescape` to convert HTML entities to characters. + * + * @private + * @param {string} chr The matched character to unescape. + * @returns {string} Returns the unescaped character. + */ + var unescapeHtmlChar = basePropertyOf(htmlUnescapes); + + /** + * Gets the size of a Unicode `string`. + * + * @private + * @param {string} string The string inspect. + * @returns {number} Returns the string size. + */ + function unicodeSize(string) { + var result = reUnicode.lastIndex = 0; + while (reUnicode.test(string)) { + ++result; + } + return result; + } + + /** + * Converts a Unicode `string` to an array. + * + * @private + * @param {string} string The string to convert. + * @returns {Array} Returns the converted array. + */ + function unicodeToArray(string) { + return string.match(reUnicode) || []; + } + + /** + * Splits a Unicode `string` into an array of its words. + * + * @private + * @param {string} The string to inspect. + * @returns {Array} Returns the words of `string`. + */ + function unicodeWords(string) { + return string.match(reUnicodeWord) || []; + } + + /*--------------------------------------------------------------------------*/ + + /** + * Create a new pristine `lodash` function using the `context` object. + * + * @static + * @memberOf _ + * @since 1.1.0 + * @category Util + * @param {Object} [context=root] The context object. + * @returns {Function} Returns a new `lodash` function. + * @example + * + * _.mixin({ 'foo': _.constant('foo') }); + * + * var lodash = _.runInContext(); + * lodash.mixin({ 'bar': lodash.constant('bar') }); + * + * _.isFunction(_.foo); + * // => true + * _.isFunction(_.bar); + * // => false + * + * lodash.isFunction(lodash.foo); + * // => false + * lodash.isFunction(lodash.bar); + * // => true + * + * // Create a suped-up `defer` in Node.js. + * var defer = _.runInContext({ 'setTimeout': setImmediate }).defer; + */ + var runInContext = (function runInContext(context) { + context = context == null ? root : _.defaults(root.Object(), context, _.pick(root, contextProps)); + + /** Built-in constructor references. */ + var Array = context.Array, + Date = context.Date, + Error = context.Error, + Function = context.Function, + Math = context.Math, + Object = context.Object, + RegExp = context.RegExp, + String = context.String, + TypeError = context.TypeError; + + /** Used for built-in method references. */ + var arrayProto = Array.prototype, + funcProto = Function.prototype, + objectProto = Object.prototype; + + /** Used to detect overreaching core-js shims. */ + var coreJsData = context['__core-js_shared__']; + + /** Used to resolve the decompiled source of functions. */ + var funcToString = funcProto.toString; + + /** Used to check objects for own properties. */ + var hasOwnProperty = objectProto.hasOwnProperty; + + /** Used to generate unique IDs. */ + var idCounter = 0; + + /** Used to detect methods masquerading as native. */ + var maskSrcKey = (function() { + var uid = /[^.]+$/.exec(coreJsData && coreJsData.keys && coreJsData.keys.IE_PROTO || ''); + return uid ? ('Symbol(src)_1.' + uid) : ''; + }()); + + /** + * Used to resolve the + * [`toStringTag`](http://ecma-international.org/ecma-262/7.0/#sec-object.prototype.tostring) + * of values. + */ + var nativeObjectToString = objectProto.toString; + + /** Used to infer the `Object` constructor. */ + var objectCtorString = funcToString.call(Object); + + /** Used to restore the original `_` reference in `_.noConflict`. */ + var oldDash = root._; + + /** Used to detect if a method is native. */ + var reIsNative = RegExp('^' + + funcToString.call(hasOwnProperty).replace(reRegExpChar, '\\$&') + .replace(/hasOwnProperty|(function).*?(?=\\\()| for .+?(?=\\\])/g, '$1.*?') + '$' + ); + + /** Built-in value references. */ + var Buffer = moduleExports ? context.Buffer : undefined, + Symbol = context.Symbol, + Uint8Array = context.Uint8Array, + allocUnsafe = Buffer ? Buffer.allocUnsafe : undefined, + getPrototype = overArg(Object.getPrototypeOf, Object), + objectCreate = Object.create, + propertyIsEnumerable = objectProto.propertyIsEnumerable, + splice = arrayProto.splice, + spreadableSymbol = Symbol ? Symbol.isConcatSpreadable : undefined, + symIterator = Symbol ? Symbol.iterator : undefined, + symToStringTag = Symbol ? Symbol.toStringTag : undefined; + + var defineProperty = (function() { + try { + var func = getNative(Object, 'defineProperty'); + func({}, '', {}); + return func; + } catch (e) {} + }()); + + /** Mocked built-ins. */ + var ctxClearTimeout = context.clearTimeout !== root.clearTimeout && context.clearTimeout, + ctxNow = Date && Date.now !== root.Date.now && Date.now, + ctxSetTimeout = context.setTimeout !== root.setTimeout && context.setTimeout; + + /* Built-in method references for those with the same name as other `lodash` methods. */ + var nativeCeil = Math.ceil, + nativeFloor = Math.floor, + nativeGetSymbols = Object.getOwnPropertySymbols, + nativeIsBuffer = Buffer ? Buffer.isBuffer : undefined, + nativeIsFinite = context.isFinite, + nativeJoin = arrayProto.join, + nativeKeys = overArg(Object.keys, Object), + nativeMax = Math.max, + nativeMin = Math.min, + nativeNow = Date.now, + nativeParseInt = context.parseInt, + nativeRandom = Math.random, + nativeReverse = arrayProto.reverse; + + /* Built-in method references that are verified to be native. */ + var DataView = getNative(context, 'DataView'), + Map = getNative(context, 'Map'), + Promise = getNative(context, 'Promise'), + Set = getNative(context, 'Set'), + WeakMap = getNative(context, 'WeakMap'), + nativeCreate = getNative(Object, 'create'); + + /** Used to store function metadata. */ + var metaMap = WeakMap && new WeakMap; + + /** Used to lookup unminified function names. */ + var realNames = {}; + + /** Used to detect maps, sets, and weakmaps. */ + var dataViewCtorString = toSource(DataView), + mapCtorString = toSource(Map), + promiseCtorString = toSource(Promise), + setCtorString = toSource(Set), + weakMapCtorString = toSource(WeakMap); + + /** Used to convert symbols to primitives and strings. */ + var symbolProto = Symbol ? Symbol.prototype : undefined, + symbolValueOf = symbolProto ? symbolProto.valueOf : undefined, + symbolToString = symbolProto ? symbolProto.toString : undefined; + + /*------------------------------------------------------------------------*/ + + /** + * Creates a `lodash` object which wraps `value` to enable implicit method + * chain sequences. Methods that operate on and return arrays, collections, + * and functions can be chained together. Methods that retrieve a single value + * or may return a primitive value will automatically end the chain sequence + * and return the unwrapped value. Otherwise, the value must be unwrapped + * with `_#value`. + * + * Explicit chain sequences, which must be unwrapped with `_#value`, may be + * enabled using `_.chain`. + * + * The execution of chained methods is lazy, that is, it's deferred until + * `_#value` is implicitly or explicitly called. + * + * Lazy evaluation allows several methods to support shortcut fusion. + * Shortcut fusion is an optimization to merge iteratee calls; this avoids + * the creation of intermediate arrays and can greatly reduce the number of + * iteratee executions. Sections of a chain sequence qualify for shortcut + * fusion if the section is applied to an array and iteratees accept only + * one argument. The heuristic for whether a section qualifies for shortcut + * fusion is subject to change. + * + * Chaining is supported in custom builds as long as the `_#value` method is + * directly or indirectly included in the build. + * + * In addition to lodash methods, wrappers have `Array` and `String` methods. + * + * The wrapper `Array` methods are: + * `concat`, `join`, `pop`, `push`, `shift`, `sort`, `splice`, and `unshift` + * + * The wrapper `String` methods are: + * `replace` and `split` + * + * The wrapper methods that support shortcut fusion are: + * `at`, `compact`, `drop`, `dropRight`, `dropWhile`, `filter`, `find`, + * `findLast`, `head`, `initial`, `last`, `map`, `reject`, `reverse`, `slice`, + * `tail`, `take`, `takeRight`, `takeRightWhile`, `takeWhile`, and `toArray` + * + * The chainable wrapper methods are: + * `after`, `ary`, `assign`, `assignIn`, `assignInWith`, `assignWith`, `at`, + * `before`, `bind`, `bindAll`, `bindKey`, `castArray`, `chain`, `chunk`, + * `commit`, `compact`, `concat`, `conforms`, `constant`, `countBy`, `create`, + * `curry`, `debounce`, `defaults`, `defaultsDeep`, `defer`, `delay`, + * `difference`, `differenceBy`, `differenceWith`, `drop`, `dropRight`, + * `dropRightWhile`, `dropWhile`, `extend`, `extendWith`, `fill`, `filter`, + * `flatMap`, `flatMapDeep`, `flatMapDepth`, `flatten`, `flattenDeep`, + * `flattenDepth`, `flip`, `flow`, `flowRight`, `fromPairs`, `functions`, + * `functionsIn`, `groupBy`, `initial`, `intersection`, `intersectionBy`, + * `intersectionWith`, `invert`, `invertBy`, `invokeMap`, `iteratee`, `keyBy`, + * `keys`, `keysIn`, `map`, `mapKeys`, `mapValues`, `matches`, `matchesProperty`, + * `memoize`, `merge`, `mergeWith`, `method`, `methodOf`, `mixin`, `negate`, + * `nthArg`, `omit`, `omitBy`, `once`, `orderBy`, `over`, `overArgs`, + * `overEvery`, `overSome`, `partial`, `partialRight`, `partition`, `pick`, + * `pickBy`, `plant`, `property`, `propertyOf`, `pull`, `pullAll`, `pullAllBy`, + * `pullAllWith`, `pullAt`, `push`, `range`, `rangeRight`, `rearg`, `reject`, + * `remove`, `rest`, `reverse`, `sampleSize`, `set`, `setWith`, `shuffle`, + * `slice`, `sort`, `sortBy`, `splice`, `spread`, `tail`, `take`, `takeRight`, + * `takeRightWhile`, `takeWhile`, `tap`, `throttle`, `thru`, `toArray`, + * `toPairs`, `toPairsIn`, `toPath`, `toPlainObject`, `transform`, `unary`, + * `union`, `unionBy`, `unionWith`, `uniq`, `uniqBy`, `uniqWith`, `unset`, + * `unshift`, `unzip`, `unzipWith`, `update`, `updateWith`, `values`, + * `valuesIn`, `without`, `wrap`, `xor`, `xorBy`, `xorWith`, `zip`, + * `zipObject`, `zipObjectDeep`, and `zipWith` + * + * The wrapper methods that are **not** chainable by default are: + * `add`, `attempt`, `camelCase`, `capitalize`, `ceil`, `clamp`, `clone`, + * `cloneDeep`, `cloneDeepWith`, `cloneWith`, `conformsTo`, `deburr`, + * `defaultTo`, `divide`, `each`, `eachRight`, `endsWith`, `eq`, `escape`, + * `escapeRegExp`, `every`, `find`, `findIndex`, `findKey`, `findLast`, + * `findLastIndex`, `findLastKey`, `first`, `floor`, `forEach`, `forEachRight`, + * `forIn`, `forInRight`, `forOwn`, `forOwnRight`, `get`, `gt`, `gte`, `has`, + * `hasIn`, `head`, `identity`, `includes`, `indexOf`, `inRange`, `invoke`, + * `isArguments`, `isArray`, `isArrayBuffer`, `isArrayLike`, `isArrayLikeObject`, + * `isBoolean`, `isBuffer`, `isDate`, `isElement`, `isEmpty`, `isEqual`, + * `isEqualWith`, `isError`, `isFinite`, `isFunction`, `isInteger`, `isLength`, + * `isMap`, `isMatch`, `isMatchWith`, `isNaN`, `isNative`, `isNil`, `isNull`, + * `isNumber`, `isObject`, `isObjectLike`, `isPlainObject`, `isRegExp`, + * `isSafeInteger`, `isSet`, `isString`, `isUndefined`, `isTypedArray`, + * `isWeakMap`, `isWeakSet`, `join`, `kebabCase`, `last`, `lastIndexOf`, + * `lowerCase`, `lowerFirst`, `lt`, `lte`, `max`, `maxBy`, `mean`, `meanBy`, + * `min`, `minBy`, `multiply`, `noConflict`, `noop`, `now`, `nth`, `pad`, + * `padEnd`, `padStart`, `parseInt`, `pop`, `random`, `reduce`, `reduceRight`, + * `repeat`, `result`, `round`, `runInContext`, `sample`, `shift`, `size`, + * `snakeCase`, `some`, `sortedIndex`, `sortedIndexBy`, `sortedLastIndex`, + * `sortedLastIndexBy`, `startCase`, `startsWith`, `stubArray`, `stubFalse`, + * `stubObject`, `stubString`, `stubTrue`, `subtract`, `sum`, `sumBy`, + * `template`, `times`, `toFinite`, `toInteger`, `toJSON`, `toLength`, + * `toLower`, `toNumber`, `toSafeInteger`, `toString`, `toUpper`, `trim`, + * `trimEnd`, `trimStart`, `truncate`, `unescape`, `uniqueId`, `upperCase`, + * `upperFirst`, `value`, and `words` + * + * @name _ + * @constructor + * @category Seq + * @param {*} value The value to wrap in a `lodash` instance. + * @returns {Object} Returns the new `lodash` wrapper instance. + * @example + * + * function square(n) { + * return n * n; + * } + * + * var wrapped = _([1, 2, 3]); + * + * // Returns an unwrapped value. + * wrapped.reduce(_.add); + * // => 6 + * + * // Returns a wrapped value. + * var squares = wrapped.map(square); + * + * _.isArray(squares); + * // => false + * + * _.isArray(squares.value()); + * // => true + */ + function lodash(value) { + if (isObjectLike(value) && !isArray(value) && !(value instanceof LazyWrapper)) { + if (value instanceof LodashWrapper) { + return value; + } + if (hasOwnProperty.call(value, '__wrapped__')) { + return wrapperClone(value); + } + } + return new LodashWrapper(value); + } + + /** + * The base implementation of `_.create` without support for assigning + * properties to the created object. + * + * @private + * @param {Object} proto The object to inherit from. + * @returns {Object} Returns the new object. + */ + var baseCreate = (function() { + function object() {} + return function(proto) { + if (!isObject(proto)) { + return {}; + } + if (objectCreate) { + return objectCreate(proto); + } + object.prototype = proto; + var result = new object; + object.prototype = undefined; + return result; + }; + }()); + + /** + * The function whose prototype chain sequence wrappers inherit from. + * + * @private + */ + function baseLodash() { + // No operation performed. + } + + /** + * The base constructor for creating `lodash` wrapper objects. + * + * @private + * @param {*} value The value to wrap. + * @param {boolean} [chainAll] Enable explicit method chain sequences. + */ + function LodashWrapper(value, chainAll) { + this.__wrapped__ = value; + this.__actions__ = []; + this.__chain__ = !!chainAll; + this.__index__ = 0; + this.__values__ = undefined; + } + + /** + * By default, the template delimiters used by lodash are like those in + * embedded Ruby (ERB) as well as ES2015 template strings. Change the + * following template settings to use alternative delimiters. + * + * @static + * @memberOf _ + * @type {Object} + */ + lodash.templateSettings = { + + /** + * Used to detect `data` property values to be HTML-escaped. + * + * @memberOf _.templateSettings + * @type {RegExp} + */ + 'escape': reEscape, + + /** + * Used to detect code to be evaluated. + * + * @memberOf _.templateSettings + * @type {RegExp} + */ + 'evaluate': reEvaluate, + + /** + * Used to detect `data` property values to inject. + * + * @memberOf _.templateSettings + * @type {RegExp} + */ + 'interpolate': reInterpolate, + + /** + * Used to reference the data object in the template text. + * + * @memberOf _.templateSettings + * @type {string} + */ + 'variable': '', + + /** + * Used to import variables into the compiled template. + * + * @memberOf _.templateSettings + * @type {Object} + */ + 'imports': { + + /** + * A reference to the `lodash` function. + * + * @memberOf _.templateSettings.imports + * @type {Function} + */ + '_': lodash + } + }; + + // Ensure wrappers are instances of `baseLodash`. + lodash.prototype = baseLodash.prototype; + lodash.prototype.constructor = lodash; + + LodashWrapper.prototype = baseCreate(baseLodash.prototype); + LodashWrapper.prototype.constructor = LodashWrapper; + + /*------------------------------------------------------------------------*/ + + /** + * Creates a lazy wrapper object which wraps `value` to enable lazy evaluation. + * + * @private + * @constructor + * @param {*} value The value to wrap. + */ + function LazyWrapper(value) { + this.__wrapped__ = value; + this.__actions__ = []; + this.__dir__ = 1; + this.__filtered__ = false; + this.__iteratees__ = []; + this.__takeCount__ = MAX_ARRAY_LENGTH; + this.__views__ = []; + } + + /** + * Creates a clone of the lazy wrapper object. + * + * @private + * @name clone + * @memberOf LazyWrapper + * @returns {Object} Returns the cloned `LazyWrapper` object. + */ + function lazyClone() { + var result = new LazyWrapper(this.__wrapped__); + result.__actions__ = copyArray(this.__actions__); + result.__dir__ = this.__dir__; + result.__filtered__ = this.__filtered__; + result.__iteratees__ = copyArray(this.__iteratees__); + result.__takeCount__ = this.__takeCount__; + result.__views__ = copyArray(this.__views__); + return result; + } + + /** + * Reverses the direction of lazy iteration. + * + * @private + * @name reverse + * @memberOf LazyWrapper + * @returns {Object} Returns the new reversed `LazyWrapper` object. + */ + function lazyReverse() { + if (this.__filtered__) { + var result = new LazyWrapper(this); + result.__dir__ = -1; + result.__filtered__ = true; + } else { + result = this.clone(); + result.__dir__ *= -1; + } + return result; + } + + /** + * Extracts the unwrapped value from its lazy wrapper. + * + * @private + * @name value + * @memberOf LazyWrapper + * @returns {*} Returns the unwrapped value. + */ + function lazyValue() { + var array = this.__wrapped__.value(), + dir = this.__dir__, + isArr = isArray(array), + isRight = dir < 0, + arrLength = isArr ? array.length : 0, + view = getView(0, arrLength, this.__views__), + start = view.start, + end = view.end, + length = end - start, + index = isRight ? end : (start - 1), + iteratees = this.__iteratees__, + iterLength = iteratees.length, + resIndex = 0, + takeCount = nativeMin(length, this.__takeCount__); + + if (!isArr || (!isRight && arrLength == length && takeCount == length)) { + return baseWrapperValue(array, this.__actions__); + } + var result = []; + + outer: + while (length-- && resIndex < takeCount) { + index += dir; + + var iterIndex = -1, + value = array[index]; + + while (++iterIndex < iterLength) { + var data = iteratees[iterIndex], + iteratee = data.iteratee, + type = data.type, + computed = iteratee(value); + + if (type == LAZY_MAP_FLAG) { + value = computed; + } else if (!computed) { + if (type == LAZY_FILTER_FLAG) { + continue outer; + } else { + break outer; + } + } + } + result[resIndex++] = value; + } + return result; + } + + // Ensure `LazyWrapper` is an instance of `baseLodash`. + LazyWrapper.prototype = baseCreate(baseLodash.prototype); + LazyWrapper.prototype.constructor = LazyWrapper; + + /*------------------------------------------------------------------------*/ + + /** + * Creates a hash object. + * + * @private + * @constructor + * @param {Array} [entries] The key-value pairs to cache. + */ + function Hash(entries) { + var index = -1, + length = entries == null ? 0 : entries.length; + + this.clear(); + while (++index < length) { + var entry = entries[index]; + this.set(entry[0], entry[1]); + } + } + + /** + * Removes all key-value entries from the hash. + * + * @private + * @name clear + * @memberOf Hash + */ + function hashClear() { + this.__data__ = nativeCreate ? nativeCreate(null) : {}; + this.size = 0; + } + + /** + * Removes `key` and its value from the hash. + * + * @private + * @name delete + * @memberOf Hash + * @param {Object} hash The hash to modify. + * @param {string} key The key of the value to remove. + * @returns {boolean} Returns `true` if the entry was removed, else `false`. + */ + function hashDelete(key) { + var result = this.has(key) && delete this.__data__[key]; + this.size -= result ? 1 : 0; + return result; + } + + /** + * Gets the hash value for `key`. + * + * @private + * @name get + * @memberOf Hash + * @param {string} key The key of the value to get. + * @returns {*} Returns the entry value. + */ + function hashGet(key) { + var data = this.__data__; + if (nativeCreate) { + var result = data[key]; + return result === HASH_UNDEFINED ? undefined : result; + } + return hasOwnProperty.call(data, key) ? data[key] : undefined; + } + + /** + * Checks if a hash value for `key` exists. + * + * @private + * @name has + * @memberOf Hash + * @param {string} key The key of the entry to check. + * @returns {boolean} Returns `true` if an entry for `key` exists, else `false`. + */ + function hashHas(key) { + var data = this.__data__; + return nativeCreate ? (data[key] !== undefined) : hasOwnProperty.call(data, key); + } + + /** + * Sets the hash `key` to `value`. + * + * @private + * @name set + * @memberOf Hash + * @param {string} key The key of the value to set. + * @param {*} value The value to set. + * @returns {Object} Returns the hash instance. + */ + function hashSet(key, value) { + var data = this.__data__; + this.size += this.has(key) ? 0 : 1; + data[key] = (nativeCreate && value === undefined) ? HASH_UNDEFINED : value; + return this; + } + + // Add methods to `Hash`. + Hash.prototype.clear = hashClear; + Hash.prototype['delete'] = hashDelete; + Hash.prototype.get = hashGet; + Hash.prototype.has = hashHas; + Hash.prototype.set = hashSet; + + /*------------------------------------------------------------------------*/ + + /** + * Creates an list cache object. + * + * @private + * @constructor + * @param {Array} [entries] The key-value pairs to cache. + */ + function ListCache(entries) { + var index = -1, + length = entries == null ? 0 : entries.length; + + this.clear(); + while (++index < length) { + var entry = entries[index]; + this.set(entry[0], entry[1]); + } + } + + /** + * Removes all key-value entries from the list cache. + * + * @private + * @name clear + * @memberOf ListCache + */ + function listCacheClear() { + this.__data__ = []; + this.size = 0; + } + + /** + * Removes `key` and its value from the list cache. + * + * @private + * @name delete + * @memberOf ListCache + * @param {string} key The key of the value to remove. + * @returns {boolean} Returns `true` if the entry was removed, else `false`. + */ + function listCacheDelete(key) { + var data = this.__data__, + index = assocIndexOf(data, key); + + if (index < 0) { + return false; + } + var lastIndex = data.length - 1; + if (index == lastIndex) { + data.pop(); + } else { + splice.call(data, index, 1); + } + --this.size; + return true; + } + + /** + * Gets the list cache value for `key`. + * + * @private + * @name get + * @memberOf ListCache + * @param {string} key The key of the value to get. + * @returns {*} Returns the entry value. + */ + function listCacheGet(key) { + var data = this.__data__, + index = assocIndexOf(data, key); + + return index < 0 ? undefined : data[index][1]; + } + + /** + * Checks if a list cache value for `key` exists. + * + * @private + * @name has + * @memberOf ListCache + * @param {string} key The key of the entry to check. + * @returns {boolean} Returns `true` if an entry for `key` exists, else `false`. + */ + function listCacheHas(key) { + return assocIndexOf(this.__data__, key) > -1; + } + + /** + * Sets the list cache `key` to `value`. + * + * @private + * @name set + * @memberOf ListCache + * @param {string} key The key of the value to set. + * @param {*} value The value to set. + * @returns {Object} Returns the list cache instance. + */ + function listCacheSet(key, value) { + var data = this.__data__, + index = assocIndexOf(data, key); + + if (index < 0) { + ++this.size; + data.push([key, value]); + } else { + data[index][1] = value; + } + return this; + } + + // Add methods to `ListCache`. + ListCache.prototype.clear = listCacheClear; + ListCache.prototype['delete'] = listCacheDelete; + ListCache.prototype.get = listCacheGet; + ListCache.prototype.has = listCacheHas; + ListCache.prototype.set = listCacheSet; + + /*------------------------------------------------------------------------*/ + + /** + * Creates a map cache object to store key-value pairs. + * + * @private + * @constructor + * @param {Array} [entries] The key-value pairs to cache. + */ + function MapCache(entries) { + var index = -1, + length = entries == null ? 0 : entries.length; + + this.clear(); + while (++index < length) { + var entry = entries[index]; + this.set(entry[0], entry[1]); + } + } + + /** + * Removes all key-value entries from the map. + * + * @private + * @name clear + * @memberOf MapCache + */ + function mapCacheClear() { + this.size = 0; + this.__data__ = { + 'hash': new Hash, + 'map': new (Map || ListCache), + 'string': new Hash + }; + } + + /** + * Removes `key` and its value from the map. + * + * @private + * @name delete + * @memberOf MapCache + * @param {string} key The key of the value to remove. + * @returns {boolean} Returns `true` if the entry was removed, else `false`. + */ + function mapCacheDelete(key) { + var result = getMapData(this, key)['delete'](key); + this.size -= result ? 1 : 0; + return result; + } + + /** + * Gets the map value for `key`. + * + * @private + * @name get + * @memberOf MapCache + * @param {string} key The key of the value to get. + * @returns {*} Returns the entry value. + */ + function mapCacheGet(key) { + return getMapData(this, key).get(key); + } + + /** + * Checks if a map value for `key` exists. + * + * @private + * @name has + * @memberOf MapCache + * @param {string} key The key of the entry to check. + * @returns {boolean} Returns `true` if an entry for `key` exists, else `false`. + */ + function mapCacheHas(key) { + return getMapData(this, key).has(key); + } + + /** + * Sets the map `key` to `value`. + * + * @private + * @name set + * @memberOf MapCache + * @param {string} key The key of the value to set. + * @param {*} value The value to set. + * @returns {Object} Returns the map cache instance. + */ + function mapCacheSet(key, value) { + var data = getMapData(this, key), + size = data.size; + + data.set(key, value); + this.size += data.size == size ? 0 : 1; + return this; + } + + // Add methods to `MapCache`. + MapCache.prototype.clear = mapCacheClear; + MapCache.prototype['delete'] = mapCacheDelete; + MapCache.prototype.get = mapCacheGet; + MapCache.prototype.has = mapCacheHas; + MapCache.prototype.set = mapCacheSet; + + /*------------------------------------------------------------------------*/ + + /** + * + * Creates an array cache object to store unique values. + * + * @private + * @constructor + * @param {Array} [values] The values to cache. + */ + function SetCache(values) { + var index = -1, + length = values == null ? 0 : values.length; + + this.__data__ = new MapCache; + while (++index < length) { + this.add(values[index]); + } + } + + /** + * Adds `value` to the array cache. + * + * @private + * @name add + * @memberOf SetCache + * @alias push + * @param {*} value The value to cache. + * @returns {Object} Returns the cache instance. + */ + function setCacheAdd(value) { + this.__data__.set(value, HASH_UNDEFINED); + return this; + } + + /** + * Checks if `value` is in the array cache. + * + * @private + * @name has + * @memberOf SetCache + * @param {*} value The value to search for. + * @returns {number} Returns `true` if `value` is found, else `false`. + */ + function setCacheHas(value) { + return this.__data__.has(value); + } + + // Add methods to `SetCache`. + SetCache.prototype.add = SetCache.prototype.push = setCacheAdd; + SetCache.prototype.has = setCacheHas; + + /*------------------------------------------------------------------------*/ + + /** + * Creates a stack cache object to store key-value pairs. + * + * @private + * @constructor + * @param {Array} [entries] The key-value pairs to cache. + */ + function Stack(entries) { + var data = this.__data__ = new ListCache(entries); + this.size = data.size; + } + + /** + * Removes all key-value entries from the stack. + * + * @private + * @name clear + * @memberOf Stack + */ + function stackClear() { + this.__data__ = new ListCache; + this.size = 0; + } + + /** + * Removes `key` and its value from the stack. + * + * @private + * @name delete + * @memberOf Stack + * @param {string} key The key of the value to remove. + * @returns {boolean} Returns `true` if the entry was removed, else `false`. + */ + function stackDelete(key) { + var data = this.__data__, + result = data['delete'](key); + + this.size = data.size; + return result; + } + + /** + * Gets the stack value for `key`. + * + * @private + * @name get + * @memberOf Stack + * @param {string} key The key of the value to get. + * @returns {*} Returns the entry value. + */ + function stackGet(key) { + return this.__data__.get(key); + } + + /** + * Checks if a stack value for `key` exists. + * + * @private + * @name has + * @memberOf Stack + * @param {string} key The key of the entry to check. + * @returns {boolean} Returns `true` if an entry for `key` exists, else `false`. + */ + function stackHas(key) { + return this.__data__.has(key); + } + + /** + * Sets the stack `key` to `value`. + * + * @private + * @name set + * @memberOf Stack + * @param {string} key The key of the value to set. + * @param {*} value The value to set. + * @returns {Object} Returns the stack cache instance. + */ + function stackSet(key, value) { + var data = this.__data__; + if (data instanceof ListCache) { + var pairs = data.__data__; + if (!Map || (pairs.length < LARGE_ARRAY_SIZE - 1)) { + pairs.push([key, value]); + this.size = ++data.size; + return this; + } + data = this.__data__ = new MapCache(pairs); + } + data.set(key, value); + this.size = data.size; + return this; + } + + // Add methods to `Stack`. + Stack.prototype.clear = stackClear; + Stack.prototype['delete'] = stackDelete; + Stack.prototype.get = stackGet; + Stack.prototype.has = stackHas; + Stack.prototype.set = stackSet; + + /*------------------------------------------------------------------------*/ + + /** + * Creates an array of the enumerable property names of the array-like `value`. + * + * @private + * @param {*} value The value to query. + * @param {boolean} inherited Specify returning inherited property names. + * @returns {Array} Returns the array of property names. + */ + function arrayLikeKeys(value, inherited) { + var isArr = isArray(value), + isArg = !isArr && isArguments(value), + isBuff = !isArr && !isArg && isBuffer(value), + isType = !isArr && !isArg && !isBuff && isTypedArray(value), + skipIndexes = isArr || isArg || isBuff || isType, + result = skipIndexes ? baseTimes(value.length, String) : [], + length = result.length; + + for (var key in value) { + if ((inherited || hasOwnProperty.call(value, key)) && + !(skipIndexes && ( + // Safari 9 has enumerable `arguments.length` in strict mode. + key == 'length' || + // Node.js 0.10 has enumerable non-index properties on buffers. + (isBuff && (key == 'offset' || key == 'parent')) || + // PhantomJS 2 has enumerable non-index properties on typed arrays. + (isType && (key == 'buffer' || key == 'byteLength' || key == 'byteOffset')) || + // Skip index properties. + isIndex(key, length) + ))) { + result.push(key); + } + } + return result; + } + + /** + * A specialized version of `_.sample` for arrays. + * + * @private + * @param {Array} array The array to sample. + * @returns {*} Returns the random element. + */ + function arraySample(array) { + var length = array.length; + return length ? array[baseRandom(0, length - 1)] : undefined; + } + + /** + * A specialized version of `_.sampleSize` for arrays. + * + * @private + * @param {Array} array The array to sample. + * @param {number} n The number of elements to sample. + * @returns {Array} Returns the random elements. + */ + function arraySampleSize(array, n) { + return shuffleSelf(copyArray(array), baseClamp(n, 0, array.length)); + } + + /** + * A specialized version of `_.shuffle` for arrays. + * + * @private + * @param {Array} array The array to shuffle. + * @returns {Array} Returns the new shuffled array. + */ + function arrayShuffle(array) { + return shuffleSelf(copyArray(array)); + } + + /** + * This function is like `assignValue` except that it doesn't assign + * `undefined` values. + * + * @private + * @param {Object} object The object to modify. + * @param {string} key The key of the property to assign. + * @param {*} value The value to assign. + */ + function assignMergeValue(object, key, value) { + if ((value !== undefined && !eq(object[key], value)) || + (value === undefined && !(key in object))) { + baseAssignValue(object, key, value); + } + } + + /** + * Assigns `value` to `key` of `object` if the existing value is not equivalent + * using [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero) + * for equality comparisons. + * + * @private + * @param {Object} object The object to modify. + * @param {string} key The key of the property to assign. + * @param {*} value The value to assign. + */ + function assignValue(object, key, value) { + var objValue = object[key]; + if (!(hasOwnProperty.call(object, key) && eq(objValue, value)) || + (value === undefined && !(key in object))) { + baseAssignValue(object, key, value); + } + } + + /** + * Gets the index at which the `key` is found in `array` of key-value pairs. + * + * @private + * @param {Array} array The array to inspect. + * @param {*} key The key to search for. + * @returns {number} Returns the index of the matched value, else `-1`. + */ + function assocIndexOf(array, key) { + var length = array.length; + while (length--) { + if (eq(array[length][0], key)) { + return length; + } + } + return -1; + } + + /** + * Aggregates elements of `collection` on `accumulator` with keys transformed + * by `iteratee` and values set by `setter`. + * + * @private + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} setter The function to set `accumulator` values. + * @param {Function} iteratee The iteratee to transform keys. + * @param {Object} accumulator The initial aggregated object. + * @returns {Function} Returns `accumulator`. + */ + function baseAggregator(collection, setter, iteratee, accumulator) { + baseEach(collection, function(value, key, collection) { + setter(accumulator, value, iteratee(value), collection); + }); + return accumulator; + } + + /** + * The base implementation of `_.assign` without support for multiple sources + * or `customizer` functions. + * + * @private + * @param {Object} object The destination object. + * @param {Object} source The source object. + * @returns {Object} Returns `object`. + */ + function baseAssign(object, source) { + return object && copyObject(source, keys(source), object); + } + + /** + * The base implementation of `_.assignIn` without support for multiple sources + * or `customizer` functions. + * + * @private + * @param {Object} object The destination object. + * @param {Object} source The source object. + * @returns {Object} Returns `object`. + */ + function baseAssignIn(object, source) { + return object && copyObject(source, keysIn(source), object); + } + + /** + * The base implementation of `assignValue` and `assignMergeValue` without + * value checks. + * + * @private + * @param {Object} object The object to modify. + * @param {string} key The key of the property to assign. + * @param {*} value The value to assign. + */ + function baseAssignValue(object, key, value) { + if (key == '__proto__' && defineProperty) { + defineProperty(object, key, { + 'configurable': true, + 'enumerable': true, + 'value': value, + 'writable': true + }); + } else { + object[key] = value; + } + } + + /** + * The base implementation of `_.at` without support for individual paths. + * + * @private + * @param {Object} object The object to iterate over. + * @param {string[]} paths The property paths to pick. + * @returns {Array} Returns the picked elements. + */ + function baseAt(object, paths) { + var index = -1, + length = paths.length, + result = Array(length), + skip = object == null; + + while (++index < length) { + result[index] = skip ? undefined : get(object, paths[index]); + } + return result; + } + + /** + * The base implementation of `_.clamp` which doesn't coerce arguments. + * + * @private + * @param {number} number The number to clamp. + * @param {number} [lower] The lower bound. + * @param {number} upper The upper bound. + * @returns {number} Returns the clamped number. + */ + function baseClamp(number, lower, upper) { + if (number === number) { + if (upper !== undefined) { + number = number <= upper ? number : upper; + } + if (lower !== undefined) { + number = number >= lower ? number : lower; + } + } + return number; + } + + /** + * The base implementation of `_.clone` and `_.cloneDeep` which tracks + * traversed objects. + * + * @private + * @param {*} value The value to clone. + * @param {boolean} bitmask The bitmask flags. + * 1 - Deep clone + * 2 - Flatten inherited properties + * 4 - Clone symbols + * @param {Function} [customizer] The function to customize cloning. + * @param {string} [key] The key of `value`. + * @param {Object} [object] The parent object of `value`. + * @param {Object} [stack] Tracks traversed objects and their clone counterparts. + * @returns {*} Returns the cloned value. + */ + function baseClone(value, bitmask, customizer, key, object, stack) { + var result, + isDeep = bitmask & CLONE_DEEP_FLAG, + isFlat = bitmask & CLONE_FLAT_FLAG, + isFull = bitmask & CLONE_SYMBOLS_FLAG; + + if (customizer) { + result = object ? customizer(value, key, object, stack) : customizer(value); + } + if (result !== undefined) { + return result; + } + if (!isObject(value)) { + return value; + } + var isArr = isArray(value); + if (isArr) { + result = initCloneArray(value); + if (!isDeep) { + return copyArray(value, result); + } + } else { + var tag = getTag(value), + isFunc = tag == funcTag || tag == genTag; + + if (isBuffer(value)) { + return cloneBuffer(value, isDeep); + } + if (tag == objectTag || tag == argsTag || (isFunc && !object)) { + result = (isFlat || isFunc) ? {} : initCloneObject(value); + if (!isDeep) { + return isFlat + ? copySymbolsIn(value, baseAssignIn(result, value)) + : copySymbols(value, baseAssign(result, value)); + } + } else { + if (!cloneableTags[tag]) { + return object ? value : {}; + } + result = initCloneByTag(value, tag, isDeep); + } + } + // Check for circular references and return its corresponding clone. + stack || (stack = new Stack); + var stacked = stack.get(value); + if (stacked) { + return stacked; + } + stack.set(value, result); + + if (isSet(value)) { + value.forEach(function(subValue) { + result.add(baseClone(subValue, bitmask, customizer, subValue, value, stack)); + }); + + return result; + } + + if (isMap(value)) { + value.forEach(function(subValue, key) { + result.set(key, baseClone(subValue, bitmask, customizer, key, value, stack)); + }); + + return result; + } + + var keysFunc = isFull + ? (isFlat ? getAllKeysIn : getAllKeys) + : (isFlat ? keysIn : keys); + + var props = isArr ? undefined : keysFunc(value); + arrayEach(props || value, function(subValue, key) { + if (props) { + key = subValue; + subValue = value[key]; + } + // Recursively populate clone (susceptible to call stack limits). + assignValue(result, key, baseClone(subValue, bitmask, customizer, key, value, stack)); + }); + return result; + } + + /** + * The base implementation of `_.conforms` which doesn't clone `source`. + * + * @private + * @param {Object} source The object of property predicates to conform to. + * @returns {Function} Returns the new spec function. + */ + function baseConforms(source) { + var props = keys(source); + return function(object) { + return baseConformsTo(object, source, props); + }; + } + + /** + * The base implementation of `_.conformsTo` which accepts `props` to check. + * + * @private + * @param {Object} object The object to inspect. + * @param {Object} source The object of property predicates to conform to. + * @returns {boolean} Returns `true` if `object` conforms, else `false`. + */ + function baseConformsTo(object, source, props) { + var length = props.length; + if (object == null) { + return !length; + } + object = Object(object); + while (length--) { + var key = props[length], + predicate = source[key], + value = object[key]; + + if ((value === undefined && !(key in object)) || !predicate(value)) { + return false; + } + } + return true; + } + + /** + * The base implementation of `_.delay` and `_.defer` which accepts `args` + * to provide to `func`. + * + * @private + * @param {Function} func The function to delay. + * @param {number} wait The number of milliseconds to delay invocation. + * @param {Array} args The arguments to provide to `func`. + * @returns {number|Object} Returns the timer id or timeout object. + */ + function baseDelay(func, wait, args) { + if (typeof func != 'function') { + throw new TypeError(FUNC_ERROR_TEXT); + } + return setTimeout(function() { func.apply(undefined, args); }, wait); + } + + /** + * The base implementation of methods like `_.difference` without support + * for excluding multiple arrays or iteratee shorthands. + * + * @private + * @param {Array} array The array to inspect. + * @param {Array} values The values to exclude. + * @param {Function} [iteratee] The iteratee invoked per element. + * @param {Function} [comparator] The comparator invoked per element. + * @returns {Array} Returns the new array of filtered values. + */ + function baseDifference(array, values, iteratee, comparator) { + var index = -1, + includes = arrayIncludes, + isCommon = true, + length = array.length, + result = [], + valuesLength = values.length; + + if (!length) { + return result; + } + if (iteratee) { + values = arrayMap(values, baseUnary(iteratee)); + } + if (comparator) { + includes = arrayIncludesWith; + isCommon = false; + } + else if (values.length >= LARGE_ARRAY_SIZE) { + includes = cacheHas; + isCommon = false; + values = new SetCache(values); + } + outer: + while (++index < length) { + var value = array[index], + computed = iteratee == null ? value : iteratee(value); + + value = (comparator || value !== 0) ? value : 0; + if (isCommon && computed === computed) { + var valuesIndex = valuesLength; + while (valuesIndex--) { + if (values[valuesIndex] === computed) { + continue outer; + } + } + result.push(value); + } + else if (!includes(values, computed, comparator)) { + result.push(value); + } + } + return result; + } + + /** + * The base implementation of `_.forEach` without support for iteratee shorthands. + * + * @private + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @returns {Array|Object} Returns `collection`. + */ + var baseEach = createBaseEach(baseForOwn); + + /** + * The base implementation of `_.forEachRight` without support for iteratee shorthands. + * + * @private + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @returns {Array|Object} Returns `collection`. + */ + var baseEachRight = createBaseEach(baseForOwnRight, true); + + /** + * The base implementation of `_.every` without support for iteratee shorthands. + * + * @private + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} predicate The function invoked per iteration. + * @returns {boolean} Returns `true` if all elements pass the predicate check, + * else `false` + */ + function baseEvery(collection, predicate) { + var result = true; + baseEach(collection, function(value, index, collection) { + result = !!predicate(value, index, collection); + return result; + }); + return result; + } + + /** + * The base implementation of methods like `_.max` and `_.min` which accepts a + * `comparator` to determine the extremum value. + * + * @private + * @param {Array} array The array to iterate over. + * @param {Function} iteratee The iteratee invoked per iteration. + * @param {Function} comparator The comparator used to compare values. + * @returns {*} Returns the extremum value. + */ + function baseExtremum(array, iteratee, comparator) { + var index = -1, + length = array.length; + + while (++index < length) { + var value = array[index], + current = iteratee(value); + + if (current != null && (computed === undefined + ? (current === current && !isSymbol(current)) + : comparator(current, computed) + )) { + var computed = current, + result = value; + } + } + return result; + } + + /** + * The base implementation of `_.fill` without an iteratee call guard. + * + * @private + * @param {Array} array The array to fill. + * @param {*} value The value to fill `array` with. + * @param {number} [start=0] The start position. + * @param {number} [end=array.length] The end position. + * @returns {Array} Returns `array`. + */ + function baseFill(array, value, start, end) { + var length = array.length; + + start = toInteger(start); + if (start < 0) { + start = -start > length ? 0 : (length + start); + } + end = (end === undefined || end > length) ? length : toInteger(end); + if (end < 0) { + end += length; + } + end = start > end ? 0 : toLength(end); + while (start < end) { + array[start++] = value; + } + return array; + } + + /** + * The base implementation of `_.filter` without support for iteratee shorthands. + * + * @private + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} predicate The function invoked per iteration. + * @returns {Array} Returns the new filtered array. + */ + function baseFilter(collection, predicate) { + var result = []; + baseEach(collection, function(value, index, collection) { + if (predicate(value, index, collection)) { + result.push(value); + } + }); + return result; + } + + /** + * The base implementation of `_.flatten` with support for restricting flattening. + * + * @private + * @param {Array} array The array to flatten. + * @param {number} depth The maximum recursion depth. + * @param {boolean} [predicate=isFlattenable] The function invoked per iteration. + * @param {boolean} [isStrict] Restrict to values that pass `predicate` checks. + * @param {Array} [result=[]] The initial result value. + * @returns {Array} Returns the new flattened array. + */ + function baseFlatten(array, depth, predicate, isStrict, result) { + var index = -1, + length = array.length; + + predicate || (predicate = isFlattenable); + result || (result = []); + + while (++index < length) { + var value = array[index]; + if (depth > 0 && predicate(value)) { + if (depth > 1) { + // Recursively flatten arrays (susceptible to call stack limits). + baseFlatten(value, depth - 1, predicate, isStrict, result); + } else { + arrayPush(result, value); + } + } else if (!isStrict) { + result[result.length] = value; + } + } + return result; + } + + /** + * The base implementation of `baseForOwn` which iterates over `object` + * properties returned by `keysFunc` and invokes `iteratee` for each property. + * Iteratee functions may exit iteration early by explicitly returning `false`. + * + * @private + * @param {Object} object The object to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @param {Function} keysFunc The function to get the keys of `object`. + * @returns {Object} Returns `object`. + */ + var baseFor = createBaseFor(); + + /** + * This function is like `baseFor` except that it iterates over properties + * in the opposite order. + * + * @private + * @param {Object} object The object to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @param {Function} keysFunc The function to get the keys of `object`. + * @returns {Object} Returns `object`. + */ + var baseForRight = createBaseFor(true); + + /** + * The base implementation of `_.forOwn` without support for iteratee shorthands. + * + * @private + * @param {Object} object The object to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @returns {Object} Returns `object`. + */ + function baseForOwn(object, iteratee) { + return object && baseFor(object, iteratee, keys); + } + + /** + * The base implementation of `_.forOwnRight` without support for iteratee shorthands. + * + * @private + * @param {Object} object The object to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @returns {Object} Returns `object`. + */ + function baseForOwnRight(object, iteratee) { + return object && baseForRight(object, iteratee, keys); + } + + /** + * The base implementation of `_.functions` which creates an array of + * `object` function property names filtered from `props`. + * + * @private + * @param {Object} object The object to inspect. + * @param {Array} props The property names to filter. + * @returns {Array} Returns the function names. + */ + function baseFunctions(object, props) { + return arrayFilter(props, function(key) { + return isFunction(object[key]); + }); + } + + /** + * The base implementation of `_.get` without support for default values. + * + * @private + * @param {Object} object The object to query. + * @param {Array|string} path The path of the property to get. + * @returns {*} Returns the resolved value. + */ + function baseGet(object, path) { + path = castPath(path, object); + + var index = 0, + length = path.length; + + while (object != null && index < length) { + object = object[toKey(path[index++])]; + } + return (index && index == length) ? object : undefined; + } + + /** + * The base implementation of `getAllKeys` and `getAllKeysIn` which uses + * `keysFunc` and `symbolsFunc` to get the enumerable property names and + * symbols of `object`. + * + * @private + * @param {Object} object The object to query. + * @param {Function} keysFunc The function to get the keys of `object`. + * @param {Function} symbolsFunc The function to get the symbols of `object`. + * @returns {Array} Returns the array of property names and symbols. + */ + function baseGetAllKeys(object, keysFunc, symbolsFunc) { + var result = keysFunc(object); + return isArray(object) ? result : arrayPush(result, symbolsFunc(object)); + } + + /** + * The base implementation of `getTag` without fallbacks for buggy environments. + * + * @private + * @param {*} value The value to query. + * @returns {string} Returns the `toStringTag`. + */ + function baseGetTag(value) { + if (value == null) { + return value === undefined ? undefinedTag : nullTag; + } + return (symToStringTag && symToStringTag in Object(value)) + ? getRawTag(value) + : objectToString(value); + } + + /** + * The base implementation of `_.gt` which doesn't coerce arguments. + * + * @private + * @param {*} value The value to compare. + * @param {*} other The other value to compare. + * @returns {boolean} Returns `true` if `value` is greater than `other`, + * else `false`. + */ + function baseGt(value, other) { + return value > other; + } + + /** + * The base implementation of `_.has` without support for deep paths. + * + * @private + * @param {Object} [object] The object to query. + * @param {Array|string} key The key to check. + * @returns {boolean} Returns `true` if `key` exists, else `false`. + */ + function baseHas(object, key) { + return object != null && hasOwnProperty.call(object, key); + } + + /** + * The base implementation of `_.hasIn` without support for deep paths. + * + * @private + * @param {Object} [object] The object to query. + * @param {Array|string} key The key to check. + * @returns {boolean} Returns `true` if `key` exists, else `false`. + */ + function baseHasIn(object, key) { + return object != null && key in Object(object); + } + + /** + * The base implementation of `_.inRange` which doesn't coerce arguments. + * + * @private + * @param {number} number The number to check. + * @param {number} start The start of the range. + * @param {number} end The end of the range. + * @returns {boolean} Returns `true` if `number` is in the range, else `false`. + */ + function baseInRange(number, start, end) { + return number >= nativeMin(start, end) && number < nativeMax(start, end); + } + + /** + * The base implementation of methods like `_.intersection`, without support + * for iteratee shorthands, that accepts an array of arrays to inspect. + * + * @private + * @param {Array} arrays The arrays to inspect. + * @param {Function} [iteratee] The iteratee invoked per element. + * @param {Function} [comparator] The comparator invoked per element. + * @returns {Array} Returns the new array of shared values. + */ + function baseIntersection(arrays, iteratee, comparator) { + var includes = comparator ? arrayIncludesWith : arrayIncludes, + length = arrays[0].length, + othLength = arrays.length, + othIndex = othLength, + caches = Array(othLength), + maxLength = Infinity, + result = []; + + while (othIndex--) { + var array = arrays[othIndex]; + if (othIndex && iteratee) { + array = arrayMap(array, baseUnary(iteratee)); + } + maxLength = nativeMin(array.length, maxLength); + caches[othIndex] = !comparator && (iteratee || (length >= 120 && array.length >= 120)) + ? new SetCache(othIndex && array) + : undefined; + } + array = arrays[0]; + + var index = -1, + seen = caches[0]; + + outer: + while (++index < length && result.length < maxLength) { + var value = array[index], + computed = iteratee ? iteratee(value) : value; + + value = (comparator || value !== 0) ? value : 0; + if (!(seen + ? cacheHas(seen, computed) + : includes(result, computed, comparator) + )) { + othIndex = othLength; + while (--othIndex) { + var cache = caches[othIndex]; + if (!(cache + ? cacheHas(cache, computed) + : includes(arrays[othIndex], computed, comparator)) + ) { + continue outer; + } + } + if (seen) { + seen.push(computed); + } + result.push(value); + } + } + return result; + } + + /** + * The base implementation of `_.invert` and `_.invertBy` which inverts + * `object` with values transformed by `iteratee` and set by `setter`. + * + * @private + * @param {Object} object The object to iterate over. + * @param {Function} setter The function to set `accumulator` values. + * @param {Function} iteratee The iteratee to transform values. + * @param {Object} accumulator The initial inverted object. + * @returns {Function} Returns `accumulator`. + */ + function baseInverter(object, setter, iteratee, accumulator) { + baseForOwn(object, function(value, key, object) { + setter(accumulator, iteratee(value), key, object); + }); + return accumulator; + } + + /** + * The base implementation of `_.invoke` without support for individual + * method arguments. + * + * @private + * @param {Object} object The object to query. + * @param {Array|string} path The path of the method to invoke. + * @param {Array} args The arguments to invoke the method with. + * @returns {*} Returns the result of the invoked method. + */ + function baseInvoke(object, path, args) { + path = castPath(path, object); + object = parent(object, path); + var func = object == null ? object : object[toKey(last(path))]; + return func == null ? undefined : apply(func, object, args); + } + + /** + * The base implementation of `_.isArguments`. + * + * @private + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is an `arguments` object, + */ + function baseIsArguments(value) { + return isObjectLike(value) && baseGetTag(value) == argsTag; + } + + /** + * The base implementation of `_.isArrayBuffer` without Node.js optimizations. + * + * @private + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is an array buffer, else `false`. + */ + function baseIsArrayBuffer(value) { + return isObjectLike(value) && baseGetTag(value) == arrayBufferTag; + } + + /** + * The base implementation of `_.isDate` without Node.js optimizations. + * + * @private + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a date object, else `false`. + */ + function baseIsDate(value) { + return isObjectLike(value) && baseGetTag(value) == dateTag; + } + + /** + * The base implementation of `_.isEqual` which supports partial comparisons + * and tracks traversed objects. + * + * @private + * @param {*} value The value to compare. + * @param {*} other The other value to compare. + * @param {boolean} bitmask The bitmask flags. + * 1 - Unordered comparison + * 2 - Partial comparison + * @param {Function} [customizer] The function to customize comparisons. + * @param {Object} [stack] Tracks traversed `value` and `other` objects. + * @returns {boolean} Returns `true` if the values are equivalent, else `false`. + */ + function baseIsEqual(value, other, bitmask, customizer, stack) { + if (value === other) { + return true; + } + if (value == null || other == null || (!isObjectLike(value) && !isObjectLike(other))) { + return value !== value && other !== other; + } + return baseIsEqualDeep(value, other, bitmask, customizer, baseIsEqual, stack); + } + + /** + * A specialized version of `baseIsEqual` for arrays and objects which performs + * deep comparisons and tracks traversed objects enabling objects with circular + * references to be compared. + * + * @private + * @param {Object} object The object to compare. + * @param {Object} other The other object to compare. + * @param {number} bitmask The bitmask flags. See `baseIsEqual` for more details. + * @param {Function} customizer The function to customize comparisons. + * @param {Function} equalFunc The function to determine equivalents of values. + * @param {Object} [stack] Tracks traversed `object` and `other` objects. + * @returns {boolean} Returns `true` if the objects are equivalent, else `false`. + */ + function baseIsEqualDeep(object, other, bitmask, customizer, equalFunc, stack) { + var objIsArr = isArray(object), + othIsArr = isArray(other), + objTag = objIsArr ? arrayTag : getTag(object), + othTag = othIsArr ? arrayTag : getTag(other); + + objTag = objTag == argsTag ? objectTag : objTag; + othTag = othTag == argsTag ? objectTag : othTag; + + var objIsObj = objTag == objectTag, + othIsObj = othTag == objectTag, + isSameTag = objTag == othTag; + + if (isSameTag && isBuffer(object)) { + if (!isBuffer(other)) { + return false; + } + objIsArr = true; + objIsObj = false; + } + if (isSameTag && !objIsObj) { + stack || (stack = new Stack); + return (objIsArr || isTypedArray(object)) + ? equalArrays(object, other, bitmask, customizer, equalFunc, stack) + : equalByTag(object, other, objTag, bitmask, customizer, equalFunc, stack); + } + if (!(bitmask & COMPARE_PARTIAL_FLAG)) { + var objIsWrapped = objIsObj && hasOwnProperty.call(object, '__wrapped__'), + othIsWrapped = othIsObj && hasOwnProperty.call(other, '__wrapped__'); + + if (objIsWrapped || othIsWrapped) { + var objUnwrapped = objIsWrapped ? object.value() : object, + othUnwrapped = othIsWrapped ? other.value() : other; + + stack || (stack = new Stack); + return equalFunc(objUnwrapped, othUnwrapped, bitmask, customizer, stack); + } + } + if (!isSameTag) { + return false; + } + stack || (stack = new Stack); + return equalObjects(object, other, bitmask, customizer, equalFunc, stack); + } + + /** + * The base implementation of `_.isMap` without Node.js optimizations. + * + * @private + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a map, else `false`. + */ + function baseIsMap(value) { + return isObjectLike(value) && getTag(value) == mapTag; + } + + /** + * The base implementation of `_.isMatch` without support for iteratee shorthands. + * + * @private + * @param {Object} object The object to inspect. + * @param {Object} source The object of property values to match. + * @param {Array} matchData The property names, values, and compare flags to match. + * @param {Function} [customizer] The function to customize comparisons. + * @returns {boolean} Returns `true` if `object` is a match, else `false`. + */ + function baseIsMatch(object, source, matchData, customizer) { + var index = matchData.length, + length = index, + noCustomizer = !customizer; + + if (object == null) { + return !length; + } + object = Object(object); + while (index--) { + var data = matchData[index]; + if ((noCustomizer && data[2]) + ? data[1] !== object[data[0]] + : !(data[0] in object) + ) { + return false; + } + } + while (++index < length) { + data = matchData[index]; + var key = data[0], + objValue = object[key], + srcValue = data[1]; + + if (noCustomizer && data[2]) { + if (objValue === undefined && !(key in object)) { + return false; + } + } else { + var stack = new Stack; + if (customizer) { + var result = customizer(objValue, srcValue, key, object, source, stack); + } + if (!(result === undefined + ? baseIsEqual(srcValue, objValue, COMPARE_PARTIAL_FLAG | COMPARE_UNORDERED_FLAG, customizer, stack) + : result + )) { + return false; + } + } + } + return true; + } + + /** + * The base implementation of `_.isNative` without bad shim checks. + * + * @private + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a native function, + * else `false`. + */ + function baseIsNative(value) { + if (!isObject(value) || isMasked(value)) { + return false; + } + var pattern = isFunction(value) ? reIsNative : reIsHostCtor; + return pattern.test(toSource(value)); + } + + /** + * The base implementation of `_.isRegExp` without Node.js optimizations. + * + * @private + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a regexp, else `false`. + */ + function baseIsRegExp(value) { + return isObjectLike(value) && baseGetTag(value) == regexpTag; + } + + /** + * The base implementation of `_.isSet` without Node.js optimizations. + * + * @private + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a set, else `false`. + */ + function baseIsSet(value) { + return isObjectLike(value) && getTag(value) == setTag; + } + + /** + * The base implementation of `_.isTypedArray` without Node.js optimizations. + * + * @private + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a typed array, else `false`. + */ + function baseIsTypedArray(value) { + return isObjectLike(value) && + isLength(value.length) && !!typedArrayTags[baseGetTag(value)]; + } + + /** + * The base implementation of `_.iteratee`. + * + * @private + * @param {*} [value=_.identity] The value to convert to an iteratee. + * @returns {Function} Returns the iteratee. + */ + function baseIteratee(value) { + // Don't store the `typeof` result in a variable to avoid a JIT bug in Safari 9. + // See https://bugs.webkit.org/show_bug.cgi?id=156034 for more details. + if (typeof value == 'function') { + return value; + } + if (value == null) { + return identity; + } + if (typeof value == 'object') { + return isArray(value) + ? baseMatchesProperty(value[0], value[1]) + : baseMatches(value); + } + return property(value); + } + + /** + * The base implementation of `_.keys` which doesn't treat sparse arrays as dense. + * + * @private + * @param {Object} object The object to query. + * @returns {Array} Returns the array of property names. + */ + function baseKeys(object) { + if (!isPrototype(object)) { + return nativeKeys(object); + } + var result = []; + for (var key in Object(object)) { + if (hasOwnProperty.call(object, key) && key != 'constructor') { + result.push(key); + } + } + return result; + } + + /** + * The base implementation of `_.keysIn` which doesn't treat sparse arrays as dense. + * + * @private + * @param {Object} object The object to query. + * @returns {Array} Returns the array of property names. + */ + function baseKeysIn(object) { + if (!isObject(object)) { + return nativeKeysIn(object); + } + var isProto = isPrototype(object), + result = []; + + for (var key in object) { + if (!(key == 'constructor' && (isProto || !hasOwnProperty.call(object, key)))) { + result.push(key); + } + } + return result; + } + + /** + * The base implementation of `_.lt` which doesn't coerce arguments. + * + * @private + * @param {*} value The value to compare. + * @param {*} other The other value to compare. + * @returns {boolean} Returns `true` if `value` is less than `other`, + * else `false`. + */ + function baseLt(value, other) { + return value < other; + } + + /** + * The base implementation of `_.map` without support for iteratee shorthands. + * + * @private + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} iteratee The function invoked per iteration. + * @returns {Array} Returns the new mapped array. + */ + function baseMap(collection, iteratee) { + var index = -1, + result = isArrayLike(collection) ? Array(collection.length) : []; + + baseEach(collection, function(value, key, collection) { + result[++index] = iteratee(value, key, collection); + }); + return result; + } + + /** + * The base implementation of `_.matches` which doesn't clone `source`. + * + * @private + * @param {Object} source The object of property values to match. + * @returns {Function} Returns the new spec function. + */ + function baseMatches(source) { + var matchData = getMatchData(source); + if (matchData.length == 1 && matchData[0][2]) { + return matchesStrictComparable(matchData[0][0], matchData[0][1]); + } + return function(object) { + return object === source || baseIsMatch(object, source, matchData); + }; + } + + /** + * The base implementation of `_.matchesProperty` which doesn't clone `srcValue`. + * + * @private + * @param {string} path The path of the property to get. + * @param {*} srcValue The value to match. + * @returns {Function} Returns the new spec function. + */ + function baseMatchesProperty(path, srcValue) { + if (isKey(path) && isStrictComparable(srcValue)) { + return matchesStrictComparable(toKey(path), srcValue); + } + return function(object) { + var objValue = get(object, path); + return (objValue === undefined && objValue === srcValue) + ? hasIn(object, path) + : baseIsEqual(srcValue, objValue, COMPARE_PARTIAL_FLAG | COMPARE_UNORDERED_FLAG); + }; + } + + /** + * The base implementation of `_.merge` without support for multiple sources. + * + * @private + * @param {Object} object The destination object. + * @param {Object} source The source object. + * @param {number} srcIndex The index of `source`. + * @param {Function} [customizer] The function to customize merged values. + * @param {Object} [stack] Tracks traversed source values and their merged + * counterparts. + */ + function baseMerge(object, source, srcIndex, customizer, stack) { + if (object === source) { + return; + } + baseFor(source, function(srcValue, key) { + if (isObject(srcValue)) { + stack || (stack = new Stack); + baseMergeDeep(object, source, key, srcIndex, baseMerge, customizer, stack); + } + else { + var newValue = customizer + ? customizer(safeGet(object, key), srcValue, (key + ''), object, source, stack) + : undefined; + + if (newValue === undefined) { + newValue = srcValue; + } + assignMergeValue(object, key, newValue); + } + }, keysIn); + } + + /** + * A specialized version of `baseMerge` for arrays and objects which performs + * deep merges and tracks traversed objects enabling objects with circular + * references to be merged. + * + * @private + * @param {Object} object The destination object. + * @param {Object} source The source object. + * @param {string} key The key of the value to merge. + * @param {number} srcIndex The index of `source`. + * @param {Function} mergeFunc The function to merge values. + * @param {Function} [customizer] The function to customize assigned values. + * @param {Object} [stack] Tracks traversed source values and their merged + * counterparts. + */ + function baseMergeDeep(object, source, key, srcIndex, mergeFunc, customizer, stack) { + var objValue = safeGet(object, key), + srcValue = safeGet(source, key), + stacked = stack.get(srcValue); + + if (stacked) { + assignMergeValue(object, key, stacked); + return; + } + var newValue = customizer + ? customizer(objValue, srcValue, (key + ''), object, source, stack) + : undefined; + + var isCommon = newValue === undefined; + + if (isCommon) { + var isArr = isArray(srcValue), + isBuff = !isArr && isBuffer(srcValue), + isTyped = !isArr && !isBuff && isTypedArray(srcValue); + + newValue = srcValue; + if (isArr || isBuff || isTyped) { + if (isArray(objValue)) { + newValue = objValue; + } + else if (isArrayLikeObject(objValue)) { + newValue = copyArray(objValue); + } + else if (isBuff) { + isCommon = false; + newValue = cloneBuffer(srcValue, true); + } + else if (isTyped) { + isCommon = false; + newValue = cloneTypedArray(srcValue, true); + } + else { + newValue = []; + } + } + else if (isPlainObject(srcValue) || isArguments(srcValue)) { + newValue = objValue; + if (isArguments(objValue)) { + newValue = toPlainObject(objValue); + } + else if (!isObject(objValue) || isFunction(objValue)) { + newValue = initCloneObject(srcValue); + } + } + else { + isCommon = false; + } + } + if (isCommon) { + // Recursively merge objects and arrays (susceptible to call stack limits). + stack.set(srcValue, newValue); + mergeFunc(newValue, srcValue, srcIndex, customizer, stack); + stack['delete'](srcValue); + } + assignMergeValue(object, key, newValue); + } + + /** + * The base implementation of `_.nth` which doesn't coerce arguments. + * + * @private + * @param {Array} array The array to query. + * @param {number} n The index of the element to return. + * @returns {*} Returns the nth element of `array`. + */ + function baseNth(array, n) { + var length = array.length; + if (!length) { + return; + } + n += n < 0 ? length : 0; + return isIndex(n, length) ? array[n] : undefined; + } + + /** + * The base implementation of `_.orderBy` without param guards. + * + * @private + * @param {Array|Object} collection The collection to iterate over. + * @param {Function[]|Object[]|string[]} iteratees The iteratees to sort by. + * @param {string[]} orders The sort orders of `iteratees`. + * @returns {Array} Returns the new sorted array. + */ + function baseOrderBy(collection, iteratees, orders) { + var index = -1; + iteratees = arrayMap(iteratees.length ? iteratees : [identity], baseUnary(getIteratee())); + + var result = baseMap(collection, function(value, key, collection) { + var criteria = arrayMap(iteratees, function(iteratee) { + return iteratee(value); + }); + return { 'criteria': criteria, 'index': ++index, 'value': value }; + }); + + return baseSortBy(result, function(object, other) { + return compareMultiple(object, other, orders); + }); + } + + /** + * The base implementation of `_.pick` without support for individual + * property identifiers. + * + * @private + * @param {Object} object The source object. + * @param {string[]} paths The property paths to pick. + * @returns {Object} Returns the new object. + */ + function basePick(object, paths) { + return basePickBy(object, paths, function(value, path) { + return hasIn(object, path); + }); + } + + /** + * The base implementation of `_.pickBy` without support for iteratee shorthands. + * + * @private + * @param {Object} object The source object. + * @param {string[]} paths The property paths to pick. + * @param {Function} predicate The function invoked per property. + * @returns {Object} Returns the new object. + */ + function basePickBy(object, paths, predicate) { + var index = -1, + length = paths.length, + result = {}; + + while (++index < length) { + var path = paths[index], + value = baseGet(object, path); + + if (predicate(value, path)) { + baseSet(result, castPath(path, object), value); + } + } + return result; + } + + /** + * A specialized version of `baseProperty` which supports deep paths. + * + * @private + * @param {Array|string} path The path of the property to get. + * @returns {Function} Returns the new accessor function. + */ + function basePropertyDeep(path) { + return function(object) { + return baseGet(object, path); + }; + } + + /** + * The base implementation of `_.pullAllBy` without support for iteratee + * shorthands. + * + * @private + * @param {Array} array The array to modify. + * @param {Array} values The values to remove. + * @param {Function} [iteratee] The iteratee invoked per element. + * @param {Function} [comparator] The comparator invoked per element. + * @returns {Array} Returns `array`. + */ + function basePullAll(array, values, iteratee, comparator) { + var indexOf = comparator ? baseIndexOfWith : baseIndexOf, + index = -1, + length = values.length, + seen = array; + + if (array === values) { + values = copyArray(values); + } + if (iteratee) { + seen = arrayMap(array, baseUnary(iteratee)); + } + while (++index < length) { + var fromIndex = 0, + value = values[index], + computed = iteratee ? iteratee(value) : value; + + while ((fromIndex = indexOf(seen, computed, fromIndex, comparator)) > -1) { + if (seen !== array) { + splice.call(seen, fromIndex, 1); + } + splice.call(array, fromIndex, 1); + } + } + return array; + } + + /** + * The base implementation of `_.pullAt` without support for individual + * indexes or capturing the removed elements. + * + * @private + * @param {Array} array The array to modify. + * @param {number[]} indexes The indexes of elements to remove. + * @returns {Array} Returns `array`. + */ + function basePullAt(array, indexes) { + var length = array ? indexes.length : 0, + lastIndex = length - 1; + + while (length--) { + var index = indexes[length]; + if (length == lastIndex || index !== previous) { + var previous = index; + if (isIndex(index)) { + splice.call(array, index, 1); + } else { + baseUnset(array, index); + } + } + } + return array; + } + + /** + * The base implementation of `_.random` without support for returning + * floating-point numbers. + * + * @private + * @param {number} lower The lower bound. + * @param {number} upper The upper bound. + * @returns {number} Returns the random number. + */ + function baseRandom(lower, upper) { + return lower + nativeFloor(nativeRandom() * (upper - lower + 1)); + } + + /** + * The base implementation of `_.range` and `_.rangeRight` which doesn't + * coerce arguments. + * + * @private + * @param {number} start The start of the range. + * @param {number} end The end of the range. + * @param {number} step The value to increment or decrement by. + * @param {boolean} [fromRight] Specify iterating from right to left. + * @returns {Array} Returns the range of numbers. + */ + function baseRange(start, end, step, fromRight) { + var index = -1, + length = nativeMax(nativeCeil((end - start) / (step || 1)), 0), + result = Array(length); + + while (length--) { + result[fromRight ? length : ++index] = start; + start += step; + } + return result; + } + + /** + * The base implementation of `_.repeat` which doesn't coerce arguments. + * + * @private + * @param {string} string The string to repeat. + * @param {number} n The number of times to repeat the string. + * @returns {string} Returns the repeated string. + */ + function baseRepeat(string, n) { + var result = ''; + if (!string || n < 1 || n > MAX_SAFE_INTEGER) { + return result; + } + // Leverage the exponentiation by squaring algorithm for a faster repeat. + // See https://en.wikipedia.org/wiki/Exponentiation_by_squaring for more details. + do { + if (n % 2) { + result += string; + } + n = nativeFloor(n / 2); + if (n) { + string += string; + } + } while (n); + + return result; + } + + /** + * The base implementation of `_.rest` which doesn't validate or coerce arguments. + * + * @private + * @param {Function} func The function to apply a rest parameter to. + * @param {number} [start=func.length-1] The start position of the rest parameter. + * @returns {Function} Returns the new function. + */ + function baseRest(func, start) { + return setToString(overRest(func, start, identity), func + ''); + } + + /** + * The base implementation of `_.sample`. + * + * @private + * @param {Array|Object} collection The collection to sample. + * @returns {*} Returns the random element. + */ + function baseSample(collection) { + return arraySample(values(collection)); + } + + /** + * The base implementation of `_.sampleSize` without param guards. + * + * @private + * @param {Array|Object} collection The collection to sample. + * @param {number} n The number of elements to sample. + * @returns {Array} Returns the random elements. + */ + function baseSampleSize(collection, n) { + var array = values(collection); + return shuffleSelf(array, baseClamp(n, 0, array.length)); + } + + /** + * The base implementation of `_.set`. + * + * @private + * @param {Object} object The object to modify. + * @param {Array|string} path The path of the property to set. + * @param {*} value The value to set. + * @param {Function} [customizer] The function to customize path creation. + * @returns {Object} Returns `object`. + */ + function baseSet(object, path, value, customizer) { + if (!isObject(object)) { + return object; + } + path = castPath(path, object); + + var index = -1, + length = path.length, + lastIndex = length - 1, + nested = object; + + while (nested != null && ++index < length) { + var key = toKey(path[index]), + newValue = value; + + if (index != lastIndex) { + var objValue = nested[key]; + newValue = customizer ? customizer(objValue, key, nested) : undefined; + if (newValue === undefined) { + newValue = isObject(objValue) + ? objValue + : (isIndex(path[index + 1]) ? [] : {}); + } + } + assignValue(nested, key, newValue); + nested = nested[key]; + } + return object; + } + + /** + * The base implementation of `setData` without support for hot loop shorting. + * + * @private + * @param {Function} func The function to associate metadata with. + * @param {*} data The metadata. + * @returns {Function} Returns `func`. + */ + var baseSetData = !metaMap ? identity : function(func, data) { + metaMap.set(func, data); + return func; + }; + + /** + * The base implementation of `setToString` without support for hot loop shorting. + * + * @private + * @param {Function} func The function to modify. + * @param {Function} string The `toString` result. + * @returns {Function} Returns `func`. + */ + var baseSetToString = !defineProperty ? identity : function(func, string) { + return defineProperty(func, 'toString', { + 'configurable': true, + 'enumerable': false, + 'value': constant(string), + 'writable': true + }); + }; + + /** + * The base implementation of `_.shuffle`. + * + * @private + * @param {Array|Object} collection The collection to shuffle. + * @returns {Array} Returns the new shuffled array. + */ + function baseShuffle(collection) { + return shuffleSelf(values(collection)); + } + + /** + * The base implementation of `_.slice` without an iteratee call guard. + * + * @private + * @param {Array} array The array to slice. + * @param {number} [start=0] The start position. + * @param {number} [end=array.length] The end position. + * @returns {Array} Returns the slice of `array`. + */ + function baseSlice(array, start, end) { + var index = -1, + length = array.length; + + if (start < 0) { + start = -start > length ? 0 : (length + start); + } + end = end > length ? length : end; + if (end < 0) { + end += length; + } + length = start > end ? 0 : ((end - start) >>> 0); + start >>>= 0; + + var result = Array(length); + while (++index < length) { + result[index] = array[index + start]; + } + return result; + } + + /** + * The base implementation of `_.some` without support for iteratee shorthands. + * + * @private + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} predicate The function invoked per iteration. + * @returns {boolean} Returns `true` if any element passes the predicate check, + * else `false`. + */ + function baseSome(collection, predicate) { + var result; + + baseEach(collection, function(value, index, collection) { + result = predicate(value, index, collection); + return !result; + }); + return !!result; + } + + /** + * The base implementation of `_.sortedIndex` and `_.sortedLastIndex` which + * performs a binary search of `array` to determine the index at which `value` + * should be inserted into `array` in order to maintain its sort order. + * + * @private + * @param {Array} array The sorted array to inspect. + * @param {*} value The value to evaluate. + * @param {boolean} [retHighest] Specify returning the highest qualified index. + * @returns {number} Returns the index at which `value` should be inserted + * into `array`. + */ + function baseSortedIndex(array, value, retHighest) { + var low = 0, + high = array == null ? low : array.length; + + if (typeof value == 'number' && value === value && high <= HALF_MAX_ARRAY_LENGTH) { + while (low < high) { + var mid = (low + high) >>> 1, + computed = array[mid]; + + if (computed !== null && !isSymbol(computed) && + (retHighest ? (computed <= value) : (computed < value))) { + low = mid + 1; + } else { + high = mid; + } + } + return high; + } + return baseSortedIndexBy(array, value, identity, retHighest); + } + + /** + * The base implementation of `_.sortedIndexBy` and `_.sortedLastIndexBy` + * which invokes `iteratee` for `value` and each element of `array` to compute + * their sort ranking. The iteratee is invoked with one argument; (value). + * + * @private + * @param {Array} array The sorted array to inspect. + * @param {*} value The value to evaluate. + * @param {Function} iteratee The iteratee invoked per element. + * @param {boolean} [retHighest] Specify returning the highest qualified index. + * @returns {number} Returns the index at which `value` should be inserted + * into `array`. + */ + function baseSortedIndexBy(array, value, iteratee, retHighest) { + value = iteratee(value); + + var low = 0, + high = array == null ? 0 : array.length, + valIsNaN = value !== value, + valIsNull = value === null, + valIsSymbol = isSymbol(value), + valIsUndefined = value === undefined; + + while (low < high) { + var mid = nativeFloor((low + high) / 2), + computed = iteratee(array[mid]), + othIsDefined = computed !== undefined, + othIsNull = computed === null, + othIsReflexive = computed === computed, + othIsSymbol = isSymbol(computed); + + if (valIsNaN) { + var setLow = retHighest || othIsReflexive; + } else if (valIsUndefined) { + setLow = othIsReflexive && (retHighest || othIsDefined); + } else if (valIsNull) { + setLow = othIsReflexive && othIsDefined && (retHighest || !othIsNull); + } else if (valIsSymbol) { + setLow = othIsReflexive && othIsDefined && !othIsNull && (retHighest || !othIsSymbol); + } else if (othIsNull || othIsSymbol) { + setLow = false; + } else { + setLow = retHighest ? (computed <= value) : (computed < value); + } + if (setLow) { + low = mid + 1; + } else { + high = mid; + } + } + return nativeMin(high, MAX_ARRAY_INDEX); + } + + /** + * The base implementation of `_.sortedUniq` and `_.sortedUniqBy` without + * support for iteratee shorthands. + * + * @private + * @param {Array} array The array to inspect. + * @param {Function} [iteratee] The iteratee invoked per element. + * @returns {Array} Returns the new duplicate free array. + */ + function baseSortedUniq(array, iteratee) { + var index = -1, + length = array.length, + resIndex = 0, + result = []; + + while (++index < length) { + var value = array[index], + computed = iteratee ? iteratee(value) : value; + + if (!index || !eq(computed, seen)) { + var seen = computed; + result[resIndex++] = value === 0 ? 0 : value; + } + } + return result; + } + + /** + * The base implementation of `_.toNumber` which doesn't ensure correct + * conversions of binary, hexadecimal, or octal string values. + * + * @private + * @param {*} value The value to process. + * @returns {number} Returns the number. + */ + function baseToNumber(value) { + if (typeof value == 'number') { + return value; + } + if (isSymbol(value)) { + return NAN; + } + return +value; + } + + /** + * The base implementation of `_.toString` which doesn't convert nullish + * values to empty strings. + * + * @private + * @param {*} value The value to process. + * @returns {string} Returns the string. + */ + function baseToString(value) { + // Exit early for strings to avoid a performance hit in some environments. + if (typeof value == 'string') { + return value; + } + if (isArray(value)) { + // Recursively convert values (susceptible to call stack limits). + return arrayMap(value, baseToString) + ''; + } + if (isSymbol(value)) { + return symbolToString ? symbolToString.call(value) : ''; + } + var result = (value + ''); + return (result == '0' && (1 / value) == -INFINITY) ? '-0' : result; + } + + /** + * The base implementation of `_.uniqBy` without support for iteratee shorthands. + * + * @private + * @param {Array} array The array to inspect. + * @param {Function} [iteratee] The iteratee invoked per element. + * @param {Function} [comparator] The comparator invoked per element. + * @returns {Array} Returns the new duplicate free array. + */ + function baseUniq(array, iteratee, comparator) { + var index = -1, + includes = arrayIncludes, + length = array.length, + isCommon = true, + result = [], + seen = result; + + if (comparator) { + isCommon = false; + includes = arrayIncludesWith; + } + else if (length >= LARGE_ARRAY_SIZE) { + var set = iteratee ? null : createSet(array); + if (set) { + return setToArray(set); + } + isCommon = false; + includes = cacheHas; + seen = new SetCache; + } + else { + seen = iteratee ? [] : result; + } + outer: + while (++index < length) { + var value = array[index], + computed = iteratee ? iteratee(value) : value; + + value = (comparator || value !== 0) ? value : 0; + if (isCommon && computed === computed) { + var seenIndex = seen.length; + while (seenIndex--) { + if (seen[seenIndex] === computed) { + continue outer; + } + } + if (iteratee) { + seen.push(computed); + } + result.push(value); + } + else if (!includes(seen, computed, comparator)) { + if (seen !== result) { + seen.push(computed); + } + result.push(value); + } + } + return result; + } + + /** + * The base implementation of `_.unset`. + * + * @private + * @param {Object} object The object to modify. + * @param {Array|string} path The property path to unset. + * @returns {boolean} Returns `true` if the property is deleted, else `false`. + */ + function baseUnset(object, path) { + path = castPath(path, object); + object = parent(object, path); + return object == null || delete object[toKey(last(path))]; + } + + /** + * The base implementation of `_.update`. + * + * @private + * @param {Object} object The object to modify. + * @param {Array|string} path The path of the property to update. + * @param {Function} updater The function to produce the updated value. + * @param {Function} [customizer] The function to customize path creation. + * @returns {Object} Returns `object`. + */ + function baseUpdate(object, path, updater, customizer) { + return baseSet(object, path, updater(baseGet(object, path)), customizer); + } + + /** + * The base implementation of methods like `_.dropWhile` and `_.takeWhile` + * without support for iteratee shorthands. + * + * @private + * @param {Array} array The array to query. + * @param {Function} predicate The function invoked per iteration. + * @param {boolean} [isDrop] Specify dropping elements instead of taking them. + * @param {boolean} [fromRight] Specify iterating from right to left. + * @returns {Array} Returns the slice of `array`. + */ + function baseWhile(array, predicate, isDrop, fromRight) { + var length = array.length, + index = fromRight ? length : -1; + + while ((fromRight ? index-- : ++index < length) && + predicate(array[index], index, array)) {} + + return isDrop + ? baseSlice(array, (fromRight ? 0 : index), (fromRight ? index + 1 : length)) + : baseSlice(array, (fromRight ? index + 1 : 0), (fromRight ? length : index)); + } + + /** + * The base implementation of `wrapperValue` which returns the result of + * performing a sequence of actions on the unwrapped `value`, where each + * successive action is supplied the return value of the previous. + * + * @private + * @param {*} value The unwrapped value. + * @param {Array} actions Actions to perform to resolve the unwrapped value. + * @returns {*} Returns the resolved value. + */ + function baseWrapperValue(value, actions) { + var result = value; + if (result instanceof LazyWrapper) { + result = result.value(); + } + return arrayReduce(actions, function(result, action) { + return action.func.apply(action.thisArg, arrayPush([result], action.args)); + }, result); + } + + /** + * The base implementation of methods like `_.xor`, without support for + * iteratee shorthands, that accepts an array of arrays to inspect. + * + * @private + * @param {Array} arrays The arrays to inspect. + * @param {Function} [iteratee] The iteratee invoked per element. + * @param {Function} [comparator] The comparator invoked per element. + * @returns {Array} Returns the new array of values. + */ + function baseXor(arrays, iteratee, comparator) { + var length = arrays.length; + if (length < 2) { + return length ? baseUniq(arrays[0]) : []; + } + var index = -1, + result = Array(length); + + while (++index < length) { + var array = arrays[index], + othIndex = -1; + + while (++othIndex < length) { + if (othIndex != index) { + result[index] = baseDifference(result[index] || array, arrays[othIndex], iteratee, comparator); + } + } + } + return baseUniq(baseFlatten(result, 1), iteratee, comparator); + } + + /** + * This base implementation of `_.zipObject` which assigns values using `assignFunc`. + * + * @private + * @param {Array} props The property identifiers. + * @param {Array} values The property values. + * @param {Function} assignFunc The function to assign values. + * @returns {Object} Returns the new object. + */ + function baseZipObject(props, values, assignFunc) { + var index = -1, + length = props.length, + valsLength = values.length, + result = {}; + + while (++index < length) { + var value = index < valsLength ? values[index] : undefined; + assignFunc(result, props[index], value); + } + return result; + } + + /** + * Casts `value` to an empty array if it's not an array like object. + * + * @private + * @param {*} value The value to inspect. + * @returns {Array|Object} Returns the cast array-like object. + */ + function castArrayLikeObject(value) { + return isArrayLikeObject(value) ? value : []; + } + + /** + * Casts `value` to `identity` if it's not a function. + * + * @private + * @param {*} value The value to inspect. + * @returns {Function} Returns cast function. + */ + function castFunction(value) { + return typeof value == 'function' ? value : identity; + } + + /** + * Casts `value` to a path array if it's not one. + * + * @private + * @param {*} value The value to inspect. + * @param {Object} [object] The object to query keys on. + * @returns {Array} Returns the cast property path array. + */ + function castPath(value, object) { + if (isArray(value)) { + return value; + } + return isKey(value, object) ? [value] : stringToPath(toString(value)); + } + + /** + * A `baseRest` alias which can be replaced with `identity` by module + * replacement plugins. + * + * @private + * @type {Function} + * @param {Function} func The function to apply a rest parameter to. + * @returns {Function} Returns the new function. + */ + var castRest = baseRest; + + /** + * Casts `array` to a slice if it's needed. + * + * @private + * @param {Array} array The array to inspect. + * @param {number} start The start position. + * @param {number} [end=array.length] The end position. + * @returns {Array} Returns the cast slice. + */ + function castSlice(array, start, end) { + var length = array.length; + end = end === undefined ? length : end; + return (!start && end >= length) ? array : baseSlice(array, start, end); + } + + /** + * A simple wrapper around the global [`clearTimeout`](https://mdn.io/clearTimeout). + * + * @private + * @param {number|Object} id The timer id or timeout object of the timer to clear. + */ + var clearTimeout = ctxClearTimeout || function(id) { + return root.clearTimeout(id); + }; + + /** + * Creates a clone of `buffer`. + * + * @private + * @param {Buffer} buffer The buffer to clone. + * @param {boolean} [isDeep] Specify a deep clone. + * @returns {Buffer} Returns the cloned buffer. + */ + function cloneBuffer(buffer, isDeep) { + if (isDeep) { + return buffer.slice(); + } + var length = buffer.length, + result = allocUnsafe ? allocUnsafe(length) : new buffer.constructor(length); + + buffer.copy(result); + return result; + } + + /** + * Creates a clone of `arrayBuffer`. + * + * @private + * @param {ArrayBuffer} arrayBuffer The array buffer to clone. + * @returns {ArrayBuffer} Returns the cloned array buffer. + */ + function cloneArrayBuffer(arrayBuffer) { + var result = new arrayBuffer.constructor(arrayBuffer.byteLength); + new Uint8Array(result).set(new Uint8Array(arrayBuffer)); + return result; + } + + /** + * Creates a clone of `dataView`. + * + * @private + * @param {Object} dataView The data view to clone. + * @param {boolean} [isDeep] Specify a deep clone. + * @returns {Object} Returns the cloned data view. + */ + function cloneDataView(dataView, isDeep) { + var buffer = isDeep ? cloneArrayBuffer(dataView.buffer) : dataView.buffer; + return new dataView.constructor(buffer, dataView.byteOffset, dataView.byteLength); + } + + /** + * Creates a clone of `regexp`. + * + * @private + * @param {Object} regexp The regexp to clone. + * @returns {Object} Returns the cloned regexp. + */ + function cloneRegExp(regexp) { + var result = new regexp.constructor(regexp.source, reFlags.exec(regexp)); + result.lastIndex = regexp.lastIndex; + return result; + } + + /** + * Creates a clone of the `symbol` object. + * + * @private + * @param {Object} symbol The symbol object to clone. + * @returns {Object} Returns the cloned symbol object. + */ + function cloneSymbol(symbol) { + return symbolValueOf ? Object(symbolValueOf.call(symbol)) : {}; + } + + /** + * Creates a clone of `typedArray`. + * + * @private + * @param {Object} typedArray The typed array to clone. + * @param {boolean} [isDeep] Specify a deep clone. + * @returns {Object} Returns the cloned typed array. + */ + function cloneTypedArray(typedArray, isDeep) { + var buffer = isDeep ? cloneArrayBuffer(typedArray.buffer) : typedArray.buffer; + return new typedArray.constructor(buffer, typedArray.byteOffset, typedArray.length); + } + + /** + * Compares values to sort them in ascending order. + * + * @private + * @param {*} value The value to compare. + * @param {*} other The other value to compare. + * @returns {number} Returns the sort order indicator for `value`. + */ + function compareAscending(value, other) { + if (value !== other) { + var valIsDefined = value !== undefined, + valIsNull = value === null, + valIsReflexive = value === value, + valIsSymbol = isSymbol(value); + + var othIsDefined = other !== undefined, + othIsNull = other === null, + othIsReflexive = other === other, + othIsSymbol = isSymbol(other); + + if ((!othIsNull && !othIsSymbol && !valIsSymbol && value > other) || + (valIsSymbol && othIsDefined && othIsReflexive && !othIsNull && !othIsSymbol) || + (valIsNull && othIsDefined && othIsReflexive) || + (!valIsDefined && othIsReflexive) || + !valIsReflexive) { + return 1; + } + if ((!valIsNull && !valIsSymbol && !othIsSymbol && value < other) || + (othIsSymbol && valIsDefined && valIsReflexive && !valIsNull && !valIsSymbol) || + (othIsNull && valIsDefined && valIsReflexive) || + (!othIsDefined && valIsReflexive) || + !othIsReflexive) { + return -1; + } + } + return 0; + } + + /** + * Used by `_.orderBy` to compare multiple properties of a value to another + * and stable sort them. + * + * If `orders` is unspecified, all values are sorted in ascending order. Otherwise, + * specify an order of "desc" for descending or "asc" for ascending sort order + * of corresponding values. + * + * @private + * @param {Object} object The object to compare. + * @param {Object} other The other object to compare. + * @param {boolean[]|string[]} orders The order to sort by for each property. + * @returns {number} Returns the sort order indicator for `object`. + */ + function compareMultiple(object, other, orders) { + var index = -1, + objCriteria = object.criteria, + othCriteria = other.criteria, + length = objCriteria.length, + ordersLength = orders.length; + + while (++index < length) { + var result = compareAscending(objCriteria[index], othCriteria[index]); + if (result) { + if (index >= ordersLength) { + return result; + } + var order = orders[index]; + return result * (order == 'desc' ? -1 : 1); + } + } + // Fixes an `Array#sort` bug in the JS engine embedded in Adobe applications + // that causes it, under certain circumstances, to provide the same value for + // `object` and `other`. See https://github.com/jashkenas/underscore/pull/1247 + // for more details. + // + // This also ensures a stable sort in V8 and other engines. + // See https://bugs.chromium.org/p/v8/issues/detail?id=90 for more details. + return object.index - other.index; + } + + /** + * Creates an array that is the composition of partially applied arguments, + * placeholders, and provided arguments into a single array of arguments. + * + * @private + * @param {Array} args The provided arguments. + * @param {Array} partials The arguments to prepend to those provided. + * @param {Array} holders The `partials` placeholder indexes. + * @params {boolean} [isCurried] Specify composing for a curried function. + * @returns {Array} Returns the new array of composed arguments. + */ + function composeArgs(args, partials, holders, isCurried) { + var argsIndex = -1, + argsLength = args.length, + holdersLength = holders.length, + leftIndex = -1, + leftLength = partials.length, + rangeLength = nativeMax(argsLength - holdersLength, 0), + result = Array(leftLength + rangeLength), + isUncurried = !isCurried; + + while (++leftIndex < leftLength) { + result[leftIndex] = partials[leftIndex]; + } + while (++argsIndex < holdersLength) { + if (isUncurried || argsIndex < argsLength) { + result[holders[argsIndex]] = args[argsIndex]; + } + } + while (rangeLength--) { + result[leftIndex++] = args[argsIndex++]; + } + return result; + } + + /** + * This function is like `composeArgs` except that the arguments composition + * is tailored for `_.partialRight`. + * + * @private + * @param {Array} args The provided arguments. + * @param {Array} partials The arguments to append to those provided. + * @param {Array} holders The `partials` placeholder indexes. + * @params {boolean} [isCurried] Specify composing for a curried function. + * @returns {Array} Returns the new array of composed arguments. + */ + function composeArgsRight(args, partials, holders, isCurried) { + var argsIndex = -1, + argsLength = args.length, + holdersIndex = -1, + holdersLength = holders.length, + rightIndex = -1, + rightLength = partials.length, + rangeLength = nativeMax(argsLength - holdersLength, 0), + result = Array(rangeLength + rightLength), + isUncurried = !isCurried; + + while (++argsIndex < rangeLength) { + result[argsIndex] = args[argsIndex]; + } + var offset = argsIndex; + while (++rightIndex < rightLength) { + result[offset + rightIndex] = partials[rightIndex]; + } + while (++holdersIndex < holdersLength) { + if (isUncurried || argsIndex < argsLength) { + result[offset + holders[holdersIndex]] = args[argsIndex++]; + } + } + return result; + } + + /** + * Copies the values of `source` to `array`. + * + * @private + * @param {Array} source The array to copy values from. + * @param {Array} [array=[]] The array to copy values to. + * @returns {Array} Returns `array`. + */ + function copyArray(source, array) { + var index = -1, + length = source.length; + + array || (array = Array(length)); + while (++index < length) { + array[index] = source[index]; + } + return array; + } + + /** + * Copies properties of `source` to `object`. + * + * @private + * @param {Object} source The object to copy properties from. + * @param {Array} props The property identifiers to copy. + * @param {Object} [object={}] The object to copy properties to. + * @param {Function} [customizer] The function to customize copied values. + * @returns {Object} Returns `object`. + */ + function copyObject(source, props, object, customizer) { + var isNew = !object; + object || (object = {}); + + var index = -1, + length = props.length; + + while (++index < length) { + var key = props[index]; + + var newValue = customizer + ? customizer(object[key], source[key], key, object, source) + : undefined; + + if (newValue === undefined) { + newValue = source[key]; + } + if (isNew) { + baseAssignValue(object, key, newValue); + } else { + assignValue(object, key, newValue); + } + } + return object; + } + + /** + * Copies own symbols of `source` to `object`. + * + * @private + * @param {Object} source The object to copy symbols from. + * @param {Object} [object={}] The object to copy symbols to. + * @returns {Object} Returns `object`. + */ + function copySymbols(source, object) { + return copyObject(source, getSymbols(source), object); + } + + /** + * Copies own and inherited symbols of `source` to `object`. + * + * @private + * @param {Object} source The object to copy symbols from. + * @param {Object} [object={}] The object to copy symbols to. + * @returns {Object} Returns `object`. + */ + function copySymbolsIn(source, object) { + return copyObject(source, getSymbolsIn(source), object); + } + + /** + * Creates a function like `_.groupBy`. + * + * @private + * @param {Function} setter The function to set accumulator values. + * @param {Function} [initializer] The accumulator object initializer. + * @returns {Function} Returns the new aggregator function. + */ + function createAggregator(setter, initializer) { + return function(collection, iteratee) { + var func = isArray(collection) ? arrayAggregator : baseAggregator, + accumulator = initializer ? initializer() : {}; + + return func(collection, setter, getIteratee(iteratee, 2), accumulator); + }; + } + + /** + * Creates a function like `_.assign`. + * + * @private + * @param {Function} assigner The function to assign values. + * @returns {Function} Returns the new assigner function. + */ + function createAssigner(assigner) { + return baseRest(function(object, sources) { + var index = -1, + length = sources.length, + customizer = length > 1 ? sources[length - 1] : undefined, + guard = length > 2 ? sources[2] : undefined; + + customizer = (assigner.length > 3 && typeof customizer == 'function') + ? (length--, customizer) + : undefined; + + if (guard && isIterateeCall(sources[0], sources[1], guard)) { + customizer = length < 3 ? undefined : customizer; + length = 1; + } + object = Object(object); + while (++index < length) { + var source = sources[index]; + if (source) { + assigner(object, source, index, customizer); + } + } + return object; + }); + } + + /** + * Creates a `baseEach` or `baseEachRight` function. + * + * @private + * @param {Function} eachFunc The function to iterate over a collection. + * @param {boolean} [fromRight] Specify iterating from right to left. + * @returns {Function} Returns the new base function. + */ + function createBaseEach(eachFunc, fromRight) { + return function(collection, iteratee) { + if (collection == null) { + return collection; + } + if (!isArrayLike(collection)) { + return eachFunc(collection, iteratee); + } + var length = collection.length, + index = fromRight ? length : -1, + iterable = Object(collection); + + while ((fromRight ? index-- : ++index < length)) { + if (iteratee(iterable[index], index, iterable) === false) { + break; + } + } + return collection; + }; + } + + /** + * Creates a base function for methods like `_.forIn` and `_.forOwn`. + * + * @private + * @param {boolean} [fromRight] Specify iterating from right to left. + * @returns {Function} Returns the new base function. + */ + function createBaseFor(fromRight) { + return function(object, iteratee, keysFunc) { + var index = -1, + iterable = Object(object), + props = keysFunc(object), + length = props.length; + + while (length--) { + var key = props[fromRight ? length : ++index]; + if (iteratee(iterable[key], key, iterable) === false) { + break; + } + } + return object; + }; + } + + /** + * Creates a function that wraps `func` to invoke it with the optional `this` + * binding of `thisArg`. + * + * @private + * @param {Function} func The function to wrap. + * @param {number} bitmask The bitmask flags. See `createWrap` for more details. + * @param {*} [thisArg] The `this` binding of `func`. + * @returns {Function} Returns the new wrapped function. + */ + function createBind(func, bitmask, thisArg) { + var isBind = bitmask & WRAP_BIND_FLAG, + Ctor = createCtor(func); + + function wrapper() { + var fn = (this && this !== root && this instanceof wrapper) ? Ctor : func; + return fn.apply(isBind ? thisArg : this, arguments); + } + return wrapper; + } + + /** + * Creates a function like `_.lowerFirst`. + * + * @private + * @param {string} methodName The name of the `String` case method to use. + * @returns {Function} Returns the new case function. + */ + function createCaseFirst(methodName) { + return function(string) { + string = toString(string); + + var strSymbols = hasUnicode(string) + ? stringToArray(string) + : undefined; + + var chr = strSymbols + ? strSymbols[0] + : string.charAt(0); + + var trailing = strSymbols + ? castSlice(strSymbols, 1).join('') + : string.slice(1); + + return chr[methodName]() + trailing; + }; + } + + /** + * Creates a function like `_.camelCase`. + * + * @private + * @param {Function} callback The function to combine each word. + * @returns {Function} Returns the new compounder function. + */ + function createCompounder(callback) { + return function(string) { + return arrayReduce(words(deburr(string).replace(reApos, '')), callback, ''); + }; + } + + /** + * Creates a function that produces an instance of `Ctor` regardless of + * whether it was invoked as part of a `new` expression or by `call` or `apply`. + * + * @private + * @param {Function} Ctor The constructor to wrap. + * @returns {Function} Returns the new wrapped function. + */ + function createCtor(Ctor) { + return function() { + // Use a `switch` statement to work with class constructors. See + // http://ecma-international.org/ecma-262/7.0/#sec-ecmascript-function-objects-call-thisargument-argumentslist + // for more details. + var args = arguments; + switch (args.length) { + case 0: return new Ctor; + case 1: return new Ctor(args[0]); + case 2: return new Ctor(args[0], args[1]); + case 3: return new Ctor(args[0], args[1], args[2]); + case 4: return new Ctor(args[0], args[1], args[2], args[3]); + case 5: return new Ctor(args[0], args[1], args[2], args[3], args[4]); + case 6: return new Ctor(args[0], args[1], args[2], args[3], args[4], args[5]); + case 7: return new Ctor(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); + } + var thisBinding = baseCreate(Ctor.prototype), + result = Ctor.apply(thisBinding, args); + + // Mimic the constructor's `return` behavior. + // See https://es5.github.io/#x13.2.2 for more details. + return isObject(result) ? result : thisBinding; + }; + } + + /** + * Creates a function that wraps `func` to enable currying. + * + * @private + * @param {Function} func The function to wrap. + * @param {number} bitmask The bitmask flags. See `createWrap` for more details. + * @param {number} arity The arity of `func`. + * @returns {Function} Returns the new wrapped function. + */ + function createCurry(func, bitmask, arity) { + var Ctor = createCtor(func); + + function wrapper() { + var length = arguments.length, + args = Array(length), + index = length, + placeholder = getHolder(wrapper); + + while (index--) { + args[index] = arguments[index]; + } + var holders = (length < 3 && args[0] !== placeholder && args[length - 1] !== placeholder) + ? [] + : replaceHolders(args, placeholder); + + length -= holders.length; + if (length < arity) { + return createRecurry( + func, bitmask, createHybrid, wrapper.placeholder, undefined, + args, holders, undefined, undefined, arity - length); + } + var fn = (this && this !== root && this instanceof wrapper) ? Ctor : func; + return apply(fn, this, args); + } + return wrapper; + } + + /** + * Creates a `_.find` or `_.findLast` function. + * + * @private + * @param {Function} findIndexFunc The function to find the collection index. + * @returns {Function} Returns the new find function. + */ + function createFind(findIndexFunc) { + return function(collection, predicate, fromIndex) { + var iterable = Object(collection); + if (!isArrayLike(collection)) { + var iteratee = getIteratee(predicate, 3); + collection = keys(collection); + predicate = function(key) { return iteratee(iterable[key], key, iterable); }; + } + var index = findIndexFunc(collection, predicate, fromIndex); + return index > -1 ? iterable[iteratee ? collection[index] : index] : undefined; + }; + } + + /** + * Creates a `_.flow` or `_.flowRight` function. + * + * @private + * @param {boolean} [fromRight] Specify iterating from right to left. + * @returns {Function} Returns the new flow function. + */ + function createFlow(fromRight) { + return flatRest(function(funcs) { + var length = funcs.length, + index = length, + prereq = LodashWrapper.prototype.thru; + + if (fromRight) { + funcs.reverse(); + } + while (index--) { + var func = funcs[index]; + if (typeof func != 'function') { + throw new TypeError(FUNC_ERROR_TEXT); + } + if (prereq && !wrapper && getFuncName(func) == 'wrapper') { + var wrapper = new LodashWrapper([], true); + } + } + index = wrapper ? index : length; + while (++index < length) { + func = funcs[index]; + + var funcName = getFuncName(func), + data = funcName == 'wrapper' ? getData(func) : undefined; + + if (data && isLaziable(data[0]) && + data[1] == (WRAP_ARY_FLAG | WRAP_CURRY_FLAG | WRAP_PARTIAL_FLAG | WRAP_REARG_FLAG) && + !data[4].length && data[9] == 1 + ) { + wrapper = wrapper[getFuncName(data[0])].apply(wrapper, data[3]); + } else { + wrapper = (func.length == 1 && isLaziable(func)) + ? wrapper[funcName]() + : wrapper.thru(func); + } + } + return function() { + var args = arguments, + value = args[0]; + + if (wrapper && args.length == 1 && isArray(value)) { + return wrapper.plant(value).value(); + } + var index = 0, + result = length ? funcs[index].apply(this, args) : value; + + while (++index < length) { + result = funcs[index].call(this, result); + } + return result; + }; + }); + } + + /** + * Creates a function that wraps `func` to invoke it with optional `this` + * binding of `thisArg`, partial application, and currying. + * + * @private + * @param {Function|string} func The function or method name to wrap. + * @param {number} bitmask The bitmask flags. See `createWrap` for more details. + * @param {*} [thisArg] The `this` binding of `func`. + * @param {Array} [partials] The arguments to prepend to those provided to + * the new function. + * @param {Array} [holders] The `partials` placeholder indexes. + * @param {Array} [partialsRight] The arguments to append to those provided + * to the new function. + * @param {Array} [holdersRight] The `partialsRight` placeholder indexes. + * @param {Array} [argPos] The argument positions of the new function. + * @param {number} [ary] The arity cap of `func`. + * @param {number} [arity] The arity of `func`. + * @returns {Function} Returns the new wrapped function. + */ + function createHybrid(func, bitmask, thisArg, partials, holders, partialsRight, holdersRight, argPos, ary, arity) { + var isAry = bitmask & WRAP_ARY_FLAG, + isBind = bitmask & WRAP_BIND_FLAG, + isBindKey = bitmask & WRAP_BIND_KEY_FLAG, + isCurried = bitmask & (WRAP_CURRY_FLAG | WRAP_CURRY_RIGHT_FLAG), + isFlip = bitmask & WRAP_FLIP_FLAG, + Ctor = isBindKey ? undefined : createCtor(func); + + function wrapper() { + var length = arguments.length, + args = Array(length), + index = length; + + while (index--) { + args[index] = arguments[index]; + } + if (isCurried) { + var placeholder = getHolder(wrapper), + holdersCount = countHolders(args, placeholder); + } + if (partials) { + args = composeArgs(args, partials, holders, isCurried); + } + if (partialsRight) { + args = composeArgsRight(args, partialsRight, holdersRight, isCurried); + } + length -= holdersCount; + if (isCurried && length < arity) { + var newHolders = replaceHolders(args, placeholder); + return createRecurry( + func, bitmask, createHybrid, wrapper.placeholder, thisArg, + args, newHolders, argPos, ary, arity - length + ); + } + var thisBinding = isBind ? thisArg : this, + fn = isBindKey ? thisBinding[func] : func; + + length = args.length; + if (argPos) { + args = reorder(args, argPos); + } else if (isFlip && length > 1) { + args.reverse(); + } + if (isAry && ary < length) { + args.length = ary; + } + if (this && this !== root && this instanceof wrapper) { + fn = Ctor || createCtor(fn); + } + return fn.apply(thisBinding, args); + } + return wrapper; + } + + /** + * Creates a function like `_.invertBy`. + * + * @private + * @param {Function} setter The function to set accumulator values. + * @param {Function} toIteratee The function to resolve iteratees. + * @returns {Function} Returns the new inverter function. + */ + function createInverter(setter, toIteratee) { + return function(object, iteratee) { + return baseInverter(object, setter, toIteratee(iteratee), {}); + }; + } + + /** + * Creates a function that performs a mathematical operation on two values. + * + * @private + * @param {Function} operator The function to perform the operation. + * @param {number} [defaultValue] The value used for `undefined` arguments. + * @returns {Function} Returns the new mathematical operation function. + */ + function createMathOperation(operator, defaultValue) { + return function(value, other) { + var result; + if (value === undefined && other === undefined) { + return defaultValue; + } + if (value !== undefined) { + result = value; + } + if (other !== undefined) { + if (result === undefined) { + return other; + } + if (typeof value == 'string' || typeof other == 'string') { + value = baseToString(value); + other = baseToString(other); + } else { + value = baseToNumber(value); + other = baseToNumber(other); + } + result = operator(value, other); + } + return result; + }; + } + + /** + * Creates a function like `_.over`. + * + * @private + * @param {Function} arrayFunc The function to iterate over iteratees. + * @returns {Function} Returns the new over function. + */ + function createOver(arrayFunc) { + return flatRest(function(iteratees) { + iteratees = arrayMap(iteratees, baseUnary(getIteratee())); + return baseRest(function(args) { + var thisArg = this; + return arrayFunc(iteratees, function(iteratee) { + return apply(iteratee, thisArg, args); + }); + }); + }); + } + + /** + * Creates the padding for `string` based on `length`. The `chars` string + * is truncated if the number of characters exceeds `length`. + * + * @private + * @param {number} length The padding length. + * @param {string} [chars=' '] The string used as padding. + * @returns {string} Returns the padding for `string`. + */ + function createPadding(length, chars) { + chars = chars === undefined ? ' ' : baseToString(chars); + + var charsLength = chars.length; + if (charsLength < 2) { + return charsLength ? baseRepeat(chars, length) : chars; + } + var result = baseRepeat(chars, nativeCeil(length / stringSize(chars))); + return hasUnicode(chars) + ? castSlice(stringToArray(result), 0, length).join('') + : result.slice(0, length); + } + + /** + * Creates a function that wraps `func` to invoke it with the `this` binding + * of `thisArg` and `partials` prepended to the arguments it receives. + * + * @private + * @param {Function} func The function to wrap. + * @param {number} bitmask The bitmask flags. See `createWrap` for more details. + * @param {*} thisArg The `this` binding of `func`. + * @param {Array} partials The arguments to prepend to those provided to + * the new function. + * @returns {Function} Returns the new wrapped function. + */ + function createPartial(func, bitmask, thisArg, partials) { + var isBind = bitmask & WRAP_BIND_FLAG, + Ctor = createCtor(func); + + function wrapper() { + var argsIndex = -1, + argsLength = arguments.length, + leftIndex = -1, + leftLength = partials.length, + args = Array(leftLength + argsLength), + fn = (this && this !== root && this instanceof wrapper) ? Ctor : func; + + while (++leftIndex < leftLength) { + args[leftIndex] = partials[leftIndex]; + } + while (argsLength--) { + args[leftIndex++] = arguments[++argsIndex]; + } + return apply(fn, isBind ? thisArg : this, args); + } + return wrapper; + } + + /** + * Creates a `_.range` or `_.rangeRight` function. + * + * @private + * @param {boolean} [fromRight] Specify iterating from right to left. + * @returns {Function} Returns the new range function. + */ + function createRange(fromRight) { + return function(start, end, step) { + if (step && typeof step != 'number' && isIterateeCall(start, end, step)) { + end = step = undefined; + } + // Ensure the sign of `-0` is preserved. + start = toFinite(start); + if (end === undefined) { + end = start; + start = 0; + } else { + end = toFinite(end); + } + step = step === undefined ? (start < end ? 1 : -1) : toFinite(step); + return baseRange(start, end, step, fromRight); + }; + } + + /** + * Creates a function that performs a relational operation on two values. + * + * @private + * @param {Function} operator The function to perform the operation. + * @returns {Function} Returns the new relational operation function. + */ + function createRelationalOperation(operator) { + return function(value, other) { + if (!(typeof value == 'string' && typeof other == 'string')) { + value = toNumber(value); + other = toNumber(other); + } + return operator(value, other); + }; + } + + /** + * Creates a function that wraps `func` to continue currying. + * + * @private + * @param {Function} func The function to wrap. + * @param {number} bitmask The bitmask flags. See `createWrap` for more details. + * @param {Function} wrapFunc The function to create the `func` wrapper. + * @param {*} placeholder The placeholder value. + * @param {*} [thisArg] The `this` binding of `func`. + * @param {Array} [partials] The arguments to prepend to those provided to + * the new function. + * @param {Array} [holders] The `partials` placeholder indexes. + * @param {Array} [argPos] The argument positions of the new function. + * @param {number} [ary] The arity cap of `func`. + * @param {number} [arity] The arity of `func`. + * @returns {Function} Returns the new wrapped function. + */ + function createRecurry(func, bitmask, wrapFunc, placeholder, thisArg, partials, holders, argPos, ary, arity) { + var isCurry = bitmask & WRAP_CURRY_FLAG, + newHolders = isCurry ? holders : undefined, + newHoldersRight = isCurry ? undefined : holders, + newPartials = isCurry ? partials : undefined, + newPartialsRight = isCurry ? undefined : partials; + + bitmask |= (isCurry ? WRAP_PARTIAL_FLAG : WRAP_PARTIAL_RIGHT_FLAG); + bitmask &= ~(isCurry ? WRAP_PARTIAL_RIGHT_FLAG : WRAP_PARTIAL_FLAG); + + if (!(bitmask & WRAP_CURRY_BOUND_FLAG)) { + bitmask &= ~(WRAP_BIND_FLAG | WRAP_BIND_KEY_FLAG); + } + var newData = [ + func, bitmask, thisArg, newPartials, newHolders, newPartialsRight, + newHoldersRight, argPos, ary, arity + ]; + + var result = wrapFunc.apply(undefined, newData); + if (isLaziable(func)) { + setData(result, newData); + } + result.placeholder = placeholder; + return setWrapToString(result, func, bitmask); + } + + /** + * Creates a function like `_.round`. + * + * @private + * @param {string} methodName The name of the `Math` method to use when rounding. + * @returns {Function} Returns the new round function. + */ + function createRound(methodName) { + var func = Math[methodName]; + return function(number, precision) { + number = toNumber(number); + precision = precision == null ? 0 : nativeMin(toInteger(precision), 292); + if (precision) { + // Shift with exponential notation to avoid floating-point issues. + // See [MDN](https://mdn.io/round#Examples) for more details. + var pair = (toString(number) + 'e').split('e'), + value = func(pair[0] + 'e' + (+pair[1] + precision)); + + pair = (toString(value) + 'e').split('e'); + return +(pair[0] + 'e' + (+pair[1] - precision)); + } + return func(number); + }; + } + + /** + * Creates a set object of `values`. + * + * @private + * @param {Array} values The values to add to the set. + * @returns {Object} Returns the new set. + */ + var createSet = !(Set && (1 / setToArray(new Set([,-0]))[1]) == INFINITY) ? noop : function(values) { + return new Set(values); + }; + + /** + * Creates a `_.toPairs` or `_.toPairsIn` function. + * + * @private + * @param {Function} keysFunc The function to get the keys of a given object. + * @returns {Function} Returns the new pairs function. + */ + function createToPairs(keysFunc) { + return function(object) { + var tag = getTag(object); + if (tag == mapTag) { + return mapToArray(object); + } + if (tag == setTag) { + return setToPairs(object); + } + return baseToPairs(object, keysFunc(object)); + }; + } + + /** + * Creates a function that either curries or invokes `func` with optional + * `this` binding and partially applied arguments. + * + * @private + * @param {Function|string} func The function or method name to wrap. + * @param {number} bitmask The bitmask flags. + * 1 - `_.bind` + * 2 - `_.bindKey` + * 4 - `_.curry` or `_.curryRight` of a bound function + * 8 - `_.curry` + * 16 - `_.curryRight` + * 32 - `_.partial` + * 64 - `_.partialRight` + * 128 - `_.rearg` + * 256 - `_.ary` + * 512 - `_.flip` + * @param {*} [thisArg] The `this` binding of `func`. + * @param {Array} [partials] The arguments to be partially applied. + * @param {Array} [holders] The `partials` placeholder indexes. + * @param {Array} [argPos] The argument positions of the new function. + * @param {number} [ary] The arity cap of `func`. + * @param {number} [arity] The arity of `func`. + * @returns {Function} Returns the new wrapped function. + */ + function createWrap(func, bitmask, thisArg, partials, holders, argPos, ary, arity) { + var isBindKey = bitmask & WRAP_BIND_KEY_FLAG; + if (!isBindKey && typeof func != 'function') { + throw new TypeError(FUNC_ERROR_TEXT); + } + var length = partials ? partials.length : 0; + if (!length) { + bitmask &= ~(WRAP_PARTIAL_FLAG | WRAP_PARTIAL_RIGHT_FLAG); + partials = holders = undefined; + } + ary = ary === undefined ? ary : nativeMax(toInteger(ary), 0); + arity = arity === undefined ? arity : toInteger(arity); + length -= holders ? holders.length : 0; + + if (bitmask & WRAP_PARTIAL_RIGHT_FLAG) { + var partialsRight = partials, + holdersRight = holders; + + partials = holders = undefined; + } + var data = isBindKey ? undefined : getData(func); + + var newData = [ + func, bitmask, thisArg, partials, holders, partialsRight, holdersRight, + argPos, ary, arity + ]; + + if (data) { + mergeData(newData, data); + } + func = newData[0]; + bitmask = newData[1]; + thisArg = newData[2]; + partials = newData[3]; + holders = newData[4]; + arity = newData[9] = newData[9] === undefined + ? (isBindKey ? 0 : func.length) + : nativeMax(newData[9] - length, 0); + + if (!arity && bitmask & (WRAP_CURRY_FLAG | WRAP_CURRY_RIGHT_FLAG)) { + bitmask &= ~(WRAP_CURRY_FLAG | WRAP_CURRY_RIGHT_FLAG); + } + if (!bitmask || bitmask == WRAP_BIND_FLAG) { + var result = createBind(func, bitmask, thisArg); + } else if (bitmask == WRAP_CURRY_FLAG || bitmask == WRAP_CURRY_RIGHT_FLAG) { + result = createCurry(func, bitmask, arity); + } else if ((bitmask == WRAP_PARTIAL_FLAG || bitmask == (WRAP_BIND_FLAG | WRAP_PARTIAL_FLAG)) && !holders.length) { + result = createPartial(func, bitmask, thisArg, partials); + } else { + result = createHybrid.apply(undefined, newData); + } + var setter = data ? baseSetData : setData; + return setWrapToString(setter(result, newData), func, bitmask); + } + + /** + * Used by `_.defaults` to customize its `_.assignIn` use to assign properties + * of source objects to the destination object for all destination properties + * that resolve to `undefined`. + * + * @private + * @param {*} objValue The destination value. + * @param {*} srcValue The source value. + * @param {string} key The key of the property to assign. + * @param {Object} object The parent object of `objValue`. + * @returns {*} Returns the value to assign. + */ + function customDefaultsAssignIn(objValue, srcValue, key, object) { + if (objValue === undefined || + (eq(objValue, objectProto[key]) && !hasOwnProperty.call(object, key))) { + return srcValue; + } + return objValue; + } + + /** + * Used by `_.defaultsDeep` to customize its `_.merge` use to merge source + * objects into destination objects that are passed thru. + * + * @private + * @param {*} objValue The destination value. + * @param {*} srcValue The source value. + * @param {string} key The key of the property to merge. + * @param {Object} object The parent object of `objValue`. + * @param {Object} source The parent object of `srcValue`. + * @param {Object} [stack] Tracks traversed source values and their merged + * counterparts. + * @returns {*} Returns the value to assign. + */ + function customDefaultsMerge(objValue, srcValue, key, object, source, stack) { + if (isObject(objValue) && isObject(srcValue)) { + // Recursively merge objects and arrays (susceptible to call stack limits). + stack.set(srcValue, objValue); + baseMerge(objValue, srcValue, undefined, customDefaultsMerge, stack); + stack['delete'](srcValue); + } + return objValue; + } + + /** + * Used by `_.omit` to customize its `_.cloneDeep` use to only clone plain + * objects. + * + * @private + * @param {*} value The value to inspect. + * @param {string} key The key of the property to inspect. + * @returns {*} Returns the uncloned value or `undefined` to defer cloning to `_.cloneDeep`. + */ + function customOmitClone(value) { + return isPlainObject(value) ? undefined : value; + } + + /** + * A specialized version of `baseIsEqualDeep` for arrays with support for + * partial deep comparisons. + * + * @private + * @param {Array} array The array to compare. + * @param {Array} other The other array to compare. + * @param {number} bitmask The bitmask flags. See `baseIsEqual` for more details. + * @param {Function} customizer The function to customize comparisons. + * @param {Function} equalFunc The function to determine equivalents of values. + * @param {Object} stack Tracks traversed `array` and `other` objects. + * @returns {boolean} Returns `true` if the arrays are equivalent, else `false`. + */ + function equalArrays(array, other, bitmask, customizer, equalFunc, stack) { + var isPartial = bitmask & COMPARE_PARTIAL_FLAG, + arrLength = array.length, + othLength = other.length; + + if (arrLength != othLength && !(isPartial && othLength > arrLength)) { + return false; + } + // Assume cyclic values are equal. + var stacked = stack.get(array); + if (stacked && stack.get(other)) { + return stacked == other; + } + var index = -1, + result = true, + seen = (bitmask & COMPARE_UNORDERED_FLAG) ? new SetCache : undefined; + + stack.set(array, other); + stack.set(other, array); + + // Ignore non-index properties. + while (++index < arrLength) { + var arrValue = array[index], + othValue = other[index]; + + if (customizer) { + var compared = isPartial + ? customizer(othValue, arrValue, index, other, array, stack) + : customizer(arrValue, othValue, index, array, other, stack); + } + if (compared !== undefined) { + if (compared) { + continue; + } + result = false; + break; + } + // Recursively compare arrays (susceptible to call stack limits). + if (seen) { + if (!arraySome(other, function(othValue, othIndex) { + if (!cacheHas(seen, othIndex) && + (arrValue === othValue || equalFunc(arrValue, othValue, bitmask, customizer, stack))) { + return seen.push(othIndex); + } + })) { + result = false; + break; + } + } else if (!( + arrValue === othValue || + equalFunc(arrValue, othValue, bitmask, customizer, stack) + )) { + result = false; + break; + } + } + stack['delete'](array); + stack['delete'](other); + return result; + } + + /** + * A specialized version of `baseIsEqualDeep` for comparing objects of + * the same `toStringTag`. + * + * **Note:** This function only supports comparing values with tags of + * `Boolean`, `Date`, `Error`, `Number`, `RegExp`, or `String`. + * + * @private + * @param {Object} object The object to compare. + * @param {Object} other The other object to compare. + * @param {string} tag The `toStringTag` of the objects to compare. + * @param {number} bitmask The bitmask flags. See `baseIsEqual` for more details. + * @param {Function} customizer The function to customize comparisons. + * @param {Function} equalFunc The function to determine equivalents of values. + * @param {Object} stack Tracks traversed `object` and `other` objects. + * @returns {boolean} Returns `true` if the objects are equivalent, else `false`. + */ + function equalByTag(object, other, tag, bitmask, customizer, equalFunc, stack) { + switch (tag) { + case dataViewTag: + if ((object.byteLength != other.byteLength) || + (object.byteOffset != other.byteOffset)) { + return false; + } + object = object.buffer; + other = other.buffer; + + case arrayBufferTag: + if ((object.byteLength != other.byteLength) || + !equalFunc(new Uint8Array(object), new Uint8Array(other))) { + return false; + } + return true; + + case boolTag: + case dateTag: + case numberTag: + // Coerce booleans to `1` or `0` and dates to milliseconds. + // Invalid dates are coerced to `NaN`. + return eq(+object, +other); + + case errorTag: + return object.name == other.name && object.message == other.message; + + case regexpTag: + case stringTag: + // Coerce regexes to strings and treat strings, primitives and objects, + // as equal. See http://www.ecma-international.org/ecma-262/7.0/#sec-regexp.prototype.tostring + // for more details. + return object == (other + ''); + + case mapTag: + var convert = mapToArray; + + case setTag: + var isPartial = bitmask & COMPARE_PARTIAL_FLAG; + convert || (convert = setToArray); + + if (object.size != other.size && !isPartial) { + return false; + } + // Assume cyclic values are equal. + var stacked = stack.get(object); + if (stacked) { + return stacked == other; + } + bitmask |= COMPARE_UNORDERED_FLAG; + + // Recursively compare objects (susceptible to call stack limits). + stack.set(object, other); + var result = equalArrays(convert(object), convert(other), bitmask, customizer, equalFunc, stack); + stack['delete'](object); + return result; + + case symbolTag: + if (symbolValueOf) { + return symbolValueOf.call(object) == symbolValueOf.call(other); + } + } + return false; + } + + /** + * A specialized version of `baseIsEqualDeep` for objects with support for + * partial deep comparisons. + * + * @private + * @param {Object} object The object to compare. + * @param {Object} other The other object to compare. + * @param {number} bitmask The bitmask flags. See `baseIsEqual` for more details. + * @param {Function} customizer The function to customize comparisons. + * @param {Function} equalFunc The function to determine equivalents of values. + * @param {Object} stack Tracks traversed `object` and `other` objects. + * @returns {boolean} Returns `true` if the objects are equivalent, else `false`. + */ + function equalObjects(object, other, bitmask, customizer, equalFunc, stack) { + var isPartial = bitmask & COMPARE_PARTIAL_FLAG, + objProps = getAllKeys(object), + objLength = objProps.length, + othProps = getAllKeys(other), + othLength = othProps.length; + + if (objLength != othLength && !isPartial) { + return false; + } + var index = objLength; + while (index--) { + var key = objProps[index]; + if (!(isPartial ? key in other : hasOwnProperty.call(other, key))) { + return false; + } + } + // Assume cyclic values are equal. + var stacked = stack.get(object); + if (stacked && stack.get(other)) { + return stacked == other; + } + var result = true; + stack.set(object, other); + stack.set(other, object); + + var skipCtor = isPartial; + while (++index < objLength) { + key = objProps[index]; + var objValue = object[key], + othValue = other[key]; + + if (customizer) { + var compared = isPartial + ? customizer(othValue, objValue, key, other, object, stack) + : customizer(objValue, othValue, key, object, other, stack); + } + // Recursively compare objects (susceptible to call stack limits). + if (!(compared === undefined + ? (objValue === othValue || equalFunc(objValue, othValue, bitmask, customizer, stack)) + : compared + )) { + result = false; + break; + } + skipCtor || (skipCtor = key == 'constructor'); + } + if (result && !skipCtor) { + var objCtor = object.constructor, + othCtor = other.constructor; + + // Non `Object` object instances with different constructors are not equal. + if (objCtor != othCtor && + ('constructor' in object && 'constructor' in other) && + !(typeof objCtor == 'function' && objCtor instanceof objCtor && + typeof othCtor == 'function' && othCtor instanceof othCtor)) { + result = false; + } + } + stack['delete'](object); + stack['delete'](other); + return result; + } + + /** + * A specialized version of `baseRest` which flattens the rest array. + * + * @private + * @param {Function} func The function to apply a rest parameter to. + * @returns {Function} Returns the new function. + */ + function flatRest(func) { + return setToString(overRest(func, undefined, flatten), func + ''); + } + + /** + * Creates an array of own enumerable property names and symbols of `object`. + * + * @private + * @param {Object} object The object to query. + * @returns {Array} Returns the array of property names and symbols. + */ + function getAllKeys(object) { + return baseGetAllKeys(object, keys, getSymbols); + } + + /** + * Creates an array of own and inherited enumerable property names and + * symbols of `object`. + * + * @private + * @param {Object} object The object to query. + * @returns {Array} Returns the array of property names and symbols. + */ + function getAllKeysIn(object) { + return baseGetAllKeys(object, keysIn, getSymbolsIn); + } + + /** + * Gets metadata for `func`. + * + * @private + * @param {Function} func The function to query. + * @returns {*} Returns the metadata for `func`. + */ + var getData = !metaMap ? noop : function(func) { + return metaMap.get(func); + }; + + /** + * Gets the name of `func`. + * + * @private + * @param {Function} func The function to query. + * @returns {string} Returns the function name. + */ + function getFuncName(func) { + var result = (func.name + ''), + array = realNames[result], + length = hasOwnProperty.call(realNames, result) ? array.length : 0; + + while (length--) { + var data = array[length], + otherFunc = data.func; + if (otherFunc == null || otherFunc == func) { + return data.name; + } + } + return result; + } + + /** + * Gets the argument placeholder value for `func`. + * + * @private + * @param {Function} func The function to inspect. + * @returns {*} Returns the placeholder value. + */ + function getHolder(func) { + var object = hasOwnProperty.call(lodash, 'placeholder') ? lodash : func; + return object.placeholder; + } + + /** + * Gets the appropriate "iteratee" function. If `_.iteratee` is customized, + * this function returns the custom method, otherwise it returns `baseIteratee`. + * If arguments are provided, the chosen function is invoked with them and + * its result is returned. + * + * @private + * @param {*} [value] The value to convert to an iteratee. + * @param {number} [arity] The arity of the created iteratee. + * @returns {Function} Returns the chosen function or its result. + */ + function getIteratee() { + var result = lodash.iteratee || iteratee; + result = result === iteratee ? baseIteratee : result; + return arguments.length ? result(arguments[0], arguments[1]) : result; + } + + /** + * Gets the data for `map`. + * + * @private + * @param {Object} map The map to query. + * @param {string} key The reference key. + * @returns {*} Returns the map data. + */ + function getMapData(map, key) { + var data = map.__data__; + return isKeyable(key) + ? data[typeof key == 'string' ? 'string' : 'hash'] + : data.map; + } + + /** + * Gets the property names, values, and compare flags of `object`. + * + * @private + * @param {Object} object The object to query. + * @returns {Array} Returns the match data of `object`. + */ + function getMatchData(object) { + var result = keys(object), + length = result.length; + + while (length--) { + var key = result[length], + value = object[key]; + + result[length] = [key, value, isStrictComparable(value)]; + } + return result; + } + + /** + * Gets the native function at `key` of `object`. + * + * @private + * @param {Object} object The object to query. + * @param {string} key The key of the method to get. + * @returns {*} Returns the function if it's native, else `undefined`. + */ + function getNative(object, key) { + var value = getValue(object, key); + return baseIsNative(value) ? value : undefined; + } + + /** + * A specialized version of `baseGetTag` which ignores `Symbol.toStringTag` values. + * + * @private + * @param {*} value The value to query. + * @returns {string} Returns the raw `toStringTag`. + */ + function getRawTag(value) { + var isOwn = hasOwnProperty.call(value, symToStringTag), + tag = value[symToStringTag]; + + try { + value[symToStringTag] = undefined; + var unmasked = true; + } catch (e) {} + + var result = nativeObjectToString.call(value); + if (unmasked) { + if (isOwn) { + value[symToStringTag] = tag; + } else { + delete value[symToStringTag]; + } + } + return result; + } + + /** + * Creates an array of the own enumerable symbols of `object`. + * + * @private + * @param {Object} object The object to query. + * @returns {Array} Returns the array of symbols. + */ + var getSymbols = !nativeGetSymbols ? stubArray : function(object) { + if (object == null) { + return []; + } + object = Object(object); + return arrayFilter(nativeGetSymbols(object), function(symbol) { + return propertyIsEnumerable.call(object, symbol); + }); + }; + + /** + * Creates an array of the own and inherited enumerable symbols of `object`. + * + * @private + * @param {Object} object The object to query. + * @returns {Array} Returns the array of symbols. + */ + var getSymbolsIn = !nativeGetSymbols ? stubArray : function(object) { + var result = []; + while (object) { + arrayPush(result, getSymbols(object)); + object = getPrototype(object); + } + return result; + }; + + /** + * Gets the `toStringTag` of `value`. + * + * @private + * @param {*} value The value to query. + * @returns {string} Returns the `toStringTag`. + */ + var getTag = baseGetTag; + + // Fallback for data views, maps, sets, and weak maps in IE 11 and promises in Node.js < 6. + if ((DataView && getTag(new DataView(new ArrayBuffer(1))) != dataViewTag) || + (Map && getTag(new Map) != mapTag) || + (Promise && getTag(Promise.resolve()) != promiseTag) || + (Set && getTag(new Set) != setTag) || + (WeakMap && getTag(new WeakMap) != weakMapTag)) { + getTag = function(value) { + var result = baseGetTag(value), + Ctor = result == objectTag ? value.constructor : undefined, + ctorString = Ctor ? toSource(Ctor) : ''; + + if (ctorString) { + switch (ctorString) { + case dataViewCtorString: return dataViewTag; + case mapCtorString: return mapTag; + case promiseCtorString: return promiseTag; + case setCtorString: return setTag; + case weakMapCtorString: return weakMapTag; + } + } + return result; + }; + } + + /** + * Gets the view, applying any `transforms` to the `start` and `end` positions. + * + * @private + * @param {number} start The start of the view. + * @param {number} end The end of the view. + * @param {Array} transforms The transformations to apply to the view. + * @returns {Object} Returns an object containing the `start` and `end` + * positions of the view. + */ + function getView(start, end, transforms) { + var index = -1, + length = transforms.length; + + while (++index < length) { + var data = transforms[index], + size = data.size; + + switch (data.type) { + case 'drop': start += size; break; + case 'dropRight': end -= size; break; + case 'take': end = nativeMin(end, start + size); break; + case 'takeRight': start = nativeMax(start, end - size); break; + } + } + return { 'start': start, 'end': end }; + } + + /** + * Extracts wrapper details from the `source` body comment. + * + * @private + * @param {string} source The source to inspect. + * @returns {Array} Returns the wrapper details. + */ + function getWrapDetails(source) { + var match = source.match(reWrapDetails); + return match ? match[1].split(reSplitDetails) : []; + } + + /** + * Checks if `path` exists on `object`. + * + * @private + * @param {Object} object The object to query. + * @param {Array|string} path The path to check. + * @param {Function} hasFunc The function to check properties. + * @returns {boolean} Returns `true` if `path` exists, else `false`. + */ + function hasPath(object, path, hasFunc) { + path = castPath(path, object); + + var index = -1, + length = path.length, + result = false; + + while (++index < length) { + var key = toKey(path[index]); + if (!(result = object != null && hasFunc(object, key))) { + break; + } + object = object[key]; + } + if (result || ++index != length) { + return result; + } + length = object == null ? 0 : object.length; + return !!length && isLength(length) && isIndex(key, length) && + (isArray(object) || isArguments(object)); + } + + /** + * Initializes an array clone. + * + * @private + * @param {Array} array The array to clone. + * @returns {Array} Returns the initialized clone. + */ + function initCloneArray(array) { + var length = array.length, + result = new array.constructor(length); + + // Add properties assigned by `RegExp#exec`. + if (length && typeof array[0] == 'string' && hasOwnProperty.call(array, 'index')) { + result.index = array.index; + result.input = array.input; + } + return result; + } + + /** + * Initializes an object clone. + * + * @private + * @param {Object} object The object to clone. + * @returns {Object} Returns the initialized clone. + */ + function initCloneObject(object) { + return (typeof object.constructor == 'function' && !isPrototype(object)) + ? baseCreate(getPrototype(object)) + : {}; + } + + /** + * Initializes an object clone based on its `toStringTag`. + * + * **Note:** This function only supports cloning values with tags of + * `Boolean`, `Date`, `Error`, `Map`, `Number`, `RegExp`, `Set`, or `String`. + * + * @private + * @param {Object} object The object to clone. + * @param {string} tag The `toStringTag` of the object to clone. + * @param {boolean} [isDeep] Specify a deep clone. + * @returns {Object} Returns the initialized clone. + */ + function initCloneByTag(object, tag, isDeep) { + var Ctor = object.constructor; + switch (tag) { + case arrayBufferTag: + return cloneArrayBuffer(object); + + case boolTag: + case dateTag: + return new Ctor(+object); + + case dataViewTag: + return cloneDataView(object, isDeep); + + case float32Tag: case float64Tag: + case int8Tag: case int16Tag: case int32Tag: + case uint8Tag: case uint8ClampedTag: case uint16Tag: case uint32Tag: + return cloneTypedArray(object, isDeep); + + case mapTag: + return new Ctor; + + case numberTag: + case stringTag: + return new Ctor(object); + + case regexpTag: + return cloneRegExp(object); + + case setTag: + return new Ctor; + + case symbolTag: + return cloneSymbol(object); + } + } + + /** + * Inserts wrapper `details` in a comment at the top of the `source` body. + * + * @private + * @param {string} source The source to modify. + * @returns {Array} details The details to insert. + * @returns {string} Returns the modified source. + */ + function insertWrapDetails(source, details) { + var length = details.length; + if (!length) { + return source; + } + var lastIndex = length - 1; + details[lastIndex] = (length > 1 ? '& ' : '') + details[lastIndex]; + details = details.join(length > 2 ? ', ' : ' '); + return source.replace(reWrapComment, '{\n/* [wrapped with ' + details + '] */\n'); + } + + /** + * Checks if `value` is a flattenable `arguments` object or array. + * + * @private + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is flattenable, else `false`. + */ + function isFlattenable(value) { + return isArray(value) || isArguments(value) || + !!(spreadableSymbol && value && value[spreadableSymbol]); + } + + /** + * Checks if `value` is a valid array-like index. + * + * @private + * @param {*} value The value to check. + * @param {number} [length=MAX_SAFE_INTEGER] The upper bounds of a valid index. + * @returns {boolean} Returns `true` if `value` is a valid index, else `false`. + */ + function isIndex(value, length) { + var type = typeof value; + length = length == null ? MAX_SAFE_INTEGER : length; + + return !!length && + (type == 'number' || + (type != 'symbol' && reIsUint.test(value))) && + (value > -1 && value % 1 == 0 && value < length); + } + + /** + * Checks if the given arguments are from an iteratee call. + * + * @private + * @param {*} value The potential iteratee value argument. + * @param {*} index The potential iteratee index or key argument. + * @param {*} object The potential iteratee object argument. + * @returns {boolean} Returns `true` if the arguments are from an iteratee call, + * else `false`. + */ + function isIterateeCall(value, index, object) { + if (!isObject(object)) { + return false; + } + var type = typeof index; + if (type == 'number' + ? (isArrayLike(object) && isIndex(index, object.length)) + : (type == 'string' && index in object) + ) { + return eq(object[index], value); + } + return false; + } + + /** + * Checks if `value` is a property name and not a property path. + * + * @private + * @param {*} value The value to check. + * @param {Object} [object] The object to query keys on. + * @returns {boolean} Returns `true` if `value` is a property name, else `false`. + */ + function isKey(value, object) { + if (isArray(value)) { + return false; + } + var type = typeof value; + if (type == 'number' || type == 'symbol' || type == 'boolean' || + value == null || isSymbol(value)) { + return true; + } + return reIsPlainProp.test(value) || !reIsDeepProp.test(value) || + (object != null && value in Object(object)); + } + + /** + * Checks if `value` is suitable for use as unique object key. + * + * @private + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is suitable, else `false`. + */ + function isKeyable(value) { + var type = typeof value; + return (type == 'string' || type == 'number' || type == 'symbol' || type == 'boolean') + ? (value !== '__proto__') + : (value === null); + } + + /** + * Checks if `func` has a lazy counterpart. + * + * @private + * @param {Function} func The function to check. + * @returns {boolean} Returns `true` if `func` has a lazy counterpart, + * else `false`. + */ + function isLaziable(func) { + var funcName = getFuncName(func), + other = lodash[funcName]; + + if (typeof other != 'function' || !(funcName in LazyWrapper.prototype)) { + return false; + } + if (func === other) { + return true; + } + var data = getData(other); + return !!data && func === data[0]; + } + + /** + * Checks if `func` has its source masked. + * + * @private + * @param {Function} func The function to check. + * @returns {boolean} Returns `true` if `func` is masked, else `false`. + */ + function isMasked(func) { + return !!maskSrcKey && (maskSrcKey in func); + } + + /** + * Checks if `func` is capable of being masked. + * + * @private + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `func` is maskable, else `false`. + */ + var isMaskable = coreJsData ? isFunction : stubFalse; + + /** + * Checks if `value` is likely a prototype object. + * + * @private + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a prototype, else `false`. + */ + function isPrototype(value) { + var Ctor = value && value.constructor, + proto = (typeof Ctor == 'function' && Ctor.prototype) || objectProto; + + return value === proto; + } + + /** + * Checks if `value` is suitable for strict equality comparisons, i.e. `===`. + * + * @private + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` if suitable for strict + * equality comparisons, else `false`. + */ + function isStrictComparable(value) { + return value === value && !isObject(value); + } + + /** + * A specialized version of `matchesProperty` for source values suitable + * for strict equality comparisons, i.e. `===`. + * + * @private + * @param {string} key The key of the property to get. + * @param {*} srcValue The value to match. + * @returns {Function} Returns the new spec function. + */ + function matchesStrictComparable(key, srcValue) { + return function(object) { + if (object == null) { + return false; + } + return object[key] === srcValue && + (srcValue !== undefined || (key in Object(object))); + }; + } + + /** + * A specialized version of `_.memoize` which clears the memoized function's + * cache when it exceeds `MAX_MEMOIZE_SIZE`. + * + * @private + * @param {Function} func The function to have its output memoized. + * @returns {Function} Returns the new memoized function. + */ + function memoizeCapped(func) { + var result = memoize(func, function(key) { + if (cache.size === MAX_MEMOIZE_SIZE) { + cache.clear(); + } + return key; + }); + + var cache = result.cache; + return result; + } + + /** + * Merges the function metadata of `source` into `data`. + * + * Merging metadata reduces the number of wrappers used to invoke a function. + * This is possible because methods like `_.bind`, `_.curry`, and `_.partial` + * may be applied regardless of execution order. Methods like `_.ary` and + * `_.rearg` modify function arguments, making the order in which they are + * executed important, preventing the merging of metadata. However, we make + * an exception for a safe combined case where curried functions have `_.ary` + * and or `_.rearg` applied. + * + * @private + * @param {Array} data The destination metadata. + * @param {Array} source The source metadata. + * @returns {Array} Returns `data`. + */ + function mergeData(data, source) { + var bitmask = data[1], + srcBitmask = source[1], + newBitmask = bitmask | srcBitmask, + isCommon = newBitmask < (WRAP_BIND_FLAG | WRAP_BIND_KEY_FLAG | WRAP_ARY_FLAG); + + var isCombo = + ((srcBitmask == WRAP_ARY_FLAG) && (bitmask == WRAP_CURRY_FLAG)) || + ((srcBitmask == WRAP_ARY_FLAG) && (bitmask == WRAP_REARG_FLAG) && (data[7].length <= source[8])) || + ((srcBitmask == (WRAP_ARY_FLAG | WRAP_REARG_FLAG)) && (source[7].length <= source[8]) && (bitmask == WRAP_CURRY_FLAG)); + + // Exit early if metadata can't be merged. + if (!(isCommon || isCombo)) { + return data; + } + // Use source `thisArg` if available. + if (srcBitmask & WRAP_BIND_FLAG) { + data[2] = source[2]; + // Set when currying a bound function. + newBitmask |= bitmask & WRAP_BIND_FLAG ? 0 : WRAP_CURRY_BOUND_FLAG; + } + // Compose partial arguments. + var value = source[3]; + if (value) { + var partials = data[3]; + data[3] = partials ? composeArgs(partials, value, source[4]) : value; + data[4] = partials ? replaceHolders(data[3], PLACEHOLDER) : source[4]; + } + // Compose partial right arguments. + value = source[5]; + if (value) { + partials = data[5]; + data[5] = partials ? composeArgsRight(partials, value, source[6]) : value; + data[6] = partials ? replaceHolders(data[5], PLACEHOLDER) : source[6]; + } + // Use source `argPos` if available. + value = source[7]; + if (value) { + data[7] = value; + } + // Use source `ary` if it's smaller. + if (srcBitmask & WRAP_ARY_FLAG) { + data[8] = data[8] == null ? source[8] : nativeMin(data[8], source[8]); + } + // Use source `arity` if one is not provided. + if (data[9] == null) { + data[9] = source[9]; + } + // Use source `func` and merge bitmasks. + data[0] = source[0]; + data[1] = newBitmask; + + return data; + } + + /** + * This function is like + * [`Object.keys`](http://ecma-international.org/ecma-262/7.0/#sec-object.keys) + * except that it includes inherited enumerable properties. + * + * @private + * @param {Object} object The object to query. + * @returns {Array} Returns the array of property names. + */ + function nativeKeysIn(object) { + var result = []; + if (object != null) { + for (var key in Object(object)) { + result.push(key); + } + } + return result; + } + + /** + * Converts `value` to a string using `Object.prototype.toString`. + * + * @private + * @param {*} value The value to convert. + * @returns {string} Returns the converted string. + */ + function objectToString(value) { + return nativeObjectToString.call(value); + } + + /** + * A specialized version of `baseRest` which transforms the rest array. + * + * @private + * @param {Function} func The function to apply a rest parameter to. + * @param {number} [start=func.length-1] The start position of the rest parameter. + * @param {Function} transform The rest array transform. + * @returns {Function} Returns the new function. + */ + function overRest(func, start, transform) { + start = nativeMax(start === undefined ? (func.length - 1) : start, 0); + return function() { + var args = arguments, + index = -1, + length = nativeMax(args.length - start, 0), + array = Array(length); + + while (++index < length) { + array[index] = args[start + index]; + } + index = -1; + var otherArgs = Array(start + 1); + while (++index < start) { + otherArgs[index] = args[index]; + } + otherArgs[start] = transform(array); + return apply(func, this, otherArgs); + }; + } + + /** + * Gets the parent value at `path` of `object`. + * + * @private + * @param {Object} object The object to query. + * @param {Array} path The path to get the parent value of. + * @returns {*} Returns the parent value. + */ + function parent(object, path) { + return path.length < 2 ? object : baseGet(object, baseSlice(path, 0, -1)); + } + + /** + * Reorder `array` according to the specified indexes where the element at + * the first index is assigned as the first element, the element at + * the second index is assigned as the second element, and so on. + * + * @private + * @param {Array} array The array to reorder. + * @param {Array} indexes The arranged array indexes. + * @returns {Array} Returns `array`. + */ + function reorder(array, indexes) { + var arrLength = array.length, + length = nativeMin(indexes.length, arrLength), + oldArray = copyArray(array); + + while (length--) { + var index = indexes[length]; + array[length] = isIndex(index, arrLength) ? oldArray[index] : undefined; + } + return array; + } + + /** + * Gets the value at `key`, unless `key` is "__proto__". + * + * @private + * @param {Object} object The object to query. + * @param {string} key The key of the property to get. + * @returns {*} Returns the property value. + */ + function safeGet(object, key) { + if (key == '__proto__') { + return; + } + + return object[key]; + } + + /** + * Sets metadata for `func`. + * + * **Note:** If this function becomes hot, i.e. is invoked a lot in a short + * period of time, it will trip its breaker and transition to an identity + * function to avoid garbage collection pauses in V8. See + * [V8 issue 2070](https://bugs.chromium.org/p/v8/issues/detail?id=2070) + * for more details. + * + * @private + * @param {Function} func The function to associate metadata with. + * @param {*} data The metadata. + * @returns {Function} Returns `func`. + */ + var setData = shortOut(baseSetData); + + /** + * A simple wrapper around the global [`setTimeout`](https://mdn.io/setTimeout). + * + * @private + * @param {Function} func The function to delay. + * @param {number} wait The number of milliseconds to delay invocation. + * @returns {number|Object} Returns the timer id or timeout object. + */ + var setTimeout = ctxSetTimeout || function(func, wait) { + return root.setTimeout(func, wait); + }; + + /** + * Sets the `toString` method of `func` to return `string`. + * + * @private + * @param {Function} func The function to modify. + * @param {Function} string The `toString` result. + * @returns {Function} Returns `func`. + */ + var setToString = shortOut(baseSetToString); + + /** + * Sets the `toString` method of `wrapper` to mimic the source of `reference` + * with wrapper details in a comment at the top of the source body. + * + * @private + * @param {Function} wrapper The function to modify. + * @param {Function} reference The reference function. + * @param {number} bitmask The bitmask flags. See `createWrap` for more details. + * @returns {Function} Returns `wrapper`. + */ + function setWrapToString(wrapper, reference, bitmask) { + var source = (reference + ''); + return setToString(wrapper, insertWrapDetails(source, updateWrapDetails(getWrapDetails(source), bitmask))); + } + + /** + * Creates a function that'll short out and invoke `identity` instead + * of `func` when it's called `HOT_COUNT` or more times in `HOT_SPAN` + * milliseconds. + * + * @private + * @param {Function} func The function to restrict. + * @returns {Function} Returns the new shortable function. + */ + function shortOut(func) { + var count = 0, + lastCalled = 0; + + return function() { + var stamp = nativeNow(), + remaining = HOT_SPAN - (stamp - lastCalled); + + lastCalled = stamp; + if (remaining > 0) { + if (++count >= HOT_COUNT) { + return arguments[0]; + } + } else { + count = 0; + } + return func.apply(undefined, arguments); + }; + } + + /** + * A specialized version of `_.shuffle` which mutates and sets the size of `array`. + * + * @private + * @param {Array} array The array to shuffle. + * @param {number} [size=array.length] The size of `array`. + * @returns {Array} Returns `array`. + */ + function shuffleSelf(array, size) { + var index = -1, + length = array.length, + lastIndex = length - 1; + + size = size === undefined ? length : size; + while (++index < size) { + var rand = baseRandom(index, lastIndex), + value = array[rand]; + + array[rand] = array[index]; + array[index] = value; + } + array.length = size; + return array; + } + + /** + * Converts `string` to a property path array. + * + * @private + * @param {string} string The string to convert. + * @returns {Array} Returns the property path array. + */ + var stringToPath = memoizeCapped(function(string) { + var result = []; + if (string.charCodeAt(0) === 46 /* . */) { + result.push(''); + } + string.replace(rePropName, function(match, number, quote, subString) { + result.push(quote ? subString.replace(reEscapeChar, '$1') : (number || match)); + }); + return result; + }); + + /** + * Converts `value` to a string key if it's not a string or symbol. + * + * @private + * @param {*} value The value to inspect. + * @returns {string|symbol} Returns the key. + */ + function toKey(value) { + if (typeof value == 'string' || isSymbol(value)) { + return value; + } + var result = (value + ''); + return (result == '0' && (1 / value) == -INFINITY) ? '-0' : result; + } + + /** + * Converts `func` to its source code. + * + * @private + * @param {Function} func The function to convert. + * @returns {string} Returns the source code. + */ + function toSource(func) { + if (func != null) { + try { + return funcToString.call(func); + } catch (e) {} + try { + return (func + ''); + } catch (e) {} + } + return ''; + } + + /** + * Updates wrapper `details` based on `bitmask` flags. + * + * @private + * @returns {Array} details The details to modify. + * @param {number} bitmask The bitmask flags. See `createWrap` for more details. + * @returns {Array} Returns `details`. + */ + function updateWrapDetails(details, bitmask) { + arrayEach(wrapFlags, function(pair) { + var value = '_.' + pair[0]; + if ((bitmask & pair[1]) && !arrayIncludes(details, value)) { + details.push(value); + } + }); + return details.sort(); + } + + /** + * Creates a clone of `wrapper`. + * + * @private + * @param {Object} wrapper The wrapper to clone. + * @returns {Object} Returns the cloned wrapper. + */ + function wrapperClone(wrapper) { + if (wrapper instanceof LazyWrapper) { + return wrapper.clone(); + } + var result = new LodashWrapper(wrapper.__wrapped__, wrapper.__chain__); + result.__actions__ = copyArray(wrapper.__actions__); + result.__index__ = wrapper.__index__; + result.__values__ = wrapper.__values__; + return result; + } + + /*------------------------------------------------------------------------*/ + + /** + * Creates an array of elements split into groups the length of `size`. + * If `array` can't be split evenly, the final chunk will be the remaining + * elements. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Array + * @param {Array} array The array to process. + * @param {number} [size=1] The length of each chunk + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`. + * @returns {Array} Returns the new array of chunks. + * @example + * + * _.chunk(['a', 'b', 'c', 'd'], 2); + * // => [['a', 'b'], ['c', 'd']] + * + * _.chunk(['a', 'b', 'c', 'd'], 3); + * // => [['a', 'b', 'c'], ['d']] + */ + function chunk(array, size, guard) { + if ((guard ? isIterateeCall(array, size, guard) : size === undefined)) { + size = 1; + } else { + size = nativeMax(toInteger(size), 0); + } + var length = array == null ? 0 : array.length; + if (!length || size < 1) { + return []; + } + var index = 0, + resIndex = 0, + result = Array(nativeCeil(length / size)); + + while (index < length) { + result[resIndex++] = baseSlice(array, index, (index += size)); + } + return result; + } + + /** + * Creates an array with all falsey values removed. The values `false`, `null`, + * `0`, `""`, `undefined`, and `NaN` are falsey. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Array + * @param {Array} array The array to compact. + * @returns {Array} Returns the new array of filtered values. + * @example + * + * _.compact([0, 1, false, 2, '', 3]); + * // => [1, 2, 3] + */ + function compact(array) { + var index = -1, + length = array == null ? 0 : array.length, + resIndex = 0, + result = []; + + while (++index < length) { + var value = array[index]; + if (value) { + result[resIndex++] = value; + } + } + return result; + } + + /** + * Creates a new array concatenating `array` with any additional arrays + * and/or values. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The array to concatenate. + * @param {...*} [values] The values to concatenate. + * @returns {Array} Returns the new concatenated array. + * @example + * + * var array = [1]; + * var other = _.concat(array, 2, [3], [[4]]); + * + * console.log(other); + * // => [1, 2, 3, [4]] + * + * console.log(array); + * // => [1] + */ + function concat() { + var length = arguments.length; + if (!length) { + return []; + } + var args = Array(length - 1), + array = arguments[0], + index = length; + + while (index--) { + args[index - 1] = arguments[index]; + } + return arrayPush(isArray(array) ? copyArray(array) : [array], baseFlatten(args, 1)); + } + + /** + * Creates an array of `array` values not included in the other given arrays + * using [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero) + * for equality comparisons. The order and references of result values are + * determined by the first array. + * + * **Note:** Unlike `_.pullAll`, this method returns a new array. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Array + * @param {Array} array The array to inspect. + * @param {...Array} [values] The values to exclude. + * @returns {Array} Returns the new array of filtered values. + * @see _.without, _.xor + * @example + * + * _.difference([2, 1], [2, 3]); + * // => [1] + */ + var difference = baseRest(function(array, values) { + return isArrayLikeObject(array) + ? baseDifference(array, baseFlatten(values, 1, isArrayLikeObject, true)) + : []; + }); + + /** + * This method is like `_.difference` except that it accepts `iteratee` which + * is invoked for each element of `array` and `values` to generate the criterion + * by which they're compared. The order and references of result values are + * determined by the first array. The iteratee is invoked with one argument: + * (value). + * + * **Note:** Unlike `_.pullAllBy`, this method returns a new array. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The array to inspect. + * @param {...Array} [values] The values to exclude. + * @param {Function} [iteratee=_.identity] The iteratee invoked per element. + * @returns {Array} Returns the new array of filtered values. + * @example + * + * _.differenceBy([2.1, 1.2], [2.3, 3.4], Math.floor); + * // => [1.2] + * + * // The `_.property` iteratee shorthand. + * _.differenceBy([{ 'x': 2 }, { 'x': 1 }], [{ 'x': 1 }], 'x'); + * // => [{ 'x': 2 }] + */ + var differenceBy = baseRest(function(array, values) { + var iteratee = last(values); + if (isArrayLikeObject(iteratee)) { + iteratee = undefined; + } + return isArrayLikeObject(array) + ? baseDifference(array, baseFlatten(values, 1, isArrayLikeObject, true), getIteratee(iteratee, 2)) + : []; + }); + + /** + * This method is like `_.difference` except that it accepts `comparator` + * which is invoked to compare elements of `array` to `values`. The order and + * references of result values are determined by the first array. The comparator + * is invoked with two arguments: (arrVal, othVal). + * + * **Note:** Unlike `_.pullAllWith`, this method returns a new array. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The array to inspect. + * @param {...Array} [values] The values to exclude. + * @param {Function} [comparator] The comparator invoked per element. + * @returns {Array} Returns the new array of filtered values. + * @example + * + * var objects = [{ 'x': 1, 'y': 2 }, { 'x': 2, 'y': 1 }]; + * + * _.differenceWith(objects, [{ 'x': 1, 'y': 2 }], _.isEqual); + * // => [{ 'x': 2, 'y': 1 }] + */ + var differenceWith = baseRest(function(array, values) { + var comparator = last(values); + if (isArrayLikeObject(comparator)) { + comparator = undefined; + } + return isArrayLikeObject(array) + ? baseDifference(array, baseFlatten(values, 1, isArrayLikeObject, true), undefined, comparator) + : []; + }); + + /** + * Creates a slice of `array` with `n` elements dropped from the beginning. + * + * @static + * @memberOf _ + * @since 0.5.0 + * @category Array + * @param {Array} array The array to query. + * @param {number} [n=1] The number of elements to drop. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`. + * @returns {Array} Returns the slice of `array`. + * @example + * + * _.drop([1, 2, 3]); + * // => [2, 3] + * + * _.drop([1, 2, 3], 2); + * // => [3] + * + * _.drop([1, 2, 3], 5); + * // => [] + * + * _.drop([1, 2, 3], 0); + * // => [1, 2, 3] + */ + function drop(array, n, guard) { + var length = array == null ? 0 : array.length; + if (!length) { + return []; + } + n = (guard || n === undefined) ? 1 : toInteger(n); + return baseSlice(array, n < 0 ? 0 : n, length); + } + + /** + * Creates a slice of `array` with `n` elements dropped from the end. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Array + * @param {Array} array The array to query. + * @param {number} [n=1] The number of elements to drop. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`. + * @returns {Array} Returns the slice of `array`. + * @example + * + * _.dropRight([1, 2, 3]); + * // => [1, 2] + * + * _.dropRight([1, 2, 3], 2); + * // => [1] + * + * _.dropRight([1, 2, 3], 5); + * // => [] + * + * _.dropRight([1, 2, 3], 0); + * // => [1, 2, 3] + */ + function dropRight(array, n, guard) { + var length = array == null ? 0 : array.length; + if (!length) { + return []; + } + n = (guard || n === undefined) ? 1 : toInteger(n); + n = length - n; + return baseSlice(array, 0, n < 0 ? 0 : n); + } + + /** + * Creates a slice of `array` excluding elements dropped from the end. + * Elements are dropped until `predicate` returns falsey. The predicate is + * invoked with three arguments: (value, index, array). + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Array + * @param {Array} array The array to query. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @returns {Array} Returns the slice of `array`. + * @example + * + * var users = [ + * { 'user': 'barney', 'active': true }, + * { 'user': 'fred', 'active': false }, + * { 'user': 'pebbles', 'active': false } + * ]; + * + * _.dropRightWhile(users, function(o) { return !o.active; }); + * // => objects for ['barney'] + * + * // The `_.matches` iteratee shorthand. + * _.dropRightWhile(users, { 'user': 'pebbles', 'active': false }); + * // => objects for ['barney', 'fred'] + * + * // The `_.matchesProperty` iteratee shorthand. + * _.dropRightWhile(users, ['active', false]); + * // => objects for ['barney'] + * + * // The `_.property` iteratee shorthand. + * _.dropRightWhile(users, 'active'); + * // => objects for ['barney', 'fred', 'pebbles'] + */ + function dropRightWhile(array, predicate) { + return (array && array.length) + ? baseWhile(array, getIteratee(predicate, 3), true, true) + : []; + } + + /** + * Creates a slice of `array` excluding elements dropped from the beginning. + * Elements are dropped until `predicate` returns falsey. The predicate is + * invoked with three arguments: (value, index, array). + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Array + * @param {Array} array The array to query. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @returns {Array} Returns the slice of `array`. + * @example + * + * var users = [ + * { 'user': 'barney', 'active': false }, + * { 'user': 'fred', 'active': false }, + * { 'user': 'pebbles', 'active': true } + * ]; + * + * _.dropWhile(users, function(o) { return !o.active; }); + * // => objects for ['pebbles'] + * + * // The `_.matches` iteratee shorthand. + * _.dropWhile(users, { 'user': 'barney', 'active': false }); + * // => objects for ['fred', 'pebbles'] + * + * // The `_.matchesProperty` iteratee shorthand. + * _.dropWhile(users, ['active', false]); + * // => objects for ['pebbles'] + * + * // The `_.property` iteratee shorthand. + * _.dropWhile(users, 'active'); + * // => objects for ['barney', 'fred', 'pebbles'] + */ + function dropWhile(array, predicate) { + return (array && array.length) + ? baseWhile(array, getIteratee(predicate, 3), true) + : []; + } + + /** + * Fills elements of `array` with `value` from `start` up to, but not + * including, `end`. + * + * **Note:** This method mutates `array`. + * + * @static + * @memberOf _ + * @since 3.2.0 + * @category Array + * @param {Array} array The array to fill. + * @param {*} value The value to fill `array` with. + * @param {number} [start=0] The start position. + * @param {number} [end=array.length] The end position. + * @returns {Array} Returns `array`. + * @example + * + * var array = [1, 2, 3]; + * + * _.fill(array, 'a'); + * console.log(array); + * // => ['a', 'a', 'a'] + * + * _.fill(Array(3), 2); + * // => [2, 2, 2] + * + * _.fill([4, 6, 8, 10], '*', 1, 3); + * // => [4, '*', '*', 10] + */ + function fill(array, value, start, end) { + var length = array == null ? 0 : array.length; + if (!length) { + return []; + } + if (start && typeof start != 'number' && isIterateeCall(array, value, start)) { + start = 0; + end = length; + } + return baseFill(array, value, start, end); + } + + /** + * This method is like `_.find` except that it returns the index of the first + * element `predicate` returns truthy for instead of the element itself. + * + * @static + * @memberOf _ + * @since 1.1.0 + * @category Array + * @param {Array} array The array to inspect. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @param {number} [fromIndex=0] The index to search from. + * @returns {number} Returns the index of the found element, else `-1`. + * @example + * + * var users = [ + * { 'user': 'barney', 'active': false }, + * { 'user': 'fred', 'active': false }, + * { 'user': 'pebbles', 'active': true } + * ]; + * + * _.findIndex(users, function(o) { return o.user == 'barney'; }); + * // => 0 + * + * // The `_.matches` iteratee shorthand. + * _.findIndex(users, { 'user': 'fred', 'active': false }); + * // => 1 + * + * // The `_.matchesProperty` iteratee shorthand. + * _.findIndex(users, ['active', false]); + * // => 0 + * + * // The `_.property` iteratee shorthand. + * _.findIndex(users, 'active'); + * // => 2 + */ + function findIndex(array, predicate, fromIndex) { + var length = array == null ? 0 : array.length; + if (!length) { + return -1; + } + var index = fromIndex == null ? 0 : toInteger(fromIndex); + if (index < 0) { + index = nativeMax(length + index, 0); + } + return baseFindIndex(array, getIteratee(predicate, 3), index); + } + + /** + * This method is like `_.findIndex` except that it iterates over elements + * of `collection` from right to left. + * + * @static + * @memberOf _ + * @since 2.0.0 + * @category Array + * @param {Array} array The array to inspect. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @param {number} [fromIndex=array.length-1] The index to search from. + * @returns {number} Returns the index of the found element, else `-1`. + * @example + * + * var users = [ + * { 'user': 'barney', 'active': true }, + * { 'user': 'fred', 'active': false }, + * { 'user': 'pebbles', 'active': false } + * ]; + * + * _.findLastIndex(users, function(o) { return o.user == 'pebbles'; }); + * // => 2 + * + * // The `_.matches` iteratee shorthand. + * _.findLastIndex(users, { 'user': 'barney', 'active': true }); + * // => 0 + * + * // The `_.matchesProperty` iteratee shorthand. + * _.findLastIndex(users, ['active', false]); + * // => 2 + * + * // The `_.property` iteratee shorthand. + * _.findLastIndex(users, 'active'); + * // => 0 + */ + function findLastIndex(array, predicate, fromIndex) { + var length = array == null ? 0 : array.length; + if (!length) { + return -1; + } + var index = length - 1; + if (fromIndex !== undefined) { + index = toInteger(fromIndex); + index = fromIndex < 0 + ? nativeMax(length + index, 0) + : nativeMin(index, length - 1); + } + return baseFindIndex(array, getIteratee(predicate, 3), index, true); + } + + /** + * Flattens `array` a single level deep. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Array + * @param {Array} array The array to flatten. + * @returns {Array} Returns the new flattened array. + * @example + * + * _.flatten([1, [2, [3, [4]], 5]]); + * // => [1, 2, [3, [4]], 5] + */ + function flatten(array) { + var length = array == null ? 0 : array.length; + return length ? baseFlatten(array, 1) : []; + } + + /** + * Recursively flattens `array`. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Array + * @param {Array} array The array to flatten. + * @returns {Array} Returns the new flattened array. + * @example + * + * _.flattenDeep([1, [2, [3, [4]], 5]]); + * // => [1, 2, 3, 4, 5] + */ + function flattenDeep(array) { + var length = array == null ? 0 : array.length; + return length ? baseFlatten(array, INFINITY) : []; + } + + /** + * Recursively flatten `array` up to `depth` times. + * + * @static + * @memberOf _ + * @since 4.4.0 + * @category Array + * @param {Array} array The array to flatten. + * @param {number} [depth=1] The maximum recursion depth. + * @returns {Array} Returns the new flattened array. + * @example + * + * var array = [1, [2, [3, [4]], 5]]; + * + * _.flattenDepth(array, 1); + * // => [1, 2, [3, [4]], 5] + * + * _.flattenDepth(array, 2); + * // => [1, 2, 3, [4], 5] + */ + function flattenDepth(array, depth) { + var length = array == null ? 0 : array.length; + if (!length) { + return []; + } + depth = depth === undefined ? 1 : toInteger(depth); + return baseFlatten(array, depth); + } + + /** + * The inverse of `_.toPairs`; this method returns an object composed + * from key-value `pairs`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} pairs The key-value pairs. + * @returns {Object} Returns the new object. + * @example + * + * _.fromPairs([['a', 1], ['b', 2]]); + * // => { 'a': 1, 'b': 2 } + */ + function fromPairs(pairs) { + var index = -1, + length = pairs == null ? 0 : pairs.length, + result = {}; + + while (++index < length) { + var pair = pairs[index]; + result[pair[0]] = pair[1]; + } + return result; + } + + /** + * Gets the first element of `array`. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @alias first + * @category Array + * @param {Array} array The array to query. + * @returns {*} Returns the first element of `array`. + * @example + * + * _.head([1, 2, 3]); + * // => 1 + * + * _.head([]); + * // => undefined + */ + function head(array) { + return (array && array.length) ? array[0] : undefined; + } + + /** + * Gets the index at which the first occurrence of `value` is found in `array` + * using [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero) + * for equality comparisons. If `fromIndex` is negative, it's used as the + * offset from the end of `array`. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Array + * @param {Array} array The array to inspect. + * @param {*} value The value to search for. + * @param {number} [fromIndex=0] The index to search from. + * @returns {number} Returns the index of the matched value, else `-1`. + * @example + * + * _.indexOf([1, 2, 1, 2], 2); + * // => 1 + * + * // Search from the `fromIndex`. + * _.indexOf([1, 2, 1, 2], 2, 2); + * // => 3 + */ + function indexOf(array, value, fromIndex) { + var length = array == null ? 0 : array.length; + if (!length) { + return -1; + } + var index = fromIndex == null ? 0 : toInteger(fromIndex); + if (index < 0) { + index = nativeMax(length + index, 0); + } + return baseIndexOf(array, value, index); + } + + /** + * Gets all but the last element of `array`. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Array + * @param {Array} array The array to query. + * @returns {Array} Returns the slice of `array`. + * @example + * + * _.initial([1, 2, 3]); + * // => [1, 2] + */ + function initial(array) { + var length = array == null ? 0 : array.length; + return length ? baseSlice(array, 0, -1) : []; + } + + /** + * Creates an array of unique values that are included in all given arrays + * using [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero) + * for equality comparisons. The order and references of result values are + * determined by the first array. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Array + * @param {...Array} [arrays] The arrays to inspect. + * @returns {Array} Returns the new array of intersecting values. + * @example + * + * _.intersection([2, 1], [2, 3]); + * // => [2] + */ + var intersection = baseRest(function(arrays) { + var mapped = arrayMap(arrays, castArrayLikeObject); + return (mapped.length && mapped[0] === arrays[0]) + ? baseIntersection(mapped) + : []; + }); + + /** + * This method is like `_.intersection` except that it accepts `iteratee` + * which is invoked for each element of each `arrays` to generate the criterion + * by which they're compared. The order and references of result values are + * determined by the first array. The iteratee is invoked with one argument: + * (value). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {...Array} [arrays] The arrays to inspect. + * @param {Function} [iteratee=_.identity] The iteratee invoked per element. + * @returns {Array} Returns the new array of intersecting values. + * @example + * + * _.intersectionBy([2.1, 1.2], [2.3, 3.4], Math.floor); + * // => [2.1] + * + * // The `_.property` iteratee shorthand. + * _.intersectionBy([{ 'x': 1 }], [{ 'x': 2 }, { 'x': 1 }], 'x'); + * // => [{ 'x': 1 }] + */ + var intersectionBy = baseRest(function(arrays) { + var iteratee = last(arrays), + mapped = arrayMap(arrays, castArrayLikeObject); + + if (iteratee === last(mapped)) { + iteratee = undefined; + } else { + mapped.pop(); + } + return (mapped.length && mapped[0] === arrays[0]) + ? baseIntersection(mapped, getIteratee(iteratee, 2)) + : []; + }); + + /** + * This method is like `_.intersection` except that it accepts `comparator` + * which is invoked to compare elements of `arrays`. The order and references + * of result values are determined by the first array. The comparator is + * invoked with two arguments: (arrVal, othVal). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {...Array} [arrays] The arrays to inspect. + * @param {Function} [comparator] The comparator invoked per element. + * @returns {Array} Returns the new array of intersecting values. + * @example + * + * var objects = [{ 'x': 1, 'y': 2 }, { 'x': 2, 'y': 1 }]; + * var others = [{ 'x': 1, 'y': 1 }, { 'x': 1, 'y': 2 }]; + * + * _.intersectionWith(objects, others, _.isEqual); + * // => [{ 'x': 1, 'y': 2 }] + */ + var intersectionWith = baseRest(function(arrays) { + var comparator = last(arrays), + mapped = arrayMap(arrays, castArrayLikeObject); + + comparator = typeof comparator == 'function' ? comparator : undefined; + if (comparator) { + mapped.pop(); + } + return (mapped.length && mapped[0] === arrays[0]) + ? baseIntersection(mapped, undefined, comparator) + : []; + }); + + /** + * Converts all elements in `array` into a string separated by `separator`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The array to convert. + * @param {string} [separator=','] The element separator. + * @returns {string} Returns the joined string. + * @example + * + * _.join(['a', 'b', 'c'], '~'); + * // => 'a~b~c' + */ + function join(array, separator) { + return array == null ? '' : nativeJoin.call(array, separator); + } + + /** + * Gets the last element of `array`. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Array + * @param {Array} array The array to query. + * @returns {*} Returns the last element of `array`. + * @example + * + * _.last([1, 2, 3]); + * // => 3 + */ + function last(array) { + var length = array == null ? 0 : array.length; + return length ? array[length - 1] : undefined; + } + + /** + * This method is like `_.indexOf` except that it iterates over elements of + * `array` from right to left. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Array + * @param {Array} array The array to inspect. + * @param {*} value The value to search for. + * @param {number} [fromIndex=array.length-1] The index to search from. + * @returns {number} Returns the index of the matched value, else `-1`. + * @example + * + * _.lastIndexOf([1, 2, 1, 2], 2); + * // => 3 + * + * // Search from the `fromIndex`. + * _.lastIndexOf([1, 2, 1, 2], 2, 2); + * // => 1 + */ + function lastIndexOf(array, value, fromIndex) { + var length = array == null ? 0 : array.length; + if (!length) { + return -1; + } + var index = length; + if (fromIndex !== undefined) { + index = toInteger(fromIndex); + index = index < 0 ? nativeMax(length + index, 0) : nativeMin(index, length - 1); + } + return value === value + ? strictLastIndexOf(array, value, index) + : baseFindIndex(array, baseIsNaN, index, true); + } + + /** + * Gets the element at index `n` of `array`. If `n` is negative, the nth + * element from the end is returned. + * + * @static + * @memberOf _ + * @since 4.11.0 + * @category Array + * @param {Array} array The array to query. + * @param {number} [n=0] The index of the element to return. + * @returns {*} Returns the nth element of `array`. + * @example + * + * var array = ['a', 'b', 'c', 'd']; + * + * _.nth(array, 1); + * // => 'b' + * + * _.nth(array, -2); + * // => 'c'; + */ + function nth(array, n) { + return (array && array.length) ? baseNth(array, toInteger(n)) : undefined; + } + + /** + * Removes all given values from `array` using + * [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero) + * for equality comparisons. + * + * **Note:** Unlike `_.without`, this method mutates `array`. Use `_.remove` + * to remove elements from an array by predicate. + * + * @static + * @memberOf _ + * @since 2.0.0 + * @category Array + * @param {Array} array The array to modify. + * @param {...*} [values] The values to remove. + * @returns {Array} Returns `array`. + * @example + * + * var array = ['a', 'b', 'c', 'a', 'b', 'c']; + * + * _.pull(array, 'a', 'c'); + * console.log(array); + * // => ['b', 'b'] + */ + var pull = baseRest(pullAll); + + /** + * This method is like `_.pull` except that it accepts an array of values to remove. + * + * **Note:** Unlike `_.difference`, this method mutates `array`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The array to modify. + * @param {Array} values The values to remove. + * @returns {Array} Returns `array`. + * @example + * + * var array = ['a', 'b', 'c', 'a', 'b', 'c']; + * + * _.pullAll(array, ['a', 'c']); + * console.log(array); + * // => ['b', 'b'] + */ + function pullAll(array, values) { + return (array && array.length && values && values.length) + ? basePullAll(array, values) + : array; + } + + /** + * This method is like `_.pullAll` except that it accepts `iteratee` which is + * invoked for each element of `array` and `values` to generate the criterion + * by which they're compared. The iteratee is invoked with one argument: (value). + * + * **Note:** Unlike `_.differenceBy`, this method mutates `array`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The array to modify. + * @param {Array} values The values to remove. + * @param {Function} [iteratee=_.identity] The iteratee invoked per element. + * @returns {Array} Returns `array`. + * @example + * + * var array = [{ 'x': 1 }, { 'x': 2 }, { 'x': 3 }, { 'x': 1 }]; + * + * _.pullAllBy(array, [{ 'x': 1 }, { 'x': 3 }], 'x'); + * console.log(array); + * // => [{ 'x': 2 }] + */ + function pullAllBy(array, values, iteratee) { + return (array && array.length && values && values.length) + ? basePullAll(array, values, getIteratee(iteratee, 2)) + : array; + } + + /** + * This method is like `_.pullAll` except that it accepts `comparator` which + * is invoked to compare elements of `array` to `values`. The comparator is + * invoked with two arguments: (arrVal, othVal). + * + * **Note:** Unlike `_.differenceWith`, this method mutates `array`. + * + * @static + * @memberOf _ + * @since 4.6.0 + * @category Array + * @param {Array} array The array to modify. + * @param {Array} values The values to remove. + * @param {Function} [comparator] The comparator invoked per element. + * @returns {Array} Returns `array`. + * @example + * + * var array = [{ 'x': 1, 'y': 2 }, { 'x': 3, 'y': 4 }, { 'x': 5, 'y': 6 }]; + * + * _.pullAllWith(array, [{ 'x': 3, 'y': 4 }], _.isEqual); + * console.log(array); + * // => [{ 'x': 1, 'y': 2 }, { 'x': 5, 'y': 6 }] + */ + function pullAllWith(array, values, comparator) { + return (array && array.length && values && values.length) + ? basePullAll(array, values, undefined, comparator) + : array; + } + + /** + * Removes elements from `array` corresponding to `indexes` and returns an + * array of removed elements. + * + * **Note:** Unlike `_.at`, this method mutates `array`. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Array + * @param {Array} array The array to modify. + * @param {...(number|number[])} [indexes] The indexes of elements to remove. + * @returns {Array} Returns the new array of removed elements. + * @example + * + * var array = ['a', 'b', 'c', 'd']; + * var pulled = _.pullAt(array, [1, 3]); + * + * console.log(array); + * // => ['a', 'c'] + * + * console.log(pulled); + * // => ['b', 'd'] + */ + var pullAt = flatRest(function(array, indexes) { + var length = array == null ? 0 : array.length, + result = baseAt(array, indexes); + + basePullAt(array, arrayMap(indexes, function(index) { + return isIndex(index, length) ? +index : index; + }).sort(compareAscending)); + + return result; + }); + + /** + * Removes all elements from `array` that `predicate` returns truthy for + * and returns an array of the removed elements. The predicate is invoked + * with three arguments: (value, index, array). + * + * **Note:** Unlike `_.filter`, this method mutates `array`. Use `_.pull` + * to pull elements from an array by value. + * + * @static + * @memberOf _ + * @since 2.0.0 + * @category Array + * @param {Array} array The array to modify. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @returns {Array} Returns the new array of removed elements. + * @example + * + * var array = [1, 2, 3, 4]; + * var evens = _.remove(array, function(n) { + * return n % 2 == 0; + * }); + * + * console.log(array); + * // => [1, 3] + * + * console.log(evens); + * // => [2, 4] + */ + function remove(array, predicate) { + var result = []; + if (!(array && array.length)) { + return result; + } + var index = -1, + indexes = [], + length = array.length; + + predicate = getIteratee(predicate, 3); + while (++index < length) { + var value = array[index]; + if (predicate(value, index, array)) { + result.push(value); + indexes.push(index); + } + } + basePullAt(array, indexes); + return result; + } + + /** + * Reverses `array` so that the first element becomes the last, the second + * element becomes the second to last, and so on. + * + * **Note:** This method mutates `array` and is based on + * [`Array#reverse`](https://mdn.io/Array/reverse). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The array to modify. + * @returns {Array} Returns `array`. + * @example + * + * var array = [1, 2, 3]; + * + * _.reverse(array); + * // => [3, 2, 1] + * + * console.log(array); + * // => [3, 2, 1] + */ + function reverse(array) { + return array == null ? array : nativeReverse.call(array); + } + + /** + * Creates a slice of `array` from `start` up to, but not including, `end`. + * + * **Note:** This method is used instead of + * [`Array#slice`](https://mdn.io/Array/slice) to ensure dense arrays are + * returned. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Array + * @param {Array} array The array to slice. + * @param {number} [start=0] The start position. + * @param {number} [end=array.length] The end position. + * @returns {Array} Returns the slice of `array`. + */ + function slice(array, start, end) { + var length = array == null ? 0 : array.length; + if (!length) { + return []; + } + if (end && typeof end != 'number' && isIterateeCall(array, start, end)) { + start = 0; + end = length; + } + else { + start = start == null ? 0 : toInteger(start); + end = end === undefined ? length : toInteger(end); + } + return baseSlice(array, start, end); + } + + /** + * Uses a binary search to determine the lowest index at which `value` + * should be inserted into `array` in order to maintain its sort order. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Array + * @param {Array} array The sorted array to inspect. + * @param {*} value The value to evaluate. + * @returns {number} Returns the index at which `value` should be inserted + * into `array`. + * @example + * + * _.sortedIndex([30, 50], 40); + * // => 1 + */ + function sortedIndex(array, value) { + return baseSortedIndex(array, value); + } + + /** + * This method is like `_.sortedIndex` except that it accepts `iteratee` + * which is invoked for `value` and each element of `array` to compute their + * sort ranking. The iteratee is invoked with one argument: (value). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The sorted array to inspect. + * @param {*} value The value to evaluate. + * @param {Function} [iteratee=_.identity] The iteratee invoked per element. + * @returns {number} Returns the index at which `value` should be inserted + * into `array`. + * @example + * + * var objects = [{ 'x': 4 }, { 'x': 5 }]; + * + * _.sortedIndexBy(objects, { 'x': 4 }, function(o) { return o.x; }); + * // => 0 + * + * // The `_.property` iteratee shorthand. + * _.sortedIndexBy(objects, { 'x': 4 }, 'x'); + * // => 0 + */ + function sortedIndexBy(array, value, iteratee) { + return baseSortedIndexBy(array, value, getIteratee(iteratee, 2)); + } + + /** + * This method is like `_.indexOf` except that it performs a binary + * search on a sorted `array`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The array to inspect. + * @param {*} value The value to search for. + * @returns {number} Returns the index of the matched value, else `-1`. + * @example + * + * _.sortedIndexOf([4, 5, 5, 5, 6], 5); + * // => 1 + */ + function sortedIndexOf(array, value) { + var length = array == null ? 0 : array.length; + if (length) { + var index = baseSortedIndex(array, value); + if (index < length && eq(array[index], value)) { + return index; + } + } + return -1; + } + + /** + * This method is like `_.sortedIndex` except that it returns the highest + * index at which `value` should be inserted into `array` in order to + * maintain its sort order. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Array + * @param {Array} array The sorted array to inspect. + * @param {*} value The value to evaluate. + * @returns {number} Returns the index at which `value` should be inserted + * into `array`. + * @example + * + * _.sortedLastIndex([4, 5, 5, 5, 6], 5); + * // => 4 + */ + function sortedLastIndex(array, value) { + return baseSortedIndex(array, value, true); + } + + /** + * This method is like `_.sortedLastIndex` except that it accepts `iteratee` + * which is invoked for `value` and each element of `array` to compute their + * sort ranking. The iteratee is invoked with one argument: (value). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The sorted array to inspect. + * @param {*} value The value to evaluate. + * @param {Function} [iteratee=_.identity] The iteratee invoked per element. + * @returns {number} Returns the index at which `value` should be inserted + * into `array`. + * @example + * + * var objects = [{ 'x': 4 }, { 'x': 5 }]; + * + * _.sortedLastIndexBy(objects, { 'x': 4 }, function(o) { return o.x; }); + * // => 1 + * + * // The `_.property` iteratee shorthand. + * _.sortedLastIndexBy(objects, { 'x': 4 }, 'x'); + * // => 1 + */ + function sortedLastIndexBy(array, value, iteratee) { + return baseSortedIndexBy(array, value, getIteratee(iteratee, 2), true); + } + + /** + * This method is like `_.lastIndexOf` except that it performs a binary + * search on a sorted `array`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The array to inspect. + * @param {*} value The value to search for. + * @returns {number} Returns the index of the matched value, else `-1`. + * @example + * + * _.sortedLastIndexOf([4, 5, 5, 5, 6], 5); + * // => 3 + */ + function sortedLastIndexOf(array, value) { + var length = array == null ? 0 : array.length; + if (length) { + var index = baseSortedIndex(array, value, true) - 1; + if (eq(array[index], value)) { + return index; + } + } + return -1; + } + + /** + * This method is like `_.uniq` except that it's designed and optimized + * for sorted arrays. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The array to inspect. + * @returns {Array} Returns the new duplicate free array. + * @example + * + * _.sortedUniq([1, 1, 2]); + * // => [1, 2] + */ + function sortedUniq(array) { + return (array && array.length) + ? baseSortedUniq(array) + : []; + } + + /** + * This method is like `_.uniqBy` except that it's designed and optimized + * for sorted arrays. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The array to inspect. + * @param {Function} [iteratee] The iteratee invoked per element. + * @returns {Array} Returns the new duplicate free array. + * @example + * + * _.sortedUniqBy([1.1, 1.2, 2.3, 2.4], Math.floor); + * // => [1.1, 2.3] + */ + function sortedUniqBy(array, iteratee) { + return (array && array.length) + ? baseSortedUniq(array, getIteratee(iteratee, 2)) + : []; + } + + /** + * Gets all but the first element of `array`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The array to query. + * @returns {Array} Returns the slice of `array`. + * @example + * + * _.tail([1, 2, 3]); + * // => [2, 3] + */ + function tail(array) { + var length = array == null ? 0 : array.length; + return length ? baseSlice(array, 1, length) : []; + } + + /** + * Creates a slice of `array` with `n` elements taken from the beginning. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Array + * @param {Array} array The array to query. + * @param {number} [n=1] The number of elements to take. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`. + * @returns {Array} Returns the slice of `array`. + * @example + * + * _.take([1, 2, 3]); + * // => [1] + * + * _.take([1, 2, 3], 2); + * // => [1, 2] + * + * _.take([1, 2, 3], 5); + * // => [1, 2, 3] + * + * _.take([1, 2, 3], 0); + * // => [] + */ + function take(array, n, guard) { + if (!(array && array.length)) { + return []; + } + n = (guard || n === undefined) ? 1 : toInteger(n); + return baseSlice(array, 0, n < 0 ? 0 : n); + } + + /** + * Creates a slice of `array` with `n` elements taken from the end. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Array + * @param {Array} array The array to query. + * @param {number} [n=1] The number of elements to take. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`. + * @returns {Array} Returns the slice of `array`. + * @example + * + * _.takeRight([1, 2, 3]); + * // => [3] + * + * _.takeRight([1, 2, 3], 2); + * // => [2, 3] + * + * _.takeRight([1, 2, 3], 5); + * // => [1, 2, 3] + * + * _.takeRight([1, 2, 3], 0); + * // => [] + */ + function takeRight(array, n, guard) { + var length = array == null ? 0 : array.length; + if (!length) { + return []; + } + n = (guard || n === undefined) ? 1 : toInteger(n); + n = length - n; + return baseSlice(array, n < 0 ? 0 : n, length); + } + + /** + * Creates a slice of `array` with elements taken from the end. Elements are + * taken until `predicate` returns falsey. The predicate is invoked with + * three arguments: (value, index, array). + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Array + * @param {Array} array The array to query. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @returns {Array} Returns the slice of `array`. + * @example + * + * var users = [ + * { 'user': 'barney', 'active': true }, + * { 'user': 'fred', 'active': false }, + * { 'user': 'pebbles', 'active': false } + * ]; + * + * _.takeRightWhile(users, function(o) { return !o.active; }); + * // => objects for ['fred', 'pebbles'] + * + * // The `_.matches` iteratee shorthand. + * _.takeRightWhile(users, { 'user': 'pebbles', 'active': false }); + * // => objects for ['pebbles'] + * + * // The `_.matchesProperty` iteratee shorthand. + * _.takeRightWhile(users, ['active', false]); + * // => objects for ['fred', 'pebbles'] + * + * // The `_.property` iteratee shorthand. + * _.takeRightWhile(users, 'active'); + * // => [] + */ + function takeRightWhile(array, predicate) { + return (array && array.length) + ? baseWhile(array, getIteratee(predicate, 3), false, true) + : []; + } + + /** + * Creates a slice of `array` with elements taken from the beginning. Elements + * are taken until `predicate` returns falsey. The predicate is invoked with + * three arguments: (value, index, array). + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Array + * @param {Array} array The array to query. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @returns {Array} Returns the slice of `array`. + * @example + * + * var users = [ + * { 'user': 'barney', 'active': false }, + * { 'user': 'fred', 'active': false }, + * { 'user': 'pebbles', 'active': true } + * ]; + * + * _.takeWhile(users, function(o) { return !o.active; }); + * // => objects for ['barney', 'fred'] + * + * // The `_.matches` iteratee shorthand. + * _.takeWhile(users, { 'user': 'barney', 'active': false }); + * // => objects for ['barney'] + * + * // The `_.matchesProperty` iteratee shorthand. + * _.takeWhile(users, ['active', false]); + * // => objects for ['barney', 'fred'] + * + * // The `_.property` iteratee shorthand. + * _.takeWhile(users, 'active'); + * // => [] + */ + function takeWhile(array, predicate) { + return (array && array.length) + ? baseWhile(array, getIteratee(predicate, 3)) + : []; + } + + /** + * Creates an array of unique values, in order, from all given arrays using + * [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero) + * for equality comparisons. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Array + * @param {...Array} [arrays] The arrays to inspect. + * @returns {Array} Returns the new array of combined values. + * @example + * + * _.union([2], [1, 2]); + * // => [2, 1] + */ + var union = baseRest(function(arrays) { + return baseUniq(baseFlatten(arrays, 1, isArrayLikeObject, true)); + }); + + /** + * This method is like `_.union` except that it accepts `iteratee` which is + * invoked for each element of each `arrays` to generate the criterion by + * which uniqueness is computed. Result values are chosen from the first + * array in which the value occurs. The iteratee is invoked with one argument: + * (value). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {...Array} [arrays] The arrays to inspect. + * @param {Function} [iteratee=_.identity] The iteratee invoked per element. + * @returns {Array} Returns the new array of combined values. + * @example + * + * _.unionBy([2.1], [1.2, 2.3], Math.floor); + * // => [2.1, 1.2] + * + * // The `_.property` iteratee shorthand. + * _.unionBy([{ 'x': 1 }], [{ 'x': 2 }, { 'x': 1 }], 'x'); + * // => [{ 'x': 1 }, { 'x': 2 }] + */ + var unionBy = baseRest(function(arrays) { + var iteratee = last(arrays); + if (isArrayLikeObject(iteratee)) { + iteratee = undefined; + } + return baseUniq(baseFlatten(arrays, 1, isArrayLikeObject, true), getIteratee(iteratee, 2)); + }); + + /** + * This method is like `_.union` except that it accepts `comparator` which + * is invoked to compare elements of `arrays`. Result values are chosen from + * the first array in which the value occurs. The comparator is invoked + * with two arguments: (arrVal, othVal). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {...Array} [arrays] The arrays to inspect. + * @param {Function} [comparator] The comparator invoked per element. + * @returns {Array} Returns the new array of combined values. + * @example + * + * var objects = [{ 'x': 1, 'y': 2 }, { 'x': 2, 'y': 1 }]; + * var others = [{ 'x': 1, 'y': 1 }, { 'x': 1, 'y': 2 }]; + * + * _.unionWith(objects, others, _.isEqual); + * // => [{ 'x': 1, 'y': 2 }, { 'x': 2, 'y': 1 }, { 'x': 1, 'y': 1 }] + */ + var unionWith = baseRest(function(arrays) { + var comparator = last(arrays); + comparator = typeof comparator == 'function' ? comparator : undefined; + return baseUniq(baseFlatten(arrays, 1, isArrayLikeObject, true), undefined, comparator); + }); + + /** + * Creates a duplicate-free version of an array, using + * [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero) + * for equality comparisons, in which only the first occurrence of each element + * is kept. The order of result values is determined by the order they occur + * in the array. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Array + * @param {Array} array The array to inspect. + * @returns {Array} Returns the new duplicate free array. + * @example + * + * _.uniq([2, 1, 2]); + * // => [2, 1] + */ + function uniq(array) { + return (array && array.length) ? baseUniq(array) : []; + } + + /** + * This method is like `_.uniq` except that it accepts `iteratee` which is + * invoked for each element in `array` to generate the criterion by which + * uniqueness is computed. The order of result values is determined by the + * order they occur in the array. The iteratee is invoked with one argument: + * (value). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The array to inspect. + * @param {Function} [iteratee=_.identity] The iteratee invoked per element. + * @returns {Array} Returns the new duplicate free array. + * @example + * + * _.uniqBy([2.1, 1.2, 2.3], Math.floor); + * // => [2.1, 1.2] + * + * // The `_.property` iteratee shorthand. + * _.uniqBy([{ 'x': 1 }, { 'x': 2 }, { 'x': 1 }], 'x'); + * // => [{ 'x': 1 }, { 'x': 2 }] + */ + function uniqBy(array, iteratee) { + return (array && array.length) ? baseUniq(array, getIteratee(iteratee, 2)) : []; + } + + /** + * This method is like `_.uniq` except that it accepts `comparator` which + * is invoked to compare elements of `array`. The order of result values is + * determined by the order they occur in the array.The comparator is invoked + * with two arguments: (arrVal, othVal). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {Array} array The array to inspect. + * @param {Function} [comparator] The comparator invoked per element. + * @returns {Array} Returns the new duplicate free array. + * @example + * + * var objects = [{ 'x': 1, 'y': 2 }, { 'x': 2, 'y': 1 }, { 'x': 1, 'y': 2 }]; + * + * _.uniqWith(objects, _.isEqual); + * // => [{ 'x': 1, 'y': 2 }, { 'x': 2, 'y': 1 }] + */ + function uniqWith(array, comparator) { + comparator = typeof comparator == 'function' ? comparator : undefined; + return (array && array.length) ? baseUniq(array, undefined, comparator) : []; + } + + /** + * This method is like `_.zip` except that it accepts an array of grouped + * elements and creates an array regrouping the elements to their pre-zip + * configuration. + * + * @static + * @memberOf _ + * @since 1.2.0 + * @category Array + * @param {Array} array The array of grouped elements to process. + * @returns {Array} Returns the new array of regrouped elements. + * @example + * + * var zipped = _.zip(['a', 'b'], [1, 2], [true, false]); + * // => [['a', 1, true], ['b', 2, false]] + * + * _.unzip(zipped); + * // => [['a', 'b'], [1, 2], [true, false]] + */ + function unzip(array) { + if (!(array && array.length)) { + return []; + } + var length = 0; + array = arrayFilter(array, function(group) { + if (isArrayLikeObject(group)) { + length = nativeMax(group.length, length); + return true; + } + }); + return baseTimes(length, function(index) { + return arrayMap(array, baseProperty(index)); + }); + } + + /** + * This method is like `_.unzip` except that it accepts `iteratee` to specify + * how regrouped values should be combined. The iteratee is invoked with the + * elements of each group: (...group). + * + * @static + * @memberOf _ + * @since 3.8.0 + * @category Array + * @param {Array} array The array of grouped elements to process. + * @param {Function} [iteratee=_.identity] The function to combine + * regrouped values. + * @returns {Array} Returns the new array of regrouped elements. + * @example + * + * var zipped = _.zip([1, 2], [10, 20], [100, 200]); + * // => [[1, 10, 100], [2, 20, 200]] + * + * _.unzipWith(zipped, _.add); + * // => [3, 30, 300] + */ + function unzipWith(array, iteratee) { + if (!(array && array.length)) { + return []; + } + var result = unzip(array); + if (iteratee == null) { + return result; + } + return arrayMap(result, function(group) { + return apply(iteratee, undefined, group); + }); + } + + /** + * Creates an array excluding all given values using + * [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero) + * for equality comparisons. + * + * **Note:** Unlike `_.pull`, this method returns a new array. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Array + * @param {Array} array The array to inspect. + * @param {...*} [values] The values to exclude. + * @returns {Array} Returns the new array of filtered values. + * @see _.difference, _.xor + * @example + * + * _.without([2, 1, 2, 3], 1, 2); + * // => [3] + */ + var without = baseRest(function(array, values) { + return isArrayLikeObject(array) + ? baseDifference(array, values) + : []; + }); + + /** + * Creates an array of unique values that is the + * [symmetric difference](https://en.wikipedia.org/wiki/Symmetric_difference) + * of the given arrays. The order of result values is determined by the order + * they occur in the arrays. + * + * @static + * @memberOf _ + * @since 2.4.0 + * @category Array + * @param {...Array} [arrays] The arrays to inspect. + * @returns {Array} Returns the new array of filtered values. + * @see _.difference, _.without + * @example + * + * _.xor([2, 1], [2, 3]); + * // => [1, 3] + */ + var xor = baseRest(function(arrays) { + return baseXor(arrayFilter(arrays, isArrayLikeObject)); + }); + + /** + * This method is like `_.xor` except that it accepts `iteratee` which is + * invoked for each element of each `arrays` to generate the criterion by + * which by which they're compared. The order of result values is determined + * by the order they occur in the arrays. The iteratee is invoked with one + * argument: (value). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {...Array} [arrays] The arrays to inspect. + * @param {Function} [iteratee=_.identity] The iteratee invoked per element. + * @returns {Array} Returns the new array of filtered values. + * @example + * + * _.xorBy([2.1, 1.2], [2.3, 3.4], Math.floor); + * // => [1.2, 3.4] + * + * // The `_.property` iteratee shorthand. + * _.xorBy([{ 'x': 1 }], [{ 'x': 2 }, { 'x': 1 }], 'x'); + * // => [{ 'x': 2 }] + */ + var xorBy = baseRest(function(arrays) { + var iteratee = last(arrays); + if (isArrayLikeObject(iteratee)) { + iteratee = undefined; + } + return baseXor(arrayFilter(arrays, isArrayLikeObject), getIteratee(iteratee, 2)); + }); + + /** + * This method is like `_.xor` except that it accepts `comparator` which is + * invoked to compare elements of `arrays`. The order of result values is + * determined by the order they occur in the arrays. The comparator is invoked + * with two arguments: (arrVal, othVal). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Array + * @param {...Array} [arrays] The arrays to inspect. + * @param {Function} [comparator] The comparator invoked per element. + * @returns {Array} Returns the new array of filtered values. + * @example + * + * var objects = [{ 'x': 1, 'y': 2 }, { 'x': 2, 'y': 1 }]; + * var others = [{ 'x': 1, 'y': 1 }, { 'x': 1, 'y': 2 }]; + * + * _.xorWith(objects, others, _.isEqual); + * // => [{ 'x': 2, 'y': 1 }, { 'x': 1, 'y': 1 }] + */ + var xorWith = baseRest(function(arrays) { + var comparator = last(arrays); + comparator = typeof comparator == 'function' ? comparator : undefined; + return baseXor(arrayFilter(arrays, isArrayLikeObject), undefined, comparator); + }); + + /** + * Creates an array of grouped elements, the first of which contains the + * first elements of the given arrays, the second of which contains the + * second elements of the given arrays, and so on. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Array + * @param {...Array} [arrays] The arrays to process. + * @returns {Array} Returns the new array of grouped elements. + * @example + * + * _.zip(['a', 'b'], [1, 2], [true, false]); + * // => [['a', 1, true], ['b', 2, false]] + */ + var zip = baseRest(unzip); + + /** + * This method is like `_.fromPairs` except that it accepts two arrays, + * one of property identifiers and one of corresponding values. + * + * @static + * @memberOf _ + * @since 0.4.0 + * @category Array + * @param {Array} [props=[]] The property identifiers. + * @param {Array} [values=[]] The property values. + * @returns {Object} Returns the new object. + * @example + * + * _.zipObject(['a', 'b'], [1, 2]); + * // => { 'a': 1, 'b': 2 } + */ + function zipObject(props, values) { + return baseZipObject(props || [], values || [], assignValue); + } + + /** + * This method is like `_.zipObject` except that it supports property paths. + * + * @static + * @memberOf _ + * @since 4.1.0 + * @category Array + * @param {Array} [props=[]] The property identifiers. + * @param {Array} [values=[]] The property values. + * @returns {Object} Returns the new object. + * @example + * + * _.zipObjectDeep(['a.b[0].c', 'a.b[1].d'], [1, 2]); + * // => { 'a': { 'b': [{ 'c': 1 }, { 'd': 2 }] } } + */ + function zipObjectDeep(props, values) { + return baseZipObject(props || [], values || [], baseSet); + } + + /** + * This method is like `_.zip` except that it accepts `iteratee` to specify + * how grouped values should be combined. The iteratee is invoked with the + * elements of each group: (...group). + * + * @static + * @memberOf _ + * @since 3.8.0 + * @category Array + * @param {...Array} [arrays] The arrays to process. + * @param {Function} [iteratee=_.identity] The function to combine + * grouped values. + * @returns {Array} Returns the new array of grouped elements. + * @example + * + * _.zipWith([1, 2], [10, 20], [100, 200], function(a, b, c) { + * return a + b + c; + * }); + * // => [111, 222] + */ + var zipWith = baseRest(function(arrays) { + var length = arrays.length, + iteratee = length > 1 ? arrays[length - 1] : undefined; + + iteratee = typeof iteratee == 'function' ? (arrays.pop(), iteratee) : undefined; + return unzipWith(arrays, iteratee); + }); + + /*------------------------------------------------------------------------*/ + + /** + * Creates a `lodash` wrapper instance that wraps `value` with explicit method + * chain sequences enabled. The result of such sequences must be unwrapped + * with `_#value`. + * + * @static + * @memberOf _ + * @since 1.3.0 + * @category Seq + * @param {*} value The value to wrap. + * @returns {Object} Returns the new `lodash` wrapper instance. + * @example + * + * var users = [ + * { 'user': 'barney', 'age': 36 }, + * { 'user': 'fred', 'age': 40 }, + * { 'user': 'pebbles', 'age': 1 } + * ]; + * + * var youngest = _ + * .chain(users) + * .sortBy('age') + * .map(function(o) { + * return o.user + ' is ' + o.age; + * }) + * .head() + * .value(); + * // => 'pebbles is 1' + */ + function chain(value) { + var result = lodash(value); + result.__chain__ = true; + return result; + } + + /** + * This method invokes `interceptor` and returns `value`. The interceptor + * is invoked with one argument; (value). The purpose of this method is to + * "tap into" a method chain sequence in order to modify intermediate results. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Seq + * @param {*} value The value to provide to `interceptor`. + * @param {Function} interceptor The function to invoke. + * @returns {*} Returns `value`. + * @example + * + * _([1, 2, 3]) + * .tap(function(array) { + * // Mutate input array. + * array.pop(); + * }) + * .reverse() + * .value(); + * // => [2, 1] + */ + function tap(value, interceptor) { + interceptor(value); + return value; + } + + /** + * This method is like `_.tap` except that it returns the result of `interceptor`. + * The purpose of this method is to "pass thru" values replacing intermediate + * results in a method chain sequence. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Seq + * @param {*} value The value to provide to `interceptor`. + * @param {Function} interceptor The function to invoke. + * @returns {*} Returns the result of `interceptor`. + * @example + * + * _(' abc ') + * .chain() + * .trim() + * .thru(function(value) { + * return [value]; + * }) + * .value(); + * // => ['abc'] + */ + function thru(value, interceptor) { + return interceptor(value); + } + + /** + * This method is the wrapper version of `_.at`. + * + * @name at + * @memberOf _ + * @since 1.0.0 + * @category Seq + * @param {...(string|string[])} [paths] The property paths to pick. + * @returns {Object} Returns the new `lodash` wrapper instance. + * @example + * + * var object = { 'a': [{ 'b': { 'c': 3 } }, 4] }; + * + * _(object).at(['a[0].b.c', 'a[1]']).value(); + * // => [3, 4] + */ + var wrapperAt = flatRest(function(paths) { + var length = paths.length, + start = length ? paths[0] : 0, + value = this.__wrapped__, + interceptor = function(object) { return baseAt(object, paths); }; + + if (length > 1 || this.__actions__.length || + !(value instanceof LazyWrapper) || !isIndex(start)) { + return this.thru(interceptor); + } + value = value.slice(start, +start + (length ? 1 : 0)); + value.__actions__.push({ + 'func': thru, + 'args': [interceptor], + 'thisArg': undefined + }); + return new LodashWrapper(value, this.__chain__).thru(function(array) { + if (length && !array.length) { + array.push(undefined); + } + return array; + }); + }); + + /** + * Creates a `lodash` wrapper instance with explicit method chain sequences enabled. + * + * @name chain + * @memberOf _ + * @since 0.1.0 + * @category Seq + * @returns {Object} Returns the new `lodash` wrapper instance. + * @example + * + * var users = [ + * { 'user': 'barney', 'age': 36 }, + * { 'user': 'fred', 'age': 40 } + * ]; + * + * // A sequence without explicit chaining. + * _(users).head(); + * // => { 'user': 'barney', 'age': 36 } + * + * // A sequence with explicit chaining. + * _(users) + * .chain() + * .head() + * .pick('user') + * .value(); + * // => { 'user': 'barney' } + */ + function wrapperChain() { + return chain(this); + } + + /** + * Executes the chain sequence and returns the wrapped result. + * + * @name commit + * @memberOf _ + * @since 3.2.0 + * @category Seq + * @returns {Object} Returns the new `lodash` wrapper instance. + * @example + * + * var array = [1, 2]; + * var wrapped = _(array).push(3); + * + * console.log(array); + * // => [1, 2] + * + * wrapped = wrapped.commit(); + * console.log(array); + * // => [1, 2, 3] + * + * wrapped.last(); + * // => 3 + * + * console.log(array); + * // => [1, 2, 3] + */ + function wrapperCommit() { + return new LodashWrapper(this.value(), this.__chain__); + } + + /** + * Gets the next value on a wrapped object following the + * [iterator protocol](https://mdn.io/iteration_protocols#iterator). + * + * @name next + * @memberOf _ + * @since 4.0.0 + * @category Seq + * @returns {Object} Returns the next iterator value. + * @example + * + * var wrapped = _([1, 2]); + * + * wrapped.next(); + * // => { 'done': false, 'value': 1 } + * + * wrapped.next(); + * // => { 'done': false, 'value': 2 } + * + * wrapped.next(); + * // => { 'done': true, 'value': undefined } + */ + function wrapperNext() { + if (this.__values__ === undefined) { + this.__values__ = toArray(this.value()); + } + var done = this.__index__ >= this.__values__.length, + value = done ? undefined : this.__values__[this.__index__++]; + + return { 'done': done, 'value': value }; + } + + /** + * Enables the wrapper to be iterable. + * + * @name Symbol.iterator + * @memberOf _ + * @since 4.0.0 + * @category Seq + * @returns {Object} Returns the wrapper object. + * @example + * + * var wrapped = _([1, 2]); + * + * wrapped[Symbol.iterator]() === wrapped; + * // => true + * + * Array.from(wrapped); + * // => [1, 2] + */ + function wrapperToIterator() { + return this; + } + + /** + * Creates a clone of the chain sequence planting `value` as the wrapped value. + * + * @name plant + * @memberOf _ + * @since 3.2.0 + * @category Seq + * @param {*} value The value to plant. + * @returns {Object} Returns the new `lodash` wrapper instance. + * @example + * + * function square(n) { + * return n * n; + * } + * + * var wrapped = _([1, 2]).map(square); + * var other = wrapped.plant([3, 4]); + * + * other.value(); + * // => [9, 16] + * + * wrapped.value(); + * // => [1, 4] + */ + function wrapperPlant(value) { + var result, + parent = this; + + while (parent instanceof baseLodash) { + var clone = wrapperClone(parent); + clone.__index__ = 0; + clone.__values__ = undefined; + if (result) { + previous.__wrapped__ = clone; + } else { + result = clone; + } + var previous = clone; + parent = parent.__wrapped__; + } + previous.__wrapped__ = value; + return result; + } + + /** + * This method is the wrapper version of `_.reverse`. + * + * **Note:** This method mutates the wrapped array. + * + * @name reverse + * @memberOf _ + * @since 0.1.0 + * @category Seq + * @returns {Object} Returns the new `lodash` wrapper instance. + * @example + * + * var array = [1, 2, 3]; + * + * _(array).reverse().value() + * // => [3, 2, 1] + * + * console.log(array); + * // => [3, 2, 1] + */ + function wrapperReverse() { + var value = this.__wrapped__; + if (value instanceof LazyWrapper) { + var wrapped = value; + if (this.__actions__.length) { + wrapped = new LazyWrapper(this); + } + wrapped = wrapped.reverse(); + wrapped.__actions__.push({ + 'func': thru, + 'args': [reverse], + 'thisArg': undefined + }); + return new LodashWrapper(wrapped, this.__chain__); + } + return this.thru(reverse); + } + + /** + * Executes the chain sequence to resolve the unwrapped value. + * + * @name value + * @memberOf _ + * @since 0.1.0 + * @alias toJSON, valueOf + * @category Seq + * @returns {*} Returns the resolved unwrapped value. + * @example + * + * _([1, 2, 3]).value(); + * // => [1, 2, 3] + */ + function wrapperValue() { + return baseWrapperValue(this.__wrapped__, this.__actions__); + } + + /*------------------------------------------------------------------------*/ + + /** + * Creates an object composed of keys generated from the results of running + * each element of `collection` thru `iteratee`. The corresponding value of + * each key is the number of times the key was returned by `iteratee`. The + * iteratee is invoked with one argument: (value). + * + * @static + * @memberOf _ + * @since 0.5.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [iteratee=_.identity] The iteratee to transform keys. + * @returns {Object} Returns the composed aggregate object. + * @example + * + * _.countBy([6.1, 4.2, 6.3], Math.floor); + * // => { '4': 1, '6': 2 } + * + * // The `_.property` iteratee shorthand. + * _.countBy(['one', 'two', 'three'], 'length'); + * // => { '3': 2, '5': 1 } + */ + var countBy = createAggregator(function(result, value, key) { + if (hasOwnProperty.call(result, key)) { + ++result[key]; + } else { + baseAssignValue(result, key, 1); + } + }); + + /** + * Checks if `predicate` returns truthy for **all** elements of `collection`. + * Iteration is stopped once `predicate` returns falsey. The predicate is + * invoked with three arguments: (value, index|key, collection). + * + * **Note:** This method returns `true` for + * [empty collections](https://en.wikipedia.org/wiki/Empty_set) because + * [everything is true](https://en.wikipedia.org/wiki/Vacuous_truth) of + * elements of empty collections. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`. + * @returns {boolean} Returns `true` if all elements pass the predicate check, + * else `false`. + * @example + * + * _.every([true, 1, null, 'yes'], Boolean); + * // => false + * + * var users = [ + * { 'user': 'barney', 'age': 36, 'active': false }, + * { 'user': 'fred', 'age': 40, 'active': false } + * ]; + * + * // The `_.matches` iteratee shorthand. + * _.every(users, { 'user': 'barney', 'active': false }); + * // => false + * + * // The `_.matchesProperty` iteratee shorthand. + * _.every(users, ['active', false]); + * // => true + * + * // The `_.property` iteratee shorthand. + * _.every(users, 'active'); + * // => false + */ + function every(collection, predicate, guard) { + var func = isArray(collection) ? arrayEvery : baseEvery; + if (guard && isIterateeCall(collection, predicate, guard)) { + predicate = undefined; + } + return func(collection, getIteratee(predicate, 3)); + } + + /** + * Iterates over elements of `collection`, returning an array of all elements + * `predicate` returns truthy for. The predicate is invoked with three + * arguments: (value, index|key, collection). + * + * **Note:** Unlike `_.remove`, this method returns a new array. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @returns {Array} Returns the new filtered array. + * @see _.reject + * @example + * + * var users = [ + * { 'user': 'barney', 'age': 36, 'active': true }, + * { 'user': 'fred', 'age': 40, 'active': false } + * ]; + * + * _.filter(users, function(o) { return !o.active; }); + * // => objects for ['fred'] + * + * // The `_.matches` iteratee shorthand. + * _.filter(users, { 'age': 36, 'active': true }); + * // => objects for ['barney'] + * + * // The `_.matchesProperty` iteratee shorthand. + * _.filter(users, ['active', false]); + * // => objects for ['fred'] + * + * // The `_.property` iteratee shorthand. + * _.filter(users, 'active'); + * // => objects for ['barney'] + */ + function filter(collection, predicate) { + var func = isArray(collection) ? arrayFilter : baseFilter; + return func(collection, getIteratee(predicate, 3)); + } + + /** + * Iterates over elements of `collection`, returning the first element + * `predicate` returns truthy for. The predicate is invoked with three + * arguments: (value, index|key, collection). + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Collection + * @param {Array|Object} collection The collection to inspect. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @param {number} [fromIndex=0] The index to search from. + * @returns {*} Returns the matched element, else `undefined`. + * @example + * + * var users = [ + * { 'user': 'barney', 'age': 36, 'active': true }, + * { 'user': 'fred', 'age': 40, 'active': false }, + * { 'user': 'pebbles', 'age': 1, 'active': true } + * ]; + * + * _.find(users, function(o) { return o.age < 40; }); + * // => object for 'barney' + * + * // The `_.matches` iteratee shorthand. + * _.find(users, { 'age': 1, 'active': true }); + * // => object for 'pebbles' + * + * // The `_.matchesProperty` iteratee shorthand. + * _.find(users, ['active', false]); + * // => object for 'fred' + * + * // The `_.property` iteratee shorthand. + * _.find(users, 'active'); + * // => object for 'barney' + */ + var find = createFind(findIndex); + + /** + * This method is like `_.find` except that it iterates over elements of + * `collection` from right to left. + * + * @static + * @memberOf _ + * @since 2.0.0 + * @category Collection + * @param {Array|Object} collection The collection to inspect. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @param {number} [fromIndex=collection.length-1] The index to search from. + * @returns {*} Returns the matched element, else `undefined`. + * @example + * + * _.findLast([1, 2, 3, 4], function(n) { + * return n % 2 == 1; + * }); + * // => 3 + */ + var findLast = createFind(findLastIndex); + + /** + * Creates a flattened array of values by running each element in `collection` + * thru `iteratee` and flattening the mapped results. The iteratee is invoked + * with three arguments: (value, index|key, collection). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @returns {Array} Returns the new flattened array. + * @example + * + * function duplicate(n) { + * return [n, n]; + * } + * + * _.flatMap([1, 2], duplicate); + * // => [1, 1, 2, 2] + */ + function flatMap(collection, iteratee) { + return baseFlatten(map(collection, iteratee), 1); + } + + /** + * This method is like `_.flatMap` except that it recursively flattens the + * mapped results. + * + * @static + * @memberOf _ + * @since 4.7.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @returns {Array} Returns the new flattened array. + * @example + * + * function duplicate(n) { + * return [[[n, n]]]; + * } + * + * _.flatMapDeep([1, 2], duplicate); + * // => [1, 1, 2, 2] + */ + function flatMapDeep(collection, iteratee) { + return baseFlatten(map(collection, iteratee), INFINITY); + } + + /** + * This method is like `_.flatMap` except that it recursively flattens the + * mapped results up to `depth` times. + * + * @static + * @memberOf _ + * @since 4.7.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @param {number} [depth=1] The maximum recursion depth. + * @returns {Array} Returns the new flattened array. + * @example + * + * function duplicate(n) { + * return [[[n, n]]]; + * } + * + * _.flatMapDepth([1, 2], duplicate, 2); + * // => [[1, 1], [2, 2]] + */ + function flatMapDepth(collection, iteratee, depth) { + depth = depth === undefined ? 1 : toInteger(depth); + return baseFlatten(map(collection, iteratee), depth); + } + + /** + * Iterates over elements of `collection` and invokes `iteratee` for each element. + * The iteratee is invoked with three arguments: (value, index|key, collection). + * Iteratee functions may exit iteration early by explicitly returning `false`. + * + * **Note:** As with other "Collections" methods, objects with a "length" + * property are iterated like arrays. To avoid this behavior use `_.forIn` + * or `_.forOwn` for object iteration. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @alias each + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @returns {Array|Object} Returns `collection`. + * @see _.forEachRight + * @example + * + * _.forEach([1, 2], function(value) { + * console.log(value); + * }); + * // => Logs `1` then `2`. + * + * _.forEach({ 'a': 1, 'b': 2 }, function(value, key) { + * console.log(key); + * }); + * // => Logs 'a' then 'b' (iteration order is not guaranteed). + */ + function forEach(collection, iteratee) { + var func = isArray(collection) ? arrayEach : baseEach; + return func(collection, getIteratee(iteratee, 3)); + } + + /** + * This method is like `_.forEach` except that it iterates over elements of + * `collection` from right to left. + * + * @static + * @memberOf _ + * @since 2.0.0 + * @alias eachRight + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @returns {Array|Object} Returns `collection`. + * @see _.forEach + * @example + * + * _.forEachRight([1, 2], function(value) { + * console.log(value); + * }); + * // => Logs `2` then `1`. + */ + function forEachRight(collection, iteratee) { + var func = isArray(collection) ? arrayEachRight : baseEachRight; + return func(collection, getIteratee(iteratee, 3)); + } + + /** + * Creates an object composed of keys generated from the results of running + * each element of `collection` thru `iteratee`. The order of grouped values + * is determined by the order they occur in `collection`. The corresponding + * value of each key is an array of elements responsible for generating the + * key. The iteratee is invoked with one argument: (value). + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [iteratee=_.identity] The iteratee to transform keys. + * @returns {Object} Returns the composed aggregate object. + * @example + * + * _.groupBy([6.1, 4.2, 6.3], Math.floor); + * // => { '4': [4.2], '6': [6.1, 6.3] } + * + * // The `_.property` iteratee shorthand. + * _.groupBy(['one', 'two', 'three'], 'length'); + * // => { '3': ['one', 'two'], '5': ['three'] } + */ + var groupBy = createAggregator(function(result, value, key) { + if (hasOwnProperty.call(result, key)) { + result[key].push(value); + } else { + baseAssignValue(result, key, [value]); + } + }); + + /** + * Checks if `value` is in `collection`. If `collection` is a string, it's + * checked for a substring of `value`, otherwise + * [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero) + * is used for equality comparisons. If `fromIndex` is negative, it's used as + * the offset from the end of `collection`. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Collection + * @param {Array|Object|string} collection The collection to inspect. + * @param {*} value The value to search for. + * @param {number} [fromIndex=0] The index to search from. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.reduce`. + * @returns {boolean} Returns `true` if `value` is found, else `false`. + * @example + * + * _.includes([1, 2, 3], 1); + * // => true + * + * _.includes([1, 2, 3], 1, 2); + * // => false + * + * _.includes({ 'a': 1, 'b': 2 }, 1); + * // => true + * + * _.includes('abcd', 'bc'); + * // => true + */ + function includes(collection, value, fromIndex, guard) { + collection = isArrayLike(collection) ? collection : values(collection); + fromIndex = (fromIndex && !guard) ? toInteger(fromIndex) : 0; + + var length = collection.length; + if (fromIndex < 0) { + fromIndex = nativeMax(length + fromIndex, 0); + } + return isString(collection) + ? (fromIndex <= length && collection.indexOf(value, fromIndex) > -1) + : (!!length && baseIndexOf(collection, value, fromIndex) > -1); + } + + /** + * Invokes the method at `path` of each element in `collection`, returning + * an array of the results of each invoked method. Any additional arguments + * are provided to each invoked method. If `path` is a function, it's invoked + * for, and `this` bound to, each element in `collection`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Array|Function|string} path The path of the method to invoke or + * the function invoked per iteration. + * @param {...*} [args] The arguments to invoke each method with. + * @returns {Array} Returns the array of results. + * @example + * + * _.invokeMap([[5, 1, 7], [3, 2, 1]], 'sort'); + * // => [[1, 5, 7], [1, 2, 3]] + * + * _.invokeMap([123, 456], String.prototype.split, ''); + * // => [['1', '2', '3'], ['4', '5', '6']] + */ + var invokeMap = baseRest(function(collection, path, args) { + var index = -1, + isFunc = typeof path == 'function', + result = isArrayLike(collection) ? Array(collection.length) : []; + + baseEach(collection, function(value) { + result[++index] = isFunc ? apply(path, value, args) : baseInvoke(value, path, args); + }); + return result; + }); + + /** + * Creates an object composed of keys generated from the results of running + * each element of `collection` thru `iteratee`. The corresponding value of + * each key is the last element responsible for generating the key. The + * iteratee is invoked with one argument: (value). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [iteratee=_.identity] The iteratee to transform keys. + * @returns {Object} Returns the composed aggregate object. + * @example + * + * var array = [ + * { 'dir': 'left', 'code': 97 }, + * { 'dir': 'right', 'code': 100 } + * ]; + * + * _.keyBy(array, function(o) { + * return String.fromCharCode(o.code); + * }); + * // => { 'a': { 'dir': 'left', 'code': 97 }, 'd': { 'dir': 'right', 'code': 100 } } + * + * _.keyBy(array, 'dir'); + * // => { 'left': { 'dir': 'left', 'code': 97 }, 'right': { 'dir': 'right', 'code': 100 } } + */ + var keyBy = createAggregator(function(result, value, key) { + baseAssignValue(result, key, value); + }); + + /** + * Creates an array of values by running each element in `collection` thru + * `iteratee`. The iteratee is invoked with three arguments: + * (value, index|key, collection). + * + * Many lodash methods are guarded to work as iteratees for methods like + * `_.every`, `_.filter`, `_.map`, `_.mapValues`, `_.reject`, and `_.some`. + * + * The guarded methods are: + * `ary`, `chunk`, `curry`, `curryRight`, `drop`, `dropRight`, `every`, + * `fill`, `invert`, `parseInt`, `random`, `range`, `rangeRight`, `repeat`, + * `sampleSize`, `slice`, `some`, `sortBy`, `split`, `take`, `takeRight`, + * `template`, `trim`, `trimEnd`, `trimStart`, and `words` + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @returns {Array} Returns the new mapped array. + * @example + * + * function square(n) { + * return n * n; + * } + * + * _.map([4, 8], square); + * // => [16, 64] + * + * _.map({ 'a': 4, 'b': 8 }, square); + * // => [16, 64] (iteration order is not guaranteed) + * + * var users = [ + * { 'user': 'barney' }, + * { 'user': 'fred' } + * ]; + * + * // The `_.property` iteratee shorthand. + * _.map(users, 'user'); + * // => ['barney', 'fred'] + */ + function map(collection, iteratee) { + var func = isArray(collection) ? arrayMap : baseMap; + return func(collection, getIteratee(iteratee, 3)); + } + + /** + * This method is like `_.sortBy` except that it allows specifying the sort + * orders of the iteratees to sort by. If `orders` is unspecified, all values + * are sorted in ascending order. Otherwise, specify an order of "desc" for + * descending or "asc" for ascending sort order of corresponding values. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Array[]|Function[]|Object[]|string[]} [iteratees=[_.identity]] + * The iteratees to sort by. + * @param {string[]} [orders] The sort orders of `iteratees`. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.reduce`. + * @returns {Array} Returns the new sorted array. + * @example + * + * var users = [ + * { 'user': 'fred', 'age': 48 }, + * { 'user': 'barney', 'age': 34 }, + * { 'user': 'fred', 'age': 40 }, + * { 'user': 'barney', 'age': 36 } + * ]; + * + * // Sort by `user` in ascending order and by `age` in descending order. + * _.orderBy(users, ['user', 'age'], ['asc', 'desc']); + * // => objects for [['barney', 36], ['barney', 34], ['fred', 48], ['fred', 40]] + */ + function orderBy(collection, iteratees, orders, guard) { + if (collection == null) { + return []; + } + if (!isArray(iteratees)) { + iteratees = iteratees == null ? [] : [iteratees]; + } + orders = guard ? undefined : orders; + if (!isArray(orders)) { + orders = orders == null ? [] : [orders]; + } + return baseOrderBy(collection, iteratees, orders); + } + + /** + * Creates an array of elements split into two groups, the first of which + * contains elements `predicate` returns truthy for, the second of which + * contains elements `predicate` returns falsey for. The predicate is + * invoked with one argument: (value). + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @returns {Array} Returns the array of grouped elements. + * @example + * + * var users = [ + * { 'user': 'barney', 'age': 36, 'active': false }, + * { 'user': 'fred', 'age': 40, 'active': true }, + * { 'user': 'pebbles', 'age': 1, 'active': false } + * ]; + * + * _.partition(users, function(o) { return o.active; }); + * // => objects for [['fred'], ['barney', 'pebbles']] + * + * // The `_.matches` iteratee shorthand. + * _.partition(users, { 'age': 1, 'active': false }); + * // => objects for [['pebbles'], ['barney', 'fred']] + * + * // The `_.matchesProperty` iteratee shorthand. + * _.partition(users, ['active', false]); + * // => objects for [['barney', 'pebbles'], ['fred']] + * + * // The `_.property` iteratee shorthand. + * _.partition(users, 'active'); + * // => objects for [['fred'], ['barney', 'pebbles']] + */ + var partition = createAggregator(function(result, value, key) { + result[key ? 0 : 1].push(value); + }, function() { return [[], []]; }); + + /** + * Reduces `collection` to a value which is the accumulated result of running + * each element in `collection` thru `iteratee`, where each successive + * invocation is supplied the return value of the previous. If `accumulator` + * is not given, the first element of `collection` is used as the initial + * value. The iteratee is invoked with four arguments: + * (accumulator, value, index|key, collection). + * + * Many lodash methods are guarded to work as iteratees for methods like + * `_.reduce`, `_.reduceRight`, and `_.transform`. + * + * The guarded methods are: + * `assign`, `defaults`, `defaultsDeep`, `includes`, `merge`, `orderBy`, + * and `sortBy` + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @param {*} [accumulator] The initial value. + * @returns {*} Returns the accumulated value. + * @see _.reduceRight + * @example + * + * _.reduce([1, 2], function(sum, n) { + * return sum + n; + * }, 0); + * // => 3 + * + * _.reduce({ 'a': 1, 'b': 2, 'c': 1 }, function(result, value, key) { + * (result[value] || (result[value] = [])).push(key); + * return result; + * }, {}); + * // => { '1': ['a', 'c'], '2': ['b'] } (iteration order is not guaranteed) + */ + function reduce(collection, iteratee, accumulator) { + var func = isArray(collection) ? arrayReduce : baseReduce, + initAccum = arguments.length < 3; + + return func(collection, getIteratee(iteratee, 4), accumulator, initAccum, baseEach); + } + + /** + * This method is like `_.reduce` except that it iterates over elements of + * `collection` from right to left. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @param {*} [accumulator] The initial value. + * @returns {*} Returns the accumulated value. + * @see _.reduce + * @example + * + * var array = [[0, 1], [2, 3], [4, 5]]; + * + * _.reduceRight(array, function(flattened, other) { + * return flattened.concat(other); + * }, []); + * // => [4, 5, 2, 3, 0, 1] + */ + function reduceRight(collection, iteratee, accumulator) { + var func = isArray(collection) ? arrayReduceRight : baseReduce, + initAccum = arguments.length < 3; + + return func(collection, getIteratee(iteratee, 4), accumulator, initAccum, baseEachRight); + } + + /** + * The opposite of `_.filter`; this method returns the elements of `collection` + * that `predicate` does **not** return truthy for. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @returns {Array} Returns the new filtered array. + * @see _.filter + * @example + * + * var users = [ + * { 'user': 'barney', 'age': 36, 'active': false }, + * { 'user': 'fred', 'age': 40, 'active': true } + * ]; + * + * _.reject(users, function(o) { return !o.active; }); + * // => objects for ['fred'] + * + * // The `_.matches` iteratee shorthand. + * _.reject(users, { 'age': 40, 'active': true }); + * // => objects for ['barney'] + * + * // The `_.matchesProperty` iteratee shorthand. + * _.reject(users, ['active', false]); + * // => objects for ['fred'] + * + * // The `_.property` iteratee shorthand. + * _.reject(users, 'active'); + * // => objects for ['barney'] + */ + function reject(collection, predicate) { + var func = isArray(collection) ? arrayFilter : baseFilter; + return func(collection, negate(getIteratee(predicate, 3))); + } + + /** + * Gets a random element from `collection`. + * + * @static + * @memberOf _ + * @since 2.0.0 + * @category Collection + * @param {Array|Object} collection The collection to sample. + * @returns {*} Returns the random element. + * @example + * + * _.sample([1, 2, 3, 4]); + * // => 2 + */ + function sample(collection) { + var func = isArray(collection) ? arraySample : baseSample; + return func(collection); + } + + /** + * Gets `n` random elements at unique keys from `collection` up to the + * size of `collection`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Collection + * @param {Array|Object} collection The collection to sample. + * @param {number} [n=1] The number of elements to sample. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`. + * @returns {Array} Returns the random elements. + * @example + * + * _.sampleSize([1, 2, 3], 2); + * // => [3, 1] + * + * _.sampleSize([1, 2, 3], 4); + * // => [2, 3, 1] + */ + function sampleSize(collection, n, guard) { + if ((guard ? isIterateeCall(collection, n, guard) : n === undefined)) { + n = 1; + } else { + n = toInteger(n); + } + var func = isArray(collection) ? arraySampleSize : baseSampleSize; + return func(collection, n); + } + + /** + * Creates an array of shuffled values, using a version of the + * [Fisher-Yates shuffle](https://en.wikipedia.org/wiki/Fisher-Yates_shuffle). + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Collection + * @param {Array|Object} collection The collection to shuffle. + * @returns {Array} Returns the new shuffled array. + * @example + * + * _.shuffle([1, 2, 3, 4]); + * // => [4, 1, 3, 2] + */ + function shuffle(collection) { + var func = isArray(collection) ? arrayShuffle : baseShuffle; + return func(collection); + } + + /** + * Gets the size of `collection` by returning its length for array-like + * values or the number of own enumerable string keyed properties for objects. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Collection + * @param {Array|Object|string} collection The collection to inspect. + * @returns {number} Returns the collection size. + * @example + * + * _.size([1, 2, 3]); + * // => 3 + * + * _.size({ 'a': 1, 'b': 2 }); + * // => 2 + * + * _.size('pebbles'); + * // => 7 + */ + function size(collection) { + if (collection == null) { + return 0; + } + if (isArrayLike(collection)) { + return isString(collection) ? stringSize(collection) : collection.length; + } + var tag = getTag(collection); + if (tag == mapTag || tag == setTag) { + return collection.size; + } + return baseKeys(collection).length; + } + + /** + * Checks if `predicate` returns truthy for **any** element of `collection`. + * Iteration is stopped once `predicate` returns truthy. The predicate is + * invoked with three arguments: (value, index|key, collection). + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`. + * @returns {boolean} Returns `true` if any element passes the predicate check, + * else `false`. + * @example + * + * _.some([null, 0, 'yes', false], Boolean); + * // => true + * + * var users = [ + * { 'user': 'barney', 'active': true }, + * { 'user': 'fred', 'active': false } + * ]; + * + * // The `_.matches` iteratee shorthand. + * _.some(users, { 'user': 'barney', 'active': false }); + * // => false + * + * // The `_.matchesProperty` iteratee shorthand. + * _.some(users, ['active', false]); + * // => true + * + * // The `_.property` iteratee shorthand. + * _.some(users, 'active'); + * // => true + */ + function some(collection, predicate, guard) { + var func = isArray(collection) ? arraySome : baseSome; + if (guard && isIterateeCall(collection, predicate, guard)) { + predicate = undefined; + } + return func(collection, getIteratee(predicate, 3)); + } + + /** + * Creates an array of elements, sorted in ascending order by the results of + * running each element in a collection thru each iteratee. This method + * performs a stable sort, that is, it preserves the original sort order of + * equal elements. The iteratees are invoked with one argument: (value). + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Collection + * @param {Array|Object} collection The collection to iterate over. + * @param {...(Function|Function[])} [iteratees=[_.identity]] + * The iteratees to sort by. + * @returns {Array} Returns the new sorted array. + * @example + * + * var users = [ + * { 'user': 'fred', 'age': 48 }, + * { 'user': 'barney', 'age': 36 }, + * { 'user': 'fred', 'age': 40 }, + * { 'user': 'barney', 'age': 34 } + * ]; + * + * _.sortBy(users, [function(o) { return o.user; }]); + * // => objects for [['barney', 36], ['barney', 34], ['fred', 48], ['fred', 40]] + * + * _.sortBy(users, ['user', 'age']); + * // => objects for [['barney', 34], ['barney', 36], ['fred', 40], ['fred', 48]] + */ + var sortBy = baseRest(function(collection, iteratees) { + if (collection == null) { + return []; + } + var length = iteratees.length; + if (length > 1 && isIterateeCall(collection, iteratees[0], iteratees[1])) { + iteratees = []; + } else if (length > 2 && isIterateeCall(iteratees[0], iteratees[1], iteratees[2])) { + iteratees = [iteratees[0]]; + } + return baseOrderBy(collection, baseFlatten(iteratees, 1), []); + }); + + /*------------------------------------------------------------------------*/ + + /** + * Gets the timestamp of the number of milliseconds that have elapsed since + * the Unix epoch (1 January 1970 00:00:00 UTC). + * + * @static + * @memberOf _ + * @since 2.4.0 + * @category Date + * @returns {number} Returns the timestamp. + * @example + * + * _.defer(function(stamp) { + * console.log(_.now() - stamp); + * }, _.now()); + * // => Logs the number of milliseconds it took for the deferred invocation. + */ + var now = ctxNow || function() { + return root.Date.now(); + }; + + /*------------------------------------------------------------------------*/ + + /** + * The opposite of `_.before`; this method creates a function that invokes + * `func` once it's called `n` or more times. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Function + * @param {number} n The number of calls before `func` is invoked. + * @param {Function} func The function to restrict. + * @returns {Function} Returns the new restricted function. + * @example + * + * var saves = ['profile', 'settings']; + * + * var done = _.after(saves.length, function() { + * console.log('done saving!'); + * }); + * + * _.forEach(saves, function(type) { + * asyncSave({ 'type': type, 'complete': done }); + * }); + * // => Logs 'done saving!' after the two async saves have completed. + */ + function after(n, func) { + if (typeof func != 'function') { + throw new TypeError(FUNC_ERROR_TEXT); + } + n = toInteger(n); + return function() { + if (--n < 1) { + return func.apply(this, arguments); + } + }; + } + + /** + * Creates a function that invokes `func`, with up to `n` arguments, + * ignoring any additional arguments. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Function + * @param {Function} func The function to cap arguments for. + * @param {number} [n=func.length] The arity cap. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`. + * @returns {Function} Returns the new capped function. + * @example + * + * _.map(['6', '8', '10'], _.ary(parseInt, 1)); + * // => [6, 8, 10] + */ + function ary(func, n, guard) { + n = guard ? undefined : n; + n = (func && n == null) ? func.length : n; + return createWrap(func, WRAP_ARY_FLAG, undefined, undefined, undefined, undefined, n); + } + + /** + * Creates a function that invokes `func`, with the `this` binding and arguments + * of the created function, while it's called less than `n` times. Subsequent + * calls to the created function return the result of the last `func` invocation. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Function + * @param {number} n The number of calls at which `func` is no longer invoked. + * @param {Function} func The function to restrict. + * @returns {Function} Returns the new restricted function. + * @example + * + * jQuery(element).on('click', _.before(5, addContactToList)); + * // => Allows adding up to 4 contacts to the list. + */ + function before(n, func) { + var result; + if (typeof func != 'function') { + throw new TypeError(FUNC_ERROR_TEXT); + } + n = toInteger(n); + return function() { + if (--n > 0) { + result = func.apply(this, arguments); + } + if (n <= 1) { + func = undefined; + } + return result; + }; + } + + /** + * Creates a function that invokes `func` with the `this` binding of `thisArg` + * and `partials` prepended to the arguments it receives. + * + * The `_.bind.placeholder` value, which defaults to `_` in monolithic builds, + * may be used as a placeholder for partially applied arguments. + * + * **Note:** Unlike native `Function#bind`, this method doesn't set the "length" + * property of bound functions. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Function + * @param {Function} func The function to bind. + * @param {*} thisArg The `this` binding of `func`. + * @param {...*} [partials] The arguments to be partially applied. + * @returns {Function} Returns the new bound function. + * @example + * + * function greet(greeting, punctuation) { + * return greeting + ' ' + this.user + punctuation; + * } + * + * var object = { 'user': 'fred' }; + * + * var bound = _.bind(greet, object, 'hi'); + * bound('!'); + * // => 'hi fred!' + * + * // Bound with placeholders. + * var bound = _.bind(greet, object, _, '!'); + * bound('hi'); + * // => 'hi fred!' + */ + var bind = baseRest(function(func, thisArg, partials) { + var bitmask = WRAP_BIND_FLAG; + if (partials.length) { + var holders = replaceHolders(partials, getHolder(bind)); + bitmask |= WRAP_PARTIAL_FLAG; + } + return createWrap(func, bitmask, thisArg, partials, holders); + }); + + /** + * Creates a function that invokes the method at `object[key]` with `partials` + * prepended to the arguments it receives. + * + * This method differs from `_.bind` by allowing bound functions to reference + * methods that may be redefined or don't yet exist. See + * [Peter Michaux's article](http://peter.michaux.ca/articles/lazy-function-definition-pattern) + * for more details. + * + * The `_.bindKey.placeholder` value, which defaults to `_` in monolithic + * builds, may be used as a placeholder for partially applied arguments. + * + * @static + * @memberOf _ + * @since 0.10.0 + * @category Function + * @param {Object} object The object to invoke the method on. + * @param {string} key The key of the method. + * @param {...*} [partials] The arguments to be partially applied. + * @returns {Function} Returns the new bound function. + * @example + * + * var object = { + * 'user': 'fred', + * 'greet': function(greeting, punctuation) { + * return greeting + ' ' + this.user + punctuation; + * } + * }; + * + * var bound = _.bindKey(object, 'greet', 'hi'); + * bound('!'); + * // => 'hi fred!' + * + * object.greet = function(greeting, punctuation) { + * return greeting + 'ya ' + this.user + punctuation; + * }; + * + * bound('!'); + * // => 'hiya fred!' + * + * // Bound with placeholders. + * var bound = _.bindKey(object, 'greet', _, '!'); + * bound('hi'); + * // => 'hiya fred!' + */ + var bindKey = baseRest(function(object, key, partials) { + var bitmask = WRAP_BIND_FLAG | WRAP_BIND_KEY_FLAG; + if (partials.length) { + var holders = replaceHolders(partials, getHolder(bindKey)); + bitmask |= WRAP_PARTIAL_FLAG; + } + return createWrap(key, bitmask, object, partials, holders); + }); + + /** + * Creates a function that accepts arguments of `func` and either invokes + * `func` returning its result, if at least `arity` number of arguments have + * been provided, or returns a function that accepts the remaining `func` + * arguments, and so on. The arity of `func` may be specified if `func.length` + * is not sufficient. + * + * The `_.curry.placeholder` value, which defaults to `_` in monolithic builds, + * may be used as a placeholder for provided arguments. + * + * **Note:** This method doesn't set the "length" property of curried functions. + * + * @static + * @memberOf _ + * @since 2.0.0 + * @category Function + * @param {Function} func The function to curry. + * @param {number} [arity=func.length] The arity of `func`. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`. + * @returns {Function} Returns the new curried function. + * @example + * + * var abc = function(a, b, c) { + * return [a, b, c]; + * }; + * + * var curried = _.curry(abc); + * + * curried(1)(2)(3); + * // => [1, 2, 3] + * + * curried(1, 2)(3); + * // => [1, 2, 3] + * + * curried(1, 2, 3); + * // => [1, 2, 3] + * + * // Curried with placeholders. + * curried(1)(_, 3)(2); + * // => [1, 2, 3] + */ + function curry(func, arity, guard) { + arity = guard ? undefined : arity; + var result = createWrap(func, WRAP_CURRY_FLAG, undefined, undefined, undefined, undefined, undefined, arity); + result.placeholder = curry.placeholder; + return result; + } + + /** + * This method is like `_.curry` except that arguments are applied to `func` + * in the manner of `_.partialRight` instead of `_.partial`. + * + * The `_.curryRight.placeholder` value, which defaults to `_` in monolithic + * builds, may be used as a placeholder for provided arguments. + * + * **Note:** This method doesn't set the "length" property of curried functions. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Function + * @param {Function} func The function to curry. + * @param {number} [arity=func.length] The arity of `func`. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`. + * @returns {Function} Returns the new curried function. + * @example + * + * var abc = function(a, b, c) { + * return [a, b, c]; + * }; + * + * var curried = _.curryRight(abc); + * + * curried(3)(2)(1); + * // => [1, 2, 3] + * + * curried(2, 3)(1); + * // => [1, 2, 3] + * + * curried(1, 2, 3); + * // => [1, 2, 3] + * + * // Curried with placeholders. + * curried(3)(1, _)(2); + * // => [1, 2, 3] + */ + function curryRight(func, arity, guard) { + arity = guard ? undefined : arity; + var result = createWrap(func, WRAP_CURRY_RIGHT_FLAG, undefined, undefined, undefined, undefined, undefined, arity); + result.placeholder = curryRight.placeholder; + return result; + } + + /** + * Creates a debounced function that delays invoking `func` until after `wait` + * milliseconds have elapsed since the last time the debounced function was + * invoked. The debounced function comes with a `cancel` method to cancel + * delayed `func` invocations and a `flush` method to immediately invoke them. + * Provide `options` to indicate whether `func` should be invoked on the + * leading and/or trailing edge of the `wait` timeout. The `func` is invoked + * with the last arguments provided to the debounced function. Subsequent + * calls to the debounced function return the result of the last `func` + * invocation. + * + * **Note:** If `leading` and `trailing` options are `true`, `func` is + * invoked on the trailing edge of the timeout only if the debounced function + * is invoked more than once during the `wait` timeout. + * + * If `wait` is `0` and `leading` is `false`, `func` invocation is deferred + * until to the next tick, similar to `setTimeout` with a timeout of `0`. + * + * See [David Corbacho's article](https://css-tricks.com/debouncing-throttling-explained-examples/) + * for details over the differences between `_.debounce` and `_.throttle`. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Function + * @param {Function} func The function to debounce. + * @param {number} [wait=0] The number of milliseconds to delay. + * @param {Object} [options={}] The options object. + * @param {boolean} [options.leading=false] + * Specify invoking on the leading edge of the timeout. + * @param {number} [options.maxWait] + * The maximum time `func` is allowed to be delayed before it's invoked. + * @param {boolean} [options.trailing=true] + * Specify invoking on the trailing edge of the timeout. + * @returns {Function} Returns the new debounced function. + * @example + * + * // Avoid costly calculations while the window size is in flux. + * jQuery(window).on('resize', _.debounce(calculateLayout, 150)); + * + * // Invoke `sendMail` when clicked, debouncing subsequent calls. + * jQuery(element).on('click', _.debounce(sendMail, 300, { + * 'leading': true, + * 'trailing': false + * })); + * + * // Ensure `batchLog` is invoked once after 1 second of debounced calls. + * var debounced = _.debounce(batchLog, 250, { 'maxWait': 1000 }); + * var source = new EventSource('/stream'); + * jQuery(source).on('message', debounced); + * + * // Cancel the trailing debounced invocation. + * jQuery(window).on('popstate', debounced.cancel); + */ + function debounce(func, wait, options) { + var lastArgs, + lastThis, + maxWait, + result, + timerId, + lastCallTime, + lastInvokeTime = 0, + leading = false, + maxing = false, + trailing = true; + + if (typeof func != 'function') { + throw new TypeError(FUNC_ERROR_TEXT); + } + wait = toNumber(wait) || 0; + if (isObject(options)) { + leading = !!options.leading; + maxing = 'maxWait' in options; + maxWait = maxing ? nativeMax(toNumber(options.maxWait) || 0, wait) : maxWait; + trailing = 'trailing' in options ? !!options.trailing : trailing; + } + + function invokeFunc(time) { + var args = lastArgs, + thisArg = lastThis; + + lastArgs = lastThis = undefined; + lastInvokeTime = time; + result = func.apply(thisArg, args); + return result; + } + + function leadingEdge(time) { + // Reset any `maxWait` timer. + lastInvokeTime = time; + // Start the timer for the trailing edge. + timerId = setTimeout(timerExpired, wait); + // Invoke the leading edge. + return leading ? invokeFunc(time) : result; + } + + function remainingWait(time) { + var timeSinceLastCall = time - lastCallTime, + timeSinceLastInvoke = time - lastInvokeTime, + timeWaiting = wait - timeSinceLastCall; + + return maxing + ? nativeMin(timeWaiting, maxWait - timeSinceLastInvoke) + : timeWaiting; + } + + function shouldInvoke(time) { + var timeSinceLastCall = time - lastCallTime, + timeSinceLastInvoke = time - lastInvokeTime; + + // Either this is the first call, activity has stopped and we're at the + // trailing edge, the system time has gone backwards and we're treating + // it as the trailing edge, or we've hit the `maxWait` limit. + return (lastCallTime === undefined || (timeSinceLastCall >= wait) || + (timeSinceLastCall < 0) || (maxing && timeSinceLastInvoke >= maxWait)); + } + + function timerExpired() { + var time = now(); + if (shouldInvoke(time)) { + return trailingEdge(time); + } + // Restart the timer. + timerId = setTimeout(timerExpired, remainingWait(time)); + } + + function trailingEdge(time) { + timerId = undefined; + + // Only invoke if we have `lastArgs` which means `func` has been + // debounced at least once. + if (trailing && lastArgs) { + return invokeFunc(time); + } + lastArgs = lastThis = undefined; + return result; + } + + function cancel() { + if (timerId !== undefined) { + clearTimeout(timerId); + } + lastInvokeTime = 0; + lastArgs = lastCallTime = lastThis = timerId = undefined; + } + + function flush() { + return timerId === undefined ? result : trailingEdge(now()); + } + + function debounced() { + var time = now(), + isInvoking = shouldInvoke(time); + + lastArgs = arguments; + lastThis = this; + lastCallTime = time; + + if (isInvoking) { + if (timerId === undefined) { + return leadingEdge(lastCallTime); + } + if (maxing) { + // Handle invocations in a tight loop. + timerId = setTimeout(timerExpired, wait); + return invokeFunc(lastCallTime); + } + } + if (timerId === undefined) { + timerId = setTimeout(timerExpired, wait); + } + return result; + } + debounced.cancel = cancel; + debounced.flush = flush; + return debounced; + } + + /** + * Defers invoking the `func` until the current call stack has cleared. Any + * additional arguments are provided to `func` when it's invoked. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Function + * @param {Function} func The function to defer. + * @param {...*} [args] The arguments to invoke `func` with. + * @returns {number} Returns the timer id. + * @example + * + * _.defer(function(text) { + * console.log(text); + * }, 'deferred'); + * // => Logs 'deferred' after one millisecond. + */ + var defer = baseRest(function(func, args) { + return baseDelay(func, 1, args); + }); + + /** + * Invokes `func` after `wait` milliseconds. Any additional arguments are + * provided to `func` when it's invoked. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Function + * @param {Function} func The function to delay. + * @param {number} wait The number of milliseconds to delay invocation. + * @param {...*} [args] The arguments to invoke `func` with. + * @returns {number} Returns the timer id. + * @example + * + * _.delay(function(text) { + * console.log(text); + * }, 1000, 'later'); + * // => Logs 'later' after one second. + */ + var delay = baseRest(function(func, wait, args) { + return baseDelay(func, toNumber(wait) || 0, args); + }); + + /** + * Creates a function that invokes `func` with arguments reversed. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Function + * @param {Function} func The function to flip arguments for. + * @returns {Function} Returns the new flipped function. + * @example + * + * var flipped = _.flip(function() { + * return _.toArray(arguments); + * }); + * + * flipped('a', 'b', 'c', 'd'); + * // => ['d', 'c', 'b', 'a'] + */ + function flip(func) { + return createWrap(func, WRAP_FLIP_FLAG); + } + + /** + * Creates a function that memoizes the result of `func`. If `resolver` is + * provided, it determines the cache key for storing the result based on the + * arguments provided to the memoized function. By default, the first argument + * provided to the memoized function is used as the map cache key. The `func` + * is invoked with the `this` binding of the memoized function. + * + * **Note:** The cache is exposed as the `cache` property on the memoized + * function. Its creation may be customized by replacing the `_.memoize.Cache` + * constructor with one whose instances implement the + * [`Map`](http://ecma-international.org/ecma-262/7.0/#sec-properties-of-the-map-prototype-object) + * method interface of `clear`, `delete`, `get`, `has`, and `set`. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Function + * @param {Function} func The function to have its output memoized. + * @param {Function} [resolver] The function to resolve the cache key. + * @returns {Function} Returns the new memoized function. + * @example + * + * var object = { 'a': 1, 'b': 2 }; + * var other = { 'c': 3, 'd': 4 }; + * + * var values = _.memoize(_.values); + * values(object); + * // => [1, 2] + * + * values(other); + * // => [3, 4] + * + * object.a = 2; + * values(object); + * // => [1, 2] + * + * // Modify the result cache. + * values.cache.set(object, ['a', 'b']); + * values(object); + * // => ['a', 'b'] + * + * // Replace `_.memoize.Cache`. + * _.memoize.Cache = WeakMap; + */ + function memoize(func, resolver) { + if (typeof func != 'function' || (resolver != null && typeof resolver != 'function')) { + throw new TypeError(FUNC_ERROR_TEXT); + } + var memoized = function() { + var args = arguments, + key = resolver ? resolver.apply(this, args) : args[0], + cache = memoized.cache; + + if (cache.has(key)) { + return cache.get(key); + } + var result = func.apply(this, args); + memoized.cache = cache.set(key, result) || cache; + return result; + }; + memoized.cache = new (memoize.Cache || MapCache); + return memoized; + } + + // Expose `MapCache`. + memoize.Cache = MapCache; + + /** + * Creates a function that negates the result of the predicate `func`. The + * `func` predicate is invoked with the `this` binding and arguments of the + * created function. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Function + * @param {Function} predicate The predicate to negate. + * @returns {Function} Returns the new negated function. + * @example + * + * function isEven(n) { + * return n % 2 == 0; + * } + * + * _.filter([1, 2, 3, 4, 5, 6], _.negate(isEven)); + * // => [1, 3, 5] + */ + function negate(predicate) { + if (typeof predicate != 'function') { + throw new TypeError(FUNC_ERROR_TEXT); + } + return function() { + var args = arguments; + switch (args.length) { + case 0: return !predicate.call(this); + case 1: return !predicate.call(this, args[0]); + case 2: return !predicate.call(this, args[0], args[1]); + case 3: return !predicate.call(this, args[0], args[1], args[2]); + } + return !predicate.apply(this, args); + }; + } + + /** + * Creates a function that is restricted to invoking `func` once. Repeat calls + * to the function return the value of the first invocation. The `func` is + * invoked with the `this` binding and arguments of the created function. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Function + * @param {Function} func The function to restrict. + * @returns {Function} Returns the new restricted function. + * @example + * + * var initialize = _.once(createApplication); + * initialize(); + * initialize(); + * // => `createApplication` is invoked once + */ + function once(func) { + return before(2, func); + } + + /** + * Creates a function that invokes `func` with its arguments transformed. + * + * @static + * @since 4.0.0 + * @memberOf _ + * @category Function + * @param {Function} func The function to wrap. + * @param {...(Function|Function[])} [transforms=[_.identity]] + * The argument transforms. + * @returns {Function} Returns the new function. + * @example + * + * function doubled(n) { + * return n * 2; + * } + * + * function square(n) { + * return n * n; + * } + * + * var func = _.overArgs(function(x, y) { + * return [x, y]; + * }, [square, doubled]); + * + * func(9, 3); + * // => [81, 6] + * + * func(10, 5); + * // => [100, 10] + */ + var overArgs = castRest(function(func, transforms) { + transforms = (transforms.length == 1 && isArray(transforms[0])) + ? arrayMap(transforms[0], baseUnary(getIteratee())) + : arrayMap(baseFlatten(transforms, 1), baseUnary(getIteratee())); + + var funcsLength = transforms.length; + return baseRest(function(args) { + var index = -1, + length = nativeMin(args.length, funcsLength); + + while (++index < length) { + args[index] = transforms[index].call(this, args[index]); + } + return apply(func, this, args); + }); + }); + + /** + * Creates a function that invokes `func` with `partials` prepended to the + * arguments it receives. This method is like `_.bind` except it does **not** + * alter the `this` binding. + * + * The `_.partial.placeholder` value, which defaults to `_` in monolithic + * builds, may be used as a placeholder for partially applied arguments. + * + * **Note:** This method doesn't set the "length" property of partially + * applied functions. + * + * @static + * @memberOf _ + * @since 0.2.0 + * @category Function + * @param {Function} func The function to partially apply arguments to. + * @param {...*} [partials] The arguments to be partially applied. + * @returns {Function} Returns the new partially applied function. + * @example + * + * function greet(greeting, name) { + * return greeting + ' ' + name; + * } + * + * var sayHelloTo = _.partial(greet, 'hello'); + * sayHelloTo('fred'); + * // => 'hello fred' + * + * // Partially applied with placeholders. + * var greetFred = _.partial(greet, _, 'fred'); + * greetFred('hi'); + * // => 'hi fred' + */ + var partial = baseRest(function(func, partials) { + var holders = replaceHolders(partials, getHolder(partial)); + return createWrap(func, WRAP_PARTIAL_FLAG, undefined, partials, holders); + }); + + /** + * This method is like `_.partial` except that partially applied arguments + * are appended to the arguments it receives. + * + * The `_.partialRight.placeholder` value, which defaults to `_` in monolithic + * builds, may be used as a placeholder for partially applied arguments. + * + * **Note:** This method doesn't set the "length" property of partially + * applied functions. + * + * @static + * @memberOf _ + * @since 1.0.0 + * @category Function + * @param {Function} func The function to partially apply arguments to. + * @param {...*} [partials] The arguments to be partially applied. + * @returns {Function} Returns the new partially applied function. + * @example + * + * function greet(greeting, name) { + * return greeting + ' ' + name; + * } + * + * var greetFred = _.partialRight(greet, 'fred'); + * greetFred('hi'); + * // => 'hi fred' + * + * // Partially applied with placeholders. + * var sayHelloTo = _.partialRight(greet, 'hello', _); + * sayHelloTo('fred'); + * // => 'hello fred' + */ + var partialRight = baseRest(function(func, partials) { + var holders = replaceHolders(partials, getHolder(partialRight)); + return createWrap(func, WRAP_PARTIAL_RIGHT_FLAG, undefined, partials, holders); + }); + + /** + * Creates a function that invokes `func` with arguments arranged according + * to the specified `indexes` where the argument value at the first index is + * provided as the first argument, the argument value at the second index is + * provided as the second argument, and so on. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Function + * @param {Function} func The function to rearrange arguments for. + * @param {...(number|number[])} indexes The arranged argument indexes. + * @returns {Function} Returns the new function. + * @example + * + * var rearged = _.rearg(function(a, b, c) { + * return [a, b, c]; + * }, [2, 0, 1]); + * + * rearged('b', 'c', 'a') + * // => ['a', 'b', 'c'] + */ + var rearg = flatRest(function(func, indexes) { + return createWrap(func, WRAP_REARG_FLAG, undefined, undefined, undefined, indexes); + }); + + /** + * Creates a function that invokes `func` with the `this` binding of the + * created function and arguments from `start` and beyond provided as + * an array. + * + * **Note:** This method is based on the + * [rest parameter](https://mdn.io/rest_parameters). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Function + * @param {Function} func The function to apply a rest parameter to. + * @param {number} [start=func.length-1] The start position of the rest parameter. + * @returns {Function} Returns the new function. + * @example + * + * var say = _.rest(function(what, names) { + * return what + ' ' + _.initial(names).join(', ') + + * (_.size(names) > 1 ? ', & ' : '') + _.last(names); + * }); + * + * say('hello', 'fred', 'barney', 'pebbles'); + * // => 'hello fred, barney, & pebbles' + */ + function rest(func, start) { + if (typeof func != 'function') { + throw new TypeError(FUNC_ERROR_TEXT); + } + start = start === undefined ? start : toInteger(start); + return baseRest(func, start); + } + + /** + * Creates a function that invokes `func` with the `this` binding of the + * create function and an array of arguments much like + * [`Function#apply`](http://www.ecma-international.org/ecma-262/7.0/#sec-function.prototype.apply). + * + * **Note:** This method is based on the + * [spread operator](https://mdn.io/spread_operator). + * + * @static + * @memberOf _ + * @since 3.2.0 + * @category Function + * @param {Function} func The function to spread arguments over. + * @param {number} [start=0] The start position of the spread. + * @returns {Function} Returns the new function. + * @example + * + * var say = _.spread(function(who, what) { + * return who + ' says ' + what; + * }); + * + * say(['fred', 'hello']); + * // => 'fred says hello' + * + * var numbers = Promise.all([ + * Promise.resolve(40), + * Promise.resolve(36) + * ]); + * + * numbers.then(_.spread(function(x, y) { + * return x + y; + * })); + * // => a Promise of 76 + */ + function spread(func, start) { + if (typeof func != 'function') { + throw new TypeError(FUNC_ERROR_TEXT); + } + start = start == null ? 0 : nativeMax(toInteger(start), 0); + return baseRest(function(args) { + var array = args[start], + otherArgs = castSlice(args, 0, start); + + if (array) { + arrayPush(otherArgs, array); + } + return apply(func, this, otherArgs); + }); + } + + /** + * Creates a throttled function that only invokes `func` at most once per + * every `wait` milliseconds. The throttled function comes with a `cancel` + * method to cancel delayed `func` invocations and a `flush` method to + * immediately invoke them. Provide `options` to indicate whether `func` + * should be invoked on the leading and/or trailing edge of the `wait` + * timeout. The `func` is invoked with the last arguments provided to the + * throttled function. Subsequent calls to the throttled function return the + * result of the last `func` invocation. + * + * **Note:** If `leading` and `trailing` options are `true`, `func` is + * invoked on the trailing edge of the timeout only if the throttled function + * is invoked more than once during the `wait` timeout. + * + * If `wait` is `0` and `leading` is `false`, `func` invocation is deferred + * until to the next tick, similar to `setTimeout` with a timeout of `0`. + * + * See [David Corbacho's article](https://css-tricks.com/debouncing-throttling-explained-examples/) + * for details over the differences between `_.throttle` and `_.debounce`. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Function + * @param {Function} func The function to throttle. + * @param {number} [wait=0] The number of milliseconds to throttle invocations to. + * @param {Object} [options={}] The options object. + * @param {boolean} [options.leading=true] + * Specify invoking on the leading edge of the timeout. + * @param {boolean} [options.trailing=true] + * Specify invoking on the trailing edge of the timeout. + * @returns {Function} Returns the new throttled function. + * @example + * + * // Avoid excessively updating the position while scrolling. + * jQuery(window).on('scroll', _.throttle(updatePosition, 100)); + * + * // Invoke `renewToken` when the click event is fired, but not more than once every 5 minutes. + * var throttled = _.throttle(renewToken, 300000, { 'trailing': false }); + * jQuery(element).on('click', throttled); + * + * // Cancel the trailing throttled invocation. + * jQuery(window).on('popstate', throttled.cancel); + */ + function throttle(func, wait, options) { + var leading = true, + trailing = true; + + if (typeof func != 'function') { + throw new TypeError(FUNC_ERROR_TEXT); + } + if (isObject(options)) { + leading = 'leading' in options ? !!options.leading : leading; + trailing = 'trailing' in options ? !!options.trailing : trailing; + } + return debounce(func, wait, { + 'leading': leading, + 'maxWait': wait, + 'trailing': trailing + }); + } + + /** + * Creates a function that accepts up to one argument, ignoring any + * additional arguments. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Function + * @param {Function} func The function to cap arguments for. + * @returns {Function} Returns the new capped function. + * @example + * + * _.map(['6', '8', '10'], _.unary(parseInt)); + * // => [6, 8, 10] + */ + function unary(func) { + return ary(func, 1); + } + + /** + * Creates a function that provides `value` to `wrapper` as its first + * argument. Any additional arguments provided to the function are appended + * to those provided to the `wrapper`. The wrapper is invoked with the `this` + * binding of the created function. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Function + * @param {*} value The value to wrap. + * @param {Function} [wrapper=identity] The wrapper function. + * @returns {Function} Returns the new function. + * @example + * + * var p = _.wrap(_.escape, function(func, text) { + * return '

' + func(text) + '

'; + * }); + * + * p('fred, barney, & pebbles'); + * // => '

fred, barney, & pebbles

' + */ + function wrap(value, wrapper) { + return partial(castFunction(wrapper), value); + } + + /*------------------------------------------------------------------------*/ + + /** + * Casts `value` as an array if it's not one. + * + * @static + * @memberOf _ + * @since 4.4.0 + * @category Lang + * @param {*} value The value to inspect. + * @returns {Array} Returns the cast array. + * @example + * + * _.castArray(1); + * // => [1] + * + * _.castArray({ 'a': 1 }); + * // => [{ 'a': 1 }] + * + * _.castArray('abc'); + * // => ['abc'] + * + * _.castArray(null); + * // => [null] + * + * _.castArray(undefined); + * // => [undefined] + * + * _.castArray(); + * // => [] + * + * var array = [1, 2, 3]; + * console.log(_.castArray(array) === array); + * // => true + */ + function castArray() { + if (!arguments.length) { + return []; + } + var value = arguments[0]; + return isArray(value) ? value : [value]; + } + + /** + * Creates a shallow clone of `value`. + * + * **Note:** This method is loosely based on the + * [structured clone algorithm](https://mdn.io/Structured_clone_algorithm) + * and supports cloning arrays, array buffers, booleans, date objects, maps, + * numbers, `Object` objects, regexes, sets, strings, symbols, and typed + * arrays. The own enumerable properties of `arguments` objects are cloned + * as plain objects. An empty object is returned for uncloneable values such + * as error objects, functions, DOM nodes, and WeakMaps. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to clone. + * @returns {*} Returns the cloned value. + * @see _.cloneDeep + * @example + * + * var objects = [{ 'a': 1 }, { 'b': 2 }]; + * + * var shallow = _.clone(objects); + * console.log(shallow[0] === objects[0]); + * // => true + */ + function clone(value) { + return baseClone(value, CLONE_SYMBOLS_FLAG); + } + + /** + * This method is like `_.clone` except that it accepts `customizer` which + * is invoked to produce the cloned value. If `customizer` returns `undefined`, + * cloning is handled by the method instead. The `customizer` is invoked with + * up to four arguments; (value [, index|key, object, stack]). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to clone. + * @param {Function} [customizer] The function to customize cloning. + * @returns {*} Returns the cloned value. + * @see _.cloneDeepWith + * @example + * + * function customizer(value) { + * if (_.isElement(value)) { + * return value.cloneNode(false); + * } + * } + * + * var el = _.cloneWith(document.body, customizer); + * + * console.log(el === document.body); + * // => false + * console.log(el.nodeName); + * // => 'BODY' + * console.log(el.childNodes.length); + * // => 0 + */ + function cloneWith(value, customizer) { + customizer = typeof customizer == 'function' ? customizer : undefined; + return baseClone(value, CLONE_SYMBOLS_FLAG, customizer); + } + + /** + * This method is like `_.clone` except that it recursively clones `value`. + * + * @static + * @memberOf _ + * @since 1.0.0 + * @category Lang + * @param {*} value The value to recursively clone. + * @returns {*} Returns the deep cloned value. + * @see _.clone + * @example + * + * var objects = [{ 'a': 1 }, { 'b': 2 }]; + * + * var deep = _.cloneDeep(objects); + * console.log(deep[0] === objects[0]); + * // => false + */ + function cloneDeep(value) { + return baseClone(value, CLONE_DEEP_FLAG | CLONE_SYMBOLS_FLAG); + } + + /** + * This method is like `_.cloneWith` except that it recursively clones `value`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to recursively clone. + * @param {Function} [customizer] The function to customize cloning. + * @returns {*} Returns the deep cloned value. + * @see _.cloneWith + * @example + * + * function customizer(value) { + * if (_.isElement(value)) { + * return value.cloneNode(true); + * } + * } + * + * var el = _.cloneDeepWith(document.body, customizer); + * + * console.log(el === document.body); + * // => false + * console.log(el.nodeName); + * // => 'BODY' + * console.log(el.childNodes.length); + * // => 20 + */ + function cloneDeepWith(value, customizer) { + customizer = typeof customizer == 'function' ? customizer : undefined; + return baseClone(value, CLONE_DEEP_FLAG | CLONE_SYMBOLS_FLAG, customizer); + } + + /** + * Checks if `object` conforms to `source` by invoking the predicate + * properties of `source` with the corresponding property values of `object`. + * + * **Note:** This method is equivalent to `_.conforms` when `source` is + * partially applied. + * + * @static + * @memberOf _ + * @since 4.14.0 + * @category Lang + * @param {Object} object The object to inspect. + * @param {Object} source The object of property predicates to conform to. + * @returns {boolean} Returns `true` if `object` conforms, else `false`. + * @example + * + * var object = { 'a': 1, 'b': 2 }; + * + * _.conformsTo(object, { 'b': function(n) { return n > 1; } }); + * // => true + * + * _.conformsTo(object, { 'b': function(n) { return n > 2; } }); + * // => false + */ + function conformsTo(object, source) { + return source == null || baseConformsTo(object, source, keys(source)); + } + + /** + * Performs a + * [`SameValueZero`](http://ecma-international.org/ecma-262/7.0/#sec-samevaluezero) + * comparison between two values to determine if they are equivalent. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to compare. + * @param {*} other The other value to compare. + * @returns {boolean} Returns `true` if the values are equivalent, else `false`. + * @example + * + * var object = { 'a': 1 }; + * var other = { 'a': 1 }; + * + * _.eq(object, object); + * // => true + * + * _.eq(object, other); + * // => false + * + * _.eq('a', 'a'); + * // => true + * + * _.eq('a', Object('a')); + * // => false + * + * _.eq(NaN, NaN); + * // => true + */ + function eq(value, other) { + return value === other || (value !== value && other !== other); + } + + /** + * Checks if `value` is greater than `other`. + * + * @static + * @memberOf _ + * @since 3.9.0 + * @category Lang + * @param {*} value The value to compare. + * @param {*} other The other value to compare. + * @returns {boolean} Returns `true` if `value` is greater than `other`, + * else `false`. + * @see _.lt + * @example + * + * _.gt(3, 1); + * // => true + * + * _.gt(3, 3); + * // => false + * + * _.gt(1, 3); + * // => false + */ + var gt = createRelationalOperation(baseGt); + + /** + * Checks if `value` is greater than or equal to `other`. + * + * @static + * @memberOf _ + * @since 3.9.0 + * @category Lang + * @param {*} value The value to compare. + * @param {*} other The other value to compare. + * @returns {boolean} Returns `true` if `value` is greater than or equal to + * `other`, else `false`. + * @see _.lte + * @example + * + * _.gte(3, 1); + * // => true + * + * _.gte(3, 3); + * // => true + * + * _.gte(1, 3); + * // => false + */ + var gte = createRelationalOperation(function(value, other) { + return value >= other; + }); + + /** + * Checks if `value` is likely an `arguments` object. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is an `arguments` object, + * else `false`. + * @example + * + * _.isArguments(function() { return arguments; }()); + * // => true + * + * _.isArguments([1, 2, 3]); + * // => false + */ + var isArguments = baseIsArguments(function() { return arguments; }()) ? baseIsArguments : function(value) { + return isObjectLike(value) && hasOwnProperty.call(value, 'callee') && + !propertyIsEnumerable.call(value, 'callee'); + }; + + /** + * Checks if `value` is classified as an `Array` object. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is an array, else `false`. + * @example + * + * _.isArray([1, 2, 3]); + * // => true + * + * _.isArray(document.body.children); + * // => false + * + * _.isArray('abc'); + * // => false + * + * _.isArray(_.noop); + * // => false + */ + var isArray = Array.isArray; + + /** + * Checks if `value` is classified as an `ArrayBuffer` object. + * + * @static + * @memberOf _ + * @since 4.3.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is an array buffer, else `false`. + * @example + * + * _.isArrayBuffer(new ArrayBuffer(2)); + * // => true + * + * _.isArrayBuffer(new Array(2)); + * // => false + */ + var isArrayBuffer = nodeIsArrayBuffer ? baseUnary(nodeIsArrayBuffer) : baseIsArrayBuffer; + + /** + * Checks if `value` is array-like. A value is considered array-like if it's + * not a function and has a `value.length` that's an integer greater than or + * equal to `0` and less than or equal to `Number.MAX_SAFE_INTEGER`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is array-like, else `false`. + * @example + * + * _.isArrayLike([1, 2, 3]); + * // => true + * + * _.isArrayLike(document.body.children); + * // => true + * + * _.isArrayLike('abc'); + * // => true + * + * _.isArrayLike(_.noop); + * // => false + */ + function isArrayLike(value) { + return value != null && isLength(value.length) && !isFunction(value); + } + + /** + * This method is like `_.isArrayLike` except that it also checks if `value` + * is an object. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is an array-like object, + * else `false`. + * @example + * + * _.isArrayLikeObject([1, 2, 3]); + * // => true + * + * _.isArrayLikeObject(document.body.children); + * // => true + * + * _.isArrayLikeObject('abc'); + * // => false + * + * _.isArrayLikeObject(_.noop); + * // => false + */ + function isArrayLikeObject(value) { + return isObjectLike(value) && isArrayLike(value); + } + + /** + * Checks if `value` is classified as a boolean primitive or object. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a boolean, else `false`. + * @example + * + * _.isBoolean(false); + * // => true + * + * _.isBoolean(null); + * // => false + */ + function isBoolean(value) { + return value === true || value === false || + (isObjectLike(value) && baseGetTag(value) == boolTag); + } + + /** + * Checks if `value` is a buffer. + * + * @static + * @memberOf _ + * @since 4.3.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a buffer, else `false`. + * @example + * + * _.isBuffer(new Buffer(2)); + * // => true + * + * _.isBuffer(new Uint8Array(2)); + * // => false + */ + var isBuffer = nativeIsBuffer || stubFalse; + + /** + * Checks if `value` is classified as a `Date` object. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a date object, else `false`. + * @example + * + * _.isDate(new Date); + * // => true + * + * _.isDate('Mon April 23 2012'); + * // => false + */ + var isDate = nodeIsDate ? baseUnary(nodeIsDate) : baseIsDate; + + /** + * Checks if `value` is likely a DOM element. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a DOM element, else `false`. + * @example + * + * _.isElement(document.body); + * // => true + * + * _.isElement(''); + * // => false + */ + function isElement(value) { + return isObjectLike(value) && value.nodeType === 1 && !isPlainObject(value); + } + + /** + * Checks if `value` is an empty object, collection, map, or set. + * + * Objects are considered empty if they have no own enumerable string keyed + * properties. + * + * Array-like values such as `arguments` objects, arrays, buffers, strings, or + * jQuery-like collections are considered empty if they have a `length` of `0`. + * Similarly, maps and sets are considered empty if they have a `size` of `0`. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is empty, else `false`. + * @example + * + * _.isEmpty(null); + * // => true + * + * _.isEmpty(true); + * // => true + * + * _.isEmpty(1); + * // => true + * + * _.isEmpty([1, 2, 3]); + * // => false + * + * _.isEmpty({ 'a': 1 }); + * // => false + */ + function isEmpty(value) { + if (value == null) { + return true; + } + if (isArrayLike(value) && + (isArray(value) || typeof value == 'string' || typeof value.splice == 'function' || + isBuffer(value) || isTypedArray(value) || isArguments(value))) { + return !value.length; + } + var tag = getTag(value); + if (tag == mapTag || tag == setTag) { + return !value.size; + } + if (isPrototype(value)) { + return !baseKeys(value).length; + } + for (var key in value) { + if (hasOwnProperty.call(value, key)) { + return false; + } + } + return true; + } + + /** + * Performs a deep comparison between two values to determine if they are + * equivalent. + * + * **Note:** This method supports comparing arrays, array buffers, booleans, + * date objects, error objects, maps, numbers, `Object` objects, regexes, + * sets, strings, symbols, and typed arrays. `Object` objects are compared + * by their own, not inherited, enumerable properties. Functions and DOM + * nodes are compared by strict equality, i.e. `===`. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to compare. + * @param {*} other The other value to compare. + * @returns {boolean} Returns `true` if the values are equivalent, else `false`. + * @example + * + * var object = { 'a': 1 }; + * var other = { 'a': 1 }; + * + * _.isEqual(object, other); + * // => true + * + * object === other; + * // => false + */ + function isEqual(value, other) { + return baseIsEqual(value, other); + } + + /** + * This method is like `_.isEqual` except that it accepts `customizer` which + * is invoked to compare values. If `customizer` returns `undefined`, comparisons + * are handled by the method instead. The `customizer` is invoked with up to + * six arguments: (objValue, othValue [, index|key, object, other, stack]). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to compare. + * @param {*} other The other value to compare. + * @param {Function} [customizer] The function to customize comparisons. + * @returns {boolean} Returns `true` if the values are equivalent, else `false`. + * @example + * + * function isGreeting(value) { + * return /^h(?:i|ello)$/.test(value); + * } + * + * function customizer(objValue, othValue) { + * if (isGreeting(objValue) && isGreeting(othValue)) { + * return true; + * } + * } + * + * var array = ['hello', 'goodbye']; + * var other = ['hi', 'goodbye']; + * + * _.isEqualWith(array, other, customizer); + * // => true + */ + function isEqualWith(value, other, customizer) { + customizer = typeof customizer == 'function' ? customizer : undefined; + var result = customizer ? customizer(value, other) : undefined; + return result === undefined ? baseIsEqual(value, other, undefined, customizer) : !!result; + } + + /** + * Checks if `value` is an `Error`, `EvalError`, `RangeError`, `ReferenceError`, + * `SyntaxError`, `TypeError`, or `URIError` object. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is an error object, else `false`. + * @example + * + * _.isError(new Error); + * // => true + * + * _.isError(Error); + * // => false + */ + function isError(value) { + if (!isObjectLike(value)) { + return false; + } + var tag = baseGetTag(value); + return tag == errorTag || tag == domExcTag || + (typeof value.message == 'string' && typeof value.name == 'string' && !isPlainObject(value)); + } + + /** + * Checks if `value` is a finite primitive number. + * + * **Note:** This method is based on + * [`Number.isFinite`](https://mdn.io/Number/isFinite). + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a finite number, else `false`. + * @example + * + * _.isFinite(3); + * // => true + * + * _.isFinite(Number.MIN_VALUE); + * // => true + * + * _.isFinite(Infinity); + * // => false + * + * _.isFinite('3'); + * // => false + */ + function isFinite(value) { + return typeof value == 'number' && nativeIsFinite(value); + } + + /** + * Checks if `value` is classified as a `Function` object. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a function, else `false`. + * @example + * + * _.isFunction(_); + * // => true + * + * _.isFunction(/abc/); + * // => false + */ + function isFunction(value) { + if (!isObject(value)) { + return false; + } + // The use of `Object#toString` avoids issues with the `typeof` operator + // in Safari 9 which returns 'object' for typed arrays and other constructors. + var tag = baseGetTag(value); + return tag == funcTag || tag == genTag || tag == asyncTag || tag == proxyTag; + } + + /** + * Checks if `value` is an integer. + * + * **Note:** This method is based on + * [`Number.isInteger`](https://mdn.io/Number/isInteger). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is an integer, else `false`. + * @example + * + * _.isInteger(3); + * // => true + * + * _.isInteger(Number.MIN_VALUE); + * // => false + * + * _.isInteger(Infinity); + * // => false + * + * _.isInteger('3'); + * // => false + */ + function isInteger(value) { + return typeof value == 'number' && value == toInteger(value); + } + + /** + * Checks if `value` is a valid array-like length. + * + * **Note:** This method is loosely based on + * [`ToLength`](http://ecma-international.org/ecma-262/7.0/#sec-tolength). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a valid length, else `false`. + * @example + * + * _.isLength(3); + * // => true + * + * _.isLength(Number.MIN_VALUE); + * // => false + * + * _.isLength(Infinity); + * // => false + * + * _.isLength('3'); + * // => false + */ + function isLength(value) { + return typeof value == 'number' && + value > -1 && value % 1 == 0 && value <= MAX_SAFE_INTEGER; + } + + /** + * Checks if `value` is the + * [language type](http://www.ecma-international.org/ecma-262/7.0/#sec-ecmascript-language-types) + * of `Object`. (e.g. arrays, functions, objects, regexes, `new Number(0)`, and `new String('')`) + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is an object, else `false`. + * @example + * + * _.isObject({}); + * // => true + * + * _.isObject([1, 2, 3]); + * // => true + * + * _.isObject(_.noop); + * // => true + * + * _.isObject(null); + * // => false + */ + function isObject(value) { + var type = typeof value; + return value != null && (type == 'object' || type == 'function'); + } + + /** + * Checks if `value` is object-like. A value is object-like if it's not `null` + * and has a `typeof` result of "object". + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is object-like, else `false`. + * @example + * + * _.isObjectLike({}); + * // => true + * + * _.isObjectLike([1, 2, 3]); + * // => true + * + * _.isObjectLike(_.noop); + * // => false + * + * _.isObjectLike(null); + * // => false + */ + function isObjectLike(value) { + return value != null && typeof value == 'object'; + } + + /** + * Checks if `value` is classified as a `Map` object. + * + * @static + * @memberOf _ + * @since 4.3.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a map, else `false`. + * @example + * + * _.isMap(new Map); + * // => true + * + * _.isMap(new WeakMap); + * // => false + */ + var isMap = nodeIsMap ? baseUnary(nodeIsMap) : baseIsMap; + + /** + * Performs a partial deep comparison between `object` and `source` to + * determine if `object` contains equivalent property values. + * + * **Note:** This method is equivalent to `_.matches` when `source` is + * partially applied. + * + * Partial comparisons will match empty array and empty object `source` + * values against any array or object value, respectively. See `_.isEqual` + * for a list of supported value comparisons. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Lang + * @param {Object} object The object to inspect. + * @param {Object} source The object of property values to match. + * @returns {boolean} Returns `true` if `object` is a match, else `false`. + * @example + * + * var object = { 'a': 1, 'b': 2 }; + * + * _.isMatch(object, { 'b': 2 }); + * // => true + * + * _.isMatch(object, { 'b': 1 }); + * // => false + */ + function isMatch(object, source) { + return object === source || baseIsMatch(object, source, getMatchData(source)); + } + + /** + * This method is like `_.isMatch` except that it accepts `customizer` which + * is invoked to compare values. If `customizer` returns `undefined`, comparisons + * are handled by the method instead. The `customizer` is invoked with five + * arguments: (objValue, srcValue, index|key, object, source). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {Object} object The object to inspect. + * @param {Object} source The object of property values to match. + * @param {Function} [customizer] The function to customize comparisons. + * @returns {boolean} Returns `true` if `object` is a match, else `false`. + * @example + * + * function isGreeting(value) { + * return /^h(?:i|ello)$/.test(value); + * } + * + * function customizer(objValue, srcValue) { + * if (isGreeting(objValue) && isGreeting(srcValue)) { + * return true; + * } + * } + * + * var object = { 'greeting': 'hello' }; + * var source = { 'greeting': 'hi' }; + * + * _.isMatchWith(object, source, customizer); + * // => true + */ + function isMatchWith(object, source, customizer) { + customizer = typeof customizer == 'function' ? customizer : undefined; + return baseIsMatch(object, source, getMatchData(source), customizer); + } + + /** + * Checks if `value` is `NaN`. + * + * **Note:** This method is based on + * [`Number.isNaN`](https://mdn.io/Number/isNaN) and is not the same as + * global [`isNaN`](https://mdn.io/isNaN) which returns `true` for + * `undefined` and other non-number values. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is `NaN`, else `false`. + * @example + * + * _.isNaN(NaN); + * // => true + * + * _.isNaN(new Number(NaN)); + * // => true + * + * isNaN(undefined); + * // => true + * + * _.isNaN(undefined); + * // => false + */ + function isNaN(value) { + // An `NaN` primitive is the only value that is not equal to itself. + // Perform the `toStringTag` check first to avoid errors with some + // ActiveX objects in IE. + return isNumber(value) && value != +value; + } + + /** + * Checks if `value` is a pristine native function. + * + * **Note:** This method can't reliably detect native functions in the presence + * of the core-js package because core-js circumvents this kind of detection. + * Despite multiple requests, the core-js maintainer has made it clear: any + * attempt to fix the detection will be obstructed. As a result, we're left + * with little choice but to throw an error. Unfortunately, this also affects + * packages, like [babel-polyfill](https://www.npmjs.com/package/babel-polyfill), + * which rely on core-js. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a native function, + * else `false`. + * @example + * + * _.isNative(Array.prototype.push); + * // => true + * + * _.isNative(_); + * // => false + */ + function isNative(value) { + if (isMaskable(value)) { + throw new Error(CORE_ERROR_TEXT); + } + return baseIsNative(value); + } + + /** + * Checks if `value` is `null`. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is `null`, else `false`. + * @example + * + * _.isNull(null); + * // => true + * + * _.isNull(void 0); + * // => false + */ + function isNull(value) { + return value === null; + } + + /** + * Checks if `value` is `null` or `undefined`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is nullish, else `false`. + * @example + * + * _.isNil(null); + * // => true + * + * _.isNil(void 0); + * // => true + * + * _.isNil(NaN); + * // => false + */ + function isNil(value) { + return value == null; + } + + /** + * Checks if `value` is classified as a `Number` primitive or object. + * + * **Note:** To exclude `Infinity`, `-Infinity`, and `NaN`, which are + * classified as numbers, use the `_.isFinite` method. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a number, else `false`. + * @example + * + * _.isNumber(3); + * // => true + * + * _.isNumber(Number.MIN_VALUE); + * // => true + * + * _.isNumber(Infinity); + * // => true + * + * _.isNumber('3'); + * // => false + */ + function isNumber(value) { + return typeof value == 'number' || + (isObjectLike(value) && baseGetTag(value) == numberTag); + } + + /** + * Checks if `value` is a plain object, that is, an object created by the + * `Object` constructor or one with a `[[Prototype]]` of `null`. + * + * @static + * @memberOf _ + * @since 0.8.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a plain object, else `false`. + * @example + * + * function Foo() { + * this.a = 1; + * } + * + * _.isPlainObject(new Foo); + * // => false + * + * _.isPlainObject([1, 2, 3]); + * // => false + * + * _.isPlainObject({ 'x': 0, 'y': 0 }); + * // => true + * + * _.isPlainObject(Object.create(null)); + * // => true + */ + function isPlainObject(value) { + if (!isObjectLike(value) || baseGetTag(value) != objectTag) { + return false; + } + var proto = getPrototype(value); + if (proto === null) { + return true; + } + var Ctor = hasOwnProperty.call(proto, 'constructor') && proto.constructor; + return typeof Ctor == 'function' && Ctor instanceof Ctor && + funcToString.call(Ctor) == objectCtorString; + } + + /** + * Checks if `value` is classified as a `RegExp` object. + * + * @static + * @memberOf _ + * @since 0.1.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a regexp, else `false`. + * @example + * + * _.isRegExp(/abc/); + * // => true + * + * _.isRegExp('/abc/'); + * // => false + */ + var isRegExp = nodeIsRegExp ? baseUnary(nodeIsRegExp) : baseIsRegExp; + + /** + * Checks if `value` is a safe integer. An integer is safe if it's an IEEE-754 + * double precision number which isn't the result of a rounded unsafe integer. + * + * **Note:** This method is based on + * [`Number.isSafeInteger`](https://mdn.io/Number/isSafeInteger). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a safe integer, else `false`. + * @example + * + * _.isSafeInteger(3); + * // => true + * + * _.isSafeInteger(Number.MIN_VALUE); + * // => false + * + * _.isSafeInteger(Infinity); + * // => false + * + * _.isSafeInteger('3'); + * // => false + */ + function isSafeInteger(value) { + return isInteger(value) && value >= -MAX_SAFE_INTEGER && value <= MAX_SAFE_INTEGER; + } + + /** + * Checks if `value` is classified as a `Set` object. + * + * @static + * @memberOf _ + * @since 4.3.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a set, else `false`. + * @example + * + * _.isSet(new Set); + * // => true + * + * _.isSet(new WeakSet); + * // => false + */ + var isSet = nodeIsSet ? baseUnary(nodeIsSet) : baseIsSet; + + /** + * Checks if `value` is classified as a `String` primitive or object. + * + * @static + * @since 0.1.0 + * @memberOf _ + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a string, else `false`. + * @example + * + * _.isString('abc'); + * // => true + * + * _.isString(1); + * // => false + */ + function isString(value) { + return typeof value == 'string' || + (!isArray(value) && isObjectLike(value) && baseGetTag(value) == stringTag); + } + + /** + * Checks if `value` is classified as a `Symbol` primitive or object. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a symbol, else `false`. + * @example + * + * _.isSymbol(Symbol.iterator); + * // => true + * + * _.isSymbol('abc'); + * // => false + */ + function isSymbol(value) { + return typeof value == 'symbol' || + (isObjectLike(value) && baseGetTag(value) == symbolTag); + } + + /** + * Checks if `value` is classified as a typed array. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a typed array, else `false`. + * @example + * + * _.isTypedArray(new Uint8Array); + * // => true + * + * _.isTypedArray([]); + * // => false + */ + var isTypedArray = nodeIsTypedArray ? baseUnary(nodeIsTypedArray) : baseIsTypedArray; + + /** + * Checks if `value` is `undefined`. + * + * @static + * @since 0.1.0 + * @memberOf _ + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is `undefined`, else `false`. + * @example + * + * _.isUndefined(void 0); + * // => true + * + * _.isUndefined(null); + * // => false + */ + function isUndefined(value) { + return value === undefined; + } + + /** + * Checks if `value` is classified as a `WeakMap` object. + * + * @static + * @memberOf _ + * @since 4.3.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a weak map, else `false`. + * @example + * + * _.isWeakMap(new WeakMap); + * // => true + * + * _.isWeakMap(new Map); + * // => false + */ + function isWeakMap(value) { + return isObjectLike(value) && getTag(value) == weakMapTag; + } + + /** + * Checks if `value` is classified as a `WeakSet` object. + * + * @static + * @memberOf _ + * @since 4.3.0 + * @category Lang + * @param {*} value The value to check. + * @returns {boolean} Returns `true` if `value` is a weak set, else `false`. + * @example + * + * _.isWeakSet(new WeakSet); + * // => true + * + * _.isWeakSet(new Set); + * // => false + */ + function isWeakSet(value) { + return isObjectLike(value) && baseGetTag(value) == weakSetTag; + } + + /** + * Checks if `value` is less than `other`. + * + * @static + * @memberOf _ + * @since 3.9.0 + * @category Lang + * @param {*} value The value to compare. + * @param {*} other The other value to compare. + * @returns {boolean} Returns `true` if `value` is less than `other`, + * else `false`. + * @see _.gt + * @example + * + * _.lt(1, 3); + * // => true + * + * _.lt(3, 3); + * // => false + * + * _.lt(3, 1); + * // => false + */ + var lt = createRelationalOperation(baseLt); + + /** + * Checks if `value` is less than or equal to `other`. + * + * @static + * @memberOf _ + * @since 3.9.0 + * @category Lang + * @param {*} value The value to compare. + * @param {*} other The other value to compare. + * @returns {boolean} Returns `true` if `value` is less than or equal to + * `other`, else `false`. + * @see _.gte + * @example + * + * _.lte(1, 3); + * // => true + * + * _.lte(3, 3); + * // => true + * + * _.lte(3, 1); + * // => false + */ + var lte = createRelationalOperation(function(value, other) { + return value <= other; + }); + + /** + * Converts `value` to an array. + * + * @static + * @since 0.1.0 + * @memberOf _ + * @category Lang + * @param {*} value The value to convert. + * @returns {Array} Returns the converted array. + * @example + * + * _.toArray({ 'a': 1, 'b': 2 }); + * // => [1, 2] + * + * _.toArray('abc'); + * // => ['a', 'b', 'c'] + * + * _.toArray(1); + * // => [] + * + * _.toArray(null); + * // => [] + */ + function toArray(value) { + if (!value) { + return []; + } + if (isArrayLike(value)) { + return isString(value) ? stringToArray(value) : copyArray(value); + } + if (symIterator && value[symIterator]) { + return iteratorToArray(value[symIterator]()); + } + var tag = getTag(value), + func = tag == mapTag ? mapToArray : (tag == setTag ? setToArray : values); + + return func(value); + } + + /** + * Converts `value` to a finite number. + * + * @static + * @memberOf _ + * @since 4.12.0 + * @category Lang + * @param {*} value The value to convert. + * @returns {number} Returns the converted number. + * @example + * + * _.toFinite(3.2); + * // => 3.2 + * + * _.toFinite(Number.MIN_VALUE); + * // => 5e-324 + * + * _.toFinite(Infinity); + * // => 1.7976931348623157e+308 + * + * _.toFinite('3.2'); + * // => 3.2 + */ + function toFinite(value) { + if (!value) { + return value === 0 ? value : 0; + } + value = toNumber(value); + if (value === INFINITY || value === -INFINITY) { + var sign = (value < 0 ? -1 : 1); + return sign * MAX_INTEGER; + } + return value === value ? value : 0; + } + + /** + * Converts `value` to an integer. + * + * **Note:** This method is loosely based on + * [`ToInteger`](http://www.ecma-international.org/ecma-262/7.0/#sec-tointeger). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to convert. + * @returns {number} Returns the converted integer. + * @example + * + * _.toInteger(3.2); + * // => 3 + * + * _.toInteger(Number.MIN_VALUE); + * // => 0 + * + * _.toInteger(Infinity); + * // => 1.7976931348623157e+308 + * + * _.toInteger('3.2'); + * // => 3 + */ + function toInteger(value) { + var result = toFinite(value), + remainder = result % 1; + + return result === result ? (remainder ? result - remainder : result) : 0; + } + + /** + * Converts `value` to an integer suitable for use as the length of an + * array-like object. + * + * **Note:** This method is based on + * [`ToLength`](http://ecma-international.org/ecma-262/7.0/#sec-tolength). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to convert. + * @returns {number} Returns the converted integer. + * @example + * + * _.toLength(3.2); + * // => 3 + * + * _.toLength(Number.MIN_VALUE); + * // => 0 + * + * _.toLength(Infinity); + * // => 4294967295 + * + * _.toLength('3.2'); + * // => 3 + */ + function toLength(value) { + return value ? baseClamp(toInteger(value), 0, MAX_ARRAY_LENGTH) : 0; + } + + /** + * Converts `value` to a number. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to process. + * @returns {number} Returns the number. + * @example + * + * _.toNumber(3.2); + * // => 3.2 + * + * _.toNumber(Number.MIN_VALUE); + * // => 5e-324 + * + * _.toNumber(Infinity); + * // => Infinity + * + * _.toNumber('3.2'); + * // => 3.2 + */ + function toNumber(value) { + if (typeof value == 'number') { + return value; + } + if (isSymbol(value)) { + return NAN; + } + if (isObject(value)) { + var other = typeof value.valueOf == 'function' ? value.valueOf() : value; + value = isObject(other) ? (other + '') : other; + } + if (typeof value != 'string') { + return value === 0 ? value : +value; + } + value = value.replace(reTrim, ''); + var isBinary = reIsBinary.test(value); + return (isBinary || reIsOctal.test(value)) + ? freeParseInt(value.slice(2), isBinary ? 2 : 8) + : (reIsBadHex.test(value) ? NAN : +value); + } + + /** + * Converts `value` to a plain object flattening inherited enumerable string + * keyed properties of `value` to own properties of the plain object. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Lang + * @param {*} value The value to convert. + * @returns {Object} Returns the converted plain object. + * @example + * + * function Foo() { + * this.b = 2; + * } + * + * Foo.prototype.c = 3; + * + * _.assign({ 'a': 1 }, new Foo); + * // => { 'a': 1, 'b': 2 } + * + * _.assign({ 'a': 1 }, _.toPlainObject(new Foo)); + * // => { 'a': 1, 'b': 2, 'c': 3 } + */ + function toPlainObject(value) { + return copyObject(value, keysIn(value)); + } + + /** + * Converts `value` to a safe integer. A safe integer can be compared and + * represented correctly. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to convert. + * @returns {number} Returns the converted integer. + * @example + * + * _.toSafeInteger(3.2); + * // => 3 + * + * _.toSafeInteger(Number.MIN_VALUE); + * // => 0 + * + * _.toSafeInteger(Infinity); + * // => 9007199254740991 + * + * _.toSafeInteger('3.2'); + * // => 3 + */ + function toSafeInteger(value) { + return value + ? baseClamp(toInteger(value), -MAX_SAFE_INTEGER, MAX_SAFE_INTEGER) + : (value === 0 ? value : 0); + } + + /** + * Converts `value` to a string. An empty string is returned for `null` + * and `undefined` values. The sign of `-0` is preserved. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Lang + * @param {*} value The value to convert. + * @returns {string} Returns the converted string. + * @example + * + * _.toString(null); + * // => '' + * + * _.toString(-0); + * // => '-0' + * + * _.toString([1, 2, 3]); + * // => '1,2,3' + */ + function toString(value) { + return value == null ? '' : baseToString(value); + } + + /*------------------------------------------------------------------------*/ + + /** + * Assigns own enumerable string keyed properties of source objects to the + * destination object. Source objects are applied from left to right. + * Subsequent sources overwrite property assignments of previous sources. + * + * **Note:** This method mutates `object` and is loosely based on + * [`Object.assign`](https://mdn.io/Object/assign). + * + * @static + * @memberOf _ + * @since 0.10.0 + * @category Object + * @param {Object} object The destination object. + * @param {...Object} [sources] The source objects. + * @returns {Object} Returns `object`. + * @see _.assignIn + * @example + * + * function Foo() { + * this.a = 1; + * } + * + * function Bar() { + * this.c = 3; + * } + * + * Foo.prototype.b = 2; + * Bar.prototype.d = 4; + * + * _.assign({ 'a': 0 }, new Foo, new Bar); + * // => { 'a': 1, 'c': 3 } + */ + var assign = createAssigner(function(object, source) { + if (isPrototype(source) || isArrayLike(source)) { + copyObject(source, keys(source), object); + return; + } + for (var key in source) { + if (hasOwnProperty.call(source, key)) { + assignValue(object, key, source[key]); + } + } + }); + + /** + * This method is like `_.assign` except that it iterates over own and + * inherited source properties. + * + * **Note:** This method mutates `object`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @alias extend + * @category Object + * @param {Object} object The destination object. + * @param {...Object} [sources] The source objects. + * @returns {Object} Returns `object`. + * @see _.assign + * @example + * + * function Foo() { + * this.a = 1; + * } + * + * function Bar() { + * this.c = 3; + * } + * + * Foo.prototype.b = 2; + * Bar.prototype.d = 4; + * + * _.assignIn({ 'a': 0 }, new Foo, new Bar); + * // => { 'a': 1, 'b': 2, 'c': 3, 'd': 4 } + */ + var assignIn = createAssigner(function(object, source) { + copyObject(source, keysIn(source), object); + }); + + /** + * This method is like `_.assignIn` except that it accepts `customizer` + * which is invoked to produce the assigned values. If `customizer` returns + * `undefined`, assignment is handled by the method instead. The `customizer` + * is invoked with five arguments: (objValue, srcValue, key, object, source). + * + * **Note:** This method mutates `object`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @alias extendWith + * @category Object + * @param {Object} object The destination object. + * @param {...Object} sources The source objects. + * @param {Function} [customizer] The function to customize assigned values. + * @returns {Object} Returns `object`. + * @see _.assignWith + * @example + * + * function customizer(objValue, srcValue) { + * return _.isUndefined(objValue) ? srcValue : objValue; + * } + * + * var defaults = _.partialRight(_.assignInWith, customizer); + * + * defaults({ 'a': 1 }, { 'b': 2 }, { 'a': 3 }); + * // => { 'a': 1, 'b': 2 } + */ + var assignInWith = createAssigner(function(object, source, srcIndex, customizer) { + copyObject(source, keysIn(source), object, customizer); + }); + + /** + * This method is like `_.assign` except that it accepts `customizer` + * which is invoked to produce the assigned values. If `customizer` returns + * `undefined`, assignment is handled by the method instead. The `customizer` + * is invoked with five arguments: (objValue, srcValue, key, object, source). + * + * **Note:** This method mutates `object`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Object + * @param {Object} object The destination object. + * @param {...Object} sources The source objects. + * @param {Function} [customizer] The function to customize assigned values. + * @returns {Object} Returns `object`. + * @see _.assignInWith + * @example + * + * function customizer(objValue, srcValue) { + * return _.isUndefined(objValue) ? srcValue : objValue; + * } + * + * var defaults = _.partialRight(_.assignWith, customizer); + * + * defaults({ 'a': 1 }, { 'b': 2 }, { 'a': 3 }); + * // => { 'a': 1, 'b': 2 } + */ + var assignWith = createAssigner(function(object, source, srcIndex, customizer) { + copyObject(source, keys(source), object, customizer); + }); + + /** + * Creates an array of values corresponding to `paths` of `object`. + * + * @static + * @memberOf _ + * @since 1.0.0 + * @category Object + * @param {Object} object The object to iterate over. + * @param {...(string|string[])} [paths] The property paths to pick. + * @returns {Array} Returns the picked values. + * @example + * + * var object = { 'a': [{ 'b': { 'c': 3 } }, 4] }; + * + * _.at(object, ['a[0].b.c', 'a[1]']); + * // => [3, 4] + */ + var at = flatRest(baseAt); + + /** + * Creates an object that inherits from the `prototype` object. If a + * `properties` object is given, its own enumerable string keyed properties + * are assigned to the created object. + * + * @static + * @memberOf _ + * @since 2.3.0 + * @category Object + * @param {Object} prototype The object to inherit from. + * @param {Object} [properties] The properties to assign to the object. + * @returns {Object} Returns the new object. + * @example + * + * function Shape() { + * this.x = 0; + * this.y = 0; + * } + * + * function Circle() { + * Shape.call(this); + * } + * + * Circle.prototype = _.create(Shape.prototype, { + * 'constructor': Circle + * }); + * + * var circle = new Circle; + * circle instanceof Circle; + * // => true + * + * circle instanceof Shape; + * // => true + */ + function create(prototype, properties) { + var result = baseCreate(prototype); + return properties == null ? result : baseAssign(result, properties); + } + + /** + * Assigns own and inherited enumerable string keyed properties of source + * objects to the destination object for all destination properties that + * resolve to `undefined`. Source objects are applied from left to right. + * Once a property is set, additional values of the same property are ignored. + * + * **Note:** This method mutates `object`. + * + * @static + * @since 0.1.0 + * @memberOf _ + * @category Object + * @param {Object} object The destination object. + * @param {...Object} [sources] The source objects. + * @returns {Object} Returns `object`. + * @see _.defaultsDeep + * @example + * + * _.defaults({ 'a': 1 }, { 'b': 2 }, { 'a': 3 }); + * // => { 'a': 1, 'b': 2 } + */ + var defaults = baseRest(function(object, sources) { + object = Object(object); + + var index = -1; + var length = sources.length; + var guard = length > 2 ? sources[2] : undefined; + + if (guard && isIterateeCall(sources[0], sources[1], guard)) { + length = 1; + } + + while (++index < length) { + var source = sources[index]; + var props = keysIn(source); + var propsIndex = -1; + var propsLength = props.length; + + while (++propsIndex < propsLength) { + var key = props[propsIndex]; + var value = object[key]; + + if (value === undefined || + (eq(value, objectProto[key]) && !hasOwnProperty.call(object, key))) { + object[key] = source[key]; + } + } + } + + return object; + }); + + /** + * This method is like `_.defaults` except that it recursively assigns + * default properties. + * + * **Note:** This method mutates `object`. + * + * @static + * @memberOf _ + * @since 3.10.0 + * @category Object + * @param {Object} object The destination object. + * @param {...Object} [sources] The source objects. + * @returns {Object} Returns `object`. + * @see _.defaults + * @example + * + * _.defaultsDeep({ 'a': { 'b': 2 } }, { 'a': { 'b': 1, 'c': 3 } }); + * // => { 'a': { 'b': 2, 'c': 3 } } + */ + var defaultsDeep = baseRest(function(args) { + args.push(undefined, customDefaultsMerge); + return apply(mergeWith, undefined, args); + }); + + /** + * This method is like `_.find` except that it returns the key of the first + * element `predicate` returns truthy for instead of the element itself. + * + * @static + * @memberOf _ + * @since 1.1.0 + * @category Object + * @param {Object} object The object to inspect. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @returns {string|undefined} Returns the key of the matched element, + * else `undefined`. + * @example + * + * var users = { + * 'barney': { 'age': 36, 'active': true }, + * 'fred': { 'age': 40, 'active': false }, + * 'pebbles': { 'age': 1, 'active': true } + * }; + * + * _.findKey(users, function(o) { return o.age < 40; }); + * // => 'barney' (iteration order is not guaranteed) + * + * // The `_.matches` iteratee shorthand. + * _.findKey(users, { 'age': 1, 'active': true }); + * // => 'pebbles' + * + * // The `_.matchesProperty` iteratee shorthand. + * _.findKey(users, ['active', false]); + * // => 'fred' + * + * // The `_.property` iteratee shorthand. + * _.findKey(users, 'active'); + * // => 'barney' + */ + function findKey(object, predicate) { + return baseFindKey(object, getIteratee(predicate, 3), baseForOwn); + } + + /** + * This method is like `_.findKey` except that it iterates over elements of + * a collection in the opposite order. + * + * @static + * @memberOf _ + * @since 2.0.0 + * @category Object + * @param {Object} object The object to inspect. + * @param {Function} [predicate=_.identity] The function invoked per iteration. + * @returns {string|undefined} Returns the key of the matched element, + * else `undefined`. + * @example + * + * var users = { + * 'barney': { 'age': 36, 'active': true }, + * 'fred': { 'age': 40, 'active': false }, + * 'pebbles': { 'age': 1, 'active': true } + * }; + * + * _.findLastKey(users, function(o) { return o.age < 40; }); + * // => returns 'pebbles' assuming `_.findKey` returns 'barney' + * + * // The `_.matches` iteratee shorthand. + * _.findLastKey(users, { 'age': 36, 'active': true }); + * // => 'barney' + * + * // The `_.matchesProperty` iteratee shorthand. + * _.findLastKey(users, ['active', false]); + * // => 'fred' + * + * // The `_.property` iteratee shorthand. + * _.findLastKey(users, 'active'); + * // => 'pebbles' + */ + function findLastKey(object, predicate) { + return baseFindKey(object, getIteratee(predicate, 3), baseForOwnRight); + } + + /** + * Iterates over own and inherited enumerable string keyed properties of an + * object and invokes `iteratee` for each property. The iteratee is invoked + * with three arguments: (value, key, object). Iteratee functions may exit + * iteration early by explicitly returning `false`. + * + * @static + * @memberOf _ + * @since 0.3.0 + * @category Object + * @param {Object} object The object to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @returns {Object} Returns `object`. + * @see _.forInRight + * @example + * + * function Foo() { + * this.a = 1; + * this.b = 2; + * } + * + * Foo.prototype.c = 3; + * + * _.forIn(new Foo, function(value, key) { + * console.log(key); + * }); + * // => Logs 'a', 'b', then 'c' (iteration order is not guaranteed). + */ + function forIn(object, iteratee) { + return object == null + ? object + : baseFor(object, getIteratee(iteratee, 3), keysIn); + } + + /** + * This method is like `_.forIn` except that it iterates over properties of + * `object` in the opposite order. + * + * @static + * @memberOf _ + * @since 2.0.0 + * @category Object + * @param {Object} object The object to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @returns {Object} Returns `object`. + * @see _.forIn + * @example + * + * function Foo() { + * this.a = 1; + * this.b = 2; + * } + * + * Foo.prototype.c = 3; + * + * _.forInRight(new Foo, function(value, key) { + * console.log(key); + * }); + * // => Logs 'c', 'b', then 'a' assuming `_.forIn` logs 'a', 'b', then 'c'. + */ + function forInRight(object, iteratee) { + return object == null + ? object + : baseForRight(object, getIteratee(iteratee, 3), keysIn); + } + + /** + * Iterates over own enumerable string keyed properties of an object and + * invokes `iteratee` for each property. The iteratee is invoked with three + * arguments: (value, key, object). Iteratee functions may exit iteration + * early by explicitly returning `false`. + * + * @static + * @memberOf _ + * @since 0.3.0 + * @category Object + * @param {Object} object The object to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @returns {Object} Returns `object`. + * @see _.forOwnRight + * @example + * + * function Foo() { + * this.a = 1; + * this.b = 2; + * } + * + * Foo.prototype.c = 3; + * + * _.forOwn(new Foo, function(value, key) { + * console.log(key); + * }); + * // => Logs 'a' then 'b' (iteration order is not guaranteed). + */ + function forOwn(object, iteratee) { + return object && baseForOwn(object, getIteratee(iteratee, 3)); + } + + /** + * This method is like `_.forOwn` except that it iterates over properties of + * `object` in the opposite order. + * + * @static + * @memberOf _ + * @since 2.0.0 + * @category Object + * @param {Object} object The object to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @returns {Object} Returns `object`. + * @see _.forOwn + * @example + * + * function Foo() { + * this.a = 1; + * this.b = 2; + * } + * + * Foo.prototype.c = 3; + * + * _.forOwnRight(new Foo, function(value, key) { + * console.log(key); + * }); + * // => Logs 'b' then 'a' assuming `_.forOwn` logs 'a' then 'b'. + */ + function forOwnRight(object, iteratee) { + return object && baseForOwnRight(object, getIteratee(iteratee, 3)); + } + + /** + * Creates an array of function property names from own enumerable properties + * of `object`. + * + * @static + * @since 0.1.0 + * @memberOf _ + * @category Object + * @param {Object} object The object to inspect. + * @returns {Array} Returns the function names. + * @see _.functionsIn + * @example + * + * function Foo() { + * this.a = _.constant('a'); + * this.b = _.constant('b'); + * } + * + * Foo.prototype.c = _.constant('c'); + * + * _.functions(new Foo); + * // => ['a', 'b'] + */ + function functions(object) { + return object == null ? [] : baseFunctions(object, keys(object)); + } + + /** + * Creates an array of function property names from own and inherited + * enumerable properties of `object`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Object + * @param {Object} object The object to inspect. + * @returns {Array} Returns the function names. + * @see _.functions + * @example + * + * function Foo() { + * this.a = _.constant('a'); + * this.b = _.constant('b'); + * } + * + * Foo.prototype.c = _.constant('c'); + * + * _.functionsIn(new Foo); + * // => ['a', 'b', 'c'] + */ + function functionsIn(object) { + return object == null ? [] : baseFunctions(object, keysIn(object)); + } + + /** + * Gets the value at `path` of `object`. If the resolved value is + * `undefined`, the `defaultValue` is returned in its place. + * + * @static + * @memberOf _ + * @since 3.7.0 + * @category Object + * @param {Object} object The object to query. + * @param {Array|string} path The path of the property to get. + * @param {*} [defaultValue] The value returned for `undefined` resolved values. + * @returns {*} Returns the resolved value. + * @example + * + * var object = { 'a': [{ 'b': { 'c': 3 } }] }; + * + * _.get(object, 'a[0].b.c'); + * // => 3 + * + * _.get(object, ['a', '0', 'b', 'c']); + * // => 3 + * + * _.get(object, 'a.b.c', 'default'); + * // => 'default' + */ + function get(object, path, defaultValue) { + var result = object == null ? undefined : baseGet(object, path); + return result === undefined ? defaultValue : result; + } + + /** + * Checks if `path` is a direct property of `object`. + * + * @static + * @since 0.1.0 + * @memberOf _ + * @category Object + * @param {Object} object The object to query. + * @param {Array|string} path The path to check. + * @returns {boolean} Returns `true` if `path` exists, else `false`. + * @example + * + * var object = { 'a': { 'b': 2 } }; + * var other = _.create({ 'a': _.create({ 'b': 2 }) }); + * + * _.has(object, 'a'); + * // => true + * + * _.has(object, 'a.b'); + * // => true + * + * _.has(object, ['a', 'b']); + * // => true + * + * _.has(other, 'a'); + * // => false + */ + function has(object, path) { + return object != null && hasPath(object, path, baseHas); + } + + /** + * Checks if `path` is a direct or inherited property of `object`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Object + * @param {Object} object The object to query. + * @param {Array|string} path The path to check. + * @returns {boolean} Returns `true` if `path` exists, else `false`. + * @example + * + * var object = _.create({ 'a': _.create({ 'b': 2 }) }); + * + * _.hasIn(object, 'a'); + * // => true + * + * _.hasIn(object, 'a.b'); + * // => true + * + * _.hasIn(object, ['a', 'b']); + * // => true + * + * _.hasIn(object, 'b'); + * // => false + */ + function hasIn(object, path) { + return object != null && hasPath(object, path, baseHasIn); + } + + /** + * Creates an object composed of the inverted keys and values of `object`. + * If `object` contains duplicate values, subsequent values overwrite + * property assignments of previous values. + * + * @static + * @memberOf _ + * @since 0.7.0 + * @category Object + * @param {Object} object The object to invert. + * @returns {Object} Returns the new inverted object. + * @example + * + * var object = { 'a': 1, 'b': 2, 'c': 1 }; + * + * _.invert(object); + * // => { '1': 'c', '2': 'b' } + */ + var invert = createInverter(function(result, value, key) { + if (value != null && + typeof value.toString != 'function') { + value = nativeObjectToString.call(value); + } + + result[value] = key; + }, constant(identity)); + + /** + * This method is like `_.invert` except that the inverted object is generated + * from the results of running each element of `object` thru `iteratee`. The + * corresponding inverted value of each inverted key is an array of keys + * responsible for generating the inverted value. The iteratee is invoked + * with one argument: (value). + * + * @static + * @memberOf _ + * @since 4.1.0 + * @category Object + * @param {Object} object The object to invert. + * @param {Function} [iteratee=_.identity] The iteratee invoked per element. + * @returns {Object} Returns the new inverted object. + * @example + * + * var object = { 'a': 1, 'b': 2, 'c': 1 }; + * + * _.invertBy(object); + * // => { '1': ['a', 'c'], '2': ['b'] } + * + * _.invertBy(object, function(value) { + * return 'group' + value; + * }); + * // => { 'group1': ['a', 'c'], 'group2': ['b'] } + */ + var invertBy = createInverter(function(result, value, key) { + if (value != null && + typeof value.toString != 'function') { + value = nativeObjectToString.call(value); + } + + if (hasOwnProperty.call(result, value)) { + result[value].push(key); + } else { + result[value] = [key]; + } + }, getIteratee); + + /** + * Invokes the method at `path` of `object`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Object + * @param {Object} object The object to query. + * @param {Array|string} path The path of the method to invoke. + * @param {...*} [args] The arguments to invoke the method with. + * @returns {*} Returns the result of the invoked method. + * @example + * + * var object = { 'a': [{ 'b': { 'c': [1, 2, 3, 4] } }] }; + * + * _.invoke(object, 'a[0].b.c.slice', 1, 3); + * // => [2, 3] + */ + var invoke = baseRest(baseInvoke); + + /** + * Creates an array of the own enumerable property names of `object`. + * + * **Note:** Non-object values are coerced to objects. See the + * [ES spec](http://ecma-international.org/ecma-262/7.0/#sec-object.keys) + * for more details. + * + * @static + * @since 0.1.0 + * @memberOf _ + * @category Object + * @param {Object} object The object to query. + * @returns {Array} Returns the array of property names. + * @example + * + * function Foo() { + * this.a = 1; + * this.b = 2; + * } + * + * Foo.prototype.c = 3; + * + * _.keys(new Foo); + * // => ['a', 'b'] (iteration order is not guaranteed) + * + * _.keys('hi'); + * // => ['0', '1'] + */ + function keys(object) { + return isArrayLike(object) ? arrayLikeKeys(object) : baseKeys(object); + } + + /** + * Creates an array of the own and inherited enumerable property names of `object`. + * + * **Note:** Non-object values are coerced to objects. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Object + * @param {Object} object The object to query. + * @returns {Array} Returns the array of property names. + * @example + * + * function Foo() { + * this.a = 1; + * this.b = 2; + * } + * + * Foo.prototype.c = 3; + * + * _.keysIn(new Foo); + * // => ['a', 'b', 'c'] (iteration order is not guaranteed) + */ + function keysIn(object) { + return isArrayLike(object) ? arrayLikeKeys(object, true) : baseKeysIn(object); + } + + /** + * The opposite of `_.mapValues`; this method creates an object with the + * same values as `object` and keys generated by running each own enumerable + * string keyed property of `object` thru `iteratee`. The iteratee is invoked + * with three arguments: (value, key, object). + * + * @static + * @memberOf _ + * @since 3.8.0 + * @category Object + * @param {Object} object The object to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @returns {Object} Returns the new mapped object. + * @see _.mapValues + * @example + * + * _.mapKeys({ 'a': 1, 'b': 2 }, function(value, key) { + * return key + value; + * }); + * // => { 'a1': 1, 'b2': 2 } + */ + function mapKeys(object, iteratee) { + var result = {}; + iteratee = getIteratee(iteratee, 3); + + baseForOwn(object, function(value, key, object) { + baseAssignValue(result, iteratee(value, key, object), value); + }); + return result; + } + + /** + * Creates an object with the same keys as `object` and values generated + * by running each own enumerable string keyed property of `object` thru + * `iteratee`. The iteratee is invoked with three arguments: + * (value, key, object). + * + * @static + * @memberOf _ + * @since 2.4.0 + * @category Object + * @param {Object} object The object to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @returns {Object} Returns the new mapped object. + * @see _.mapKeys + * @example + * + * var users = { + * 'fred': { 'user': 'fred', 'age': 40 }, + * 'pebbles': { 'user': 'pebbles', 'age': 1 } + * }; + * + * _.mapValues(users, function(o) { return o.age; }); + * // => { 'fred': 40, 'pebbles': 1 } (iteration order is not guaranteed) + * + * // The `_.property` iteratee shorthand. + * _.mapValues(users, 'age'); + * // => { 'fred': 40, 'pebbles': 1 } (iteration order is not guaranteed) + */ + function mapValues(object, iteratee) { + var result = {}; + iteratee = getIteratee(iteratee, 3); + + baseForOwn(object, function(value, key, object) { + baseAssignValue(result, key, iteratee(value, key, object)); + }); + return result; + } + + /** + * This method is like `_.assign` except that it recursively merges own and + * inherited enumerable string keyed properties of source objects into the + * destination object. Source properties that resolve to `undefined` are + * skipped if a destination value exists. Array and plain object properties + * are merged recursively. Other objects and value types are overridden by + * assignment. Source objects are applied from left to right. Subsequent + * sources overwrite property assignments of previous sources. + * + * **Note:** This method mutates `object`. + * + * @static + * @memberOf _ + * @since 0.5.0 + * @category Object + * @param {Object} object The destination object. + * @param {...Object} [sources] The source objects. + * @returns {Object} Returns `object`. + * @example + * + * var object = { + * 'a': [{ 'b': 2 }, { 'd': 4 }] + * }; + * + * var other = { + * 'a': [{ 'c': 3 }, { 'e': 5 }] + * }; + * + * _.merge(object, other); + * // => { 'a': [{ 'b': 2, 'c': 3 }, { 'd': 4, 'e': 5 }] } + */ + var merge = createAssigner(function(object, source, srcIndex) { + baseMerge(object, source, srcIndex); + }); + + /** + * This method is like `_.merge` except that it accepts `customizer` which + * is invoked to produce the merged values of the destination and source + * properties. If `customizer` returns `undefined`, merging is handled by the + * method instead. The `customizer` is invoked with six arguments: + * (objValue, srcValue, key, object, source, stack). + * + * **Note:** This method mutates `object`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Object + * @param {Object} object The destination object. + * @param {...Object} sources The source objects. + * @param {Function} customizer The function to customize assigned values. + * @returns {Object} Returns `object`. + * @example + * + * function customizer(objValue, srcValue) { + * if (_.isArray(objValue)) { + * return objValue.concat(srcValue); + * } + * } + * + * var object = { 'a': [1], 'b': [2] }; + * var other = { 'a': [3], 'b': [4] }; + * + * _.mergeWith(object, other, customizer); + * // => { 'a': [1, 3], 'b': [2, 4] } + */ + var mergeWith = createAssigner(function(object, source, srcIndex, customizer) { + baseMerge(object, source, srcIndex, customizer); + }); + + /** + * The opposite of `_.pick`; this method creates an object composed of the + * own and inherited enumerable property paths of `object` that are not omitted. + * + * **Note:** This method is considerably slower than `_.pick`. + * + * @static + * @since 0.1.0 + * @memberOf _ + * @category Object + * @param {Object} object The source object. + * @param {...(string|string[])} [paths] The property paths to omit. + * @returns {Object} Returns the new object. + * @example + * + * var object = { 'a': 1, 'b': '2', 'c': 3 }; + * + * _.omit(object, ['a', 'c']); + * // => { 'b': '2' } + */ + var omit = flatRest(function(object, paths) { + var result = {}; + if (object == null) { + return result; + } + var isDeep = false; + paths = arrayMap(paths, function(path) { + path = castPath(path, object); + isDeep || (isDeep = path.length > 1); + return path; + }); + copyObject(object, getAllKeysIn(object), result); + if (isDeep) { + result = baseClone(result, CLONE_DEEP_FLAG | CLONE_FLAT_FLAG | CLONE_SYMBOLS_FLAG, customOmitClone); + } + var length = paths.length; + while (length--) { + baseUnset(result, paths[length]); + } + return result; + }); + + /** + * The opposite of `_.pickBy`; this method creates an object composed of + * the own and inherited enumerable string keyed properties of `object` that + * `predicate` doesn't return truthy for. The predicate is invoked with two + * arguments: (value, key). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Object + * @param {Object} object The source object. + * @param {Function} [predicate=_.identity] The function invoked per property. + * @returns {Object} Returns the new object. + * @example + * + * var object = { 'a': 1, 'b': '2', 'c': 3 }; + * + * _.omitBy(object, _.isNumber); + * // => { 'b': '2' } + */ + function omitBy(object, predicate) { + return pickBy(object, negate(getIteratee(predicate))); + } + + /** + * Creates an object composed of the picked `object` properties. + * + * @static + * @since 0.1.0 + * @memberOf _ + * @category Object + * @param {Object} object The source object. + * @param {...(string|string[])} [paths] The property paths to pick. + * @returns {Object} Returns the new object. + * @example + * + * var object = { 'a': 1, 'b': '2', 'c': 3 }; + * + * _.pick(object, ['a', 'c']); + * // => { 'a': 1, 'c': 3 } + */ + var pick = flatRest(function(object, paths) { + return object == null ? {} : basePick(object, paths); + }); + + /** + * Creates an object composed of the `object` properties `predicate` returns + * truthy for. The predicate is invoked with two arguments: (value, key). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Object + * @param {Object} object The source object. + * @param {Function} [predicate=_.identity] The function invoked per property. + * @returns {Object} Returns the new object. + * @example + * + * var object = { 'a': 1, 'b': '2', 'c': 3 }; + * + * _.pickBy(object, _.isNumber); + * // => { 'a': 1, 'c': 3 } + */ + function pickBy(object, predicate) { + if (object == null) { + return {}; + } + var props = arrayMap(getAllKeysIn(object), function(prop) { + return [prop]; + }); + predicate = getIteratee(predicate); + return basePickBy(object, props, function(value, path) { + return predicate(value, path[0]); + }); + } + + /** + * This method is like `_.get` except that if the resolved value is a + * function it's invoked with the `this` binding of its parent object and + * its result is returned. + * + * @static + * @since 0.1.0 + * @memberOf _ + * @category Object + * @param {Object} object The object to query. + * @param {Array|string} path The path of the property to resolve. + * @param {*} [defaultValue] The value returned for `undefined` resolved values. + * @returns {*} Returns the resolved value. + * @example + * + * var object = { 'a': [{ 'b': { 'c1': 3, 'c2': _.constant(4) } }] }; + * + * _.result(object, 'a[0].b.c1'); + * // => 3 + * + * _.result(object, 'a[0].b.c2'); + * // => 4 + * + * _.result(object, 'a[0].b.c3', 'default'); + * // => 'default' + * + * _.result(object, 'a[0].b.c3', _.constant('default')); + * // => 'default' + */ + function result(object, path, defaultValue) { + path = castPath(path, object); + + var index = -1, + length = path.length; + + // Ensure the loop is entered when path is empty. + if (!length) { + length = 1; + object = undefined; + } + while (++index < length) { + var value = object == null ? undefined : object[toKey(path[index])]; + if (value === undefined) { + index = length; + value = defaultValue; + } + object = isFunction(value) ? value.call(object) : value; + } + return object; + } + + /** + * Sets the value at `path` of `object`. If a portion of `path` doesn't exist, + * it's created. Arrays are created for missing index properties while objects + * are created for all other missing properties. Use `_.setWith` to customize + * `path` creation. + * + * **Note:** This method mutates `object`. + * + * @static + * @memberOf _ + * @since 3.7.0 + * @category Object + * @param {Object} object The object to modify. + * @param {Array|string} path The path of the property to set. + * @param {*} value The value to set. + * @returns {Object} Returns `object`. + * @example + * + * var object = { 'a': [{ 'b': { 'c': 3 } }] }; + * + * _.set(object, 'a[0].b.c', 4); + * console.log(object.a[0].b.c); + * // => 4 + * + * _.set(object, ['x', '0', 'y', 'z'], 5); + * console.log(object.x[0].y.z); + * // => 5 + */ + function set(object, path, value) { + return object == null ? object : baseSet(object, path, value); + } + + /** + * This method is like `_.set` except that it accepts `customizer` which is + * invoked to produce the objects of `path`. If `customizer` returns `undefined` + * path creation is handled by the method instead. The `customizer` is invoked + * with three arguments: (nsValue, key, nsObject). + * + * **Note:** This method mutates `object`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Object + * @param {Object} object The object to modify. + * @param {Array|string} path The path of the property to set. + * @param {*} value The value to set. + * @param {Function} [customizer] The function to customize assigned values. + * @returns {Object} Returns `object`. + * @example + * + * var object = {}; + * + * _.setWith(object, '[0][1]', 'a', Object); + * // => { '0': { '1': 'a' } } + */ + function setWith(object, path, value, customizer) { + customizer = typeof customizer == 'function' ? customizer : undefined; + return object == null ? object : baseSet(object, path, value, customizer); + } + + /** + * Creates an array of own enumerable string keyed-value pairs for `object` + * which can be consumed by `_.fromPairs`. If `object` is a map or set, its + * entries are returned. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @alias entries + * @category Object + * @param {Object} object The object to query. + * @returns {Array} Returns the key-value pairs. + * @example + * + * function Foo() { + * this.a = 1; + * this.b = 2; + * } + * + * Foo.prototype.c = 3; + * + * _.toPairs(new Foo); + * // => [['a', 1], ['b', 2]] (iteration order is not guaranteed) + */ + var toPairs = createToPairs(keys); + + /** + * Creates an array of own and inherited enumerable string keyed-value pairs + * for `object` which can be consumed by `_.fromPairs`. If `object` is a map + * or set, its entries are returned. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @alias entriesIn + * @category Object + * @param {Object} object The object to query. + * @returns {Array} Returns the key-value pairs. + * @example + * + * function Foo() { + * this.a = 1; + * this.b = 2; + * } + * + * Foo.prototype.c = 3; + * + * _.toPairsIn(new Foo); + * // => [['a', 1], ['b', 2], ['c', 3]] (iteration order is not guaranteed) + */ + var toPairsIn = createToPairs(keysIn); + + /** + * An alternative to `_.reduce`; this method transforms `object` to a new + * `accumulator` object which is the result of running each of its own + * enumerable string keyed properties thru `iteratee`, with each invocation + * potentially mutating the `accumulator` object. If `accumulator` is not + * provided, a new object with the same `[[Prototype]]` will be used. The + * iteratee is invoked with four arguments: (accumulator, value, key, object). + * Iteratee functions may exit iteration early by explicitly returning `false`. + * + * @static + * @memberOf _ + * @since 1.3.0 + * @category Object + * @param {Object} object The object to iterate over. + * @param {Function} [iteratee=_.identity] The function invoked per iteration. + * @param {*} [accumulator] The custom accumulator value. + * @returns {*} Returns the accumulated value. + * @example + * + * _.transform([2, 3, 4], function(result, n) { + * result.push(n *= n); + * return n % 2 == 0; + * }, []); + * // => [4, 9] + * + * _.transform({ 'a': 1, 'b': 2, 'c': 1 }, function(result, value, key) { + * (result[value] || (result[value] = [])).push(key); + * }, {}); + * // => { '1': ['a', 'c'], '2': ['b'] } + */ + function transform(object, iteratee, accumulator) { + var isArr = isArray(object), + isArrLike = isArr || isBuffer(object) || isTypedArray(object); + + iteratee = getIteratee(iteratee, 4); + if (accumulator == null) { + var Ctor = object && object.constructor; + if (isArrLike) { + accumulator = isArr ? new Ctor : []; + } + else if (isObject(object)) { + accumulator = isFunction(Ctor) ? baseCreate(getPrototype(object)) : {}; + } + else { + accumulator = {}; + } + } + (isArrLike ? arrayEach : baseForOwn)(object, function(value, index, object) { + return iteratee(accumulator, value, index, object); + }); + return accumulator; + } + + /** + * Removes the property at `path` of `object`. + * + * **Note:** This method mutates `object`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Object + * @param {Object} object The object to modify. + * @param {Array|string} path The path of the property to unset. + * @returns {boolean} Returns `true` if the property is deleted, else `false`. + * @example + * + * var object = { 'a': [{ 'b': { 'c': 7 } }] }; + * _.unset(object, 'a[0].b.c'); + * // => true + * + * console.log(object); + * // => { 'a': [{ 'b': {} }] }; + * + * _.unset(object, ['a', '0', 'b', 'c']); + * // => true + * + * console.log(object); + * // => { 'a': [{ 'b': {} }] }; + */ + function unset(object, path) { + return object == null ? true : baseUnset(object, path); + } + + /** + * This method is like `_.set` except that accepts `updater` to produce the + * value to set. Use `_.updateWith` to customize `path` creation. The `updater` + * is invoked with one argument: (value). + * + * **Note:** This method mutates `object`. + * + * @static + * @memberOf _ + * @since 4.6.0 + * @category Object + * @param {Object} object The object to modify. + * @param {Array|string} path The path of the property to set. + * @param {Function} updater The function to produce the updated value. + * @returns {Object} Returns `object`. + * @example + * + * var object = { 'a': [{ 'b': { 'c': 3 } }] }; + * + * _.update(object, 'a[0].b.c', function(n) { return n * n; }); + * console.log(object.a[0].b.c); + * // => 9 + * + * _.update(object, 'x[0].y.z', function(n) { return n ? n + 1 : 0; }); + * console.log(object.x[0].y.z); + * // => 0 + */ + function update(object, path, updater) { + return object == null ? object : baseUpdate(object, path, castFunction(updater)); + } + + /** + * This method is like `_.update` except that it accepts `customizer` which is + * invoked to produce the objects of `path`. If `customizer` returns `undefined` + * path creation is handled by the method instead. The `customizer` is invoked + * with three arguments: (nsValue, key, nsObject). + * + * **Note:** This method mutates `object`. + * + * @static + * @memberOf _ + * @since 4.6.0 + * @category Object + * @param {Object} object The object to modify. + * @param {Array|string} path The path of the property to set. + * @param {Function} updater The function to produce the updated value. + * @param {Function} [customizer] The function to customize assigned values. + * @returns {Object} Returns `object`. + * @example + * + * var object = {}; + * + * _.updateWith(object, '[0][1]', _.constant('a'), Object); + * // => { '0': { '1': 'a' } } + */ + function updateWith(object, path, updater, customizer) { + customizer = typeof customizer == 'function' ? customizer : undefined; + return object == null ? object : baseUpdate(object, path, castFunction(updater), customizer); + } + + /** + * Creates an array of the own enumerable string keyed property values of `object`. + * + * **Note:** Non-object values are coerced to objects. + * + * @static + * @since 0.1.0 + * @memberOf _ + * @category Object + * @param {Object} object The object to query. + * @returns {Array} Returns the array of property values. + * @example + * + * function Foo() { + * this.a = 1; + * this.b = 2; + * } + * + * Foo.prototype.c = 3; + * + * _.values(new Foo); + * // => [1, 2] (iteration order is not guaranteed) + * + * _.values('hi'); + * // => ['h', 'i'] + */ + function values(object) { + return object == null ? [] : baseValues(object, keys(object)); + } + + /** + * Creates an array of the own and inherited enumerable string keyed property + * values of `object`. + * + * **Note:** Non-object values are coerced to objects. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category Object + * @param {Object} object The object to query. + * @returns {Array} Returns the array of property values. + * @example + * + * function Foo() { + * this.a = 1; + * this.b = 2; + * } + * + * Foo.prototype.c = 3; + * + * _.valuesIn(new Foo); + * // => [1, 2, 3] (iteration order is not guaranteed) + */ + function valuesIn(object) { + return object == null ? [] : baseValues(object, keysIn(object)); + } + + /*------------------------------------------------------------------------*/ + + /** + * Clamps `number` within the inclusive `lower` and `upper` bounds. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category Number + * @param {number} number The number to clamp. + * @param {number} [lower] The lower bound. + * @param {number} upper The upper bound. + * @returns {number} Returns the clamped number. + * @example + * + * _.clamp(-10, -5, 5); + * // => -5 + * + * _.clamp(10, -5, 5); + * // => 5 + */ + function clamp(number, lower, upper) { + if (upper === undefined) { + upper = lower; + lower = undefined; + } + if (upper !== undefined) { + upper = toNumber(upper); + upper = upper === upper ? upper : 0; + } + if (lower !== undefined) { + lower = toNumber(lower); + lower = lower === lower ? lower : 0; + } + return baseClamp(toNumber(number), lower, upper); + } + + /** + * Checks if `n` is between `start` and up to, but not including, `end`. If + * `end` is not specified, it's set to `start` with `start` then set to `0`. + * If `start` is greater than `end` the params are swapped to support + * negative ranges. + * + * @static + * @memberOf _ + * @since 3.3.0 + * @category Number + * @param {number} number The number to check. + * @param {number} [start=0] The start of the range. + * @param {number} end The end of the range. + * @returns {boolean} Returns `true` if `number` is in the range, else `false`. + * @see _.range, _.rangeRight + * @example + * + * _.inRange(3, 2, 4); + * // => true + * + * _.inRange(4, 8); + * // => true + * + * _.inRange(4, 2); + * // => false + * + * _.inRange(2, 2); + * // => false + * + * _.inRange(1.2, 2); + * // => true + * + * _.inRange(5.2, 4); + * // => false + * + * _.inRange(-3, -2, -6); + * // => true + */ + function inRange(number, start, end) { + start = toFinite(start); + if (end === undefined) { + end = start; + start = 0; + } else { + end = toFinite(end); + } + number = toNumber(number); + return baseInRange(number, start, end); + } + + /** + * Produces a random number between the inclusive `lower` and `upper` bounds. + * If only one argument is provided a number between `0` and the given number + * is returned. If `floating` is `true`, or either `lower` or `upper` are + * floats, a floating-point number is returned instead of an integer. + * + * **Note:** JavaScript follows the IEEE-754 standard for resolving + * floating-point values which can produce unexpected results. + * + * @static + * @memberOf _ + * @since 0.7.0 + * @category Number + * @param {number} [lower=0] The lower bound. + * @param {number} [upper=1] The upper bound. + * @param {boolean} [floating] Specify returning a floating-point number. + * @returns {number} Returns the random number. + * @example + * + * _.random(0, 5); + * // => an integer between 0 and 5 + * + * _.random(5); + * // => also an integer between 0 and 5 + * + * _.random(5, true); + * // => a floating-point number between 0 and 5 + * + * _.random(1.2, 5.2); + * // => a floating-point number between 1.2 and 5.2 + */ + function random(lower, upper, floating) { + if (floating && typeof floating != 'boolean' && isIterateeCall(lower, upper, floating)) { + upper = floating = undefined; + } + if (floating === undefined) { + if (typeof upper == 'boolean') { + floating = upper; + upper = undefined; + } + else if (typeof lower == 'boolean') { + floating = lower; + lower = undefined; + } + } + if (lower === undefined && upper === undefined) { + lower = 0; + upper = 1; + } + else { + lower = toFinite(lower); + if (upper === undefined) { + upper = lower; + lower = 0; + } else { + upper = toFinite(upper); + } + } + if (lower > upper) { + var temp = lower; + lower = upper; + upper = temp; + } + if (floating || lower % 1 || upper % 1) { + var rand = nativeRandom(); + return nativeMin(lower + (rand * (upper - lower + freeParseFloat('1e-' + ((rand + '').length - 1)))), upper); + } + return baseRandom(lower, upper); + } + + /*------------------------------------------------------------------------*/ + + /** + * Converts `string` to [camel case](https://en.wikipedia.org/wiki/CamelCase). + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category String + * @param {string} [string=''] The string to convert. + * @returns {string} Returns the camel cased string. + * @example + * + * _.camelCase('Foo Bar'); + * // => 'fooBar' + * + * _.camelCase('--foo-bar--'); + * // => 'fooBar' + * + * _.camelCase('__FOO_BAR__'); + * // => 'fooBar' + */ + var camelCase = createCompounder(function(result, word, index) { + word = word.toLowerCase(); + return result + (index ? capitalize(word) : word); + }); + + /** + * Converts the first character of `string` to upper case and the remaining + * to lower case. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category String + * @param {string} [string=''] The string to capitalize. + * @returns {string} Returns the capitalized string. + * @example + * + * _.capitalize('FRED'); + * // => 'Fred' + */ + function capitalize(string) { + return upperFirst(toString(string).toLowerCase()); + } + + /** + * Deburrs `string` by converting + * [Latin-1 Supplement](https://en.wikipedia.org/wiki/Latin-1_Supplement_(Unicode_block)#Character_table) + * and [Latin Extended-A](https://en.wikipedia.org/wiki/Latin_Extended-A) + * letters to basic Latin letters and removing + * [combining diacritical marks](https://en.wikipedia.org/wiki/Combining_Diacritical_Marks). + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category String + * @param {string} [string=''] The string to deburr. + * @returns {string} Returns the deburred string. + * @example + * + * _.deburr('déjà vu'); + * // => 'deja vu' + */ + function deburr(string) { + string = toString(string); + return string && string.replace(reLatin, deburrLetter).replace(reComboMark, ''); + } + + /** + * Checks if `string` ends with the given target string. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category String + * @param {string} [string=''] The string to inspect. + * @param {string} [target] The string to search for. + * @param {number} [position=string.length] The position to search up to. + * @returns {boolean} Returns `true` if `string` ends with `target`, + * else `false`. + * @example + * + * _.endsWith('abc', 'c'); + * // => true + * + * _.endsWith('abc', 'b'); + * // => false + * + * _.endsWith('abc', 'b', 2); + * // => true + */ + function endsWith(string, target, position) { + string = toString(string); + target = baseToString(target); + + var length = string.length; + position = position === undefined + ? length + : baseClamp(toInteger(position), 0, length); + + var end = position; + position -= target.length; + return position >= 0 && string.slice(position, end) == target; + } + + /** + * Converts the characters "&", "<", ">", '"', and "'" in `string` to their + * corresponding HTML entities. + * + * **Note:** No other characters are escaped. To escape additional + * characters use a third-party library like [_he_](https://mths.be/he). + * + * Though the ">" character is escaped for symmetry, characters like + * ">" and "/" don't need escaping in HTML and have no special meaning + * unless they're part of a tag or unquoted attribute value. See + * [Mathias Bynens's article](https://mathiasbynens.be/notes/ambiguous-ampersands) + * (under "semi-related fun fact") for more details. + * + * When working with HTML you should always + * [quote attribute values](http://wonko.com/post/html-escaping) to reduce + * XSS vectors. + * + * @static + * @since 0.1.0 + * @memberOf _ + * @category String + * @param {string} [string=''] The string to escape. + * @returns {string} Returns the escaped string. + * @example + * + * _.escape('fred, barney, & pebbles'); + * // => 'fred, barney, & pebbles' + */ + function escape(string) { + string = toString(string); + return (string && reHasUnescapedHtml.test(string)) + ? string.replace(reUnescapedHtml, escapeHtmlChar) + : string; + } + + /** + * Escapes the `RegExp` special characters "^", "$", "\", ".", "*", "+", + * "?", "(", ")", "[", "]", "{", "}", and "|" in `string`. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category String + * @param {string} [string=''] The string to escape. + * @returns {string} Returns the escaped string. + * @example + * + * _.escapeRegExp('[lodash](https://lodash.com/)'); + * // => '\[lodash\]\(https://lodash\.com/\)' + */ + function escapeRegExp(string) { + string = toString(string); + return (string && reHasRegExpChar.test(string)) + ? string.replace(reRegExpChar, '\\$&') + : string; + } + + /** + * Converts `string` to + * [kebab case](https://en.wikipedia.org/wiki/Letter_case#Special_case_styles). + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category String + * @param {string} [string=''] The string to convert. + * @returns {string} Returns the kebab cased string. + * @example + * + * _.kebabCase('Foo Bar'); + * // => 'foo-bar' + * + * _.kebabCase('fooBar'); + * // => 'foo-bar' + * + * _.kebabCase('__FOO_BAR__'); + * // => 'foo-bar' + */ + var kebabCase = createCompounder(function(result, word, index) { + return result + (index ? '-' : '') + word.toLowerCase(); + }); + + /** + * Converts `string`, as space separated words, to lower case. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category String + * @param {string} [string=''] The string to convert. + * @returns {string} Returns the lower cased string. + * @example + * + * _.lowerCase('--Foo-Bar--'); + * // => 'foo bar' + * + * _.lowerCase('fooBar'); + * // => 'foo bar' + * + * _.lowerCase('__FOO_BAR__'); + * // => 'foo bar' + */ + var lowerCase = createCompounder(function(result, word, index) { + return result + (index ? ' ' : '') + word.toLowerCase(); + }); + + /** + * Converts the first character of `string` to lower case. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category String + * @param {string} [string=''] The string to convert. + * @returns {string} Returns the converted string. + * @example + * + * _.lowerFirst('Fred'); + * // => 'fred' + * + * _.lowerFirst('FRED'); + * // => 'fRED' + */ + var lowerFirst = createCaseFirst('toLowerCase'); + + /** + * Pads `string` on the left and right sides if it's shorter than `length`. + * Padding characters are truncated if they can't be evenly divided by `length`. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category String + * @param {string} [string=''] The string to pad. + * @param {number} [length=0] The padding length. + * @param {string} [chars=' '] The string used as padding. + * @returns {string} Returns the padded string. + * @example + * + * _.pad('abc', 8); + * // => ' abc ' + * + * _.pad('abc', 8, '_-'); + * // => '_-abc_-_' + * + * _.pad('abc', 3); + * // => 'abc' + */ + function pad(string, length, chars) { + string = toString(string); + length = toInteger(length); + + var strLength = length ? stringSize(string) : 0; + if (!length || strLength >= length) { + return string; + } + var mid = (length - strLength) / 2; + return ( + createPadding(nativeFloor(mid), chars) + + string + + createPadding(nativeCeil(mid), chars) + ); + } + + /** + * Pads `string` on the right side if it's shorter than `length`. Padding + * characters are truncated if they exceed `length`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category String + * @param {string} [string=''] The string to pad. + * @param {number} [length=0] The padding length. + * @param {string} [chars=' '] The string used as padding. + * @returns {string} Returns the padded string. + * @example + * + * _.padEnd('abc', 6); + * // => 'abc ' + * + * _.padEnd('abc', 6, '_-'); + * // => 'abc_-_' + * + * _.padEnd('abc', 3); + * // => 'abc' + */ + function padEnd(string, length, chars) { + string = toString(string); + length = toInteger(length); + + var strLength = length ? stringSize(string) : 0; + return (length && strLength < length) + ? (string + createPadding(length - strLength, chars)) + : string; + } + + /** + * Pads `string` on the left side if it's shorter than `length`. Padding + * characters are truncated if they exceed `length`. + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category String + * @param {string} [string=''] The string to pad. + * @param {number} [length=0] The padding length. + * @param {string} [chars=' '] The string used as padding. + * @returns {string} Returns the padded string. + * @example + * + * _.padStart('abc', 6); + * // => ' abc' + * + * _.padStart('abc', 6, '_-'); + * // => '_-_abc' + * + * _.padStart('abc', 3); + * // => 'abc' + */ + function padStart(string, length, chars) { + string = toString(string); + length = toInteger(length); + + var strLength = length ? stringSize(string) : 0; + return (length && strLength < length) + ? (createPadding(length - strLength, chars) + string) + : string; + } + + /** + * Converts `string` to an integer of the specified radix. If `radix` is + * `undefined` or `0`, a `radix` of `10` is used unless `value` is a + * hexadecimal, in which case a `radix` of `16` is used. + * + * **Note:** This method aligns with the + * [ES5 implementation](https://es5.github.io/#x15.1.2.2) of `parseInt`. + * + * @static + * @memberOf _ + * @since 1.1.0 + * @category String + * @param {string} string The string to convert. + * @param {number} [radix=10] The radix to interpret `value` by. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`. + * @returns {number} Returns the converted integer. + * @example + * + * _.parseInt('08'); + * // => 8 + * + * _.map(['6', '08', '10'], _.parseInt); + * // => [6, 8, 10] + */ + function parseInt(string, radix, guard) { + if (guard || radix == null) { + radix = 0; + } else if (radix) { + radix = +radix; + } + return nativeParseInt(toString(string).replace(reTrimStart, ''), radix || 0); + } + + /** + * Repeats the given string `n` times. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category String + * @param {string} [string=''] The string to repeat. + * @param {number} [n=1] The number of times to repeat the string. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`. + * @returns {string} Returns the repeated string. + * @example + * + * _.repeat('*', 3); + * // => '***' + * + * _.repeat('abc', 2); + * // => 'abcabc' + * + * _.repeat('abc', 0); + * // => '' + */ + function repeat(string, n, guard) { + if ((guard ? isIterateeCall(string, n, guard) : n === undefined)) { + n = 1; + } else { + n = toInteger(n); + } + return baseRepeat(toString(string), n); + } + + /** + * Replaces matches for `pattern` in `string` with `replacement`. + * + * **Note:** This method is based on + * [`String#replace`](https://mdn.io/String/replace). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category String + * @param {string} [string=''] The string to modify. + * @param {RegExp|string} pattern The pattern to replace. + * @param {Function|string} replacement The match replacement. + * @returns {string} Returns the modified string. + * @example + * + * _.replace('Hi Fred', 'Fred', 'Barney'); + * // => 'Hi Barney' + */ + function replace() { + var args = arguments, + string = toString(args[0]); + + return args.length < 3 ? string : string.replace(args[1], args[2]); + } + + /** + * Converts `string` to + * [snake case](https://en.wikipedia.org/wiki/Snake_case). + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category String + * @param {string} [string=''] The string to convert. + * @returns {string} Returns the snake cased string. + * @example + * + * _.snakeCase('Foo Bar'); + * // => 'foo_bar' + * + * _.snakeCase('fooBar'); + * // => 'foo_bar' + * + * _.snakeCase('--FOO-BAR--'); + * // => 'foo_bar' + */ + var snakeCase = createCompounder(function(result, word, index) { + return result + (index ? '_' : '') + word.toLowerCase(); + }); + + /** + * Splits `string` by `separator`. + * + * **Note:** This method is based on + * [`String#split`](https://mdn.io/String/split). + * + * @static + * @memberOf _ + * @since 4.0.0 + * @category String + * @param {string} [string=''] The string to split. + * @param {RegExp|string} separator The separator pattern to split by. + * @param {number} [limit] The length to truncate results to. + * @returns {Array} Returns the string segments. + * @example + * + * _.split('a-b-c', '-', 2); + * // => ['a', 'b'] + */ + function split(string, separator, limit) { + if (limit && typeof limit != 'number' && isIterateeCall(string, separator, limit)) { + separator = limit = undefined; + } + limit = limit === undefined ? MAX_ARRAY_LENGTH : limit >>> 0; + if (!limit) { + return []; + } + string = toString(string); + if (string && ( + typeof separator == 'string' || + (separator != null && !isRegExp(separator)) + )) { + separator = baseToString(separator); + if (!separator && hasUnicode(string)) { + return castSlice(stringToArray(string), 0, limit); + } + } + return string.split(separator, limit); + } + + /** + * Converts `string` to + * [start case](https://en.wikipedia.org/wiki/Letter_case#Stylistic_or_specialised_usage). + * + * @static + * @memberOf _ + * @since 3.1.0 + * @category String + * @param {string} [string=''] The string to convert. + * @returns {string} Returns the start cased string. + * @example + * + * _.startCase('--foo-bar--'); + * // => 'Foo Bar' + * + * _.startCase('fooBar'); + * // => 'Foo Bar' + * + * _.startCase('__FOO_BAR__'); + * // => 'FOO BAR' + */ + var startCase = createCompounder(function(result, word, index) { + return result + (index ? ' ' : '') + upperFirst(word); + }); + + /** + * Checks if `string` starts with the given target string. + * + * @static + * @memberOf _ + * @since 3.0.0 + * @category String + * @param {string} [string=''] The string to inspect. + * @param {string} [target] The string to search for. + * @param {number} [position=0] The position to search from. + * @returns {boolean} Returns `true` if `string` starts with `target`, + * else `false`. + * @example + * + * _.startsWith('abc', 'a'); + * // => true + * + * _.startsWith('abc', 'b'); + * // => false + * + * _.startsWith('abc', 'b', 1); + * // => true + */ + function startsWith(string, target, position) { + string = toString(string); + position = position == null + ? 0 + : baseClamp(toInteger(position), 0, string.length); + + target = baseToString(target); + return string.slice(position, position + target.length) == target; + } + + /** + * Creates a compiled template function that can interpolate data properties + * in "interpolate" delimiters, HTML-escape interpolated data properties in + * "escape" delimiters, and execute JavaScript in "evaluate" delimiters. Data + * properties may be accessed as free variables in the template. If a setting + * object is given, it takes precedence over `_.templateSettings` values. + * + * **Note:** In the development build `_.template` utilizes + * [sourceURLs](http://www.html5rocks.com/en/tutorials/developertools/sourcemaps/#toc-sourceurl) + * for easier debugging. + * + * For more information on precompiling templates see + * [lodash's custom builds documentation](https://lodash.com/custom-builds). + * + * For more information on Chrome extension sandboxes see + * [Chrome's extensions documentation](https://developer.chrome.com/extensions/sandboxingEval). + * + * @static + * @since 0.1.0 + * @memberOf _ + * @category String + * @param {string} [string=''] The template string. + * @param {Object} [options={}] The options object. + * @param {RegExp} [options.escape=_.templateSettings.escape] + * The HTML "escape" delimiter. + * @param {RegExp} [options.evaluate=_.templateSettings.evaluate] + * The "evaluate" delimiter. + * @param {Object} [options.imports=_.templateSettings.imports] + * An object to import into the template as free variables. + * @param {RegExp} [options.interpolate=_.templateSettings.interpolate] + * The "interpolate" delimiter. + * @param {string} [options.sourceURL='lodash.templateSources[n]'] + * The sourceURL of the compiled template. + * @param {string} [options.variable='obj'] + * The data object variable name. + * @param- {Object} [guard] Enables use as an iteratee for methods like `_.map`. + * @returns {Function} Returns the compiled template function. + * @example + * + * // Use the "interpolate" delimiter to create a compiled template. + * var compiled = _.template('hello <%= user %>!'); + * compiled({ 'user': 'fred' }); + * // => 'hello fred!' + * + * // Use the HTML "escape" delimiter to escape data property values. + * var compiled = _.template('<%- value %>'); + * compiled({ 'value': '''' % bundle + random_id = id_generator(size=15, random_state=check_random_state(self.random_state)) + out += u''' +
+ ''' % random_id + + predict_proba_js = '' + if self.mode == "classification" and predict_proba: + predict_proba_js = u''' + var pp_div = top_div.append('div') + .classed('lime predict_proba', true); + var pp_svg = pp_div.append('svg').style('width', '100%%'); + var pp = new lime.PredictProba(pp_svg, %s, %s); + ''' % (jsonize([str(x) for x in self.class_names]), + jsonize(list(self.predict_proba.astype(float)))) + + predict_value_js = '' + if self.mode == "regression" and show_predicted_value: + # reference self.predicted_value + # (svg, predicted_value, min_value, max_value) + predict_value_js = u''' + var pp_div = top_div.append('div') + .classed('lime predicted_value', true); + var pp_svg = pp_div.append('svg').style('width', '100%%'); + var pp = new lime.PredictedValue(pp_svg, %s, %s, %s); + ''' % (jsonize(float(self.predicted_value)), + jsonize(float(self.min_value)), + jsonize(float(self.max_value))) + + exp_js = '''var exp_div; + var exp = new lime.Explanation(%s); + ''' % (jsonize([str(x) for x in self.class_names])) + + if self.mode == "classification": + for label in labels: + exp = jsonize(self.as_list(label)) + exp_js += u''' + exp_div = top_div.append('div').classed('lime explanation', true); + exp.show(%s, %d, exp_div); + ''' % (exp, label) + else: + exp = jsonize(self.as_list()) + exp_js += u''' + exp_div = top_div.append('div').classed('lime explanation', true); + exp.show(%s, %s, exp_div); + ''' % (exp, self.dummy_label) + + raw_js = '''var raw_div = top_div.append('div');''' + + if self.mode == "classification": + html_data = self.local_exp[labels[0]] + else: + html_data = self.local_exp[self.dummy_label] + + raw_js += self.domain_mapper.visualize_instance_html( + html_data, + labels[0] if self.mode == "classification" else self.dummy_label, + 'raw_div', + 'exp', + **kwargs) + out += u''' + + ''' % (random_id, predict_proba_js, predict_value_js, exp_js, raw_js) + out += u'' + + return out diff --git a/lime/lime_base.py b/lime/lime_base.py new file mode 100644 index 0000000000000000000000000000000000000000..c63474b6cdd70beec73f93b2d1297eedf0a3522d --- /dev/null +++ b/lime/lime_base.py @@ -0,0 +1,210 @@ +""" +Contains abstract functionality for learning locally linear sparse model. +""" +import numpy as np +import scipy as sp +from sklearn.linear_model import Ridge, lars_path +from sklearn.utils import check_random_state + + +class LimeBase(object): + """Class for learning a locally linear sparse model from perturbed data""" + def __init__(self, + kernel_fn, + verbose=False, + random_state=None): + """Init function + + Args: + kernel_fn: function that transforms an array of distances into an + array of proximity values (floats). + verbose: if true, print local prediction values from linear model. + random_state: an integer or numpy.RandomState that will be used to + generate random numbers. If None, the random state will be + initialized using the internal numpy seed. + """ + self.kernel_fn = kernel_fn + self.verbose = verbose + self.random_state = check_random_state(random_state) + + @staticmethod + def generate_lars_path(weighted_data, weighted_labels): + """Generates the lars path for weighted data. + + Args: + weighted_data: data that has been weighted by kernel + weighted_label: labels, weighted by kernel + + Returns: + (alphas, coefs), both are arrays corresponding to the + regularization parameter and coefficients, respectively + """ + x_vector = weighted_data + alphas, _, coefs = lars_path(x_vector, + weighted_labels, + method='lasso', + verbose=False) + return alphas, coefs + + def forward_selection(self, data, labels, weights, num_features): + """Iteratively adds features to the model""" + clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state) + used_features = [] + for _ in range(min(num_features, data.shape[1])): + max_ = -100000000 + best = 0 + for feature in range(data.shape[1]): + if feature in used_features: + continue + clf.fit(data[:, used_features + [feature]], labels, + sample_weight=weights) + score = clf.score(data[:, used_features + [feature]], + labels, + sample_weight=weights) + if score > max_: + best = feature + max_ = score + used_features.append(best) + return np.array(used_features) + + def feature_selection(self, data, labels, weights, num_features, method): + """Selects features for the model. see explain_instance_with_data to + understand the parameters.""" + if method == 'none': + return np.array(range(data.shape[1])) + elif method == 'forward_selection': + return self.forward_selection(data, labels, weights, num_features) + elif method == 'highest_weights': + clf = Ridge(alpha=0.01, fit_intercept=True, + random_state=self.random_state) + # print("data shape: ", data.shape) + # print("labels shape: ", labels.shape) + # assert(False) + clf.fit(data, labels, sample_weight=weights) + + coef = clf.coef_ + if sp.sparse.issparse(data): + coef = sp.sparse.csr_matrix(clf.coef_) + weighted_data = coef.multiply(data[0]) + # Note: most efficient to slice the data before reversing + sdata = len(weighted_data.data) + argsort_data = np.abs(weighted_data.data).argsort() + # Edge case where data is more sparse than requested number of feature importances + # In that case, we just pad with zero-valued features + if sdata < num_features: + nnz_indexes = argsort_data[::-1] + indices = weighted_data.indices[nnz_indexes] + num_to_pad = num_features - sdata + indices = np.concatenate((indices, np.zeros(num_to_pad, dtype=indices.dtype))) + indices_set = set(indices) + pad_counter = 0 + for i in range(data.shape[1]): + if i not in indices_set: + indices[pad_counter + sdata] = i + pad_counter += 1 + if pad_counter >= num_to_pad: + break + else: + nnz_indexes = argsort_data[sdata - num_features:sdata][::-1] + indices = weighted_data.indices[nnz_indexes] + return indices + else: + weighted_data = coef * data[0] + feature_weights = sorted( + zip(range(data.shape[1]), weighted_data), + key=lambda x: np.abs(x[1]), + reverse=True) + return np.array([x[0] for x in feature_weights[:num_features]]) + elif method == 'lasso_path': + weighted_data = ((data - np.average(data, axis=0, weights=weights)) + * np.sqrt(weights[:, np.newaxis])) + weighted_labels = ((labels - np.average(labels, weights=weights)) + * np.sqrt(weights)) + nonzero = range(weighted_data.shape[1]) + _, coefs = self.generate_lars_path(weighted_data, + weighted_labels) + for i in range(len(coefs.T) - 1, 0, -1): + nonzero = coefs.T[i].nonzero()[0] + if len(nonzero) <= num_features: + break + used_features = nonzero + return used_features + elif method == 'auto': + if num_features <= 6: + n_method = 'forward_selection' + else: + n_method = 'highest_weights' + return self.feature_selection(data, labels, weights, + num_features, n_method) + + def explain_instance_with_data(self, + neighborhood_data, + neighborhood_labels, + distances, + label, + num_features, + feature_selection='auto', + model_regressor=None): + """Takes perturbed data, labels and distances, returns explanation. + + Args: + neighborhood_data: perturbed data, 2d array. first element is + assumed to be the original data point. + neighborhood_labels: corresponding perturbed labels. should have as + many columns as the number of possible labels. + distances: distances to original data point. + label: label for which we want an explanation + num_features: maximum number of features in explanation + feature_selection: how to select num_features. options are: + 'forward_selection': iteratively add features to the model. + This is costly when num_features is high + 'highest_weights': selects the features that have the highest + product of absolute weight * original data point when + learning with all the features + 'lasso_path': chooses features based on the lasso + regularization path + 'none': uses all features, ignores num_features + 'auto': uses forward_selection if num_features <= 6, and + 'highest_weights' otherwise. + model_regressor: sklearn regressor to use in explanation. + Defaults to Ridge regression if None. Must have + model_regressor.coef_ and 'sample_weight' as a parameter + to model_regressor.fit() + + Returns: + (intercept, exp, score, local_pred): + intercept is a float. + exp is a sorted list of tuples, where each tuple (x,y) corresponds + to the feature id (x) and the local weight (y). The list is sorted + by decreasing absolute value of y. + score is the R^2 value of the returned explanation + local_pred is the prediction of the explanation model on the original instance + """ + + weights = self.kernel_fn(distances) + labels_column = neighborhood_labels[:, label] + used_features = self.feature_selection(neighborhood_data, + labels_column, + weights, + num_features, + feature_selection) + if model_regressor is None: + model_regressor = Ridge(alpha=1, fit_intercept=True, + random_state=self.random_state) + easy_model = model_regressor + easy_model.fit(neighborhood_data[:, used_features], + labels_column, sample_weight=weights) + prediction_score = easy_model.score( + neighborhood_data[:, used_features], + labels_column, sample_weight=weights) + + local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1)) + + if self.verbose: + print('Intercept', easy_model.intercept_) + print('Prediction_local', local_pred,) + print('Right:', neighborhood_labels[0, label]) + return (easy_model.intercept_, + sorted(zip(used_features, easy_model.coef_), + key=lambda x: np.abs(x[1]), reverse=True), + prediction_score, local_pred) diff --git a/lime/lime_image.py b/lime/lime_image.py new file mode 100644 index 0000000000000000000000000000000000000000..d907f2a3f3253855013e07b34778435139464d7b --- /dev/null +++ b/lime/lime_image.py @@ -0,0 +1,298 @@ +""" +Functions for explaining classifiers that use Image data. +""" +import copy +from functools import partial + +import numpy as np +import sklearn +import sklearn.preprocessing +from sklearn.utils import check_random_state +from skimage.color import gray2rgb +from tqdm.auto import tqdm + + +from . import lime_base +from .wrappers.scikit_image import SegmentationAlgorithm + + +class ImageExplanation(object): + def __init__(self, image, segments): + """Init function. + + Args: + image: 3d numpy array + segments: 2d numpy array, with the output from skimage.segmentation + """ + self.image = image + self.segments = segments + self.intercept = {} + self.local_exp = {} + self.local_pred = None + + def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False, + num_features=5, min_weight=0.): + """Init function. + + Args: + label: label to explain + positive_only: if True, only take superpixels that positively contribute to + the prediction of the label. + negative_only: if True, only take superpixels that negatively contribute to + the prediction of the label. If false, and so is positive_only, then both + negativey and positively contributions will be taken. + Both can't be True at the same time + hide_rest: if True, make the non-explanation part of the return + image gray + num_features: number of superpixels to include in explanation + min_weight: minimum weight of the superpixels to include in explanation + + Returns: + (image, mask), where image is a 3d numpy array and mask is a 2d + numpy array that can be used with + skimage.segmentation.mark_boundaries + """ + if label not in self.local_exp: + raise KeyError('Label not in explanation') + if positive_only & negative_only: + raise ValueError("Positive_only and negative_only cannot be true at the same time.") + segments = self.segments + image = self.image + exp = self.local_exp[label] + mask = np.zeros(segments.shape, segments.dtype) + if hide_rest: + temp = np.zeros(self.image.shape) + else: + temp = self.image.copy() + if positive_only: + fs = [x[0] for x in exp + if x[1] > 0 and x[1] > min_weight][:num_features] + if negative_only: + fs = [x[0] for x in exp + if x[1] < 0 and abs(x[1]) > min_weight][:num_features] + if positive_only or negative_only: + for f in fs: + temp[segments == f] = image[segments == f].copy() + mask[segments == f] = 1 + return temp, mask + else: + for f, w in exp[:num_features]: + if np.abs(w) < min_weight: + continue + c = 0 if w < 0 else 1 + mask[segments == f] = -1 if w < 0 else 1 + temp[segments == f] = image[segments == f].copy() + temp[segments == f, c] = np.max(image) + return temp, mask + + +class LimeImageExplainer(object): + """Explains predictions on Image (i.e. matrix) data. + For numerical features, perturb them by sampling from a Normal(0,1) and + doing the inverse operation of mean-centering and scaling, according to the + means and stds in the training data. For categorical features, perturb by + sampling according to the training distribution, and making a binary + feature that is 1 when the value is the same as the instance being + explained.""" + + def __init__(self, kernel_width=.25, kernel=None, verbose=False, + feature_selection='auto', random_state=None): + """Init function. + + Args: + kernel_width: kernel width for the exponential kernel. + If None, defaults to sqrt(number of columns) * 0.75. + kernel: similarity kernel that takes euclidean distances and kernel + width as input and outputs weights in (0,1). If None, defaults to + an exponential kernel. + verbose: if true, print local prediction values from linear model + feature_selection: feature selection method. can be + 'forward_selection', 'lasso_path', 'none' or 'auto'. + See function 'explain_instance_with_data' in lime_base.py for + details on what each of the options does. + random_state: an integer or numpy.RandomState that will be used to + generate random numbers. If None, the random state will be + initialized using the internal numpy seed. + """ + kernel_width = float(kernel_width) + + if kernel is None: + def kernel(d, kernel_width): + return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2)) + + kernel_fn = partial(kernel, kernel_width=kernel_width) + + self.random_state = check_random_state(random_state) + self.feature_selection = feature_selection + self.base = lime_base.LimeBase(kernel_fn, verbose, random_state=self.random_state) + + ### Custom function to acquire segmentation only, same as in the explain_instance() function + def acquireSegmOnly(self, img): + random_seed = self.random_state.randint(0, high=1000) + segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, + max_dist=200, ratio=0.2, + random_seed=random_seed) + segments = segmentation_fn(img) + return segments + + def explain_instance(self, image, inputImg, classifier_fn, labels=(1,), + hide_color=None, + top_labels=5, num_features=100000, num_samples=1000, + batch_size=10, + segmentation_fn=None, + distance_metric='cosine', + model_regressor=None, + random_seed=None, + squaredSegm=None, + loadedSegmData=None): + """Generates explanations for a prediction. + + First, we generate neighborhood data by randomly perturbing features + from the instance (see __data_inverse). We then learn locally weighted + linear models on this neighborhood data to explain each of the classes + in an interpretable way (see lime_base.py). + + Args: + image: 3 dimension RGB image. If this is only two dimensional, + we will assume it's a grayscale image and call gray2rgb. + classifier_fn: classifier prediction probability function, which + takes a numpy array and outputs prediction probabilities. For + ScikitClassifiers , this is classifier.predict_proba. + labels: iterable with labels to be explained. + hide_color: TODO + top_labels: if not None, ignore labels and produce explanations for + the K labels with highest prediction probabilities, where K is + this parameter. + num_features: maximum number of features present in explanation + num_samples: size of the neighborhood to learn the linear model + batch_size: TODO + distance_metric: the distance metric to use for weights. + model_regressor: sklearn regressor to use in explanation. Defaults + to Ridge regression in LimeBase. Must have model_regressor.coef_ + and 'sample_weight' as a parameter to model_regressor.fit() + segmentation_fn: SegmentationAlgorithm, wrapped skimage + segmentation function + random_seed: integer used as random seed for the segmentation + algorithm. If None, a random integer, between 0 and 1000, + will be generated using the internal random number generator. + squaredSegm: integer or None (default): + + Returns: + An ImageExplanation object (see lime_image.py) with the corresponding + explanations. + """ + if len(image.shape) == 2: + image = gray2rgb(image) + if random_seed is None: + random_seed = self.random_state.randint(0, high=1000) + + if squaredSegm == 4: + segments = np.zeros((image.shape[0], image.shape[1]), dtype=np.int64) + imgW = image.shape[1] + halfW1 = 1*imgW//4 + halfW2 = 2*imgW//4 + halfW3 = 3*imgW//4 + segments[:,0:halfW1] = 0 + segments[:,halfW1:halfW2] = 1 + segments[:,halfW2:halfW3] = 2 + segments[:,halfW3:imgW] = 3 + elif squaredSegm == -2: ### Use to load custom resized segm data + segments = loadedSegmData + else: + if segmentation_fn is None: + segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, + max_dist=200, ratio=0.2, + random_seed=random_seed) + try: + segments = segmentation_fn(image) + except ValueError as e: + raise e + + fudged_image = image.copy() + if hide_color is None: + for x in np.unique(segments): + fudged_image[segments == x] = ( + np.mean(image[segments == x][:, 0]), + np.mean(image[segments == x][:, 1]), + np.mean(image[segments == x][:, 2])) + else: + fudged_image[:] = hide_color + + top = labels + + data, labels = self.data_labels(image, inputImg, fudged_image, segments, + classifier_fn, num_samples, + batch_size=batch_size) + + distances = sklearn.metrics.pairwise_distances( + data, + data[0].reshape(1, -1), + metric=distance_metric + ).ravel() + + ret_exp = ImageExplanation(image, segments) + if top_labels: + top = np.argsort(labels[0])[-top_labels:] + ret_exp.top_labels = list(top) + ret_exp.top_labels.reverse() + for label in top: + (ret_exp.intercept[label], + ret_exp.local_exp[label], + ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data( + data, labels, distances, label, num_features, + model_regressor=model_regressor, + feature_selection=self.feature_selection) + return ret_exp + + def data_labels(self, + image, + inputImg, + fudged_image, + segments, + classifier_fn, + num_samples, + batch_size=10): + """Generates images and predictions in the neighborhood of this image. + + Args: + image: 3d numpy array, the image + fudged_image: 3d numpy array, image to replace original image when + superpixel is turned off + segments: segmentation of the image + classifier_fn: function that takes a list of images and returns a + matrix of prediction probabilities + num_samples: size of the neighborhood to learn the linear model + batch_size: classifier_fn will be called on batches of this size. + + Returns: + A tuple (data, labels), where: + data: dense num_samples * num_superpixels + labels: prediction probabilities matrix + """ + n_features = np.unique(segments).shape[0] + data = self.random_state.randint(0, 2, num_samples * n_features)\ + .reshape((num_samples, n_features)) + labels = [] + data[0, :] = 1 + imgs = [] + # print("data new shape: ", data.shape) + # assert(False) + # for row in tqdm(data): + for row in data: + temp = copy.deepcopy(image) + zeros = np.where(row == 0)[0] + mask = np.zeros(segments.shape).astype(bool) + for z in zeros: + mask[segments == z] = True + temp[mask] = fudged_image[mask] + imgs.append(temp) + if len(imgs) == batch_size: + preds = classifier_fn(inputImg) + preds = preds.cpu().detach().numpy() + labels.extend(preds) + imgs = [] + if len(imgs) > 0: + preds = classifier_fn(inputImg) + preds = preds.cpu().detach().numpy() + labels.extend(preds) + return data, np.array(labels) diff --git a/lime/lime_tabular.py b/lime/lime_tabular.py new file mode 100644 index 0000000000000000000000000000000000000000..f7883151b2d8b868c8c771c57d671c3962738f75 --- /dev/null +++ b/lime/lime_tabular.py @@ -0,0 +1,702 @@ +""" +Functions for explaining classifiers that use tabular data (matrices). +""" +import collections +import copy +from functools import partial +import json +import warnings + +import numpy as np +import scipy as sp +import sklearn +import sklearn.preprocessing +from sklearn.utils import check_random_state + +from lime.discretize import QuartileDiscretizer +from lime.discretize import DecileDiscretizer +from lime.discretize import EntropyDiscretizer +from lime.discretize import BaseDiscretizer +from lime.discretize import StatsDiscretizer +from . import explanation +from . import lime_base + + +class TableDomainMapper(explanation.DomainMapper): + """Maps feature ids to names, generates table views, etc""" + + def __init__(self, feature_names, feature_values, scaled_row, + categorical_features, discretized_feature_names=None, + feature_indexes=None): + """Init. + + Args: + feature_names: list of feature names, in order + feature_values: list of strings with the values of the original row + scaled_row: scaled row + categorical_features: list of categorical features ids (ints) + feature_indexes: optional feature indexes used in the sparse case + """ + self.exp_feature_names = feature_names + self.discretized_feature_names = discretized_feature_names + self.feature_names = feature_names + self.feature_values = feature_values + self.feature_indexes = feature_indexes + self.scaled_row = scaled_row + if sp.sparse.issparse(scaled_row): + self.all_categorical = False + else: + self.all_categorical = len(categorical_features) == len(scaled_row) + self.categorical_features = categorical_features + + def map_exp_ids(self, exp): + """Maps ids to feature names. + + Args: + exp: list of tuples [(id, weight), (id,weight)] + + Returns: + list of tuples (feature_name, weight) + """ + names = self.exp_feature_names + if self.discretized_feature_names is not None: + names = self.discretized_feature_names + return [(names[x[0]], x[1]) for x in exp] + + def visualize_instance_html(self, + exp, + label, + div_name, + exp_object_name, + show_table=True, + show_all=False): + """Shows the current example in a table format. + + Args: + exp: list of tuples [(id, weight), (id,weight)] + label: label id (integer) + div_name: name of div object to be used for rendering(in js) + exp_object_name: name of js explanation object + show_table: if False, don't show table visualization. + show_all: if True, show zero-weighted features in the table. + """ + if not show_table: + return '' + weights = [0] * len(self.feature_names) + for x in exp: + weights[x[0]] = x[1] + if self.feature_indexes is not None: + # Sparse case: only display the non-zero values and importances + fnames = [self.exp_feature_names[i] for i in self.feature_indexes] + fweights = [weights[i] for i in self.feature_indexes] + if show_all: + out_list = list(zip(fnames, + self.feature_values, + fweights)) + else: + out_dict = dict(map(lambda x: (x[0], (x[1], x[2], x[3])), + zip(self.feature_indexes, + fnames, + self.feature_values, + fweights))) + out_list = [out_dict.get(x[0], (str(x[0]), 0.0, 0.0)) for x in exp] + else: + out_list = list(zip(self.exp_feature_names, + self.feature_values, + weights)) + if not show_all: + out_list = [out_list[x[0]] for x in exp] + ret = u''' + %s.show_raw_tabular(%s, %d, %s); + ''' % (exp_object_name, json.dumps(out_list, ensure_ascii=False), label, div_name) + return ret + + +class LimeTabularExplainer(object): + """Explains predictions on tabular (i.e. matrix) data. + For numerical features, perturb them by sampling from a Normal(0,1) and + doing the inverse operation of mean-centering and scaling, according to the + means and stds in the training data. For categorical features, perturb by + sampling according to the training distribution, and making a binary + feature that is 1 when the value is the same as the instance being + explained.""" + + def __init__(self, + training_data, + mode="classification", + training_labels=None, + feature_names=None, + categorical_features=None, + categorical_names=None, + kernel_width=None, + kernel=None, + verbose=False, + class_names=None, + feature_selection='auto', + discretize_continuous=True, + discretizer='quartile', + sample_around_instance=False, + random_state=None, + training_data_stats=None): + """Init function. + + Args: + training_data: numpy 2d array + mode: "classification" or "regression" + training_labels: labels for training data. Not required, but may be + used by discretizer. + feature_names: list of names (strings) corresponding to the columns + in the training data. + categorical_features: list of indices (ints) corresponding to the + categorical columns. Everything else will be considered + continuous. Values in these columns MUST be integers. + categorical_names: map from int to list of names, where + categorical_names[x][y] represents the name of the yth value of + column x. + kernel_width: kernel width for the exponential kernel. + If None, defaults to sqrt (number of columns) * 0.75 + kernel: similarity kernel that takes euclidean distances and kernel + width as input and outputs weights in (0,1). If None, defaults to + an exponential kernel. + verbose: if true, print local prediction values from linear model + class_names: list of class names, ordered according to whatever the + classifier is using. If not present, class names will be '0', + '1', ... + feature_selection: feature selection method. can be + 'forward_selection', 'lasso_path', 'none' or 'auto'. + See function 'explain_instance_with_data' in lime_base.py for + details on what each of the options does. + discretize_continuous: if True, all non-categorical features will + be discretized into quartiles. + discretizer: only matters if discretize_continuous is True + and data is not sparse. Options are 'quartile', 'decile', + 'entropy' or a BaseDiscretizer instance. + sample_around_instance: if True, will sample continuous features + in perturbed samples from a normal centered at the instance + being explained. Otherwise, the normal is centered on the mean + of the feature data. + random_state: an integer or numpy.RandomState that will be used to + generate random numbers. If None, the random state will be + initialized using the internal numpy seed. + training_data_stats: a dict object having the details of training data + statistics. If None, training data information will be used, only matters + if discretize_continuous is True. Must have the following keys: + means", "mins", "maxs", "stds", "feature_values", + "feature_frequencies" + """ + self.random_state = check_random_state(random_state) + self.mode = mode + self.categorical_names = categorical_names or {} + self.sample_around_instance = sample_around_instance + self.training_data_stats = training_data_stats + + # Check and raise proper error in stats are supplied in non-descritized path + if self.training_data_stats: + self.validate_training_data_stats(self.training_data_stats) + + if categorical_features is None: + categorical_features = [] + if feature_names is None: + feature_names = [str(i) for i in range(training_data.shape[1])] + + self.categorical_features = list(categorical_features) + self.feature_names = list(feature_names) + + self.discretizer = None + if discretize_continuous and not sp.sparse.issparse(training_data): + # Set the discretizer if training data stats are provided + if self.training_data_stats: + discretizer = StatsDiscretizer(training_data, self.categorical_features, + self.feature_names, labels=training_labels, + data_stats=self.training_data_stats, + random_state=self.random_state) + + if discretizer == 'quartile': + self.discretizer = QuartileDiscretizer( + training_data, self.categorical_features, + self.feature_names, labels=training_labels, + random_state=self.random_state) + elif discretizer == 'decile': + self.discretizer = DecileDiscretizer( + training_data, self.categorical_features, + self.feature_names, labels=training_labels, + random_state=self.random_state) + elif discretizer == 'entropy': + self.discretizer = EntropyDiscretizer( + training_data, self.categorical_features, + self.feature_names, labels=training_labels, + random_state=self.random_state) + elif isinstance(discretizer, BaseDiscretizer): + self.discretizer = discretizer + else: + raise ValueError('''Discretizer must be 'quartile',''' + + ''' 'decile', 'entropy' or a''' + + ''' BaseDiscretizer instance''') + self.categorical_features = list(range(training_data.shape[1])) + + # Get the discretized_training_data when the stats are not provided + if(self.training_data_stats is None): + discretized_training_data = self.discretizer.discretize( + training_data) + + if kernel_width is None: + kernel_width = np.sqrt(training_data.shape[1]) * .75 + kernel_width = float(kernel_width) + + if kernel is None: + def kernel(d, kernel_width): + return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2)) + + kernel_fn = partial(kernel, kernel_width=kernel_width) + + self.feature_selection = feature_selection + self.base = lime_base.LimeBase(kernel_fn, verbose, random_state=self.random_state) + self.class_names = class_names + + # Though set has no role to play if training data stats are provided + self.scaler = sklearn.preprocessing.StandardScaler(with_mean=False) + self.scaler.fit(training_data) + self.feature_values = {} + self.feature_frequencies = {} + + for feature in self.categorical_features: + if training_data_stats is None: + if self.discretizer is not None: + column = discretized_training_data[:, feature] + else: + column = training_data[:, feature] + + feature_count = collections.Counter(column) + values, frequencies = map(list, zip(*(sorted(feature_count.items())))) + else: + values = training_data_stats["feature_values"][feature] + frequencies = training_data_stats["feature_frequencies"][feature] + + self.feature_values[feature] = values + self.feature_frequencies[feature] = (np.array(frequencies) / + float(sum(frequencies))) + self.scaler.mean_[feature] = 0 + self.scaler.scale_[feature] = 1 + + @staticmethod + def convert_and_round(values): + return ['%.2f' % v for v in values] + + @staticmethod + def validate_training_data_stats(training_data_stats): + """ + Method to validate the structure of training data stats + """ + stat_keys = list(training_data_stats.keys()) + valid_stat_keys = ["means", "mins", "maxs", "stds", "feature_values", "feature_frequencies"] + missing_keys = list(set(valid_stat_keys) - set(stat_keys)) + if len(missing_keys) > 0: + raise Exception("Missing keys in training_data_stats. Details: %s" % (missing_keys)) + + def explain_instance(self, + data_row, + predict_fn, + labels=(1,), + top_labels=None, + num_features=10, + num_samples=5000, + distance_metric='euclidean', + model_regressor=None): + """Generates explanations for a prediction. + + First, we generate neighborhood data by randomly perturbing features + from the instance (see __data_inverse). We then learn locally weighted + linear models on this neighborhood data to explain each of the classes + in an interpretable way (see lime_base.py). + + Args: + data_row: 1d numpy array or scipy.sparse matrix, corresponding to a row + predict_fn: prediction function. For classifiers, this should be a + function that takes a numpy array and outputs prediction + probabilities. For regressors, this takes a numpy array and + returns the predictions. For ScikitClassifiers, this is + `classifier.predict_proba()`. For ScikitRegressors, this + is `regressor.predict()`. The prediction function needs to work + on multiple feature vectors (the vectors randomly perturbed + from the data_row). + labels: iterable with labels to be explained. + top_labels: if not None, ignore labels and produce explanations for + the K labels with highest prediction probabilities, where K is + this parameter. + num_features: maximum number of features present in explanation + num_samples: size of the neighborhood to learn the linear model + distance_metric: the distance metric to use for weights. + model_regressor: sklearn regressor to use in explanation. Defaults + to Ridge regression in LimeBase. Must have model_regressor.coef_ + and 'sample_weight' as a parameter to model_regressor.fit() + + Returns: + An Explanation object (see explanation.py) with the corresponding + explanations. + """ + if sp.sparse.issparse(data_row) and not sp.sparse.isspmatrix_csr(data_row): + # Preventative code: if sparse, convert to csr format if not in csr format already + data_row = data_row.tocsr() + data, inverse = self.__data_inverse(data_row, num_samples) + if sp.sparse.issparse(data): + # Note in sparse case we don't subtract mean since data would become dense + scaled_data = data.multiply(self.scaler.scale_) + # Multiplying with csr matrix can return a coo sparse matrix + if not sp.sparse.isspmatrix_csr(scaled_data): + scaled_data = scaled_data.tocsr() + else: + scaled_data = (data - self.scaler.mean_) / self.scaler.scale_ + distances = sklearn.metrics.pairwise_distances( + scaled_data, + scaled_data[0].reshape(1, -1), + metric=distance_metric + ).ravel() + + yss = predict_fn(inverse) + + # for classification, the model needs to provide a list of tuples - classes + # along with prediction probabilities + if self.mode == "classification": + if len(yss.shape) == 1: + raise NotImplementedError("LIME does not currently support " + "classifier models without probability " + "scores. If this conflicts with your " + "use case, please let us know: " + "https://github.com/datascienceinc/lime/issues/16") + elif len(yss.shape) == 2: + if self.class_names is None: + self.class_names = [str(x) for x in range(yss[0].shape[0])] + else: + self.class_names = list(self.class_names) + if not np.allclose(yss.sum(axis=1), 1.0): + warnings.warn(""" + Prediction probabilties do not sum to 1, and + thus does not constitute a probability space. + Check that you classifier outputs probabilities + (Not log probabilities, or actual class predictions). + """) + else: + raise ValueError("Your model outputs " + "arrays with {} dimensions".format(len(yss.shape))) + + # for regression, the output should be a one-dimensional array of predictions + else: + try: + if len(yss.shape) != 1 and len(yss[0].shape) == 1: + yss = np.array([v[0] for v in yss]) + assert isinstance(yss, np.ndarray) and len(yss.shape) == 1 + except AssertionError: + raise ValueError("Your model needs to output single-dimensional \ + numpyarrays, not arrays of {} dimensions".format(yss.shape)) + + predicted_value = yss[0] + min_y = min(yss) + max_y = max(yss) + + # add a dimension to be compatible with downstream machinery + yss = yss[:, np.newaxis] + + feature_names = copy.deepcopy(self.feature_names) + if feature_names is None: + feature_names = [str(x) for x in range(data_row.shape[0])] + + if sp.sparse.issparse(data_row): + values = self.convert_and_round(data_row.data) + feature_indexes = data_row.indices + else: + values = self.convert_and_round(data_row) + feature_indexes = None + + for i in self.categorical_features: + if self.discretizer is not None and i in self.discretizer.lambdas: + continue + name = int(data_row[i]) + if i in self.categorical_names: + name = self.categorical_names[i][name] + feature_names[i] = '%s=%s' % (feature_names[i], name) + values[i] = 'True' + categorical_features = self.categorical_features + + discretized_feature_names = None + if self.discretizer is not None: + categorical_features = range(data.shape[1]) + discretized_instance = self.discretizer.discretize(data_row) + discretized_feature_names = copy.deepcopy(feature_names) + for f in self.discretizer.names: + discretized_feature_names[f] = self.discretizer.names[f][int( + discretized_instance[f])] + + domain_mapper = TableDomainMapper(feature_names, + values, + scaled_data[0], + categorical_features=categorical_features, + discretized_feature_names=discretized_feature_names, + feature_indexes=feature_indexes) + ret_exp = explanation.Explanation(domain_mapper, + mode=self.mode, + class_names=self.class_names) + if self.mode == "classification": + ret_exp.predict_proba = yss[0] + if top_labels: + labels = np.argsort(yss[0])[-top_labels:] + ret_exp.top_labels = list(labels) + ret_exp.top_labels.reverse() + else: + ret_exp.predicted_value = predicted_value + ret_exp.min_value = min_y + ret_exp.max_value = max_y + labels = [0] + for label in labels: + (ret_exp.intercept[label], + ret_exp.local_exp[label], + ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data( + scaled_data, + yss, + distances, + label, + num_features, + model_regressor=model_regressor, + feature_selection=self.feature_selection) + + if self.mode == "regression": + ret_exp.intercept[1] = ret_exp.intercept[0] + ret_exp.local_exp[1] = [x for x in ret_exp.local_exp[0]] + ret_exp.local_exp[0] = [(i, -1 * j) for i, j in ret_exp.local_exp[1]] + + return ret_exp + + def __data_inverse(self, + data_row, + num_samples): + """Generates a neighborhood around a prediction. + + For numerical features, perturb them by sampling from a Normal(0,1) and + doing the inverse operation of mean-centering and scaling, according to + the means and stds in the training data. For categorical features, + perturb by sampling according to the training distribution, and making + a binary feature that is 1 when the value is the same as the instance + being explained. + + Args: + data_row: 1d numpy array, corresponding to a row + num_samples: size of the neighborhood to learn the linear model + + Returns: + A tuple (data, inverse), where: + data: dense num_samples * K matrix, where categorical features + are encoded with either 0 (not equal to the corresponding value + in data_row) or 1. The first row is the original instance. + inverse: same as data, except the categorical features are not + binary, but categorical (as the original data) + """ + is_sparse = sp.sparse.issparse(data_row) + if is_sparse: + num_cols = data_row.shape[1] + data = sp.sparse.csr_matrix((num_samples, num_cols), dtype=data_row.dtype) + else: + num_cols = data_row.shape[0] + data = np.zeros((num_samples, num_cols)) + categorical_features = range(num_cols) + if self.discretizer is None: + instance_sample = data_row + scale = self.scaler.scale_ + mean = self.scaler.mean_ + if is_sparse: + # Perturb only the non-zero values + non_zero_indexes = data_row.nonzero()[1] + num_cols = len(non_zero_indexes) + instance_sample = data_row[:, non_zero_indexes] + scale = scale[non_zero_indexes] + mean = mean[non_zero_indexes] + data = self.random_state.normal( + 0, 1, num_samples * num_cols).reshape( + num_samples, num_cols) + if self.sample_around_instance: + data = data * scale + instance_sample + else: + data = data * scale + mean + if is_sparse: + if num_cols == 0: + data = sp.sparse.csr_matrix((num_samples, + data_row.shape[1]), + dtype=data_row.dtype) + else: + indexes = np.tile(non_zero_indexes, num_samples) + indptr = np.array( + range(0, len(non_zero_indexes) * (num_samples + 1), + len(non_zero_indexes))) + data_1d_shape = data.shape[0] * data.shape[1] + data_1d = data.reshape(data_1d_shape) + data = sp.sparse.csr_matrix( + (data_1d, indexes, indptr), + shape=(num_samples, data_row.shape[1])) + categorical_features = self.categorical_features + first_row = data_row + else: + first_row = self.discretizer.discretize(data_row) + data[0] = data_row.copy() + inverse = data.copy() + for column in categorical_features: + values = self.feature_values[column] + freqs = self.feature_frequencies[column] + inverse_column = self.random_state.choice(values, size=num_samples, + replace=True, p=freqs) + binary_column = (inverse_column == first_row[column]).astype(int) + binary_column[0] = 1 + inverse_column[0] = data[0, column] + data[:, column] = binary_column + inverse[:, column] = inverse_column + if self.discretizer is not None: + inverse[1:] = self.discretizer.undiscretize(inverse[1:]) + inverse[0] = data_row + return data, inverse + + +class RecurrentTabularExplainer(LimeTabularExplainer): + """ + An explainer for keras-style recurrent neural networks, where the + input shape is (n_samples, n_timesteps, n_features). This class + just extends the LimeTabularExplainer class and reshapes the training + data and feature names such that they become something like + + (val1_t1, val1_t2, val1_t3, ..., val2_t1, ..., valn_tn) + + Each of the methods that take data reshape it appropriately, + so you can pass in the training/testing data exactly as you + would to the recurrent neural network. + + """ + + def __init__(self, training_data, mode="classification", + training_labels=None, feature_names=None, + categorical_features=None, categorical_names=None, + kernel_width=None, kernel=None, verbose=False, class_names=None, + feature_selection='auto', discretize_continuous=True, + discretizer='quartile', random_state=None): + """ + Args: + training_data: numpy 3d array with shape + (n_samples, n_timesteps, n_features) + mode: "classification" or "regression" + training_labels: labels for training data. Not required, but may be + used by discretizer. + feature_names: list of names (strings) corresponding to the columns + in the training data. + categorical_features: list of indices (ints) corresponding to the + categorical columns. Everything else will be considered + continuous. Values in these columns MUST be integers. + categorical_names: map from int to list of names, where + categorical_names[x][y] represents the name of the yth value of + column x. + kernel_width: kernel width for the exponential kernel. + If None, defaults to sqrt(number of columns) * 0.75 + kernel: similarity kernel that takes euclidean distances and kernel + width as input and outputs weights in (0,1). If None, defaults to + an exponential kernel. + verbose: if true, print local prediction values from linear model + class_names: list of class names, ordered according to whatever the + classifier is using. If not present, class names will be '0', + '1', ... + feature_selection: feature selection method. can be + 'forward_selection', 'lasso_path', 'none' or 'auto'. + See function 'explain_instance_with_data' in lime_base.py for + details on what each of the options does. + discretize_continuous: if True, all non-categorical features will + be discretized into quartiles. + discretizer: only matters if discretize_continuous is True. Options + are 'quartile', 'decile', 'entropy' or a BaseDiscretizer + instance. + random_state: an integer or numpy.RandomState that will be used to + generate random numbers. If None, the random state will be + initialized using the internal numpy seed. + """ + + # Reshape X + n_samples, n_timesteps, n_features = training_data.shape + training_data = np.transpose(training_data, axes=(0, 2, 1)).reshape( + n_samples, n_timesteps * n_features) + self.n_timesteps = n_timesteps + self.n_features = n_features + + # Update the feature names + feature_names = ['{}_t-{}'.format(n, n_timesteps - (i + 1)) + for n in feature_names for i in range(n_timesteps)] + + # Send off the the super class to do its magic. + super(RecurrentTabularExplainer, self).__init__( + training_data, + mode=mode, + training_labels=training_labels, + feature_names=feature_names, + categorical_features=categorical_features, + categorical_names=categorical_names, + kernel_width=kernel_width, + kernel=kernel, + verbose=verbose, + class_names=class_names, + feature_selection=feature_selection, + discretize_continuous=discretize_continuous, + discretizer=discretizer, + random_state=random_state) + + def _make_predict_proba(self, func): + """ + The predict_proba method will expect 3d arrays, but we are reshaping + them to 2D so that LIME works correctly. This wraps the function + you give in explain_instance to first reshape the data to have + the shape the the keras-style network expects. + """ + + def predict_proba(X): + n_samples = X.shape[0] + new_shape = (n_samples, self.n_features, self.n_timesteps) + X = np.transpose(X.reshape(new_shape), axes=(0, 2, 1)) + return func(X) + + return predict_proba + + def explain_instance(self, data_row, classifier_fn, labels=(1,), + top_labels=None, num_features=10, num_samples=5000, + distance_metric='euclidean', model_regressor=None): + """Generates explanations for a prediction. + + First, we generate neighborhood data by randomly perturbing features + from the instance (see __data_inverse). We then learn locally weighted + linear models on this neighborhood data to explain each of the classes + in an interpretable way (see lime_base.py). + + Args: + data_row: 2d numpy array, corresponding to a row + classifier_fn: classifier prediction probability function, which + takes a numpy array and outputs prediction probabilities. For + ScikitClassifiers , this is classifier.predict_proba. + labels: iterable with labels to be explained. + top_labels: if not None, ignore labels and produce explanations for + the K labels with highest prediction probabilities, where K is + this parameter. + num_features: maximum number of features present in explanation + num_samples: size of the neighborhood to learn the linear model + distance_metric: the distance metric to use for weights. + model_regressor: sklearn regressor to use in explanation. Defaults + to Ridge regression in LimeBase. Must have + model_regressor.coef_ and 'sample_weight' as a parameter + to model_regressor.fit() + + Returns: + An Explanation object (see explanation.py) with the corresponding + explanations. + """ + + # Flatten input so that the normal explainer can handle it + data_row = data_row.T.reshape(self.n_timesteps * self.n_features) + + # Wrap the classifier to reshape input + classifier_fn = self._make_predict_proba(classifier_fn) + return super(RecurrentTabularExplainer, self).explain_instance( + data_row, classifier_fn, + labels=labels, + top_labels=top_labels, + num_features=num_features, + num_samples=num_samples, + distance_metric=distance_metric, + model_regressor=model_regressor) diff --git a/lime/lime_text.py b/lime/lime_text.py new file mode 100644 index 0000000000000000000000000000000000000000..b9716b7f710402a0e87b370846043a9ac7029716 --- /dev/null +++ b/lime/lime_text.py @@ -0,0 +1,484 @@ +""" +Functions for explaining text classifiers. +""" +from functools import partial +import itertools +import json +import re + +import numpy as np +import scipy as sp +import sklearn +from sklearn.utils import check_random_state + +from . import explanation +from . import lime_base + + +class TextDomainMapper(explanation.DomainMapper): + """Maps feature ids to words or word-positions""" + + def __init__(self, indexed_string): + """Initializer. + + Args: + indexed_string: lime_text.IndexedString, original string + """ + self.indexed_string = indexed_string + + def map_exp_ids(self, exp, positions=False): + """Maps ids to words or word-position strings. + + Args: + exp: list of tuples [(id, weight), (id,weight)] + positions: if True, also return word positions + + Returns: + list of tuples (word, weight), or (word_positions, weight) if + examples: ('bad', 1) or ('bad_3-6-12', 1) + """ + if positions: + exp = [('%s_%s' % ( + self.indexed_string.word(x[0]), + '-'.join( + map(str, + self.indexed_string.string_position(x[0])))), x[1]) + for x in exp] + else: + exp = [(self.indexed_string.word(x[0]), x[1]) for x in exp] + return exp + + def visualize_instance_html(self, exp, label, div_name, exp_object_name, + text=True, opacity=True): + """Adds text with highlighted words to visualization. + + Args: + exp: list of tuples [(id, weight), (id,weight)] + label: label id (integer) + div_name: name of div object to be used for rendering(in js) + exp_object_name: name of js explanation object + text: if False, return empty + opacity: if True, fade colors according to weight + """ + if not text: + return u'' + text = (self.indexed_string.raw_string() + .encode('utf-8', 'xmlcharrefreplace').decode('utf-8')) + text = re.sub(r'[<>&]', '|', text) + exp = [(self.indexed_string.word(x[0]), + self.indexed_string.string_position(x[0]), + x[1]) for x in exp] + all_occurrences = list(itertools.chain.from_iterable( + [itertools.product([x[0]], x[1], [x[2]]) for x in exp])) + all_occurrences = [(x[0], int(x[1]), x[2]) for x in all_occurrences] + ret = ''' + %s.show_raw_text(%s, %d, %s, %s, %s); + ''' % (exp_object_name, json.dumps(all_occurrences), label, + json.dumps(text), div_name, json.dumps(opacity)) + return ret + + +class IndexedString(object): + """String with various indexes.""" + + def __init__(self, raw_string, split_expression=r'\W+', bow=True, + mask_string=None): + """Initializer. + + Args: + raw_string: string with raw text in it + split_expression: Regex string or callable. If regex string, will be used with re.split. + If callable, the function should return a list of tokens. + bow: if True, a word is the same everywhere in the text - i.e. we + will index multiple occurrences of the same word. If False, + order matters, so that the same word will have different ids + according to position. + mask_string: If not None, replace words with this if bow=False + if None, default value is UNKWORDZ + """ + self.raw = raw_string + self.mask_string = 'UNKWORDZ' if mask_string is None else mask_string + + if callable(split_expression): + tokens = split_expression(self.raw) + self.as_list = self._segment_with_tokens(self.raw, tokens) + tokens = set(tokens) + + def non_word(string): + return string not in tokens + + else: + # with the split_expression as a non-capturing group (?:), we don't need to filter out + # the separator character from the split results. + splitter = re.compile(r'(%s)|$' % split_expression) + self.as_list = [s for s in splitter.split(self.raw) if s] + non_word = splitter.match + + self.as_np = np.array(self.as_list) + self.string_start = np.hstack( + ([0], np.cumsum([len(x) for x in self.as_np[:-1]]))) + vocab = {} + self.inverse_vocab = [] + self.positions = [] + self.bow = bow + non_vocab = set() + for i, word in enumerate(self.as_np): + if word in non_vocab: + continue + if non_word(word): + non_vocab.add(word) + continue + if bow: + if word not in vocab: + vocab[word] = len(vocab) + self.inverse_vocab.append(word) + self.positions.append([]) + idx_word = vocab[word] + self.positions[idx_word].append(i) + else: + self.inverse_vocab.append(word) + self.positions.append(i) + if not bow: + self.positions = np.array(self.positions) + + def raw_string(self): + """Returns the original raw string""" + return self.raw + + def num_words(self): + """Returns the number of tokens in the vocabulary for this document.""" + return len(self.inverse_vocab) + + def word(self, id_): + """Returns the word that corresponds to id_ (int)""" + return self.inverse_vocab[id_] + + def string_position(self, id_): + """Returns a np array with indices to id_ (int) occurrences""" + if self.bow: + return self.string_start[self.positions[id_]] + else: + return self.string_start[[self.positions[id_]]] + + def inverse_removing(self, words_to_remove): + """Returns a string after removing the appropriate words. + + If self.bow is false, replaces word with UNKWORDZ instead of removing + it. + + Args: + words_to_remove: list of ids (ints) to remove + + Returns: + original raw string with appropriate words removed. + """ + mask = np.ones(self.as_np.shape[0], dtype='bool') + mask[self.__get_idxs(words_to_remove)] = False + if not self.bow: + return ''.join( + [self.as_list[i] if mask[i] else self.mask_string + for i in range(mask.shape[0])]) + return ''.join([self.as_list[v] for v in mask.nonzero()[0]]) + + @staticmethod + def _segment_with_tokens(text, tokens): + """Segment a string around the tokens created by a passed-in tokenizer""" + list_form = [] + text_ptr = 0 + for token in tokens: + inter_token_string = [] + while not text[text_ptr:].startswith(token): + inter_token_string.append(text[text_ptr]) + text_ptr += 1 + if text_ptr >= len(text): + raise ValueError("Tokenization produced tokens that do not belong in string!") + text_ptr += len(token) + if inter_token_string: + list_form.append(''.join(inter_token_string)) + list_form.append(token) + if text_ptr < len(text): + list_form.append(text[text_ptr:]) + return list_form + + def __get_idxs(self, words): + """Returns indexes to appropriate words.""" + if self.bow: + return list(itertools.chain.from_iterable( + [self.positions[z] for z in words])) + else: + return self.positions[words] + + +class IndexedCharacters(object): + """String with various indexes.""" + + def __init__(self, raw_string, bow=True, mask_string=None): + """Initializer. + + Args: + raw_string: string with raw text in it + bow: if True, a char is the same everywhere in the text - i.e. we + will index multiple occurrences of the same character. If False, + order matters, so that the same word will have different ids + according to position. + mask_string: If not None, replace characters with this if bow=False + if None, default value is chr(0) + """ + self.raw = raw_string + self.as_list = list(self.raw) + self.as_np = np.array(self.as_list) + self.mask_string = chr(0) if mask_string is None else mask_string + self.string_start = np.arange(len(self.raw)) + vocab = {} + self.inverse_vocab = [] + self.positions = [] + self.bow = bow + non_vocab = set() + for i, char in enumerate(self.as_np): + if char in non_vocab: + continue + if bow: + if char not in vocab: + vocab[char] = len(vocab) + self.inverse_vocab.append(char) + self.positions.append([]) + idx_char = vocab[char] + self.positions[idx_char].append(i) + else: + self.inverse_vocab.append(char) + self.positions.append(i) + if not bow: + self.positions = np.array(self.positions) + + def raw_string(self): + """Returns the original raw string""" + return self.raw + + def num_words(self): + """Returns the number of tokens in the vocabulary for this document.""" + return len(self.inverse_vocab) + + def word(self, id_): + """Returns the word that corresponds to id_ (int)""" + return self.inverse_vocab[id_] + + def string_position(self, id_): + """Returns a np array with indices to id_ (int) occurrences""" + if self.bow: + return self.string_start[self.positions[id_]] + else: + return self.string_start[[self.positions[id_]]] + + def inverse_removing(self, words_to_remove): + """Returns a string after removing the appropriate words. + + If self.bow is false, replaces word with UNKWORDZ instead of removing + it. + + Args: + words_to_remove: list of ids (ints) to remove + + Returns: + original raw string with appropriate words removed. + """ + mask = np.ones(self.as_np.shape[0], dtype='bool') + mask[self.__get_idxs(words_to_remove)] = False + if not self.bow: + return ''.join( + [self.as_list[i] if mask[i] else self.mask_string + for i in range(mask.shape[0])]) + return ''.join([self.as_list[v] for v in mask.nonzero()[0]]) + + def __get_idxs(self, words): + """Returns indexes to appropriate words.""" + if self.bow: + return list(itertools.chain.from_iterable( + [self.positions[z] for z in words])) + else: + return self.positions[words] + + +class LimeTextExplainer(object): + """Explains text classifiers. + Currently, we are using an exponential kernel on cosine distance, and + restricting explanations to words that are present in documents.""" + + def __init__(self, + kernel_width=25, + kernel=None, + verbose=False, + class_names=None, + feature_selection='auto', + split_expression=r'\W+', + bow=True, + mask_string=None, + random_state=None, + char_level=False): + """Init function. + + Args: + kernel_width: kernel width for the exponential kernel. + kernel: similarity kernel that takes euclidean distances and kernel + width as input and outputs weights in (0,1). If None, defaults to + an exponential kernel. + verbose: if true, print local prediction values from linear model + class_names: list of class names, ordered according to whatever the + classifier is using. If not present, class names will be '0', + '1', ... + feature_selection: feature selection method. can be + 'forward_selection', 'lasso_path', 'none' or 'auto'. + See function 'explain_instance_with_data' in lime_base.py for + details on what each of the options does. + split_expression: Regex string or callable. If regex string, will be used with re.split. + If callable, the function should return a list of tokens. + bow: if True (bag of words), will perturb input data by removing + all occurrences of individual words or characters. + Explanations will be in terms of these words. Otherwise, will + explain in terms of word-positions, so that a word may be + important the first time it appears and unimportant the second. + Only set to false if the classifier uses word order in some way + (bigrams, etc), or if you set char_level=True. + mask_string: String used to mask tokens or characters if bow=False + if None, will be 'UNKWORDZ' if char_level=False, chr(0) + otherwise. + random_state: an integer or numpy.RandomState that will be used to + generate random numbers. If None, the random state will be + initialized using the internal numpy seed. + char_level: an boolean identifying that we treat each character + as an independent occurence in the string + """ + + if kernel is None: + def kernel(d, kernel_width): + return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2)) + + kernel_fn = partial(kernel, kernel_width=kernel_width) + + self.random_state = check_random_state(random_state) + self.base = lime_base.LimeBase(kernel_fn, verbose, + random_state=self.random_state) + self.class_names = class_names + self.vocabulary = None + self.feature_selection = feature_selection + self.bow = bow + self.mask_string = mask_string + self.split_expression = split_expression + self.char_level = char_level + + def explain_instance(self, + text_instance, + classifier_fn, + labels=(1,), + top_labels=None, + num_features=10, + num_samples=5000, + distance_metric='cosine', + model_regressor=None): + """Generates explanations for a prediction. + + First, we generate neighborhood data by randomly hiding features from + the instance (see __data_labels_distance_mapping). We then learn + locally weighted linear models on this neighborhood data to explain + each of the classes in an interpretable way (see lime_base.py). + + Args: + text_instance: raw text string to be explained. + classifier_fn: classifier prediction probability function, which + takes a list of d strings and outputs a (d, k) numpy array with + prediction probabilities, where k is the number of classes. + For ScikitClassifiers , this is classifier.predict_proba. + labels: iterable with labels to be explained. + top_labels: if not None, ignore labels and produce explanations for + the K labels with highest prediction probabilities, where K is + this parameter. + num_features: maximum number of features present in explanation + num_samples: size of the neighborhood to learn the linear model + distance_metric: the distance metric to use for sample weighting, + defaults to cosine similarity + model_regressor: sklearn regressor to use in explanation. Defaults + to Ridge regression in LimeBase. Must have model_regressor.coef_ + and 'sample_weight' as a parameter to model_regressor.fit() + Returns: + An Explanation object (see explanation.py) with the corresponding + explanations. + """ + + indexed_string = (IndexedCharacters( + text_instance, bow=self.bow, mask_string=self.mask_string) + if self.char_level else + IndexedString(text_instance, bow=self.bow, + split_expression=self.split_expression, + mask_string=self.mask_string)) + domain_mapper = TextDomainMapper(indexed_string) + data, yss, distances = self.__data_labels_distances( + indexed_string, classifier_fn, num_samples, + distance_metric=distance_metric) + if self.class_names is None: + self.class_names = [str(x) for x in range(yss[0].shape[0])] + ret_exp = explanation.Explanation(domain_mapper=domain_mapper, + class_names=self.class_names, + random_state=self.random_state) + ret_exp.predict_proba = yss[0] + if top_labels: + labels = np.argsort(yss[0])[-top_labels:] + ret_exp.top_labels = list(labels) + ret_exp.top_labels.reverse() + for label in labels: + (ret_exp.intercept[label], + ret_exp.local_exp[label], + ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data( + data, yss, distances, label, num_features, + model_regressor=model_regressor, + feature_selection=self.feature_selection) + return ret_exp + + def __data_labels_distances(self, + indexed_string, + classifier_fn, + num_samples, + distance_metric='cosine'): + """Generates a neighborhood around a prediction. + + Generates neighborhood data by randomly removing words from + the instance, and predicting with the classifier. Uses cosine distance + to compute distances between original and perturbed instances. + Args: + indexed_string: document (IndexedString) to be explained, + classifier_fn: classifier prediction probability function, which + takes a string and outputs prediction probabilities. For + ScikitClassifier, this is classifier.predict_proba. + num_samples: size of the neighborhood to learn the linear model + distance_metric: the distance metric to use for sample weighting, + defaults to cosine similarity. + + + Returns: + A tuple (data, labels, distances), where: + data: dense num_samples * K binary matrix, where K is the + number of tokens in indexed_string. The first row is the + original instance, and thus a row of ones. + labels: num_samples * L matrix, where L is the number of target + labels + distances: cosine distance between the original instance and + each perturbed instance (computed in the binary 'data' + matrix), times 100. + """ + + def distance_fn(x): + return sklearn.metrics.pairwise.pairwise_distances( + x, x[0], metric=distance_metric).ravel() * 100 + + doc_size = indexed_string.num_words() + sample = self.random_state.randint(1, doc_size + 1, num_samples - 1) + data = np.ones((num_samples, doc_size)) + data[0] = np.ones(doc_size) + features_range = range(doc_size) + inverse_data = [indexed_string.raw_string()] + for i, size in enumerate(sample, start=1): + inactive = self.random_state.choice(features_range, size, + replace=False) + data[i, inactive] = 0 + inverse_data.append(indexed_string.inverse_removing(inactive)) + labels = classifier_fn(inverse_data) + distances = distance_fn(sp.sparse.csr_matrix(data)) + return data, labels, distances diff --git a/lime/submodular_pick.py b/lime/submodular_pick.py new file mode 100644 index 0000000000000000000000000000000000000000..b8ae62fd0472d4df6481be5a660600872e50746a --- /dev/null +++ b/lime/submodular_pick.py @@ -0,0 +1,128 @@ +import numpy as np +import warnings + + +class SubmodularPick(object): + """Class for submodular pick + + Saves a representative sample of explanation objects using SP-LIME, + as well as saving all generated explanations + + First, a collection of candidate explanations are generated + (see explain_instance). From these candidates, num_exps_desired are + chosen using submodular pick. (see marcotcr et al paper).""" + + def __init__(self, + explainer, + data, + predict_fn, + method='sample', + sample_size=1000, + num_exps_desired=5, + num_features=10, + **kwargs): + + """ + Args: + data: a numpy array where each row is a single input into predict_fn + predict_fn: prediction function. For classifiers, this should be a + function that takes a numpy array and outputs prediction + probabilities. For regressors, this takes a numpy array and + returns the predictions. For ScikitClassifiers, this is + `classifier.predict_proba()`. For ScikitRegressors, this + is `regressor.predict()`. The prediction function needs to work + on multiple feature vectors (the vectors randomly perturbed + from the data_row). + method: The method to use to generate candidate explanations + method == 'sample' will sample the data uniformly at + random. The sample size is given by sample_size. Otherwise + if method == 'full' then explanations will be generated for the + entire data. l + sample_size: The number of instances to explain if method == 'sample' + num_exps_desired: The number of explanation objects returned + num_features: maximum number of features present in explanation + + + Sets value: + sp_explanations: A list of explanation objects that has a high coverage + explanations: All the candidate explanations saved for potential future use. + """ + + top_labels = kwargs.get('top_labels', 1) + if 'top_labels' in kwargs: + del kwargs['top_labels'] + # Parse args + if method == 'sample': + if sample_size > len(data): + warnings.warn("""Requested sample size larger than + size of input data. Using all data""") + sample_size = len(data) + all_indices = np.arange(len(data)) + np.random.shuffle(all_indices) + sample_indices = all_indices[:sample_size] + elif method == 'full': + sample_indices = np.arange(len(data)) + else: + raise ValueError('Method must be \'sample\' or \'full\'') + + # Generate Explanations + self.explanations = [] + for i in sample_indices: + self.explanations.append( + explainer.explain_instance( + data[i], predict_fn, num_features=num_features, + top_labels=top_labels, + **kwargs)) + # Error handling + try: + num_exps_desired = int(num_exps_desired) + except TypeError: + return("Requested number of explanations should be an integer") + if num_exps_desired > len(self.explanations): + warnings.warn("""Requested number of explanations larger than + total number of explanations, returning all + explanations instead.""") + num_exps_desired = min(num_exps_desired, len(self.explanations)) + + # Find all the explanation model features used. Defines the dimension d' + features_dict = {} + feature_iter = 0 + for exp in self.explanations: + labels = exp.available_labels() if exp.mode == 'classification' else [1] + for label in labels: + for feature, _ in exp.as_list(label=label): + if feature not in features_dict.keys(): + features_dict[feature] = (feature_iter) + feature_iter += 1 + d_prime = len(features_dict.keys()) + + # Create the n x d' dimensional 'explanation matrix', W + W = np.zeros((len(self.explanations), d_prime)) + for i, exp in enumerate(self.explanations): + labels = exp.available_labels() if exp.mode == 'classification' else [1] + for label in labels: + for feature, value in exp.as_list(label): + W[i, features_dict[feature]] += value + + # Create the global importance vector, I_j described in the paper + importance = np.sum(abs(W), axis=0)**.5 + + # Now run the SP-LIME greedy algorithm + remaining_indices = set(range(len(self.explanations))) + V = [] + for _ in range(num_exps_desired): + best = 0 + best_ind = None + current = 0 + for i in remaining_indices: + current = np.dot( + (np.sum(abs(W)[V + [i]], axis=0) > 0), importance + ) # coverage function + if current >= best: + best = current + best_ind = i + V.append(best_ind) + remaining_indices -= {best_ind} + + self.sp_explanations = [self.explanations[i] for i in V] + self.V = V diff --git a/lime/tests/__init__.py b/lime/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lime/tests/test_discretize.py b/lime/tests/test_discretize.py new file mode 100644 index 0000000000000000000000000000000000000000..815f588b180b6f502b853e4822862663d38093b6 --- /dev/null +++ b/lime/tests/test_discretize.py @@ -0,0 +1,177 @@ +import unittest +from unittest import TestCase + +import numpy as np + +from sklearn.datasets import load_iris + +from lime.discretize import QuartileDiscretizer, DecileDiscretizer, EntropyDiscretizer + + +class TestDiscretize(TestCase): + + def setUp(self): + iris = load_iris() + + self.feature_names = iris.feature_names + self.x = iris.data + self.y = iris.target + + def check_random_state_for_discretizer_class(self, DiscretizerClass): + # ---------------------------------------------------------------------- + # -----------Check if the same random_state produces the same----------- + # -------------results for different discretizer instances.------------- + # ---------------------------------------------------------------------- + discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, + random_state=10) + x_1 = discretizer.undiscretize(discretizer.discretize(self.x)) + + discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, + random_state=10) + x_2 = discretizer.undiscretize(discretizer.discretize(self.x)) + + self.assertEqual((x_1 == x_2).sum(), x_1.shape[0] * x_1.shape[1]) + + discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, + random_state=np.random.RandomState(10)) + x_1 = discretizer.undiscretize(discretizer.discretize(self.x)) + + discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, + random_state=np.random.RandomState(10)) + x_2 = discretizer.undiscretize(discretizer.discretize(self.x)) + + self.assertEqual((x_1 == x_2).sum(), x_1.shape[0] * x_1.shape[1]) + + # ---------------------------------------------------------------------- + # ---------Check if two different random_state values produces---------- + # -------different results for different discretizers instances.-------- + # ---------------------------------------------------------------------- + discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, + random_state=10) + x_1 = discretizer.undiscretize(discretizer.discretize(self.x)) + + discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, + random_state=20) + x_2 = discretizer.undiscretize(discretizer.discretize(self.x)) + + self.assertFalse((x_1 == x_2).sum() == x_1.shape[0] * x_1.shape[1]) + + discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, + random_state=np.random.RandomState(10)) + x_1 = discretizer.undiscretize(discretizer.discretize(self.x)) + + discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, + random_state=np.random.RandomState(20)) + x_2 = discretizer.undiscretize(discretizer.discretize(self.x)) + + self.assertFalse((x_1 == x_2).sum() == x_1.shape[0] * x_1.shape[1]) + + def test_random_state(self): + self.check_random_state_for_discretizer_class(QuartileDiscretizer) + + self.check_random_state_for_discretizer_class(DecileDiscretizer) + + self.check_random_state_for_discretizer_class(EntropyDiscretizer) + + def test_feature_names_1(self): + self.maxDiff = None + discretizer = QuartileDiscretizer(self.x, [], self.feature_names, + self.y, random_state=10) + self.assertDictEqual( + {0: ['sepal length (cm) <= 5.10', + '5.10 < sepal length (cm) <= 5.80', + '5.80 < sepal length (cm) <= 6.40', + 'sepal length (cm) > 6.40'], + 1: ['sepal width (cm) <= 2.80', + '2.80 < sepal width (cm) <= 3.00', + '3.00 < sepal width (cm) <= 3.30', + 'sepal width (cm) > 3.30'], + 2: ['petal length (cm) <= 1.60', + '1.60 < petal length (cm) <= 4.35', + '4.35 < petal length (cm) <= 5.10', + 'petal length (cm) > 5.10'], + 3: ['petal width (cm) <= 0.30', + '0.30 < petal width (cm) <= 1.30', + '1.30 < petal width (cm) <= 1.80', + 'petal width (cm) > 1.80']}, + discretizer.names) + + def test_feature_names_2(self): + self.maxDiff = None + discretizer = DecileDiscretizer(self.x, [], self.feature_names, self.y, + random_state=10) + self.assertDictEqual( + {0: ['sepal length (cm) <= 4.80', + '4.80 < sepal length (cm) <= 5.00', + '5.00 < sepal length (cm) <= 5.27', + '5.27 < sepal length (cm) <= 5.60', + '5.60 < sepal length (cm) <= 5.80', + '5.80 < sepal length (cm) <= 6.10', + '6.10 < sepal length (cm) <= 6.30', + '6.30 < sepal length (cm) <= 6.52', + '6.52 < sepal length (cm) <= 6.90', + 'sepal length (cm) > 6.90'], + 1: ['sepal width (cm) <= 2.50', + '2.50 < sepal width (cm) <= 2.70', + '2.70 < sepal width (cm) <= 2.80', + '2.80 < sepal width (cm) <= 3.00', + '3.00 < sepal width (cm) <= 3.10', + '3.10 < sepal width (cm) <= 3.20', + '3.20 < sepal width (cm) <= 3.40', + '3.40 < sepal width (cm) <= 3.61', + 'sepal width (cm) > 3.61'], + 2: ['petal length (cm) <= 1.40', + '1.40 < petal length (cm) <= 1.50', + '1.50 < petal length (cm) <= 1.70', + '1.70 < petal length (cm) <= 3.90', + '3.90 < petal length (cm) <= 4.35', + '4.35 < petal length (cm) <= 4.64', + '4.64 < petal length (cm) <= 5.00', + '5.00 < petal length (cm) <= 5.32', + '5.32 < petal length (cm) <= 5.80', + 'petal length (cm) > 5.80'], + 3: ['petal width (cm) <= 0.20', + '0.20 < petal width (cm) <= 0.40', + '0.40 < petal width (cm) <= 1.16', + '1.16 < petal width (cm) <= 1.30', + '1.30 < petal width (cm) <= 1.50', + '1.50 < petal width (cm) <= 1.80', + '1.80 < petal width (cm) <= 1.90', + '1.90 < petal width (cm) <= 2.20', + 'petal width (cm) > 2.20']}, + discretizer.names) + + def test_feature_names_3(self): + self.maxDiff = None + discretizer = EntropyDiscretizer(self.x, [], self.feature_names, + self.y, random_state=10) + self.assertDictEqual( + {0: ['sepal length (cm) <= 4.85', + '4.85 < sepal length (cm) <= 5.45', + '5.45 < sepal length (cm) <= 5.55', + '5.55 < sepal length (cm) <= 5.85', + '5.85 < sepal length (cm) <= 6.15', + '6.15 < sepal length (cm) <= 7.05', + 'sepal length (cm) > 7.05'], + 1: ['sepal width (cm) <= 2.45', + '2.45 < sepal width (cm) <= 2.95', + '2.95 < sepal width (cm) <= 3.05', + '3.05 < sepal width (cm) <= 3.35', + '3.35 < sepal width (cm) <= 3.45', + '3.45 < sepal width (cm) <= 3.55', + 'sepal width (cm) > 3.55'], + 2: ['petal length (cm) <= 2.45', + '2.45 < petal length (cm) <= 4.45', + '4.45 < petal length (cm) <= 4.75', + '4.75 < petal length (cm) <= 5.15', + 'petal length (cm) > 5.15'], + 3: ['petal width (cm) <= 0.80', + '0.80 < petal width (cm) <= 1.35', + '1.35 < petal width (cm) <= 1.75', + '1.75 < petal width (cm) <= 1.85', + 'petal width (cm) > 1.85']}, + discretizer.names) + + +if __name__ == '__main__': + unittest.main() diff --git a/lime/tests/test_generic_utils.py b/lime/tests/test_generic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..02380ad7d9220225aecae128f96c59971f41c04f --- /dev/null +++ b/lime/tests/test_generic_utils.py @@ -0,0 +1,93 @@ +import unittest +import sys +from lime.utils.generic_utils import has_arg + + +class TestGenericUtils(unittest.TestCase): + + def test_has_arg(self): + # fn is callable / is not callable + + class FooNotCallable: + + def __init__(self, word): + self.message = word + + class FooCallable: + + def __init__(self, word): + self.message = word + + def __call__(self, message): + return message + + def positional_argument_call(self, arg1): + return self.message + + def multiple_positional_arguments_call(self, *args): + res = [] + for a in args: + res.append(a) + return res + + def keyword_argument_call(self, filter_=True): + res = self.message + if filter_: + res = 'KO' + return res + + def multiple_keyword_arguments_call(self, arg1='1', arg2='2'): + return self.message + arg1 + arg2 + + def undefined_keyword_arguments_call(self, **kwargs): + res = self.message + for a in kwargs: + res = res + a + return a + + foo_callable = FooCallable('OK') + self.assertTrue(has_arg(foo_callable, 'message')) + + if sys.version_info < (3,): + foo_not_callable = FooNotCallable('KO') + self.assertFalse(has_arg(foo_not_callable, 'message')) + elif sys.version_info < (3, 6): + with self.assertRaises(TypeError): + foo_not_callable = FooNotCallable('KO') + has_arg(foo_not_callable, 'message') + + # Python 2, argument in / not in valid arguments / keyword arguments + if sys.version_info < (3,): + self.assertFalse(has_arg(foo_callable, 'invalid_arg')) + self.assertTrue(has_arg(foo_callable.positional_argument_call, 'arg1')) + self.assertFalse(has_arg(foo_callable.multiple_positional_arguments_call, 'argX')) + self.assertFalse(has_arg(foo_callable.keyword_argument_call, 'argX')) + self.assertTrue(has_arg(foo_callable.keyword_argument_call, 'filter_')) + self.assertTrue(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg2')) + self.assertFalse(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg3')) + self.assertFalse(has_arg(foo_callable.undefined_keyword_arguments_call, 'argX')) + # Python 3, argument in / not in valid arguments / keyword arguments + elif sys.version_info < (3, 6): + self.assertFalse(has_arg(foo_callable, 'invalid_arg')) + self.assertTrue(has_arg(foo_callable.positional_argument_call, 'arg1')) + self.assertFalse(has_arg(foo_callable.multiple_positional_arguments_call, 'argX')) + self.assertFalse(has_arg(foo_callable.keyword_argument_call, 'argX')) + self.assertTrue(has_arg(foo_callable.keyword_argument_call, 'filter_')) + self.assertTrue(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg2')) + self.assertFalse(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg3')) + self.assertFalse(has_arg(foo_callable.undefined_keyword_arguments_call, 'argX')) + else: + self.assertFalse(has_arg(foo_callable, 'invalid_arg')) + self.assertTrue(has_arg(foo_callable.positional_argument_call, 'arg1')) + self.assertFalse(has_arg(foo_callable.multiple_positional_arguments_call, 'argX')) + self.assertFalse(has_arg(foo_callable.keyword_argument_call, 'argX')) + self.assertTrue(has_arg(foo_callable.keyword_argument_call, 'filter_')) + self.assertTrue(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg2')) + self.assertFalse(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg3')) + self.assertFalse(has_arg(foo_callable.undefined_keyword_arguments_call, 'argX')) + # argname is None + self.assertFalse(has_arg(foo_callable, None)) + + +if __name__ == '__main__': + unittest.main() diff --git a/lime/tests/test_lime_tabular.py b/lime/tests/test_lime_tabular.py new file mode 100644 index 0000000000000000000000000000000000000000..426079b48b8a54f4376e989b8c1d463bb0636eaa --- /dev/null +++ b/lime/tests/test_lime_tabular.py @@ -0,0 +1,651 @@ +import unittest + +import numpy as np +import collections +import sklearn # noqa +import sklearn.datasets +import sklearn.ensemble +import sklearn.linear_model # noqa +from numpy.testing import assert_array_equal +from sklearn.datasets import load_iris, make_classification, make_multilabel_classification +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import LinearRegression +from lime.discretize import QuartileDiscretizer, DecileDiscretizer, EntropyDiscretizer + + +try: + from sklearn.model_selection import train_test_split +except ImportError: + # Deprecated in scikit-learn version 0.18, removed in 0.20 + from sklearn.cross_validation import train_test_split + +from lime.lime_tabular import LimeTabularExplainer + + +class TestLimeTabular(unittest.TestCase): + + def setUp(self): + iris = load_iris() + + self.feature_names = iris.feature_names + self.target_names = iris.target_names + + (self.train, + self.test, + self.labels_train, + self.labels_test) = train_test_split(iris.data, iris.target, train_size=0.80) + + def test_lime_explainer_good_regressor(self): + np.random.seed(1) + rf = RandomForestClassifier(n_estimators=500) + rf.fit(self.train, self.labels_train) + i = np.random.randint(0, self.test.shape[0]) + + explainer = LimeTabularExplainer(self.train, + mode="classification", + feature_names=self.feature_names, + class_names=self.target_names, + discretize_continuous=True) + + exp = explainer.explain_instance(self.test[i], + rf.predict_proba, + num_features=2, + model_regressor=LinearRegression()) + + self.assertIsNotNone(exp) + keys = [x[0] for x in exp.as_list()] + self.assertEqual(1, + sum([1 if 'petal width' in x else 0 for x in keys]), + "Petal Width is a major feature") + self.assertEqual(1, + sum([1 if 'petal length' in x else 0 for x in keys]), + "Petal Length is a major feature") + + def test_lime_explainer_good_regressor_synthetic_data(self): + X, y = make_classification(n_samples=1000, + n_features=20, + n_informative=2, + n_redundant=2, + random_state=10) + + rf = RandomForestClassifier(n_estimators=500) + rf.fit(X, y) + instance = np.random.randint(0, X.shape[0]) + feature_names = ["feature" + str(i) for i in range(20)] + explainer = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True) + + exp = explainer.explain_instance(X[instance], rf.predict_proba) + + self.assertIsNotNone(exp) + self.assertEqual(10, len(exp.as_list())) + + def test_lime_explainer_sparse_synthetic_data(self): + n_features = 20 + X, y = make_multilabel_classification(n_samples=100, + sparse=True, + n_features=n_features, + n_classes=1, + n_labels=2) + rf = RandomForestClassifier(n_estimators=500) + rf.fit(X, y) + instance = np.random.randint(0, X.shape[0]) + feature_names = ["feature" + str(i) for i in range(n_features)] + explainer = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True) + + exp = explainer.explain_instance(X[instance], rf.predict_proba) + + self.assertIsNotNone(exp) + self.assertEqual(10, len(exp.as_list())) + + def test_lime_explainer_no_regressor(self): + np.random.seed(1) + + rf = RandomForestClassifier(n_estimators=500) + rf.fit(self.train, self.labels_train) + i = np.random.randint(0, self.test.shape[0]) + + explainer = LimeTabularExplainer(self.train, + feature_names=self.feature_names, + class_names=self.target_names, + discretize_continuous=True) + + exp = explainer.explain_instance(self.test[i], + rf.predict_proba, + num_features=2) + self.assertIsNotNone(exp) + keys = [x[0] for x in exp.as_list()] + self.assertEqual(1, + sum([1 if 'petal width' in x else 0 for x in keys]), + "Petal Width is a major feature") + self.assertEqual(1, + sum([1 if 'petal length' in x else 0 for x in keys]), + "Petal Length is a major feature") + + def test_lime_explainer_entropy_discretizer(self): + np.random.seed(1) + + rf = RandomForestClassifier(n_estimators=500) + rf.fit(self.train, self.labels_train) + i = np.random.randint(0, self.test.shape[0]) + + explainer = LimeTabularExplainer(self.train, + feature_names=self.feature_names, + class_names=self.target_names, + training_labels=self.labels_train, + discretize_continuous=True, + discretizer='entropy') + + exp = explainer.explain_instance(self.test[i], + rf.predict_proba, + num_features=2) + self.assertIsNotNone(exp) + keys = [x[0] for x in exp.as_list()] + print(keys) + self.assertEqual(1, + sum([1 if 'petal width' in x else 0 for x in keys]), + "Petal Width is a major feature") + self.assertEqual(1, + sum([1 if 'petal length' in x else 0 for x in keys]), + "Petal Length is a major feature") + + def test_lime_tabular_explainer_equal_random_state(self): + X, y = make_classification(n_samples=1000, + n_features=20, + n_informative=2, + n_redundant=2, + random_state=10) + + rf = RandomForestClassifier(n_estimators=500, random_state=10) + rf.fit(X, y) + instance = np.random.RandomState(10).randint(0, X.shape[0]) + feature_names = ["feature" + str(i) for i in range(20)] + + # ---------------------------------------------------------------------- + # -------------------------Quartile Discretizer------------------------- + # ---------------------------------------------------------------------- + discretizer = QuartileDiscretizer(X, [], feature_names, y, + random_state=10) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = QuartileDiscretizer(X, [], feature_names, y, + random_state=10) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertDictEqual(exp_1.as_map(), exp_2.as_map()) + + # ---------------------------------------------------------------------- + # --------------------------Decile Discretizer-------------------------- + # ---------------------------------------------------------------------- + discretizer = DecileDiscretizer(X, [], feature_names, y, + random_state=10) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = DecileDiscretizer(X, [], feature_names, y, + random_state=10) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertDictEqual(exp_1.as_map(), exp_2.as_map()) + + # ---------------------------------------------------------------------- + # -------------------------Entropy Discretizer-------------------------- + # ---------------------------------------------------------------------- + discretizer = EntropyDiscretizer(X, [], feature_names, y, + random_state=10) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = EntropyDiscretizer(X, [], feature_names, y, + random_state=10) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertDictEqual(exp_1.as_map(), exp_2.as_map()) + + def test_lime_tabular_explainer_not_equal_random_state(self): + X, y = make_classification(n_samples=1000, + n_features=20, + n_informative=2, + n_redundant=2, + random_state=10) + + rf = RandomForestClassifier(n_estimators=500, random_state=10) + rf.fit(X, y) + instance = np.random.RandomState(10).randint(0, X.shape[0]) + feature_names = ["feature" + str(i) for i in range(20)] + + # ---------------------------------------------------------------------- + # -------------------------Quartile Discretizer------------------------- + # ---------------------------------------------------------------------- + + # ---------------------------------[1]---------------------------------- + discretizer = QuartileDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = QuartileDiscretizer(X, [], feature_names, y, + random_state=10) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertTrue(exp_1.as_map() != exp_2.as_map()) + + # ---------------------------------[2]---------------------------------- + discretizer = QuartileDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=20) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = QuartileDiscretizer(X, [], feature_names, y, + random_state=10) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertTrue(exp_1.as_map() != exp_2.as_map()) + + # ---------------------------------[3]---------------------------------- + discretizer = QuartileDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=20) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = QuartileDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertTrue(exp_1.as_map() != exp_2.as_map()) + + # ---------------------------------[4]---------------------------------- + discretizer = QuartileDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=20) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = QuartileDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=20) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertFalse(exp_1.as_map() != exp_2.as_map()) + + # ---------------------------------------------------------------------- + # --------------------------Decile Discretizer-------------------------- + # ---------------------------------------------------------------------- + + # ---------------------------------[1]---------------------------------- + discretizer = DecileDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = DecileDiscretizer(X, [], feature_names, y, + random_state=10) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertTrue(exp_1.as_map() != exp_2.as_map()) + + # ---------------------------------[2]---------------------------------- + discretizer = DecileDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=20) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = DecileDiscretizer(X, [], feature_names, y, + random_state=10) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertTrue(exp_1.as_map() != exp_2.as_map()) + + # ---------------------------------[3]---------------------------------- + discretizer = DecileDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=20) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = DecileDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertTrue(exp_1.as_map() != exp_2.as_map()) + + # ---------------------------------[4]---------------------------------- + discretizer = DecileDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=20) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = DecileDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=20) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertFalse(exp_1.as_map() != exp_2.as_map()) + + # ---------------------------------------------------------------------- + # --------------------------Entropy Discretizer------------------------- + # ---------------------------------------------------------------------- + + # ---------------------------------[1]---------------------------------- + discretizer = EntropyDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = EntropyDiscretizer(X, [], feature_names, y, + random_state=10) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertTrue(exp_1.as_map() != exp_2.as_map()) + + # ---------------------------------[2]---------------------------------- + discretizer = EntropyDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=20) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = EntropyDiscretizer(X, [], feature_names, y, + random_state=10) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertTrue(exp_1.as_map() != exp_2.as_map()) + + # ---------------------------------[3]---------------------------------- + discretizer = EntropyDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=20) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = EntropyDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=10) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertTrue(exp_1.as_map() != exp_2.as_map()) + + # ---------------------------------[4]---------------------------------- + discretizer = EntropyDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_1 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=20) + exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + discretizer = EntropyDiscretizer(X, [], feature_names, y, + random_state=20) + explainer_2 = LimeTabularExplainer(X, + feature_names=feature_names, + discretize_continuous=True, + discretizer=discretizer, + random_state=20) + exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, + num_samples=500) + + self.assertFalse(exp_1.as_map() != exp_2.as_map()) + + def testFeatureNamesAndCategoricalFeats(self): + training_data = np.array([[0., 1.], [1., 0.]]) + + explainer = LimeTabularExplainer(training_data=training_data) + self.assertEqual(explainer.feature_names, ['0', '1']) + self.assertEqual(explainer.categorical_features, [0, 1]) + + explainer = LimeTabularExplainer( + training_data=training_data, + feature_names=np.array(['one', 'two']) + ) + self.assertEqual(explainer.feature_names, ['one', 'two']) + + explainer = LimeTabularExplainer( + training_data=training_data, + categorical_features=np.array([0]), + discretize_continuous=False + ) + self.assertEqual(explainer.categorical_features, [0]) + + def testFeatureValues(self): + training_data = np.array([ + [0, 0, 2], + [1, 1, 0], + [0, 2, 2], + [1, 3, 0] + ]) + + explainer = LimeTabularExplainer( + training_data=training_data, + categorical_features=[0, 1, 2] + ) + + self.assertEqual(set(explainer.feature_values[0]), {0, 1}) + self.assertEqual(set(explainer.feature_values[1]), {0, 1, 2, 3}) + self.assertEqual(set(explainer.feature_values[2]), {0, 2}) + + assert_array_equal(explainer.feature_frequencies[0], np.array([.5, .5])) + assert_array_equal(explainer.feature_frequencies[1], np.array([.25, .25, .25, .25])) + assert_array_equal(explainer.feature_frequencies[2], np.array([.5, .5])) + + def test_lime_explainer_with_data_stats(self): + np.random.seed(1) + + rf = RandomForestClassifier(n_estimators=500) + rf.fit(self.train, self.labels_train) + i = np.random.randint(0, self.test.shape[0]) + + # Generate stats using a quartile descritizer + descritizer = QuartileDiscretizer(self.train, [], self.feature_names, self.target_names, + random_state=20) + + d_means = descritizer.means + d_stds = descritizer.stds + d_mins = descritizer.mins + d_maxs = descritizer.maxs + d_bins = descritizer.bins(self.train, self.target_names) + + # Compute feature values and frequencies of all columns + cat_features = np.arange(self.train.shape[1]) + discretized_training_data = descritizer.discretize(self.train) + + feature_values = {} + feature_frequencies = {} + for feature in cat_features: + column = discretized_training_data[:, feature] + feature_count = collections.Counter(column) + values, frequencies = map(list, zip(*(feature_count.items()))) + feature_values[feature] = values + feature_frequencies[feature] = frequencies + + # Convert bins to list from array + d_bins_revised = {} + index = 0 + for bin in d_bins: + d_bins_revised[index] = bin.tolist() + index = index+1 + + # Descritized stats + data_stats = {} + data_stats["means"] = d_means + data_stats["stds"] = d_stds + data_stats["maxs"] = d_maxs + data_stats["mins"] = d_mins + data_stats["bins"] = d_bins_revised + data_stats["feature_values"] = feature_values + data_stats["feature_frequencies"] = feature_frequencies + + data = np.zeros((2, len(self.feature_names))) + explainer = LimeTabularExplainer( + data, feature_names=self.feature_names, random_state=10, + training_data_stats=data_stats, training_labels=self.target_names) + + exp = explainer.explain_instance(self.test[i], + rf.predict_proba, + num_features=2, + model_regressor=LinearRegression()) + + self.assertIsNotNone(exp) + keys = [x[0] for x in exp.as_list()] + self.assertEqual(1, + sum([1 if 'petal width' in x else 0 for x in keys]), + "Petal Width is a major feature") + self.assertEqual(1, + sum([1 if 'petal length' in x else 0 for x in keys]), + "Petal Length is a major feature") + + +if __name__ == '__main__': + unittest.main() diff --git a/lime/tests/test_lime_text.py b/lime/tests/test_lime_text.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d2f728e457c5f8fa9b19f447ab8d9701621a43 --- /dev/null +++ b/lime/tests/test_lime_text.py @@ -0,0 +1,167 @@ +import re +import unittest + +import sklearn # noqa +from sklearn.datasets import fetch_20newsgroups +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics import f1_score +from sklearn.naive_bayes import MultinomialNB +from sklearn.pipeline import make_pipeline + +import numpy as np + +from lime.lime_text import LimeTextExplainer +from lime.lime_text import IndexedCharacters, IndexedString + + +class TestLimeText(unittest.TestCase): + + def test_lime_text_explainer_good_regressor(self): + categories = ['alt.atheism', 'soc.religion.christian'] + newsgroups_train = fetch_20newsgroups(subset='train', + categories=categories) + newsgroups_test = fetch_20newsgroups(subset='test', + categories=categories) + class_names = ['atheism', 'christian'] + vectorizer = TfidfVectorizer(lowercase=False) + train_vectors = vectorizer.fit_transform(newsgroups_train.data) + test_vectors = vectorizer.transform(newsgroups_test.data) + nb = MultinomialNB(alpha=.01) + nb.fit(train_vectors, newsgroups_train.target) + pred = nb.predict(test_vectors) + f1_score(newsgroups_test.target, pred, average='weighted') + c = make_pipeline(vectorizer, nb) + explainer = LimeTextExplainer(class_names=class_names) + idx = 83 + exp = explainer.explain_instance(newsgroups_test.data[idx], + c.predict_proba, num_features=6) + self.assertIsNotNone(exp) + self.assertEqual(6, len(exp.as_list())) + + def test_lime_text_tabular_equal_random_state(self): + categories = ['alt.atheism', 'soc.religion.christian'] + newsgroups_train = fetch_20newsgroups(subset='train', + categories=categories) + newsgroups_test = fetch_20newsgroups(subset='test', + categories=categories) + class_names = ['atheism', 'christian'] + vectorizer = TfidfVectorizer(lowercase=False) + train_vectors = vectorizer.fit_transform(newsgroups_train.data) + test_vectors = vectorizer.transform(newsgroups_test.data) + nb = MultinomialNB(alpha=.01) + nb.fit(train_vectors, newsgroups_train.target) + pred = nb.predict(test_vectors) + f1_score(newsgroups_test.target, pred, average='weighted') + c = make_pipeline(vectorizer, nb) + + explainer = LimeTextExplainer(class_names=class_names, random_state=10) + exp_1 = explainer.explain_instance(newsgroups_test.data[83], + c.predict_proba, num_features=6) + + explainer = LimeTextExplainer(class_names=class_names, random_state=10) + exp_2 = explainer.explain_instance(newsgroups_test.data[83], + c.predict_proba, num_features=6) + + self.assertTrue(exp_1.as_map() == exp_2.as_map()) + + def test_lime_text_tabular_not_equal_random_state(self): + categories = ['alt.atheism', 'soc.religion.christian'] + newsgroups_train = fetch_20newsgroups(subset='train', + categories=categories) + newsgroups_test = fetch_20newsgroups(subset='test', + categories=categories) + class_names = ['atheism', 'christian'] + vectorizer = TfidfVectorizer(lowercase=False) + train_vectors = vectorizer.fit_transform(newsgroups_train.data) + test_vectors = vectorizer.transform(newsgroups_test.data) + nb = MultinomialNB(alpha=.01) + nb.fit(train_vectors, newsgroups_train.target) + pred = nb.predict(test_vectors) + f1_score(newsgroups_test.target, pred, average='weighted') + c = make_pipeline(vectorizer, nb) + + explainer = LimeTextExplainer( + class_names=class_names, random_state=10) + exp_1 = explainer.explain_instance(newsgroups_test.data[83], + c.predict_proba, num_features=6) + + explainer = LimeTextExplainer( + class_names=class_names, random_state=20) + exp_2 = explainer.explain_instance(newsgroups_test.data[83], + c.predict_proba, num_features=6) + + self.assertFalse(exp_1.as_map() == exp_2.as_map()) + + def test_indexed_characters_bow(self): + s = 'Please, take your time' + inverse_vocab = ['P', 'l', 'e', 'a', 's', ',', ' ', 't', 'k', 'y', 'o', 'u', 'r', 'i', 'm'] + positions = [[0], [1], [2, 5, 11, 21], [3, 9], + [4], [6], [7, 12, 17], [8, 18], [10], + [13], [14], [15], [16], [19], [20]] + ic = IndexedCharacters(s) + + self.assertTrue(np.array_equal(ic.as_np, np.array(list(s)))) + self.assertTrue(np.array_equal(ic.string_start, np.arange(len(s)))) + self.assertTrue(ic.inverse_vocab == inverse_vocab) + self.assertTrue(ic.positions == positions) + + def test_indexed_characters_not_bow(self): + s = 'Please, take your time' + + ic = IndexedCharacters(s, bow=False) + + self.assertTrue(np.array_equal(ic.as_np, np.array(list(s)))) + self.assertTrue(np.array_equal(ic.string_start, np.arange(len(s)))) + self.assertTrue(ic.inverse_vocab == list(s)) + self.assertTrue(np.array_equal(ic.positions, np.arange(len(s)))) + + def test_indexed_string_regex(self): + s = 'Please, take your time. Please' + tokenized_string = np.array( + ['Please', ', ', 'take', ' ', 'your', ' ', 'time', '. ', 'Please']) + inverse_vocab = ['Please', 'take', 'your', 'time'] + start_positions = [0, 6, 8, 12, 13, 17, 18, 22, 24] + positions = [[0, 8], [2], [4], [6]] + indexed_string = IndexedString(s) + + self.assertTrue(np.array_equal(indexed_string.as_np, tokenized_string)) + self.assertTrue(np.array_equal(indexed_string.string_start, start_positions)) + self.assertTrue(indexed_string.inverse_vocab == inverse_vocab) + self.assertTrue(np.array_equal(indexed_string.positions, positions)) + + def test_indexed_string_callable(self): + s = 'aabbccddaa' + + def tokenizer(string): + return [string[i] + string[i + 1] for i in range(0, len(string) - 1, 2)] + + tokenized_string = np.array(['aa', 'bb', 'cc', 'dd', 'aa']) + inverse_vocab = ['aa', 'bb', 'cc', 'dd'] + start_positions = [0, 2, 4, 6, 8] + positions = [[0, 4], [1], [2], [3]] + indexed_string = IndexedString(s, tokenizer) + + self.assertTrue(np.array_equal(indexed_string.as_np, tokenized_string)) + self.assertTrue(np.array_equal(indexed_string.string_start, start_positions)) + self.assertTrue(indexed_string.inverse_vocab == inverse_vocab) + self.assertTrue(np.array_equal(indexed_string.positions, positions)) + + def test_indexed_string_inverse_removing_tokenizer(self): + s = 'This is a good movie. This, it is a great movie.' + + def tokenizer(string): + return re.split(r'(?:\W+)|$', string) + + indexed_string = IndexedString(s, tokenizer) + + self.assertEqual(s, indexed_string.inverse_removing([])) + + def test_indexed_string_inverse_removing_regex(self): + s = 'This is a good movie. This is a great movie' + indexed_string = IndexedString(s) + + self.assertEqual(s, indexed_string.inverse_removing([])) + + +if __name__ == '__main__': + unittest.main() diff --git a/lime/tests/test_scikit_image.py b/lime/tests/test_scikit_image.py new file mode 100644 index 0000000000000000000000000000000000000000..7c18040cceeddd25b7cd2867a8d0f59da04ca974 --- /dev/null +++ b/lime/tests/test_scikit_image.py @@ -0,0 +1,128 @@ +import unittest +from lime.wrappers.scikit_image import BaseWrapper +from lime.wrappers.scikit_image import SegmentationAlgorithm +from skimage.segmentation import quickshift +from skimage.data import chelsea +from skimage.util import img_as_float +import numpy as np + + +class TestBaseWrapper(unittest.TestCase): + + def test_base_wrapper(self): + + obj_with_params = BaseWrapper(a=10, b='message') + obj_without_params = BaseWrapper() + + def foo_fn(): + return 'bar' + + obj_with_fn = BaseWrapper(foo_fn) + self.assertEqual(obj_with_params.target_params, {'a': 10, 'b': 'message'}) + self.assertEqual(obj_without_params.target_params, {}) + self.assertEqual(obj_with_fn.target_fn(), 'bar') + + def test__check_params(self): + + def bar_fn(a): + return str(a) + + class Pipo(): + + def __init__(self): + self.name = 'pipo' + + def __call__(self, message): + return message + + pipo = Pipo() + obj_with_valid_fn = BaseWrapper(bar_fn, a=10, b='message') + obj_with_valid_callable_fn = BaseWrapper(pipo, c=10, d='message') + obj_with_invalid_fn = BaseWrapper([1, 2, 3], fn_name='invalid') + + # target_fn is not a callable or function/method + with self.assertRaises(AttributeError): + obj_with_invalid_fn._check_params('fn_name') + + # parameters is not in target_fn args + with self.assertRaises(ValueError): + obj_with_valid_fn._check_params(['c']) + obj_with_valid_callable_fn._check_params(['e']) + + # params is in target_fn args + try: + obj_with_valid_fn._check_params(['a']) + obj_with_valid_callable_fn._check_params(['message']) + except Exception: + self.fail("_check_params() raised an unexpected exception") + + # params is not a dict or list + with self.assertRaises(TypeError): + obj_with_valid_fn._check_params(None) + with self.assertRaises(TypeError): + obj_with_valid_fn._check_params('param_name') + + def test_set_params(self): + + class Pipo(): + + def __init__(self): + self.name = 'pipo' + + def __call__(self, message): + return message + pipo = Pipo() + obj = BaseWrapper(pipo) + + # argument is set accordingly + obj.set_params(message='OK') + self.assertEqual(obj.target_params, {'message': 'OK'}) + self.assertEqual(obj.target_fn(**obj.target_params), 'OK') + + # invalid argument is passed + try: + obj = BaseWrapper(Pipo()) + obj.set_params(invalid='KO') + except Exception: + self.assertEqual(obj.target_params, {}) + + def test_filter_params(self): + + # right arguments are kept and wrong dismmissed + def baz_fn(a, b, c=True): + if c: + return a + b + else: + return a + obj_ = BaseWrapper(baz_fn, a=10, b=100, d=1000) + self.assertEqual(obj_.filter_params(baz_fn), {'a': 10, 'b': 100}) + + # target_params is overriden using 'override' argument + self.assertEqual(obj_.filter_params(baz_fn, override={'c': False}), + {'a': 10, 'b': 100, 'c': False}) + + +class TestSegmentationAlgorithm(unittest.TestCase): + + def test_instanciate_segmentation_algorithm(self): + img = img_as_float(chelsea()[::2, ::2]) + + # wrapped functions provide the same result + fn = SegmentationAlgorithm('quickshift', kernel_size=3, max_dist=6, + ratio=0.5, random_seed=133) + fn_result = fn(img) + original_result = quickshift(img, kernel_size=3, max_dist=6, ratio=0.5, + random_seed=133) + + # same segments + self.assertTrue(np.array_equal(fn_result, original_result)) + + def test_instanciate_slic(self): + pass + + def test_instanciate_felzenszwalb(self): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/lime/utils/__init__.py b/lime/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lime/utils/generic_utils.py b/lime/utils/generic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ea7eca06b04cd7e5178dc112ffa60eb10e643158 --- /dev/null +++ b/lime/utils/generic_utils.py @@ -0,0 +1,39 @@ +import sys +import inspect +import types + + +def has_arg(fn, arg_name): + """Checks if a callable accepts a given keyword argument. + + Args: + fn: callable to inspect + arg_name: string, keyword argument name to check + + Returns: + bool, whether `fn` accepts a `arg_name` keyword argument. + """ + if sys.version_info < (3,): + if isinstance(fn, types.FunctionType) or isinstance(fn, types.MethodType): + arg_spec = inspect.getargspec(fn) + else: + try: + arg_spec = inspect.getargspec(fn.__call__) + except AttributeError: + return False + return (arg_name in arg_spec.args) + elif sys.version_info < (3, 6): + arg_spec = inspect.getfullargspec(fn) + return (arg_name in arg_spec.args or + arg_name in arg_spec.kwonlyargs) + else: + try: + signature = inspect.signature(fn) + except ValueError: + # handling Cython + signature = inspect.signature(fn.__call__) + parameter = signature.parameters.get(arg_name) + if parameter is None: + return False + return (parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY)) diff --git a/lime/webpack.config.js b/lime/webpack.config.js new file mode 100644 index 0000000000000000000000000000000000000000..b37b258bbfb2c1ebe454c03fa01016073f2eb00b --- /dev/null +++ b/lime/webpack.config.js @@ -0,0 +1,40 @@ +var path = require('path'); +var webpack = require('webpack'); + +module.exports = { + entry: './js/main.js', + output: { + path: __dirname, + filename: 'bundle.js', + library: 'lime' + }, + module: { + loaders: [ + { + loader: 'babel-loader', + test: path.join(__dirname, 'js'), + query: { + presets: 'es2015-ie', + }, + + }, + { + test: /\.css$/, + loaders: ['style-loader', 'css-loader'], + + } + + ] + }, + plugins: [ + // Avoid publishing files when compilation fails + new webpack.NoErrorsPlugin() + ], + stats: { + // Nice colored output + colors: true + }, + // Create Sourcemaps for the bundle + devtool: 'source-map', +}; + diff --git a/lime/wrappers/__init__.py b/lime/wrappers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lime/wrappers/scikit_image.py b/lime/wrappers/scikit_image.py new file mode 100644 index 0000000000000000000000000000000000000000..1b6228786f42541f5e4a8dcfaa88b343b036bed4 --- /dev/null +++ b/lime/wrappers/scikit_image.py @@ -0,0 +1,117 @@ +import types +from lime.utils.generic_utils import has_arg +from skimage.segmentation import felzenszwalb, slic, quickshift + + +class BaseWrapper(object): + """Base class for LIME Scikit-Image wrapper + + + Args: + target_fn: callable function or class instance + target_params: dict, parameters to pass to the target_fn + + + 'target_params' takes parameters required to instanciate the + desired Scikit-Image class/model + """ + + def __init__(self, target_fn=None, **target_params): + self.target_fn = target_fn + self.target_params = target_params + + self.target_fn = target_fn + self.target_params = target_params + + def _check_params(self, parameters): + """Checks for mistakes in 'parameters' + + Args : + parameters: dict, parameters to be checked + + Raises : + ValueError: if any parameter is not a valid argument for the target function + or the target function is not defined + TypeError: if argument parameters is not iterable + """ + a_valid_fn = [] + if self.target_fn is None: + if callable(self): + a_valid_fn.append(self.__call__) + else: + raise TypeError('invalid argument: tested object is not callable,\ + please provide a valid target_fn') + elif isinstance(self.target_fn, types.FunctionType) \ + or isinstance(self.target_fn, types.MethodType): + a_valid_fn.append(self.target_fn) + else: + a_valid_fn.append(self.target_fn.__call__) + + if not isinstance(parameters, str): + for p in parameters: + for fn in a_valid_fn: + if has_arg(fn, p): + pass + else: + raise ValueError('{} is not a valid parameter'.format(p)) + else: + raise TypeError('invalid argument: list or dictionnary expected') + + def set_params(self, **params): + """Sets the parameters of this estimator. + Args: + **params: Dictionary of parameter names mapped to their values. + + Raises : + ValueError: if any parameter is not a valid argument + for the target function + """ + self._check_params(params) + self.target_params = params + + def filter_params(self, fn, override=None): + """Filters `target_params` and return those in `fn`'s arguments. + Args: + fn : arbitrary function + override: dict, values to override target_params + Returns: + result : dict, dictionary containing variables + in both target_params and fn's arguments. + """ + override = override or {} + result = {} + for name, value in self.target_params.items(): + if has_arg(fn, name): + result.update({name: value}) + result.update(override) + return result + + +class SegmentationAlgorithm(BaseWrapper): + """ Define the image segmentation function based on Scikit-Image + implementation and a set of provided parameters + + Args: + algo_type: string, segmentation algorithm among the following: + 'quickshift', 'slic', 'felzenszwalb' + target_params: dict, algorithm parameters (valid model paramters + as define in Scikit-Image documentation) + """ + + def __init__(self, algo_type, **target_params): + self.algo_type = algo_type + if (self.algo_type == 'quickshift'): + BaseWrapper.__init__(self, quickshift, **target_params) + kwargs = self.filter_params(quickshift) + self.set_params(**kwargs) + elif (self.algo_type == 'felzenszwalb'): + BaseWrapper.__init__(self, felzenszwalb, **target_params) + kwargs = self.filter_params(felzenszwalb) + self.set_params(**kwargs) + elif (self.algo_type == 'slic'): + BaseWrapper.__init__(self, slic, **target_params) + kwargs = self.filter_params(slic) + self.set_params(**kwargs) + + def __call__(self, *args): + return self.target_fn(args[0], **self.target_params) diff --git a/loader/__init__.py b/loader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/loader/data_loader.py b/loader/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..7c277f2e62d5837cfc4107345aa3da71d9a5188e --- /dev/null +++ b/loader/data_loader.py @@ -0,0 +1,799 @@ +from functools import partial +import numpy +import os +import re +import random +import signal +import csv +from PIL import Image +import settings +import numpy as np +from collections import OrderedDict +import cv2 +# from scipy.misc import imread +from multiprocessing import Pool, cpu_count +from multiprocessing.pool import ThreadPool +from scipy.ndimage.interpolation import zoom +import sys +import pickle + +def load_csv(filename, readfields=None): + def convert(value): + if re.match(r'^-?\d+$', value): + try: + return int(value) + except: + pass + if re.match(r'^-?[\.\d]+(?:e[+=]\d+)$', value): + try: + return float(value) + except: + pass + return value + + with open(filename) as f: + reader = csv.DictReader(f) + result = [{k: convert(v) for k, v in row.items()} for row in reader] + if readfields is not None: + readfields.extend(reader.fieldnames) + return result + +class AbstractSegmentation: + def all_names(self, category, j): + raise NotImplementedError + def size(self, split=None): + return 0 + def filename(self, i): + raise NotImplementedError + def metadata(self, i): + return self.filename(i) + @classmethod + def resolve_segmentation(cls, m): + return {} + + def name(self, category, i): + ''' + Default implemtnation for segmentation_data, + utilizing all_names. + ''' + all_names = self.all_names(category, i) + return all_names[0] if len(all_names) else '' + + def segmentation_data(self, category, i, c=0, full=False): + ''' + Default implemtnation for segmentation_data, + utilizing metadata and resolve_segmentation. + ''' + segs = self.resolve_segmentation( + self.metadata(i), categories=[category]) + if category not in segs: + return 0 + data = segs[category] + if not full and len(data.shape) >= 3: + return data[0] + return data + + +class SegmentationData(AbstractSegmentation): + ''' + Represents and loads a multi-channel segmentation represented with + a series of csv files: index.csv lists the images together with + any label data avilable in each category; category.csv lists + the categories of segmentations available; and label.csv lists the + numbers used to describe each label class. In addition, the categories + each have a separate c_*.csv file describing a dense coding of labels. + + isImageSet - if True, duplicate rgb images in index.csv will be removed + ''' + + def __init__(self, directory, categories=None, require_all=False, isImageSet=False): + directory = os.path.expanduser(directory) + self.directory = directory + with open(os.path.join(directory, settings.INDEX_FILE)) as f: + self.image = [decode_index_dict(r) for r in csv.DictReader(f)] + # self.actualFeatIdx = None ### use this dict in tally to map duplicates idx to nonduplicates + # if isImageSet is True: + # self.actualFeatIdx = {} + # self.newImgSet = [] + # self.duplicateDict = {} + # for imgRGBIdx, imgRGBData in enumerate(self.image): + # if imgRGBData["image"] not in self.duplicateDict: + # self.newImgSet.append(imgRGBData) + # self.duplicateDict[imgRGBData["image"]] = len(self.newImgSet) - 1 + # # print("self.duplicateDict[imgRGBData[image]]: ", self.duplicateDict[imgRGBData["image"]]) + # self.actualFeatIdx[imgRGBIdx] = self.duplicateDict[imgRGBData["image"]] ### Start at 0 + # self.image = self.newImgSet + # print("data_set.actualFeatIdx: ", self.actualFeatIdx) + # sys.exit() + # print("self.image: ", self.image) ### list + # print("type self.image: ", type(self.image)) ### list + # print("len self.image: ", len(self.image)) ### total rows in index.csv + # sys.exit() + with open(os.path.join(directory, 'category.csv')) as f: + self.category = OrderedDict() + for row in csv.DictReader(f): + if categories and row['name'] in categories: + self.category[row['name']] = row + categories = self.category.keys() + with open(os.path.join(directory, 'label.csv')) as f: + label_data = [decode_label_dict(r) for r in csv.DictReader(f)] + self.label = build_dense_label_array(label_data) ### Len is label_data+1 (from csv), 0 index is None + # print("self.label[0]: ", self.label[0]) ### None value (no class specified) + # print("self.label: ", self.label) + # sys.exit() + # Filter out images with insufficient data + filter_fn = partial( + index_has_all_data if require_all else index_has_any_data, + categories=categories) + self.image = [row for row in self.image if filter_fn(row)] + # Build dense remapping arrays for labels, so that you can + # get dense ranges of labels for each category. + self.category_map = {} + self.category_unmap = {} + self.category_label = {} + for cat in self.category: + with open(os.path.join(directory, 'c_%s.csv' % cat)) as f: + c_data = [decode_label_dict(r) for r in csv.DictReader(f)] + self.category_unmap[cat], self.category_map[cat] = ( + build_numpy_category_map(c_data)) + self.category_label[cat] = build_dense_label_array( + c_data, key='code') + # print("category_unmap: ", self.category_unmap) + self.labelcat = self.onehot(self.primary_categories_per_index()) # (480,1) + ### labelcat - all ones + + def primary_categories_per_index(ds): + ''' + Returns an array of primary category numbers for each label, where the + first category listed in ds.category_names is given category number 0. + ''' + catmap = {} + categories = ds.category_names() + for cat in categories: + imap = ds.category_index_map(cat) + if len(imap) < ds.label_size(None): + imap = np.concatenate((imap, np.zeros( + ds.label_size(None) - len(imap), dtype=imap.dtype))) + catmap[cat] = imap + result = [] + for i in range(ds.label_size(None)): + maxcov, maxcat = max( + (ds.coverage(cat, catmap[cat][i]) if catmap[cat][i] else 0, ic) + for ic, cat in enumerate(categories)) + result.append(maxcat) + return np.array(result) + + def onehot(self, arr, minlength=None): + ''' + Expands an array of integers in one-hot encoding by adding a new last + dimension, leaving zeros everywhere except for the nth dimension, where + the original array contained the integer n. The minlength parameter is + used to indcate the minimum size of the new dimension. + ''' + length = np.amax(arr) + 1 + if minlength is not None: + length = max(minlength, length) + result = np.zeros(arr.shape + (length,)) + result[list(np.indices(arr.shape)) + [arr]] = 1 + return result + + + def all_names(self, category, j): + '''All English synonyms for the given label''' + if category is not None: + j = self.category_unmap[category][j] + return [self.label[j]['name']] + self.label[j]['syns'] + + def size(self, split=None): + '''The number of images in this data set.''' + if split is None: + return len(self.image) + return len([im for im in self.image if im['split'] == split]) + + def filename(self, i): + '''The filename of the ith jpeg (original image).''' + return os.path.join(self.directory, 'images', self.image[i]['image']) + + def split(self, i): + '''Which split contains item i.''' + return self.image[i]['split'] + + def metadata(self, i): + '''Extract metadata for image i, For efficient data loading.''' + return self.directory, self.image[i] + + meta_categories = ['image', 'split', 'ih', 'iw', 'sh', 'sw'] + + @classmethod + def resolve_segmentation(cls, m, categories=None, segm_to_label=None): + ''' + Resolves a full segmentation, potentially in a differenct process, + for efficient multiprocess data loading. + ''' + directory, row = m + result = {} + for cat, d in row.items(): + if cat in cls.meta_categories: + continue + if not wants(cat, categories): + continue + if all(isinstance(data, int) for data in d): + result[cat] = d + continue + out = numpy.empty((len(d), row['sh'], row['sw']), dtype=numpy.int16) + for i, channel in enumerate(d): + if isinstance(channel, int): + out[i] = channel + else: + segmFilenameSplit = channel.split('/') + segmFileToLabelName = segmFilenameSplit[-2] + "/" + segmFilenameSplit[-1] + if 'seg_label' not in result: + result['seg_label'] = [] + result['seg_label'].append(segm_to_label[segmFileToLabelName]) + # print("os.path.join(directory, 'images', channel)): ", os.path.join(directory, 'images', channel)) + rgb = cv2.resize(cv2.imread(os.path.join(directory, 'images', channel)), (settings.SEGM_SIZE, settings.SEGM_SIZE)) + rgb[:,:,0] = 0 + rgb[:,:,2] = np.where(rgb[:,:,2]>0, 235, 0) + rgb[:,:,1] = np.where(rgb[:,:,2]>0, 1, 0) + out[i] = rgb[:,:,0] + rgb[:,:,1] * 256 + result[cat] = out + return result, (row['sh'], row['sw']) + + def label_size(self, category=None): + ''' + Returns the number of distinct labels (plus zero), i.e., one + more than the maximum label number. If a category is specified, + returns the number of distinct labels within that category. + ''' + if category is None: + return len(self.label) + else: + return len(self.category_unmap[category]) + + def name(self, category, j): + ''' + Returns an English name for the jth label. If a category is + specified, returns the name for the category-specific nubmer j. + If category=None, then treats j as a fully unified index number. + ''' + if category is not None: + j = self.category_unmap[category][j] + return self.label[j]['name'] + + def frequency(self, category, j): + ''' + Returns the number of images for which the label appears. + ''' + if category is not None: + return self.category_label[category][j]['frequency'] + return self.label[j]['frequency'] + + def coverage(self, category, j): + ''' + Returns the pixel coverage of the label in units of whole-images. + ''' + if category is not None: + return self.category_label[category][j]['coverage'] + return self.label[j]['coverage'] + + def category_names(self): + ''' + Returns the set of category names. + ''' + return list(self.category.keys()) + + def category_frequency(self, category): + ''' + Returns the number of images touched by a category. + ''' + return float(self.category[category]['frequency']) + + def primary_categories_per_index(self, categories=None): + ''' + Returns an array of primary category numbers for each label, where + catagories are indexed according to the list of categories passed, + or self.category_names() if none. + ''' + if categories is None: + categories = self.category_names() + # Make lists which are nonzero for labels in a category + catmap = {} + for cat in categories: + imap = self.category_index_map(cat) + if len(imap) < self.label_size(None): + imap = numpy.concatenate((imap, numpy.zeros( + self.label_size(None) - len(imap), dtype=imap.dtype))) + catmap[cat] = imap + # For each label, find the category with maximum coverage. + result = [] + for i in range(self.label_size(None)): + maxcov, maxcat = max( + (self.coverage(cat, catmap[cat][i]) + if catmap[cat][i] else 0, ic) + for ic, cat in enumerate(categories)) + result.append(maxcat) + # Return the max-coverage cateogry for each label. + return numpy.array(result) + + + def segmentation_data(self, category, i, c=0, full=False, out=None): + ''' + Returns a 2-d numpy matrix with segmentation data for the ith image, + restricted to the given category. By default, maps all label numbers + to the category-specific dense mapping described in the c_*.csv + listing; but can be asked to expose the fully unique indexing by + using full=True. + ''' + row = self.image[i] + data_channels = row.get(category, ()) + if c >= len(data_channels): + channel = 0 # Deal with unlabeled data in this category + else: + channel = data_channels[c] + if out is None: + out = numpy.empty((row['sh'], row['sw']), dtype=numpy.int16) + if isinstance(channel, int): + if not full: + channel = self.category_map[category][channel] + out[:,:] = channel # Single-label for the whole image + return out + png = cv2.resize(cv2.imread(os.path.join(self.directory, 'images', channel)), (settings.SEGM_SIZE, settings.SEGM_SIZE)) + png[:,:,0] = 0 + png[:,:,2] = np.where(png[:,:,2]>0, 235, 0) + png[:,:,1] = np.where(png[:,:,2]>0, 1, 0) + if full: + # Full case: just combine png channels. + out[...] = png[:,:,0] + png[:,:,1] * 256 + else: + # Dense case: combine png channels and apply the category map. + catmap = self.category_map[category] + out[...] = catmap[png[:,:,0] + png[:,:,1] * 256] + return out + + def full_segmentation_data(self, i, + categories=None, max_depth=None, out=None): + ''' + Returns a 3-d numpy tensor with segmentation data for the ith image, + with multiple layers represnting multiple lables for each pixel. + The depth is variable depending on available data but can be + limited to max_depth. + ''' + row = self.image[i] + if categories: + groups = [d for cat, d in row.items() if cat in categories and d] + else: + groups = [d for cat, d in row.items() if d and ( + cat not in self.meta_categories)] + depth = sum(len(c) for c in groups) + if max_depth is not None: + depth = min(depth, max_depth) + # Allocate an array if not already allocated. + if out is None: + out = numpy.empty((depth, row['sh'], row['sw']), dtype=numpy.int16) + i = 0 + # Stack up the result segmentation one channel at a time + for group in groups: + for channel in group: + if isinstance(channel, int): + out[i] = channel + else: + png = cv2.resize(cv2.imread(os.path.join(self.directory, 'images', channel)), (settings.SEGM_SIZE, settings.SEGM_SIZE)) + png[:,:,0] = 0 + png[:,:,2] = np.where(png[:,:,2]>0, 235, 0) + png[:,:,1] = np.where(png[:,:,2]>0, 1, 0) + out[i] = png[:,:,0] + png[:,:,1] * 256 + i += 1 + if i == depth: + return out + # Return above when we get up to depth + assert False + + def category_index_map(self, category): + return numpy.array(self.category_map[category]) + +def build_dense_label_array(label_data, key='number', allow_none=False): + ''' + Input: set of rows with 'number' fields (or another field name key). + Output: array such that a[number] = the row with the given number. + ''' + result = [None] * (max([d[key] for d in label_data]) + 1) + for d in label_data: + result[d[key]] = d + # Fill in none + if not allow_none: + example = label_data[0] + def make_empty(k): + return dict((c, k if c is key else type(v)()) + for c, v in example.items()) + for i, d in enumerate(result): + if d is None: + result[i] = dict(make_empty(i)) + return result + +def build_numpy_category_map(map_data, key1='code', key2='number'): + ''' + Input: set of rows with 'number' fields (or another field name key). + Output: array such that a[number] = the row with the given number. + ''' + results = list(numpy.zeros((max([d[key] for d in map_data]) + 1), + dtype=numpy.int16) for key in (key1, key2)) + for d in map_data: + results[0][d[key1]] = d[key2] + results[1][d[key2]] = d[key1] + return results + +def decode_label_dict(row): + result = {} + for key, val in row.items(): + if key == 'category': + result[key] = dict((c, int(n)) + for c, n in [re.match('^([^(]*)\(([^)]*)\)$', f).groups() + for f in val.split(';')]) + elif key == 'name': + result[key] = val + elif key == 'syns': + result[key] = val.split(';') + elif re.match('^\d+$', val): + result[key] = int(val) + elif re.match('^\d+\.\d*$', val): + result[key] = float(val) + else: + result[key] = val + return result + +def decode_index_dict(row): + result = {} + for key, val in row.items(): + if key in ['image', 'split']: + result[key] = val + elif key in ['sw', 'sh', 'iw', 'ih']: + result[key] = int(val) + else: + item = [s for s in val.split(';') if s] + for i, v in enumerate(item): + if re.match('^\d+$', v): + item[i] = int(v) + result[key] = item + return result + +def index_has_any_data(row, categories): + for c in categories: + for data in row[c]: + if data: return True + return False + +def index_has_all_data(row, categories): + for c in categories: + cat_has = False + for data in row[c]: + if data: + cat_has = True + break + if not cat_has: + return False + return True + +class SegmentationPrefetcher: + ''' + SegmentationPrefetcher will prefetch a bunch of segmentation + images using a multiprocessing pool, so you do not have to wait + around while the files get opened and decoded. Just request + batches of images and segmentations calling fetch_batch(). + ''' + def __init__(self, segmentation, split=None, randomize=False, + segmentation_shape=None, categories=None, once=False, + start=None, end=None, batch_size=4, ahead=4, thread=False): + ''' + Constructor arguments: + segmentation: The AbstractSegmentation to load. + split: None for no filtering, or 'train' or 'val' etc. + randomize: True to randomly shuffle order, or a random seed. + categories: a list of categories to include in each batch. + batch_size: number of data items for each batch. + ahead: the number of data items to prefetch ahead. + ''' + self.segmentation = segmentation + self.segm_to_label = None + with open(settings.SEGM_TO_LABEL_PKL, 'rb') as f: + self.segm_to_label = pickle.load(f) + self.split = split + self.randomize = randomize + self.random = random.Random() + if randomize is not True: + self.random.seed(randomize) + self.categories = categories + self.once = once + self.batch_size = batch_size + self.ahead = ahead + # Initialize the multiprocessing pool + n_procs = cpu_count() + if thread: + self.pool = ThreadPool(processes=n_procs) + else: + original_sigint_handler = setup_sigint() + self.pool = Pool(processes=n_procs, initializer=setup_sigint) + restore_sigint(original_sigint_handler) + # Prefilter the image indexes of interest + if start is None: + start = 0 + if end is None: + end = segmentation.size() + self.indexes = range(start, end) + if split: + self.indexes = [i for i in self.indexes + if segmentation.split(i) == split] + if self.randomize: + self.random.shuffle(self.indexes) + self.index = 0 + self.result_queue = [] + self.segmentation_shape = segmentation_shape + # Get dense catmaps + self.catmaps = [ + segmentation.category_index_map(cat) if cat != 'image' else None + for cat in categories] + + def next_job(self): + if self.index < 0: + return None + j = self.indexes[self.index] + result = (j, + self.segmentation.__class__, + self.segmentation.metadata(j), + self.segmentation.filename(j), + self.categories, + self.segm_to_label, + self.segmentation_shape) + self.index += 1 + if self.index >= len(self.indexes): + if self.once: + self.index = -1 + else: + self.index = 0 + if self.randomize: + # Reshuffle every time through + self.random.shuffle(self.indexes) + return result + + def batches(self): + '''Iterator for all batches''' + while True: + batch = self.fetch_batch() + if batch is None: + break + else: + yield batch + # def batches(self): + # '''Iterator for all batches''' + # while True: + # batch = self.fetch_batch() + # if batch is None: + # raise StopIteration + # yield batch + def fetch_batch(self): + '''Returns a single batch as an array of dictionaries.''' + try: + self.refill_tasks() + if len(self.result_queue) == 0: + return None + result = self.result_queue.pop(0) + return result.get(31536000) + except KeyboardInterrupt: + print("Caught KeyboardInterrupt, terminating workers") + self.pool.terminate() + raise + + def fetch_tensor_batch(self, bgr_mean=None, global_labels=False): + '''Iterator for batches as arrays of tensors.''' + batch = self.fetch_batch() + return self.form_caffe_tensors(batch, bgr_mean, global_labels) + + def tensor_batches(self, bgr_mean=None, global_labels=False): + '''Returns a single batch as an array of tensors, one per category.''' + while True: + batch = self.fetch_tensor_batch( + bgr_mean=bgr_mean, global_labels=global_labels) + if batch is None: + break + else: + yield batch + + def form_caffe_tensors(self, batch, bgr_mean=None, global_labels=False): + # Assemble a batch in [{'cat': data,..},..] format into + # an array of batch tensors, the first for the image, and the + # remaining for each category in self.categories, in order. + # This also applies a random flip if needed + if batch is None: + return None + batches = [[] for c in self.categories] + for record in batch: + default_shape = (1, record['sh'], record['sw']) + for c, cat in enumerate(self.categories): + if cat == 'image': + # Normalize image with right RGB order and mean + batches[c].append(normalize_image( + record[cat], bgr_mean)) + elif global_labels: + batches[c].append(normalize_label( + record[cat], default_shape, flatten=True)) + else: + catmap = self.catmaps[c] + batches[c].append(catmap[normalize_label( + record[cat], default_shape, flatten=True)]) + return [numpy.concatenate(tuple(m[numpy.newaxis] for m in b)) + for b in batches] + + def refill_tasks(self): + # It will call the sequencer to ask for a sequence + # of batch_size jobs (indexes with categories) + # Then it will call pool.map_async + while len(self.result_queue) < self.ahead: + data = [] + while len(data) < self.batch_size: + job = self.next_job() + if job is None: + break + data.append(job) + if len(data) == 0: + return + self.result_queue.append(self.pool.map_async(prefetch_worker, data)) + + def close(self): + while len(self.result_queue): + result = self.result_queue.pop(0) + if result is not None: + result.wait(0.001) + self.pool.close() + self.poool.cancel_join_thread() + +def prefetch_worker(d): + if d is None: + return None + j, typ, m, fn, categories, segm_to_label, segmentation_shape = d + segs, shape = typ.resolve_segmentation(m, categories=categories, segm_to_label=segm_to_label) + if segmentation_shape is not None: + for k, v in segs.items(): + print("k: ", k) + segs[k] = scale_segmentation(v, segmentation_shape) + shape = segmentation_shape + # Some additional metadata to provide + segs['sh'], segs['sw'] = shape + segs['i'] = j + segs['fn'] = fn + if categories is None or 'image' in categories: + segs['image'] = np.asarray(Image.open(fn).convert('L').resize((settings.IMG_SIZE, settings.IMG_SIZE))) + return segs +# def convertRGBToGray(rgbImg): +# return np.dot(rgbImg[...,:3], [0.114, 0.587, 0.299]) +def scale_segmentation(segmentation, dims, crop=False): + ''' + Zooms a 2d or 3d segmentation to the given dims, using nearest neighbor. + ''' + shape = numpy.shape(segmentation) + if len(shape) < 2 or shape[-2:] == dims: + return segmentation + peel = (len(shape) == 2) + if peel: + segmentation = segmentation[numpy.newaxis] + levels = segmentation.shape[0] + result = numpy.zeros((levels, ) + dims, + dtype=segmentation.dtype) + ratio = (1,) + tuple(res / float(orig) + for res, orig in zip(result.shape[1:], segmentation.shape[1:])) + if not crop: + safezoom(segmentation, ratio, output=result, order=0) + else: + ratio = max(ratio[1:]) + height = int(round(dims[0] / ratio)) + hmargin = (segmentation.shape[0] - height) // 2 + width = int(round(dims[1] / ratio)) + wmargin = (segmentation.shape[1] - height) // 2 + safezoom(segmentation[:, hmargin:hmargin+height, + wmargin:wmargin+width], + (1, ratio, ratio), output=result, order=0) + if peel: + result = result[0] + return result + +def safezoom(array, ratio, output=None, order=0): + '''Like numpy.zoom, but does not crash when the first dimension + of the array is of size 1, as happens often with segmentations''' + dtype = array.dtype + if array.dtype == numpy.float16: + array = array.astype(numpy.float32) + if array.shape[0] == 1: + if output is not None: + output = output[0,...] + result = zoom(array[0,...], ratio[1:], + output=output, order=order) + if output is None: + output = result[numpy.newaxis] + else: + result = zoom(array, ratio, output=output, order=order) + if output is None: + output = result + return output.astype(dtype) + +def setup_sigint(): + import threading + if not isinstance(threading.current_thread(), threading._MainThread): + return None + return signal.signal(signal.SIGINT, signal.SIG_IGN) + +def restore_sigint(original): + import threading + if not isinstance(threading.current_thread(), threading._MainThread): + return + if original is None: + original = signal.SIG_DFL + signal.signal(signal.SIGINT, original) + +def wants(what, option): + if option is None: + return True + return what in option + +def normalize_image(rgb_image, bgr_mean): + """ + Load input image and preprocess for Caffe: + - cast to float + - switch channels RGB -> BGR + - subtract mean + - transpose to channel x height x width order + """ + # img = numpy.array(rgb_image, dtype=numpy.float32) + # if (img.ndim == 2): + # img = numpy.repeat(img[:,:,None], 3, axis = 2) + # # img = img[:,:,::-1] + # if bgr_mean is not None: + # img -= bgr_mean + rgb_image = np.expand_dims(rgb_image, axis=0) + # rgb_image = rgb_image.transpose((2,0,1)) + # print("rgb_image shape: ", rgb_image.shape) + return rgb_image + +### Original code +# def normalize_image(rgb_image, bgr_mean): +# """ +# Load input image and preprocess for Caffe: +# - cast to float +# - switch channels RGB -> BGR +# - subtract mean +# - transpose to channel x height x width order +# """ +# img = numpy.array(rgb_image, dtype=numpy.float32) +# if (img.ndim == 2): +# img = numpy.repeat(img[:,:,None], 3, axis = 2) +# img = img[:,:,::-1] +# if bgr_mean is not None: +# img -= bgr_mean +# img = img.transpose((2,0,1)) +# return img + +def normalize_label(label_data, shape, flatten=False): + """ + Given a 0, 1, 2, or 3-dimensional label_data and a default + shape of the form (1, y, x), returns a 3d tensor by + """ + dims = len(numpy.shape(label_data)) + if dims <= 2: + # Scalar data on this channel: fill shape + if dims == 1: + if flatten: + label_data = label_data[0] if len(label_data) else 0 + else: + return (numpy.ones(shape, dtype=numpy.int16) * + numpy.asarray(label_data, dtype=numpy.int16) + [:, numpy.newaxis, numpy.newaxis]) + return numpy.full(shape, label_data, dtype=numpy.int16) + else: + if dims == 3: + if flatten: + label_data = label_data[0] + else: + return label_data + return label_data[numpy.newaxis] + +if __name__ == '__main__': + data = SegmentationData('broden1_227') + pd = SegmentationPrefetcher(data,categories=data.category_names()+['image'],once=True) + bs = pd.batches().next() diff --git a/losses.py b/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..949ca21f59dfc48468b5fc1ff5886a73024712b9 --- /dev/null +++ b/losses.py @@ -0,0 +1,72 @@ +from fastai.vision import * + +from modules_abinet.model import Model + + +class MultiLosses(nn.Module): + def __init__(self, one_hot=True): + super().__init__() + self.ce = SoftCrossEntropyLoss() if one_hot else torch.nn.CrossEntropyLoss() + self.bce = torch.nn.BCELoss() + + @property + def last_losses(self): + return self.losses + + def _flatten(self, sources, lengths): + return torch.cat([t[:l] for t, l in zip(sources, lengths)]) + + def _merge_list(self, all_res): + if not isinstance(all_res, (list, tuple)): + return all_res + def merge(items): + if isinstance(items[0], torch.Tensor): return torch.cat(items, dim=0) + else: return items[0] + res = dict() + for key in all_res[0].keys(): + items = [r[key] for r in all_res] + res[key] = merge(items) + return res + + def _ce_loss(self, output, gt_labels, gt_lengths, idx=None, record=True): + loss_name = output.get('name') + pt_logits, weight = output['logits'], output['loss_weight'] + + assert pt_logits.shape[0] % gt_labels.shape[0] == 0 + iter_size = pt_logits.shape[0] // gt_labels.shape[0] + if iter_size > 1: + gt_labels = gt_labels.repeat(3, 1, 1) + gt_lengths = gt_lengths.repeat(3) + flat_gt_labels = self._flatten(gt_labels, gt_lengths) + flat_pt_logits = self._flatten(pt_logits, gt_lengths) + + nll = output.get('nll') + if nll is not None: + loss = self.ce(flat_pt_logits, flat_gt_labels, softmax=False) * weight + else: + loss = self.ce(flat_pt_logits, flat_gt_labels) * weight + if record and loss_name is not None: self.losses[f'{loss_name}_loss'] = loss + + return loss + + def forward(self, outputs, *args): + self.losses = {} + if isinstance(outputs, (tuple, list)): + outputs = [self._merge_list(o) for o in outputs] + return sum([self._ce_loss(o, *args) for o in outputs if o['loss_weight'] > 0.]) + else: + return self._ce_loss(outputs, *args, record=False) + + +class SoftCrossEntropyLoss(nn.Module): + def __init__(self, reduction="mean"): + super().__init__() + self.reduction = reduction + + def forward(self, input, target, softmax=True): + if softmax: log_prob = F.log_softmax(input, dim=-1) + else: log_prob = torch.log(input) + loss = -(target * log_prob).sum(dim=-1) + if self.reduction == "mean": return loss.mean() + elif self.reduction == "sum": return loss.sum() + else: return loss diff --git a/losses_matrn.py b/losses_matrn.py new file mode 100644 index 0000000000000000000000000000000000000000..e6410fb088d9f948e42b7346ae1c6e3714baf521 --- /dev/null +++ b/losses_matrn.py @@ -0,0 +1,72 @@ +from fastai.vision import * + +from modules_matrn.model import Model + + +class MultiLosses(nn.Module): + def __init__(self, one_hot=True): + super().__init__() + self.ce = SoftCrossEntropyLoss() if one_hot else torch.nn.CrossEntropyLoss() + self.bce = torch.nn.BCELoss() + + @property + def last_losses(self): + return self.losses + + def _flatten(self, sources, lengths): + return torch.cat([t[:l] for t, l in zip(sources, lengths)]) + + def _merge_list(self, all_res): + if not isinstance(all_res, (list, tuple)): + return all_res + def merge(items): + if isinstance(items[0], torch.Tensor): return torch.cat(items, dim=0) + else: return items[0] + res = dict() + for key in all_res[0].keys(): + items = [r[key] for r in all_res] + res[key] = merge(items) + return res + + def _ce_loss(self, output, gt_labels, gt_lengths, idx=None, record=True): + loss_name = output.get('name') + pt_logits, weight = output['logits'], output['loss_weight'] + + assert pt_logits.shape[0] % gt_labels.shape[0] == 0 + iter_size = pt_logits.shape[0] // gt_labels.shape[0] + if iter_size > 1: + gt_labels = gt_labels.repeat(iter_size, 1, 1) + gt_lengths = gt_lengths.repeat(iter_size) + flat_gt_labels = self._flatten(gt_labels, gt_lengths) + flat_pt_logits = self._flatten(pt_logits, gt_lengths) + + nll = output.get('nll') + if nll is not None: + loss = self.ce(flat_pt_logits, flat_gt_labels, softmax=False) * weight + else: + loss = self.ce(flat_pt_logits, flat_gt_labels) * weight + if record and loss_name is not None: self.losses[f'{loss_name}_loss'] = loss + + return loss + + def forward(self, outputs, *args): + self.losses = {} + if isinstance(outputs, (tuple, list)): + outputs = [self._merge_list(o) for o in outputs] + return sum([self._ce_loss(o, *args) for o in outputs if o['loss_weight'] > 0.]) + else: + return self._ce_loss(outputs, *args, record=False) + + +class SoftCrossEntropyLoss(nn.Module): + def __init__(self, reduction="mean"): + super().__init__() + self.reduction = reduction + + def forward(self, input, target, softmax=True): + if softmax: log_prob = F.log_softmax(input, dim=-1) + else: log_prob = torch.log(input) + loss = -(target * log_prob).sum(dim=-1) + if self.reduction == "mean": return loss.mean() + elif self.reduction == "sum": return loss.sum() + else: return loss diff --git a/model.py b/model.py new file mode 100644 index 0000000000000000000000000000000000000000..46e5a492a887a1f64beab5186ea70f909d59b37b --- /dev/null +++ b/model.py @@ -0,0 +1,302 @@ +""" +Copyright (c) 2019-present NAVER Corp. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modules.transformation import TPS_SpatialTransformerNetwork +from modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor +from modules.sequence_modeling import BidirectionalLSTM +from modules.prediction import Attention +from modules.vitstr import create_vitstr + +import math +import sys +import settings + +# singleChar - if -1 then STRScore outputs all char, however if +# 0 - N, then it will output the single character confidence of the index 0 to N +class STRScore(nn.Module): + def __init__(self, opt, converter, device, gtStr="", enableSingleCharAttrAve=False, model=None): + super(STRScore, self).__init__() + self.enableSingleCharAttrAve = enableSingleCharAttrAve + self.singleChar = -1 + self.opt = opt + self.converter = converter + self.device = device + self.gtStr = gtStr + self.model = model # Pass here if you want to use it + self.blank = torch.tensor([-1], dtype=torch.float).to(self.device) + self.separator = torch.tensor([-2], dtype=torch.float).to(self.device) + + # singleChar - if >=0, then the output of STRScore will only be a single character + # instead of a whole. The char index will be equal to the parameter "singleChar". + def setSingleCharOutput(self, singleChar): + self.singleChar = singleChar + + def forward(self, preds): + bs = preds.shape[0] + # text_for_loss, length_for_loss = self.converter.encode(labels, batch_max_length=self.opt.batch_max_length) + text_for_loss_length = self.opt.batch_max_length + 1 + length_for_pred = torch.IntTensor([self.opt.batch_max_length] * bs).to(self.device) + if 'CTC' in self.opt.Prediction: + # Calculate evaluation loss for CTC decoder. + preds_size = torch.FloatTensor([preds.size(1)] * bs) + if self.opt.baiduCTC: + _, preds_index = preds.max(2) + preds_index = preds_index.view(-1) + else: + _, preds_index = preds.max(2) + # print("preds_index shape: ", preds_index.shape) + preds_str = self.converter.decode(preds_index.data, preds_size.data) + # preds_str = self.converter.decode(preds_index, length_for_pred) + preds = preds.log_softmax(2).permute(1, 0, 2) + elif self.opt.Transformer: + # preds_index = preds_index.view(-1, self.converter.batch_max_length) + # print("preds shape: ", preds.shape) + # print("preds_index: ", preds_index) + # preds_str = self.converter.decode(preds_index, length_for_pred) + if settings.MODEL == 'vitstr': + _, preds_index = preds.topk(1, dim=-1, largest=True, sorted=True) + preds_str = self.converter.decode(preds_index[:, 1:], length_for_pred) + elif settings.MODEL == 'parseq': + preds_str, confidence = self.model.tokenizer.decode(preds) + # print("preds_str: ", preds_str) + else: + preds = preds[:, :text_for_loss_length, :] + + # select max probabilty (greedy decoding) then decode index to character + _, preds_index = preds.max(2) + # print("preds shape: ", preds.shape) + # print("preds_index: ", preds_index) + preds_str = self.converter.decode(preds_index, length_for_pred) + # print("preds_str: ", preds_str) + + # Confidence score + # ARGMAX calculation + sum = torch.FloatTensor([0]*bs).to(self.device) + if self.enableSingleCharAttrAve: + sum = torch.zeros((bs, preds.shape[2])).to(self.device) + if self.opt.confidence_mode == 0: + preds_prob = F.softmax(preds, dim=2) + # preds_prob shape: torch.Size([1, 25, 96]) + preds_max_prob, preds_max_idx = preds_prob.max(dim=2) + # preds_max_prob shape: torch.Size([1, 25]) + confidence_score_list = [] + count = 0 + for one_hot_preds, pred, pred_max_prob in zip(preds_prob, preds_str, preds_max_prob): + if self.opt.Transformer: + if settings.MODEL == 'vitstr': + if self.enableSingleCharAttrAve: + one_hot = one_hot_preds[self.singleChar, :] + pred = pred[self.singleChar] + pred_max_prob = pred_max_prob[self.singleChar] + else: + pred_EOS = pred.find('[s]') + pred = pred[:pred_EOS] + pred_max_prob = pred_max_prob[:pred_EOS] + + # if pred_max_prob.shape[0] == 0: continue + if self.enableSingleCharAttrAve: + sum[count] = one_hot + # sum = one_hot + # sum shape: torch.Size([96]) + # sum = sum.unsqueeze(0) + else: + if self.opt.scorer == "cumprod": + confidence_score = pred_max_prob.cumprod(dim=0)[-1] ### Maximum is 1 + sum[count] += confidence_score + elif self.opt.scorer == "mean": + confidence_score = torch.mean(pred_max_prob) ### Maximum is 1 + sum[count] += confidence_score + sum = sum.unsqueeze(1) + + elif settings.MODEL == 'parseq': + if self.enableSingleCharAttrAve: + one_hot = one_hot_preds[self.singleChar, :] + # pred = pred[self.singleChar] + pred_max_prob = pred_max_prob[self.singleChar] + else: + pred_EOS = len(pred) # Predition string already has no EOS, fully intact + pred_max_prob = pred_max_prob[:pred_EOS] + + # if pred_max_prob.shape[0] == 0: continue + if self.enableSingleCharAttrAve: + sum[count] = one_hot + # sum shape: torch.Size([96]) + # sum = sum.unsqueeze(0) + else: + if self.opt.scorer == "cumprod": + confidence_score = pred_max_prob.cumprod(dim=0)[-1] ### Maximum is 1 + sum[count] += confidence_score + elif self.opt.scorer == "mean": + confidence_score = torch.mean(pred_max_prob) ### Maximum is 1 + sum[count] += confidence_score + sum = sum.unsqueeze(1) + elif 'Attn' in self.opt.Prediction: + # if pred_max_prob.shape[0] == 0: continue + if self.enableSingleCharAttrAve: + one_hot = one_hot_preds[self.singleChar, :] + sum[count] = one_hot + else: + pred_EOS = pred.find('[s]') + pred = pred[:pred_EOS] + pred_max_prob = pred_max_prob[:pred_EOS] ### Use score of all letters + # pred_max_prob = pred_max_prob[0:1] ### Use score of first letter only + if pred_max_prob.shape[0] == 0: continue + confidence_score = pred_max_prob.cumprod(dim=0)[-1] ### Maximum is 1 + sum[count] += confidence_score + sum = sum.unsqueeze(1) + elif 'CTC' in self.opt.Prediction: + confidence_score = pred_max_prob.cumprod(dim=0)[-1] + sum[count] += confidence_score + sum = sum.unsqueeze(1) + count += 1 + # return sum.detach().cpu().numpy() + # print("sumshape: ", sum.shape) + elif self.opt.confidence_mode == 1: + preds_prob = F.softmax(preds, dim=2) + ### Predicted indices + preds_max_prob = torch.argmax(preds_prob, 2) + # print("preds_max_prob shape: ", preds_max_prob.shape) + ### Ground truth indices + gtIndices, _ = self.converter.encode([self.gtStr for i in range(0,preds_prob.shape[0])], batch_max_length=self.opt.batch_max_length-1) + # print("gtIndices shape: ", gtIndices.shape) + ### Acquire levenstein distance + m = torch.tensor([preds_prob.shape[1] for i in range(0, gtIndices.shape[0])], dtype=torch.float).to(self.device) + n = torch.tensor([preds_prob.shape[1] for i in range(0, gtIndices.shape[0])], dtype=torch.float).to(self.device) + # print("m: ", m) + # print("preds_max_prob dtype: ", preds_max_prob.dtype) + # print("gtIndices dtype: ", gtIndices.dtype) + preds_max_prob = preds_max_prob.type(torch.float) + gtIndices = gtIndices.type(torch.float) + r = levenshtein_distance(preds_max_prob.to(self.device), gtIndices.to(self.device), n, m, torch.cat([self.blank, self.separator]), torch.empty([], dtype=torch.float).to(self.device)) + # print("r shape: ", r.shape) + # confidence_score_list = [] + # count = 0 + # for pred, pred_max_prob in zip(preds_str, preds_max_prob): + # if 'Attn' in self.opt.Prediction: + # pred_EOS = pred.find('[s]') + # pred = pred[:pred_EOS] + # pred_max_prob = pred_max_prob[:pred_EOS] ### Use score of all letters + # # pred_max_prob = pred_max_prob[0:1] ### Use score of first letter only + # if pred_max_prob.shape[0] == 0: continue + # confidence_score = pred_max_prob.cumprod(dim=0)[-1] + # sum[count] += confidence_score + # count += 1 + # return sum.detach().cpu().numpy() + # print("sumshape: ", sum.shape) + # sum = sum.unsqueeze(1) + rSoft = F.softmax(r[:,2].type(torch.float)) + # rSoft = rSoft.contiguous() + rNorm = rSoft.max()-rSoft + sum = rNorm.unsqueeze(1) + return sum + +class Model(nn.Module): + + def __init__(self, opt, device=None, converter=None, gt_text=""): + super(Model, self).__init__() + self.opt = opt + self.device = device + self.converter = converter + self.gt_text = gt_text + self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, + 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction, + 'ViTSTR': opt.Transformer} + + """ Transformation """ + if opt.Transformation == 'TPS': + self.Transformation = TPS_SpatialTransformerNetwork( + F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) + else: + print('No Transformation module specified') + + if opt.Transformer: + self.vitstr = create_vitstr(num_tokens=opt.num_class, model=opt.TransformerModel) + return + + """ FeatureExtraction """ + if opt.FeatureExtraction == 'VGG': + self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel) + elif opt.FeatureExtraction == 'RCNN': + self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel) + elif opt.FeatureExtraction == 'ResNet': + self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) + else: + raise Exception('No FeatureExtraction module specified') + self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 + self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 + + """ Sequence modeling""" + if opt.SequenceModeling == 'BiLSTM': + self.SequenceModeling = nn.Sequential( + BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), + BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) + self.SequenceModeling_output = opt.hidden_size + else: + print('No SequenceModeling module specified') + self.SequenceModeling_output = self.FeatureExtraction_output + + """ Prediction """ + if opt.Prediction == 'CTC': + self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) + elif opt.Prediction == 'Attn': + self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) + else: + raise Exception('Prediction is neither CTC or Attn') + def set_labels(self, labels): + self.labels = labels + def patch_embed_func(self): + if self.opt.Transformer: + return self.vitstr.patch_embed_func() + return None + def setGTText(self, text): + self.gt_text = text + def forward(self, input, text="", seqlen=25, is_train=False): + # text = torch.FloatTensor(input.shape[0], self.opt.batch_max_length + 1).fill_(0).to(self.device) + # text = self.converter.encode(self.labels) + if settings.MODEL == 'trba': + text = self.gt_text + if not self.stages['ViTSTR']: + assert(len(text)>0) + """ Transformation stage """ + if not self.stages['Trans'] == "None": + input = self.Transformation(input) + + if self.stages['ViTSTR']: + prediction = self.vitstr(input, seqlen=seqlen) + return prediction + + """ Feature extraction stage """ + visual_feature = self.FeatureExtraction(input) + visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] + visual_feature = visual_feature.squeeze(3) + + """ Sequence modeling stage """ + if self.stages['Seq'] == 'BiLSTM': + contextual_feature = self.SequenceModeling(visual_feature) + else: + contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM + + """ Prediction stage """ + if self.stages['Pred'] == 'CTC': + prediction = self.Prediction(contextual_feature.contiguous()) + else: + prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length) + + return prediction diff --git a/model_abinet.py b/model_abinet.py new file mode 100644 index 0000000000000000000000000000000000000000..1e59eafa1569c1f7913641e0469dfcdda67536a6 --- /dev/null +++ b/model_abinet.py @@ -0,0 +1,80 @@ +""" +Copyright (c) 2019-present NAVER Corp. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import sys +import math + +class STRScore(nn.Module): + def __init__(self, config, charsetMapper, postprocessFunc, device, enableSingleCharAttrAve=False): + super(STRScore, self).__init__() + self.config = config + self.charsetMapper = charsetMapper + self.postprocess = postprocessFunc + self.device = device + self.enableSingleCharAttrAve = enableSingleCharAttrAve + + # singleChar - if >=0, then the output of STRScore will only be a single character + # instead of a whole. The char index will be equal to the parameter "singleChar". + def setSingleCharOutput(self, singleChar): + self.singleChar = singleChar + + ### Output of ABINET model + ### Shape with 1 batchsize: torch.Size([1, 26, 37]) + def forward(self, preds): + # Acquire predicted text + pt_text, _, __ = self.postprocess(preds[0], self.charsetMapper, self.config.model_eval) + preds = preds[0]["logits"] + # preds shape: torch.Size([50, 26, 37]) + # Confidence score + bs = preds.shape[0] + # ARGMAX calculation + sum = torch.FloatTensor([0]*len(preds)).to(self.device) + preds_prob = F.softmax(preds, dim=2) + preds_max_prob, preds_max_index = preds_prob.max(dim=2) + if self.enableSingleCharAttrAve: + preds_max_prob = preds_max_prob[:,self.singleChar] + preds_max_prob = preds_max_prob.unsqueeze(0) + if self.enableSingleCharAttrAve: + sum = torch.zeros((bs, len(self.config.character)-1)).to(self.device) + # print("preds_max_prob shape: ", preds_max_prob.shape) (1,26) + confidence_score_list = [] + count = 0 + for one_hot_preds, pred, pred_max_prob in zip(preds_prob, pt_text, preds_max_prob): + if self.enableSingleCharAttrAve: + one_hot = one_hot_preds[self.singleChar, :] + sum[count] = one_hot + # sum = sum.unsqueeze(0) + else: + pred_EOS = len(pred) + # pred = pred[:pred_EOS] + pred_max_prob = pred_max_prob[:pred_EOS] ### Use score of all letters excluding null char + # pred_max_prob = pred_max_prob[0:1] ### Use score of first letter only + if pred_max_prob.shape[0] == 0: continue + if self.config.scorer == "cumprod": + confidence_score = pred_max_prob.cumprod(dim=0)[-1] ### Maximum is 1 + elif self.config.scorer == "mean": + confidence_score = torch.mean(pred_max_prob) ### Maximum is 1 + sum[count] += confidence_score + count += 1 + if self.enableSingleCharAttrAve: + pass + else: + sum = sum.unsqueeze(1) + return sum diff --git a/model_matrn.py b/model_matrn.py new file mode 100644 index 0000000000000000000000000000000000000000..1e59eafa1569c1f7913641e0469dfcdda67536a6 --- /dev/null +++ b/model_matrn.py @@ -0,0 +1,80 @@ +""" +Copyright (c) 2019-present NAVER Corp. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import sys +import math + +class STRScore(nn.Module): + def __init__(self, config, charsetMapper, postprocessFunc, device, enableSingleCharAttrAve=False): + super(STRScore, self).__init__() + self.config = config + self.charsetMapper = charsetMapper + self.postprocess = postprocessFunc + self.device = device + self.enableSingleCharAttrAve = enableSingleCharAttrAve + + # singleChar - if >=0, then the output of STRScore will only be a single character + # instead of a whole. The char index will be equal to the parameter "singleChar". + def setSingleCharOutput(self, singleChar): + self.singleChar = singleChar + + ### Output of ABINET model + ### Shape with 1 batchsize: torch.Size([1, 26, 37]) + def forward(self, preds): + # Acquire predicted text + pt_text, _, __ = self.postprocess(preds[0], self.charsetMapper, self.config.model_eval) + preds = preds[0]["logits"] + # preds shape: torch.Size([50, 26, 37]) + # Confidence score + bs = preds.shape[0] + # ARGMAX calculation + sum = torch.FloatTensor([0]*len(preds)).to(self.device) + preds_prob = F.softmax(preds, dim=2) + preds_max_prob, preds_max_index = preds_prob.max(dim=2) + if self.enableSingleCharAttrAve: + preds_max_prob = preds_max_prob[:,self.singleChar] + preds_max_prob = preds_max_prob.unsqueeze(0) + if self.enableSingleCharAttrAve: + sum = torch.zeros((bs, len(self.config.character)-1)).to(self.device) + # print("preds_max_prob shape: ", preds_max_prob.shape) (1,26) + confidence_score_list = [] + count = 0 + for one_hot_preds, pred, pred_max_prob in zip(preds_prob, pt_text, preds_max_prob): + if self.enableSingleCharAttrAve: + one_hot = one_hot_preds[self.singleChar, :] + sum[count] = one_hot + # sum = sum.unsqueeze(0) + else: + pred_EOS = len(pred) + # pred = pred[:pred_EOS] + pred_max_prob = pred_max_prob[:pred_EOS] ### Use score of all letters excluding null char + # pred_max_prob = pred_max_prob[0:1] ### Use score of first letter only + if pred_max_prob.shape[0] == 0: continue + if self.config.scorer == "cumprod": + confidence_score = pred_max_prob.cumprod(dim=0)[-1] ### Maximum is 1 + elif self.config.scorer == "mean": + confidence_score = torch.mean(pred_max_prob) ### Maximum is 1 + sum[count] += confidence_score + count += 1 + if self.enableSingleCharAttrAve: + pass + else: + sum = sum.unsqueeze(1) + return sum diff --git a/model_srn.py b/model_srn.py new file mode 100644 index 0000000000000000000000000000000000000000..93130ab4bf1e14aa1702643cdcda25293e8e549d --- /dev/null +++ b/model_srn.py @@ -0,0 +1,280 @@ +""" +Copyright (c) 2019-present NAVER Corp. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modules_srn.transformation import TPS_SpatialTransformerNetwork +from modules_srn.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor +from modules_srn.sequence_modeling import BidirectionalLSTM +from modules_srn.prediction import Attention +from modules_srn.resnet_aster import ResNet_ASTER + +from modules_srn.bert import Bert_Ocr +from modules_srn.bert import Config + +from modules_srn.SRN_modules import Transforme_Encoder, SRN_Decoder, Torch_transformer_encoder +from modules_srn.resnet_fpn import ResNet_FPN +import settings +import sys + +# singleChar - if -1 then STRScore outputs all char, however if +# 0 - N, then it will output the single character confidence of the index 0 to N +class STRScore(nn.Module): + def __init__(self, opt, converter, device, gtStr="", enableSingleCharAttrAve=False, model=None): + super(STRScore, self).__init__() + self.enableSingleCharAttrAve = enableSingleCharAttrAve + self.singleChar = -1 + self.opt = opt + self.converter = converter + self.device = device + self.gtStr = gtStr + self.model = model # Pass here if you want to use it + self.blank = torch.tensor([-1], dtype=torch.float).to(self.device) + self.separator = torch.tensor([-2], dtype=torch.float).to(self.device) + + # singleChar - if >=0, then the output of STRScore will only be a single character + # instead of a whole. The char index will be equal to the parameter "singleChar". + def setSingleCharOutput(self, singleChar): + self.singleChar = singleChar + + def forward(self, preds): + preds = preds[2] # Access second index + bs = preds.shape[0] + # text_for_loss, length_for_loss = self.converter.encode(labels, batch_max_length=self.opt.batch_max_length) + text_for_loss_length = self.opt.batch_max_length + 1 + + # _, preds_index = preds.topk(1, dim=-1, largest=True, sorted=True) + # preds_index = preds_index.view(-1, self.converter.batch_max_length) + # print("preds shape: ", preds.shape) + # print("preds_index: ", preds_index) + # preds_str = self.converter.decode(preds_index, length_for_pred) + if settings.MODEL == 'vitstr': + preds_str = self.converter.decode(preds_index[:, 1:], length_for_pred) + elif settings.MODEL == 'srn': + _, preds_index = preds.max(2) + length_for_pred = torch.IntTensor([self.opt.batch_max_length] * bs).to(self.device) + preds_str = self.converter.decode(preds_index, length_for_pred) + # sys.exit() + elif settings.MODEL == 'parseq': + preds_str, confidence = self.model.tokenizer.decode(preds) + + # Confidence score + # ARGMAX calculation + sum = torch.FloatTensor([0]*bs).to(self.device) + if self.opt.confidence_mode == 0: + preds_prob = F.softmax(preds, dim=2) + # preds_prob shape: torch.Size([1, 25, 96]) + preds_max_prob, preds_max_idx = preds_prob.max(dim=2) + # preds_max_prob shape: torch.Size([1, 25]) + confidence_score_list = [] + count = 0 + for one_hot_preds, pred, pred_max_prob in zip(preds_prob, preds_str, preds_max_prob): + if settings.MODEL == 'vitstr' or settings.MODEL == 'srn': + if self.enableSingleCharAttrAve: + one_hot = one_hot_preds[self.singleChar, :] + # pred = pred[self.singleChar] + pred_max_prob = pred_max_prob[self.singleChar] + else: + pred_EOS = pred.find('[s]') + pred = pred[:pred_EOS] + pred_max_prob = pred_max_prob[:pred_EOS] + + # if pred_max_prob.shape[0] == 0: continue + if self.enableSingleCharAttrAve: + sum = one_hot + # sum shape: torch.Size([96]) + sum = sum.unsqueeze(0) + else: + if self.opt.scorer == "cumprod": + confidence_score = pred_max_prob.cumprod(dim=0)[-1] ### Maximum is 1 + sum[count] += confidence_score + elif self.opt.scorer == "mean": + confidence_score = torch.mean(pred_max_prob) ### Maximum is 1 + sum[count] += confidence_score + sum = sum.unsqueeze(1) + + elif settings.MODEL == 'parseq': + if self.enableSingleCharAttrAve: + one_hot = one_hot_preds[self.singleChar, :] + # pred = pred[self.singleChar] + pred_max_prob = pred_max_prob[self.singleChar] + else: + pred_EOS = len(pred) # Predition string already has no EOS, fully intact + pred_max_prob = pred_max_prob[:pred_EOS] + + # if pred_max_prob.shape[0] == 0: continue + if self.enableSingleCharAttrAve: + sum = one_hot + # sum shape: torch.Size([96]) + sum = sum.unsqueeze(0) + else: + if self.opt.scorer == "cumprod": + confidence_score = pred_max_prob.cumprod(dim=0)[-1] ### Maximum is 1 + sum[count] += confidence_score + elif self.opt.scorer == "mean": + confidence_score = torch.mean(pred_max_prob) ### Maximum is 1 + sum[count] += confidence_score + sum = sum.unsqueeze(1) + count += 1 + # return sum.detach().cpu().numpy() + # print("sumshape: ", sum.shape) + elif self.opt.confidence_mode == 1: + preds_prob = F.softmax(preds, dim=2) + ### Predicted indices + preds_max_prob = torch.argmax(preds_prob, 2) + # print("preds_max_prob shape: ", preds_max_prob.shape) + ### Ground truth indices + gtIndices, _ = self.converter.encode([self.gtStr for i in range(0,preds_prob.shape[0])], batch_max_length=self.opt.batch_max_length-1) + # print("gtIndices shape: ", gtIndices.shape) + ### Acquire levenstein distance + m = torch.tensor([preds_prob.shape[1] for i in range(0, gtIndices.shape[0])], dtype=torch.float).to(self.device) + n = torch.tensor([preds_prob.shape[1] for i in range(0, gtIndices.shape[0])], dtype=torch.float).to(self.device) + # print("m: ", m) + # print("preds_max_prob dtype: ", preds_max_prob.dtype) + # print("gtIndices dtype: ", gtIndices.dtype) + preds_max_prob = preds_max_prob.type(torch.float) + gtIndices = gtIndices.type(torch.float) + r = levenshtein_distance(preds_max_prob.to(self.device), gtIndices.to(self.device), n, m, torch.cat([self.blank, self.separator]), torch.empty([], dtype=torch.float).to(self.device)) + # print("r shape: ", r.shape) + # confidence_score_list = [] + # count = 0 + # for pred, pred_max_prob in zip(preds_str, preds_max_prob): + # if 'Attn' in self.opt.Prediction: + # pred_EOS = pred.find('[s]') + # pred = pred[:pred_EOS] + # pred_max_prob = pred_max_prob[:pred_EOS] ### Use score of all letters + # # pred_max_prob = pred_max_prob[0:1] ### Use score of first letter only + # if pred_max_prob.shape[0] == 0: continue + # confidence_score = pred_max_prob.cumprod(dim=0)[-1] + # sum[count] += confidence_score + # count += 1 + # return sum.detach().cpu().numpy() + # print("sumshape: ", sum.shape) + # sum = sum.unsqueeze(1) + rSoft = F.softmax(r[:,2].type(torch.float)) + # rSoft = rSoft.contiguous() + rNorm = rSoft.max()-rSoft + sum = rNorm.unsqueeze(1) + return sum + +class Model(nn.Module): + + def __init__(self, opt): + super(Model, self).__init__() + self.opt = opt + self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, + 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} + + """ Transformation """ + if opt.Transformation == 'TPS': + self.Transformation = TPS_SpatialTransformerNetwork( + F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) + else: + print('No Transformation module specified') + + """ FeatureExtraction """ + if opt.FeatureExtraction == 'VGG': + self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel) + elif opt.FeatureExtraction == 'RCNN': + self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel) + elif opt.FeatureExtraction == 'ResNet': + self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) + self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 + elif opt.FeatureExtraction == 'AsterRes': + self.FeatureExtraction = ResNet_ASTER(opt.input_channel, opt.output_channel) + elif opt.FeatureExtraction == 'ResnetFpn': + self.FeatureExtraction = ResNet_FPN() + else: + raise Exception('No FeatureExtraction module specified') + self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 + + + """ Sequence modeling""" + if opt.SequenceModeling == 'BiLSTM': + self.SequenceModeling = nn.Sequential( + BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), + BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) + self.SequenceModeling_output = opt.hidden_size + elif opt.SequenceModeling == 'Bert': + cfg = Config() + cfg.dim = opt.output_channel; cfg.dim_c = opt.output_channel # 降维减少计算量 + cfg.p_dim = opt.position_dim # 一张图片cnn编码之后的特征序列长度 + cfg.max_vocab_size = opt.batch_max_length + 1 # 一张图片中最多的文字个数, +1 for EOS + cfg.len_alphabet = opt.alphabet_size # 文字的类别个数 + self.SequenceModeling = Bert_Ocr(cfg) + elif opt.SequenceModeling == 'SRN': + self.SequenceModeling = Transforme_Encoder(n_layers=2, n_position=opt.position_dim) + # self.SequenceModeling = Torch_transformer_encoder(n_layers=2, n_position=opt.position_dim) + self.SequenceModeling_output = 512 + else: + print('No SequenceModeling module specified') + self.SequenceModeling_output = self.FeatureExtraction_output + + """ Prediction """ + if opt.Prediction == 'CTC': + self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) + elif opt.Prediction == 'Attn': + self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) + elif opt.Prediction == 'Bert_pred': + pass + elif opt.Prediction == 'SRN': + self.Prediction = SRN_Decoder(n_position=opt.position_dim, N_max_character=opt.batch_max_character + 1, n_class=opt.alphabet_size) + else: + raise Exception('Prediction is neither CTC or Attn') + + def forward(self, input, text=None, is_train=True): + """ Transformation stage """ + if not self.stages['Trans'] == "None": + input = self.Transformation(input) + + + """ Feature extraction stage """ + visual_feature = self.FeatureExtraction(input) + # if self.stages['Feat'] == 'AsterRes' or self.stages['Feat'] == 'ResnetFpn': + if self.stages['Feat'] == 'AsterRes' or self.stages['Feat'] == 'ResnetFpn': + b, c, h, w = visual_feature.shape + visual_feature = visual_feature.permute(0, 1, 3, 2) + visual_feature = visual_feature.contiguous().view(b, c, -1) + visual_feature = visual_feature.permute(0, 2, 1) # batch, seq, feature + else: + visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] + visual_feature = visual_feature.squeeze(3) + + + """ Sequence modeling stage """ + if self.stages['Seq'] == 'BiLSTM': + contextual_feature = self.SequenceModeling(visual_feature) + elif self.stages['Seq'] == 'Bert': + pad_mask = text + contextual_feature = self.SequenceModeling(visual_feature, pad_mask) + elif self.stages['Seq'] == 'SRN': + contextual_feature = self.SequenceModeling(visual_feature, src_mask=None)[0] + else: + contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM + + + """ Prediction stage """ + if self.stages['Pred'] == 'CTC': + prediction = self.Prediction(contextual_feature.contiguous()) + elif self.stages['Pred'] == 'Bert_pred': + prediction = contextual_feature + elif self.stages['Pred'] == 'SRN': + prediction = self.Prediction(contextual_feature) + else: + prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length) + + return prediction diff --git a/model_trba.py b/model_trba.py new file mode 100644 index 0000000000000000000000000000000000000000..1a3509992cee15e86120d097ae70a63f050cdcab --- /dev/null +++ b/model_trba.py @@ -0,0 +1,328 @@ +""" +Copyright (c) 2019-present NAVER Corp. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch.nn as nn + +from modules_trba.transformation import TPS_SpatialTransformerNetwork +from modules_trba.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor +from modules_trba.sequence_modeling import BidirectionalLSTM +from modules_trba.prediction import Attention +import numpy as np +import torch +import torch.nn.functional as F +import random +import copy +# from torch_edit_distance import levenshtein_distance + +class STRScore(nn.Module): + def __init__(self, opt, converter, device, gtStr="", enableSingleCharAttrAve=False): + super(STRScore, self).__init__() + self.opt = opt + self.converter = converter + self.device = device + self.gtStr = gtStr + self.enableSingleCharAttrAve = enableSingleCharAttrAve + self.blank = torch.tensor([-1], dtype=torch.float).to(self.device) + self.separator = torch.tensor([-2], dtype=torch.float).to(self.device) + + # singleChar - if >=0, then the output of STRScore will only be a single character + # instead of a whole. The char index will be equal to the parameter "singleChar". + def setSingleCharOutput(self, singleChar): + self.singleChar = singleChar + + def forward(self, preds): + bs = preds.shape[0] + # text_for_loss, length_for_loss = self.converter.encode(labels, batch_max_length=self.opt.batch_max_length) + text_for_loss_length = self.opt.batch_max_length + 1 + length_for_pred = torch.IntTensor([self.opt.batch_max_length] * bs).to(self.device) + if 'CTC' in self.opt.Prediction: + # Calculate evaluation loss for CTC decoder. + preds_size = torch.FloatTensor([preds.size(1)] * bs) + if self.opt.baiduCTC: + _, preds_index = preds.max(2) + preds_index = preds_index.view(-1) + else: + _, preds_index = preds.max(2) + # print("preds_index shape: ", preds_index.shape) + preds_str = self.converter.decode(preds_index.data, preds_size.data) + # preds_str = self.converter.decode(preds_index, length_for_pred) + preds = preds.log_softmax(2).permute(1, 0, 2) + else: + preds = preds[:, :text_for_loss_length, :] + + # select max probabilty (greedy decoding) then decode index to character + _, preds_index = preds.max(2) + # print("preds shape: ", preds.shape) + # print("preds_index: ", preds_index) + preds_str = self.converter.decode(preds_index, length_for_pred) + # print("preds_str: ", preds_str) + + # Confidence score + # ARGMAX calculation + sum = torch.FloatTensor([0]*bs).to(self.device) + if self.enableSingleCharAttrAve: + sum = torch.zeros((bs, preds.shape[2])).to(self.device) + if self.opt.confidence_mode == 0: + preds_prob = F.softmax(preds, dim=2) + # print("preds_prob shape: ", preds_prob.shape) + preds_max_prob, _ = preds_prob.max(dim=2) + # print("preds_max_prob shape: ", preds_max_prob.shape) + confidence_score_list = [] + count = 0 + for one_hot_preds, pred, pred_max_prob in zip(preds_prob, preds_str, preds_max_prob): + if 'Attn' in self.opt.Prediction: + if self.enableSingleCharAttrAve: + one_hot = one_hot_preds[self.singleChar, :] + sum[count] = one_hot + else: + pred_EOS = pred.find('[s]') + pred = pred[:pred_EOS] + pred_max_prob = pred_max_prob[:pred_EOS] ### Use score of all letters + # pred_max_prob = pred_max_prob[0:1] ### Use score of first letter only + if pred_max_prob.shape[0] == 0: continue + if self.opt.scorer == "cumprod": + confidence_score = pred_max_prob.cumprod(dim=0)[-1] ### Maximum is 1 + elif self.opt.scorer == "mean": + confidence_score = torch.mean(pred_max_prob) ### Maximum is 1 + sum[count] += confidence_score + sum = sum.unsqueeze(1) + elif 'CTC' in self.opt.Prediction: + if self.opt.scorer == "cumprod": + confidence_score = pred_max_prob.cumprod(dim=0)[-1] ### Maximum is 1 + elif self.opt.scorer == "mean": + confidence_score = torch.mean(pred_max_prob) ### Maximum is 1 + sum[count] += confidence_score + sum = sum.unsqueeze(1) + count += 1 + # return sum.detach().cpu().numpy() + # print("sumshape: ", sum.shape) + elif self.opt.confidence_mode == 1: + preds_prob = F.softmax(preds, dim=2) + ### Predicted indices + preds_max_prob = torch.argmax(preds_prob, 2) + # print("preds_max_prob shape: ", preds_max_prob.shape) + ### Ground truth indices + gtIndices, _ = self.converter.encode([self.gtStr for i in range(0,preds_prob.shape[0])], batch_max_length=self.opt.batch_max_length-1) + # print("gtIndices shape: ", gtIndices.shape) + ### Acquire levenstein distance + m = torch.tensor([preds_prob.shape[1] for i in range(0, gtIndices.shape[0])], dtype=torch.float).to(self.device) + n = torch.tensor([preds_prob.shape[1] for i in range(0, gtIndices.shape[0])], dtype=torch.float).to(self.device) + # print("m: ", m) + # print("preds_max_prob dtype: ", preds_max_prob.dtype) + # print("gtIndices dtype: ", gtIndices.dtype) + preds_max_prob = preds_max_prob.type(torch.float) + gtIndices = gtIndices.type(torch.float) + r = levenshtein_distance(preds_max_prob.to(self.device), gtIndices.to(self.device), n, m, torch.cat([self.blank, self.separator]), torch.empty([], dtype=torch.float).to(self.device)) + # print("r shape: ", r.shape) + # confidence_score_list = [] + # count = 0 + # for pred, pred_max_prob in zip(preds_str, preds_max_prob): + # if 'Attn' in self.opt.Prediction: + # pred_EOS = pred.find('[s]') + # pred = pred[:pred_EOS] + # pred_max_prob = pred_max_prob[:pred_EOS] ### Use score of all letters + # # pred_max_prob = pred_max_prob[0:1] ### Use score of first letter only + # if pred_max_prob.shape[0] == 0: continue + # confidence_score = pred_max_prob.cumprod(dim=0)[-1] + # sum[count] += confidence_score + # count += 1 + # return sum.detach().cpu().numpy() + # print("sumshape: ", sum.shape) + # sum = sum.unsqueeze(1) + rSoft = F.softmax(r[:,2].type(torch.float)) + # rSoft = rSoft.contiguous() + rNorm = rSoft.max()-rSoft + sum = rNorm.unsqueeze(1) + print("sum shape: ", sum.shape) + return sum + +class SuperPixler(nn.Module): + def __init__(self, n_super_pixel, imageList, super_pixel_width, super_pixel_height, opt): + super(SuperPixler, self).__init__() + self.opt = opt + self.imageList = imageList + self.n_super_pixel = n_super_pixel + # self.image = image + # self.image = image.transpose(2, 0, 1) # model expects images in BRG, not RGB, so transpose color channels + # self.mean_color = self.image.mean() + # self.image = np.expand_dims(self.image, axis=0) + self.super_pixel_width = super_pixel_width + self.super_pixel_height = super_pixel_height + # def setImage(self, image): + # self.image = image + # self.image_height = image.shape[2] + # self.image_width = image.shape[3] + def sampleImages(self, size): + newImgList = [] + for i in range(0, size): + randIdx = random.randint(0, len(self.imageList)-1) + newImgList.append(copy.deepcopy(self.imageList[randIdx])) + return np.array(newImgList) + def forward(self, x): + """ + In the forward step we accept the super pixel masks and transform them to a batch of images + """ + # x = self.sampleMasks(image.shape[0]) + image = self.sampleImages(x.shape[0]) + self.image = image + self.image_height = image.shape[2] + self.image_width = image.shape[3] + self.mean_color = self.image.mean() + # self.mean_color = self.image.mean(axis=(1,2,3)) + # pixeled_image = np.repeat(self.image.copy(), x.shape[0], axis=0)# WARNING: + pixeled_image = self.image.copy() + # print("pixeled_image shape: ", pixeled_image.shape) + # print("x shape: ", x.shape) + for i, super_pixel in enumerate(x.T): + images_to_pixelate = [bool(p) for p in super_pixel] + # print("super_pixel shape: ", super_pixel.shape) + # print("images_to_pixelate len: ", len(images_to_pixelate)) + # print("i: {}, superPix: {}, images_to_pixelate: {}".format(i, super_pixel, images_to_pixelate)) + x = (i*self.super_pixel_height//self.image_height)*self.super_pixel_width + y = i*self.super_pixel_height%self.image_height + ### Reshape image means since it has n-dim size, not a single number. Need to repeat sideways. + # origShapeToApply = pixeled_image[images_to_pixelate,:,y:y+self.super_pixel_height,x:x+self.super_pixel_width].shape + # print("origShapeToApply: ", origShapeToApply) + # mean_color_spec = np.tile(self.mean_color, origShapeToApply[1:]) # + # mean_color_spec = np.reshape(mean_color_spec, origShapeToApply[::-1]).T ### reshape to reversed + ### Apply image means + pixeled_image[images_to_pixelate,:,y:y+self.super_pixel_height,x:x+self.super_pixel_width] = self.mean_color + return pixeled_image + +class CastNumpy(nn.Module): + def __init__(self, device): + super(CastNumpy, self).__init__() + self.device = device + + def forward(self, image): + """ + In the forward function we accept the inputs and cast them to a pytorch tensor + """ + + image = np.ascontiguousarray(image) + image = torch.from_numpy(image).to(self.device) + if image.ndimension() == 3: + image = image.unsqueeze(0) + image_half = image.half() + return image_half.float() + +class Model(nn.Module): + + def __init__(self, opt, device, feature_ext_outputs=None): + super(Model, self).__init__() + self.opt = opt + self.device = device + self.gtText = None + self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, + 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} + + """ Transformation """ + if opt.Transformation == 'TPS': + self.Transformation = TPS_SpatialTransformerNetwork( + F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) + else: + print('No Transformation module specified') + + """ FeatureExtraction """ + if opt.FeatureExtraction == 'VGG': + self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel) + elif opt.FeatureExtraction == 'RCNN': + self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel) + elif opt.FeatureExtraction == 'ResNet': + self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) + else: + raise Exception('No FeatureExtraction module specified') + self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 + self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 + + """ Sequence modeling""" + if opt.SequenceModeling == 'BiLSTM': + self.SequenceModeling = nn.Sequential( + BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), + BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) + self.SequenceModeling_output = opt.hidden_size + else: + print('No SequenceModeling module specified') + self.SequenceModeling_output = self.FeatureExtraction_output + + """ Prediction """ + if opt.Prediction == 'CTC': + self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) + elif opt.Prediction == 'Attn': + self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) + else: + raise Exception('Prediction is neither CTC or Attn') + + ### Set feature map outputter modules + + if opt.output_feat_maps: + feature_ext_outputs.set_feature_ext(self.FeatureExtraction) + ### Define hooks + feature_ext_outputs = feature_ext_outputs + totalCNNLayers = 0 + idxToOutput = [] + layersList = [] + + layerCount = 0 + # print("list(self.FeatureExtraction._modules.items()): ", list(self.FeatureExtraction._modules.items())) + # print("list(self.FeatureExtraction.ConvNet_modules.items())[0][1]: ", list(self.FeatureExtraction.ConvNet._modules.items())[0][1]) + first_layer = list(self.FeatureExtraction.ConvNet._modules.items())[0][1] + first_layer.register_backward_hook(feature_ext_outputs.append_first_grads) + for layer in self.FeatureExtraction.modules(): + if isinstance(layer, nn.Conv2d): + layerCount += 1 + if layerCount >= opt.min_layer_out and layerCount <= opt.max_layer_out: + layer.register_forward_hook(feature_ext_outputs.append_layer_out) + layer.register_backward_hook(feature_ext_outputs.append_grad_out) + # def get_feature_ext(self): + # return self.FeatureExtraction + def setGTText(self, text): + self.gtText = text + def forward(self, input, text="", is_train=True): + if self.opt.is_shap: + text = torch.LongTensor(input.shape[0], self.opt.batch_max_length + 1).fill_(0).to(self.device) + elif self.gtText is not None: + text = self.gtText + else: + text = torch.LongTensor(input.shape[0], self.opt.batch_max_length + 1).fill_(0).to(self.device) + # print("text shape: ", text.shape) (1,26) + tpsOut = input.contiguous() + """ Transformation stage """ + if not self.stages['Trans'] == "None": + tpsOut = self.Transformation(tpsOut) + # print("Transformation feature shape: ", input.shape) + + """ Feature extraction stage """ + visual_feature = self.FeatureExtraction(tpsOut) + visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] + visual_feature = visual_feature.squeeze(3) + # print("visual feature shape: ", visual_feature.shape) + + """ Sequence modeling stage """ + if self.stages['Seq'] == 'BiLSTM': + contextual_feature = self.SequenceModeling(visual_feature) + else: + contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM + # print("Sequence feature shape: ", contextual_feature.shape) + + """ Prediction stage """ + if self.stages['Pred'] == 'CTC': + prediction = self.Prediction(contextual_feature.contiguous()) + else: + prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length) + # print("prediction feature shape: ", prediction.shape) + # return prediction, tpsOut + return prediction diff --git a/modules/feature_extraction.py b/modules/feature_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..4316165eef38b14f8495226548e18684658d3c44 --- /dev/null +++ b/modules/feature_extraction.py @@ -0,0 +1,253 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class VGG_FeatureExtractor(nn.Module): + """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super(VGG_FeatureExtractor, self).__init__() + self.output_channel = [int(output_channel / 8), int(output_channel / 4), + int(output_channel / 2), output_channel] # [64, 128, 256, 512] + self.ConvNet = nn.Sequential( + nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d(2, 2), # 64x16x50 + nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d(2, 2), # 128x8x25 + nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25 + nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 + nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False), + nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25 + nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False), + nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), + nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 + nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 + + def forward(self, input): + return self.ConvNet(input) + + +class RCNN_FeatureExtractor(nn.Module): + """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super(RCNN_FeatureExtractor, self).__init__() + self.output_channel = [int(output_channel / 8), int(output_channel / 4), + int(output_channel / 2), output_channel] # [64, 128, 256, 512] + self.ConvNet = nn.Sequential( + nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d(2, 2), # 64 x 16 x 50 + GRCL(self.output_channel[0], self.output_channel[0], num_iteration=5, kernel_size=3, pad=1), + nn.MaxPool2d(2, 2), # 64 x 8 x 25 + GRCL(self.output_channel[0], self.output_channel[1], num_iteration=5, kernel_size=3, pad=1), + nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26 + GRCL(self.output_channel[1], self.output_channel[2], num_iteration=5, kernel_size=3, pad=1), + nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27 + nn.Conv2d(self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False), + nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26 + + def forward(self, input): + return self.ConvNet(input) + + +class ResNet_FeatureExtractor(nn.Module): + """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super(ResNet_FeatureExtractor, self).__init__() + self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) + + def forward(self, input): + return self.ConvNet(input) + + +# For Gated RCNN +class GRCL(nn.Module): + + def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad): + super(GRCL, self).__init__() + self.wgf_u = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False) + self.wgr_x = nn.Conv2d(output_channel, output_channel, 1, 1, 0, bias=False) + self.wf_u = nn.Conv2d(input_channel, output_channel, kernel_size, 1, pad, bias=False) + self.wr_x = nn.Conv2d(output_channel, output_channel, kernel_size, 1, pad, bias=False) + + self.BN_x_init = nn.BatchNorm2d(output_channel) + + self.num_iteration = num_iteration + self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)] + self.GRCL = nn.Sequential(*self.GRCL) + + def forward(self, input): + """ The input of GRCL is consistant over time t, which is denoted by u(0) + thus wgf_u / wf_u is also consistant over time t. + """ + wgf_u = self.wgf_u(input) + wf_u = self.wf_u(input) + x = F.relu(self.BN_x_init(wf_u)) + + for i in range(self.num_iteration): + x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x)) + + return x + + +class GRCL_unit(nn.Module): + + def __init__(self, output_channel): + super(GRCL_unit, self).__init__() + self.BN_gfu = nn.BatchNorm2d(output_channel) + self.BN_grx = nn.BatchNorm2d(output_channel) + self.BN_fu = nn.BatchNorm2d(output_channel) + self.BN_rx = nn.BatchNorm2d(output_channel) + self.BN_Gx = nn.BatchNorm2d(output_channel) + + def forward(self, wgf_u, wgr_x, wf_u, wr_x): + G_first_term = self.BN_gfu(wgf_u) + G_second_term = self.BN_grx(wgr_x) + G = F.sigmoid(G_first_term + G_second_term) + + x_first_term = self.BN_fu(wf_u) + x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G) + x = F.relu(x_first_term + x_second_term) + + return x + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = self._conv3x3(inplanes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = self._conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU() + self.relu2 = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def _conv3x3(self, in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu2(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, input_channel, output_channel, block, layers): + super(ResNet, self).__init__() + + self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] + + self.inplanes = int(output_channel / 8) + self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) + self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_2 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU() + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.relu3 = nn.ReLU() + self.relu4 = nn.ReLU() + self.relu5 = nn.ReLU() + self.relu6 = nn.ReLU() + + self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) + self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ + 0], kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) + + self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) + self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ + 1], kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) + + self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) + self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) + self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ + 2], kernel_size=3, stride=1, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) + + self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) + self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) + self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) + self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=1, padding=0, bias=False) + self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv0_1(x) + x = self.bn0_1(x) + x = self.relu(x) + x = self.conv0_2(x) + x = self.bn0_2(x) + x = self.relu1(x) + + x = self.maxpool1(x) + x = self.layer1(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu2(x) + + x = self.maxpool2(x) + x = self.layer2(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu3(x) + + x = self.maxpool3(x) + x = self.layer3(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu4(x) + + x = self.layer4(x) + x = self.conv4_1(x) + x = self.bn4_1(x) + x = self.relu5(x) + x = self.conv4_2(x) + x = self.bn4_2(x) + x = self.relu6(x) + + return x diff --git a/modules/prediction.py b/modules/prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..b6c3cb32541ea4ea5f22377ce5062e5ea34ff2f7 --- /dev/null +++ b/modules/prediction.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +class Attention(nn.Module): + + def __init__(self, input_size, hidden_size, num_classes): + super(Attention, self).__init__() + self.attention_cell = AttentionCell(input_size, hidden_size, num_classes) + self.hidden_size = hidden_size + self.num_classes = num_classes + self.generator = nn.Linear(hidden_size, num_classes) + + def _char_to_onehot(self, input_char, onehot_dim=38): + input_char = input_char.unsqueeze(1) + batch_size = input_char.size(0) + one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device) + one_hot = one_hot.scatter_(1, input_char, 1) + return one_hot + + def forward(self, batch_H, text, is_train=True, batch_max_length=25): + """ + input: + batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels] + text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. + output: probability distribution at each step [batch_size x num_steps x num_classes] + """ + batch_size = batch_H.size(0) + num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. + + output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device) + hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), + torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device)) + + if is_train: + for i in range(num_steps): + # one-hot vectors for a i-th char. in a batch + char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes) + # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) + hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) + output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) + probs = self.generator(output_hiddens) + + else: + targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token + probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device) + + for i in range(num_steps): + char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) + probs_step = self.generator(hidden[0]) + probs[:, i, :] = probs_step + _, next_input = probs_step.max(1) + targets = next_input + + return probs # batch_size x num_steps x num_classes + + +class AttentionCell(nn.Module): + + def __init__(self, input_size, hidden_size, num_embeddings): + super(AttentionCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias=False) + self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias + self.score = nn.Linear(hidden_size, 1, bias=False) + self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) + e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 + + alpha = F.softmax(e, dim=1) + context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel + concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) + cur_hidden = self.rnn(concat_context, prev_hidden) + return cur_hidden, alpha diff --git a/modules/sequence_modeling.py b/modules/sequence_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..af32c59b2cc981be1b43412ddf4ac853d0611210 --- /dev/null +++ b/modules/sequence_modeling.py @@ -0,0 +1,19 @@ +import torch.nn as nn + + +class BidirectionalLSTM(nn.Module): + + def __init__(self, input_size, hidden_size, output_size): + super(BidirectionalLSTM, self).__init__() + self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) + self.linear = nn.Linear(hidden_size * 2, output_size) + + def forward(self, input): + """ + input : visual feature [batch_size x T x input_size] + output : contextual feature [batch_size x T x output_size] + """ + self.rnn.flatten_parameters() + recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) + output = self.linear(recurrent) # batch_size x T x output_size + return output diff --git a/modules/transformation.py b/modules/transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..716e8085970bf43443930ff2f22dacb5f8af3d18 --- /dev/null +++ b/modules/transformation.py @@ -0,0 +1,163 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +class TPS_SpatialTransformerNetwork(nn.Module): + """ Rectification Network of RARE, namely TPS based STN """ + + def __init__(self, F, I_size, I_r_size, I_channel_num=1): + """ Based on RARE TPS + input: + batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] + I_size : (height, width) of the input image I + I_r_size : (height, width) of the rectified image I_r + I_channel_num : the number of channels of the input image I + output: + batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] + """ + super(TPS_SpatialTransformerNetwork, self).__init__() + self.F = F + self.I_size = I_size + self.I_r_size = I_r_size # = (I_r_height, I_r_width) + self.I_channel_num = I_channel_num + self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) + self.GridGenerator = GridGenerator(self.F, self.I_r_size) + + def forward(self, batch_I): + batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 + build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 + build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) + + if torch.__version__ > "1.2.0": + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) + else: + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') + + return batch_I_r + + +class LocalizationNetwork(nn.Module): + """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ + + def __init__(self, F, I_channel_num): + super(LocalizationNetwork, self).__init__() + self.F = F + self.I_channel_num = I_channel_num + self.conv = nn.Sequential( + nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, + bias=False), nn.BatchNorm2d(64), nn.ReLU(), + nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 + nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(), + nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 + nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(), + nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 + nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(), + nn.AdaptiveAvgPool2d(1) # batch_size x 512 + ) + + self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU()) + self.localization_fc2 = nn.Linear(256, self.F * 2) + + # Init fc2 in LocalizationNetwork + self.localization_fc2.weight.data.fill_(0) + """ see RARE paper Fig. 6 (a) """ + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) + ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) + + def forward(self, batch_I): + """ + input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] + output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] + """ + batch_size = batch_I.size(0) + features = self.conv(batch_I).view(batch_size, -1) + batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) + return batch_C_prime + + +class GridGenerator(nn.Module): + """ Grid Generator of RARE, which produces P_prime by multipling T with P """ + + def __init__(self, F, I_r_size): + """ Generate P_hat and inv_delta_C for later """ + super(GridGenerator, self).__init__() + self.eps = 1e-6 + self.I_r_height, self.I_r_width = I_r_size + self.F = F + self.C = self._build_C(self.F) # F x 2 + self.P = self._build_P(self.I_r_width, self.I_r_height) + ## for multi-gpu, you need register buffer + self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 + self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 + ## for fine-tuning with different image width, you may use below instead of self.register_buffer + #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3 + #self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3 + + def _build_C(self, F): + """ Return coordinates of fiducial points in I_r; C """ + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = -1 * np.ones(int(F / 2)) + ctrl_pts_y_bottom = np.ones(int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + return C # F x 2 + + def _build_inv_delta_C(self, F, C): + """ Return inv_delta_C which is needed to calculate T """ + hat_C = np.zeros((F, F), dtype=float) # F x F + for i in range(0, F): + for j in range(i, F): + r = np.linalg.norm(C[i] - C[j]) + hat_C[i, j] = r + hat_C[j, i] = r + np.fill_diagonal(hat_C, 1) + hat_C = (hat_C ** 2) * np.log(hat_C) + # print(C.shape, hat_C.shape) + delta_C = np.concatenate( # F+3 x F+3 + [ + np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 + np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 + np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 + ], + axis=0 + ) + inv_delta_C = np.linalg.inv(delta_C) + return inv_delta_C # F+3 x F+3 + + def _build_P(self, I_r_width, I_r_height): + I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width + I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height + P = np.stack( # self.I_r_width x self.I_r_height x 2 + np.meshgrid(I_r_grid_x, I_r_grid_y), + axis=2 + ) + return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 + + def _build_P_hat(self, F, C, P): + n = P.shape[0] # n (= self.I_r_width x self.I_r_height) + P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 + C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 + P_diff = P_tile - C_tile # n x F x 2 + rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F + rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F + P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) + return P_hat # n x F+3 + + def build_P_prime(self, batch_C_prime): + """ Generate Grid from batch_C_prime [batch_size x F x 2] """ + batch_size = batch_C_prime.size(0) + batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) + batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) + batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( + batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2 + batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 + batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 + return batch_P_prime # batch_size x n x 2 diff --git a/modules/vitstr.py b/modules/vitstr.py new file mode 100644 index 0000000000000000000000000000000000000000..01d4f590a2fa02472ae74c81263922d0e73f10d9 --- /dev/null +++ b/modules/vitstr.py @@ -0,0 +1,240 @@ +''' +Implementation of ViTSTR based on timm VisionTransformer. + +TODO: +1) distilled deit backbone +2) base deit backbone + +Copyright 2021 Rowel Atienza +''' + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import logging +import torch.utils.model_zoo as model_zoo + +from copy import deepcopy +from functools import partial +from timm.models.vision_transformer import VisionTransformer, _cfg +from timm.models.registry import register_model +from timm.models import create_model + +_logger = logging.getLogger(__name__) + +__all__ = [ + 'vitstr_tiny_patch16_224', + 'vitstr_small_patch16_224', + 'vitstr_base_patch16_224', + #'vitstr_tiny_distilled_patch16_224', + #'vitstr_small_distilled_patch16_224', + #'vitstr_base_distilled_patch16_224', +] + +def create_vitstr(num_tokens, model=None, checkpoint_path=''): + vitstr = create_model( + model, + pretrained=True, + num_classes=num_tokens, + checkpoint_path=checkpoint_path) + + # might need to run to get zero init head for transfer learning + vitstr.reset_classifier(num_classes=num_tokens) + + return vitstr + +class ViTSTR(VisionTransformer): + ''' + ViTSTR is basically a ViT that uses DeiT weights. + Modified head to support a sequence of characters prediction for STR. + ''' + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def reset_classifier(self, num_classes): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def patch_embed_func(self): + return self.patch_embed + + def forward_features(self, x): + B = x.shape[0] + # print("prevx shape: ", x.shape) ### (1, 224, 224) + x = self.patch_embed(x) + # print("new x shape: ", x.shape) ### (1, 196, 768) + # patchsize is 16X16 so there are 14X14 grids=196. + # 768 - embedding size + # self.cls_token shape: torch.Size([1, 1, 768]) + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + # self.pos_embed shape: torch.Size([1, 197, 768])] + x = x + self.pos_embed + # + self.pos_embed shape: torch.Size([1, 197, 768]) + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + # blocks shape: torch.Size([1, 197, 768]) ALLL + x = self.norm(x) + # norm shape: torch.Size([1, 197, 768]) + return x + + def forward(self, x, seqlen=25): + x = self.forward_features(x) + x = x[:, :seqlen] + # seqlen shape: torch.Size([1, 25, 768]) + + # batch, seqlen, embsize + b, s, e = x.size() + x = x.reshape(b*s, e) + # reshaped shape: torch.Size([25, 768]) + x = self.head(x).view(b, s, self.num_classes) + return x + + +def load_pretrained(model, cfg=None, num_classes=1000, in_chans=1, filter_fn=None, strict=True): + ''' + Loads a pretrained checkpoint + From an older version of timm + ''' + if cfg is None: + cfg = getattr(model, 'default_cfg') + if cfg is None or 'url' not in cfg or not cfg['url']: + _logger.warning("Pretrained model URL is invalid, using random initialization.") + return + + state_dict = model_zoo.load_url(cfg['url'], progress=True, map_location='cpu') + if "model" in state_dict.keys(): + state_dict = state_dict["model"] + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + if in_chans == 1: + conv1_name = cfg['first_conv'] + _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name) + key = conv1_name + '.weight' + if key in state_dict.keys(): + _logger.info('(%s) key found in state_dict' % key) + conv1_weight = state_dict[conv1_name + '.weight'] + else: + _logger.info('(%s) key NOT found in state_dict' % key) + return + # Some weights are in torch.half, ensure it's float for sum on CPU + conv1_type = conv1_weight.dtype + conv1_weight = conv1_weight.float() + O, I, J, K = conv1_weight.shape + if I > 3: + assert conv1_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) + conv1_weight = conv1_weight.sum(dim=2, keepdim=False) + else: + conv1_weight = conv1_weight.sum(dim=1, keepdim=True) + conv1_weight = conv1_weight.to(conv1_type) + state_dict[conv1_name + '.weight'] = conv1_weight + + classifier_name = cfg['classifier'] + if num_classes == 1000 and cfg['num_classes'] == 1001: + # special case for imagenet trained models with extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[1:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[1:] + elif num_classes != cfg['num_classes']: + # completely discard fully connected for all other differences between pretrained and created model + del state_dict[classifier_name + '.weight'] + del state_dict[classifier_name + '.bias'] + strict = False + + print("Loading pre-trained vision transformer weights from %s ..." % cfg['url']) + model.load_state_dict(state_dict, strict=strict) + + +def _conv_filter(state_dict, patch_size=16): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k: + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + return out_dict + +@register_model +def vitstr_tiny_patch16_224(pretrained=False, **kwargs): + kwargs['in_chans'] = 1 + model = ViTSTR( + patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, **kwargs) + + model.default_cfg = _cfg( + #url='https://github.com/roatienza/public/releases/download/v0.1-deit-tiny/deit_tiny_patch16_224-a1311bcf.pth' + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth' + ) + + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) + return model + +@register_model +def vitstr_small_patch16_224(pretrained=False, **kwargs): + kwargs['in_chans'] = 1 + model = ViTSTR( + patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, **kwargs) + model.default_cfg = _cfg( + #url="https://github.com/roatienza/public/releases/download/v0.1-deit-small/deit_small_patch16_224-cd65a155.pth" + url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth" + ) + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) + return model + +@register_model +def vitstr_base_patch16_224(pretrained=False, **kwargs): + kwargs['in_chans'] = 1 + model = ViTSTR( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs) + model.default_cfg = _cfg( + #url='https://github.com/roatienza/public/releases/download/v0.1-deit-base/deit_base_patch16_224-b5f2ef4d.pth' + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth' + ) + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) + return model + +# below is work in progress +@register_model +def vitstr_tiny_distilled_patch16_224(pretrained=False, **kwargs): + kwargs['in_chans'] = 1 + #kwargs['distilled'] = True + model = ViTSTR( + patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, **kwargs) + model.default_cfg = _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth' + ) + + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) + return model + + +@register_model +def vitstr_small_distilled_patch16_224(pretrained=False, **kwargs): + kwargs['in_chans'] = 1 + kwargs['distilled'] = True + model = ViTSTR( + patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, **kwargs) + model.default_cfg = _cfg( + url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth" + ) + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) + return model diff --git a/modules_abinet/__init__.py b/modules_abinet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules_abinet/attention.py b/modules_abinet/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..15f38657df601c65527e6708f7acbdacd0866ae6 --- /dev/null +++ b/modules_abinet/attention.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +from .transformer import PositionalEncoding + +class Attention(nn.Module): + def __init__(self, in_channels=512, max_length=25, n_feature=256): + super().__init__() + self.max_length = max_length + + self.f0_embedding = nn.Embedding(max_length, in_channels) + self.w0 = nn.Linear(max_length, n_feature) + self.wv = nn.Linear(in_channels, in_channels) + self.we = nn.Linear(in_channels, max_length) + + self.active = nn.Tanh() + self.softmax = nn.Softmax(dim=2) + + def forward(self, enc_output): + enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2) + reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device) + reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S) + reading_order_embed = self.f0_embedding(reading_order) # b,25,512 + + t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256 + t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512 + + attn = self.we(t) # b,256,25 + attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256 + g_output = torch.bmm(attn, enc_output) # b,25,512 + return g_output, attn.view(*attn.shape[:2], 8, 32) + + +def encoder_layer(in_c, out_c, k=3, s=2, p=1): + return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p), + nn.BatchNorm2d(out_c), + nn.ReLU()) + +def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None): + align_corners = None if mode=='nearest' else True + return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor, + mode=mode, align_corners=align_corners), + nn.Conv2d(in_c, out_c, k, s, p), + nn.BatchNorm2d(out_c), + nn.ReLU()) + + +class PositionAttention(nn.Module): + def __init__(self, max_length, in_channels=512, num_channels=64, + h=8, w=32, mode='nearest', **kwargs): + super().__init__() + self.max_length = max_length + self.k_encoder = nn.Sequential( + encoder_layer(in_channels, num_channels, s=(1, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)) + ) + self.k_decoder = nn.Sequential( + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, in_channels, size=(h, w), mode=mode) + ) + + self.pos_encoder = PositionalEncoding(in_channels, dropout=0, max_len=max_length) + self.project = nn.Linear(in_channels, in_channels) + + def forward(self, x): + N, E, H, W = x.size() + k, v = x, x # (N, E, H, W) + + # calculate key vector + features = [] + for i in range(0, len(self.k_encoder)): + k = self.k_encoder[i](k) + features.append(k) + for i in range(0, len(self.k_decoder) - 1): + k = self.k_decoder[i](k) + k = k + features[len(self.k_decoder) - 2 - i] + k = self.k_decoder[-1](k) + + # calculate query vector + # TODO q=f(q,k) + zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E) + q = self.pos_encoder(zeros) # (T, N, E) + q = q.permute(1, 0, 2) # (N, T, E) + q = self.project(q) # (N, T, E) + + # calculate attention + attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W)) + attn_scores = attn_scores / (E ** 0.5) + attn_scores = torch.softmax(attn_scores, dim=-1) + + v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E) + attn_vecs = torch.bmm(attn_scores, v) # (N, T, E) + + return attn_vecs, attn_scores.view(N, -1, H, W) diff --git a/modules_abinet/backbone.py b/modules_abinet/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..5430ca47acb14e3956802888694b27debc5013bf --- /dev/null +++ b/modules_abinet/backbone.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +from fastai.vision import * + +from modules_abinet.model import _default_tfmer_cfg +from modules_abinet.resnet import resnet45 +from modules_abinet.transformer import (PositionalEncoding, + TransformerEncoder, + TransformerEncoderLayer) + + +class ResTranformer(nn.Module): + def __init__(self, config): + super().__init__() + self.resnet = resnet45() + + self.d_model = ifnone(config.model_vision_d_model, _default_tfmer_cfg['d_model']) + nhead = ifnone(config.model_vision_nhead, _default_tfmer_cfg['nhead']) + d_inner = ifnone(config.model_vision_d_inner, _default_tfmer_cfg['d_inner']) + dropout = ifnone(config.model_vision_dropout, _default_tfmer_cfg['dropout']) + activation = ifnone(config.model_vision_activation, _default_tfmer_cfg['activation']) + num_layers = ifnone(config.model_vision_backbone_ln, 2) + + self.pos_encoder = PositionalEncoding(self.d_model, max_len=8*32) + encoder_layer = TransformerEncoderLayer(d_model=self.d_model, nhead=nhead, + dim_feedforward=d_inner, dropout=dropout, activation=activation) + self.transformer = TransformerEncoder(encoder_layer, num_layers) + + def forward(self, images): + feature = self.resnet(images) + n, c, h, w = feature.shape + feature = feature.view(n, c, -1).permute(2, 0, 1) + feature = self.pos_encoder(feature) + feature = self.transformer(feature) + feature = feature.permute(1, 2, 0).view(n, c, h, w) + return feature diff --git a/modules_abinet/model.py b/modules_abinet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f71136d94f4b994ab34dd52521b892c234b9902c --- /dev/null +++ b/modules_abinet/model.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn + +from utils_abinet import CharsetMapper + + +_default_tfmer_cfg = dict(d_model=512, nhead=8, d_inner=2048, # 1024 + dropout=0.1, activation='relu') + +class Model(nn.Module): + + def __init__(self, config): + super().__init__() + self.max_length = config.dataset_max_length + 1 + self.charset = CharsetMapper(config.dataset_charset_path, max_length=self.max_length) + + def load(self, source, device=None, strict=True): + state = torch.load(source, map_location=device) + self.load_state_dict(state['model'], strict=strict) + + def _get_length(self, logit, dim=-1): + """ Greed decoder to obtain length from logit""" + out = (logit.argmax(dim=-1) == self.charset.null_label) + abn = out.any(dim) + out = ((out.cumsum(dim) == 1) & out).max(dim)[1] + out = out + 1 # additional end token + out = torch.where(abn, out, out.new_tensor(logit.shape[1])) + return out + + @staticmethod + def _get_padding_mask(length, max_length): + length = length.unsqueeze(-1) + grid = torch.arange(0, max_length, device=length.device).unsqueeze(0) + return grid >= length + + @staticmethod + def _get_square_subsequent_mask(sz, device, diagonal=0, fw=True): + r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ + mask = (torch.triu(torch.ones(sz, sz, device=device), diagonal=diagonal) == 1) + if fw: mask = mask.transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + @staticmethod + def _get_location_mask(sz, device=None): + mask = torch.eye(sz, device=device) + mask = mask.float().masked_fill(mask == 1, float('-inf')) + return mask diff --git a/modules_abinet/model_abinet.py b/modules_abinet/model_abinet.py new file mode 100644 index 0000000000000000000000000000000000000000..34c37b64ac4814b868483e3027d6ecf88b62c1bb --- /dev/null +++ b/modules_abinet/model_abinet.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +from fastai.vision import * + +from .model_vision import BaseVision +from .model_language import BCNLanguage +from .model_alignment import BaseAlignment + + +class ABINetModel(nn.Module): + def __init__(self, config): + super().__init__() + self.use_alignment = ifnone(config.model_use_alignment, True) + self.max_length = config.dataset_max_length + 1 # additional stop token + self.vision = BaseVision(config) + self.language = BCNLanguage(config) + if self.use_alignment: self.alignment = BaseAlignment(config) + + def forward(self, images, *args): + v_res = self.vision(images) + v_tokens = torch.softmax(v_res['logits'], dim=-1) + v_lengths = v_res['pt_lengths'].clamp_(2, self.max_length) # TODO:move to langauge model + + l_res = self.language(v_tokens, v_lengths) + if not self.use_alignment: + return l_res, v_res + l_feature, v_feature = l_res['feature'], v_res['feature'] + + a_res = self.alignment(l_feature, v_feature) + return a_res, l_res, v_res diff --git a/modules_abinet/model_abinet_iter.py b/modules_abinet/model_abinet_iter.py new file mode 100644 index 0000000000000000000000000000000000000000..6588890570a1180ea32f7969cefd4b59c25409a7 --- /dev/null +++ b/modules_abinet/model_abinet_iter.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +from fastai.vision import * + +from .model_vision import BaseVision +from .model_language import BCNLanguage +from .model_alignment import BaseAlignment + + +class ABINetIterModel(nn.Module): + def __init__(self, config): + super().__init__() + self.iter_size = ifnone(config.model_iter_size, 1) + self.max_length = config.dataset_max_length + 1 # additional stop token + self.vision = BaseVision(config) + self.language = BCNLanguage(config) + self.alignment = BaseAlignment(config) + + def forward(self, images, *args): + v_res = self.vision(images) + a_res = v_res + all_l_res, all_a_res = [], [] + for _ in range(self.iter_size): + tokens = torch.softmax(a_res['logits'], dim=-1) + lengths = a_res['pt_lengths'] + lengths.clamp_(2, self.max_length) # TODO:move to langauge model + l_res = self.language(tokens, lengths) + all_l_res.append(l_res) + a_res = self.alignment(l_res['feature'], v_res['feature']) + all_a_res.append(a_res) + if self.training: + return all_a_res, all_l_res, v_res + else: + return a_res, all_l_res[-1], v_res diff --git a/modules_abinet/model_alignment.py b/modules_abinet/model_alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..00779fc504b23597b1e38e3f0cb15a05c650c54c --- /dev/null +++ b/modules_abinet/model_alignment.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +from fastai.vision import * + +from modules_abinet.model import Model, _default_tfmer_cfg + + +class BaseAlignment(Model): + def __init__(self, config): + super().__init__(config) + d_model = ifnone(config.model_alignment_d_model, _default_tfmer_cfg['d_model']) + + self.loss_weight = ifnone(config.model_alignment_loss_weight, 1.0) + self.max_length = config.dataset_max_length + 1 # additional stop token + self.w_att = nn.Linear(2 * d_model, d_model) + self.cls = nn.Linear(d_model, self.charset.num_classes) + + def forward(self, l_feature, v_feature): + """ + Args: + l_feature: (N, T, E) where T is length, N is batch size and d is dim of model + v_feature: (N, T, E) shape the same as l_feature + l_lengths: (N,) + v_lengths: (N,) + """ + f = torch.cat((l_feature, v_feature), dim=2) + f_att = torch.sigmoid(self.w_att(f)) + output = f_att * v_feature + (1 - f_att) * l_feature + + logits = self.cls(output) # (N, T, C) + pt_lengths = self._get_length(logits) + + return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight':self.loss_weight, + 'name': 'alignment'} diff --git a/modules_abinet/model_language.py b/modules_abinet/model_language.py new file mode 100644 index 0000000000000000000000000000000000000000..96682e91ba5cb9bddb5514fee31552dd60fc9558 --- /dev/null +++ b/modules_abinet/model_language.py @@ -0,0 +1,67 @@ +import logging +import torch.nn as nn +from fastai.vision import * + +from modules_abinet.model import _default_tfmer_cfg +from modules_abinet.model import Model +from modules_abinet.transformer import (PositionalEncoding, + TransformerDecoder, + TransformerDecoderLayer) + + +class BCNLanguage(Model): + def __init__(self, config): + super().__init__(config) + d_model = ifnone(config.model_language_d_model, _default_tfmer_cfg['d_model']) + nhead = ifnone(config.model_language_nhead, _default_tfmer_cfg['nhead']) + d_inner = ifnone(config.model_language_d_inner, _default_tfmer_cfg['d_inner']) + dropout = ifnone(config.model_language_dropout, _default_tfmer_cfg['dropout']) + activation = ifnone(config.model_language_activation, _default_tfmer_cfg['activation']) + num_layers = ifnone(config.model_language_num_layers, 4) + self.d_model = d_model + self.detach = ifnone(config.model_language_detach, True) + self.use_self_attn = ifnone(config.model_language_use_self_attn, False) + self.loss_weight = ifnone(config.model_language_loss_weight, 1.0) + self.max_length = config.dataset_max_length + 1 # additional stop token + self.debug = ifnone(config.global_debug, False) + + self.proj = nn.Linear(self.charset.num_classes, d_model, False) + self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length) + self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length) + decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout, + activation, self_attn=self.use_self_attn, debug=self.debug) + self.model = TransformerDecoder(decoder_layer, num_layers) + + self.cls = nn.Linear(d_model, self.charset.num_classes) + + if config.model_language_checkpoint is not None: + logging.info(f'Read language model from {config.model_language_checkpoint}.') + self.load(config.model_language_checkpoint) + + def forward(self, tokens, lengths): + """ + Args: + tokens: (N, T, C) where T is length, N is batch size and C is classes number + lengths: (N,) + """ + if self.detach: tokens = tokens.detach() + embed = self.proj(tokens) # (N, T, E) + embed = embed.permute(1, 0, 2) # (T, N, E) + embed = self.token_encoder(embed) # (T, N, E) + padding_mask = self._get_padding_mask(lengths, self.max_length) + + zeros = embed.new_zeros(*embed.shape) + qeury = self.pos_encoder(zeros) + location_mask = self._get_location_mask(self.max_length, tokens.device) + output = self.model(qeury, embed, + tgt_key_padding_mask=padding_mask, + memory_mask=location_mask, + memory_key_padding_mask=padding_mask) # (T, N, E) + output = output.permute(1, 0, 2) # (N, T, E) + + logits = self.cls(output) # (N, T, C) + pt_lengths = self._get_length(logits) + + res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths, + 'loss_weight':self.loss_weight, 'name': 'language'} + return res diff --git a/modules_abinet/model_vision.py b/modules_abinet/model_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..5b2ecd1681cfc459a39a23caab2409f5791c3ccd --- /dev/null +++ b/modules_abinet/model_vision.py @@ -0,0 +1,47 @@ +import logging +import torch.nn as nn +from fastai.vision import * + +from modules_abinet.attention import * +from modules_abinet.backbone import ResTranformer +from modules_abinet.model import Model +from modules_abinet.resnet import resnet45 + + +class BaseVision(Model): + def __init__(self, config): + super().__init__(config) + self.loss_weight = ifnone(config.model_vision_loss_weight, 1.0) + self.out_channels = ifnone(config.model_vision_d_model, 512) + + if config.model_vision_backbone == 'transformer': + self.backbone = ResTranformer(config) + else: self.backbone = resnet45() + + if config.model_vision_attention == 'position': + mode = ifnone(config.model_vision_attention_mode, 'nearest') + self.attention = PositionAttention( + max_length=config.dataset_max_length + 1, # additional stop token + mode=mode, + ) + elif config.model_vision_attention == 'attention': + self.attention = Attention( + max_length=config.dataset_max_length + 1, # additional stop token + n_feature=8*32, + ) + else: + raise Exception(f'{config.model_vision_attention} is not valid.') + self.cls = nn.Linear(self.out_channels, self.charset.num_classes) + + if config.model_vision_checkpoint is not None: + logging.info(f'Read vision model from {config.model_vision_checkpoint}.') + self.load(config.model_vision_checkpoint) + + def forward(self, images, *args): + features = self.backbone(images) # (N, E, H, W) + attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W) + logits = self.cls(attn_vecs) # (N, T, C) + pt_lengths = self._get_length(logits) + + return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths, + 'attn_scores': attn_scores, 'loss_weight':self.loss_weight, 'name': 'vision'} diff --git a/modules_abinet/resnet.py b/modules_abinet/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d990b272479866df5bc4bbf6a26cf487aa43097a --- /dev/null +++ b/modules_abinet/resnet.py @@ -0,0 +1,104 @@ +import math + +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo + + +def conv1x1(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv1x1(inplanes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = F.relu_(out.clone()) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = F.relu_(out.clone()) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers): + self.inplanes = 32 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, 32, layers[0], stride=2) + self.layer2 = self._make_layer(block, 64, layers[1], stride=1) + self.layer3 = self._make_layer(block, 128, layers[2], stride=2) + self.layer4 = self._make_layer(block, 256, layers[3], stride=1) + self.layer5 = self._make_layer(block, 512, layers[4], stride=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu_(x.clone()) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.layer5(x) + return x + + +def resnet45(): + return ResNet(BasicBlock, [3, 4, 6, 6, 3]) diff --git a/modules_abinet/transformer.py b/modules_abinet/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..96017382156895754e2c3da8d46e4f5de6fe3516 --- /dev/null +++ b/modules_abinet/transformer.py @@ -0,0 +1,901 @@ +# pytorch 1.5.0 +import copy +import math +import warnings +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList, Parameter +from torch.nn import functional as F +from torch.nn.init import constant_, xavier_uniform_ + + +def multi_head_attention_forward(query, # type: Tensor + key, # type: Tensor + value, # type: Tensor + embed_dim_to_check, # type: int + num_heads, # type: int + in_proj_weight, # type: Tensor + in_proj_bias, # type: Tensor + bias_k, # type: Optional[Tensor] + bias_v, # type: Optional[Tensor] + add_zero_attn, # type: bool + dropout_p, # type: float + out_proj_weight, # type: Tensor + out_proj_bias, # type: Tensor + training=True, # type: bool + key_padding_mask=None, # type: Optional[Tensor] + need_weights=True, # type: bool + attn_mask=None, # type: Optional[Tensor] + use_separate_proj_weight=False, # type: bool + q_proj_weight=None, # type: Optional[Tensor] + k_proj_weight=None, # type: Optional[Tensor] + v_proj_weight=None, # type: Optional[Tensor] + static_k=None, # type: Optional[Tensor] + static_v=None # type: Optional[Tensor] + ): + # type: (...) -> Tuple[Tensor, Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + # if not torch.jit.is_scripting(): + # tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, + # out_proj_weight, out_proj_bias) + # if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + # return handle_torch_function( + # multi_head_attention_forward, tens_ops, query, key, value, + # embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, + # bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, + # out_proj_bias, training=training, key_padding_mask=key_padding_mask, + # need_weights=need_weights, attn_mask=attn_mask, + # use_separate_proj_weight=use_separate_proj_weight, + # q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, + # v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v) + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size() == value.size() + + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if not use_separate_proj_weight: + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = F.linear(query, _w, _b) + + if key is None: + assert value is None + k = None + v = None + else: + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = F.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = F.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = F.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = F.linear(value, _w, _b) + else: + q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) + len1, len2 = q_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == query.size(-1) + + k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) + len1, len2 = k_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == key.size(-1) + + v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) + len1, len2 = v_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == value.size(-1) + + if in_proj_bias is not None: + q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) + k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) + v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) + else: + q = F.linear(query, q_proj_weight_non_opt, in_proj_bias) + k = F.linear(key, k_proj_weight_non_opt, in_proj_bias) + v = F.linear(value, v_proj_weight_non_opt, in_proj_bias) + q = q * scaling + + if attn_mask is not None: + assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ + attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \ + 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) + if attn_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) + # attn_mask's dim is 3 now. + + # # convert ByteTensor key_padding_mask to bool + # if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + # warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + # key_padding_mask = key_padding_mask.to(torch.bool) + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float('-inf')) + else: + attn_output_weights += attn_mask + + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) + + attn_output_weights = F.softmax( + attn_output_weights, dim=-1) + attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + Note: if kdim and vdim are None, they will be set to embed_dim such that + query, key, and value have the same number of features. + Examples:: + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + # __annotations__ = { + # 'bias_k': torch._jit_internal.Optional[torch.Tensor], + # 'bias_v': torch._jit_internal.Optional[torch.Tensor], + # } + __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'] + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + if self._qkv_same_embed_dim is False: + self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) + self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) + self.register_parameter('in_proj_weight', None) + else: + self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) + self.register_parameter('q_proj_weight', None) + self.register_parameter('k_proj_weight', None) + self.register_parameter('v_proj_weight', None) + + if bias: + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) + else: + self.register_parameter('in_proj_bias', None) + self.out_proj = Linear(embed_dim, embed_dim, bias=bias) + + if add_bias_kv: + self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) + self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.) + constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if '_qkv_same_embed_dim' not in state: + state['_qkv_same_embed_dim'] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward(self, query, key, value, key_padding_mask=None, + need_weights=True, attn_mask=None): + # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if not self._qkv_same_embed_dim: + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight) + else: + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask) + + +class Transformer(Module): + r"""A transformer model. User is able to modify the attributes as needed. The architecture + is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, + Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and + Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information + Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805) + model with corresponding parameters. + + Args: + d_model: the number of expected features in the encoder/decoder inputs (default=512). + nhead: the number of heads in the multiheadattention models (default=8). + num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). + num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu). + custom_encoder: custom encoder (default=None). + custom_decoder: custom decoder (default=None). + + Examples:: + >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) + >>> src = torch.rand((10, 32, 512)) + >>> tgt = torch.rand((20, 32, 512)) + >>> out = transformer_model(src, tgt) + + Note: A full example to apply nn.Transformer module for the word language model is available in + https://github.com/pytorch/examples/tree/master/word_language_model + """ + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", custom_encoder=None, custom_decoder=None): + super(Transformer, self).__init__() + + if custom_encoder is not None: + self.encoder = custom_encoder + else: + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation) + encoder_norm = LayerNorm(d_model) + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + if custom_decoder is not None: + self.decoder = custom_decoder + else: + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation) + decoder_norm = LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def forward(self, src, tgt, src_mask=None, tgt_mask=None, + memory_mask=None, src_key_padding_mask=None, + tgt_key_padding_mask=None, memory_key_padding_mask=None): + # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor # noqa + r"""Take in and process masked source/target sequences. + + Args: + src: the sequence to the encoder (required). + tgt: the sequence to the decoder (required). + src_mask: the additive mask for the src sequence (optional). + tgt_mask: the additive mask for the tgt sequence (optional). + memory_mask: the additive mask for the encoder output (optional). + src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). + tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). + memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). + + Shape: + - src: :math:`(S, N, E)`. + - tgt: :math:`(T, N, E)`. + - src_mask: :math:`(S, S)`. + - tgt_mask: :math:`(T, T)`. + - memory_mask: :math:`(T, S)`. + - src_key_padding_mask: :math:`(N, S)`. + - tgt_key_padding_mask: :math:`(N, T)`. + - memory_key_padding_mask: :math:`(N, S)`. + + Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by + the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero + positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + + - output: :math:`(T, N, E)`. + + Note: Due to the multi-head attention architecture in the transformer model, + the output sequence length of a transformer is same as the input sequence + (i.e. target) length of the decode. + + where S is the source sequence length, T is the target sequence length, N is the + batch size, E is the feature number + + Examples: + >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) + """ + + if src.size(1) != tgt.size(1): + raise RuntimeError("the batch number of src and tgt must be equal") + + if src.size(2) != self.d_model or tgt.size(2) != self.d_model: + raise RuntimeError("the feature number of src and tgt must be equal to d_model") + + memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) + output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + return output + + def generate_square_subsequent_mask(self, sz): + r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + def _reset_parameters(self): + r"""Initiate parameters in the transformer model.""" + + for p in self.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + + +class TransformerEncoder(Module): + r"""TransformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ['norm'] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, mask=None, src_key_padding_mask=None): + # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + output = src + + for i, mod in enumerate(self.layers): + output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(Module): + r"""TransformerDecoder is a stack of N decoder layers + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = transformer_decoder(tgt, memory) + """ + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm=None): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, tgt, memory, memory2=None, tgt_mask=None, + memory_mask=None, memory_mask2=None, tgt_key_padding_mask=None, + memory_key_padding_mask=None, memory_key_padding_mask2=None): + # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor + r"""Pass the inputs (and mask) through the decoder layer in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + output = tgt + + for mod in self.layers: + output = mod(output, memory, memory2=memory2, tgt_mask=tgt_mask, + memory_mask=memory_mask, memory_mask2=memory_mask2, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + memory_key_padding_mask2=memory_key_padding_mask2) + + if self.norm is not None: + output = self.norm(output) + + return output + +class TransformerEncoderLayer(Module): + r"""TransformerEncoderLayer is made up of self-attn and feedforward network. + This standard encoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", debug=False): + super(TransformerEncoderLayer, self).__init__() + self.debug = debug + self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + src2, attn = self.self_attn(src, src, src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask) + if self.debug: self.attn = attn + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + + return src + + +class TransformerDecoderLayer(Module): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", self_attn=True, siamese=False, debug=False): + super(TransformerDecoderLayer, self).__init__() + self.has_self_attn, self.siamese = self_attn, siamese + self.debug = debug + if self.has_self_attn: + self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) + self.norm1 = LayerNorm(d_model) + self.dropout1 = Dropout(dropout) + self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model) + + self.norm2 = LayerNorm(d_model) + self.norm3 = LayerNorm(d_model) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + if self.siamese: + self.multihead_attn2 = MultiheadAttention(d_model, nhead, dropout=dropout) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super(TransformerDecoderLayer, self).__setstate__(state) + + def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, + tgt_key_padding_mask=None, memory_key_padding_mask=None, + memory2=None, memory_mask2=None, memory_key_padding_mask2=None): + # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor + r"""Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + if self.has_self_attn: + tgt2, attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + if self.debug: self.attn = attn + tgt2, attn2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + if self.debug: self.attn2 = attn2 + + if self.siamese: + tgt3, attn3 = self.multihead_attn2(tgt, memory2, memory2, attn_mask=memory_mask2, + key_padding_mask=memory_key_padding_mask2) + tgt = tgt + self.dropout2(tgt3) + if self.debug: self.attn3 = attn3 + + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + + return tgt + + +def _get_clones(module, N): + return ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + + +class PositionalEncoding(nn.Module): + r"""Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + r"""Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + + x = x + self.pe[:x.size(0), :] + return self.dropout(x) + + +if __name__ == '__main__': + transformer_model = Transformer(nhead=16, num_encoder_layers=12) + src = torch.rand((10, 32, 512)) + tgt = torch.rand((20, 32, 512)) + out = transformer_model(src, tgt) + print(out) diff --git a/modules_matrn/attention.py b/modules_matrn/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a5f0bbab93b21d54662fd5cc14f7dc545b72bf --- /dev/null +++ b/modules_matrn/attention.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +from .transformer import PositionalEncoding + +class Attention(nn.Module): + def __init__(self, in_channels=512, max_length=25, n_feature=256): + super().__init__() + self.max_length = max_length + + self.f0_embedding = nn.Embedding(max_length, in_channels) + self.w0 = nn.Linear(max_length, n_feature) + self.wv = nn.Linear(in_channels, in_channels) + self.we = nn.Linear(in_channels, max_length) + + self.active = nn.Tanh() + self.softmax = nn.Softmax(dim=2) + + def forward(self, enc_output): + enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2) + reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device) + reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S) + reading_order_embed = self.f0_embedding(reading_order) # b,25,512 + + t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256 + t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512 + + attn = self.we(t) # b,256,25 + attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256 + g_output = torch.bmm(attn, enc_output) # b,25,512 + return g_output, attn.view(*attn.shape[:2], 8, 32) + + +def encoder_layer(in_c, out_c, k=3, s=2, p=1): + return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p), + nn.BatchNorm2d(out_c), + nn.ReLU()) + +def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None): + align_corners = None if mode=='nearest' else True + return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor, + mode=mode, align_corners=align_corners), + nn.Conv2d(in_c, out_c, k, s, p), + nn.BatchNorm2d(out_c), + nn.ReLU()) + + +class PositionAttention(nn.Module): + def __init__(self, max_length, in_channels=512, num_channels=64, + h=8, w=32, mode='nearest', **kwargs): + super().__init__() + self.max_length = max_length + self.k_encoder = nn.Sequential( + encoder_layer(in_channels, num_channels, s=(1, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)) + ) + self.k_decoder = nn.Sequential( + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, in_channels, size=(h, w), mode=mode) + ) + self.pos_encoder = PositionalEncoding(in_channels, dropout=0, max_len=max_length) + self.project = nn.Linear(in_channels, in_channels) + + def forward(self, x, q=None): + N, E, H, W = x.size() + k, v = x, x # (N, E, H, W) + + # calculate key vector + features = [] + for i in range(0, len(self.k_encoder)): + k = self.k_encoder[i](k) + features.append(k) + for i in range(0, len(self.k_decoder) - 1): + k = self.k_decoder[i](k) + k = k + features[len(self.k_decoder) - 2 - i] + k = self.k_decoder[-1](k) + + # calculate query vector + # TODO q=f(q,k) + if q is None: + zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E) + q = self.pos_encoder(zeros) # (T, N, E) + q = q.permute(1, 0, 2) # (N, T, E) + q = self.project(q) # (N, T, E) + + # calculate attention + attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W)) + attn_scores = attn_scores / (E ** 0.5) + attn_scores = torch.softmax(attn_scores, dim=-1) + + v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E) + attn_vecs = torch.bmm(attn_scores, v) # (N, T, E) + + return attn_vecs, attn_scores.view(N, -1, H, W) diff --git a/modules_matrn/backbone.py b/modules_matrn/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..a932b6885ce282eb4e3436be82d9ece1da2aea64 --- /dev/null +++ b/modules_matrn/backbone.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from fastai.vision import * + +from modules_matrn.model import _default_tfmer_cfg +from modules_matrn.resnet import resnet45 +from modules_matrn.transformer import (PositionalEncoding, + TransformerEncoder, + TransformerEncoderLayer) + + +class ResTranformer(nn.Module): + def __init__(self, config): + super().__init__() + self.resnet = resnet45() + + self.d_model = ifnone(config.model_vision_d_model, _default_tfmer_cfg['d_model']) + nhead = ifnone(config.model_vision_nhead, _default_tfmer_cfg['nhead']) + d_inner = ifnone(config.model_vision_d_inner, _default_tfmer_cfg['d_inner']) + dropout = ifnone(config.model_vision_dropout, _default_tfmer_cfg['dropout']) + activation = ifnone(config.model_vision_activation, _default_tfmer_cfg['activation']) + num_layers = ifnone(config.model_vision_backbone_ln, 2) + + self.pos_encoder = PositionalEncoding(self.d_model, max_len=8*32) + encoder_layer = TransformerEncoderLayer(d_model=self.d_model, nhead=nhead, + dim_feedforward=d_inner, dropout=dropout, activation=activation) + self.transformer = TransformerEncoder(encoder_layer, num_layers) + + def forward(self, images): + feature = self.resnet(images) + n, c, h, w = feature.shape + feature = feature.view(n, c, -1).permute(2, 0, 1) + feature = self.pos_encoder(feature) + feature = self.transformer(feature) + feature = feature.permute(1, 2, 0).view(n, c, h, w) + return feature + + +class ResNetWithPosEnc(nn.Module): + def __init__(self, config): + super().__init__() + self.resnet = resnet45() + + self.d_model = ifnone(config.model_vision_d_model, _default_tfmer_cfg['d_model']) + self.pos_encoder = PositionalEncoding(self.d_model, max_len=8*32) + + def forward(self, images): + feature = self.resnet(images) + n, c, h, w = feature.shape + feature = feature.view(n, c, -1).permute(2, 0, 1) + feature = self.pos_encoder(feature) + feature = feature.permute(1, 2, 0).view(n, c, h, w) + return feature diff --git a/modules_matrn/model.py b/modules_matrn/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d18a71b7c26522f237e0d725dcc2657bbd93e38e --- /dev/null +++ b/modules_matrn/model.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn + +from utils_matrn import CharsetMapper + + +_default_tfmer_cfg = dict(d_model=512, nhead=8, d_inner=2048, # 1024 + dropout=0.1, activation='relu') + +class Model(nn.Module): + + def __init__(self, config): + super().__init__() + self.max_length = config.dataset_max_length + 1 + self.charset = CharsetMapper(config.dataset_charset_path, max_length=self.max_length) + + def load(self, source, device=None, strict=True): + state = torch.load(source, map_location=device) + self.load_state_dict(state['model'], strict=strict) + + def _get_length(self, logit, dim=-1): + """ Greed decoder to obtain length from logit""" + out = (logit.argmax(dim=-1) == self.charset.null_label) + abn = out.any(dim) + out = ((out.cumsum(dim) == 1) & out).max(dim)[1] + out = out + 1 # additional end token + out = torch.where(abn, out, out.new_tensor(logit.shape[1])) + return out + + @staticmethod + def _get_padding_mask(length, max_length): + length = length.unsqueeze(-1) + grid = torch.arange(0, max_length, device=length.device).unsqueeze(0) + return grid >= length + + @staticmethod + def _get_square_subsequent_mask(sz, device, diagonal=0, fw=True): + r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ + mask = (torch.triu(torch.ones(sz, sz, device=device), diagonal=diagonal) == 1) + if fw: mask = mask.transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + @staticmethod + def _get_location_mask(sz, device=None): + mask = torch.eye(sz, device=device) + mask = mask.float().masked_fill(mask == 1, float('-inf')) + return mask diff --git a/modules_matrn/model_abinet.py b/modules_matrn/model_abinet.py new file mode 100644 index 0000000000000000000000000000000000000000..34c37b64ac4814b868483e3027d6ecf88b62c1bb --- /dev/null +++ b/modules_matrn/model_abinet.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +from fastai.vision import * + +from .model_vision import BaseVision +from .model_language import BCNLanguage +from .model_alignment import BaseAlignment + + +class ABINetModel(nn.Module): + def __init__(self, config): + super().__init__() + self.use_alignment = ifnone(config.model_use_alignment, True) + self.max_length = config.dataset_max_length + 1 # additional stop token + self.vision = BaseVision(config) + self.language = BCNLanguage(config) + if self.use_alignment: self.alignment = BaseAlignment(config) + + def forward(self, images, *args): + v_res = self.vision(images) + v_tokens = torch.softmax(v_res['logits'], dim=-1) + v_lengths = v_res['pt_lengths'].clamp_(2, self.max_length) # TODO:move to langauge model + + l_res = self.language(v_tokens, v_lengths) + if not self.use_alignment: + return l_res, v_res + l_feature, v_feature = l_res['feature'], v_res['feature'] + + a_res = self.alignment(l_feature, v_feature) + return a_res, l_res, v_res diff --git a/modules_matrn/model_abinet_iter.py b/modules_matrn/model_abinet_iter.py new file mode 100644 index 0000000000000000000000000000000000000000..6588890570a1180ea32f7969cefd4b59c25409a7 --- /dev/null +++ b/modules_matrn/model_abinet_iter.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +from fastai.vision import * + +from .model_vision import BaseVision +from .model_language import BCNLanguage +from .model_alignment import BaseAlignment + + +class ABINetIterModel(nn.Module): + def __init__(self, config): + super().__init__() + self.iter_size = ifnone(config.model_iter_size, 1) + self.max_length = config.dataset_max_length + 1 # additional stop token + self.vision = BaseVision(config) + self.language = BCNLanguage(config) + self.alignment = BaseAlignment(config) + + def forward(self, images, *args): + v_res = self.vision(images) + a_res = v_res + all_l_res, all_a_res = [], [] + for _ in range(self.iter_size): + tokens = torch.softmax(a_res['logits'], dim=-1) + lengths = a_res['pt_lengths'] + lengths.clamp_(2, self.max_length) # TODO:move to langauge model + l_res = self.language(tokens, lengths) + all_l_res.append(l_res) + a_res = self.alignment(l_res['feature'], v_res['feature']) + all_a_res.append(a_res) + if self.training: + return all_a_res, all_l_res, v_res + else: + return a_res, all_l_res[-1], v_res diff --git a/modules_matrn/model_alignment.py b/modules_matrn/model_alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..f2bf5fa95c8b36f8d613def48bc5beaad7c22955 --- /dev/null +++ b/modules_matrn/model_alignment.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +from fastai.vision import * + +from modules_matrn.model import Model, _default_tfmer_cfg + + +class BaseAlignment(Model): + def __init__(self, config): + super().__init__(config) + d_model = ifnone(config.model_alignment_d_model, _default_tfmer_cfg['d_model']) + + self.loss_weight = ifnone(config.model_alignment_loss_weight, 1.0) + self.max_length = config.dataset_max_length + 1 # additional stop token + self.w_att = nn.Linear(2 * d_model, d_model) + self.cls = nn.Linear(d_model, self.charset.num_classes) + + def forward(self, l_feature, v_feature): + """ + Args: + l_feature: (N, T, E) where T is length, N is batch size and d is dim of model + v_feature: (N, T, E) shape the same as l_feature + l_lengths: (N,) + v_lengths: (N,) + """ + f = torch.cat((l_feature, v_feature), dim=2) + f_att = torch.sigmoid(self.w_att(f)) + output = f_att * v_feature + (1 - f_att) * l_feature + + logits = self.cls(output) # (N, T, C) + pt_lengths = self._get_length(logits) + + return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight':self.loss_weight, + 'name': 'alignment'} diff --git a/modules_matrn/model_language.py b/modules_matrn/model_language.py new file mode 100644 index 0000000000000000000000000000000000000000..030f3135529cb3fef53bc607b1a80c64592cf02a --- /dev/null +++ b/modules_matrn/model_language.py @@ -0,0 +1,67 @@ +import logging +import torch.nn as nn +from fastai.vision import * + +from modules_matrn.model import _default_tfmer_cfg +from modules_matrn.model import Model +from modules_matrn.transformer import (PositionalEncoding, + TransformerDecoder, + TransformerDecoderLayer) + + +class BCNLanguage(Model): + def __init__(self, config): + super().__init__(config) + d_model = ifnone(config.model_language_d_model, _default_tfmer_cfg['d_model']) + nhead = ifnone(config.model_language_nhead, _default_tfmer_cfg['nhead']) + d_inner = ifnone(config.model_language_d_inner, _default_tfmer_cfg['d_inner']) + dropout = ifnone(config.model_language_dropout, _default_tfmer_cfg['dropout']) + activation = ifnone(config.model_language_activation, _default_tfmer_cfg['activation']) + num_layers = ifnone(config.model_language_num_layers, 4) + self.d_model = d_model + self.detach = ifnone(config.model_language_detach, True) + self.use_self_attn = ifnone(config.model_language_use_self_attn, False) + self.loss_weight = ifnone(config.model_language_loss_weight, 1.0) + self.max_length = config.dataset_max_length + 1 # additional stop token + self.debug = ifnone(config.global_debug, False) + + self.proj = nn.Linear(self.charset.num_classes, d_model, False) + self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length) + self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length) + decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout, + activation, self_attn=self.use_self_attn, debug=self.debug) + self.model = TransformerDecoder(decoder_layer, num_layers) + + self.cls = nn.Linear(d_model, self.charset.num_classes) + + if config.model_language_checkpoint is not None: + logging.info(f'Read language model from {config.model_language_checkpoint}.') + self.load(config.model_language_checkpoint) + + def forward(self, tokens, lengths): + """ + Args: + tokens: (N, T, C) where T is length, N is batch size and C is classes number + lengths: (N,) + """ + if self.detach: tokens = tokens.detach() + embed = self.proj(tokens) # (N, T, E) + embed = embed.permute(1, 0, 2) # (T, N, E) + embed = self.token_encoder(embed) # (T, N, E) + padding_mask = self._get_padding_mask(lengths, self.max_length) + + zeros = embed.new_zeros(*embed.shape) + qeury = self.pos_encoder(zeros) + location_mask = self._get_location_mask(self.max_length, tokens.device) + output = self.model(qeury, embed, + tgt_key_padding_mask=padding_mask, + memory_mask=location_mask, + memory_key_padding_mask=padding_mask) # (T, N, E) + output = output.permute(1, 0, 2) # (N, T, E) + + logits = self.cls(output) # (N, T, C) + pt_lengths = self._get_length(logits) + + res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths, + 'loss_weight':self.loss_weight, 'name': 'language'} + return res diff --git a/modules_matrn/model_matrn_iter.py b/modules_matrn/model_matrn_iter.py new file mode 100644 index 0000000000000000000000000000000000000000..fd9198a7014271bdd7394ba39b2191d056e90ddf --- /dev/null +++ b/modules_matrn/model_matrn_iter.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +from fastai.vision import * + +from .model_vision import BaseVision +from .model_language import BCNLanguage +from .model_semantic_visual_backbone_feature import BaseSemanticVisual_backbone_feature + + +class MATRN(nn.Module): + def __init__(self, config): + super().__init__() + self.iter_size = ifnone(config.model_iter_size, 1) + self.test_bh = ifnone(config.test_bh, None) + self.max_length = config.dataset_max_length + 1 # additional stop token + self.vision = BaseVision(config) + self.language = BCNLanguage(config) + self.semantic_visual = BaseSemanticVisual_backbone_feature(config) + + # def forward(self, images, *args): + def forward(self, images, texts=None): + v_res = self.vision(images) + a_res = v_res + all_l_res, all_a_res = [], [] + for _ in range(self.iter_size): + tokens = torch.softmax(a_res['logits'], dim=-1) + lengths = a_res['pt_lengths'] + lengths.clamp_(2, self.max_length) + l_res = self.language(tokens, lengths) + all_l_res.append(l_res) + lengths_l = l_res['pt_lengths'] + lengths_l.clamp_(2, self.max_length) + + v_attn_input = v_res['attn_scores'].clone().detach() + l_logits_input = None + texts_input = None + + a_res = self.semantic_visual(l_res['feature'], v_res['backbone_feature'], lengths_l=lengths_l, v_attn=v_attn_input, l_logits=l_logits_input, texts=texts_input, training=self.training) + + a_v_res = {'logits': a_res['v_logits'], 'pt_lengths': a_res['pt_v_lengths'], 'loss_weight': a_res['loss_weight'], + 'name': 'alignment'} + all_a_res.append(a_v_res) + a_s_res = {'logits': a_res['s_logits'], 'pt_lengths': a_res['pt_s_lengths'], 'loss_weight': a_res['loss_weight'], + 'name': 'alignment'} + all_a_res.append(a_s_res) + all_a_res.append(a_res) + + if self.training: + return all_a_res, all_l_res, v_res + else: + if self.test_bh is None: + return a_res, all_l_res[-1], v_res + elif self.test_bh == 'final': + return a_res, all_l_res[-1], v_res + elif self.test_bh == 'semantic': + return all_a_res[-2], all_l_res[-1], v_res + elif self.test_bh == 'visual': + return all_a_res[-3], all_l_res[-1], v_res + else: + raise NotImplementedError diff --git a/modules_matrn/model_semantic_visual_backbone_feature.py b/modules_matrn/model_semantic_visual_backbone_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4ecf85e3ba923bcdfead6b1ec65459e02bf0e1 --- /dev/null +++ b/modules_matrn/model_semantic_visual_backbone_feature.py @@ -0,0 +1,119 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import numpy as np +from fastai.vision import * + +from modules_matrn.attention import * +from modules_matrn.model import Model, _default_tfmer_cfg +from modules_matrn.transformer import (PositionalEncoding, + TransformerEncoder, + TransformerEncoderLayer) + + +class BaseSemanticVisual_backbone_feature(Model): + def __init__(self, config): + super().__init__(config) + d_model = ifnone(config.model_alignment_d_model, _default_tfmer_cfg['d_model']) + nhead = ifnone(config.model_alignment_nhead, _default_tfmer_cfg['nhead']) + d_inner = ifnone(config.model_alignment_d_inner, _default_tfmer_cfg['d_inner']) + dropout = ifnone(config.model_alignmentl_dropout, _default_tfmer_cfg['dropout']) + activation = ifnone(config.model_alignment_activation, _default_tfmer_cfg['activation']) + num_layers = ifnone(config.model_alignment_num_layers, 2) + + self.mask_example_prob = ifnone(config.model_alignment_mask_example_prob, 0.9) + self.mask_candidate_prob = ifnone(config.model_alignment_mask_candidate_prob, 0.9) + self.num_vis_mask = ifnone(config.model_alignment_num_vis_mask, 10) + self.nhead = nhead + + self.d_model = d_model + self.use_self_attn = ifnone(config.model_alignment_use_self_attn, False) + self.loss_weight = ifnone(config.model_alignment_loss_weight, 1.0) + self.max_length = config.dataset_max_length + 1 # additional stop token + self.debug = ifnone(config.global_debug, False) + + encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead, + dim_feedforward=d_inner, dropout=dropout, activation=activation) + self.model1 = TransformerEncoder(encoder_layer, num_layers) + self.pos_encoder_tfm = PositionalEncoding(d_model, dropout=0, max_len=8*32) + + mode = ifnone(config.model_alignment_attention_mode, 'nearest') + self.model2_vis = PositionAttention( + max_length=config.dataset_max_length + 1, # additional stop token + mode=mode + ) + self.cls_vis = nn.Linear(d_model, self.charset.num_classes) + self.cls_sem = nn.Linear(d_model, self.charset.num_classes) + self.w_att = nn.Linear(2 * d_model, d_model) + + v_token = torch.empty((1, d_model)) + self.v_token = nn.Parameter(v_token) + torch.nn.init.uniform_(self.v_token, -0.001, 0.001) + + self.cls = nn.Linear(d_model, self.charset.num_classes) + + def forward(self, l_feature, v_feature, lengths_l=None, v_attn=None, l_logits=None, texts=None, training=True): + """ + Args: + l_feature: (N, T, E) where T is length, N is batch size and d is dim of model + v_feature: (N, E, H, W) + lengths_l: (N,) + v_attn: (N, T, H, W) + l_logits: (N, T, C) + texts: (N, T, C) + """ + padding_mask = self._get_padding_mask(lengths_l, self.max_length) + + l_feature = l_feature.permute(1, 0, 2) # (T, N, E) + N, E, H, W = v_feature.size() + v_feature = v_feature.view(N, E, H*W).contiguous().permute(2, 0, 1) # (H*W, N, E) + + if training: + n, t, h, w = v_attn.shape + v_attn = v_attn.view(n, t, -1) # (N, T, H*W) + for idx, length in enumerate(lengths_l): + if np.random.random() <= self.mask_example_prob: + l_idx = np.random.randint(int(length)) + v_random_idx = v_attn[idx, l_idx].argsort(descending=True).cpu().numpy()[:self.num_vis_mask,] + v_random_idx = v_random_idx[np.random.random(v_random_idx.shape) <= self.mask_candidate_prob] + v_feature[v_random_idx, idx] = self.v_token + + if len(v_attn.shape) == 4: + n, t, h, w = v_attn.shape + v_attn = v_attn.view(n, t, -1) # (N, T, H*W) + + zeros = v_feature.new_zeros((h*w, n, E)) # (H*W, N, E) + base_pos = self.pos_encoder_tfm(zeros) # (H*W, N, E) + base_pos = base_pos.permute(1, 0, 2) # (N, H*W, E) + + base_pos = torch.bmm(v_attn, base_pos) # (N, T, E) + base_pos = base_pos.permute(1, 0, 2) # (T, N, E) + + l_feature = l_feature + base_pos + + sv_feature = torch.cat((v_feature, l_feature), dim=0) # (H*W+T, N, E) + sv_feature = self.model1(sv_feature) # (H*W+T, N, E) + + sv_to_v_feature = sv_feature[:H*W] # (H*W, N, E) + sv_to_s_feature = sv_feature[H*W:] # (T, N, E) + + sv_to_v_feature = sv_to_v_feature.permute(1, 2, 0).view(N, E, H, W) + sv_to_v_feature, _ = self.model2_vis(sv_to_v_feature) # (N, T, E) + sv_to_v_logits = self.cls_vis(sv_to_v_feature) # (N, T, C) + pt_v_lengths = self._get_length(sv_to_v_logits) # (N,) + + sv_to_s_feature = sv_to_s_feature.permute(1, 0, 2) # (N, T, E) + sv_to_s_logits = self.cls_sem(sv_to_s_feature) # (N, T, C) + pt_s_lengths = self._get_length(sv_to_s_logits) # (N,) + + f = torch.cat((sv_to_v_feature, sv_to_s_feature), dim=2) + f_att = torch.sigmoid(self.w_att(f)) + output = f_att * sv_to_v_feature + (1 - f_att) * sv_to_s_feature + + logits = self.cls(output) # (N, T, C) + pt_lengths = self._get_length(logits) + + return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight':self.loss_weight*3, + 'v_logits': sv_to_v_logits, 'pt_v_lengths': pt_v_lengths, + 's_logits': sv_to_s_logits, 'pt_s_lengths': pt_s_lengths, + 'name': 'alignment'} diff --git a/modules_matrn/model_vision.py b/modules_matrn/model_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..c61d6cec067857c9bf1973e419a9deec3dd6e699 --- /dev/null +++ b/modules_matrn/model_vision.py @@ -0,0 +1,64 @@ +import logging +import torch.nn as nn +from fastai.vision import * + +from modules_matrn.attention import * +from modules_matrn.backbone import ResTranformer, ResNetWithPosEnc +from modules_matrn.model import Model + + +class BaseVision(Model): + def __init__(self, config): + super().__init__(config) + self.loss_weight = ifnone(config.model_vision_loss_weight, 1.0) + self.out_channels = ifnone(config.model_vision_d_model, 512) + + self.num_more_attention = ifnone(config.model_vision_num_more_attention, -1) + + if config.model_vision_backbone == 'transformer': + self.backbone = ResTranformer(config) + else: self.backbone = ResNetWithPosEnc(config) + + if config.model_vision_attention == 'position': + mode = ifnone(config.model_vision_attention_mode, 'nearest') + self.attention = PositionAttention( + max_length=config.dataset_max_length + 1, # additional stop token + mode=mode, + ) + elif config.model_vision_attention == 'attention': + self.attention = Attention( + max_length=config.dataset_max_length + 1, # additional stop token + n_feature=8*32, + ) + else: + raise Exception(f'{config.model_vision_attention} is not valid.') + self.cls = nn.Linear(self.out_channels, self.charset.num_classes) + + if config.model_vision_checkpoint is not None: + logging.info(f'Read vision model from {config.model_vision_checkpoint}.') + self.load(config.model_vision_checkpoint) + + if self.num_more_attention > 0: + mode = ifnone(config.model_vision_attention_mode, 'nearest') + self.more_attention = nn.ModuleList([ + PositionAttention( + max_length=config.dataset_max_length + 1, # additional stop token + mode=mode, + ) for _ in range(self.num_more_attention) + ]) + + + def forward(self, images, *args): + features = self.backbone(images) # (N, E, H, W) + attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W) + + if self.num_more_attention > 0: + for attn in self.more_attention: + attn_vecs, attn_scores = attn(features, attn_vecs) + + logits = self.cls(attn_vecs) # (N, T, C) + pt_lengths = self._get_length(logits) + + return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths, + 'attn_scores': attn_scores, 'loss_weight':self.loss_weight, 'name': 'vision', + 'backbone_feature': features} diff --git a/modules_matrn/resnet.py b/modules_matrn/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..119a83edc78a5e9aa05be079d628553fa7008406 --- /dev/null +++ b/modules_matrn/resnet.py @@ -0,0 +1,105 @@ +import math + +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo + + +def conv1x1(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv1x1(inplanes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU() + self.relu2 = nn.ReLU() + self.conv2 = conv3x3(planes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu2(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers): + self.inplanes = 32 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU() + + self.layer1 = self._make_layer(block, 32, layers[0], stride=2) + self.layer2 = self._make_layer(block, 64, layers[1], stride=1) + self.layer3 = self._make_layer(block, 128, layers[2], stride=2) + self.layer4 = self._make_layer(block, 256, layers[3], stride=1) + self.layer5 = self._make_layer(block, 512, layers[4], stride=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.layer5(x) + return x + + +def resnet45(): + return ResNet(BasicBlock, [3, 4, 6, 6, 3]) diff --git a/modules_matrn/transformer.py b/modules_matrn/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..42717880a5b62b60c88a58ad6ee27ec53118e578 --- /dev/null +++ b/modules_matrn/transformer.py @@ -0,0 +1,956 @@ +# pytorch 1.5.0 +import copy +import math +import warnings +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList, Parameter +from torch.nn import functional as F +from torch.nn.init import constant_, xavier_uniform_ + + +def multi_head_attention_forward(query, # type: Tensor + key, # type: Tensor + value, # type: Tensor + embed_dim_to_check, # type: int + num_heads, # type: int + in_proj_weight, # type: Tensor + in_proj_bias, # type: Tensor + bias_k, # type: Optional[Tensor] + bias_v, # type: Optional[Tensor] + add_zero_attn, # type: bool + dropout_p, # type: float + out_proj_weight, # type: Tensor + out_proj_bias, # type: Tensor + training=True, # type: bool + key_padding_mask=None, # type: Optional[Tensor] + need_weights=True, # type: bool + attn_mask=None, # type: Optional[Tensor] + use_separate_proj_weight=False, # type: bool + q_proj_weight=None, # type: Optional[Tensor] + k_proj_weight=None, # type: Optional[Tensor] + v_proj_weight=None, # type: Optional[Tensor] + static_k=None, # type: Optional[Tensor] + static_v=None # type: Optional[Tensor] + ): + # type: (...) -> Tuple[Tensor, Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + # if not torch.jit.is_scripting(): + # tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, + # out_proj_weight, out_proj_bias) + # if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + # return handle_torch_function( + # multi_head_attention_forward, tens_ops, query, key, value, + # embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, + # bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, + # out_proj_bias, training=training, key_padding_mask=key_padding_mask, + # need_weights=need_weights, attn_mask=attn_mask, + # use_separate_proj_weight=use_separate_proj_weight, + # q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, + # v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v) + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size() == value.size() + + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if not use_separate_proj_weight: + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = F.linear(query, _w, _b) + + if key is None: + assert value is None + k = None + v = None + else: + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = F.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = F.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = F.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = F.linear(value, _w, _b) + else: + q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) + len1, len2 = q_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == query.size(-1) + + k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) + len1, len2 = k_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == key.size(-1) + + v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) + len1, len2 = v_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == value.size(-1) + + if in_proj_bias is not None: + q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) + k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) + v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) + else: + q = F.linear(query, q_proj_weight_non_opt, in_proj_bias) + k = F.linear(key, k_proj_weight_non_opt, in_proj_bias) + v = F.linear(value, v_proj_weight_non_opt, in_proj_bias) + q = q * scaling + + if attn_mask is not None: + assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ + attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \ + 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) + if attn_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) + # attn_mask's dim is 3 now. + + # # convert ByteTensor key_padding_mask to bool + # if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + # warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + # key_padding_mask = key_padding_mask.to(torch.bool) + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float('-inf')) + else: + attn_output_weights += attn_mask + + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) + + attn_output_weights = F.softmax( + attn_output_weights, dim=-1) + attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + Note: if kdim and vdim are None, they will be set to embed_dim such that + query, key, and value have the same number of features. + Examples:: + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + # __annotations__ = { + # 'bias_k': torch._jit_internal.Optional[torch.Tensor], + # 'bias_v': torch._jit_internal.Optional[torch.Tensor], + # } + __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'] + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + if self._qkv_same_embed_dim is False: + self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) + self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) + self.register_parameter('in_proj_weight', None) + else: + self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) + self.register_parameter('q_proj_weight', None) + self.register_parameter('k_proj_weight', None) + self.register_parameter('v_proj_weight', None) + + if bias: + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) + else: + self.register_parameter('in_proj_bias', None) + self.out_proj = Linear(embed_dim, embed_dim, bias=bias) + + if add_bias_kv: + self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) + self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.) + constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if '_qkv_same_embed_dim' not in state: + state['_qkv_same_embed_dim'] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward(self, query, key, value, key_padding_mask=None, + need_weights=True, attn_mask=None): + # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if not self._qkv_same_embed_dim: + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight) + else: + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask) + + +class Transformer(Module): + r"""A transformer model. User is able to modify the attributes as needed. The architecture + is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, + Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and + Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information + Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805) + model with corresponding parameters. + + Args: + d_model: the number of expected features in the encoder/decoder inputs (default=512). + nhead: the number of heads in the multiheadattention models (default=8). + num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). + num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu). + custom_encoder: custom encoder (default=None). + custom_decoder: custom decoder (default=None). + + Examples:: + >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) + >>> src = torch.rand((10, 32, 512)) + >>> tgt = torch.rand((20, 32, 512)) + >>> out = transformer_model(src, tgt) + + Note: A full example to apply nn.Transformer module for the word language model is available in + https://github.com/pytorch/examples/tree/master/word_language_model + """ + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", custom_encoder=None, custom_decoder=None): + super(Transformer, self).__init__() + + if custom_encoder is not None: + self.encoder = custom_encoder + else: + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation) + encoder_norm = LayerNorm(d_model) + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + if custom_decoder is not None: + self.decoder = custom_decoder + else: + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation) + decoder_norm = LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def forward(self, src, tgt, src_mask=None, tgt_mask=None, + memory_mask=None, src_key_padding_mask=None, + tgt_key_padding_mask=None, memory_key_padding_mask=None): + # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor # noqa + r"""Take in and process masked source/target sequences. + + Args: + src: the sequence to the encoder (required). + tgt: the sequence to the decoder (required). + src_mask: the additive mask for the src sequence (optional). + tgt_mask: the additive mask for the tgt sequence (optional). + memory_mask: the additive mask for the encoder output (optional). + src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). + tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). + memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). + + Shape: + - src: :math:`(S, N, E)`. + - tgt: :math:`(T, N, E)`. + - src_mask: :math:`(S, S)`. + - tgt_mask: :math:`(T, T)`. + - memory_mask: :math:`(T, S)`. + - src_key_padding_mask: :math:`(N, S)`. + - tgt_key_padding_mask: :math:`(N, T)`. + - memory_key_padding_mask: :math:`(N, S)`. + + Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by + the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero + positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + + - output: :math:`(T, N, E)`. + + Note: Due to the multi-head attention architecture in the transformer model, + the output sequence length of a transformer is same as the input sequence + (i.e. target) length of the decode. + + where S is the source sequence length, T is the target sequence length, N is the + batch size, E is the feature number + + Examples: + >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) + """ + + if src.size(1) != tgt.size(1): + raise RuntimeError("the batch number of src and tgt must be equal") + + if src.size(2) != self.d_model or tgt.size(2) != self.d_model: + raise RuntimeError("the feature number of src and tgt must be equal to d_model") + + memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) + output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + return output + + def generate_square_subsequent_mask(self, sz): + r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + def _reset_parameters(self): + r"""Initiate parameters in the transformer model.""" + + for p in self.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + + +class TransformerEncoder(Module): + r"""TransformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ['norm'] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, mask=None, src_key_padding_mask=None): + # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + output = src + + for i, mod in enumerate(self.layers): + output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(Module): + r"""TransformerDecoder is a stack of N decoder layers + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = transformer_decoder(tgt, memory) + """ + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm=None): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, tgt, memory, memory2=None, tgt_mask=None, + memory_mask=None, memory_mask2=None, tgt_key_padding_mask=None, + memory_key_padding_mask=None, memory_key_padding_mask2=None): + # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor + r"""Pass the inputs (and mask) through the decoder layer in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + output = tgt + + for mod in self.layers: + output = mod(output, memory, memory2=memory2, tgt_mask=tgt_mask, + memory_mask=memory_mask, memory_mask2=memory_mask2, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + memory_key_padding_mask2=memory_key_padding_mask2) + + if self.norm is not None: + output = self.norm(output) + + return output + +class TransformerDecoder_DiffKV(Module): + r"""TransformerDecoder is a stack of N decoder layers + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = transformer_decoder(tgt, memory) + """ + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm=None): + super(TransformerDecoder_DiffKV, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, tgt, memory, memory2=None, tgt_mask=None, + memory_mask=None, memory_mask2=None, tgt_key_padding_mask=None, + memory_key_padding_mask=None, memory_key_padding_mask2=None): + # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor + r"""Pass the inputs (and mask) through the decoder layer in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + output = tgt + cnt = 0 + + for mod in self.layers: + output = mod(output, memory[cnt], memory2=memory2, tgt_mask=tgt_mask, + memory_mask=memory_mask, memory_mask2=memory_mask2, + tgt_key_padding_mask=tgt_key_padding_mask[cnt], + memory_key_padding_mask=memory_key_padding_mask[cnt], + memory_key_padding_mask2=memory_key_padding_mask2) + cnt = cnt + 1 + if self.norm is not None: + output = self.norm(output) + + return output + +class TransformerEncoderLayer(Module): + r"""TransformerEncoderLayer is made up of self-attn and feedforward network. + This standard encoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", debug=False): + super(TransformerEncoderLayer, self).__init__() + self.debug = debug + self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + src2, attn = self.self_attn(src, src, src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask) + if self.debug: self.attn = attn + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + + return src + + +class TransformerDecoderLayer(Module): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", self_attn=True, siamese=False, debug=False): + super(TransformerDecoderLayer, self).__init__() + self.has_self_attn, self.siamese = self_attn, siamese + self.debug = debug + if self.has_self_attn: + self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) + self.norm1 = LayerNorm(d_model) + self.dropout1 = Dropout(dropout) + self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model) + + self.norm2 = LayerNorm(d_model) + self.norm3 = LayerNorm(d_model) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + if self.siamese: + self.multihead_attn2 = MultiheadAttention(d_model, nhead, dropout=dropout) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super(TransformerDecoderLayer, self).__setstate__(state) + + def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, + tgt_key_padding_mask=None, memory_key_padding_mask=None, + memory2=None, memory_mask2=None, memory_key_padding_mask2=None): + # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor + r"""Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + if self.has_self_attn: + tgt2, attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + if self.debug: self.attn = attn + tgt2, attn2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + if self.debug: self.attn2 = attn2 + + if self.siamese: + tgt3, attn3 = self.multihead_attn2(tgt, memory2, memory2, attn_mask=memory_mask2, + key_padding_mask=memory_key_padding_mask2) + tgt = tgt + self.dropout2(tgt3) + if self.debug: self.attn3 = attn3 + + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + + return tgt + + +def _get_clones(module, N): + return ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + + +class PositionalEncoding(nn.Module): + r"""Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + r"""Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + + x = x + self.pe[:x.size(0), :] + return self.dropout(x) + + +if __name__ == '__main__': + transformer_model = Transformer(nhead=16, num_encoder_layers=12) + src = torch.rand((10, 32, 512)) + tgt = torch.rand((20, 32, 512)) + out = transformer_model(src, tgt) + print(out) diff --git a/modules_srn/SRN_Resnet.py b/modules_srn/SRN_Resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6ef4d8a65d5314e7b3b6ecdc2c4db261b218a2 --- /dev/null +++ b/modules_srn/SRN_Resnet.py @@ -0,0 +1,255 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class VGG_FeatureExtractor(nn.Module): + """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super(VGG_FeatureExtractor, self).__init__() + self.output_channel = [int(output_channel / 8), int(output_channel / 4), + int(output_channel / 2), output_channel] # [64, 128, 256, 512] + self.ConvNet = nn.Sequential( + nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d(2, 2), # 64x16x50 + nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d(2, 2), # 128x8x25 + nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25 + nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 + nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False), + nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25 + nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False), + nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), + nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 + nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 + + def forward(self, input): + return self.ConvNet(input) + + +class RCNN_FeatureExtractor(nn.Module): + """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super(RCNN_FeatureExtractor, self).__init__() + self.output_channel = [int(output_channel / 8), int(output_channel / 4), + int(output_channel / 2), output_channel] # [64, 128, 256, 512] + self.ConvNet = nn.Sequential( + nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d(2, 2), # 64 x 16 x 50 + GRCL(self.output_channel[0], self.output_channel[0], num_iteration=5, kernel_size=3, pad=1), + nn.MaxPool2d(2, 2), # 64 x 8 x 25 + GRCL(self.output_channel[0], self.output_channel[1], num_iteration=5, kernel_size=3, pad=1), + nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26 + GRCL(self.output_channel[1], self.output_channel[2], num_iteration=5, kernel_size=3, pad=1), + nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27 + nn.Conv2d(self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False), + nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26 + + def forward(self, input): + return self.ConvNet(input) + + +class ResNet_FeatureExtractor(nn.Module): + """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super(ResNet_FeatureExtractor, self).__init__() + self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) + + def forward(self, input): + return self.ConvNet(input) + + +# For Gated RCNN +class GRCL(nn.Module): + + def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad): + super(GRCL, self).__init__() + self.wgf_u = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False) + self.wgr_x = nn.Conv2d(output_channel, output_channel, 1, 1, 0, bias=False) + self.wf_u = nn.Conv2d(input_channel, output_channel, kernel_size, 1, pad, bias=False) + self.wr_x = nn.Conv2d(output_channel, output_channel, kernel_size, 1, pad, bias=False) + + self.BN_x_init = nn.BatchNorm2d(output_channel) + + self.num_iteration = num_iteration + self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)] + self.GRCL = nn.Sequential(*self.GRCL) + + def forward(self, input): + """ The input of GRCL is consistant over time t, which is denoted by u(0) + thus wgf_u / wf_u is also consistant over time t. + """ + wgf_u = self.wgf_u(input) + wf_u = self.wf_u(input) + x = F.relu(self.BN_x_init(wf_u)) + + for i in range(self.num_iteration): + x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x)) + + return x + + +class GRCL_unit(nn.Module): + + def __init__(self, output_channel): + super(GRCL_unit, self).__init__() + self.BN_gfu = nn.BatchNorm2d(output_channel) + self.BN_grx = nn.BatchNorm2d(output_channel) + self.BN_fu = nn.BatchNorm2d(output_channel) + self.BN_rx = nn.BatchNorm2d(output_channel) + self.BN_Gx = nn.BatchNorm2d(output_channel) + + def forward(self, wgf_u, wgr_x, wf_u, wr_x): + G_first_term = self.BN_gfu(wgf_u) + G_second_term = self.BN_grx(wgr_x) + G = F.sigmoid(G_first_term + G_second_term) + + x_first_term = self.BN_fu(wf_u) + x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G) + x = F.relu(x_first_term + x_second_term) + + return x + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = self._conv3x3(inplanes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = self._conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def _conv3x3(self, in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, input_channel, output_channel, block, layers): + super(ResNet, self).__init__() + + self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] + + self.inplanes = int(output_channel / 8) + self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) + self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_2 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU(inplace=True) + + self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) + self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ + 0], kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) + + self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) + self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ + 1], kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) + + self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) + self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) + self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ + 2], kernel_size=3, stride=1, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) + + self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) + self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) + self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) + self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=1, padding=0, bias=False) + self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv0_1(x) + x = self.bn0_1(x) + x = self.relu(x) + x = self.conv0_2(x) + x = self.bn0_2(x) + x = self.relu(x) + + x = self.maxpool1(x) + x = self.layer1(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.maxpool2(x) + x = self.layer2(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.maxpool3(x) + x = self.layer3(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.layer4(x) + x = self.conv4_1(x) + x = self.bn4_1(x) + x = self.relu(x) + x = self.conv4_2(x) + x = self.bn4_2(x) + x = self.relu(x) + + return x + +if __name__=='__main__': + x = torch.rand(4, 1, 32, 100) + # x = x.cuda() + model = ResNet_FeatureExtractor(1, 5112) + # model = model.cuda() + y = model(x) + print(y.shape) \ No newline at end of file diff --git a/modules_srn/SRN_modules.py b/modules_srn/SRN_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..a1754086a5c2a2ca1314c5e4cd5cd826c83849de --- /dev/null +++ b/modules_srn/SRN_modules.py @@ -0,0 +1,518 @@ +# coding:utf-8 +# chenjun +# date:2020-04-18 +import torch.nn as nn +import torch +import torch.nn.functional as F +import numpy as np + + +# def get_non_pad_mask(seq, PAD): +# assert seq.dim() == 2 +# return seq.ne(PAD).type(torch.float).unsqueeze(-1) + +def get_pad_mask(seq, pad_idx): + return (seq == pad_idx).unsqueeze(-2) + + +def get_subsequent_mask(seq): + ''' For masking out the subsequent info. ''' + + sz_b, len_s = seq.size() + subsequent_mask = torch.triu( + torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) # 返回上三角矩阵 + subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls + + return subsequent_mask + + +def get_attn_key_pad_mask(seq_k, seq_q, PAD): + ''' For masking out the padding part of key sequence. + seq_k:src_seq + seq_q:tgt_seq + ''' + + # Expand to fit the shape of key query attention matrix. + len_q = seq_q.size(1) # 目标序列 + padding_mask = seq_k.eq(PAD) # 源序列 + padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk + + return padding_mask + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_hid, n_position=200): + super(PositionalEncoding, self).__init__() + + # Not a parameter + self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) + + def _get_sinusoid_encoding_table(self, n_position, d_hid): + ''' Sinusoid position encoding table ''' + # TODO: make it with torch instead of numpy + + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + def forward(self, x): + return x + self.pos_table[:, :x.size(1)].clone().detach() + + +class ScaledDotProductAttention(nn.Module): + ''' Scaled Dot-Product Attention ''' + + def __init__(self, temperature, attn_dropout=0.1): + super(ScaledDotProductAttention, self).__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + self.softmax = nn.Softmax(dim=2) + + def forward(self, q, k, v, mask=None): + + attn = torch.bmm(q, k.transpose(1, 2)) + attn = attn / self.temperature + + if mask is not None: + # print(mask.shape, attn.shape, v.shape) + attn = attn.masked_fill(mask, -1e9) + + attn = self.softmax(attn) # 第3个维度为权重 + attn = self.dropout(attn) + output = torch.bmm(attn, v) + + return output, attn + + +class MultiHeadAttention(nn.Module): + ''' Multi-Head Attention module ''' + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + super(MultiHeadAttention, self).__init__() + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.w_qs = nn.Linear(d_model, n_head * d_k) + self.w_ks = nn.Linear(d_model, n_head * d_k) + self.w_vs = nn.Linear(d_model, n_head * d_v) + nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) + + self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) + self.layer_norm = nn.LayerNorm(d_model) + + self.fc = nn.Linear(n_head * d_v, d_model) + nn.init.xavier_normal_(self.fc.weight) + + self.dropout = nn.Dropout(dropout) + + + def forward(self, q, k, v, mask=None): + + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + + sz_b, len_q, _ = q.size() + sz_b, len_k, _ = k.size() + sz_b, len_v, _ = v.size() + + residual = q + + q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) # 4*21*512 ---- 4*21*8*64 + k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) + + q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk + k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk + v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv + + mask = mask.repeat(n_head, 1, 1) if mask is not None else None # (n*b) x .. x .. + output, attn = self.attention(q, k, v, mask=mask) + + output = output.view(n_head, sz_b, len_q, d_v) + output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) + + output = self.dropout(self.fc(output)) + output = self.layer_norm(output + residual) + + return output, attn + +class PositionwiseFeedForward(nn.Module): + ''' A two-feed-forward-layer module ''' + + def __init__(self, d_in, d_hid, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise + self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise + self.layer_norm = nn.LayerNorm(d_in) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + residual = x + output = x.transpose(1, 2) + output = self.w_2(F.relu(self.w_1(output))) + output = output.transpose(1, 2) + output = self.dropout(output) + output = self.layer_norm(output + residual) + return output + + +class EncoderLayer(nn.Module): + ''' Compose with two layers ''' + + def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): + super(EncoderLayer, self).__init__() + self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) + + def forward(self, enc_input, slf_attn_mask=None): + enc_output, enc_slf_attn = self.slf_attn( + enc_input, enc_input, enc_input, mask=slf_attn_mask) + enc_output = self.pos_ffn(enc_output) + return enc_output, enc_slf_attn + + +class Torch_transformer_encoder(nn.Module): + ''' + use pytorch transformer for sequence learning + + ''' + def __init__(self, d_word_vec=512, n_layers=2, n_head=8, d_model=512, dim_feedforward=1024, n_position=256): + super(Torch_transformer_encoder, self).__init__() + + self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) + encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_head, dim_feedforward=dim_feedforward) + self.layer_norm = nn.LayerNorm(d_model) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers, norm=self.layer_norm) + self.dropout = nn.Dropout(p=0.1) + + def forward(self, cnn_feature, src_mask=None, return_attns=False): + enc_slf_attn_list = [] + + # -- Forward + enc_output = self.dropout(self.position_enc(cnn_feature)) # position embeding + + enc_output = self.encoder(enc_output) + + enc_output = self.layer_norm(enc_output) + + if return_attns: + return enc_output, enc_slf_attn_list + return enc_output, + + + +class Transforme_Encoder(nn.Module): + ''' to capture the global spatial dependencies''' + ''' + d_word_vec: 位置编码,特征空间维度 + n_layers: transformer的层数 + n_head:多头数量 + d_k: 64 + d_v: 64 + d_model: 512, + d_inner: 1024 + n_position: 位置编码的最大值 + ''' + def __init__( + self, d_word_vec=512, n_layers=2, n_head=8, d_k=64, d_v=64, + d_model=512, d_inner=1024, dropout=0.1, n_position=256): + + super().__init__() + + self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) + self.dropout = nn.Dropout(p=dropout) + self.layer_stack = nn.ModuleList([ + EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers)]) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + def forward(self, cnn_feature, src_mask, return_attns=False): + + enc_slf_attn_list = [] + + # -- Forward + enc_output = self.dropout(self.position_enc(cnn_feature)) # position embeding + + for enc_layer in self.layer_stack: + enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask) + enc_slf_attn_list += [enc_slf_attn] if return_attns else [] + + enc_output = self.layer_norm(enc_output) + + if return_attns: + return enc_output, enc_slf_attn_list + return enc_output, + + +class PVAM(nn.Module): + ''' Parallel Visual attention module 平行解码''' + ''' + n_dim:512,阅读顺序序列编码的空间维度 + N_max_character: 25,单张图片最多有多少个字符 + n_position: cnn出来之后特征的序列长度 + ''' + def __init__(self, n_dim=512, N_max_character=25, n_position=256): + + super(PVAM, self).__init__() + self.character_len = N_max_character + + self.f0_embedding = nn.Embedding(N_max_character, n_dim) + + self.w0 = nn.Linear(N_max_character, n_position) + self.wv = nn.Linear(n_dim, n_dim) + # first linear(512,25) + self.we = nn.Linear(n_dim, N_max_character) + + self.active = nn.Tanh() + self.softmax = nn.Softmax(dim=2) + + def forward(self, enc_output): + reading_order = torch.arange(self.character_len, dtype=torch.long, device=enc_output.device) + reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S) + reading_order_embed = self.f0_embedding(reading_order) # b,25,512 + + t = self.w0(reading_order_embed.permute(0,2,1)) # b,512,256 + t = self.active(t.permute(0,2,1) + self.wv(enc_output)) # b,256,512 + # first linear(512,25) + attn = self.we(t) # b,256,25 + + attn = self.softmax(attn.permute(0,2,1)) # b,25,256 + + g_output = torch.bmm(attn, enc_output) # b,25,512 + return g_output + + +class GSRM(nn.Module): + # global semantic reasoning module + ''' + n_dim:embed编码的特征空间维度 + n_class:embedding需要用到 + PAD:计算mask用到 + ''' + def __init__(self, n_dim=512, n_class=37, PAD=37-1, n_layers=4, n_position=25): + + super(GSRM, self).__init__() + + self.PAD = PAD + self.argmax_embed = nn.Embedding(n_class, n_dim) + + self.transformer_units = Transforme_Encoder(n_layers=n_layers, n_position=n_position) # for global context information + # self.transformer_units = Torch_transformer_encoder(n_layers=n_layers, n_position=n_position) + + def forward(self, e_out): + ''' + e_out: b,25,37 | the output from PVAM3 + ''' + e_argmax = e_out.argmax(dim=-1) # b, 25 + e = self.argmax_embed(e_argmax) # b,25,512 + + e_mask = get_pad_mask(e_argmax, self.PAD) # b,25,1 + s = self.transformer_units(e, None) # b,25,512 + + return s + + +class SRN_Decoder(nn.Module): + # the wrapper of decoder layers + ''' + n_dim: 特征空间维度 + n_class:字符种类 + N_max_character: 单张图最多只25个字符 + n_position:cnn输出的特征序列长度 + 整个有三个部分的输出 + ''' + def __init__(self, n_dim=512, n_class=37, N_max_character=25, n_position=256, GSRM_layer=4 ): + + super(SRN_Decoder, self).__init__() + + self.pvam = PVAM(N_max_character=N_max_character, n_position=n_position) + self.w_e = nn.Linear(n_dim, n_class) # output layer + + self.GSRM = GSRM(n_class=n_class, PAD=n_class-1, n_dim=n_dim, n_position=N_max_character, n_layers=GSRM_layer) + self.w_s = nn.Linear(n_dim, n_class) # output layer + + self.w_f = nn.Linear(n_dim, n_class) # output layer + + def forward(self, cnn_feature ): + '''cnn_feature: b,256,512 | the output from cnn''' + + g_output = self.pvam(cnn_feature) # b,25,512 + e_out = self.w_e(g_output) # b,25,37 ----> cross entropy loss | 第一个输出 + + s = self.GSRM(e_out)[0] # b,25,512 + s_out = self.w_s(s) # b,25,37f + + # TODO:change the add to gated unit + f = g_output + s # b,25,512 + f_out = self.w_f(f) + + return e_out, s_out, f_out + + +def cal_performance(preds, gold, mask=None, smoothing='1'): + ''' Apply label smoothing if needed ''' + + loss = 0. + n_correct = 0 + weights = [1.0, 0.15, 2.0] + for ori_pred, weight in zip(preds, weights): + pred = ori_pred.view(-1, ori_pred.shape[-1]) + # debug show + t_gold = gold.view(ori_pred.shape[0], -1) + t_pred_index = ori_pred.max(2)[1] + + mask = mask.view(-1) + non_pad_mask = mask.ne(0) if mask is not None else None + tloss = cal_loss(pred, gold, non_pad_mask, smoothing) + if torch.isnan(tloss): + print('have nan loss') + continue + else: + loss += tloss * weight + + pred = pred.max(1)[1] + gold = gold.contiguous().view(-1) + n_correct = pred.eq(gold) + n_correct = n_correct.masked_select(non_pad_mask).sum().item() if mask is not None else None + + return loss, n_correct + + +def cal_loss(pred, gold, mask, smoothing): + ''' Calculate cross entropy loss, apply label smoothing if needed. ''' + + gold = gold.contiguous().view(-1) + + if smoothing=='0': + eps = 0.1 + n_class = pred.size(1) + + one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) + one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) + log_prb = F.log_softmax(pred, dim=1) + + non_pad_mask = gold.ne(0) + loss = -(one_hot * log_prb).sum(dim=1) + loss = loss.masked_select(non_pad_mask).sum() # average later + elif smoothing == '1': + if mask is not None: + loss = F.cross_entropy(pred, gold, reduction='none') + loss = loss.masked_select(mask) + loss = loss.sum() / mask.sum() + else: + loss = F.cross_entropy(pred, gold) + else: + # loss = F.cross_entropy(pred, gold, ignore_index=PAD) + loss = F.cross_entropy(pred, gold) + + return loss + + +def cal_performance2(preds, gold, PAD, smoothing='1'): + ''' Apply label smoothing if needed ''' + + loss = 0. + n_correct = 0 + weights = [1.0, 0.15, 2.0] + for ori_pred, weight in zip(preds, weights): + pred = ori_pred.view(-1, ori_pred.shape[-1]) + # debug show + t_gold = gold.view(ori_pred.shape[0], -1) + t_pred_index = ori_pred.max(2)[1] + + tloss = cal_loss2(pred, gold, PAD, smoothing=smoothing) + if torch.isnan(tloss): + print('have nan loss') + continue + else: + loss += tloss * weight + + pred = pred.max(1)[1] + gold = gold.contiguous().view(-1) + n_correct = pred.eq(gold) + non_pad_mask = gold.ne(PAD) + n_correct = n_correct.masked_select(non_pad_mask).sum().item() + + return loss, n_correct + + +def cal_loss2(pred, gold, PAD, smoothing='1'): + ''' Calculate cross entropy loss, apply label smoothing if needed. ''' + + gold = gold.contiguous().view(-1) + + if smoothing=='0': + eps = 0.1 + n_class = pred.size(1) + + one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) + one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) + log_prb = F.log_softmax(pred, dim=1) + + non_pad_mask = gold.ne(0) + loss = -(one_hot * log_prb).sum(dim=1) + loss = loss.masked_select(non_pad_mask).sum() # average later + elif smoothing == '1': + loss = F.cross_entropy(pred, gold, ignore_index=PAD) + else: + # loss = F.cross_entropy(pred, gold, ignore_index=PAD) + loss = F.cross_entropy(pred, gold) + + return loss + + +if __name__=='__main__': + cnn_feature = torch.rand((2,256,512)) + model1 = Transforme_Encoder() + image = model1(cnn_feature,src_mask=None)[0] + model = SRN_Decoder(N_max_character=30) + + outs = model(image) + for out in outs: + print(out.shape) + + # image = torch.rand((4,3,32,60)) + # tgt_seq = torch.tensor([[ 2, 24, 2176, 882, 2480, 612, 1525, 480, 875, 147, 1700, 715, + # 1465, 3], + # [ 2, 369, 1781, 882, 703, 879, 2855, 2415, 502, 1154, 833, 1465, + # 3, 0], + # [ 2, 2943, 334, 328, 480, 330, 1644, 1449, 163, 147, 1823, 1184, + # 1465, 3], + # [ 2, 24, 396, 480, 703, 1646, 897, 1711, 1508, 703, 2321, 147, + # 642, 1465]], device='cuda:0') + # tgt_pos = torch.tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + # [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], + # [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + # [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]], + # device='cuda:0') + # src_seq = torch.tensor([[ 2, 598, 2088, 822, 2802, 1156, 157, 1099, 1000, 598, 1707, 1345, + # 3, 0, 0, 0], + # [ 2, 598, 2348, 822, 598, 1222, 471, 948, 986, 423, 1345, 3, + # 0, 0, 0, 0], + # [ 2, 2437, 2470, 901, 2473, 598, 1735, 84, 1, 2277, 1979, 499, + # 962, 1345, 3, 0], + # [ 2, 598, 186, 1904, 598, 868, 1339, 1604, 84, 598, 608, 1728, + # 1345, 3, 0, 0]], device='cuda:0') + + # device = torch.device('cuda') + # image = image.cuda() + # transformer = Transformer() + # transformer = transformer.to(device) + # transformer.train() + # out = transformer(image, tgt_seq, tgt_pos, src_seq) + + # gold = tgt_seq[:, 1:] # 从第二列开始 + + # # backward + # loss, n_correct = cal_performance(out, gold, smoothing=True) + # print(loss, n_correct) + # a = 1 \ No newline at end of file diff --git a/modules_srn/__init__.py b/modules_srn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules_srn/bert.py b/modules_srn/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..adf6d0cbea55ecc2d5bf7d07277c61d164138268 --- /dev/null +++ b/modules_srn/bert.py @@ -0,0 +1,284 @@ +# Copyright 2018 Dong-Hyun Lee, Kakao Brain. +# (Strongly inspired by original Google BERT code and Hugging Face's code) + +""" Transformer Model Classes & Config Class """ + +import math +import json +from typing import NamedTuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def split_last(x, shape): + "split the last dimension to given shape" + shape = list(shape) + assert shape.count(-1) <= 1 + if -1 in shape: + shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape)) + return x.view(*x.size()[:-1], *shape) + + +def merge_last(x, n_dims): + "merge the last n_dims to a dimension" + s = x.size() + assert n_dims > 1 and n_dims < len(s) + return x.view(*s[:-n_dims], -1) + + +def gelu(x): + "Implementation of the gelu activation function by Hugging Face" + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +class LayerNorm(nn.Module): + "A layernorm module in the TF style (epsilon inside the square root)." + def __init__(self, cfg, variance_epsilon=1e-12): + super().__init__() + self.gamma = nn.Parameter(torch.ones(cfg.dim)) + self.beta = nn.Parameter(torch.zeros(cfg.dim)) + self.variance_epsilon = variance_epsilon + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.gamma * x + self.beta + + +class Embeddings(nn.Module): + "The embedding module from word, position and token_type embeddings." + def __init__(self, cfg): + super().__init__() + self.pos_embed = nn.Embedding(cfg.p_dim, cfg.dim) # position embedding + self.norm = LayerNorm(cfg) + self.drop = nn.Dropout(cfg.p_drop_hidden) + + def forward(self, x): + seq_len = x.size(1) + pos = torch.arange(seq_len, dtype=torch.long, device=x.device) + pos = pos.unsqueeze(0).expand(x.size(0), -1) # (S,) -> (B, S) + + e = x + self.pos_embed(pos) + return self.drop(self.norm(e)) + + +class MultiHeadedSelfAttention(nn.Module): + """ Multi-Headed Dot Product Attention """ + def __init__(self, cfg): + super().__init__() + self.proj_q = nn.Linear(cfg.dim, cfg.dim) + self.proj_k = nn.Linear(cfg.dim, cfg.dim) + self.proj_v = nn.Linear(cfg.dim, cfg.dim) + self.drop = nn.Dropout(cfg.p_drop_attn) + self.scores = None # for visualization + self.n_heads = cfg.n_heads + + def forward(self, x, mask): + """ + x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) + mask : (B(batch_size) x S(seq_len)) + * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W + """ + # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) + q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) + q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) + for x in [q, k, v]) + # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) + scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) + if mask is not None: + mask = mask[:, None, None, :].float() + scores -= 10000.0 * (1.0 - mask) + scores = self.drop(F.softmax(scores, dim=-1)) + # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) + h = (scores @ v).transpose(1, 2).contiguous() + # -merge-> (B, S, D) + h = merge_last(h, 2) + self.scores = scores + return h + + +class PositionWiseFeedForward(nn.Module): + """ FeedForward Neural Networks for each position """ + def __init__(self, cfg): + super().__init__() + self.fc1 = nn.Linear(cfg.dim, cfg.dim_ff) + self.fc2 = nn.Linear(cfg.dim_ff, cfg.dim) + #self.activ = lambda x: activ_fn(cfg.activ_fn, x) + + def forward(self, x): + # (B, S, D) -> (B, S, D_ff) -> (B, S, D) + return self.fc2(gelu(self.fc1(x))) + + +class Block(nn.Module): + """ Transformer Block """ + def __init__(self, cfg): + super().__init__() + self.attn = MultiHeadedSelfAttention(cfg) + self.proj = nn.Linear(cfg.dim, cfg.dim) + self.norm1 = LayerNorm(cfg) + self.pwff = PositionWiseFeedForward(cfg) + self.norm2 = LayerNorm(cfg) + self.drop = nn.Dropout(cfg.p_drop_hidden) + + def forward(self, x, mask): + h = self.attn(x, mask) + h = self.norm1(x + self.drop(self.proj(h))) + h = self.norm2(h + self.drop(self.pwff(h))) + return h + + +class Transformer(nn.Module): + """ Transformer with Self-Attentive Blocks""" + def __init__(self, cfg, n_layers): + super().__init__() + self.embed = Embeddings(cfg) + self.blocks = nn.ModuleList([Block(cfg) for _ in range(n_layers)]) + + def forward(self, x, mask): + h = self.embed(x) + for block in self.blocks: + h = block(h, mask) + return h + + +class Parallel_Attention(nn.Module): + ''' the Parallel Attention Module for 2D attention + reference the origin paper: https://arxiv.org/abs/1906.05708 + ''' + def __init__(self, cfg): + super().__init__() + self.atten_w1 = nn.Linear(cfg.dim_c, cfg.dim_c) + self.atten_w2 = nn.Linear(cfg.dim_c, cfg.max_vocab_size) + self.activ_fn = nn.Tanh() + self.soft = nn.Softmax(dim=1) + self.drop = nn.Dropout(0.1) + + def forward(self, origin_I, bert_out, mask=None): + bert_out = self.activ_fn(self.drop(self.atten_w1(bert_out))) + atten_w = self.soft(self.atten_w2(bert_out)) # b*200*94 + x = torch.bmm(origin_I.transpose(1,2), atten_w) # b*512*94 + return x + + +class MultiHeadAttention(nn.Module): + ''' Multi-Head Attention module ''' + + def __init__(self, n_head=8, d_k=64, d_model=128, max_vocab_size=94, dropout=0.1): + ''' d_k: the attention dim + d_model: the encoder output feature + max_vocab_size: the output maxium length of sequence + ''' + super(MultiHeadAttention, self).__init__() + + self.n_head, self.d_k = n_head, d_k + self.temperature = np.power(d_k, 0.5) + self.max_vocab_size = max_vocab_size + + self.w_encoder = nn.Linear(d_model, n_head * d_k) + self.w_atten = nn.Linear(d_model, n_head * max_vocab_size) + self.w_out = nn.Linear(n_head * d_k, d_model) + self.activ_fn = nn.Tanh() + + self.softmax = nn.Softmax(dim=1) # at the d_in dimension + self.dropout = nn.Dropout(dropout) + + nn.init.normal_(self.w_encoder.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_(self.w_atten.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.xavier_normal_(self.w_out.weight) + + + def forward(self, encoder_feature, bert_out, mask=None): + + d_k, n_head, max_vocab_size = self.d_k, self.n_head, self.max_vocab_size + + sz_b, d_in, _ = encoder_feature.size() + + # 原始特征 + encoder_feature = encoder_feature.view(sz_b, d_in, n_head, d_k) + encoder_feature = encoder_feature.permute(2, 0, 1, 3).contiguous().view(-1, d_in, d_k) # 32*200*64 + + # 求解权值 + alpha = self.activ_fn(self.dropout(self.w_encoder(bert_out))) + alpha = self.w_atten(alpha).view(sz_b, d_in, n_head, max_vocab_size) # 4*200*8*94 + alpha = alpha.permute(2, 0, 1, 3).contiguous().view(-1, d_in, max_vocab_size) # 32*200*94 + alpha = alpha / self.temperature + alpha = self.dropout(self.softmax(alpha)) # 32*200*94 + + # 输出部分 + output = torch.bmm(encoder_feature.transpose(1,2), alpha) # 32*64*94 + output = output.view(n_head, sz_b, d_k, max_vocab_size) + output = output.permute(1, 3, 0, 2).contiguous().view(sz_b, max_vocab_size, -1) # 4*94*512 + output = self.dropout(self.w_out(output)) + output = output.transpose(1,2) + + return output + + +class Two_Stage_Decoder(nn.Module): + def __init__(self, cfg): + super().__init__() + self.out_w = nn.Linear(cfg.dim_c, cfg.len_alphabet) + self.relation_attention = Transformer(cfg, cfg.decoder_atten_layers) + self.out_w1 = nn.Linear(cfg.dim_c, cfg.len_alphabet) + + def forward(self, x): + x1 = self.out_w(x) + x2 = self.relation_attention(x, mask=None) + x2 = self.out_w1(x2) # 两个分支的输出部分采用不同的网络 + + return x1, x2 + + +class Bert_Ocr(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.transformer = Transformer(cfg, cfg.attention_layers) + self.attention = Parallel_Attention(cfg) +# self.attention = MultiHeadAttention(d_model=cfg.dim, max_vocab_size=cfg.max_vocab_size) + self.decoder = Two_Stage_Decoder(cfg) + + def forward(self, encoder_feature, mask=None): + bert_out = self.transformer(encoder_feature, mask) # 做一个self_attention//4*200*512 + glimpses = self.attention(encoder_feature, bert_out, mask) # 原始序列和目标序列的转化//4*512*94 + res = self.decoder(glimpses.transpose(1,2)) + return res + + +class Config(object): + '''参数设置''' + """ Relation Attention Module """ + p_drop_attn = 0.1 + p_drop_hidden = 0.1 + dim = 512 # the encode output feature + attention_layers = 2 # the layers of transformer + n_heads = 8 + dim_ff = 1024 * 2 # 位置前向传播的隐含层维度 + + ''' Parallel Attention Module ''' + dim_c = dim + max_vocab_size = 26 # 一张图片含有字符的最大长度 + + """ Two-stage Decoder """ + len_alphabet = 39 # 字符类别数量 + decoder_atten_layers = 2 + + +def numel(model): + return sum(p.numel() for p in model.parameters()) + + +if __name__ == '__main__': + + cfg = Config() + mask = None + x = torch.randn(4, 200, cfg.dim) + net = Bert_Ocr(cfg) + res1, res2 = net(x, mask) + print(res1.shape, res2.shape) + print('参数总量为:', numel(net)) diff --git a/modules_srn/feature_extraction.py b/modules_srn/feature_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..eb0d22afc3f5c91713e85d22e3ca6b37a2061a9d --- /dev/null +++ b/modules_srn/feature_extraction.py @@ -0,0 +1,248 @@ +# coding:utf-8 +# 2020-05-11 +import torch.nn as nn +import torch.nn.functional as F + + +class VGG_FeatureExtractor(nn.Module): + """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super(VGG_FeatureExtractor, self).__init__() + self.output_channel = [int(output_channel / 8), int(output_channel / 4), + int(output_channel / 2), output_channel] # [64, 128, 256, 512] + self.ConvNet = nn.Sequential( + nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d(2, 2), # 64x16x50 + nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d(2, 2), # 128x8x25 + nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25 + nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 + nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False), + nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25 + nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False), + nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), + nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 + nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 + + def forward(self, input): + return self.ConvNet(input) + + +class RCNN_FeatureExtractor(nn.Module): + """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super(RCNN_FeatureExtractor, self).__init__() + self.output_channel = [int(output_channel / 8), int(output_channel / 4), + int(output_channel / 2), output_channel] # [64, 128, 256, 512] + self.ConvNet = nn.Sequential( + nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d(2, 2), # 64 x 16 x 50 + GRCL(self.output_channel[0], self.output_channel[0], num_iteration=5, kernel_size=3, pad=1), + nn.MaxPool2d(2, 2), # 64 x 8 x 25 + GRCL(self.output_channel[0], self.output_channel[1], num_iteration=5, kernel_size=3, pad=1), + nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26 + GRCL(self.output_channel[1], self.output_channel[2], num_iteration=5, kernel_size=3, pad=1), + nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27 + nn.Conv2d(self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False), + nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26 + + def forward(self, input): + return self.ConvNet(input) + + +class ResNet_FeatureExtractor(nn.Module): + """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super(ResNet_FeatureExtractor, self).__init__() + self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) + + def forward(self, input): + return self.ConvNet(input) + + +# For Gated RCNN +class GRCL(nn.Module): + + def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad): + super(GRCL, self).__init__() + self.wgf_u = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False) + self.wgr_x = nn.Conv2d(output_channel, output_channel, 1, 1, 0, bias=False) + self.wf_u = nn.Conv2d(input_channel, output_channel, kernel_size, 1, pad, bias=False) + self.wr_x = nn.Conv2d(output_channel, output_channel, kernel_size, 1, pad, bias=False) + + self.BN_x_init = nn.BatchNorm2d(output_channel) + + self.num_iteration = num_iteration + self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)] + self.GRCL = nn.Sequential(*self.GRCL) + + def forward(self, input): + """ The input of GRCL is consistant over time t, which is denoted by u(0) + thus wgf_u / wf_u is also consistant over time t. + """ + wgf_u = self.wgf_u(input) + wf_u = self.wf_u(input) + x = F.relu(self.BN_x_init(wf_u)) + + for i in range(self.num_iteration): + x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x)) + + return x + + +class GRCL_unit(nn.Module): + + def __init__(self, output_channel): + super(GRCL_unit, self).__init__() + self.BN_gfu = nn.BatchNorm2d(output_channel) + self.BN_grx = nn.BatchNorm2d(output_channel) + self.BN_fu = nn.BatchNorm2d(output_channel) + self.BN_rx = nn.BatchNorm2d(output_channel) + self.BN_Gx = nn.BatchNorm2d(output_channel) + + def forward(self, wgf_u, wgr_x, wf_u, wr_x): + G_first_term = self.BN_gfu(wgf_u) + G_second_term = self.BN_grx(wgr_x) + G = F.sigmoid(G_first_term + G_second_term) + + x_first_term = self.BN_fu(wf_u) + x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G) + x = F.relu(x_first_term + x_second_term) + + return x + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = self._conv3x3(inplanes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = self._conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def _conv3x3(self, in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = F.relu_(out.clone()) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = F.relu_(out.clone()) + + return out + + +class ResNet(nn.Module): + + def __init__(self, input_channel, output_channel, block, layers): + super(ResNet, self).__init__() + + self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] + + self.inplanes = int(output_channel / 8) + self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) + self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_2 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU(inplace=True) + + self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) + self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ + 0], kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) + + self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) + self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ + 1], kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) + + self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) + self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) + self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ + 2], kernel_size=3, stride=1, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) + + self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) + self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) + self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) + self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=1, padding=0, bias=False) + self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv0_1(x) + x = self.bn0_1(x) + x = F.relu_(x.clone()) + x = self.conv0_2(x) + x = self.bn0_2(x) + x = F.relu_(x.clone()) + + x = self.maxpool1(x) + x = self.layer1(x) + x = self.conv1(x) + x = self.bn1(x) + x = F.relu_(x.clone()) + + x = self.maxpool2(x) + x = self.layer2(x) + x = self.conv2(x) + x = self.bn2(x) + x = F.relu_(x.clone()) + + x = self.maxpool3(x) + x = self.layer3(x) + x = self.conv3(x) + x = self.bn3(x) + x = F.relu_(x.clone()) + + x = self.layer4(x) + x = self.conv4_1(x) + x = self.bn4_1(x) + x = F.relu_(x.clone()) + x = self.conv4_2(x) + x = self.bn4_2(x) + x = F.relu_(x.clone()) + + return x diff --git a/modules_srn/optimizer/__init__.py b/modules_srn/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules_srn/optimizer/ranger.py b/modules_srn/optimizer/ranger.py new file mode 100644 index 0000000000000000000000000000000000000000..5702d014e7b5321e6461c46445fd2cebb32b737e --- /dev/null +++ b/modules_srn/optimizer/ranger.py @@ -0,0 +1,143 @@ +#Ranger deep learning optimizer - RAdam + Lookahead combined. +#https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer + +import math +import torch +from torch.optim.optimizer import Optimizer, required +import itertools as it +#from torch.optim import Optimizer +#credit - Lookahead implementation from LonePatient - https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py +#credit2 - RAdam code by https://github.com/LiyuanLucasLiu/RAdam/blob/master/radam.py +#changes 8/31/19 - fix references to *self*.N_sma_threshold; + #changed eps to 1e-5 as better default than 1e-8. + +class Ranger(Optimizer): + + def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95,0.999), eps=1e-5, weight_decay=0): + #parameter checks + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + if not lr > 0: + raise ValueError(f'Invalid Learning Rate: {lr}') + if not eps > 0: + raise ValueError(f'Invalid eps: {eps}') + + #parameter comments: + # beta1 (momentum) of .95 seems to work better than .90... + #N_sma_threshold of 5 seems better in testing than 4. + #In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. + + #prep defaults and init torch.optim base + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super().__init__(params,defaults) + + #adjustable threshold + self.N_sma_threshhold = N_sma_threshhold + + #now we can get to work... + for group in self.param_groups: + group["step_counter"] = 0 + #print("group step counter init") + + #look ahead params + self.alpha = alpha + self.k = k + + #radam buffer for state + self.radam_buffer = [[None,None,None] for ind in range(10)] + + #lookahead weights + self.slow_weights = [[p.clone().detach() for p in group['params']] + for group in self.param_groups] + + #don't use grad for lookahead weights + for w in it.chain(*self.slow_weights): + w.requires_grad = False + + def __setstate__(self, state): + print("set state called") + super(Ranger, self).__setstate__(state) + + + def step(self, closure=None): + loss = None + #note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. + #Uncomment if you need to use the actual closure... + + #if closure is not None: + #loss = closure() + + #------------ radam + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + buffered = self.radam_buffer[int(state['step'] % 10)] + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + if N_sma > self.N_sma_threshhold: + step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = 1.0 / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + if N_sma > self.N_sma_threshhold: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) + else: + p_data_fp32.add_(-step_size * group['lr'], exp_avg) + + p.data.copy_(p_data_fp32) + + + #---------------- end radam step + + #look ahead tracking and updating if latest batch = k + for group,slow_weights in zip(self.param_groups,self.slow_weights): + group['step_counter'] += 1 + if group['step_counter'] % self.k != 0: + continue + for p,q in zip(group['params'],slow_weights): + if p.grad is None: + continue + q.data.add_(self.alpha,p.data - q.data) + p.data.copy_(q.data) + + + + return loss \ No newline at end of file diff --git a/modules_srn/prediction.py b/modules_srn/prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..37afab4c06a855e5bf3e83a64c1ee02a127a34d1 --- /dev/null +++ b/modules_srn/prediction.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Attention(nn.Module): + + def __init__(self, input_size, hidden_size, num_classes): + super(Attention, self).__init__() + self.attention_cell = AttentionCell(input_size, hidden_size, num_classes) + self.hidden_size = hidden_size + self.num_classes = num_classes + self.generator = nn.Linear(hidden_size, num_classes) + + def _char_to_onehot(self, input_char, onehot_dim=38): + input_char = input_char.unsqueeze(1) + batch_size = input_char.size(0) + one_hot = torch.cuda.FloatTensor(batch_size, onehot_dim).zero_() + one_hot = one_hot.scatter_(1, input_char, 1) + return one_hot + + def forward(self, batch_H, text, is_train=True, batch_max_length=25): + """ + input: + batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes] + text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. + output: probability distribution at each step [batch_size x num_steps x num_classes] + """ + batch_size = batch_H.size(0) + num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. + + output_hiddens = torch.cuda.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0) + hidden = (torch.cuda.FloatTensor(batch_size, self.hidden_size).fill_(0), + torch.cuda.FloatTensor(batch_size, self.hidden_size).fill_(0)) + + if is_train: + for i in range(num_steps): + # one-hot vectors for a i-th char. in a batch + char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes) + # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) + hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) + output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) + probs = self.generator(output_hiddens) + + else: + targets = torch.cuda.LongTensor(batch_size).fill_(0) # [GO] token + probs = torch.cuda.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0) + + for i in range(num_steps): + char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) + probs_step = self.generator(hidden[0]) + probs[:, i, :] = probs_step + _, next_input = probs_step.max(1) + targets = next_input + + return probs # batch_size x num_steps x num_classes + + +class AttentionCell(nn.Module): + + def __init__(self, input_size, hidden_size, num_embeddings): + super(AttentionCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias=False) + self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias + self.score = nn.Linear(hidden_size, 1, bias=False) + self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) + e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 + + alpha = F.softmax(e, dim=1) + context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel + concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) + cur_hidden = self.rnn(concat_context, prev_hidden) + return cur_hidden, alpha diff --git a/modules_srn/resnet_aster.py b/modules_srn/resnet_aster.py new file mode 100644 index 0000000000000000000000000000000000000000..10229efd7249986875169f9a5a2c840a0a523a88 --- /dev/null +++ b/modules_srn/resnet_aster.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +import torchvision + +import sys +import math + +# from config import get_args +# global_args = get_args(sys.argv[1:]) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000): + # [n_position] + positions = torch.arange(0, n_position)#.cuda() + # [feat_dim] + dim_range = torch.arange(0, feat_dim)#.cuda() + dim_range = torch.pow(wave_length, 2 * (dim_range // 2) / feat_dim) + # [n_position, feat_dim] + angles = positions.unsqueeze(1) / dim_range.unsqueeze(0) + angles = angles.float() + angles[:, 0::2] = torch.sin(angles[:, 0::2]) + angles[:, 1::2] = torch.cos(angles[:, 1::2]) + return angles + + +class AsterBlock(nn.Module): + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(AsterBlock, self).__init__() + self.conv1 = conv1x1(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class ResNet_ASTER(nn.Module): + """For aster or crnn + borrowed from: https://github.com/ayumiymk/aster.pytorch + """ + def __init__(self, in_channels=1, out_channel=512, n_group=1): + super(ResNet_ASTER, self).__init__() + self.n_group = n_group + + in_channels = in_channels + self.layer0 = nn.Sequential( + nn.Conv2d(in_channels, 32, kernel_size=(3, 3), stride=1, padding=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True)) + + self.inplanes = 32 + self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50] + self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25] + self.layer3 = self._make_layer(128, 6, [2, 2]) # [4, 25] + self.layer4 = self._make_layer(256, 6, [1, ]) # [2, 25] + self.layer5 = self._make_layer(out_channel, 3, [1, 1]) # [1, 25] + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, planes, blocks, stride): + downsample = None + if stride != [1, 1] or self.inplanes != planes: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes, stride), + nn.BatchNorm2d(planes)) + + layers = [] + layers.append(AsterBlock(self.inplanes, planes, stride, downsample)) + self.inplanes = planes + for _ in range(1, blocks): + layers.append(AsterBlock(self.inplanes, planes)) + return nn.Sequential(*layers) + + def forward(self, x): + + x0 = self.layer0(x) + x1 = self.layer1(x0) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + x5 = self.layer5(x4) + + return x5 + + +def numel(model): + return sum(p.numel() for p in model.parameters()) + +if __name__ == "__main__": + x = torch.randn(3, 1, 64, 256) + net = ResNet_ASTER() + encoder_feat = net(x) + print(encoder_feat.size()) # 3*512*h/4*w/4 + + num_params = numel(net) + print(f'Number of parameters: {num_params}') \ No newline at end of file diff --git a/modules_srn/resnet_fpn.py b/modules_srn/resnet_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..0684478aed876cea9b9d4871afa3b7f6faf0399a --- /dev/null +++ b/modules_srn/resnet_fpn.py @@ -0,0 +1,321 @@ +# -------------------------------------------------------- +# Pytorch Faster R-CNN and FPN +# Licensed under The MIT License [see LICENSE for details] +# Written by Zheqi He and Xinlei Chen, Yixiao Ge +# https://github.com/yxgeee/pytorch-FPN/blob/master/lib/nets/resnet_v1.py +# -------------------------------------------------------- +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import torch.utils.model_zoo as model_zoo + + +__all__ = [ + 'ResNet_FPN', + 'ResNet', + 'resnet18', + 'resnet34', + 'resnet50', + 'resnet101', + 'resnet152'] + + +model_urls = { + 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', + 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', + 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, + planes, + kernel_size=1, + stride=stride, + bias=False) # change + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class BuildBlock(nn.Module): + def __init__(self, planes=512): + super(BuildBlock, self).__init__() + + self.planes = planes + # Top-down layers, use nn.ConvTranspose2d to replace + # nn.Conv2d+F.upsample? + self.toplayer1 = nn.Conv2d( + 2048, + planes, + kernel_size=1, + stride=1, + padding=0) # Reduce channels + self.toplayer2 = nn.Conv2d( + 512, planes, kernel_size=3, stride=1, padding=1) + self.toplayer3 = nn.Conv2d( + 512, planes, kernel_size=3, stride=1, padding=1) + + # Lateral layers + self.latlayer1 = nn.Conv2d( + 1024, planes, kernel_size=1, stride=1, padding=0) + self.latlayer2 = nn.Conv2d( + 512, planes, kernel_size=1, stride=1, padding=0) + + def _upsample_add(self, x, y): + _, _, H, W = y.size() + return F.upsample( + x, + size=( + H, + W), + mode='bilinear', + align_corners=True) + y + + def forward(self, c3, c4, c5): + # Top-down + p5 = self.toplayer1(c5) + p4 = self._upsample_add(p5, self.latlayer1(c4)) + p4 = self.toplayer2(p4) + p3 = self._upsample_add(p4, self.latlayer2(c3)) + p3 = self.toplayer3(p3) + + return p3, p4, p5 + + +class ResNet(nn.Module): + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNet, self).__init__() + # the symbol is referred to fots. + # Conv1 /2 + self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + # Pool1 /4 + # maxpool different from pytorch-resnet, to match tf-faster-rcnn + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer( + block, 64, layers[0], stride=1) # Res2 /4 + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2) # Res3 /8 + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2) # Res4 /16 + # use stride 1 for the last conv4 layer (same as tf-faster-rcnn) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2) # Res5 /32 + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + +def resnet18(pretrained=False): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2]) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + return model + + +def resnet34(pretrained=False): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3]) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + return model + + +def resnet50(pretrained=False): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3]) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + return model + + +def resnet101(pretrained=False): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3]) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + return model + + +def resnet152(pretrained=False): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3]) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + return model + + +class ResNet_FPN(nn.Module): + def __init__(self, num_layers=50): + super(ResNet_FPN, self).__init__() + self._num_layers = num_layers + self._layers = {} + + self._init_head_tail() + self.out_planes = self.fpn.planes + + def forward(self, x): + c2 = self.head1(x) + c3 = self.head2(c2) + c4 = self.head3(c3) + c5 = self.head4(c4) + p3, p4, p5 = self.fpn( c3, c4, c5) + # net_conv = [p2, p3, p4, p5] + + # return p2, [x, self.resnet.conv1(x), c2] + return p3 + + def _init_head_tail(self): + # choose different blocks for different number of layers + if self._num_layers == 50: + self.resnet = resnet50() + + elif self._num_layers == 101: + self.resnet = resnet101() + + elif self._num_layers == 152: + self.resnet = resnet152() + + else: + # other numbers are not supported + raise NotImplementedError + + # Build Building Block for FPN + self.fpn = BuildBlock() + self.head1 = nn.Sequential( + self.resnet.conv1, + self.resnet.bn1, + self.resnet.relu, + self.resnet.maxpool, + self.resnet.layer1) # /4 + self.head2 = nn.Sequential(self.resnet.layer2) # /8 + self.head3 = nn.Sequential(self.resnet.layer3) # /16 + self.head4 = nn.Sequential(self.resnet.layer4) # /32 + + +if __name__=='__main__': + model = ResNet_FPN() + + x = torch.randn((2,1,64,256)) + y = model(x) + print(y.shape) \ No newline at end of file diff --git a/modules_srn/sequence_modeling.py b/modules_srn/sequence_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..af32c59b2cc981be1b43412ddf4ac853d0611210 --- /dev/null +++ b/modules_srn/sequence_modeling.py @@ -0,0 +1,19 @@ +import torch.nn as nn + + +class BidirectionalLSTM(nn.Module): + + def __init__(self, input_size, hidden_size, output_size): + super(BidirectionalLSTM, self).__init__() + self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) + self.linear = nn.Linear(hidden_size * 2, output_size) + + def forward(self, input): + """ + input : visual feature [batch_size x T x input_size] + output : contextual feature [batch_size x T x output_size] + """ + self.rnn.flatten_parameters() + recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) + output = self.linear(recurrent) # batch_size x T x output_size + return output diff --git a/modules_srn/transformation.py b/modules_srn/transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..893147d0141c59341b0a83089ed708c5ec32677a --- /dev/null +++ b/modules_srn/transformation.py @@ -0,0 +1,155 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TPS_SpatialTransformerNetwork(nn.Module): + """ Rectification Network of RARE, namely TPS based STN """ + + def __init__(self, F, I_size, I_r_size, I_channel_num=1): + """ Based on RARE TPS + input: + batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] + I_size : (height, width) of the input image I + I_r_size : (height, width) of the rectified image I_r + I_channel_num : the number of channels of the input image I + output: + batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] + """ + super(TPS_SpatialTransformerNetwork, self).__init__() + self.F = F + self.I_size = I_size + self.I_r_size = I_r_size # = (I_r_height, I_r_width) + self.I_channel_num = I_channel_num + self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) + self.GridGenerator = GridGenerator(self.F, self.I_r_size) + + def forward(self, batch_I): + batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 + build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 + build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') + + return batch_I_r + + +class LocalizationNetwork(nn.Module): + """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ + + def __init__(self, F, I_channel_num): + super(LocalizationNetwork, self).__init__() + self.F = F + self.I_channel_num = I_channel_num + self.conv = nn.Sequential( + nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, + bias=False), nn.BatchNorm2d(64), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 + nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 + nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 + nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), + nn.AdaptiveAvgPool2d(1) # batch_size x 512 + ) + + self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) + self.localization_fc2 = nn.Linear(256, self.F * 2) + + # Init fc2 in LocalizationNetwork + self.localization_fc2.weight.data.fill_(0) + """ see RARE paper Fig. 6 (a) """ + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) + ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) + + def forward(self, batch_I): + """ + input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] + output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] + """ + batch_size = batch_I.size(0) + features = self.conv(batch_I).view(batch_size, -1) + batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) + return batch_C_prime + + +class GridGenerator(nn.Module): + """ Grid Generator of RARE, which produces P_prime by multipling T with P """ + + def __init__(self, F, I_r_size): + """ Generate P_hat and inv_delta_C for later """ + super(GridGenerator, self).__init__() + self.eps = 1e-6 + self.I_r_height, self.I_r_width = I_r_size + self.F = F + self.C = self._build_C(self.F) # F x 2 + self.P = self._build_P(self.I_r_width, self.I_r_height) + self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 + self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 + + def _build_C(self, F): + """ Return coordinates of fiducial points in I_r; C """ + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = -1 * np.ones(int(F / 2)) + ctrl_pts_y_bottom = np.ones(int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + return C # F x 2 + + def _build_inv_delta_C(self, F, C): + """ Return inv_delta_C which is needed to calculate T """ + hat_C = np.zeros((F, F), dtype=float) # F x F + for i in range(0, F): + for j in range(i, F): + r = np.linalg.norm(C[i] - C[j]) + hat_C[i, j] = r + hat_C[j, i] = r + np.fill_diagonal(hat_C, 1) + hat_C = (hat_C ** 2) * np.log(hat_C) + # print(C.shape, hat_C.shape) + delta_C = np.concatenate( # F+3 x F+3 + [ + np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 + np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 + np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 + ], + axis=0 + ) + inv_delta_C = np.linalg.inv(delta_C) + return inv_delta_C # F+3 x F+3 + + def _build_P(self, I_r_width, I_r_height): + I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width + I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height + P = np.stack( # self.I_r_width x self.I_r_height x 2 + np.meshgrid(I_r_grid_x, I_r_grid_y), + axis=2 + ) + return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 + + def _build_P_hat(self, F, C, P): + n = P.shape[0] # n (= self.I_r_width x self.I_r_height) + P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 + C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 + P_diff = P_tile - C_tile # n x F x 2 + rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F + rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F + P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) + return P_hat # n x F+3 + + def build_P_prime(self, batch_C_prime): + """ Generate Grid from batch_C_prime [batch_size x F x 2] """ + batch_size = batch_C_prime.size(0) + batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) + batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) + batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( + batch_size, 3, 2).float().cuda()), dim=1) # batch_size x F+3 x 2 + batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 + batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 + return batch_P_prime # batch_size x n x 2 diff --git a/modules_trba/extras.py b/modules_trba/extras.py new file mode 100644 index 0000000000000000000000000000000000000000..0075ab61dd2c4f55b23091766a9106a4723d9670 --- /dev/null +++ b/modules_trba/extras.py @@ -0,0 +1,182 @@ +import matplotlib.pyplot as plt +import matplotlib.cm as mpl_color_map +import copy +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from modules.guided_backprop import GuidedBackprop +import sys,os +sys.path.append(os.getcwd()) + +def apply_colormap_on_image(org_im, activation, colormap_name): + """ + Apply heatmap on image + Args: + org_img (PIL img): Original image + activation_map (numpy arr): Activation map (grayscale) 0-255 + colormap_name (str): Name of the colormap + """ + # Get colormap + color_map = mpl_color_map.get_cmap(colormap_name) + no_trans_heatmap = color_map(activation) + # Change alpha channel in colormap to make sure original image is displayed + heatmap = copy.copy(no_trans_heatmap) + heatmap[:, :, 3] = 0.4 + heatmap = Image.fromarray((heatmap*255).astype(np.uint8)) + no_trans_heatmap = Image.fromarray((no_trans_heatmap*255).astype(np.uint8)) + + # Apply heatmap on iamge + org_im = np.uint8(org_im.detach().to("cpu").numpy()[0][0]*255) + org_im = Image.fromarray(org_im) + heatmap_on_image = Image.new("RGBA", org_im.size) + heatmap_on_image = Image.alpha_composite(heatmap_on_image, org_im.convert('RGBA')) + heatmap_on_image = Image.alpha_composite(heatmap_on_image, heatmap) + return no_trans_heatmap, heatmap_on_image + +def save_gradient_images(gradient, file_name): + """ + Exports the original gradient image + + Args: + gradient (np arr): Numpy array of the gradient with shape (3, 224, 224) + file_name (str): Full filename including directory and png + """ + if not os.path.exists('../results'): + os.makedirs('../results') + # Normalize + gradient = gradient - gradient.min() + gradient /= gradient.max() + # Save image + path_to_file = file_name + # print("gradient save shape: ", gradient.shape) + save_image(gradient, path_to_file) + +def format_np_output(np_arr): + """ + This is a (kind of) bandaid fix to streamline saving procedure. + It converts all the outputs to the same format which is 3xWxH + with using sucecssive if clauses. + Args: + im_as_arr (Numpy array): Matrix of shape 1xWxH or WxH or 3xWxH + """ + # Phase/Case 1: The np arr only has 2 dimensions + # Result: Add a dimension at the beginning + if len(np_arr.shape) == 2: + np_arr = np.expand_dims(np_arr, axis=0) + # Phase/Case 2: Np arr has only 1 channel (assuming first dim is channel) + # Result: Repeat first channel and convert 1xWxH to 3xWxH + if np_arr.shape[0] == 1: + np_arr = np.repeat(np_arr, 3, axis=0) + # Phase/Case 3: Np arr is of shape 3xWxH + # Result: Convert it to WxHx3 in order to make it saveable by PIL + if np_arr.shape[0] == 3: + np_arr = np_arr.transpose(1, 2, 0) + # Phase/Case 4: NP arr is normalized between 0-1 + # Result: Multiply with 255 and change type to make it saveable by PIL + if np.max(np_arr) <= 1: + np_arr = (np_arr*255).astype(np.uint8) + return np_arr + +def save_image(im, path): + """ + Saves a numpy matrix or PIL image as an image + Args: + im_as_arr (Numpy array): Matrix of shape DxWxH + path (str): Path to the image + """ + if isinstance(im, (np.ndarray, np.generic)): + im = format_np_output(im) + im = Image.fromarray(im) + im.save(path) + +def module_output_to_numpy(tensor): + return tensor.data.to('cpu').numpy() + +def convert_to_grayscale(im_as_arr): + """ + Converts 3d image to grayscale + + Args: + im_as_arr (numpy arr): RGB image with shape (D,W,H) + + returns: + grayscale_im (numpy_arr): Grayscale image with shape (1,W,D) + """ + grayscale_im = np.sum(np.abs(im_as_arr), axis=0) + im_max = np.percentile(grayscale_im, 99) + im_min = np.min(grayscale_im) + grayscale_im = (np.clip((grayscale_im - im_min) / (im_max - im_min), 0, 1)) + grayscale_im = np.expand_dims(grayscale_im, axis=0) + return grayscale_im +class SaveOutput: + def __init__(self, totalFeatMaps): + self.layer_outputs = [] + self.grad_outputs = [] + self.first_grads = [] + self.totalFeatMaps = totalFeatMaps + self.feature_ext = None + ### Used on register_forward_hook + ### Output up to totalFeatMaps + def append_layer_out(self, module, input, output): + self.layer_outputs.append(output[0]) ### Appending with earlier index pertaining to earlier layers + ### Used on register_backward_hook + ### Output up to totalFeatMaps + def append_grad_out(self, module, grad_input, grad_output): + self.grad_outputs.append(grad_output[0][0]) ### Appending with last-to-first index pertaining to first-to-last layers + ### Used as guided backprop mask + def append_first_grads(self, module, grad_in, grad_out): + self.first_grads.append(grad_in[0]) + def clear(self): + self.layer_outputs = [] + self.grad_outputs = [] + self.first_grads = [] + def set_feature_ext(self, feature_ext): + self.feature_ext = feature_ext + def getGuidedGradImg(self, layerNum, input_img): + # print("layer outputs shape: ", self.layer_outputs[0].shape) + # print("layer grad_outputs shape: ", self.grad_outputs[0].shape) + conv_output_img = module_output_to_numpy(self.layer_outputs[layerNum]) + grad_output_img = module_output_to_numpy(self.grad_outputs[len(self.grad_outputs)-layerNum-1]) + first_grad_output = self.first_grads[0].data.to('cpu').numpy()[0] + print("conv_output_img output shape: ", conv_output_img.shape) + print("grad_output_img output shape: ", grad_output_img.shape) + print("first_grad_output output shape: ", first_grad_output.shape) + print("target min max: {}, {}".format(conv_output_img.min(), conv_output_img.max())) + print("guided_gradients min max: {}, {}".format(grad_output_img.min(), grad_output_img.max())) + weights = np.mean(grad_output_img, axis=(1, 2)) # Take averages for each gradient + print("weights shape: ", weights.shape) + print("weights min max1: {}, {}".format(weights.min(), weights.max())) + # Create empty numpy array for cam + # conv_output_img = np.clip(conv_output_img, 0, conv_output_img.max()) + cam = np.ones(conv_output_img.shape[1:], dtype=np.float32) + print("cam min max1: {}, {}".format(cam.min(), cam.max())) + # Multiply each weight with its conv output and then, sum + for i, w in enumerate(weights): + cam += w * conv_output_img[i, :, :] + # cam = np.maximum(cam, 0) + print("cam min max2: {}, {}".format(cam.min(), cam.max())) + cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) # Normalize between 0-1 + cam = np.uint8(cam * 255) # Scale between 0-255 to visualize + cam = np.uint8(Image.fromarray(cam).resize((input_img.shape[3], + input_img.shape[2]), Image.ANTIALIAS))/255 + # cam_gb = np.multiply(cam, first_grad_output) + # grayscale_cam_gb = convert_to_grayscale(cam) + return cam + def getGuidedGradTimesImg(self, layerNum, input_img): + grad_output_img = module_output_to_numpy(self.grad_outputs[len(self.grad_outputs)-layerNum-1]) + print("grad_output_img output shape: ", grad_output_img.shape) + grad_times_image = grad_output_img[0]*input_img.detach().to("cpu").numpy()[0] + return grad_times_image + ### target_output -- pass a created output tensor with one hot (1s) already in placed, used for guided gradients (first layer) + def output_feature_maps(self, targetDir, input_img): + # GBP = GuidedBackprop(self.feature_ext, 'resnet34') + # guided_grads = GBP.generate_gradients(input_img, one_hot_output_guided, text_for_pred) + # print("guided_grads shape: ", guided_grads.shape) + for layerNum in range(self.totalFeatMaps): + grad_times_image = self.getGuidedGradTimesImg(layerNum, input_img) + # save_gradient_images(cam_gb, targetDir + 'GGrad_Cam_Layer{}.jpg'.format(layerNum)) + # save_gradient_images(grayscale_cam_gb, targetDir + 'GGrad_Cam_Gray_Layer{}.jpg'.format(layerNum)) + ### Output heatmaps + grayscale_vanilla_grads = convert_to_grayscale(grad_times_image) + save_gradient_images(grayscale_vanilla_grads, targetDir + 'Vanilla_grad_times_image_gray{}.jpg'.format(layerNum)) diff --git a/modules_trba/feature_extraction.py b/modules_trba/feature_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..41e93c17090841a6f0296543fdbfce39828e35d5 --- /dev/null +++ b/modules_trba/feature_extraction.py @@ -0,0 +1,259 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class VGG_FeatureExtractor(nn.Module): + """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super(VGG_FeatureExtractor, self).__init__() + self.output_channel = [int(output_channel / 8), int(output_channel / 4), + int(output_channel / 2), output_channel] # [64, 128, 256, 512] + self.ConvNet = nn.Sequential( + nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d(2, 2), # 64x16x50 + nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d(2, 2), # 128x8x25 + nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25 + nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 + nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False), + nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25 + nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False), + nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), + nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 + nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 + + def forward(self, input): + return self.ConvNet(input) + + +class RCNN_FeatureExtractor(nn.Module): + """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super(RCNN_FeatureExtractor, self).__init__() + self.output_channel = [int(output_channel / 8), int(output_channel / 4), + int(output_channel / 2), output_channel] # [64, 128, 256, 512] + self.ConvNet = nn.Sequential( + nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d(2, 2), # 64 x 16 x 50 + GRCL(self.output_channel[0], self.output_channel[0], num_iteration=5, kernel_size=3, pad=1), + nn.MaxPool2d(2, 2), # 64 x 8 x 25 + GRCL(self.output_channel[0], self.output_channel[1], num_iteration=5, kernel_size=3, pad=1), + nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26 + GRCL(self.output_channel[1], self.output_channel[2], num_iteration=5, kernel_size=3, pad=1), + nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27 + nn.Conv2d(self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False), + nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26 + + def forward(self, input): + return self.ConvNet(input) + + +class ResNet_FeatureExtractor(nn.Module): + """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super(ResNet_FeatureExtractor, self).__init__() + self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) + + def forward(self, input): + return self.ConvNet(input) + + +# For Gated RCNN +class GRCL(nn.Module): + + def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad): + super(GRCL, self).__init__() + self.wgf_u = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False) + self.wgr_x = nn.Conv2d(output_channel, output_channel, 1, 1, 0, bias=False) + self.wf_u = nn.Conv2d(input_channel, output_channel, kernel_size, 1, pad, bias=False) + self.wr_x = nn.Conv2d(output_channel, output_channel, kernel_size, 1, pad, bias=False) + + self.BN_x_init = nn.BatchNorm2d(output_channel) + + self.num_iteration = num_iteration + self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)] + self.GRCL = nn.Sequential(*self.GRCL) + + def forward(self, input): + """ The input of GRCL is consistant over time t, which is denoted by u(0) + thus wgf_u / wf_u is also consistant over time t. + """ + wgf_u = self.wgf_u(input) + wf_u = self.wf_u(input) + x = F.relu(self.BN_x_init(wf_u)) + + for i in range(self.num_iteration): + x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x)) + + return x + + +class GRCL_unit(nn.Module): + + def __init__(self, output_channel): + super(GRCL_unit, self).__init__() + self.BN_gfu = nn.BatchNorm2d(output_channel) + self.BN_grx = nn.BatchNorm2d(output_channel) + self.BN_fu = nn.BatchNorm2d(output_channel) + self.BN_rx = nn.BatchNorm2d(output_channel) + self.BN_Gx = nn.BatchNorm2d(output_channel) + + def forward(self, wgf_u, wgr_x, wf_u, wr_x): + G_first_term = self.BN_gfu(wgf_u) + G_second_term = self.BN_grx(wgr_x) + G = F.sigmoid(G_first_term + G_second_term) + + x_first_term = self.BN_fu(wf_u) + x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G) + x = F.relu(x_first_term + x_second_term) + + return x + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = self._conv3x3(inplanes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = self._conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU() + self.relu1 = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def _conv3x3(self, in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu1(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, input_channel, output_channel, block, layers): + super(ResNet, self).__init__() + + self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] + self.inplanes = int(output_channel / 8) + self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) + self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_2 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU() + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.relu3 = nn.ReLU() + self.relu4 = nn.ReLU() + self.relu5 = nn.ReLU() + self.relu6 = nn.ReLU() + + self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) + self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ + 0], kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) + + self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) + self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ + 1], kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) + + self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) + self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) + self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ + 2], kernel_size=3, stride=1, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) + + self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) + self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) + self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) + self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=1, padding=0, bias=False) + self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv0_1(x) + x = self.bn0_1(x) + # print("x0 shape: ", x.shape) + x = self.relu(x) + x = self.conv0_2(x) + x = self.bn0_2(x) + # print("x1 shape: ", x.shape) + x = self.relu1(x) + + x = self.maxpool1(x) + x = self.layer1(x) + x = self.conv1(x) + x = self.bn1(x) + # print("x2 shape: ", x.shape) + x = self.relu2(x) + + x = self.maxpool2(x) + x = self.layer2(x) + x = self.conv2(x) + x = self.bn2(x) + # print("x3 shape: ", x.shape) + x = self.relu3(x) + + x = self.maxpool3(x) + x = self.layer3(x) + x = self.conv3(x) + x = self.bn3(x) + # print("x4 shape: ", x.shape) + x = self.relu4(x) + + x = self.layer4(x) + x = self.conv4_1(x) + x = self.bn4_1(x) + # print("x5 shape: ", x.shape) + x = self.relu5(x) + x = self.conv4_2(x) + x = self.bn4_2(x) + # print("x6 shape: ", x.shape) + x = self.relu6(x) + + return x diff --git a/modules_trba/guided_backprop.py b/modules_trba/guided_backprop.py new file mode 100644 index 0000000000000000000000000000000000000000..ad3533600973d957d122a7eb4a11a5dd6100fa26 --- /dev/null +++ b/modules_trba/guided_backprop.py @@ -0,0 +1,101 @@ +""" +Created on Thu Oct 26 11:23:47 2017 + +@author: Utku Ozbulak - github.com/utkuozbulak +""" +import torch +from torch.nn import ReLU + +class GuidedBackprop(): + """ + Produces gradients generated with guided back propagation from the given image + """ + def __init__(self, model, arch): + self.model = model + self.arch = arch + self.gradients = None + self.forward_relu_outputs = [] + # Put model in evaluation mode + self.model.train() + self.update_relus() + self.hook_layers() + + def hook_layers(self): + def hook_function(module, grad_in, grad_out): + self.gradients = grad_in[0] + # Register hook to the first layer + if 'alexnet' in self.arch: + first_layer = list(self.model.features._modules.items())[0][1] + elif 'resnet' in self.arch: + first_layer = list(self.model._modules.items())[0][1] + first_layer.register_backward_hook(hook_function) + + def update_relus(self): + """ + Updates relu activation functions so that + 1- stores output in forward pass + 2- imputes zero for gradient values that are less than zero + """ + def relu_backward_hook_function(module, grad_in, grad_out): + """ + If there is a negative gradient, change it to zero + """ + # Get last forward output + corresponding_forward_output = self.forward_relu_outputs[-1] + corresponding_forward_output[corresponding_forward_output > 0] = 1 + modified_grad_out = corresponding_forward_output * torch.clamp(grad_in[0], min=0.0) + del self.forward_relu_outputs[-1] # Remove last forward output + return (modified_grad_out,) + + def relu_forward_hook_function(module, ten_in, ten_out): + """ + Store results of forward pass + """ + self.forward_relu_outputs.append(ten_out) + + # Loop through layers, hook up ReLUs + if 'alexnet' in self.arch: + for pos, module in self.model.features._modules.items(): + if isinstance(module, ReLU): + module.register_backward_hook(relu_backward_hook_function) + module.register_forward_hook(relu_forward_hook_function) + elif 'resnet' in self.arch: + for module in self.model.modules(): + if isinstance(module, ReLU): + module.register_backward_hook(relu_backward_hook_function) + module.register_forward_hook(relu_forward_hook_function) + + def generate_gradients(self, input_image, one_hot_output_guided, text_for_pred): + # Forward pass + model_output = self.model(input_image, text_for_pred, is_train=False) + # Zero gradients + self.model.zero_grad() + # Backward pass + model_output.backward(gradient=one_hot_output_guided) + # Convert Pytorch variable to numpy array + # [0] to get rid of the first channel (1,3,224,224) + gradients_as_arr = self.gradients.data.to('cpu').numpy()[0] + print("gradients_as_arr shape: ", gradients_as_arr.shape) + return gradients_as_arr + + +if __name__ == '__main__': + target_example = 0 # Snake + (original_image, prep_img, target_class, file_name_to_export, pretrained_model) =\ + get_example_params(target_example) + + # Guided backprop + GBP = GuidedBackprop(pretrained_model) + # Get gradients + guided_grads = GBP.generate_gradients(prep_img, target_class) + # Save colored gradients + save_gradient_images(guided_grads, file_name_to_export + '_Guided_BP_color') + # Convert to grayscale + grayscale_guided_grads = convert_to_grayscale(guided_grads) + # Save grayscale gradients + save_gradient_images(grayscale_guided_grads, file_name_to_export + '_Guided_BP_gray') + # Positive and negative saliency maps + pos_sal, neg_sal = get_positive_negative_saliency(guided_grads) + save_gradient_images(pos_sal, file_name_to_export + '_pos_sal') + save_gradient_images(neg_sal, file_name_to_export + '_neg_sal') + print('Guided backprop completed') diff --git a/modules_trba/prediction.py b/modules_trba/prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..b6c3cb32541ea4ea5f22377ce5062e5ea34ff2f7 --- /dev/null +++ b/modules_trba/prediction.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +class Attention(nn.Module): + + def __init__(self, input_size, hidden_size, num_classes): + super(Attention, self).__init__() + self.attention_cell = AttentionCell(input_size, hidden_size, num_classes) + self.hidden_size = hidden_size + self.num_classes = num_classes + self.generator = nn.Linear(hidden_size, num_classes) + + def _char_to_onehot(self, input_char, onehot_dim=38): + input_char = input_char.unsqueeze(1) + batch_size = input_char.size(0) + one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device) + one_hot = one_hot.scatter_(1, input_char, 1) + return one_hot + + def forward(self, batch_H, text, is_train=True, batch_max_length=25): + """ + input: + batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels] + text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. + output: probability distribution at each step [batch_size x num_steps x num_classes] + """ + batch_size = batch_H.size(0) + num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. + + output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device) + hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), + torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device)) + + if is_train: + for i in range(num_steps): + # one-hot vectors for a i-th char. in a batch + char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes) + # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) + hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) + output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) + probs = self.generator(output_hiddens) + + else: + targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token + probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device) + + for i in range(num_steps): + char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) + probs_step = self.generator(hidden[0]) + probs[:, i, :] = probs_step + _, next_input = probs_step.max(1) + targets = next_input + + return probs # batch_size x num_steps x num_classes + + +class AttentionCell(nn.Module): + + def __init__(self, input_size, hidden_size, num_embeddings): + super(AttentionCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias=False) + self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias + self.score = nn.Linear(hidden_size, 1, bias=False) + self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) + e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 + + alpha = F.softmax(e, dim=1) + context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel + concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) + cur_hidden = self.rnn(concat_context, prev_hidden) + return cur_hidden, alpha diff --git a/modules_trba/sequence_modeling.py b/modules_trba/sequence_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..af32c59b2cc981be1b43412ddf4ac853d0611210 --- /dev/null +++ b/modules_trba/sequence_modeling.py @@ -0,0 +1,19 @@ +import torch.nn as nn + + +class BidirectionalLSTM(nn.Module): + + def __init__(self, input_size, hidden_size, output_size): + super(BidirectionalLSTM, self).__init__() + self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) + self.linear = nn.Linear(hidden_size * 2, output_size) + + def forward(self, input): + """ + input : visual feature [batch_size x T x input_size] + output : contextual feature [batch_size x T x output_size] + """ + self.rnn.flatten_parameters() + recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) + output = self.linear(recurrent) # batch_size x T x output_size + return output diff --git a/modules_trba/transformation.py b/modules_trba/transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..e003cd4196ff5f34fe777448a062f9a081fedaea --- /dev/null +++ b/modules_trba/transformation.py @@ -0,0 +1,164 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +class TPS_SpatialTransformerNetwork(nn.Module): + """ Rectification Network of RARE, namely TPS based STN """ + + def __init__(self, F, I_size, I_r_size, I_channel_num=1): + """ Based on RARE TPS + input: + batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] + I_size : (height, width) of the input image I + I_r_size : (height, width) of the rectified image I_r + I_channel_num : the number of channels of the input image I + output: + batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] + """ + super(TPS_SpatialTransformerNetwork, self).__init__() + self.F = F + self.I_size = I_size + self.I_r_size = I_r_size # = (I_r_height, I_r_width) + self.I_channel_num = I_channel_num + self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) + self.GridGenerator = GridGenerator(self.F, self.I_r_size) + + def forward(self, batch_I): + batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 + build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 + build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) + + if torch.__version__ > "1.2.0": + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) + else: + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') + + return batch_I_r + + +class LocalizationNetwork(nn.Module): + """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ + + def __init__(self, F, I_channel_num): + super(LocalizationNetwork, self).__init__() + self.F = F + self.I_channel_num = I_channel_num + self.conv = nn.Sequential( + nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, + bias=False), nn.BatchNorm2d(64), nn.ReLU(), + nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 + nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(), + nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 + nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(), + nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 + nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(), + nn.AdaptiveAvgPool2d(1) # batch_size x 512 + ) + + self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU()) + self.localization_fc2 = nn.Linear(256, self.F * 2) + + # Init fc2 in LocalizationNetwork + self.localization_fc2.weight.data.fill_(0) + """ see RARE paper Fig. 6 (a) """ + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) + ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) + + def forward(self, batch_I): + """ + input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] + output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] + """ + batch_size = batch_I.size(0) + features = self.conv(batch_I).view(batch_size, -1) + batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) + return batch_C_prime + + +class GridGenerator(nn.Module): + """ Grid Generator of RARE, which produces P_prime by multipling T with P """ + + def __init__(self, F, I_r_size): + """ Generate P_hat and inv_delta_C for later """ + super(GridGenerator, self).__init__() + self.eps = 1e-6 + self.I_r_height, self.I_r_width = I_r_size + self.F = F + self.C = self._build_C(self.F) # F x 2 + self.P = self._build_P(self.I_r_width, self.I_r_height) + ## for multi-gpu, you need register buffer + self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 + self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 + ## for fine-tuning with different image width, you may use below instead of self.register_buffer + #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3 + #self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3 + + def _build_C(self, F): + """ Return coordinates of fiducial points in I_r; C """ + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = -1 * np.ones(int(F / 2)) + ctrl_pts_y_bottom = np.ones(int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + return C # F x 2 + + def _build_inv_delta_C(self, F, C): + """ Return inv_delta_C which is needed to calculate T """ + hat_C = np.zeros((F, F), dtype=float) # F x F + for i in range(0, F): + for j in range(i, F): + r = np.linalg.norm(C[i] - C[j]) + hat_C[i, j] = r + hat_C[j, i] = r + np.fill_diagonal(hat_C, 1) + hat_C = (hat_C ** 2) * np.log(hat_C) + # print(C.shape, hat_C.shape) + delta_C = np.concatenate( # F+3 x F+3 + [ + np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 + np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 + np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 + ], + axis=0 + ) + inv_delta_C = np.linalg.inv(delta_C) + return inv_delta_C # F+3 x F+3 + + def _build_P(self, I_r_width, I_r_height): + I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width + I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height + P = np.stack( # self.I_r_width x self.I_r_height x 2 + np.meshgrid(I_r_grid_x, I_r_grid_y), + axis=2 + ) + return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 + + def _build_P_hat(self, F, C, P): + n = P.shape[0] # n (= self.I_r_width x self.I_r_height) + P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 + C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 + P_diff = P_tile - C_tile # n x F x 2 + rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F + rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F + P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) + return P_hat # n x F+3 + + def build_P_prime(self, batch_C_prime): + """ Generate Grid from batch_C_prime [batch_size x F x 2] """ + batch_size = batch_C_prime.size(0) + batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) + batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) + batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( + batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2 + batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 + batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 + return batch_P_prime # batch_size x n x 2 diff --git a/settings.py b/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9f76ac112de0f05e0f910daf225035165b05b4 --- /dev/null +++ b/settings.py @@ -0,0 +1,4 @@ +######### global settings ######### +MODEL = 'vitstr' # model arch: vitstr, parseq, srn, abinet, trba, matrn +SEGM_DIR = "./datasets/segmentations" # segmentation directory of the real test sets +TARGET_DATASET = "SVTP" # 'IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80' diff --git a/str_exp_demo.py b/str_exp_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..351729dffdd42a508a5377c1e4813d7f439e8775 --- /dev/null +++ b/str_exp_demo.py @@ -0,0 +1,493 @@ +import settings +import captum +import numpy as np +import torch +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +from utils import get_args +from utils import CTCLabelConverter, AttnLabelConverter, Averager, TokenLabelConverter +import string +import time +import sys +from dataset import hierarchical_dataset, AlignCollate +import validators +from model import Model, STRScore +from PIL import Image +from lime.wrappers.scikit_image import SegmentationAlgorithm +from captum._utils.models.linear_model import SkLearnLinearModel, SkLearnRidge +import random +import os +from skimage.color import gray2rgb +import pickle +from train_shap_corr import getPredAndConf +import re +from captum_test import acquire_average_auc, saveAttrData +import copy +from skimage.color import gray2rgb +from matplotlib import pyplot as plt +from torchvision import transforms + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +from captum.attr import ( + GradientShap, + DeepLift, + DeepLiftShap, + IntegratedGradients, + LayerConductance, + NeuronConductance, + NoiseTunnel, + Saliency, + InputXGradient, + GuidedBackprop, + Deconvolution, + GuidedGradCam, + FeatureAblation, + ShapleyValueSampling, + Lime, + KernelShap +) + +from captum.metrics import ( + infidelity, + sensitivity_max +) + +from captum.attr._utils.visualization import visualize_image_attr + +### Acquire pixelwise attributions and replace them with ranked numbers averaged +### across segmentation with the largest contribution having the largest number +### and the smallest set to 1, which is the minimum number. +### attr - original attribution +### segm - image segmentations +def rankedAttributionsBySegm(attr, segm): + aveSegmentations, sortedDict = averageSegmentsOut(attr[0,0], segm) + totalSegm = len(sortedDict.keys()) # total segmentations + sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])] + sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score + currentRank = totalSegm + rankedSegmImg = torch.clone(attr) + for totalSegToHide in range(0, len(sortedKeys)): + currentSegmentToHide = sortedKeys[totalSegToHide] + rankedSegmImg[0,0][segm == currentSegmentToHide] = currentRank + currentRank -= 1 + return rankedSegmImg + +### Returns the mean for each segmentation having shape as the same as the input +### This function can only one attribution image at a time +def averageSegmentsOut(attr, segments): + averagedInput = torch.clone(attr) + sortedDict = {} + for x in np.unique(segments): + segmentMean = torch.mean(attr[segments == x][:]) + sortedDict[x] = float(segmentMean.detach().cpu().numpy()) + averagedInput[segments == x] = segmentMean + return averagedInput, sortedDict + +### Output and save segmentations only for one dataset only +def outputSegmOnly(opt): + ### targetDataset - one dataset only, SVTP-645, CUTE80-288images + targetDataset = "CUTE80" # ['IIIT5k_3000', 'SVT', 'IC03_867', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] + segmRootDir = "/home/uclpc1/Documents/STR/datasets/segmentations/224X224/{}/".format(targetDataset) + + if not os.path.exists(segmRootDir): + os.makedirs(segmRootDir) + + opt.eval = True + ### Only IIIT5k_3000 + if opt.fast_acc: + # # To easily compute the total accuracy of our paper. + eval_data_list = [targetDataset] + else: + # The evaluation datasets, dataset order is same with Table 1 in our paper. + eval_data_list = [targetDataset] + + ### Taken from LIME + segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, + max_dist=200, ratio=0.2, + random_seed=random.randint(0, 1000)) + + for eval_data in eval_data_list: + eval_data_path = os.path.join(opt.eval_data, eval_data) + AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt) + eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt) + evaluation_loader = torch.utils.data.DataLoader( + eval_data, batch_size=1, + shuffle=False, + num_workers=int(opt.workers), + collate_fn=AlignCollate_evaluation, pin_memory=True) + for i, (image_tensors, labels) in enumerate(evaluation_loader): + imgDataDict = {} + img_numpy = image_tensors.cpu().detach().numpy()[0] ### Need to set batch size to 1 only + if img_numpy.shape[0] == 1: + img_numpy = gray2rgb(img_numpy[0]) + # print("img_numpy shape: ", img_numpy.shape) # (224,224,3) + segmOutput = segmentation_fn(img_numpy) + imgDataDict['segdata'] = segmOutput + imgDataDict['label'] = labels[0] + outputPickleFile = segmRootDir + "{}.pkl".format(i) + with open(outputPickleFile, 'wb') as f: + pickle.dump(imgDataDict, f) + +def acquireSelectivityHit(origImg, attributions, segmentations, model, converter, labels, scoring): + # print("segmentations unique len: ", np.unique(segmentations)) + aveSegmentations, sortedDict = averageSegmentsOut(attributions[0,0], segmentations) + sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])] + sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score + # print("sortedDict: ", sortedDict) # {0: -5.51e-06, 1: -1.469e-05, 2: -3.06e-05,...} + # print("aveSegmentations unique len: ", np.unique(aveSegmentations)) + # print("aveSegmentations device: ", aveSegmentations.device) # cuda:0 + # print("aveSegmentations shape: ", aveSegmentations.shape) # (224,224) + # print("aveSegmentations: ", aveSegmentations) + + n_correct = [] + confidenceList = [] # First index is one feature removed, second index two features removed, and so on... + clonedImg = torch.clone(origImg) + gt = str(labels) + for totalSegToHide in range(0, len(sortedKeys)): + ### Acquire LIME prediction result + currentSegmentToHide = sortedKeys[totalSegToHide] + clonedImg[0,0][segmentations == currentSegmentToHide] = 0.0 + pred, confScore = getPredAndConf(opt, model, scoring, clonedImg, converter, np.array([gt])) + # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. + if opt.sensitive and opt.data_filtering_off: + pred = pred.lower() + gt = gt.lower() + alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz' + out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]' + pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred) + gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt) + if pred == gt: + n_correct.append(1) + else: + n_correct.append(0) + confScore = confScore[0][0]*100 + confidenceList.append(confScore) + return n_correct, confidenceList + +### Once you have the selectivity_eval_results.pkl file, +def acquire_selectivity_auc(opt, pkl_filename=None): + if pkl_filename is None: + pkl_filename = "/home/goo/str/str_vit_dataexplain_lambda/metrics_sensitivity_eval_results_CUTE80.pkl" # VITSTR + accKeys = [] + + with open(pkl_filename, 'rb') as f: + selectivity_data = pickle.load(f) + + for resDictIdx, resDict in enumerate(selectivity_data): + keylistAcc = [] + keylistConf = [] + metricsKeys = resDict.keys() + for keyStr in resDict.keys(): + if "_acc" in keyStr: keylistAcc.append(keyStr) + if "_conf" in keyStr: keylistConf.append(keyStr) + # Need to check if network correctly predicted the image + for metrics_accStr in keylistAcc: + if 1 not in resDict[metrics_accStr]: print("resDictIdx") + +# Single directory STRExp explanations output demo +def sampleDemo(opt): + targetDataset = "SVTP" + demoImgDir = "demo_image/" + outputDir = "/data/goo/demo_image_output/" + + if not os.path.exists(outputDir): + os.makedirs(outputDir) + + segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, + max_dist=200, ratio=0.2, + random_seed=random.randint(0, 1000)) + + """ model configuration """ + if opt.Transformer: + converter = TokenLabelConverter(opt) + elif 'CTC' in opt.Prediction: + converter = CTCLabelConverter(opt.character) + else: + converter = AttnLabelConverter(opt.character) + opt.num_class = len(converter.character) + + if opt.rgb: + opt.input_channel = 3 + model_obj = Model(opt) + + print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, + opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, + opt.SequenceModeling, opt.Prediction) + model = torch.nn.DataParallel(model_obj).to(device) + + modelCopy = copy.deepcopy(model) + + """ evaluation """ + scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True) + super_pixel_model_singlechar = torch.nn.Sequential( + # super_pixler, + # numpy2torch_converter, + modelCopy, + scoring_singlechar + ).to(device) + modelCopy.eval() + scoring_singlechar.eval() + super_pixel_model_singlechar.eval() + + # Single Char Attribution Averaging + # enableSingleCharAttrAve - set to True + scoring = STRScore(opt=opt, converter=converter, device=device) + super_pixel_model = torch.nn.Sequential( + # super_pixler, + # numpy2torch_converter, + model, + scoring + ).to(device) + model.eval() + scoring.eval() + super_pixel_model.eval() + + if opt.blackbg: + shapImgLs = np.zeros(shape=(1, 1, 224, 224)).astype(np.float32) + trainList = np.array(shapImgLs) + background = torch.from_numpy(trainList).to(device) + + opt.eval = True + for path, subdirs, files in os.walk(demoImgDir): + for name in files: + nameNoExt = name.split('.')[0] + labels = nameNoExt + fullfilename = os.path.join(demoImgDir, name) # Value + # fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl + pilImg = Image.open(fullfilename) + + if settings.MODEL=="vitstr": + pilImg = pilImg.resize((224, 224)) + + orig_img_tensors = transforms.ToTensor()(pilImg) + orig_img_tensors = torch.mean(orig_img_tensors, dim=0).unsqueeze(0).unsqueeze(0) + image_tensors = ((torch.clone(orig_img_tensors) + 1.0) / 2.0) * 255.0 + imgDataDict = {} + img_numpy = image_tensors.cpu().detach().numpy()[0] ### Need to set batch size to 1 only + if img_numpy.shape[0] == 1: + img_numpy = gray2rgb(img_numpy[0]) + # print("img_numpy shape: ", img_numpy.shape) # (32,100,3) + segmOutput = segmentation_fn(img_numpy) + # print("orig_img_tensors shape: ", orig_img_tensors.shape) # (3, 224, 224) + # print("orig_img_tensors max: ", orig_img_tensors.max()) # 0.6824 (1) + # print("orig_img_tensors min: ", orig_img_tensors.min()) # 0.0235 (0) + # sys.exit() + + results_dict = {} + aveAttr = [] + aveAttr_charContrib = [] + # segmData, labels = segAndLabels[0] + target = converter.encode([labels]) + + # labels: RONALDO + segmDataNP = segmOutput + segmTensor = torch.from_numpy(segmDataNP).unsqueeze(0).unsqueeze(0) + # print("segmTensor min: ", segmTensor.min()) # 0 starting segmentation + segmTensor = segmTensor.to(device) + # print("segmTensor shape: ", segmTensor.shape) + # img1 = np.asarray(imgPIL.convert('L')) + # sys.exit() + # img1 = img1 / 255.0 + # img1 = torch.from_numpy(img1).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device) + img1 = orig_img_tensors.to(device) + img1.requires_grad = True + bgImg = torch.zeros(img1.shape).to(device) + + ### Single char averaging + charOffset = 1 + + # preds = model(img1, seqlen=converter.batch_max_length) + input = img1 + origImgNP = torch.clone(orig_img_tensors).detach().cpu().numpy()[0][0] # (1, 1, 224, 224) + origImgNP = gray2rgb(origImgNP) + + ### Local explanations only + collectedAttributions = [] + for charIdx in range(0, len(labels)): + scoring_singlechar.setSingleCharOutput(charIdx + charOffset) + gtClassNum = target[0][charIdx + charOffset] + + ### Shapley Value Sampling + svs = ShapleyValueSampling(super_pixel_model_singlechar) + # attr = svs.attribute(input, target=0, n_samples=200) ### Individual pixels, too long to calculate + attributions = svs.attribute(input, target=gtClassNum, feature_mask=segmTensor) + collectedAttributions.append(attributions) + aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0) + if not torch.isnan(aveAttributions).any(): + rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(outputDir + '{}_shapley_l.png'.format(nameNoExt)) + mplotfig.clear() + plt.close(mplotfig) + + ### Shapley Value Sampling + svs = ShapleyValueSampling(super_pixel_model) + # attr = svs.attribute(input, target=0, n_samples=200) ### Individual pixels, too long to calculate + attributions = svs.attribute(input, target=0, feature_mask=segmTensor) + if not torch.isnan(attributions).any(): + collectedAttributions.append(attributions) + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(outputDir + '{}_shapley.png'.format(nameNoExt)) + mplotfig.clear() + plt.close(mplotfig) + + ### Global + Local context + aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0) + if not torch.isnan(aveAttributions).any(): + rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(outputDir + '{}_shapley_gl.png'.format(nameNoExt)) + mplotfig.clear() + plt.close(mplotfig) + + ### BASELINE Evaluations + + ### Integrated Gradients + ig = IntegratedGradients(super_pixel_model) + attributions = ig.attribute(input, target=0) + if not torch.isnan(attributions).any(): + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(outputDir + '{}_intgrad.png'.format(nameNoExt)) + mplotfig.clear() + plt.close(mplotfig) + + ### Gradient SHAP using zero-background + gs = GradientShap(super_pixel_model) + # We define a distribution of baselines and draw `n_samples` from that + # distribution in order to estimate the expectations of gradients across all baselines + baseline_dist = torch.zeros((1, 1, 224, 224)) + baseline_dist = baseline_dist.to(device) + attributions = gs.attribute(input, baselines=baseline_dist, target=0) + if not torch.isnan(attributions).any(): + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(outputDir + '{}_gradshap.png'.format(nameNoExt)) + mplotfig.clear() + plt.close(mplotfig) + + ### DeepLift using zero-background + dl = DeepLift(super_pixel_model) + attributions = dl.attribute(input, target=0) + if not torch.isnan(attributions).any(): + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(outputDir + '{}_deeplift.png'.format(nameNoExt)) + mplotfig.clear() + plt.close(mplotfig) + + ### Saliency + saliency = Saliency(super_pixel_model) + attributions = saliency.attribute(input, target=0) ### target=class0 + if not torch.isnan(attributions).any(): + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(outputDir + '{}_saliency.png'.format(nameNoExt)) + mplotfig.clear() + plt.close(mplotfig) + + ### InputXGradient + input_x_gradient = InputXGradient(super_pixel_model) + attributions = input_x_gradient.attribute(input, target=0) + if not torch.isnan(attributions).any(): + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(outputDir + '{}_inpxgrad.png'.format(nameNoExt)) + mplotfig.clear() + plt.close(mplotfig) + + ### GuidedBackprop + gbp = GuidedBackprop(super_pixel_model) + attributions = gbp.attribute(input, target=0) + if not torch.isnan(attributions).any(): + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(outputDir + '{}_guidedbp.png'.format(nameNoExt)) + mplotfig.clear() + plt.close(mplotfig) + + ### Deconvolution + deconv = Deconvolution(super_pixel_model) + attributions = deconv.attribute(input, target=0) + if not torch.isnan(attributions).any(): + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(outputDir + '{}_deconv.png'.format(nameNoExt)) + mplotfig.clear() + plt.close(mplotfig) + + ### Feature ablator + ablator = FeatureAblation(super_pixel_model) + attributions = ablator.attribute(input, target=0, feature_mask=segmTensor) + if not torch.isnan(attributions).any(): + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(outputDir + '{}_featablt.png'.format(nameNoExt)) + mplotfig.clear() + plt.close(mplotfig) + + ## LIME + interpretable_model = SkLearnRidge(alpha=1, fit_intercept=True) ### This is the default used by LIME + lime = Lime(super_pixel_model, interpretable_model=interpretable_model) + attributions = lime.attribute(input, target=0, feature_mask=segmTensor) + if not torch.isnan(attributions).any(): + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(outputDir + '{}_lime.png'.format(nameNoExt)) + mplotfig.clear() + plt.close(mplotfig) + + ### KernelSHAP + ks = KernelShap(super_pixel_model) + attributions = ks.attribute(input, target=0, feature_mask=segmTensor) + if not torch.isnan(attributions).any(): + rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) + rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] + rankedAttr = gray2rgb(rankedAttr) + mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') + mplotfig.savefig(outputDir + '{}_kernelshap.png'.format(nameNoExt)) + mplotfig.clear() + plt.close(mplotfig) + +if __name__ == '__main__': + # deleteInf() + opt = get_args(is_train=False) + + """ vocab / character number configuration """ + if opt.sensitive: + opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). + + cudnn.benchmark = True + cudnn.deterministic = True + opt.num_gpu = torch.cuda.device_count() + + # combineBestDataXAI(opt) + # acquire_average_auc(opt) + # acquireSingleCharAttrAve(opt) + sampleDemo(opt) diff --git a/strhub/__init__.py b/strhub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/strhub/data/__init__.py b/strhub/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/strhub/data/aa_overrides.py b/strhub/data/aa_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..7dcba717676180d61bee82c0404b5d3b3a63d339 --- /dev/null +++ b/strhub/data/aa_overrides.py @@ -0,0 +1,46 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extends default ops to accept optional parameters.""" +from functools import partial + +from timm.data.auto_augment import _LEVEL_DENOM, _randomly_negate, LEVEL_TO_ARG, NAME_TO_OP, rotate + + +def rotate_expand(img, degrees, **kwargs): + """Rotate operation with expand=True to avoid cutting off the characters""" + kwargs['expand'] = True + return rotate(img, degrees, **kwargs) + + +def _level_to_arg(level, hparams, key, default): + magnitude = hparams.get(key, default) + level = (level / _LEVEL_DENOM) * magnitude + level = _randomly_negate(level) + return level, + + +def apply(): + # Overrides + NAME_TO_OP.update({ + 'Rotate': rotate_expand + }) + LEVEL_TO_ARG.update({ + 'Rotate': partial(_level_to_arg, key='rotate_deg', default=30.), + 'ShearX': partial(_level_to_arg, key='shear_x_pct', default=0.3), + 'ShearY': partial(_level_to_arg, key='shear_y_pct', default=0.3), + 'TranslateXRel': partial(_level_to_arg, key='translate_x_pct', default=0.45), + 'TranslateYRel': partial(_level_to_arg, key='translate_y_pct', default=0.45), + }) diff --git a/strhub/data/augment.py b/strhub/data/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c1fb5a5ee5cc307f056d9ea0c78e91514078e6 --- /dev/null +++ b/strhub/data/augment.py @@ -0,0 +1,111 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import imgaug.augmenters as iaa +import numpy as np +from PIL import ImageFilter, Image +from timm.data import auto_augment + +from strhub.data import aa_overrides + +aa_overrides.apply() + +_OP_CACHE = {} + + +def _get_op(key, factory): + try: + op = _OP_CACHE[key] + except KeyError: + op = factory() + _OP_CACHE[key] = op + return op + + +def _get_param(level, img, max_dim_factor, min_level=1): + max_level = max(min_level, max_dim_factor * max(img.size)) + return round(min(level, max_level)) + + +def gaussian_blur(img, radius, **__): + radius = _get_param(radius, img, 0.02) + key = 'gaussian_blur_' + str(radius) + op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius)) + return img.filter(op) + + +def motion_blur(img, k, **__): + k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values + key = 'motion_blur_' + str(k) + op = _get_op(key, lambda: iaa.MotionBlur(k)) + return Image.fromarray(op(image=np.asarray(img))) + + +def gaussian_noise(img, scale, **_): + scale = _get_param(scale, img, 0.25) | 1 # bin to odd values + key = 'gaussian_noise_' + str(scale) + op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale)) + return Image.fromarray(op(image=np.asarray(img))) + + +def poisson_noise(img, lam, **_): + lam = _get_param(lam, img, 0.2) | 1 # bin to odd values + key = 'poisson_noise_' + str(lam) + op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam)) + return Image.fromarray(op(image=np.asarray(img))) + + +def _level_to_arg(level, _hparams, max): + level = max * level / auto_augment._LEVEL_DENOM + return level, + + +_RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy() +_RAND_TRANSFORMS.remove('SharpnessIncreasing') # remove, interferes with *blur ops +_RAND_TRANSFORMS.extend([ + 'GaussianBlur', + # 'MotionBlur', + # 'GaussianNoise', + 'PoissonNoise' +]) +auto_augment.LEVEL_TO_ARG.update({ + 'GaussianBlur': partial(_level_to_arg, max=4), + 'MotionBlur': partial(_level_to_arg, max=20), + 'GaussianNoise': partial(_level_to_arg, max=0.1 * 255), + 'PoissonNoise': partial(_level_to_arg, max=40) +}) +auto_augment.NAME_TO_OP.update({ + 'GaussianBlur': gaussian_blur, + 'MotionBlur': motion_blur, + 'GaussianNoise': gaussian_noise, + 'PoissonNoise': poisson_noise +}) + + +def rand_augment_transform(magnitude=5, num_layers=3): + # These are tuned for magnitude=5, which means that effective magnitudes are half of these values. + hparams = { + 'rotate_deg': 30, + 'shear_x_pct': 0.9, + 'shear_y_pct': 0.2, + 'translate_x_pct': 0.10, + 'translate_y_pct': 0.30 + } + ra_ops = auto_augment.rand_augment_ops(magnitude, hparams, transforms=_RAND_TRANSFORMS) + # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice) + choice_weights = [1. / len(ra_ops) for _ in range(len(ra_ops))] + return auto_augment.RandAugment(ra_ops, num_layers, choice_weights) diff --git a/strhub/data/dataset.py b/strhub/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e1da774dad124346f1fecbdab4ab7885a2cd8cdd --- /dev/null +++ b/strhub/data/dataset.py @@ -0,0 +1,137 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import glob +import io +import logging +import unicodedata +from pathlib import Path, PurePath +from typing import Callable, Optional, Union + +import lmdb +from PIL import Image +from torch.utils.data import Dataset, ConcatDataset + +from strhub.data.utils import CharsetAdapter + +log = logging.getLogger(__name__) + + +def build_tree_dataset(root: Union[PurePath, str], *args, **kwargs): + try: + kwargs.pop('root') # prevent 'root' from being passed via kwargs + except KeyError: + pass + root = Path(root).absolute() + log.info(f'dataset root:\t{root}') + datasets = [] + for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True): + mdb = Path(mdb) + ds_name = str(mdb.parent.relative_to(root)) + ds_root = str(mdb.parent.absolute()) + dataset = LmdbDataset(ds_root, *args, **kwargs) + log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}') + datasets.append(dataset) + return ConcatDataset(datasets) + + +class LmdbDataset(Dataset): + """Dataset interface to an LMDB database. + + It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned + as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset. + Labels are transformed according to the charset. + """ + + def __init__(self, root: str, charset: str, max_label_len: int, min_image_dim: int = 0, + remove_whitespace: bool = True, normalize_unicode: bool = True, + unlabelled: bool = False, transform: Optional[Callable] = None): + self._env = None + self.root = root + self.unlabelled = unlabelled + self.transform = transform + self.labels = [] + self.filtered_index_list = [] + self.num_samples = self._preprocess_labels(charset, remove_whitespace, normalize_unicode, + max_label_len, min_image_dim) + + def __del__(self): + if self._env is not None: + self._env.close() + self._env = None + + def _create_env(self): + return lmdb.open(self.root, max_readers=1, readonly=True, create=False, + readahead=False, meminit=False, lock=False) + + @property + def env(self): + if self._env is None: + self._env = self._create_env() + return self._env + + def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim): + charset_adapter = CharsetAdapter(charset) + with self._create_env() as env, env.begin() as txn: + num_samples = int(txn.get('num-samples'.encode())) + if self.unlabelled: + return num_samples + for index in range(num_samples): + index += 1 # lmdb starts with 1 + label_key = f'label-{index:09d}'.encode() + label = txn.get(label_key).decode() + # Normally, whitespace is removed from the labels. + if remove_whitespace: + label = ''.join(label.split()) + # Normalize unicode composites (if any) and convert to compatible ASCII characters + if normalize_unicode: + label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode() + # Filter by length before removing unsupported characters. The original label might be too long. + if len(label) > max_label_len: + continue + label = charset_adapter(label) + # We filter out samples which don't contain any supported characters + if not label: + continue + # Filter images that are too small. + if min_image_dim > 0: + img_key = f'image-{index:09d}'.encode() + buf = io.BytesIO(txn.get(img_key)) + w, h = Image.open(buf).size + if w < self.min_image_dim or h < self.min_image_dim: + continue + self.labels.append(label) + self.filtered_index_list.append(index) + return len(self.labels) + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + if self.unlabelled: + label = index + else: + label = self.labels[index] + index = self.filtered_index_list[index] + + img_key = f'image-{index:09d}'.encode() + with self.env.begin() as txn: + imgbuf = txn.get(img_key) + buf = io.BytesIO(imgbuf) + img = Image.open(buf).convert('RGB') + + if self.transform is not None: + img = self.transform(img) + + return img, label diff --git a/strhub/data/module.py b/strhub/data/module.py new file mode 100644 index 0000000000000000000000000000000000000000..f85c855426039e85c3c04f4c1fbc5e0c0620168b --- /dev/null +++ b/strhub/data/module.py @@ -0,0 +1,107 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import PurePath +from typing import Optional, Callable, Sequence, Tuple + +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from torchvision import transforms as T + +from .dataset import build_tree_dataset, LmdbDataset + + +class SceneTextDataModule(pl.LightningDataModule): + TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80') + TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80') + TEST_NEW = ('ArT', 'COCOv1.4', 'Uber') + TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW)) + + def __init__(self, root_dir: str, train_dir: str, img_size: Sequence[int], max_label_length: int, + charset_train: str, charset_test: str, batch_size: int, num_workers: int, augment: bool, + remove_whitespace: bool = True, normalize_unicode: bool = True, + min_image_dim: int = 0, rotation: int = 0, collate_fn: Optional[Callable] = None): + super().__init__() + self.root_dir = root_dir + self.train_dir = train_dir + self.img_size = tuple(img_size) + self.max_label_length = max_label_length + self.charset_train = charset_train + self.charset_test = charset_test + self.batch_size = batch_size + self.num_workers = num_workers + self.augment = augment + self.remove_whitespace = remove_whitespace + self.normalize_unicode = normalize_unicode + self.min_image_dim = min_image_dim + self.rotation = rotation + self.collate_fn = collate_fn + self._train_dataset = None + self._val_dataset = None + + @staticmethod + def get_transform(img_size: Tuple[int], augment: bool = False, rotation: int = 0): + transforms = [] + if augment: + from .augment import rand_augment_transform + transforms.append(rand_augment_transform()) + if rotation: + transforms.append(lambda img: img.rotate(rotation, expand=True)) + transforms.extend([ + # T.Resize(img_size, T.InterpolationMode.BICUBIC), + # T.ToTensor(), + T.Normalize(0.5, 0.5) + ]) + return T.Compose(transforms) + + @property + def train_dataset(self): + if self._train_dataset is None: + transform = self.get_transform(self.img_size, self.augment) + root = PurePath(self.root_dir, 'train', self.train_dir) + self._train_dataset = build_tree_dataset(root, self.charset_train, self.max_label_length, + self.min_image_dim, self.remove_whitespace, self.normalize_unicode, + transform=transform) + return self._train_dataset + + @property + def val_dataset(self): + if self._val_dataset is None: + transform = self.get_transform(self.img_size) + root = PurePath(self.root_dir, 'val') + self._val_dataset = build_tree_dataset(root, self.charset_test, self.max_label_length, + self.min_image_dim, self.remove_whitespace, self.normalize_unicode, + transform=transform) + return self._val_dataset + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, + num_workers=self.num_workers, persistent_workers=self.num_workers > 0, + pin_memory=True, collate_fn=self.collate_fn) + + def val_dataloader(self): + return DataLoader(self.val_dataset, batch_size=self.batch_size, + num_workers=self.num_workers, persistent_workers=self.num_workers > 0, + pin_memory=True, collate_fn=self.collate_fn) + + def test_dataloaders(self, subset): + transform = self.get_transform(self.img_size, rotation=self.rotation) + root = PurePath(self.root_dir, 'test') + datasets = {s: LmdbDataset(str(root / s), self.charset_test, self.max_label_length, + self.min_image_dim, self.remove_whitespace, self.normalize_unicode, + transform=transform) for s in subset} + return {k: DataLoader(v, batch_size=self.batch_size, num_workers=self.num_workers, + pin_memory=True, collate_fn=self.collate_fn) + for k, v in datasets.items()} diff --git a/strhub/data/utils.py b/strhub/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..52a5b56fb5ccacf07285b786a33a3e8102a99028 --- /dev/null +++ b/strhub/data/utils.py @@ -0,0 +1,148 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from abc import ABC, abstractmethod +from itertools import groupby +from typing import List, Optional, Tuple + +import torch +from torch import Tensor +from torch.nn.utils.rnn import pad_sequence + + +class CharsetAdapter: + """Transforms labels according to the target charset.""" + + def __init__(self, target_charset) -> None: + super().__init__() + self.lowercase_only = target_charset == target_charset.lower() + self.uppercase_only = target_charset == target_charset.upper() + self.unsupported = f'[^{re.escape(target_charset)}]' + + def __call__(self, label): + if self.lowercase_only: + label = label.lower() + elif self.uppercase_only: + label = label.upper() + # Remove unsupported characters + label = re.sub(self.unsupported, '', label) + return label + + +class BaseTokenizer(ABC): + + def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None: + self._itos = specials_first + tuple(charset) + specials_last + self._stoi = {s: i for i, s in enumerate(self._itos)} + + def __len__(self): + return len(self._itos) + + def _tok2ids(self, tokens: str) -> List[int]: + return [self._stoi[s] for s in tokens] + + def _ids2tok(self, token_ids: List[int], join: bool = True) -> str: + tokens = [self._itos[i] for i in token_ids] + return ''.join(tokens) if join else tokens + + @abstractmethod + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + """Encode a batch of labels to a representation suitable for the model. + + Args: + labels: List of labels. Each can be of arbitrary length. + device: Create tensor on this device. + + Returns: + Batched tensor representation padded to the max label length. Shape: N, L + """ + raise NotImplementedError + + @abstractmethod + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + """Internal method which performs the necessary filtering prior to decoding.""" + raise NotImplementedError + + def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]: + """Decode a batch of token distributions. + + Args: + token_dists: softmax probabilities over the token distribution. Shape: N, L, C + raw: return unprocessed labels (will return list of list of strings) + + Returns: + list of string labels (arbitrary length) and + their corresponding sequence probabilities as a list of Tensors + """ + batch_tokens = [] + batch_probs = [] + for dist in token_dists: + probs, ids = dist.max(-1) # greedy selection + if not raw: + probs, ids = self._filter(probs, ids) + tokens = self._ids2tok(ids, not raw) + batch_tokens.append(tokens) + batch_probs.append(probs) + return batch_tokens, batch_probs + + +class Tokenizer(BaseTokenizer): + BOS = '[B]' + EOS = '[E]' + PAD = '[P]' + + def __init__(self, charset: str) -> None: + specials_first = (self.EOS,) + specials_last = (self.BOS, self.PAD) + super().__init__(charset, specials_first, specials_last) + self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last] + + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device) + for y in labels] + return pad_sequence(batch, batch_first=True, padding_value=self.pad_id) + + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + ids = ids.tolist() + try: + eos_idx = ids.index(self.eos_id) + except ValueError: + eos_idx = len(ids) # Nothing to truncate. + # Truncate after EOS + ids = ids[:eos_idx] + probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists) + return probs, ids + + +class CTCTokenizer(BaseTokenizer): + BLANK = '[B]' + + def __init__(self, charset: str) -> None: + # BLANK uses index == 0 by default + super().__init__(charset, specials_first=(self.BLANK,)) + self.blank_id = self._stoi[self.BLANK] + + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + # We use a padded representation since we don't want to use CUDNN's CTC implementation + batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels] + return pad_sequence(batch, batch_first=True, padding_value=self.blank_id) + + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + # Best path decoding: + ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens + ids = [x for x in ids if x != self.blank_id] # Remove BLANKs + # `probs` is just pass-through since all positions are considered part of the path + return probs, ids diff --git a/strhub/models/__init__.py b/strhub/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/strhub/models/abinet/LICENSE b/strhub/models/abinet/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2f1d4adb4889b2719f13ed6edf56aed10246a516 --- /dev/null +++ b/strhub/models/abinet/LICENSE @@ -0,0 +1,25 @@ +ABINet for non-commercial purposes + +Copyright (c) 2021, USTC +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/strhub/models/abinet/__init__.py b/strhub/models/abinet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..604811036fda52d8485eecfebd4ffeb7f7176042 --- /dev/null +++ b/strhub/models/abinet/__init__.py @@ -0,0 +1,13 @@ +r""" +Fang, Shancheng, Hongtao, Xie, Yuxin, Wang, Zhendong, Mao, and Yongdong, Zhang. +"Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition." . +In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 7098-7107).2021. + +https://arxiv.org/abs/2103.06495 + +All source files, except `system.py`, are based on the implementation listed below, +and hence are released under the license of the original. + +Source: https://github.com/FangShancheng/ABINet +License: 2-clause BSD License (see included LICENSE file) +""" diff --git a/strhub/models/abinet/attention.py b/strhub/models/abinet/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..cc8fba0638e7444fdffe964f72d0566c1a5bb818 --- /dev/null +++ b/strhub/models/abinet/attention.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn + +from .transformer import PositionalEncoding + + +class Attention(nn.Module): + def __init__(self, in_channels=512, max_length=25, n_feature=256): + super().__init__() + self.max_length = max_length + + self.f0_embedding = nn.Embedding(max_length, in_channels) + self.w0 = nn.Linear(max_length, n_feature) + self.wv = nn.Linear(in_channels, in_channels) + self.we = nn.Linear(in_channels, max_length) + + self.active = nn.Tanh() + self.softmax = nn.Softmax(dim=2) + + def forward(self, enc_output): + enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2) + reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device) + reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S) + reading_order_embed = self.f0_embedding(reading_order) # b,25,512 + + t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256 + t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512 + + attn = self.we(t) # b,256,25 + attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256 + g_output = torch.bmm(attn, enc_output) # b,25,512 + return g_output, attn.view(*attn.shape[:2], 8, 32) + + +def encoder_layer(in_c, out_c, k=3, s=2, p=1): + return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p), + nn.BatchNorm2d(out_c), + nn.ReLU(True)) + + +def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None): + align_corners = None if mode == 'nearest' else True + return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor, + mode=mode, align_corners=align_corners), + nn.Conv2d(in_c, out_c, k, s, p), + nn.BatchNorm2d(out_c), + nn.ReLU(True)) + + +class PositionAttention(nn.Module): + def __init__(self, max_length, in_channels=512, num_channels=64, + h=8, w=32, mode='nearest', **kwargs): + super().__init__() + self.max_length = max_length + self.k_encoder = nn.Sequential( + encoder_layer(in_channels, num_channels, s=(1, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)), + encoder_layer(num_channels, num_channels, s=(2, 2)) + ) + self.k_decoder = nn.Sequential( + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), + decoder_layer(num_channels, in_channels, size=(h, w), mode=mode) + ) + + self.pos_encoder = PositionalEncoding(in_channels, dropout=0., max_len=max_length) + self.project = nn.Linear(in_channels, in_channels) + + def forward(self, x): + N, E, H, W = x.size() + k, v = x, x # (N, E, H, W) + + # calculate key vector + features = [] + for i in range(0, len(self.k_encoder)): + k = self.k_encoder[i](k) + features.append(k) + for i in range(0, len(self.k_decoder) - 1): + k = self.k_decoder[i](k) + k = k + features[len(self.k_decoder) - 2 - i] + k = self.k_decoder[-1](k) + + # calculate query vector + # TODO q=f(q,k) + zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E) + q = self.pos_encoder(zeros) # (T, N, E) + q = q.permute(1, 0, 2) # (N, T, E) + q = self.project(q) # (N, T, E) + + # calculate attention + attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W)) + attn_scores = attn_scores / (E ** 0.5) + attn_scores = torch.softmax(attn_scores, dim=-1) + + v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E) + attn_vecs = torch.bmm(attn_scores, v) # (N, T, E) + + return attn_vecs, attn_scores.view(N, -1, H, W) diff --git a/strhub/models/abinet/backbone.py b/strhub/models/abinet/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..debcabd7f115db0e698a55175a01a0ff0131e10f --- /dev/null +++ b/strhub/models/abinet/backbone.py @@ -0,0 +1,24 @@ +import torch.nn as nn +from torch.nn import TransformerEncoderLayer, TransformerEncoder + +from .resnet import resnet45 +from .transformer import PositionalEncoding + + +class ResTranformer(nn.Module): + def __init__(self, d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', backbone_ln=2): + super().__init__() + self.resnet = resnet45() + self.pos_encoder = PositionalEncoding(d_model, max_len=8 * 32) + encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead, + dim_feedforward=d_inner, dropout=dropout, activation=activation) + self.transformer = TransformerEncoder(encoder_layer, backbone_ln) + + def forward(self, images): + feature = self.resnet(images) + n, c, h, w = feature.shape + feature = feature.view(n, c, -1).permute(2, 0, 1) + feature = self.pos_encoder(feature) + feature = self.transformer(feature) + feature = feature.permute(1, 2, 0).view(n, c, h, w) + return feature diff --git a/strhub/models/abinet/model.py b/strhub/models/abinet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..cc0cd143d324822c57b897b6e5749024d857fd30 --- /dev/null +++ b/strhub/models/abinet/model.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn + + +class Model(nn.Module): + + def __init__(self, dataset_max_length: int, null_label: int): + super().__init__() + self.max_length = dataset_max_length + 1 # additional stop token + self.null_label = null_label + + def _get_length(self, logit, dim=-1): + """ Greed decoder to obtain length from logit""" + out = (logit.argmax(dim=-1) == self.null_label) + abn = out.any(dim) + out = ((out.cumsum(dim) == 1) & out).max(dim)[1] + out = out + 1 # additional end token + out = torch.where(abn, out, out.new_tensor(logit.shape[1], device=out.device)) + return out + + @staticmethod + def _get_padding_mask(length, max_length): + length = length.unsqueeze(-1) + grid = torch.arange(0, max_length, device=length.device).unsqueeze(0) + return grid >= length + + @staticmethod + def _get_location_mask(sz, device=None): + mask = torch.eye(sz, device=device) + mask = mask.float().masked_fill(mask == 1, float('-inf')) + return mask diff --git a/strhub/models/abinet/model_abinet_iter.py b/strhub/models/abinet/model_abinet_iter.py new file mode 100644 index 0000000000000000000000000000000000000000..1a8523ff6431f991037d56dc8dd72ae67c7bf242 --- /dev/null +++ b/strhub/models/abinet/model_abinet_iter.py @@ -0,0 +1,39 @@ +import torch +from torch import nn + +from .model_alignment import BaseAlignment +from .model_language import BCNLanguage +from .model_vision import BaseVision + + +class ABINetIterModel(nn.Module): + def __init__(self, dataset_max_length, null_label, num_classes, iter_size=1, + d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', + v_loss_weight=1., v_attention='position', v_attention_mode='nearest', + v_backbone='transformer', v_num_layers=2, + l_loss_weight=1., l_num_layers=4, l_detach=True, l_use_self_attn=False, + a_loss_weight=1.): + super().__init__() + self.iter_size = iter_size + self.vision = BaseVision(dataset_max_length, null_label, num_classes, v_attention, v_attention_mode, + v_loss_weight, d_model, nhead, d_inner, dropout, activation, v_backbone, v_num_layers) + self.language = BCNLanguage(dataset_max_length, null_label, num_classes, d_model, nhead, d_inner, dropout, + activation, l_num_layers, l_detach, l_use_self_attn, l_loss_weight) + self.alignment = BaseAlignment(dataset_max_length, null_label, num_classes, d_model, a_loss_weight) + + def forward(self, images): + v_res = self.vision(images) + a_res = v_res + all_l_res, all_a_res = [], [] + for _ in range(self.iter_size): + tokens = torch.softmax(a_res['logits'], dim=-1) + lengths = a_res['pt_lengths'] + lengths.clamp_(2, self.language.max_length) # TODO:move to langauge model + l_res = self.language(tokens, lengths) + all_l_res.append(l_res) + a_res = self.alignment(l_res['feature'], v_res['feature']) + all_a_res.append(a_res) + if self.training: + return all_a_res, all_l_res, v_res + else: + return a_res, all_l_res[-1], v_res diff --git a/strhub/models/abinet/model_alignment.py b/strhub/models/abinet/model_alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..9ccfa95e65dbd7176c8bcee693bb0bcb8ad13c69 --- /dev/null +++ b/strhub/models/abinet/model_alignment.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn + +from .model import Model + + +class BaseAlignment(Model): + def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, loss_weight=1.0): + super().__init__(dataset_max_length, null_label) + self.loss_weight = loss_weight + self.w_att = nn.Linear(2 * d_model, d_model) + self.cls = nn.Linear(d_model, num_classes) + + def forward(self, l_feature, v_feature): + """ + Args: + l_feature: (N, T, E) where T is length, N is batch size and d is dim of model + v_feature: (N, T, E) shape the same as l_feature + """ + f = torch.cat((l_feature, v_feature), dim=2) + f_att = torch.sigmoid(self.w_att(f)) + output = f_att * v_feature + (1 - f_att) * l_feature + + logits = self.cls(output) # (N, T, C) + pt_lengths = self._get_length(logits) + + return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight': self.loss_weight, + 'name': 'alignment'} diff --git a/strhub/models/abinet/model_language.py b/strhub/models/abinet/model_language.py new file mode 100644 index 0000000000000000000000000000000000000000..659d446578915bd1cab945554be749b4f1b0dff3 --- /dev/null +++ b/strhub/models/abinet/model_language.py @@ -0,0 +1,50 @@ +import torch.nn as nn +from torch.nn import TransformerDecoder + +from .model import Model +from .transformer import PositionalEncoding, TransformerDecoderLayer + + +class BCNLanguage(Model): + def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, nhead=8, d_inner=2048, dropout=0.1, + activation='relu', num_layers=4, detach=True, use_self_attn=False, loss_weight=1.0, + global_debug=False): + super().__init__(dataset_max_length, null_label) + self.detach = detach + self.loss_weight = loss_weight + self.proj = nn.Linear(num_classes, d_model, False) + self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length) + self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length) + decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout, + activation, self_attn=use_self_attn, debug=global_debug) + self.model = TransformerDecoder(decoder_layer, num_layers) + self.cls = nn.Linear(d_model, num_classes) + + def forward(self, tokens, lengths): + """ + Args: + tokens: (N, T, C) where T is length, N is batch size and C is classes number + lengths: (N,) + """ + if self.detach: + tokens = tokens.detach() + embed = self.proj(tokens) # (N, T, E) + embed = embed.permute(1, 0, 2) # (T, N, E) + embed = self.token_encoder(embed) # (T, N, E) + padding_mask = self._get_padding_mask(lengths, self.max_length) + + zeros = embed.new_zeros(*embed.shape) + qeury = self.pos_encoder(zeros) + location_mask = self._get_location_mask(self.max_length, tokens.device) + output = self.model(qeury, embed, + tgt_key_padding_mask=padding_mask, + memory_mask=location_mask, + memory_key_padding_mask=padding_mask) # (T, N, E) + output = output.permute(1, 0, 2) # (N, T, E) + + logits = self.cls(output) # (N, T, C) + pt_lengths = self._get_length(logits) + + res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths, + 'loss_weight': self.loss_weight, 'name': 'language'} + return res diff --git a/strhub/models/abinet/model_vision.py b/strhub/models/abinet/model_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..bddb7d5f237854b81c388090e2e20fc26632c431 --- /dev/null +++ b/strhub/models/abinet/model_vision.py @@ -0,0 +1,45 @@ +from torch import nn + +from .attention import PositionAttention, Attention +from .backbone import ResTranformer +from .model import Model +from .resnet import resnet45 + + +class BaseVision(Model): + def __init__(self, dataset_max_length, null_label, num_classes, + attention='position', attention_mode='nearest', loss_weight=1.0, + d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', + backbone='transformer', backbone_ln=2): + super().__init__(dataset_max_length, null_label) + self.loss_weight = loss_weight + self.out_channels = d_model + + if backbone == 'transformer': + self.backbone = ResTranformer(d_model, nhead, d_inner, dropout, activation, backbone_ln) + else: + self.backbone = resnet45() + + if attention == 'position': + self.attention = PositionAttention( + max_length=self.max_length, + mode=attention_mode + ) + elif attention == 'attention': + self.attention = Attention( + max_length=self.max_length, + n_feature=8 * 32, + ) + else: + raise ValueError(f'invalid attention: {attention}') + + self.cls = nn.Linear(self.out_channels, num_classes) + + def forward(self, images): + features = self.backbone(images) # (N, E, H, W) + attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W) + logits = self.cls(attn_vecs) # (N, T, C) + pt_lengths = self._get_length(logits) + + return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths, + 'attn_scores': attn_scores, 'loss_weight': self.loss_weight, 'name': 'vision'} diff --git a/strhub/models/abinet/resnet.py b/strhub/models/abinet/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..59bf38896987b3560e254e8037426d29bcdd5844 --- /dev/null +++ b/strhub/models/abinet/resnet.py @@ -0,0 +1,72 @@ +import math +from typing import Optional, Callable + +import torch.nn as nn +from torchvision.models import resnet + + +class BasicBlock(resnet.BasicBlock): + + def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, + groups: int = 1, base_width: int = 64, dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None) -> None: + super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer) + self.conv1 = resnet.conv1x1(inplanes, planes) + self.conv2 = resnet.conv3x3(planes, planes, stride) + + +class ResNet(nn.Module): + + def __init__(self, block, layers): + super().__init__() + self.inplanes = 32 + self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, 32, layers[0], stride=2) + self.layer2 = self._make_layer(block, 64, layers[1], stride=1) + self.layer3 = self._make_layer(block, 128, layers[2], stride=2) + self.layer4 = self._make_layer(block, 256, layers[3], stride=1) + self.layer5 = self._make_layer(block, 512, layers[4], stride=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.layer5(x) + return x + + +def resnet45(): + return ResNet(BasicBlock, [3, 4, 6, 6, 3]) diff --git a/strhub/models/abinet/system.py b/strhub/models/abinet/system.py new file mode 100644 index 0000000000000000000000000000000000000000..fc8f2cd8d274f45007f97301a555d884398a2e97 --- /dev/null +++ b/strhub/models/abinet/system.py @@ -0,0 +1,172 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +from typing import Any, Tuple, List, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.optim import AdamW +from torch.optim.lr_scheduler import OneCycleLR + +from pytorch_lightning.utilities.types import STEP_OUTPUT +from timm.optim.optim_factory import param_groups_weight_decay + +from strhub.models.base import CrossEntropySystem +from strhub.models.utils import init_weights +from .model_abinet_iter import ABINetIterModel as Model + +log = logging.getLogger(__name__) + + +class ABINet(CrossEntropySystem): + + def __init__(self, charset_train: str, charset_test: str, max_label_length: int, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float, + iter_size: int, d_model: int, nhead: int, d_inner: int, dropout: float, activation: str, + v_loss_weight: float, v_attention: str, v_attention_mode: str, v_backbone: str, v_num_layers: int, + l_loss_weight: float, l_num_layers: int, l_detach: bool, l_use_self_attn: bool, + l_lr: float, a_loss_weight: float, lm_only: bool = False, **kwargs) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.scheduler = None + self.save_hyperparameters() + self.max_label_length = max_label_length + self.num_classes = len(self.tokenizer) - 2 # We don't predict nor + self.model = Model(max_label_length, self.eos_id, self.num_classes, iter_size, d_model, nhead, d_inner, + dropout, activation, v_loss_weight, v_attention, v_attention_mode, v_backbone, v_num_layers, + l_loss_weight, l_num_layers, l_detach, l_use_self_attn, a_loss_weight) + self.model.apply(init_weights) + # FIXME: doesn't support resumption from checkpoint yet + self._reset_alignment = True + self._reset_optimizers = True + self.l_lr = l_lr + self.lm_only = lm_only + # Train LM only. Freeze other submodels. + if lm_only: + self.l_lr = lr # for tuning + self.model.vision.requires_grad_(False) + self.model.alignment.requires_grad_(False) + + @property + def _pretraining(self): + # In the original work, VM was pretrained for 8 epochs while full model was trained for an additional 10 epochs. + total_steps = self.trainer.estimated_stepping_batches * self.trainer.accumulate_grad_batches + return self.global_step < (8 / (8 + 10)) * total_steps + + @torch.jit.ignore + def no_weight_decay(self): + return {'model.language.proj.weight'} + + def _add_weight_decay(self, model: nn.Module, skip_list=()): + if self.weight_decay: + return param_groups_weight_decay(model, self.weight_decay, skip_list) + else: + return [{'params': model.parameters()}] + + def configure_optimizers(self): + agb = self.trainer.accumulate_grad_batches + # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP. + lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256. + lr = lr_scale * self.lr + l_lr = lr_scale * self.l_lr + params = [] + params.extend(self._add_weight_decay(self.model.vision)) + params.extend(self._add_weight_decay(self.model.alignment)) + # We use a different learning rate for the LM. + for p in self._add_weight_decay(self.model.language, ('proj.weight',)): + p['lr'] = l_lr + params.append(p) + max_lr = [p.get('lr', lr) for p in params] + optim = AdamW(params, lr) + self.scheduler = OneCycleLR(optim, max_lr, self.trainer.estimated_stepping_batches, + pct_start=self.warmup_pct, cycle_momentum=False) + return {'optimizer': optim, 'lr_scheduler': {'scheduler': self.scheduler, 'interval': 'step'}} + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + logits = self.model.forward(images)[0]['logits'] + return logits[:, :max_length + 1] # truncate + + def calc_loss(self, targets, *res_lists) -> Tensor: + total_loss = 0 + for res_list in res_lists: + loss = 0 + if isinstance(res_list, dict): + res_list = [res_list] + for res in res_list: + logits = res['logits'].flatten(end_dim=1) + loss += F.cross_entropy(logits, targets.flatten(), ignore_index=self.pad_id) + loss /= len(res_list) + self.log('loss_' + res_list[0]['name'], loss) + total_loss += res_list[0]['loss_weight'] * loss + return total_loss + + def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: + if not self._pretraining and self._reset_optimizers: + log.info('Pretraining ends. Updating base LRs.') + self._reset_optimizers = False + # Make base_lr the same for all groups + base_lr = self.scheduler.base_lrs[0] # base_lr of group 0 - VM + self.scheduler.base_lrs = [base_lr] * len(self.scheduler.base_lrs) + + def _prepare_inputs_and_targets(self, labels): + # Use dummy label to ensure sequence length is constant. + dummy = ['0' * self.max_label_length] + targets = self.tokenizer.encode(dummy + list(labels), self.device)[1:] + targets = targets[:, 1:] # remove . Unused here. + # Inputs are padded with eos_id + inputs = torch.where(targets == self.pad_id, self.eos_id, targets) + inputs = F.one_hot(inputs, self.num_classes).float() + lengths = torch.as_tensor(list(map(len, labels)), device=self.device) + 1 # +1 for eos + return inputs, lengths, targets + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + inputs, lengths, targets = self._prepare_inputs_and_targets(labels) + if self.lm_only: + l_res = self.model.language(inputs, lengths) + loss = self.calc_loss(targets, l_res) + # Pretrain submodels independently first + elif self._pretraining: + # Vision + v_res = self.model.vision(images) + # Language + l_res = self.model.language(inputs, lengths) + # We also train the alignment model to 'satisfy' DDP requirements (all parameters should be used). + # We'll reset its parameters prior to joint training. + a_res = self.model.alignment(l_res['feature'].detach(), v_res['feature'].detach()) + loss = self.calc_loss(targets, v_res, l_res, a_res) + else: + # Reset alignment model's parameters once prior to full model training. + if self._reset_alignment: + log.info('Pretraining ends. Resetting alignment model.') + self._reset_alignment = False + self.model.alignment.apply(init_weights) + all_a_res, all_l_res, v_res = self.model.forward(images) + loss = self.calc_loss(targets, v_res, all_l_res, all_a_res) + self.log('loss', loss) + return loss + + def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]: + if self.lm_only: + inputs, lengths, targets = self._prepare_inputs_and_targets(labels) + l_res = self.model.language(inputs, lengths) + loss = self.calc_loss(targets, l_res) + loss_numel = (targets != self.pad_id).sum() + return l_res['logits'], loss, loss_numel + else: + return super().forward_logits_loss(images, labels) diff --git a/strhub/models/abinet/transformer.py b/strhub/models/abinet/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a920805d67eea5671675f5623a47d06ad8af894b --- /dev/null +++ b/strhub/models/abinet/transformer.py @@ -0,0 +1,143 @@ +import math + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.modules.transformer import _get_activation_fn + + +class TransformerDecoderLayer(nn.Module): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", self_attn=True, siamese=False, debug=False): + super().__init__() + self.has_self_attn, self.siamese = self_attn, siamese + self.debug = debug + if self.has_self_attn: + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.norm1 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + if self.siamese: + self.multihead_attn2 = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super().__setstate__(state) + + def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, + tgt_key_padding_mask=None, memory_key_padding_mask=None, + memory2=None, memory_mask2=None, memory_key_padding_mask2=None): + # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor + r"""Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + if self.has_self_attn: + tgt2, attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + if self.debug: self.attn = attn + tgt2, attn2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + if self.debug: self.attn2 = attn2 + + if self.siamese: + tgt3, attn3 = self.multihead_attn2(tgt, memory2, memory2, attn_mask=memory_mask2, + key_padding_mask=memory_key_padding_mask2) + tgt = tgt + self.dropout2(tgt3) + if self.debug: self.attn3 = attn3 + + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + + return tgt + + +class PositionalEncoding(nn.Module): + r"""Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + r"""Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + + x = x + self.pe[:x.size(0), :] + return self.dropout(x) diff --git a/strhub/models/base.py b/strhub/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a9fafe450f615f027503522ff5c81443fe29aa4a --- /dev/null +++ b/strhub/models/base.py @@ -0,0 +1,202 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Tuple, List + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from nltk import edit_distance +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from timm.optim import create_optimizer_v2 +from torch import Tensor +from torch.optim import Optimizer +from torch.optim.lr_scheduler import OneCycleLR + +from strhub.data.utils import CharsetAdapter, CTCTokenizer, Tokenizer, BaseTokenizer + + +@dataclass +class BatchResult: + num_samples: int + correct: int + ned: float + confidence: float + label_length: int + loss: Tensor + loss_numel: int + + +class BaseSystem(pl.LightningModule, ABC): + + def __init__(self, tokenizer: BaseTokenizer, charset_test: str, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None: + super().__init__() + self.tokenizer = tokenizer + self.charset_adapter = CharsetAdapter(charset_test) + self.batch_size = batch_size + self.lr = lr + self.warmup_pct = warmup_pct + self.weight_decay = weight_decay + + @abstractmethod + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + """Inference + + Args: + images: Batch of images. Shape: N, Ch, H, W + max_length: Max sequence length of the output. If None, will use default. + + Returns: + logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials) + """ + raise NotImplementedError + + @abstractmethod + def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]: + """Like forward(), but also computes the loss (calls forward() internally). + + Args: + images: Batch of images. Shape: N, Ch, H, W + labels: Text labels of the images + + Returns: + logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials) + loss: mean loss for the batch + loss_numel: number of elements the loss was calculated from + """ + raise NotImplementedError + + def configure_optimizers(self): + agb = self.trainer.accumulate_grad_batches + # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP. + lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256. + lr = lr_scale * self.lr + optim = create_optimizer_v2(self, 'adamw', lr, self.weight_decay) + sched = OneCycleLR(optim, lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, + cycle_momentum=False) + return {'optimizer': optim, 'lr_scheduler': {'scheduler': sched, 'interval': 'step'}} + + def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): + optimizer.zero_grad(set_to_none=True) + + def _eval_step(self, batch, validation: bool) -> Optional[STEP_OUTPUT]: + images, labels = batch + + correct = 0 + total = 0 + ned = 0 + confidence = 0 + label_length = 0 + if validation: + logits, loss, loss_numel = self.forward_logits_loss(images, labels) + else: + # At test-time, we shouldn't specify a max_label_length because the test-time charset used + # might be different from the train-time charset. max_label_length in eval_logits_loss() is computed + # based on the transformed label, which could be wrong if the actual gt label contains characters existing + # in the train-time charset but not in the test-time charset. For example, "aishahaleyes.blogspot.com" + # is exactly 25 characters, but if processed by CharsetAdapter for the 36-char set, it becomes 23 characters + # long only, which sets max_label_length = 23. This will cause the model prediction to be truncated. + logits = self.forward(images) + loss = loss_numel = None # Only used for validation; not needed at test-time. + + probs = logits.softmax(-1) + preds, probs = self.tokenizer.decode(probs) + for pred, prob, gt in zip(preds, probs, labels): + confidence += prob.prod().item() + pred = self.charset_adapter(pred) + # Follow ICDAR 2019 definition of N.E.D. + ned += edit_distance(pred, gt) / max(len(pred), len(gt)) + if pred == gt: + correct += 1 + total += 1 + label_length += len(pred) + return dict(output=BatchResult(total, correct, ned, confidence, label_length, loss, loss_numel)) + + @staticmethod + def _aggregate_results(outputs: EPOCH_OUTPUT) -> Tuple[float, float, float]: + if not outputs: + return 0., 0., 0. + total_loss = 0 + total_loss_numel = 0 + total_n_correct = 0 + total_norm_ED = 0 + total_size = 0 + for result in outputs: + result = result['output'] + total_loss += result.loss_numel * result.loss + total_loss_numel += result.loss_numel + total_n_correct += result.correct + total_norm_ED += result.ned + total_size += result.num_samples + acc = total_n_correct / total_size + ned = (1 - total_norm_ED / total_size) + loss = total_loss / total_loss_numel + return acc, ned, loss + + def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]: + return self._eval_step(batch, True) + + def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + acc, ned, loss = self._aggregate_results(outputs) + self.log('val_accuracy', 100 * acc, sync_dist=True) + self.log('val_NED', 100 * ned, sync_dist=True) + self.log('val_loss', loss, sync_dist=True) + self.log('hp_metric', acc, sync_dist=True) + + def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]: + return self._eval_step(batch, False) + + +class CrossEntropySystem(BaseSystem): + + def __init__(self, charset_train: str, charset_test: str, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None: + tokenizer = Tokenizer(charset_train) + super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.bos_id = tokenizer.bos_id + self.eos_id = tokenizer.eos_id + self.pad_id = tokenizer.pad_id + + def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]: + targets = self.tokenizer.encode(labels, self.device) + targets = targets[:, 1:] # Discard + max_len = targets.shape[1] - 1 # exclude from count + logits = self.forward(images, max_len) + loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id) + loss_numel = (targets != self.pad_id).sum() + return logits, loss, loss_numel + + +class CTCSystem(BaseSystem): + + def __init__(self, charset_train: str, charset_test: str, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None: + tokenizer = CTCTokenizer(charset_train) + super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.blank_id = tokenizer.blank_id + + def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]: + targets = self.tokenizer.encode(labels, self.device) + logits = self.forward(images) + log_probs = logits.log_softmax(-1).transpose(0, 1) # swap batch and seq. dims + T, N, _ = log_probs.shape + input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long, device=self.device) + target_lengths = torch.as_tensor(list(map(len, labels)), dtype=torch.long, device=self.device) + loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_id, zero_infinity=True) + return logits, loss, N diff --git a/strhub/models/crnn/LICENSE b/strhub/models/crnn/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f98687be392fdce266708e79885aadaa4991b67f --- /dev/null +++ b/strhub/models/crnn/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2017 Jieru Mei + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/strhub/models/crnn/__init__.py b/strhub/models/crnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4535947d9233c8fb0a85e9c22b151697d37f410 --- /dev/null +++ b/strhub/models/crnn/__init__.py @@ -0,0 +1,13 @@ +r""" +Shi, Baoguang, Xiang Bai, and Cong Yao. +"An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition." +IEEE transactions on pattern analysis and machine intelligence 39, no. 11 (2016): 2298-2304. + +https://arxiv.org/abs/1507.05717 + +All source files, except `system.py`, are based on the implementation listed below, +and hence are released under the license of the original. + +Source: https://github.com/meijieru/crnn.pytorch +License: MIT License (see included LICENSE file) +""" diff --git a/strhub/models/crnn/model.py b/strhub/models/crnn/model.py new file mode 100644 index 0000000000000000000000000000000000000000..1a71845fba242c3c63c15a79cf43134a35807453 --- /dev/null +++ b/strhub/models/crnn/model.py @@ -0,0 +1,62 @@ +import torch.nn as nn + +from strhub.models.modules import BidirectionalLSTM + + +class CRNN(nn.Module): + + def __init__(self, img_h, nc, nclass, nh, leaky_relu=False): + super().__init__() + assert img_h % 16 == 0, 'img_h has to be a multiple of 16' + + ks = [3, 3, 3, 3, 3, 3, 2] + ps = [1, 1, 1, 1, 1, 1, 0] + ss = [1, 1, 1, 1, 1, 1, 1] + nm = [64, 128, 256, 256, 512, 512, 512] + + cnn = nn.Sequential() + + def convRelu(i, batchNormalization=False): + nIn = nc if i == 0 else nm[i - 1] + nOut = nm[i] + cnn.add_module('conv{0}'.format(i), + nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i], bias=not batchNormalization)) + if batchNormalization: + cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) + if leaky_relu: + cnn.add_module('relu{0}'.format(i), + nn.LeakyReLU(0.2, inplace=True)) + else: + cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) + + convRelu(0) + cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 + convRelu(1) + cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 + convRelu(2, True) + convRelu(3) + cnn.add_module('pooling{0}'.format(2), + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 + convRelu(4, True) + convRelu(5) + cnn.add_module('pooling{0}'.format(3), + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 + convRelu(6, True) # 512x1x16 + + self.cnn = cnn + self.rnn = nn.Sequential( + BidirectionalLSTM(512, nh, nh), + BidirectionalLSTM(nh, nh, nclass)) + + def forward(self, input): + # conv features + conv = self.cnn(input) + b, c, h, w = conv.size() + assert h == 1, 'the height of conv must be 1' + conv = conv.squeeze(2) + conv = conv.transpose(1, 2) # [b, w, c] + + # rnn features + output = self.rnn(conv) + + return output diff --git a/strhub/models/crnn/system.py b/strhub/models/crnn/system.py new file mode 100644 index 0000000000000000000000000000000000000000..abcb28e5c29f4ca484b87e283d1e4615e56378f2 --- /dev/null +++ b/strhub/models/crnn/system.py @@ -0,0 +1,43 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Optional + +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch import Tensor + +from strhub.models.base import CTCSystem +from strhub.models.utils import init_weights +from .model import CRNN as Model + + +class CRNN(CTCSystem): + + def __init__(self, charset_train: str, charset_test: str, max_label_length: int, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float, + img_size: Sequence[int], hidden_size: int, leaky_relu: bool, **kwargs) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + self.model = Model(img_size[0], 3, len(self.tokenizer), hidden_size, leaky_relu) + self.model.apply(init_weights) + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + return self.model.forward(images) + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + loss = self.forward_logits_loss(images, labels)[1] + self.log('loss', loss) + return loss diff --git a/strhub/models/modules.py b/strhub/models/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..a89d05f6afd67437f3cfa8aff6d2d8b12df3fafa --- /dev/null +++ b/strhub/models/modules.py @@ -0,0 +1,20 @@ +r"""Shared modules used by CRNN and TRBA""" +from torch import nn + + +class BidirectionalLSTM(nn.Module): + """Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py""" + + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) + self.linear = nn.Linear(hidden_size * 2, output_size) + + def forward(self, input): + """ + input : visual feature [batch_size x T x input_size], T = num_steps. + output : contextual feature [batch_size x T x output_size] + """ + recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) + output = self.linear(recurrent) # batch_size x T x output_size + return output diff --git a/strhub/models/parseq/__init__.py b/strhub/models/parseq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/strhub/models/parseq/modules.py b/strhub/models/parseq/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..a7fdbd60bc5f05978448d4d00b73cb917bfea796 --- /dev/null +++ b/strhub/models/parseq/modules.py @@ -0,0 +1,126 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +from torch import nn as nn, Tensor +from torch.nn import functional as F +from torch.nn.modules import transformer + +from timm.models.vision_transformer import VisionTransformer, PatchEmbed + + +class DecoderLayer(nn.Module): + """A Transformer decoder layer supporting two-stream attention (XLNet) + This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch.""" + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', + layer_norm_eps=1e-5): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) + self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = transformer._get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.gelu + super().__setstate__(state) + + def forward_stream(self, tgt: Tensor, tgt_norm: Tensor, tgt_kv: Tensor, memory: Tensor, tgt_mask: Optional[Tensor], + tgt_key_padding_mask: Optional[Tensor]): + """Forward pass for a single stream (i.e. content or query) + tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency. + Both tgt_kv and memory are expected to be LayerNorm'd too. + memory is LayerNorm'd by ViT. + """ + tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + tgt = tgt + self.dropout1(tgt2) + + tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory) + tgt = tgt + self.dropout2(tgt2) + + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt))))) + tgt = tgt + self.dropout3(tgt2) + return tgt, sa_weights, ca_weights + + def forward(self, query, content, memory, query_mask: Optional[Tensor] = None, content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None, update_content: bool = True): + query_norm = self.norm_q(query) + content_norm = self.norm_c(content) + query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0] + if update_content: + content = self.forward_stream(content, content_norm, content_norm, memory, content_mask, + content_key_padding_mask)[0] + return query, content + + +class Decoder(nn.Module): + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm): + super().__init__() + self.layers = transformer._get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, query, content, memory, query_mask: Optional[Tensor] = None, content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None): + for i, mod in enumerate(self.layers): + last = i == len(self.layers) - 1 + query, content = mod(query, content, memory, query_mask, content_mask, content_key_padding_mask, + update_content=not last) + query = self.norm(query) + return query + + +class Encoder(VisionTransformer): + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., + qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed): + super().__init__(img_size, patch_size, in_chans, embed_dim=embed_dim, depth=depth, num_heads=num_heads, + mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, embed_layer=embed_layer, + num_classes=0, global_pool='', class_token=False) # these disable the classifier head + + def forward(self, x): + # Return all tokens + return self.forward_features(x) + + +class TokenEmbedding(nn.Module): + + def __init__(self, charset_size: int, embed_dim: int): + super().__init__() + self.embedding = nn.Embedding(charset_size, embed_dim) + self.embed_dim = embed_dim + + def forward(self, tokens: torch.Tensor): + return math.sqrt(self.embed_dim) * self.embedding(tokens) diff --git a/strhub/models/parseq/system.py b/strhub/models/parseq/system.py new file mode 100644 index 0000000000000000000000000000000000000000..a7079d544bcd5dbbfb83c9d94491ad3bbfdfa3ea --- /dev/null +++ b/strhub/models/parseq/system.py @@ -0,0 +1,270 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import partial +from itertools import permutations +from typing import Sequence, Any, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from pytorch_lightning.utilities.types import STEP_OUTPUT +from timm.models.helpers import named_apply + +from strhub.models.base import CrossEntropySystem +from strhub.models.utils import init_weights +from .modules import DecoderLayer, Decoder, Encoder, TokenEmbedding +import sys + + +class PARSeq(CrossEntropySystem): + + def __init__(self, charset_train: str, charset_test: str, max_label_length: int, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float, + img_size: Sequence[int], patch_size: Sequence[int], embed_dim: int, + enc_num_heads: int, enc_mlp_ratio: int, enc_depth: int, + dec_num_heads: int, dec_mlp_ratio: int, dec_depth: int, + perm_num: int, perm_forward: bool, perm_mirrored: bool, + decode_ar: bool, refine_iters: int, dropout: float, **kwargs: Any) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + + self.max_label_length = max_label_length + self.decode_ar = decode_ar + self.refine_iters = refine_iters + + self.encoder = Encoder(img_size, patch_size, embed_dim=embed_dim, depth=enc_depth, num_heads=enc_num_heads, + mlp_ratio=enc_mlp_ratio) + decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) + self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=nn.LayerNorm(embed_dim)) + + # Perm/attn mask stuff + self.rng = np.random.default_rng() + self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num + self.perm_forward = perm_forward + self.perm_mirrored = perm_mirrored + + # We don't predict nor + self.head = nn.Linear(embed_dim, len(self.tokenizer) - 2) + self.text_embed = TokenEmbedding(len(self.tokenizer), embed_dim) + + # +1 for + self.pos_queries = nn.Parameter(torch.Tensor(1, max_label_length + 1, embed_dim)) + self.dropout = nn.Dropout(p=dropout) + # Encoder has its own init. + named_apply(partial(init_weights, exclude=['encoder']), self) + nn.init.trunc_normal_(self.pos_queries, std=.02) + + @torch.jit.ignore + def no_weight_decay(self): + param_names = {'text_embed.embedding.weight', 'pos_queries'} + enc_param_names = {'encoder.' + n for n in self.encoder.no_weight_decay()} + return param_names.union(enc_param_names) + + def encode(self, img: torch.Tensor): + return self.encoder(img) + + def decode(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[Tensor] = None, + tgt_padding_mask: Optional[Tensor] = None, tgt_query: Optional[Tensor] = None, + tgt_query_mask: Optional[Tensor] = None): + N, L = tgt.shape + # stands for the null context. We only supply position information for characters after . + null_ctx = self.text_embed(tgt[:, :1]) + tgt_emb = self.pos_queries[:, :L - 1] + self.text_embed(tgt[:, 1:]) + tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1)) + if tgt_query is None: + tgt_query = self.pos_queries[:, :L].expand(N, -1, -1) + tgt_query = self.dropout(tgt_query) + return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask) + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + testing = max_length is None + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + bs = images.shape[0] + # +1 for at end of sequence. + num_steps = max_length + 1 + # images shape: torch.Size([1, 3, 32, 128]) + memory = self.encode(images) + # memory shape: torch.Size([1, 128, 384]) + + # Query positions up to `num_steps` + pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1) + # pos_queries shape: torch.Size([1, 26, 384]) + # self.decode_ar: True + # sys.exit() + + # Special case for the forward permutation. Faster than using `generate_attn_masks()` + tgt_mask = query_mask = torch.triu(torch.full((num_steps, num_steps), float('-inf'), device=self._device), 1) + # num_steps: 26 + if self.decode_ar: + tgt_in = torch.full((bs, num_steps), self.pad_id, dtype=torch.long, device=self._device) + tgt_in[:, 0] = self.bos_id + + logits = [] + for i in range(num_steps): + j = i + 1 # next token index + # Efficient decoding: + # Input the context up to the ith token. We use only one query (at position = i) at a time. + # This works because of the lookahead masking effect of the canonical (forward) AR context. + # Past tokens have no access to future tokens, hence are fixed once computed. + tgt_out = self.decode(tgt_in[:, :j], memory, tgt_mask[:j, :j], tgt_query=pos_queries[:, i:j], + tgt_query_mask=query_mask[i:j, :j]) + # tgt_out shape: torch.Size([1, 1, 384]) + + # the next token probability is in the output's ith token position + p_i = self.head(tgt_out) + logits.append(p_i) + if j < num_steps: + # greedy decode. add the next token index to the target input + tgt_in[:, j] = p_i.squeeze().argmax(-1) + # Efficient batch decoding: If all output words have at least one EOS token, end decoding. + if testing and (tgt_in == self.eos_id).any(dim=-1).all(): + break + + logits = torch.cat(logits, dim=1) + else: + # No prior context, so input is just . We query all positions. + tgt_in = torch.full((bs, 1), self.bos_id, dtype=torch.long, device=self._device) + tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries) + logits = self.head(tgt_out) + # print("logits before: ", logits.shape) torch.Size([1, 6, 95]) + # self.refine_iters: 1 + if self.refine_iters: + # For iterative refinement, we always use a 'cloze' mask. + # We can derive it from the AR forward mask by unmasking the token context to the right. + query_mask[torch.triu(torch.ones(num_steps, num_steps, dtype=torch.bool, device=self._device), 2)] = 0 + bos = torch.full((bs, 1), self.bos_id, dtype=torch.long, device=self._device) + for i in range(self.refine_iters): + # Prior context is the previous output. + tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1) + # print("tgt_in 2: ", tgt_in.shape) torch.Size([1, 6]) + tgt_padding_mask = ((tgt_in == self.eos_id).cumsum(-1) > 0) # mask tokens beyond the first EOS token. + tgt_out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, + tgt_query=pos_queries, tgt_query_mask=query_mask[:, :tgt_in.shape[1]]) + # print("tgt_out 2: ", tgt_out.shape) torch.Size([1, 26, 384]) + logits = self.head(tgt_out) + + return logits + + def gen_tgt_perms(self, tgt): + """Generate shared permutations for the whole batch. + This works because the same attention mask can be used for the shorter sequences + because of the padding mask. + """ + # We don't permute the position of BOS, we permute EOS separately + max_num_chars = tgt.shape[1] - 2 + # Special handling for 1-character sequences + if max_num_chars == 1: + return torch.arange(3, device=self._device).unsqueeze(0) + perms = [torch.arange(max_num_chars, device=self._device)] if self.perm_forward else [] + # Additional permutations if needed + max_perms = math.factorial(max_num_chars) + if self.perm_mirrored: + max_perms //= 2 + num_gen_perms = min(self.max_gen_perms, max_perms) + # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions + # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars. + if max_num_chars < 5: + # Pool of permutations to sample from. We only need the first half (if complementary option is selected) + # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves + if max_num_chars == 4 and self.perm_mirrored: + selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21] + else: + selector = list(range(max_perms)) + perm_pool = torch.as_tensor(list(permutations(range(max_num_chars), max_num_chars)), device=self._device)[selector] + # If the forward permutation is always selected, no need to add it to the pool for sampling + if self.perm_forward: + perm_pool = perm_pool[1:] + perms = torch.stack(perms) + if len(perm_pool): + i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(perms), replace=False) + perms = torch.cat([perms, perm_pool[i]]) + else: + perms.extend([torch.randperm(max_num_chars, device=self._device) for _ in range(num_gen_perms - len(perms))]) + perms = torch.stack(perms) + if self.perm_mirrored: + # Add complementary pairs + comp = perms.flip(-1) + # Stack in such a way that the pairs are next to each other. + perms = torch.stack([perms, comp]).transpose(0, 1).reshape(-1, max_num_chars) + # NOTE: + # The only meaningful way of permuting the EOS position is by moving it one character position at a time. + # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS + # positions will always be much less than the number of permutations (unless a low perm_num is set). + # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly + # distribute it across the chosen number of permutations. + # Add position indices of BOS and EOS + bos_idx = perms.new_zeros((len(perms), 1)) + eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1) + perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1) + # Special handling for the reverse direction. This does two things: + # 1. Reverse context for the characters + # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode) + if len(perms) > 1: + perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=self._device) + return perms + + def generate_attn_masks(self, perm): + """Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens) + :param perm: the permutation sequence. i = 0 is always the BOS + :return: lookahead attention masks + """ + sz = perm.shape[0] + mask = torch.zeros((sz, sz), device=self._device) + for i in range(sz): + query_idx = perm[i] + masked_keys = perm[i + 1:] + mask[query_idx, masked_keys] = float('-inf') + content_mask = mask[:-1, :-1].clone() + mask[torch.eye(sz, dtype=torch.bool, device=self._device)] = float('-inf') # mask "self" + query_mask = mask[1:, :-1] + return content_mask, query_mask + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + tgt = self.tokenizer.encode(labels, self._device) + + # Encode the source sequence (i.e. the image codes) + memory = self.encode(images) + + # Prepare the target sequences (input and output) + tgt_perms = self.gen_tgt_perms(tgt) + tgt_in = tgt[:, :-1] + tgt_out = tgt[:, 1:] + # The [EOS] token is not depended upon by any other token in any permutation ordering + tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id) + + loss = 0 + loss_numel = 0 + n = (tgt_out != self.pad_id).sum().item() + for i, perm in enumerate(tgt_perms): + tgt_mask, query_mask = self.generate_attn_masks(perm) + out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask) + logits = self.head(out).flatten(end_dim=1) + loss += n * F.cross_entropy(logits, tgt_out.flatten(), ignore_index=self.pad_id) + loss_numel += n + # After the second iteration (i.e. done with canonical and reverse orderings), + # remove the [EOS] tokens for the succeeding perms + if i == 1: + tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id, tgt_out) + n = (tgt_out != self.pad_id).sum().item() + loss /= loss_numel + + self.log('loss', loss) + return loss diff --git a/strhub/models/trba/__init__.py b/strhub/models/trba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a574a8af95e7f1ffaa05c45b4cd22f4a3cc0a5c0 --- /dev/null +++ b/strhub/models/trba/__init__.py @@ -0,0 +1,13 @@ +r""" +Baek, Jeonghun, Geewook Kim, Junyeop Lee, Sungrae Park, Dongyoon Han, Sangdoo Yun, Seong Joon Oh, and Hwalsuk Lee. +"What is wrong with scene text recognition model comparisons? dataset and model analysis." +In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4715-4723. 2019. + +https://arxiv.org/abs/1904.01906 + +All source files, except `system.py`, are based on the implementation listed below, +and hence are released under the license of the original. + +Source: https://github.com/clovaai/deep-text-recognition-benchmark +License: Apache License 2.0 (see LICENSE file in project root) +""" diff --git a/strhub/models/trba/feature_extraction.py b/strhub/models/trba/feature_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..17646e3ff83ad28c1021237824a838e38c3b6345 --- /dev/null +++ b/strhub/models/trba/feature_extraction.py @@ -0,0 +1,110 @@ +import torch.nn as nn + +from torchvision.models.resnet import BasicBlock + + +class ResNet_FeatureExtractor(nn.Module): + """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super().__init__() + self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) + + def forward(self, input): + return self.ConvNet(input) + + +class ResNet(nn.Module): + + def __init__(self, input_channel, output_channel, block, layers): + super().__init__() + + self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] + + self.inplanes = int(output_channel / 8) + self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) + self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_2 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU(inplace=True) + + self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) + self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ + 0], kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) + + self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) + self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ + 1], kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) + + self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) + self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) + self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ + 2], kernel_size=3, stride=1, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) + + self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) + self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) + self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) + self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=1, padding=0, bias=False) + self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv0_1(x) + x = self.bn0_1(x) + x = self.relu(x) + x = self.conv0_2(x) + x = self.bn0_2(x) + x = self.relu(x) + + x = self.maxpool1(x) + x = self.layer1(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.maxpool2(x) + x = self.layer2(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.maxpool3(x) + x = self.layer3(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.layer4(x) + x = self.conv4_1(x) + x = self.bn4_1(x) + x = self.relu(x) + x = self.conv4_2(x) + x = self.bn4_2(x) + x = self.relu(x) + + return x diff --git a/strhub/models/trba/model.py b/strhub/models/trba/model.py new file mode 100644 index 0000000000000000000000000000000000000000..41161a4df4e2ff368bfe1c62f681c6964510a0c0 --- /dev/null +++ b/strhub/models/trba/model.py @@ -0,0 +1,55 @@ +import torch.nn as nn + +from strhub.models.modules import BidirectionalLSTM +from .feature_extraction import ResNet_FeatureExtractor +from .prediction import Attention +from .transformation import TPS_SpatialTransformerNetwork + + +class TRBA(nn.Module): + + def __init__(self, img_h, img_w, num_class, num_fiducial=20, input_channel=3, output_channel=512, hidden_size=256, + use_ctc=False): + super().__init__() + """ Transformation """ + self.Transformation = TPS_SpatialTransformerNetwork( + F=num_fiducial, I_size=(img_h, img_w), I_r_size=(img_h, img_w), + I_channel_num=input_channel) + + """ FeatureExtraction """ + self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel) + self.FeatureExtraction_output = output_channel + self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 + + """ Sequence modeling""" + self.SequenceModeling = nn.Sequential( + BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size), + BidirectionalLSTM(hidden_size, hidden_size, hidden_size)) + self.SequenceModeling_output = hidden_size + + """ Prediction """ + if use_ctc: + self.Prediction = nn.Linear(self.SequenceModeling_output, num_class) + else: + self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class) + + def forward(self, image, max_label_length, text=None): + """ Transformation stage """ + image = self.Transformation(image) + + """ Feature extraction stage """ + visual_feature = self.FeatureExtraction(image) + visual_feature = visual_feature.permute(0, 3, 1, 2) # [b, c, h, w] -> [b, w, c, h] + visual_feature = self.AdaptiveAvgPool(visual_feature) # [b, w, c, h] -> [b, w, c, 1] + visual_feature = visual_feature.squeeze(3) # [b, w, c, 1] -> [b, w, c] + + """ Sequence modeling stage """ + contextual_feature = self.SequenceModeling(visual_feature) # [b, num_steps, hidden_size] + + """ Prediction stage """ + if isinstance(self.Prediction, Attention): + prediction = self.Prediction(contextual_feature.contiguous(), text, max_label_length) + else: + prediction = self.Prediction(contextual_feature.contiguous()) # CTC + + return prediction # [b, num_steps, num_class] diff --git a/strhub/models/trba/prediction.py b/strhub/models/trba/prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..5609398a28ef5288d3f3971786c2cebc2e574336 --- /dev/null +++ b/strhub/models/trba/prediction.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Attention(nn.Module): + + def __init__(self, input_size, hidden_size, num_class, num_char_embeddings=256): + super().__init__() + self.attention_cell = AttentionCell(input_size, hidden_size, num_char_embeddings) + self.hidden_size = hidden_size + self.num_class = num_class + self.generator = nn.Linear(hidden_size, num_class) + self.char_embeddings = nn.Embedding(num_class, num_char_embeddings) + + def forward(self, batch_H, text, max_label_length=25): + """ + input: + batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_class] + text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [SOS] token. text[:, 0] = [SOS]. + output: probability distribution at each step [batch_size x num_steps x num_class] + """ + batch_size = batch_H.size(0) + num_steps = max_label_length + 1 # +1 for [EOS] at end of sentence. + + output_hiddens = batch_H.new_zeros((batch_size, num_steps, self.hidden_size), dtype=torch.float) + hidden = (batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float), + batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float)) + + if self.training: + for i in range(num_steps): + char_embeddings = self.char_embeddings(text[:, i]) + # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_embeddings : f(y_{t-1}) + hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings) + output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) + probs = self.generator(output_hiddens) + + else: + targets = text[0].expand(batch_size) # should be fill with [SOS] token + probs = batch_H.new_zeros((batch_size, num_steps, self.num_class), dtype=torch.float) + + for i in range(num_steps): + char_embeddings = self.char_embeddings(targets) + hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings) + probs_step = self.generator(hidden[0]) + probs[:, i, :] = probs_step + _, next_input = probs_step.max(1) + targets = next_input + + return probs # batch_size x num_steps x num_class + + +class AttentionCell(nn.Module): + + def __init__(self, input_size, hidden_size, num_embeddings): + super().__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias=False) + self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias + self.score = nn.Linear(hidden_size, 1, bias=False) + self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_embeddings): + # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) + e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 + + alpha = F.softmax(e, dim=1) + context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel + concat_context = torch.cat([context, char_embeddings], 1) # batch_size x (num_channel + num_embedding) + cur_hidden = self.rnn(concat_context, prev_hidden) + return cur_hidden, alpha diff --git a/strhub/models/trba/system.py b/strhub/models/trba/system.py new file mode 100644 index 0000000000000000000000000000000000000000..31bbb6d44e6eabf47402ae998ffbe4de8fc427a2 --- /dev/null +++ b/strhub/models/trba/system.py @@ -0,0 +1,87 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Sequence, Any, Optional + +import torch +import torch.nn.functional as F +from pytorch_lightning.utilities.types import STEP_OUTPUT +from timm.models.helpers import named_apply +from torch import Tensor + +from strhub.models.base import CrossEntropySystem, CTCSystem +from strhub.models.utils import init_weights +from .model import TRBA as Model + + +class TRBA(CrossEntropySystem): + + def __init__(self, charset_train: str, charset_test: str, max_label_length: int, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float, + img_size: Sequence[int], num_fiducial: int, output_channel: int, hidden_size: int, + **kwargs: Any) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + self.max_label_length = max_label_length + img_h, img_w = img_size + self.model = Model(img_h, img_w, len(self.tokenizer), num_fiducial, + output_channel=output_channel, hidden_size=hidden_size, use_ctc=False) + named_apply(partial(init_weights, exclude=['Transformation.LocalizationNetwork.localization_fc2']), self.model) + + @torch.jit.ignore + def no_weight_decay(self): + return {'model.Prediction.char_embeddings.weight'} + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + text = images.new_full([1], self.bos_id, dtype=torch.long) + return self.model.forward(images, max_length, text) + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + encoded = self.tokenizer.encode(labels, self.device) + inputs = encoded[:, :-1] # remove + targets = encoded[:, 1:] # remove + max_length = encoded.shape[1] - 2 # exclude and from count + logits = self.model.forward(images, max_length, inputs) + loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id) + self.log('loss', loss) + return loss + + +class TRBC(CTCSystem): + + def __init__(self, charset_train: str, charset_test: str, max_label_length: int, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float, + img_size: Sequence[int], num_fiducial: int, output_channel: int, hidden_size: int, + **kwargs: Any) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + self.max_label_length = max_label_length + img_h, img_w = img_size + self.model = Model(img_h, img_w, len(self.tokenizer), num_fiducial, + output_channel=output_channel, hidden_size=hidden_size, use_ctc=True) + named_apply(partial(init_weights, exclude=['Transformation.LocalizationNetwork.localization_fc2']), self.model) + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + # max_label_length is unused in CTC prediction + return self.model.forward(images, None) + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + loss = self.forward_logits_loss(images, labels)[1] + self.log('loss', loss) + return loss diff --git a/strhub/models/trba/transformation.py b/strhub/models/trba/transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..960419d135ec878aaaa3297c3ff5c22e998ef6be --- /dev/null +++ b/strhub/models/trba/transformation.py @@ -0,0 +1,169 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TPS_SpatialTransformerNetwork(nn.Module): + """ Rectification Network of RARE, namely TPS based STN """ + + def __init__(self, F, I_size, I_r_size, I_channel_num=1): + """ Based on RARE TPS + input: + batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] + I_size : (height, width) of the input image I + I_r_size : (height, width) of the rectified image I_r + I_channel_num : the number of channels of the input image I + output: + batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] + """ + super().__init__() + self.F = F + self.I_size = I_size + self.I_r_size = I_r_size # = (I_r_height, I_r_width) + self.I_channel_num = I_channel_num + self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) + self.GridGenerator = GridGenerator(self.F, self.I_r_size) + + def forward(self, batch_I): + batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 + # batch_size x n (= I_r_width x I_r_height) x 2 + build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) + build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) + + if torch.__version__ > "1.2.0": + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) + else: + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') + + return batch_I_r + + +class LocalizationNetwork(nn.Module): + """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ + + def __init__(self, F, I_channel_num): + super().__init__() + self.F = F + self.I_channel_num = I_channel_num + self.conv = nn.Sequential( + nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, + bias=False), nn.BatchNorm2d(64), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 + nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 + nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 + nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), + nn.AdaptiveAvgPool2d(1) # batch_size x 512 + ) + + self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) + self.localization_fc2 = nn.Linear(256, self.F * 2) + + # Init fc2 in LocalizationNetwork + self.localization_fc2.weight.data.fill_(0) + """ see RARE paper Fig. 6 (a) """ + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) + ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) + + def forward(self, batch_I): + """ + input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] + output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] + """ + batch_size = batch_I.size(0) + features = self.conv(batch_I).view(batch_size, -1) + batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) + return batch_C_prime + + +class GridGenerator(nn.Module): + """ Grid Generator of RARE, which produces P_prime by multipling T with P """ + + def __init__(self, F, I_r_size): + """ Generate P_hat and inv_delta_C for later """ + super().__init__() + self.eps = 1e-6 + self.I_r_height, self.I_r_width = I_r_size + self.F = F + self.C = self._build_C(self.F) # F x 2 + self.P = self._build_P(self.I_r_width, self.I_r_height) + + # num_gpu = torch.cuda.device_count() + # if num_gpu > 1: + # for multi-gpu, you may need register buffer + self.register_buffer("inv_delta_C", torch.tensor( + self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 + self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 + # else: + # # for fine-tuning with different image width, you may use below instead of self.register_buffer + # self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float() # F+3 x F+3 + # self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float() # n x F+3 + + def _build_C(self, F): + """ Return coordinates of fiducial points in I_r; C """ + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = -1 * np.ones(int(F / 2)) + ctrl_pts_y_bottom = np.ones(int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + return C # F x 2 + + def _build_inv_delta_C(self, F, C): + """ Return inv_delta_C which is needed to calculate T """ + hat_C = np.zeros((F, F), dtype=float) # F x F + for i in range(0, F): + for j in range(i, F): + r = np.linalg.norm(C[i] - C[j]) + hat_C[i, j] = r + hat_C[j, i] = r + np.fill_diagonal(hat_C, 1) + hat_C = (hat_C ** 2) * np.log(hat_C) + # print(C.shape, hat_C.shape) + delta_C = np.concatenate( # F+3 x F+3 + [ + np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 + np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 + np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 + ], + axis=0 + ) + inv_delta_C = np.linalg.inv(delta_C) + return inv_delta_C # F+3 x F+3 + + def _build_P(self, I_r_width, I_r_height): + I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width + I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height + P = np.stack( # self.I_r_width x self.I_r_height x 2 + np.meshgrid(I_r_grid_x, I_r_grid_y), + axis=2 + ) + return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 + + def _build_P_hat(self, F, C, P): + n = P.shape[0] # n (= self.I_r_width x self.I_r_height) + P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 + C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 + P_diff = P_tile - C_tile # n x F x 2 + rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F + rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F + P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) + return P_hat # n x F+3 + + def build_P_prime(self, batch_C_prime): + """ Generate Grid from batch_C_prime [batch_size x F x 2] """ + batch_size = batch_C_prime.size(0) + batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) + batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) + batch_C_prime_with_zeros = torch.cat((batch_C_prime, batch_C_prime.new_zeros( + (batch_size, 3, 2), dtype=torch.float)), dim=1) # batch_size x F+3 x 2 + batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 + batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 + return batch_P_prime # batch_size x n x 2 diff --git a/strhub/models/utils.py b/strhub/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9b3a50b0420ce84ec12de560894990243de3d687 --- /dev/null +++ b/strhub/models/utils.py @@ -0,0 +1,118 @@ +from pathlib import PurePath +from typing import Sequence + +import torch +from torch import nn + +import yaml + + +class InvalidModelError(RuntimeError): + """Exception raised for any model-related error (creation, loading)""" + + +_WEIGHTS_URL = { + 'parseq-tiny': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq_tiny-e7a21b54.pt', + 'parseq': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq-bb5792a6.pt', + 'abinet': 'https://github.com/baudm/parseq/releases/download/v1.0.0/abinet-1d1e373e.pt', + 'trba': 'https://github.com/baudm/parseq/releases/download/v1.0.0/trba-cfaed284.pt', + 'vitstr': 'https://github.com/baudm/parseq/releases/download/v1.0.0/vitstr-26d0fcf4.pt', + 'crnn': 'https://github.com/baudm/parseq/releases/download/v1.0.0/crnn-679d0e31.pt', +} + + +def _get_config(experiment: str, **kwargs): + """Emulates hydra config resolution""" + root = PurePath(__file__).parents[2] + with open(root / 'configs/main.yaml', 'r') as f: + config = yaml.load(f, yaml.Loader)['model'] + with open(root / f'configs/charset/94_full.yaml', 'r') as f: + config.update(yaml.load(f, yaml.Loader)['model']) + with open(root / f'configs/experiment/{experiment}.yaml', 'r') as f: + exp = yaml.load(f, yaml.Loader) + # Apply base model config + model = exp['defaults'][0]['override /model'] + with open(root / f'configs/model/{model}.yaml', 'r') as f: + config.update(yaml.load(f, yaml.Loader)) + # Apply experiment config + if 'model' in exp: + config.update(exp['model']) + config.update(kwargs) + return config + + +def _get_model_class(key): + if 'abinet' in key: + from .abinet.system import ABINet as ModelClass + elif 'crnn' in key: + from .crnn.system import CRNN as ModelClass + elif 'parseq' in key: + from .parseq.system import PARSeq as ModelClass + elif 'trba' in key: + from .trba.system import TRBA as ModelClass + elif 'trbc' in key: + from .trba.system import TRBC as ModelClass + elif 'vitstr' in key: + from .vitstr.system import ViTSTR as ModelClass + else: + raise InvalidModelError("Unable to find model class for '{}'".format(key)) + return ModelClass + + +def create_model(experiment: str, pretrained: bool = False, **kwargs): + try: + config = _get_config(experiment, **kwargs) + except FileNotFoundError: + raise InvalidModelError("No configuration found for '{}'".format(experiment)) from None + ModelClass = _get_model_class(experiment) + model = ModelClass(**config) + if pretrained: + try: + url = _WEIGHTS_URL[experiment] + except KeyError: + raise InvalidModelError("No pretrained weights found for '{}'".format(experiment)) from None + checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location='cpu', check_hash=True) + model.load_state_dict(checkpoint) + return model + + +def load_from_checkpoint(checkpoint_path: str, **kwargs): + if checkpoint_path.startswith('pretrained='): + model_id = checkpoint_path.split('=', maxsplit=1)[1] + model = create_model(model_id, True, **kwargs) + else: + ModelClass = _get_model_class(checkpoint_path) + model = ModelClass.load_from_checkpoint(checkpoint_path, **kwargs) + return model + + +def parse_model_args(args): + kwargs = {} + arg_types = {t.__name__: t for t in [int, float, str]} + arg_types['bool'] = lambda v: v.lower() == 'true' # special handling for bool + for arg in args: + name, value = arg.split('=', maxsplit=1) + name, arg_type = name.split(':', maxsplit=1) + kwargs[name] = arg_types[arg_type](value) + return kwargs + + +def init_weights(module: nn.Module, name: str = '', exclude: Sequence[str] = ()): + """Initialize the weights using the typical initialization schemes used in SOTA models.""" + if any(map(name.startswith, exclude)): + return + if isinstance(module, nn.Linear): + nn.init.trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.trunc_normal_(module.weight, std=.02) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) diff --git a/strhub/models/vitstr/__init__.py b/strhub/models/vitstr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19e985679da1fcaa6deb306697993fd601892d6c --- /dev/null +++ b/strhub/models/vitstr/__init__.py @@ -0,0 +1,12 @@ +r""" +Atienza, Rowel. "Vision Transformer for Fast and Efficient Scene Text Recognition." +In International Conference on Document Analysis and Recognition (ICDAR). 2021. + +https://arxiv.org/abs/2105.08582 + +All source files, except `system.py`, are based on the implementation listed below, +and hence are released under the license of the original. + +Source: https://github.com/roatienza/deep-text-recognition-benchmark +License: Apache License 2.0 (see LICENSE file in project root) +""" diff --git a/strhub/models/vitstr/model.py b/strhub/models/vitstr/model.py new file mode 100644 index 0000000000000000000000000000000000000000..62c5d551626c325243a4f0d055869384a59b3910 --- /dev/null +++ b/strhub/models/vitstr/model.py @@ -0,0 +1,28 @@ +""" +Implementation of ViTSTR based on timm VisionTransformer. + +TODO: +1) distilled deit backbone +2) base deit backbone + +Copyright 2021 Rowel Atienza +""" + +from timm.models.vision_transformer import VisionTransformer + + +class ViTSTR(VisionTransformer): + """ + ViTSTR is basically a ViT that uses DeiT weights. + Modified head to support a sequence of characters prediction for STR. + """ + + def forward(self, x, seqlen: int = 25): + x = self.forward_features(x) + x = x[:, :seqlen] + + # batch, seqlen, embsize + b, s, e = x.size() + x = x.reshape(b * s, e) + x = self.head(x).view(b, s, self.num_classes) + return x diff --git a/strhub/models/vitstr/system.py b/strhub/models/vitstr/system.py new file mode 100644 index 0000000000000000000000000000000000000000..f5cedc4d2e1ef08430df743c42150c5cc84220dc --- /dev/null +++ b/strhub/models/vitstr/system.py @@ -0,0 +1,58 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Any, Optional + +import torch +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch import Tensor + +from strhub.models.base import CrossEntropySystem +from strhub.models.utils import init_weights +from .model import ViTSTR as Model + + +class ViTSTR(CrossEntropySystem): + + def __init__(self, charset_train: str, charset_test: str, max_label_length: int, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float, + img_size: Sequence[int], patch_size: Sequence[int], embed_dim: int, num_heads: int, + **kwargs: Any) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + self.max_label_length = max_label_length + # We don't predict nor + self.model = Model(img_size=img_size, patch_size=patch_size, depth=12, mlp_ratio=4, qkv_bias=True, + embed_dim=embed_dim, num_heads=num_heads, num_classes=len(self.tokenizer) - 2) + # Non-zero weight init for the head + self.model.head.apply(init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {'model.' + n for n in self.model.no_weight_decay()} + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + logits = self.model.forward(images, max_length + 2) # +2 tokens for [GO] and [s] + # Truncate to conform to other models. [GO] in ViTSTR is actually used as the padding (therefore, ignored). + # First position corresponds to the class token, which is unused and ignored in the original work. + logits = logits[:, 1:] + return logits + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + loss = self.forward_logits_loss(images, labels)[1] + self.log('loss', loss) + return loss diff --git a/train_shap_corr.py b/train_shap_corr.py new file mode 100644 index 0000000000000000000000000000000000000000..579b1071d4e7700fe4e686ed2fa374803c53610a --- /dev/null +++ b/train_shap_corr.py @@ -0,0 +1,90 @@ +import os +import time +import string +import argparse +import re +import validators +import sys + +import torch +import torch.backends.cudnn as cudnn +import torch.utils.data +import torch.nn.functional as F +import numpy as np +from nltk.metrics.distance import edit_distance +import pickle + +from utils import CTCLabelConverter, AttnLabelConverter, Averager, TokenLabelConverter +from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset +from model import Model, STRScore +from utils import get_args, AccuracyMeter +import matplotlib.pyplot as plt +import settings + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +def getPredAndConf(opt, model, scoring, image, converter, labels): + batch_size = image.size(0) + length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) + text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) + if not opt.Transformer: + text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length) + + if settings.MODEL=="vitstr": + target = converter.encode(labels) + preds = model(image, text=target, seqlen=converter.batch_max_length) + + confScore = scoring(preds) + confScore = confScore.detach().cpu().numpy() + + _, preds_index = preds.topk(1, dim=-1, largest=True, sorted=True) + preds_index = preds_index.view(-1, converter.batch_max_length) + + length_for_pred = torch.IntTensor([converter.batch_max_length - 1] * batch_size).to(device) + preds_str = converter.decode(preds_index[:, 1:], length_for_pred) + preds_str = preds_str[0] + preds_str = preds_str[:preds_str.find('[s]')] + + elif settings.MODEL=="trba": + preds = model(image) + confScore = scoring(preds) + _, preds_index = preds.max(2) + preds_str = converter.decode(preds_index, length_for_pred) + # print("preds_str: ", preds_str) # ['ronaldo[s] + preds_str = preds_str[0] + preds_str = preds_str[:preds_str.find('[s]')] + + elif settings.MODEL=="srn": + target = converter.encode(labels) + preds = model(image, None) + + _, preds_index = preds[2].max(2) + + confScore = scoring(preds) + confScore = confScore.detach().cpu().numpy() + + length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) + # length_for_pred = torch.IntTensor([converter.batch_max_length - 1] * batch_size).to(device) + preds_str = converter.decode(preds_index, length_for_pred) + preds_str = preds_str[0] + # preds_str = preds_str[:preds_str.find('[s]')] + preds = preds[2] + + elif settings.MODEL=="parseq": + target = converter.encode(labels) + preds = model(image) + + predStr, confidence = model.tokenizer.decode(preds) + + confScore = scoring(preds) + confScore = confScore.detach().cpu().numpy() + + # _, preds_index = preds.topk(1, dim=-1, largest=True, sorted=True) + # preds_index = preds_index.view(-1, converter.batch_max_length) + # + # length_for_pred = torch.IntTensor([converter.batch_max_length - 1] * batch_size).to(device) + # preds_str = converter.decode(preds_index[:, 0:], length_for_pred) + preds_str = predStr[0] + # preds_str = preds_str[:preds_str.find('[s]')] + # pred = pred[:pred_EOS] + return preds_str, confScore diff --git a/transforms.py b/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7042f3368bc832566d5c22d1e18abe5d8547f5 --- /dev/null +++ b/transforms.py @@ -0,0 +1,329 @@ +import math +import numbers +import random + +import cv2 +import numpy as np +from PIL import Image +from torchvision import transforms +from torchvision.transforms import Compose + + +def sample_asym(magnitude, size=None): + return np.random.beta(1, 4, size) * magnitude + +def sample_sym(magnitude, size=None): + return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude + +def sample_uniform(low, high, size=None): + return np.random.uniform(low, high, size=size) + +def get_interpolation(type='random'): + if type == 'random': + choice = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA] + interpolation = choice[random.randint(0, len(choice)-1)] + elif type == 'nearest': interpolation = cv2.INTER_NEAREST + elif type == 'linear': interpolation = cv2.INTER_LINEAR + elif type == 'cubic': interpolation = cv2.INTER_CUBIC + elif type == 'area': interpolation = cv2.INTER_AREA + else: raise TypeError('Interpolation types only nearest, linear, cubic, area are supported!') + return interpolation + +class CVRandomRotation(object): + def __init__(self, degrees=15): + assert isinstance(degrees, numbers.Number), "degree should be a single number." + assert degrees >= 0, "degree must be positive." + self.degrees = degrees + + @staticmethod + def get_params(degrees): + return sample_sym(degrees) + + def __call__(self, img): + angle = self.get_params(self.degrees) + src_h, src_w = img.shape[:2] + M = cv2.getRotationMatrix2D(center=(src_w/2, src_h/2), angle=angle, scale=1.0) + abs_cos, abs_sin = abs(M[0,0]), abs(M[0,1]) + dst_w = int(src_h * abs_sin + src_w * abs_cos) + dst_h = int(src_h * abs_cos + src_w * abs_sin) + M[0, 2] += (dst_w - src_w)/2 + M[1, 2] += (dst_h - src_h)/2 + + flags = get_interpolation() + return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE) + +class CVRandomAffine(object): + def __init__(self, degrees, translate=None, scale=None, shear=None): + assert isinstance(degrees, numbers.Number), "degree should be a single number." + assert degrees >= 0, "degree must be positive." + self.degrees = degrees + + if translate is not None: + assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ + "translate should be a list or tuple and it must be of length 2." + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError("translation values should be between 0 and 1") + self.translate = translate + + if scale is not None: + assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ + "scale should be a list or tuple and it must be of length 2." + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + if isinstance(shear, numbers.Number): + if shear < 0: + raise ValueError("If shear is a single number, it must be positive.") + self.shear = [shear] + else: + assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \ + "shear should be a list or tuple and it must be of length 2." + self.shear = shear + else: + self.shear = shear + + def _get_inverse_affine_matrix(self, center, angle, translate, scale, shear): + # https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717 + from numpy import sin, cos, tan + + if isinstance(shear, numbers.Number): + shear = [shear, 0] + + if not isinstance(shear, (tuple, list)) and len(shear) == 2: + raise ValueError( + "Shear should be a single value or a tuple/list containing " + + "two values. Got {}".format(shear)) + + rot = math.radians(angle) + sx, sy = [math.radians(s) for s in shear] + + cx, cy = center + tx, ty = translate + + # RSS without scaling + a = cos(rot - sy) / cos(sy) + b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot) + c = sin(rot - sy) / cos(sy) + d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot) + + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + M = [d, -b, 0, + -c, a, 0] + M = [x / scale for x in M] + + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty) + M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty) + + # Apply center translation: C * RSS^-1 * C^-1 * T^-1 + M[2] += cx + M[5] += cy + return M + + @staticmethod + def get_params(degrees, translate, scale_ranges, shears, height): + angle = sample_sym(degrees) + if translate is not None: + max_dx = translate[0] * height + max_dy = translate[1] * height + translations = (np.round(sample_sym(max_dx)), np.round(sample_sym(max_dy))) + else: + translations = (0, 0) + + if scale_ranges is not None: + scale = sample_uniform(scale_ranges[0], scale_ranges[1]) + else: + scale = 1.0 + + if shears is not None: + if len(shears) == 1: + shear = [sample_sym(shears[0]), 0.] + elif len(shears) == 2: + shear = [sample_sym(shears[0]), sample_sym(shears[1])] + else: + shear = 0.0 + + return angle, translations, scale, shear + + + def __call__(self, img): + src_h, src_w = img.shape[:2] + angle, translate, scale, shear = self.get_params( + self.degrees, self.translate, self.scale, self.shear, src_h) + + M = self._get_inverse_affine_matrix((src_w/2, src_h/2), angle, (0, 0), scale, shear) + M = np.array(M).reshape(2,3) + + startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1), (0, src_h - 1)] + project = lambda x, y, a, b, c: int(a*x + b*y + c) + endpoints = [(project(x, y, *M[0]), project(x, y, *M[1])) for x, y in startpoints] + + rect = cv2.minAreaRect(np.array(endpoints)) + bbox = cv2.boxPoints(rect).astype(dtype=np.int) + max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max() + min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min() + + dst_w = int(max_x - min_x) + dst_h = int(max_y - min_y) + M[0, 2] += (dst_w - src_w) / 2 + M[1, 2] += (dst_h - src_h) / 2 + + # add translate + dst_w += int(abs(translate[0])) + dst_h += int(abs(translate[1])) + if translate[0] < 0: M[0, 2] += abs(translate[0]) + if translate[1] < 0: M[1, 2] += abs(translate[1]) + + flags = get_interpolation() + return cv2.warpAffine(img, M, (dst_w , dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE) + +class CVRandomPerspective(object): + def __init__(self, distortion=0.5): + self.distortion = distortion + + def get_params(self, width, height, distortion): + offset_h = sample_asym(distortion * height / 2, size=4).astype(dtype=np.int) + offset_w = sample_asym(distortion * width / 2, size=4).astype(dtype=np.int) + topleft = ( offset_w[0], offset_h[0]) + topright = (width - 1 - offset_w[1], offset_h[1]) + botright = (width - 1 - offset_w[2], height - 1 - offset_h[2]) + botleft = ( offset_w[3], height - 1 - offset_h[3]) + + startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)] + endpoints = [topleft, topright, botright, botleft] + return np.array(startpoints, dtype=np.float32), np.array(endpoints, dtype=np.float32) + + def __call__(self, img): + height, width = img.shape[:2] + startpoints, endpoints = self.get_params(width, height, self.distortion) + M = cv2.getPerspectiveTransform(startpoints, endpoints) + + # TODO: more robust way to crop image + rect = cv2.minAreaRect(endpoints) + bbox = cv2.boxPoints(rect).astype(dtype=np.int) + max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max() + min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min() + min_x, min_y = max(min_x, 0), max(min_y, 0) + + flags = get_interpolation() + img = cv2.warpPerspective(img, M, (max_x, max_y), flags=flags, borderMode=cv2.BORDER_REPLICATE) + img = img[min_y:, min_x:] + return img + +class CVRescale(object): + + def __init__(self, factor=4, base_size=(128, 512)): + """ Define image scales using gaussian pyramid and rescale image to target scale. + + Args: + factor: the decayed factor from base size, factor=4 keeps target scale by default. + base_size: base size the build the bottom layer of pyramid + """ + if isinstance(factor, numbers.Number): + self.factor = round(sample_uniform(0, factor)) + elif isinstance(factor, (tuple, list)) and len(factor) == 2: + self.factor = round(sample_uniform(factor[0], factor[1])) + else: + raise Exception('factor must be number or list with length 2') + # assert factor is valid + self.base_h, self.base_w = base_size[:2] + + def __call__(self, img): + if self.factor == 0: return img + src_h, src_w = img.shape[:2] + cur_w, cur_h = self.base_w, self.base_h + scale_img = cv2.resize(img, (cur_w, cur_h), interpolation=get_interpolation()) + for _ in range(self.factor): + scale_img = cv2.pyrDown(scale_img) + scale_img = cv2.resize(scale_img, (src_w, src_h), interpolation=get_interpolation()) + return scale_img + +class CVGaussianNoise(object): + def __init__(self, mean=0, var=20): + self.mean = mean + if isinstance(var, numbers.Number): + self.var = max(int(sample_asym(var)), 1) + elif isinstance(var, (tuple, list)) and len(var) == 2: + self.var = int(sample_uniform(var[0], var[1])) + else: + raise Exception('degree must be number or list with length 2') + + def __call__(self, img): + noise = np.random.normal(self.mean, self.var**0.5, img.shape) + img = np.clip(img + noise, 0, 255).astype(np.uint8) + return img + +class CVMotionBlur(object): + def __init__(self, degrees=12, angle=90): + if isinstance(degrees, numbers.Number): + self.degree = max(int(sample_asym(degrees)), 1) + elif isinstance(degrees, (tuple, list)) and len(degrees) == 2: + self.degree = int(sample_uniform(degrees[0], degrees[1])) + else: + raise Exception('degree must be number or list with length 2') + self.angle = sample_uniform(-angle, angle) + + def __call__(self, img): + M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2), self.angle, 1) + motion_blur_kernel = np.zeros((self.degree, self.degree)) + motion_blur_kernel[self.degree // 2, :] = 1 + motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (self.degree, self.degree)) + motion_blur_kernel = motion_blur_kernel / self.degree + img = cv2.filter2D(img, -1, motion_blur_kernel) + img = np.clip(img, 0, 255).astype(np.uint8) + return img + +class CVGeometry(object): + def __init__(self, degrees=15, translate=(0.3, 0.3), scale=(0.5, 2.), + shear=(45, 15), distortion=0.5, p=0.5): + self.p = p + type_p = random.random() + if type_p < 0.33: + self.transforms = CVRandomRotation(degrees=degrees) + elif type_p < 0.66: + self.transforms = CVRandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear) + else: + self.transforms = CVRandomPerspective(distortion=distortion) + + def __call__(self, img): + if random.random() < self.p: + img = np.array(img) + return Image.fromarray(self.transforms(img)) + else: return img + +class CVDeterioration(object): + def __init__(self, var, degrees, factor, p=0.5): + self.p = p + transforms = [] + if var is not None: + transforms.append(CVGaussianNoise(var=var)) + if degrees is not None: + transforms.append(CVMotionBlur(degrees=degrees)) + if factor is not None: + transforms.append(CVRescale(factor=factor)) + + random.shuffle(transforms) + transforms = Compose(transforms) + self.transforms = transforms + + def __call__(self, img): + if random.random() < self.p: + img = np.array(img) + return Image.fromarray(self.transforms(img)) + else: return img + + +class CVColorJitter(object): + def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.5): + self.p = p + self.transforms = transforms.ColorJitter(brightness=brightness, contrast=contrast, + saturation=saturation, hue=hue) + + def __call__(self, img): + if random.random() < self.p: return self.transforms(img) + else: return img diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/util/clean.py b/util/clean.py new file mode 100644 index 0000000000000000000000000000000000000000..c9b296e83be8b3f4a6a36d57e56b01974bb8e20a --- /dev/null +++ b/util/clean.py @@ -0,0 +1,7 @@ +import settings +import os + +def clean(): + filelist = [f for f in os.listdir(settings.OUTPUT_FOLDER) if f.endswith('mmap')] + for f in filelist: + os.remove(os.path.join(settings.OUTPUT_FOLDER, f)) diff --git a/util/upsample.py b/util/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8e664e74ccdbeb3e65604aabb1b273e8d71ac7 --- /dev/null +++ b/util/upsample.py @@ -0,0 +1,368 @@ +from scipy.ndimage.filters import gaussian_filter +from scipy.interpolate import RectBivariateSpline +from scipy.ndimage.interpolation import zoom +import numpy + +def upsampleL(fieldmap, activation_data, reduction=1, shape=None, + scaleshape=None, out=None): + ''' + Applies a bilinear upsampling. + ''' + offset, size, step = fieldmap + input_count = activation_data.shape[0] + if len(activation_data.shape) == 2: + ay, ax = centered_arange(fieldmap, activation_data.shape, reduction) + if shape is None: + shape = upsampled_shape( + fieldmap, activation_data.shape, reduction) + else: + ay, ax = centered_arange(fieldmap, activation_data.shape[1:], reduction) + if shape is None: + shape = upsampled_shape( + fieldmap, activation_data.shape[1:], reduction) + if scaleshape is not None: + iy, ix = full_arange(scaleshape) + # TODO: consider treaing each point as a center of a pixel + iy *= shape[0] / scaleshape[0] + ix *= shape[1] / scaleshape[1] + else: + iy, ix = full_arange(shape) + if out is None: + out = numpy.empty((input_count, len(iy), len(ix)), + dtype=activation_data.dtype) + if len(activation_data.shape) == 2: + f = RectBivariateSpline(ay, ax, activation_data, kx=1, ky=1) + return f(iy, ix, grid=True) + else: + for z in range(input_count): + f = RectBivariateSpline(ay, ax, activation_data[z], kx=1, ky=1) + out[z] = f(iy, ix, grid=True) + return out + +def upsampleC(fieldmap, activation_data, shape=None, out=None): + ''' + Applies a bicubic upsampling. + ''' + offset, size, step = fieldmap + input_count = activation_data.shape[0] + ay, ax = centered_arange(fieldmap, activation_data.shape[1:]) + if shape is None: + shape = upsampled_shape(fieldmap, activation_data.shape[1:]) + iy, ix = full_arange(shape) + if out is None: + out = numpy.empty((input_count,) + shape, + dtype=activation_data.dtype) + for z in range(input_count): + f = RectBivariateSpline(ay, ax, activation_data[z], kx=3, ky=3) + out[z] = f(iy, ix, grid=True) + return out + +def upsampleG(fieldmap, activation_data, shape=None): + ''' + Upsampling utility functions + ''' + offset, size, step = fieldmap + input_count = activation_data.shape[0] + if shape is None: + shape = upsampled_shape(fieldmap, activation_data.shape[1:]) + activations = numpy.zeros((input_count,) + shape) + activations[(slice(None),) + + centered_slice(fieldmap, activation_data.shape[1:])] = ( + activation_data * numpy.prod(step)) + blurred = gaussian_filter( + activations, + sigma=(0, ) + tuple(t // 1.414 for o, s, t in zip(*fieldmap)), + mode='constant') + return blurred + +def topo_sort(layers): + # First, build a links-from and also a links-to graph + links_from = {} + links_to = {} + for layer in layers: + for bot in layer.bottom: + if bot not in links_from: + links_from[bot] = [] + links_from[bot].append(layer) + for top in layer.top: + if top not in links_to: + links_to[top] = [] + links_to[top].append(layer) + # Now do a DFS to figure out the ordering (using links-from) + visited = set() + ordering = [] + stack = [] + for seed in links_from: + if seed not in visited: + stack.append((seed, True)) + stack.append((seed, False)) + visited.add(seed) + while stack: + (blob, completed) = stack.pop() + if completed: + ordering.append(blob) + elif blob in links_from: + for layer in links_from[blob]: + for t in layer.top: + if t not in visited: + stack.append((t, True)) + stack.append((t, False)) + visited.add(t) + # Return a result in front-to-back order, with incoming links for each + return list((blob, links_to[blob] if blob in links_to else []) + for blob in reversed(ordering)) + +def composed_fieldmap(layers, end): + ts = topo_sort(layers) + fm_record = {} + for blob, layers in ts: + # Compute fm's on all the edges that go to this blob. + all_fms = [ + (compose_fieldmap(fm_record[bot][0], layer_fieldmap(layer)), + fm_record[bot][1] + [(bot, layer)]) + for layer in layers for bot in layer.bottom if bot != blob] + # And take the max fieldmap. + fm_record[blob] = max_fieldmap(all_fms) + if blob == end: + return fm_record[blob] + +def max_fieldmap(maps): + biggest, bp = None, None + for fm, path in maps: + if biggest is None: + biggest, bp = fm, path + elif fm[1][0] > biggest[1][0]: + biggest, bp = fm, path + # When there is no biggest, for example when maps is the empty array, + # use the trivial identity fieldmap with no path. + if biggest is None: + return ((0, 0), (1, 1), (1, 1)), [] + return biggest, bp + +def shortest_layer_path(start, end, layers): + # First, build a blob-to-outgoing-layer graph + links_from = {} + for layer in layers: + for bot in layer.bottom: + if bot not in links_from: + links_from[bot] = [] + links_from[bot].append(layer) + # Then do a BFS on the graph to find the shortest path to 'end' + queue = [(s, []) for s in start] + visited = set(start) + while queue: + (blob, path) = queue.pop(0) + for layer in links_from[blob]: + for t in layer.top: + if t == end: + return path + [layer] + if t not in visited: + queue.append((t, path + [layer])) + visited.add(t) + return None + +def upsampled_shape(fieldmap, shape, reduction=1): + # Given the shape of a layer's activation and a fieldmap describing + # the transformation to original image space, returns the shape of + # the input size + return tuple(((w - 1) * t + s + 2 * o) // reduction + for (o, s, t), w in zip(zip(*fieldmap), shape)) + +def make_mask_set(image_shape, fieldmap, activation_data, + output=None, sigma=0.1, threshold=0.5, percentile=None): + """Creates a set of receptive field masks with uniform thresholds + over a range of inputs. + """ + offset, shape, step = fieldmap + input_count = activation_data.shape[0] + activations = numpy.zeros((input_count,) + image_shape) + activations[(slice(None),) + + centered_slice(fieldmap, activation_data.shape[1:])] = ( + activation_data) + blurred = gaussian_filter( + activations, + sigma=(0, ) + tuple(s * sigma for s in shape), + mode='constant') + if percentile is not None: + limit = blurred.ravel().percentile(percentile) + return blurred > limit + else: + maximum = blurred.ravel().max() + return (blurred > maximum * threshold) + +def safezoom(array, ratio, output=None, order=0): + '''Like numpy.zoom, but does not crash when the first dimension + of the array is of size 1, as happens often with segmentations''' + dtype = array.dtype + if array.dtype == numpy.float16: + array = array.astype(numpy.float32) + if array.shape[0] == 1: + if output is not None: + output = output[0,...] + result = zoom(array[0,...], ratio[1:], + output=output, order=order) + if output is None: + output = result[numpy.newaxis] + else: + result = zoom(array, ratio, output=output, order=order) + if output is None: + output = result + return output.astype(dtype) + +def receptive_field(location, fieldmap): + """Computes the receptive field of a specific location. + + Parameters + ---------- + location: tuple + The x-y position of the unit being queried. + fieldmap: + The (offset, size, step) tuple fieldmap representing the + receptive field map for the layer being queried. + """ + return compose_fieldmap(fieldmap, (location, (1, 1), (1, 1)))[:2] + + +def proto_getattr(p, a, d): + hf = True + # Try using HasField to detect the presence of a field; + # if there is no HasField, then just use getattr. + try: + hf = p.HasField(a) + except: + pass + if hf: + return getattr(p, a, d) + return d + +def wh_attr(layer, attrname, default=0, minval=0): + if not hasattr(default, '__len__'): + default = (default, default) + val = proto_getattr(layer, attrname, None) + if val is None or val == []: + h = max(minval, getattr(layer, attrname + '_h', default[0])) + w = max(minval, getattr(layer, attrname + '_w', default[1])) + elif hasattr(val, '__len__'): + h = val[0] + w = val[1] if len(val) >= 2 else h + else: + h = val + w = val + return (h, w) + +def layer_fieldmap(layer): + # Only convolutional and pooling layers affect geometry. + if layer.type == 'Convolution' or layer.type == 'Pooling': + if layer.type == 'Pooling': + config = layer.pooling_param + if config.global_pooling: + return ((0, 0), (None, None), (1, 1)) + else: + config = layer.convolution_param + size = wh_attr(config, 'kernel_size', wh_attr(config, 'kernel', 1)) + stride = wh_attr(config, 'stride', 1, minval=1) + padding = wh_attr(config, 'pad', 0) + neg_padding = tuple((-x) for x in padding) + return (neg_padding, size, stride) + # All other layers just pass through geometry unchanged. + return ((0, 0), (1, 1), (1, 1)) + +def layerarray_fieldmap(layerarray): + fieldmap = ((0, 0), (1, 1), (1, 1)) + for layer in layerarray: + fieldmap = compose_fieldmap(fieldmap, layer_fieldmap(layer)) + return fieldmap + +# rf1 is the lower layer, rf2 is the higher layer +def compose_fieldmap(rf1, rf2): + """Composes two stacked fieldmap maps. + + Field maps are represented as triples of (offset, size, step), + where each is an (x, y) pair. + + To find the pixel range corresponding to output pixel (x, y), just + do the following: + start_x = x * step[0] + offset[1] + limit_x = start_x + size[0] + start_y = y * step[1] + offset[1] + limit_y = start_y + size[1] + + Parameters + ---------- + rf1: tuple + The lower-layer receptive fieldmap, a tuple of (offset, size, step). + rf2: tuple + The higher-layer receptive fieldmap, a tuple of (offset, size, step). + """ + if rf1 == None: + import pdb; pdb.set_trace() + offset1, size1, step1 = rf1 + offset2, size2, step2 = rf2 + + size = tuple((size2c - 1) * step1c + size1c + for size1c, step1c, size2c in zip(size1, step1, size2)) + offset = tuple(offset2c * step1c + offset1c + for offset2c, step1c, offset1c in zip(offset2, step1, offset1)) + step = tuple(step2c * step1c + for step1c, step2c in zip(step1, step2)) + return (offset, size, step) + +def _cropped_slices(offset, size, limit): + corner = 0 + if offset < 0: + size += offset + offset = 0 + if limit - offset < size: + corner = limit - offset + size -= corner + return (slice(corner, corner + size), slice(offset, offset + size)) + +def crop_field(image_data, fieldmap, location): + """Crops image_data to the specified receptive field. + + Together fieldmap and location specify a receptive field on the image, + which may overlap the edge. This returns a crop to that shape, including + any zero padding necessary to fill out the shape beyond the image edge. + """ + offset, size = receptive_field(fieldmap, location) + return crop_rectangle(image_data, offset, size) + +def crop_rectangle(image_data, offset, size): + coloraxis = 0 if image_data.size <= 2 else 1 + allcolors = () if not coloraxis else (slice(None),) * coloraxis + colordepth = () if not coloraxis else (image_data.size[0], ) + result = numpy.zeros(colordepth + size) + (xto, xfrom), (yto, yfrom) = (_cropped_slices( + o, s, l) for o, s, l in zip(offset, size, image_data.size[coloraxis:])) + result[allcolors + (xto, yto)] = image_data[allcolors + (xfrom, yfrom)] + return result + +def center_location(fieldmap, location): + if isinstance(location, numpy.ndarray): + offset, size, step = fieldmap + broadcast = (numpy.newaxis, ) * (len(location.shape) - 1) + ( + slice(None),) + step = numpy.array(step)[broadcast] + offset = numpy.array(offset)[broadcast] + size = numpy.array(size)[broadcast] + return location * step + offset + size // 2 + else: + offset, shape = receptive_field(location, fieldmap) + return tuple(o + s // 2 for o, s in zip(offset, shape)) + +def centered_slice(fieldmap, activation_shape, reduction=1): + offset, size, step = fieldmap + r = reduction + return tuple(slice((s // 2 + o) // r, (s // 2 + o + a * t) // r, t // r) + for o, s, t, a in zip(offset, size, step, activation_shape)) + +def centered_arange(fieldmap, activation_shape, reduction=1): + offset, size, step = fieldmap + r = reduction + return tuple(numpy.arange( + (s // 2 + o) // r, (s // 2 + o + a * t) // r, t // r)[:a] # Hack to avoid a+1 points + for o, s, t, a in zip(offset, size, step, activation_shape)) + +def full_arange(output_shape): + return tuple(numpy.arange(o) for o in output_shape) + diff --git a/util/vecquantile.py b/util/vecquantile.py new file mode 100644 index 0000000000000000000000000000000000000000..9ddbc6b4b577f91751894a9e780568ff987027c1 --- /dev/null +++ b/util/vecquantile.py @@ -0,0 +1,263 @@ +import numpy + +class QuantileVector: + """ + Streaming randomized quantile computation for numpy. + + Add any amount of data repeatedly via add(data). At any time, + quantile estimates (or old-style percentiles) can be read out using + quantiles(q) or percentiles(p). + + Accuracy scales according to resolution: the default is to + set resolution to be accurate to better than 0.1%, + while limiting storage to about 50,000 samples. + + Good for computing quantiles of huge data without using much memory. + Works well on arbitrary data with probability near 1. + + Based on the optimal KLL quantile algorithm by Karnin, Lang, and Liberty + from FOCS 2016. http://ieee-focs.org/FOCS-2016-Papers/3933a071.pdf + """ + + def __init__(self, depth=1, resolution=24 * 1024, buffersize=None, + dtype=None, seed=None): + self.resolution = resolution + self.depth = depth + # Default buffersize: 128 samples (and smaller than resolution). + if buffersize is None: + buffersize = min(128, (resolution + 7) // 8) + self.buffersize = buffersize + self.samplerate = 1.0 + self.data = [numpy.zeros(shape=(depth, resolution), dtype=dtype)] + self.firstfree = [0] + self.random = numpy.random.RandomState(seed) + self.extremes = numpy.empty(shape=(depth, 2), dtype=dtype) + self.extremes.fill(numpy.NaN) + self.size = 0 + + def add(self, incoming): + assert len(incoming.shape) == 2 + assert incoming.shape[1] == self.depth + self.size += incoming.shape[0] + # Convert to a flat numpy array. + if self.samplerate >= 1.0: + self._add_every(incoming) + return + # If we are sampling, then subsample a large chunk at a time. + self._scan_extremes(incoming) + chunksize = numpy.ceil[self.buffersize / self.samplerate] + for index in range(0, len(incoming), chunksize): + batch = incoming[index:index+chunksize] + sample = batch[self.random.binomial(1, self.samplerate, len(batch))] + self._add_every(sample) + + def _add_every(self, incoming): + supplied = len(incoming) + index = 0 + while index < supplied: + ff = self.firstfree[0] + available = self.data[0].shape[1] - ff + if available == 0: + if not self._shift(): + # If we shifted by subsampling, then subsample. + incoming = incoming[index:] + if self.samplerate >= 0.5: + print('SAMPLING') + self._scan_extremes(incoming) + incoming = incoming[self.random.binomial(1, 0.5, + len(incoming - index))] + index = 0 + supplied = len(incoming) + ff = self.firstfree[0] + available = self.data[0].shape[1] - ff + copycount = min(available, supplied - index) + self.data[0][:,ff:ff + copycount] = numpy.transpose( + incoming[index:index + copycount,:]) + self.firstfree[0] += copycount + index += copycount + + def _shift(self): + index = 0 + # If remaining space at the current layer is less than half prev + # buffer size (rounding up), then we need to shift it up to ensure + # enough space for future shifting. + while self.data[index].shape[1] - self.firstfree[index] < ( + -(-self.data[index-1].shape[1] // 2) if index else 1): + if index + 1 >= len(self.data): + return self._expand() + data = self.data[index][:,0:self.firstfree[index]] + data.sort() + if index == 0 and self.samplerate >= 1.0: + self._update_extremes(data[:,0], data[:,-1]) + offset = self.random.binomial(1, 0.5) + position = self.firstfree[index + 1] + subset = data[:,offset::2] + self.data[index + 1][:,position:position + subset.shape[1]] = subset + self.firstfree[index] = 0 + self.firstfree[index + 1] += subset.shape[1] + index += 1 + return True + + def _scan_extremes(self, incoming): + # When sampling, we need to scan every item still to get extremes + self._update_extremes( + numpy.nanmin(incoming, axis=0), + numpy.nanmax(incoming, axis=0)) + + def _update_extremes(self, minr, maxr): + self.extremes[:,0] = numpy.nanmin( + [self.extremes[:, 0], minr], axis=0) + self.extremes[:,-1] = numpy.nanmax( + [self.extremes[:, -1], maxr], axis=0) + + def minmax(self): + if self.firstfree[0]: + self._scan_extremes(self.data[0][:,:self.firstfree[0]].transpose()) + return self.extremes.copy() + + def _expand(self): + cap = self._next_capacity() + if cap > 0: + # First, make a new layer of the proper capacity. + self.data.insert(0, numpy.empty( + shape=(self.depth, cap), dtype=self.data[-1].dtype)) + self.firstfree.insert(0, 0) + else: + # Unless we're so big we are just subsampling. + assert self.firstfree[0] == 0 + self.samplerate *= 0.5 + for index in range(1, len(self.data)): + # Scan for existing data that needs to be moved down a level. + amount = self.firstfree[index] + if amount == 0: + continue + position = self.firstfree[index-1] + # Move data down if it would leave enough empty space there + # This is the key invariant: enough empty space to fit half + # of the previous level's buffer size (rounding up) + if self.data[index-1].shape[1] - (amount + position) >= ( + -(-self.data[index-2].shape[1] // 2) if (index-1) else 1): + self.data[index-1][:,position:position + amount] = ( + self.data[index][:,:amount]) + self.firstfree[index-1] += amount + self.firstfree[index] = 0 + else: + # Scrunch the data if it would not. + data = self.data[index][:,:amount] + data.sort() + if index == 1: + self._update_extremes(data[:,0], data[:,-1]) + offset = self.random.binomial(1, 0.5) + scrunched = data[:,offset::2] + self.data[index][:,:scrunched.shape[1]] = scrunched + self.firstfree[index] = scrunched.shape[1] + return cap > 0 + + def _next_capacity(self): + cap = numpy.ceil(self.resolution * numpy.power(0.67, len(self.data))) + if cap < 2: + return 0 + return max(self.buffersize, int(cap)) + + def _weighted_summary(self, sort=True): + if self.firstfree[0]: + self._scan_extremes(self.data[0][:,:self.firstfree[0]].transpose()) + size = sum(self.firstfree) + 2 + weights = numpy.empty( + shape=(size), dtype='float32') # floating point + summary = numpy.empty( + shape=(self.depth, size), dtype=self.data[-1].dtype) + weights[0:2] = 0 + summary[:,0:2] = self.extremes + index = 2 + for level, ff in enumerate(self.firstfree): + if ff == 0: + continue + summary[:,index:index + ff] = self.data[level][:,:ff] + weights[index:index + ff] = numpy.power(2.0, level) + index += ff + assert index == summary.shape[1] + if sort: + order = numpy.argsort(summary) + summary = summary[numpy.arange(self.depth)[:,None], order] + weights = weights[order] + return (summary, weights) + + def quantiles(self, quantiles, old_style=False): + if self.size == 0: + return numpy.full((self.depth, len(quantiles)), numpy.nan) + summary, weights = self._weighted_summary() + cumweights = numpy.cumsum(weights, axis=-1) - weights / 2 + if old_style: + # To be convenient with numpy.percentile + cumweights -= cumweights[:,0:1] + cumweights /= cumweights[:,-1:] + else: + cumweights /= numpy.sum(weights, axis=-1, keepdims=True) + result = numpy.empty(shape=(self.depth, len(quantiles))) + for d in range(self.depth): + result[d] = numpy.interp(quantiles, cumweights[d], summary[d]) + return result + + def integrate(self, fun): + result = None + for level, ff in enumerate(self.firstfree): + if ff == 0: + continue + term = numpy.sum( + fun(self.data[level][:,:ff]) * numpy.power(2.0, level), + axis=-1) + if result is None: + result = term + else: + result += term + if result is not None: + result /= self.samplerate + return result + + def percentiles(self, percentiles): + return self.quantiles(percentiles, old_style=True) + + def readout(self, count, old_style=True): + return self.quantiles( + numpy.linspace(0.0, 1.0, count), old_style=old_style) + + +if __name__ == '__main__': + import time + # An adverarial case: we keep finding more numbers in the middle + # as the stream goes on. + amount = 10000000 + percentiles = 1000 + data = numpy.arange(float(amount)) + data[1::2] = data[-1::-2] + (len(data) - 1) + data /= 2 + depth = 50 + alldata = data[:,None] + (numpy.arange(depth) * amount)[None, :] + actual_sum = numpy.sum(alldata * alldata, axis=0) + amt = amount // depth + for r in range(depth): + numpy.random.shuffle(alldata[r*amt:r*amt+amt,r]) + # data[::2] = data[-2::-2] + # numpy.random.shuffle(data) + starttime = time.time() + qc = QuantileVector(depth=depth, resolution=8 * 1024) + qc.add(alldata) + ro = qc.readout(1001) + endtime = time.time() + # print 'ro', ro + # print ro - numpy.linspace(0, amount, percentiles+1) + gt = numpy.linspace(0, amount, percentiles+1)[None,:] + ( + numpy.arange(qc.depth) * amount)[:,None] + print("Maximum relative deviation among %d perentiles:" % percentiles, ( + numpy.max(abs(ro - gt) / amount) * percentiles)) + print("Minmax eror %f, %f" % ( + max(abs(qc.minmax()[:,0] - numpy.arange(qc.depth) * amount)), + max(abs(qc.minmax()[:, -1] - (numpy.arange(qc.depth)+1) * amount + 1)))) + print("Integral error:", numpy.max(numpy.abs( + qc.integrate(lambda x: x * x) + - actual_sum) / actual_sum)) + print("Count error: ", (qc.integrate(lambda x: numpy.ones(x.shape[-1]) + ) - qc.size) / (0.0 + qc.size)) + print("Time", (endtime - starttime)) + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ede9230a93040afb41861e9073b1c115db95cd89 --- /dev/null +++ b/utils.py @@ -0,0 +1,411 @@ +import torch +import numpy as np +import argparse + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +class CTCLabelConverter(object): + """ Convert between text-label and text-index """ + + def __init__(self, character): + # character (str): set of the possible characters. + dict_character = list(character) + + self.dict = {} + for i, char in enumerate(dict_character): + # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss + self.dict[char] = i + 1 + + self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) + + def encode(self, text, batch_max_length=25): + """convert text-label into text-index. + input: + text: text labels of each image. [batch_size] + batch_max_length: max length of text label in the batch. 25 by default + + output: + text: text index for CTCLoss. [batch_size, batch_max_length] + length: length of each text. [batch_size] + """ + length = [len(s) for s in text] + + # The index used for padding (=0) would not affect the CTC loss calculation. + batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) + for i, t in enumerate(text): + text = list(t) + text = [self.dict[char] for char in text] + batch_text[i][:len(text)] = torch.LongTensor(text) + return (batch_text.to(device), torch.IntTensor(length).to(device)) + + def decode(self, text_index, length): + """ convert text-index into text-label. """ + texts = [] + for index, l in enumerate(length): + t = text_index[index, :] + + char_list = [] + for i in range(l): + if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. + char_list.append(self.character[t[i]]) + text = ''.join(char_list) + + texts.append(text) + return texts + + +class CTCLabelConverterForBaiduWarpctc(object): + """ Convert between text-label and text-index for baidu warpctc """ + + def __init__(self, character): + # character (str): set of the possible characters. + dict_character = list(character) + + self.dict = {} + for i, char in enumerate(dict_character): + # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss + self.dict[char] = i + 1 + + self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) + + def encode(self, text, batch_max_length=25): + """convert text-label into text-index. + input: + text: text labels of each image. [batch_size] + output: + text: concatenated text index for CTCLoss. + [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] + length: length of each text. [batch_size] + """ + length = [len(s) for s in text] + text = ''.join(text) + text = [self.dict[char] for char in text] + + return (torch.IntTensor(text), torch.IntTensor(length)) + + def decode(self, text_index, length): + """ convert text-index into text-label. """ + texts = [] + index = 0 + for l in length: + t = text_index[index:index + l] + + char_list = [] + for i in range(l): + if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. + char_list.append(self.character[t[i]]) + text = ''.join(char_list) + + texts.append(text) + index += l + return texts + + +class AttnLabelConverter(object): + """ Convert between text-label and text-index """ + + def __init__(self, character): + # character (str): set of the possible characters. + # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. + list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] + list_character = list(character) + self.character = list_token + list_character + + self.dict = {} + for i, char in enumerate(self.character): + # print(i, char) + self.dict[char] = i + + def encode(self, text, batch_max_length=25): + """ convert text-label into text-index. + input: + text: text labels of each image. [batch_size] + batch_max_length: max length of text label in the batch. 25 by default + + output: + text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. + text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. + length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] + """ + length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. + # batch_max_length = max(length) # this is not allowed for multi-gpu setting + batch_max_length += 1 + # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. + batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) + for i, t in enumerate(text): + text = list(t) + text.append('[s]') + text = [self.dict[char] for char in text] + batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token + return (batch_text.to(device), torch.IntTensor(length).to(device)) + + def decode(self, text_index, length): + """ convert text-index into text-label. """ + texts = [] + for index, l in enumerate(length): + text = ''.join([self.character[i] for i in text_index[index, :]]) + texts.append(text) + return texts + + +class TokenLabelConverter(object): + """ Convert between text-label and text-index """ + + def __init__(self, opt): + # character (str): set of the possible characters. + # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. + self.SPACE = '[s]' + self.GO = '[GO]' + #self.MASK = '[MASK]' + + #self.list_token = [self.GO, self.SPACE, self.MASK] + self.list_token = [self.GO, self.SPACE] + self.character = self.list_token + list(opt.character) + + self.dict = {word: i for i, word in enumerate(self.character)} + self.batch_max_length = opt.batch_max_length + len(self.list_token) + + def encode(self, text): + """ convert text-label into text-index. + """ + length = [len(s) + len(self.list_token) for s in text] # +2 for [GO] and [s] at end of sentence. + batch_text = torch.LongTensor(len(text), self.batch_max_length).fill_(self.dict[self.GO]) + for i, t in enumerate(text): + txt = [self.GO] + list(t) + [self.SPACE] + txt = [self.dict[char] for char in txt] + #prob = np.random.uniform() + #mask_len = round(len(list(t)) * 0.15) + #if is_train and mask_len > 0: + # for m in range(mask_len): + # index = np.random.randint(1, len(t) + 1) + # prob = np.random.uniform() + # if prob > 0.2: + # text[index] = self.dict[self.MASK] + # batch_weights[i][index] = 1. + # elif prob > 0.1: + # char_index = np.random.randint(len(self.list_token), len(self.character)) + # text[index] = self.dict[self.character[char_index]] + # batch_weights[i][index] = 1. + batch_text[i][:len(txt)] = torch.LongTensor(txt) # batch_text[:, 0] = [GO] token + return batch_text.to(device) + + def decode(self, text_index, length): + """ convert text-index into text-label. """ + texts = [] + for index, l in enumerate(length): + text = ''.join([self.character[i] for i in text_index[index, :]]) + texts.append(text) + return texts + +class SRNConverter(object): + """ Convert between text-label and text-index """ + + def __init__(self, character, PAD=36): + # character (str): set of the possible characters. + # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. + # list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] + list_character = list(character) + self.character = list_character + self.PAD = PAD + + self.dict = {} + for i, char in enumerate(self.character): + # print(i, char) + self.dict[char] = i + + def encode(self, text, batch_max_length=25): + """ convert text-label into text-index. + input: + text: text labels of each image. [batch_size] + batch_max_length: max length of text label in the batch. 25 by default + + output: + text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. + text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. + length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] + """ + length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. + # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. + batch_text = torch.cuda.LongTensor(len(text), batch_max_length + 1).fill_(self.PAD) + # mask_text = torch.cuda.LongTensor(len(text), batch_max_length).fill_(0) + for i, t in enumerate(text): + t = list(t + self.character[-2]) + text = [self.dict[char] for char in t] + # t_mask = [1 for i in range(len(text) + 1)] + batch_text[i][0:len(text)] = torch.cuda.LongTensor(text) # batch_text[:, len_text+1] = [EOS] token + # mask_text[i][0:len(text)+1] = torch.cuda.LongTensor(t_mask) + return (batch_text, torch.cuda.IntTensor(length)) + + def decode(self, text_index, length): + """ convert text-index into text-label. """ + texts = [] + for index, l in enumerate(length): + text = ''.join([self.character[i] for i in text_index[index, :]]) + idx = text.find('$') + texts.append(text[:idx]) + return texts + +class Averager(object): + """Compute average for torch.Tensor, used for loss average.""" + + def __init__(self): + self.reset() + + def add(self, v): + count = v.data.numel() + v = v.data.sum() + self.n_count += count + self.sum += v + + def reset(self): + self.n_count = 0 + self.sum = 0 + + def val(self): + res = 0 + if self.n_count != 0: + res = self.sum / float(self.n_count) + return res + +class AccuracyMeter(object): + def __init__(self): + self.hit = 0 + self.total = 0 + self.reset() + ### Important to call this after calling getAccuracy() + def reset(self): + self.hit = 0 + self.total = 0 + ### boolVal - determines if a condition is hit (true), then adds it + def applyHit(self, boolVal): + if boolVal: + self.hit += 1 + self.total += 1 + else: + self.total += 1 + def getAccuracy(self): + ### Returns accuracy in range (0-1) or (-1 of number of items = 0) + if self.total == 0: return -1 + return float(self.hit) / self.total + +def get_device(verbose=True): + use_cuda = torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + if verbose: + print("Device:", device) + return device + + +def get_args(is_train=True): + parser = argparse.ArgumentParser(description='STR') + + # for test + parser.add_argument('--eval_data', required=not is_train, help='path to evaluation dataset') + parser.add_argument('--benchmark_all_eval', action='store_true', help='evaluate 10 benchmark evaluation datasets') + parser.add_argument('--calculate_infer_time', action='store_true', help='calculate inference timing') + parser.add_argument('--flops', action='store_true', help='calculates approx flops (may not work)') + + # for train + parser.add_argument('--exp_name', help='Where to store logs and models') + parser.add_argument('--train_data', required=is_train, help='path to training dataset') + parser.add_argument('--valid_data', required=is_train, help='path to validation dataset') + parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting') + parser.add_argument('--workers', type=int, help='number of data loading workers. Use -1 to use all cores.', default=4) + parser.add_argument('--batch_size', type=int, default=192, help='input batch size') + parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for') + parser.add_argument('--valInterval', type=int, default=2000, help='Interval between each validation') + parser.add_argument('--saved_model', default='', help="path to model to continue training") + parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning') + parser.add_argument('--sgd', action='store_true', help='Whether to use SGD (default is Adadelta)') + parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)') + parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta') + parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9') + parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95') + parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') + parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5') + parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode') + """ Data processing """ + parser.add_argument('--select_data', type=str, default='MJ-ST', + help='select training data (default is MJ-ST, which means MJ and ST used as training data)') + parser.add_argument('--batch_ratio', type=str, default='0.5-0.5', + help='assign ratio for each selected data in the batch') + parser.add_argument('--total_data_usage_ratio', type=str, default='1.0', + help='total data usage ratio, this ratio is multiplied to total number of data.') + parser.add_argument('--inf_outdir', type=str, default='outdir', help='Specify output directory of influence function') + parser.add_argument('--inf_mode', type=str, default='Normal', help='Normal, VanGrad, SHAP') + parser.add_argument('--shap_pkl_root', type=str, default='', help='If Influence mode is SHAP, \ + this is a required argument. Remove last forward slash.') + parser.add_argument('--char_contrib_amnt', type=float, default=2.0, help='Multiplier on the first character for \ + contribution calculation. Min:1.0. Set to -1.0 to deactivate.') + # If --scorer is NA, then STR scorer will just output the single char index one-hot + parser.add_argument('--scorer', type=str, default='mean', help='See STRScore: cumprod | mean') + parser.add_argument('--blackbg', action='store_true', help='if True, background color for covering features will be black(0)') + parser.add_argument('--shap_eval', action='store_true', help='set always to true if you want to run test_shap.py') + parser.add_argument('--influence_train', action='store_true', help='if set to true, trains pretrained model with influence harmful/helpful') + parser.add_argument('--selective_sample_str', type=str, default='', \ + help='If =='', only sample images with string matching this (see --sensitive for case sensitivity)') + parser.add_argument('--max_selective_list', type=int, default=-1, help='if selective sample list has elements greater than this, autoclear list for batch selection') + parser.add_argument('--confidence_mode', type=int, default=0, help='0-sum of argmax; 1-edit distance') + parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') + parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') + parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') + parser.add_argument('--rgb', action='store_true', help='use rgb input') + parser.add_argument('--character', type=str, + default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') + parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') + parser.add_argument('--ignore_case_sensitivity', action='store_true', help='use this only for shap testing') + parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') + parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') + + """ Model Architecture """ + parser.add_argument('--Transformer', action='store_true', help='Use end-to-end transformer') + + choices = ["vitstr_tiny_patch16_224", "vitstr_small_patch16_224", "vitstr_base_patch16_224", "vitstr_tiny_distilled_patch16_224", "vitstr_small_distilled_patch16_224"] + parser.add_argument('--TransformerModel', default=choices[0], help='Which vit/deit transformer model', choices=choices) + parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') + parser.add_argument('--FeatureExtraction', type=str, required=True, + help='FeatureExtraction stage. VGG|RCNN|ResNet') + parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') + parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. None|CTC|Attn') + parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') + parser.add_argument('--input_channel', type=int, default=1, + help='the number of input channel of Feature extractor') + parser.add_argument('--output_channel', type=int, default=512, + help='the number of output channel of Feature extractor') + parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') + + # selective augmentation (individual) + # can choose specific data augmentation + parser.add_argument('--issel_aug', action='store_true', help='Select augs') + parser.add_argument('--sel_prob', type=float, default=1., help='Probability of applying augmentation') + parser.add_argument('--pattern', action='store_true', help='Pattern group') + parser.add_argument('--warp', action='store_true', help='Warp group') + parser.add_argument('--geometry', action='store_true', help='Geometry group') + parser.add_argument('--weather', action='store_true', help='Weather group') + parser.add_argument('--noise', action='store_true', help='Noise group') + parser.add_argument('--blur', action='store_true', help='Blur group') + parser.add_argument('--camera', action='store_true', help='Camera group') + parser.add_argument('--process', action='store_true', help='Image processing routines') + parser.add_argument('--min_rand', type=int, default=0, help='minimum magnitude for aug (inclusive)') + parser.add_argument('--max_rand', type=int, default=3, help='maximum magnitude for aug (exclusive)') + + # use cosine learning rate decay + parser.add_argument('--scheduler', action='store_true', help='Use lr scheduler') + + parser.add_argument('--intact_prob', type=float, default=0.5, help='Probability of not applying augmentation') + parser.add_argument('--isrand_aug', action='store_true', help='Use RandAug') + parser.add_argument('--isshap_aug', action='store_true', help='Use SHAPAug') + parser.add_argument('--augs_num', type=int, default=3, help='Number of data augment groups to apply. 1 to 8.') + parser.add_argument('--augs_mag', type=int, default=None, help='Magnitude of data augment groups to apply. None if random.') + + # for comparison to other augmentations + parser.add_argument('--issemantic_aug', action='store_true', help='Use Semantic') + parser.add_argument('--isrotation_aug', action='store_true', help='Use ') + parser.add_argument('--isscatter_aug', action='store_true', help='Use ') + parser.add_argument('--islearning_aug', action='store_true', help='Use ') + + # orig paper uses this for fast benchmarking + parser.add_argument('--fast_acc', action='store_true', help='Fast average accuracy computation') + + args = parser.parse_args() + return args diff --git a/utils_abinet.py b/utils_abinet.py new file mode 100644 index 0000000000000000000000000000000000000000..ae5512e66716eb41af474e7b8756adfad4686cd4 --- /dev/null +++ b/utils_abinet.py @@ -0,0 +1,305 @@ +import logging +import os +import time + +import cv2 +import numpy as np +import torch +import yaml +from matplotlib import colors +from matplotlib import pyplot as plt +from torch import Tensor, nn +from torch.utils.data import ConcatDataset + +class CharsetMapper(object): + """A simple class to map ids into strings. + + It works only when the character set is 1:1 mapping between individual + characters and individual ids. + """ + + def __init__(self, + filename='', + max_length=30, + null_char=u'\u2591'): + """Creates a lookup table. + + Args: + filename: Path to charset file which maps characters to ids. + max_sequence_length: The max length of ids and string. + null_char: A unicode character used to replace '' character. + the default value is a light shade block '░'. + """ + self.null_char = null_char + self.max_length = max_length + + self.label_to_char = self._read_charset(filename) + self.char_to_label = dict(map(reversed, self.label_to_char.items())) + self.num_classes = len(self.label_to_char) + # print("self.num_classes: ", self.num_classes) ### 37 + + def _read_charset(self, filename): + """Reads a charset definition from a tab separated text file. + + Args: + filename: a path to the charset file. + + Returns: + a dictionary with keys equal to character codes and values - unicode + characters. + """ + import re + pattern = re.compile(r'(\d+)\t(.+)') + charset = {} + self.null_label = 0 + charset[self.null_label] = self.null_char + with open(filename, 'r') as f: + for i, line in enumerate(f): + m = pattern.match(line) + assert m, f'Incorrect charset file. line #{i}: {line}' + label = int(m.group(1)) + 1 + char = m.group(2) + charset[label] = char + return charset + + def trim(self, text): + assert isinstance(text, str) + return text.replace(self.null_char, '') + + def get_text(self, labels, length=None, padding=True, trim=False): + """ Returns a string corresponding to a sequence of character ids. + """ + length = length if length else self.max_length + labels = [l.item() if isinstance(l, Tensor) else int(l) for l in labels] + if padding: + labels = labels + [self.null_label] * (length-len(labels)) + text = ''.join([self.label_to_char[label] for label in labels]) + if trim: text = self.trim(text) + return text + + def get_labels(self, text, length=None, padding=True, case_sensitive=False): + """ Returns the labels of the corresponding text. + """ + length = length if length else self.max_length + if padding: + text = text + self.null_char * (length - len(text)) + if not case_sensitive: + text = text.lower() + labels = [self.char_to_label[char] for char in text] + return labels + + def pad_labels(self, labels, length=None): + length = length if length else self.max_length + + return labels + [self.null_label] * (length - len(labels)) + + @property + def digits(self): + return '0123456789' + + @property + def digit_labels(self): + return self.get_labels(self.digits, padding=False) + + @property + def alphabets(self): + all_chars = list(self.char_to_label.keys()) + valid_chars = [] + for c in all_chars: + if c in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ': + valid_chars.append(c) + return ''.join(valid_chars) + + @property + def alphabet_labels(self): + return self.get_labels(self.alphabets, padding=False) + + +class Timer(object): + """A simple timer.""" + def __init__(self): + self.data_time = 0. + self.data_diff = 0. + self.data_total_time = 0. + self.data_call = 0 + self.running_time = 0. + self.running_diff = 0. + self.running_total_time = 0. + self.running_call = 0 + + def tic(self): + self.start_time = time.time() + self.running_time = self.start_time + + def toc_data(self): + self.data_time = time.time() + self.data_diff = self.data_time - self.running_time + self.data_total_time += self.data_diff + self.data_call += 1 + + def toc_running(self): + self.running_time = time.time() + self.running_diff = self.running_time - self.data_time + self.running_total_time += self.running_diff + self.running_call += 1 + + def total_time(self): + return self.data_total_time + self.running_total_time + + def average_time(self): + return self.average_data_time() + self.average_running_time() + + def average_data_time(self): + return self.data_total_time / (self.data_call or 1) + + def average_running_time(self): + return self.running_total_time / (self.running_call or 1) + + +class Logger(object): + _handle = None + _root = None + + @staticmethod + def init(output_dir, name, phase): + format = '[%(asctime)s %(filename)s:%(lineno)d %(levelname)s {}] ' \ + '%(message)s'.format(name) + logging.basicConfig(level=logging.INFO, format=format) + + try: os.makedirs(output_dir) + except: pass + config_path = os.path.join(output_dir, f'{phase}.txt') + Logger._handle = logging.FileHandler(config_path) + Logger._root = logging.getLogger() + + @staticmethod + def enable_file(): + if Logger._handle is None or Logger._root is None: + raise Exception('Invoke Logger.init() first!') + Logger._root.addHandler(Logger._handle) + + @staticmethod + def disable_file(): + if Logger._handle is None or Logger._root is None: + raise Exception('Invoke Logger.init() first!') + Logger._root.removeHandler(Logger._handle) + + +class Config(object): + + def __init__(self, config_path, host=True): + def __dict2attr(d, prefix=''): + for k, v in d.items(): + if isinstance(v, dict): + __dict2attr(v, f'{prefix}{k}_') + else: + if k == 'phase': + assert v in ['train', 'test'] + if k == 'stage': + assert v in ['pretrain-vision', 'pretrain-language', + 'train-semi-super', 'train-super'] + self.__setattr__(f'{prefix}{k}', v) + + assert os.path.exists(config_path), '%s does not exists!' % config_path + with open(config_path) as file: + config_dict = yaml.load(file, Loader=yaml.FullLoader) + with open('configs/template.yaml') as file: + default_config_dict = yaml.load(file, Loader=yaml.FullLoader) + __dict2attr(default_config_dict) + __dict2attr(config_dict) + self.global_workdir = os.path.join(self.global_workdir, self.global_name) + + def __getattr__(self, item): + attr = self.__dict__.get(item) + if attr is None: + attr = dict() + prefix = f'{item}_' + for k, v in self.__dict__.items(): + if k.startswith(prefix): + n = k.replace(prefix, '') + attr[n] = v + return attr if len(attr) > 0 else None + else: + return attr + + def __repr__(self): + str = 'ModelConfig(\n' + for i, (k, v) in enumerate(sorted(vars(self).items())): + str += f'\t({i}): {k} = {v}\n' + str += ')' + return str + +def blend_mask(image, mask, alpha=0.5, cmap='jet', color='b', color_alpha=1.0): + # normalize mask + mask = (mask-mask.min()) / (mask.max() - mask.min() + np.finfo(float).eps) + if mask.shape != image.shape: + mask = cv2.resize(mask,(image.shape[1], image.shape[0])) + # get color map + color_map = plt.get_cmap(cmap) + mask = color_map(mask)[:,:,:3] + # convert float to uint8 + mask = (mask * 255).astype(dtype=np.uint8) + + # set the basic color + basic_color = np.array(colors.to_rgb(color)) * 255 + basic_color = np.tile(basic_color, [image.shape[0], image.shape[1], 1]) + basic_color = basic_color.astype(dtype=np.uint8) + # blend with basic color + blended_img = cv2.addWeighted(image, color_alpha, basic_color, 1-color_alpha, 0) + # blend with mask + blended_img = cv2.addWeighted(blended_img, alpha, mask, 1-alpha, 0) + + return blended_img + +def onehot(label, depth, device=None): + """ + Args: + label: shape (n1, n2, ..., ) + depth: a scalar + + Returns: + onehot: (n1, n2, ..., depth) + """ + if not isinstance(label, torch.Tensor): + label = torch.tensor(label, device=device) + onehot = torch.zeros(label.size() + torch.Size([depth]), device=device) + onehot = onehot.scatter_(-1, label.unsqueeze(-1), 1) + + return onehot + +class MyDataParallel(nn.DataParallel): + + def gather(self, outputs, target_device): + r""" + Gathers tensors from different GPUs on a specified device + (-1 means the CPU). + """ + def gather_map(outputs): + out = outputs[0] + if isinstance(out, (str, int, float)): + return out + if isinstance(out, list) and isinstance(out[0], str): + return [o for out in outputs for o in out] + if isinstance(out, torch.Tensor): + return torch.nn.parallel._functions.Gather.apply(target_device, self.dim, *outputs) + if out is None: + return None + if isinstance(out, dict): + if not all((len(out) == len(d) for d in outputs)): + raise ValueError('All dicts must have the same number of keys') + return type(out)(((k, gather_map([d[k] for d in outputs])) + for k in out)) + return type(out)(map(gather_map, zip(*outputs))) + + # Recursive function calls like this create reference cycles. + # Setting the function to None clears the refcycle. + try: + res = gather_map(outputs) + finally: + gather_map = None + return res + + +class MyConcatDataset(ConcatDataset): + def __getattr__(self, k): + return getattr(self.datasets[0], k) diff --git a/utils_matrn.py b/utils_matrn.py new file mode 100644 index 0000000000000000000000000000000000000000..0b99a4055646f20c2775e22c34cee49b2c3acc07 --- /dev/null +++ b/utils_matrn.py @@ -0,0 +1,303 @@ +import logging +import os +import time + +import cv2 +import numpy as np +import torch +import yaml +from matplotlib import colors +from matplotlib import pyplot as plt +from torch import Tensor, nn +from torch.utils.data import ConcatDataset + +class CharsetMapper(object): + """A simple class to map ids into strings. + + It works only when the character set is 1:1 mapping between individual + characters and individual ids. + """ + + def __init__(self, + filename='', + max_length=30, + null_char=u'\u2591'): + """Creates a lookup table. + + Args: + filename: Path to charset file which maps characters to ids. + max_sequence_length: The max length of ids and string. + null_char: A unicode character used to replace '' character. + the default value is a light shade block '░'. + """ + self.null_char = null_char + self.max_length = max_length + + self.label_to_char = self._read_charset(filename) + self.char_to_label = dict(map(reversed, self.label_to_char.items())) + self.num_classes = len(self.label_to_char) + + def _read_charset(self, filename): + """Reads a charset definition from a tab separated text file. + + Args: + filename: a path to the charset file. + + Returns: + a dictionary with keys equal to character codes and values - unicode + characters. + """ + import re + pattern = re.compile(r'(\d+)\t(.+)') + charset = {} + self.null_label = 0 + charset[self.null_label] = self.null_char + with open(filename, 'r') as f: + for i, line in enumerate(f): + m = pattern.match(line) + assert m, f'Incorrect charset file. line #{i}: {line}' + label = int(m.group(1)) + 1 + char = m.group(2) + charset[label] = char + return charset + + def trim(self, text): + assert isinstance(text, str) + return text.replace(self.null_char, '') + + def get_text(self, labels, length=None, padding=True, trim=False): + """ Returns a string corresponding to a sequence of character ids. + """ + length = length if length else self.max_length + labels = [l.item() if isinstance(l, Tensor) else int(l) for l in labels] + if padding: + labels = labels + [self.null_label] * (length-len(labels)) + text = ''.join([self.label_to_char[label] for label in labels]) + if trim: text = self.trim(text) + return text + + def get_labels(self, text, length=None, padding=True, case_sensitive=False): + """ Returns the labels of the corresponding text. + """ + length = length if length else self.max_length + if padding: + text = text + self.null_char * (length - len(text)) + if not case_sensitive: + text = text.lower() + labels = [self.char_to_label[char] for char in text] + return labels + + def pad_labels(self, labels, length=None): + length = length if length else self.max_length + + return labels + [self.null_label] * (length - len(labels)) + + @property + def digits(self): + return '0123456789' + + @property + def digit_labels(self): + return self.get_labels(self.digits, padding=False) + + @property + def alphabets(self): + all_chars = list(self.char_to_label.keys()) + valid_chars = [] + for c in all_chars: + if c in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ': + valid_chars.append(c) + return ''.join(valid_chars) + + @property + def alphabet_labels(self): + return self.get_labels(self.alphabets, padding=False) + + +class Timer(object): + """A simple timer.""" + def __init__(self): + self.data_time = 0. + self.data_diff = 0. + self.data_total_time = 0. + self.data_call = 0 + self.running_time = 0. + self.running_diff = 0. + self.running_total_time = 0. + self.running_call = 0 + + def tic(self): + self.start_time = time.time() + self.running_time = self.start_time + + def toc_data(self): + self.data_time = time.time() + self.data_diff = self.data_time - self.running_time + self.data_total_time += self.data_diff + self.data_call += 1 + + def toc_running(self): + self.running_time = time.time() + self.running_diff = self.running_time - self.data_time + self.running_total_time += self.running_diff + self.running_call += 1 + + def total_time(self): + return self.data_total_time + self.running_total_time + + def average_time(self): + return self.average_data_time() + self.average_running_time() + + def average_data_time(self): + return self.data_total_time / (self.data_call or 1) + + def average_running_time(self): + return self.running_total_time / (self.running_call or 1) + + +class Logger(object): + _handle = None + _root = None + + @staticmethod + def init(output_dir, name, phase): + format = '[%(asctime)s %(filename)s:%(lineno)d %(levelname)s {}] ' \ + '%(message)s'.format(name) + logging.basicConfig(level=logging.INFO, format=format) + + try: os.makedirs(output_dir) + except: pass + config_path = os.path.join(output_dir, f'{phase}.txt') + Logger._handle = logging.FileHandler(config_path) + Logger._root = logging.getLogger() + + @staticmethod + def enable_file(): + if Logger._handle is None or Logger._root is None: + raise Exception('Invoke Logger.init() first!') + Logger._root.addHandler(Logger._handle) + + @staticmethod + def disable_file(): + if Logger._handle is None or Logger._root is None: + raise Exception('Invoke Logger.init() first!') + Logger._root.removeHandler(Logger._handle) + + +class Config(object): + + def __init__(self, config_path, host=True): + def __dict2attr(d, prefix=''): + for k, v in d.items(): + if isinstance(v, dict): + __dict2attr(v, f'{prefix}{k}_') + else: + if k == 'phase': + assert v in ['train', 'test'] + if k == 'stage': + assert v in ['pretrain-vision', 'pretrain-language', + 'train-semi-super', 'train-super'] + self.__setattr__(f'{prefix}{k}', v) + + assert os.path.exists(config_path), '%s does not exists!' % config_path + with open(config_path) as file: + config_dict = yaml.load(file, Loader=yaml.FullLoader) + with open('configs/template.yaml') as file: + default_config_dict = yaml.load(file, Loader=yaml.FullLoader) + __dict2attr(default_config_dict) + __dict2attr(config_dict) + + def __getattr__(self, item): + attr = self.__dict__.get(item) + if attr is None: + attr = dict() + prefix = f'{item}_' + for k, v in self.__dict__.items(): + if k.startswith(prefix): + n = k.replace(prefix, '') + attr[n] = v + return attr if len(attr) > 0 else None + else: + return attr + + def __repr__(self): + str = 'ModelConfig(\n' + for i, (k, v) in enumerate(sorted(vars(self).items())): + str += f'\t({i}): {k} = {v}\n' + str += ')' + return str + +def blend_mask(image, mask, alpha=0.5, cmap='jet', color='b', color_alpha=1.0): + # normalize mask + mask = (mask-mask.min()) / (mask.max() - mask.min() + np.finfo(float).eps) + if mask.shape != image.shape: + mask = cv2.resize(mask,(image.shape[1], image.shape[0])) + # get color map + color_map = plt.get_cmap(cmap) + mask = color_map(mask)[:,:,:3] + # convert float to uint8 + mask = (mask * 255).astype(dtype=np.uint8) + + # set the basic color + basic_color = np.array(colors.to_rgb(color)) * 255 + basic_color = np.tile(basic_color, [image.shape[0], image.shape[1], 1]) + basic_color = basic_color.astype(dtype=np.uint8) + # blend with basic color + blended_img = cv2.addWeighted(image, color_alpha, basic_color, 1-color_alpha, 0) + # blend with mask + blended_img = cv2.addWeighted(blended_img, alpha, mask, 1-alpha, 0) + + return blended_img + +def onehot(label, depth, device=None): + """ + Args: + label: shape (n1, n2, ..., ) + depth: a scalar + + Returns: + onehot: (n1, n2, ..., depth) + """ + if not isinstance(label, torch.Tensor): + label = torch.tensor(label, device=device) + onehot = torch.zeros(label.size() + torch.Size([depth]), device=device) + onehot = onehot.scatter_(-1, label.unsqueeze(-1), 1) + + return onehot + +class MyDataParallel(nn.DataParallel): + + def gather(self, outputs, target_device): + r""" + Gathers tensors from different GPUs on a specified device + (-1 means the CPU). + """ + def gather_map(outputs): + out = outputs[0] + if isinstance(out, (str, int, float)): + return out + if isinstance(out, list) and isinstance(out[0], str): + return [o for out in outputs for o in out] + if isinstance(out, torch.Tensor): + return torch.nn.parallel._functions.Gather.apply(target_device, self.dim, *outputs) + if out is None: + return None + if isinstance(out, dict): + if not all((len(out) == len(d) for d in outputs)): + raise ValueError('All dicts must have the same number of keys') + return type(out)(((k, gather_map([d[k] for d in outputs])) + for k in out)) + return type(out)(map(gather_map, zip(*outputs))) + + # Recursive function calls like this create reference cycles. + # Setting the function to None clears the refcycle. + try: + res = gather_map(outputs) + finally: + gather_map = None + return res + + +class MyConcatDataset(ConcatDataset): + def __getattr__(self, k): + return getattr(self.datasets[0], k)