strexp / captum_test.py
markytools's picture
added strexp
d61b9c7
raw
history blame
15.9 kB
import settings
import captum
import numpy as np
import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from utils import get_args
from utils import CTCLabelConverter, AttnLabelConverter, Averager, TokenLabelConverter
import string
import time
import sys
from dataset import hierarchical_dataset, AlignCollate
import validators
from model import Model, STRScore
from PIL import Image
from lime.wrappers.scikit_image import SegmentationAlgorithm
from captum._utils.models.linear_model import SkLearnLinearModel, SkLearnRidge
import random
import os
from skimage.color import gray2rgb
import pickle
from train_shap_corr import getPredAndConf
import re
import copy
import statistics
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from captum.attr import (
GradientShap,
DeepLift,
DeepLiftShap,
IntegratedGradients,
LayerConductance,
NeuronConductance,
NoiseTunnel,
Saliency,
InputXGradient,
GuidedBackprop,
Deconvolution,
GuidedGradCam,
FeatureAblation,
ShapleyValueSampling,
Lime,
KernelShap
)
from captum.metrics import (
infidelity,
sensitivity_max
)
### Returns the mean for each segmentation having shape as the same as the input
### This function can only one attribution image at a time
def averageSegmentsOut(attr, segments):
averagedInput = torch.clone(attr)
sortedDict = {}
for x in np.unique(segments):
segmentMean = torch.mean(attr[segments == x][:])
sortedDict[x] = float(segmentMean.detach().cpu().numpy())
averagedInput[segments == x] = segmentMean
return averagedInput, sortedDict
### Output and save segmentations only for one dataset only
def outputSegmOnly(opt):
### targetDataset - one dataset only, SVTP-645, CUTE80-288images
targetDataset = "CUTE80" # ['IIIT5k_3000', 'SVT', 'IC03_867', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80']
segmRootDir = "/home/uclpc1/Documents/STR/datasets/segmentations/224X224/{}/".format(targetDataset)
if not os.path.exists(segmRootDir):
os.makedirs(segmRootDir)
opt.eval = True
### Only IIIT5k_3000
if opt.fast_acc:
# # To easily compute the total accuracy of our paper.
eval_data_list = [targetDataset]
else:
# The evaluation datasets, dataset order is same with Table 1 in our paper.
eval_data_list = [targetDataset]
### Taken from LIME
segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4,
max_dist=200, ratio=0.2,
random_seed=random.randint(0, 1000))
for eval_data in eval_data_list:
eval_data_path = os.path.join(opt.eval_data, eval_data)
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt)
eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt)
evaluation_loader = torch.utils.data.DataLoader(
eval_data, batch_size=1,
shuffle=False,
num_workers=int(opt.workers),
collate_fn=AlignCollate_evaluation, pin_memory=True)
for i, (image_tensors, labels) in enumerate(evaluation_loader):
imgDataDict = {}
img_numpy = image_tensors.cpu().detach().numpy()[0] ### Need to set batch size to 1 only
if img_numpy.shape[0] == 1:
img_numpy = gray2rgb(img_numpy[0])
# print("img_numpy shape: ", img_numpy.shape) # (224,224,3)
segmOutput = segmentation_fn(img_numpy)
imgDataDict['segdata'] = segmOutput
imgDataDict['label'] = labels[0]
outputPickleFile = segmRootDir + "{}.pkl".format(i)
with open(outputPickleFile, 'wb') as f:
pickle.dump(imgDataDict, f)
def acquireSelectivityHit(origImg, attributions, segmentations, model, converter, labels, scoring):
# print("segmentations unique len: ", np.unique(segmentations))
aveSegmentations, sortedDict = averageSegmentsOut(attributions[0,0], segmentations)
sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])]
sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score
# print("sortedDict: ", sortedDict) # {0: -5.51e-06, 1: -1.469e-05, 2: -3.06e-05,...}
# print("aveSegmentations unique len: ", np.unique(aveSegmentations))
# print("aveSegmentations device: ", aveSegmentations.device) # cuda:0
# print("aveSegmentations shape: ", aveSegmentations.shape) # (224,224)
# print("aveSegmentations: ", aveSegmentations)
n_correct = []
confidenceList = [] # First index is one feature removed, second index two features removed, and so on...
clonedImg = torch.clone(origImg)
gt = str(labels)
for totalSegToHide in range(0, len(sortedKeys)):
### Acquire LIME prediction result
currentSegmentToHide = sortedKeys[totalSegToHide]
clonedImg[0,0][segmentations == currentSegmentToHide] = 0.0
pred, confScore = getPredAndConf(opt, model, scoring, clonedImg, converter, np.array([gt]))
# To evaluate 'case sensitive model' with alphanumeric and case insensitve setting.
if opt.sensitive and opt.data_filtering_off:
pred = pred.lower()
gt = gt.lower()
alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt)
if pred == gt:
n_correct.append(1)
else:
n_correct.append(0)
confScore = confScore[0][0]*100
confidenceList.append(confScore)
return n_correct, confidenceList
### Once you have the selectivity_eval_results.pkl file,
def acquire_selectivity_auc(opt, pkl_filename=None):
if pkl_filename is None:
pkl_filename = "/home/goo/str/str_vit_dataexplain_lambda/metrics_sensitivity_eval_results_CUTE80.pkl" # VITSTR
accKeys = []
with open(pkl_filename, 'rb') as f:
selectivity_data = pickle.load(f)
for resDictIdx, resDict in enumerate(selectivity_data):
keylistAcc = []
keylistConf = []
metricsKeys = resDict.keys()
for keyStr in resDict.keys():
if "_acc" in keyStr: keylistAcc.append(keyStr)
if "_conf" in keyStr: keylistConf.append(keyStr)
# Need to check if network correctly predicted the image
for metrics_accStr in keylistAcc:
if 1 not in resDict[metrics_accStr]: print("resDictIdx")
## gtClassNum - set to gtClassNum=0 for standard implemention, or specific class idx for local explanation
def acquireAttribution(opt, super_model, input, segmTensor, gtClassNum, lowestAccKey, device):
channels = 1
if opt.rgb:
channels = 3
### Perform attribution
if "intgrad_" in lowestAccKey:
ig = IntegratedGradients(super_model)
attributions = ig.attribute(input, target=gtClassNum)
elif "gradshap_" in lowestAccKey:
gs = GradientShap(super_model)
baseline_dist = torch.zeros((1, channels, opt.imgH, opt.imgW))
baseline_dist = baseline_dist.to(device)
attributions = gs.attribute(input, baselines=baseline_dist, target=gtClassNum)
elif "deeplift_" in lowestAccKey:
dl = DeepLift(super_model)
attributions = dl.attribute(input, target=gtClassNum)
elif "saliency_" in lowestAccKey:
saliency = Saliency(super_model)
attributions = saliency.attribute(input, target=gtClassNum)
elif "inpxgrad_" in lowestAccKey:
input_x_gradient = InputXGradient(super_model)
attributions = input_x_gradient.attribute(input, target=gtClassNum)
elif "guidedbp_" in lowestAccKey:
gbp = GuidedBackprop(super_model)
attributions = gbp.attribute(input, target=gtClassNum)
elif "deconv_" in lowestAccKey:
deconv = Deconvolution(super_model)
attributions = deconv.attribute(input, target=gtClassNum)
elif "featablt_" in lowestAccKey:
ablator = FeatureAblation(super_model)
attributions = ablator.attribute(input, target=gtClassNum, feature_mask=segmTensor)
elif "shapley_" in lowestAccKey:
svs = ShapleyValueSampling(super_model)
attributions = svs.attribute(input, target=gtClassNum, feature_mask=segmTensor)
elif "lime_" in lowestAccKey:
interpretable_model = SkLearnRidge(alpha=1, fit_intercept=True) ### This is the default used by LIME
lime = Lime(super_model, interpretable_model=interpretable_model)
attributions = lime.attribute(input, target=gtClassNum, feature_mask=segmTensor)
elif "kernelshap_" in lowestAccKey:
ks = KernelShap(super_model)
attributions = ks.attribute(input, target=gtClassNum, feature_mask=segmTensor)
else:
assert False
return attributions
### In addition to acquire_average_auc(), this function also returns the best selectivity_acc attr-based method
### pklFile - you need to pass pkl file here
def acquire_bestacc_attr(opt, pickleFile):
# pickleFile = "metrics_sensitivity_eval_results_IIIT5k_3000.pkl"
# pickleFile = "/home/goo/str/str_vit_dataexplain_lambda/shapley_singlechar_ave_matrn_SVT.pkl"
acquireSelectivity = True # If True, set to
acquireInfidelity = False
acquireSensitivity = False
with open(pickleFile, 'rb') as f:
data = pickle.load(f)
metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens"
selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle
for imgData in data:
if acquireSelectivity:
for keyStr in imgData.keys():
if ("_acc" in keyStr or "_conf" in keyStr) and not ("_local_" in keyStr or "_global_local_" in keyStr): # Accept only selectivity
if keyStr not in metricDict:
metricDict[keyStr] = []
dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0]
dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0
denom = [1] * len(dataList) # Denominator to normalize AUC
auc_norm = np.trapz(dataList) / np.trapz(denom)
metricDict[keyStr].append(auc_norm)
elif acquireInfidelity:
pass # TODO
elif acquireSensitivity:
pass # TODO
lowestAccKey = ""
lowestAcc = 10000000
for metricKey in metricDict:
if "_acc" in metricKey: # Used for selectivity accuracy only
statisticVal = statistics.mean(metricDict[metricKey])
if statisticVal < lowestAcc:
lowestAcc = statisticVal
lowestAccKey = metricKey
# print("{}: {}".format(metricKey, statisticVal))
assert lowestAccKey!=""
return lowestAccKey
def saveAttrData(filename, attribution, segmData, origImg):
pklData = {}
pklData['attribution'] = torch.clone(attribution).detach().cpu().numpy()
pklData['segmData'] = segmData
pklData['origImg'] = origImg
with open(filename, 'wb') as f:
pickle.dump(pklData, f)
### New code (8/3/2022) to acquire average selectivity, infidelity, etc. after running captum test
def acquire_average_auc(opt):
# pickleFile = "metrics_sensitivity_eval_results_IIIT5k_3000.pkl"
pickleFile = "/home/goo/str/str_vit_dataexplain_lambda/shapley_singlechar_ave_vitstr_IC03_860.pkl"
acquireSelectivity = True # If True, set to
acquireInfidelity = False
acquireSensitivity = False
with open(pickleFile, 'rb') as f:
data = pickle.load(f)
metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens"
selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle
for imgData in data:
if acquireSelectivity:
for keyStr in imgData.keys():
if "_acc" in keyStr or "_conf" in keyStr: # Accept only selectivity
if keyStr not in metricDict:
metricDict[keyStr] = []
dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0]
dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0
denom = [1] * len(dataList) # Denominator to normalize AUC
auc_norm = np.trapz(dataList) / np.trapz(denom)
metricDict[keyStr].append(auc_norm)
elif acquireInfidelity:
pass # TODO
elif acquireSensitivity:
pass # TODO
for metricKey in metricDict:
print("{}: {}".format(metricKey, statistics.mean(metricDict[metricKey])))
### Use this acquire list
def acquireListOfAveAUC(opt):
acquireSelectivity = True
acquireInfidelity = False
acquireSensitivity = False
totalChars = 10
collectedMetricDict = {}
for charNum in range(0, totalChars):
pickleFile = f"/home/goo/str/str_vit_dataexplain_lambda/singlechar{charNum}_results_{totalChars}chardataset.pkl"
with open(pickleFile, 'rb') as f:
data = pickle.load(f)
metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens"
selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle
for imgData in data:
if acquireSelectivity:
for keyStr in imgData.keys():
if "_acc" in keyStr or "_conf" in keyStr: # Accept only selectivity
if keyStr not in metricDict:
metricDict[keyStr] = []
dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0]
dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0
denom = [1] * len(dataList) # Denominator to normalize AUC
auc_norm = np.trapz(dataList) / np.trapz(denom)
metricDict[keyStr].append(auc_norm)
for metricKey in metricDict:
selec_auc_normalize = statistics.mean(metricDict[metricKey])
if metricKey not in collectedMetricDict:
collectedMetricDict[metricKey] = []
collectedMetricDict[metricKey].append(selec_auc_normalize)
for collectedMetricDictKey in collectedMetricDict:
print("{}: {}".format(collectedMetricDictKey, collectedMetricDict[collectedMetricDictKey]))
for charNum in range(0, totalChars):
selectivityAcrossCharsLs = []
for collectedMetricDictKey in collectedMetricDict:
if "_acc" in collectedMetricDictKey:
selectivityAcrossCharsLs.append(collectedMetricDict[collectedMetricDictKey][charNum])
print("accuracy -- {}: {}".format(charNum, statistics.mean(selectivityAcrossCharsLs)))
for charNum in range(0, totalChars):
selectivityAcrossCharsLs = []
for collectedMetricDictKey in collectedMetricDict:
if "_conf" in collectedMetricDictKey:
selectivityAcrossCharsLs.append(collectedMetricDict[collectedMetricDictKey][charNum])
print("confidence -- {}: {}".format(charNum, statistics.mean(selectivityAcrossCharsLs)))
if __name__ == '__main__':
# deleteInf()
opt = get_args(is_train=False)
""" vocab / character number configuration """
if opt.sensitive:
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
cudnn.benchmark = True
cudnn.deterministic = True
opt.num_gpu = torch.cuda.device_count()
main(opt)