import settings import captum import numpy as np import torch import torch.nn.functional as F import torch.backends.cudnn as cudnn from utils import get_args from utils import CTCLabelConverter, AttnLabelConverter, Averager, TokenLabelConverter import string import time import sys from dataset import hierarchical_dataset, AlignCollate import validators from model import Model, STRScore from PIL import Image from lime.wrappers.scikit_image import SegmentationAlgorithm from captum._utils.models.linear_model import SkLearnLinearModel, SkLearnRidge import random import os from skimage.color import gray2rgb import pickle from train_shap_corr import getPredAndConf import re import copy import statistics device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') from captum.attr import ( GradientShap, DeepLift, DeepLiftShap, IntegratedGradients, LayerConductance, NeuronConductance, NoiseTunnel, Saliency, InputXGradient, GuidedBackprop, Deconvolution, GuidedGradCam, FeatureAblation, ShapleyValueSampling, Lime, KernelShap ) from captum.metrics import ( infidelity, sensitivity_max ) ### Returns the mean for each segmentation having shape as the same as the input ### This function can only one attribution image at a time def averageSegmentsOut(attr, segments): averagedInput = torch.clone(attr) sortedDict = {} for x in np.unique(segments): segmentMean = torch.mean(attr[segments == x][:]) sortedDict[x] = float(segmentMean.detach().cpu().numpy()) averagedInput[segments == x] = segmentMean return averagedInput, sortedDict ### Output and save segmentations only for one dataset only def outputSegmOnly(opt): ### targetDataset - one dataset only, SVTP-645, CUTE80-288images targetDataset = "CUTE80" # ['IIIT5k_3000', 'SVT', 'IC03_867', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] segmRootDir = "/home/uclpc1/Documents/STR/datasets/segmentations/224X224/{}/".format(targetDataset) if not os.path.exists(segmRootDir): os.makedirs(segmRootDir) opt.eval = True ### Only IIIT5k_3000 if opt.fast_acc: # # To easily compute the total accuracy of our paper. eval_data_list = [targetDataset] else: # The evaluation datasets, dataset order is same with Table 1 in our paper. eval_data_list = [targetDataset] ### Taken from LIME segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, max_dist=200, ratio=0.2, random_seed=random.randint(0, 1000)) for eval_data in eval_data_list: eval_data_path = os.path.join(opt.eval_data, eval_data) AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt) eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=1, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) for i, (image_tensors, labels) in enumerate(evaluation_loader): imgDataDict = {} img_numpy = image_tensors.cpu().detach().numpy()[0] ### Need to set batch size to 1 only if img_numpy.shape[0] == 1: img_numpy = gray2rgb(img_numpy[0]) # print("img_numpy shape: ", img_numpy.shape) # (224,224,3) segmOutput = segmentation_fn(img_numpy) imgDataDict['segdata'] = segmOutput imgDataDict['label'] = labels[0] outputPickleFile = segmRootDir + "{}.pkl".format(i) with open(outputPickleFile, 'wb') as f: pickle.dump(imgDataDict, f) def acquireSelectivityHit(origImg, attributions, segmentations, model, converter, labels, scoring): # print("segmentations unique len: ", np.unique(segmentations)) aveSegmentations, sortedDict = averageSegmentsOut(attributions[0,0], segmentations) sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])] sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score # print("sortedDict: ", sortedDict) # {0: -5.51e-06, 1: -1.469e-05, 2: -3.06e-05,...} # print("aveSegmentations unique len: ", np.unique(aveSegmentations)) # print("aveSegmentations device: ", aveSegmentations.device) # cuda:0 # print("aveSegmentations shape: ", aveSegmentations.shape) # (224,224) # print("aveSegmentations: ", aveSegmentations) n_correct = [] confidenceList = [] # First index is one feature removed, second index two features removed, and so on... clonedImg = torch.clone(origImg) gt = str(labels) for totalSegToHide in range(0, len(sortedKeys)): ### Acquire LIME prediction result currentSegmentToHide = sortedKeys[totalSegToHide] clonedImg[0,0][segmentations == currentSegmentToHide] = 0.0 pred, confScore = getPredAndConf(opt, model, scoring, clonedImg, converter, np.array([gt])) # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. if opt.sensitive and opt.data_filtering_off: pred = pred.lower() gt = gt.lower() alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz' out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]' pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred) gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt) if pred == gt: n_correct.append(1) else: n_correct.append(0) confScore = confScore[0][0]*100 confidenceList.append(confScore) return n_correct, confidenceList ### Once you have the selectivity_eval_results.pkl file, def acquire_selectivity_auc(opt, pkl_filename=None): if pkl_filename is None: pkl_filename = "/home/goo/str/str_vit_dataexplain_lambda/metrics_sensitivity_eval_results_CUTE80.pkl" # VITSTR accKeys = [] with open(pkl_filename, 'rb') as f: selectivity_data = pickle.load(f) for resDictIdx, resDict in enumerate(selectivity_data): keylistAcc = [] keylistConf = [] metricsKeys = resDict.keys() for keyStr in resDict.keys(): if "_acc" in keyStr: keylistAcc.append(keyStr) if "_conf" in keyStr: keylistConf.append(keyStr) # Need to check if network correctly predicted the image for metrics_accStr in keylistAcc: if 1 not in resDict[metrics_accStr]: print("resDictIdx") ## gtClassNum - set to gtClassNum=0 for standard implemention, or specific class idx for local explanation def acquireAttribution(opt, super_model, input, segmTensor, gtClassNum, lowestAccKey, device): channels = 1 if opt.rgb: channels = 3 ### Perform attribution if "intgrad_" in lowestAccKey: ig = IntegratedGradients(super_model) attributions = ig.attribute(input, target=gtClassNum) elif "gradshap_" in lowestAccKey: gs = GradientShap(super_model) baseline_dist = torch.zeros((1, channels, opt.imgH, opt.imgW)) baseline_dist = baseline_dist.to(device) attributions = gs.attribute(input, baselines=baseline_dist, target=gtClassNum) elif "deeplift_" in lowestAccKey: dl = DeepLift(super_model) attributions = dl.attribute(input, target=gtClassNum) elif "saliency_" in lowestAccKey: saliency = Saliency(super_model) attributions = saliency.attribute(input, target=gtClassNum) elif "inpxgrad_" in lowestAccKey: input_x_gradient = InputXGradient(super_model) attributions = input_x_gradient.attribute(input, target=gtClassNum) elif "guidedbp_" in lowestAccKey: gbp = GuidedBackprop(super_model) attributions = gbp.attribute(input, target=gtClassNum) elif "deconv_" in lowestAccKey: deconv = Deconvolution(super_model) attributions = deconv.attribute(input, target=gtClassNum) elif "featablt_" in lowestAccKey: ablator = FeatureAblation(super_model) attributions = ablator.attribute(input, target=gtClassNum, feature_mask=segmTensor) elif "shapley_" in lowestAccKey: svs = ShapleyValueSampling(super_model) attributions = svs.attribute(input, target=gtClassNum, feature_mask=segmTensor) elif "lime_" in lowestAccKey: interpretable_model = SkLearnRidge(alpha=1, fit_intercept=True) ### This is the default used by LIME lime = Lime(super_model, interpretable_model=interpretable_model) attributions = lime.attribute(input, target=gtClassNum, feature_mask=segmTensor) elif "kernelshap_" in lowestAccKey: ks = KernelShap(super_model) attributions = ks.attribute(input, target=gtClassNum, feature_mask=segmTensor) else: assert False return attributions ### In addition to acquire_average_auc(), this function also returns the best selectivity_acc attr-based method ### pklFile - you need to pass pkl file here def acquire_bestacc_attr(opt, pickleFile): # pickleFile = "metrics_sensitivity_eval_results_IIIT5k_3000.pkl" # pickleFile = "/home/goo/str/str_vit_dataexplain_lambda/shapley_singlechar_ave_matrn_SVT.pkl" acquireSelectivity = True # If True, set to acquireInfidelity = False acquireSensitivity = False with open(pickleFile, 'rb') as f: data = pickle.load(f) metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens" selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle for imgData in data: if acquireSelectivity: for keyStr in imgData.keys(): if ("_acc" in keyStr or "_conf" in keyStr) and not ("_local_" in keyStr or "_global_local_" in keyStr): # Accept only selectivity if keyStr not in metricDict: metricDict[keyStr] = [] dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0] dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0 denom = [1] * len(dataList) # Denominator to normalize AUC auc_norm = np.trapz(dataList) / np.trapz(denom) metricDict[keyStr].append(auc_norm) elif acquireInfidelity: pass # TODO elif acquireSensitivity: pass # TODO lowestAccKey = "" lowestAcc = 10000000 for metricKey in metricDict: if "_acc" in metricKey: # Used for selectivity accuracy only statisticVal = statistics.mean(metricDict[metricKey]) if statisticVal < lowestAcc: lowestAcc = statisticVal lowestAccKey = metricKey # print("{}: {}".format(metricKey, statisticVal)) assert lowestAccKey!="" return lowestAccKey def saveAttrData(filename, attribution, segmData, origImg): pklData = {} pklData['attribution'] = torch.clone(attribution).detach().cpu().numpy() pklData['segmData'] = segmData pklData['origImg'] = origImg with open(filename, 'wb') as f: pickle.dump(pklData, f) ### New code (8/3/2022) to acquire average selectivity, infidelity, etc. after running captum test def acquire_average_auc(opt): # pickleFile = "metrics_sensitivity_eval_results_IIIT5k_3000.pkl" pickleFile = "/home/goo/str/str_vit_dataexplain_lambda/shapley_singlechar_ave_vitstr_IC03_860.pkl" acquireSelectivity = True # If True, set to acquireInfidelity = False acquireSensitivity = False with open(pickleFile, 'rb') as f: data = pickle.load(f) metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens" selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle for imgData in data: if acquireSelectivity: for keyStr in imgData.keys(): if "_acc" in keyStr or "_conf" in keyStr: # Accept only selectivity if keyStr not in metricDict: metricDict[keyStr] = [] dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0] dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0 denom = [1] * len(dataList) # Denominator to normalize AUC auc_norm = np.trapz(dataList) / np.trapz(denom) metricDict[keyStr].append(auc_norm) elif acquireInfidelity: pass # TODO elif acquireSensitivity: pass # TODO for metricKey in metricDict: print("{}: {}".format(metricKey, statistics.mean(metricDict[metricKey]))) ### Use this acquire list def acquireListOfAveAUC(opt): acquireSelectivity = True acquireInfidelity = False acquireSensitivity = False totalChars = 10 collectedMetricDict = {} for charNum in range(0, totalChars): pickleFile = f"/home/goo/str/str_vit_dataexplain_lambda/singlechar{charNum}_results_{totalChars}chardataset.pkl" with open(pickleFile, 'rb') as f: data = pickle.load(f) metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens" selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle for imgData in data: if acquireSelectivity: for keyStr in imgData.keys(): if "_acc" in keyStr or "_conf" in keyStr: # Accept only selectivity if keyStr not in metricDict: metricDict[keyStr] = [] dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0] dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0 denom = [1] * len(dataList) # Denominator to normalize AUC auc_norm = np.trapz(dataList) / np.trapz(denom) metricDict[keyStr].append(auc_norm) for metricKey in metricDict: selec_auc_normalize = statistics.mean(metricDict[metricKey]) if metricKey not in collectedMetricDict: collectedMetricDict[metricKey] = [] collectedMetricDict[metricKey].append(selec_auc_normalize) for collectedMetricDictKey in collectedMetricDict: print("{}: {}".format(collectedMetricDictKey, collectedMetricDict[collectedMetricDictKey])) for charNum in range(0, totalChars): selectivityAcrossCharsLs = [] for collectedMetricDictKey in collectedMetricDict: if "_acc" in collectedMetricDictKey: selectivityAcrossCharsLs.append(collectedMetricDict[collectedMetricDictKey][charNum]) print("accuracy -- {}: {}".format(charNum, statistics.mean(selectivityAcrossCharsLs))) for charNum in range(0, totalChars): selectivityAcrossCharsLs = [] for collectedMetricDictKey in collectedMetricDict: if "_conf" in collectedMetricDictKey: selectivityAcrossCharsLs.append(collectedMetricDict[collectedMetricDictKey][charNum]) print("confidence -- {}: {}".format(charNum, statistics.mean(selectivityAcrossCharsLs))) if __name__ == '__main__': # deleteInf() opt = get_args(is_train=False) """ vocab / character number configuration """ if opt.sensitive: opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). cudnn.benchmark = True cudnn.deterministic = True opt.num_gpu = torch.cuda.device_count() main(opt)