Spaces:
Build error
Build error
import settings | |
import captum | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torch.backends.cudnn as cudnn | |
from utils import get_args | |
from utils import CTCLabelConverter, AttnLabelConverter, Averager, TokenLabelConverter | |
import string | |
import time | |
import sys | |
from dataset import hierarchical_dataset, AlignCollate | |
import validators | |
from model import Model, STRScore | |
from PIL import Image | |
from lime.wrappers.scikit_image import SegmentationAlgorithm | |
from captum._utils.models.linear_model import SkLearnLinearModel, SkLearnRidge | |
import random | |
import os | |
from skimage.color import gray2rgb | |
import pickle | |
from train_shap_corr import getPredAndConf | |
import re | |
import copy | |
import statistics | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
from captum.attr import ( | |
GradientShap, | |
DeepLift, | |
DeepLiftShap, | |
IntegratedGradients, | |
LayerConductance, | |
NeuronConductance, | |
NoiseTunnel, | |
Saliency, | |
InputXGradient, | |
GuidedBackprop, | |
Deconvolution, | |
GuidedGradCam, | |
FeatureAblation, | |
ShapleyValueSampling, | |
Lime, | |
KernelShap | |
) | |
from captum.metrics import ( | |
infidelity, | |
sensitivity_max | |
) | |
### Returns the mean for each segmentation having shape as the same as the input | |
### This function can only one attribution image at a time | |
def averageSegmentsOut(attr, segments): | |
averagedInput = torch.clone(attr) | |
sortedDict = {} | |
for x in np.unique(segments): | |
segmentMean = torch.mean(attr[segments == x][:]) | |
sortedDict[x] = float(segmentMean.detach().cpu().numpy()) | |
averagedInput[segments == x] = segmentMean | |
return averagedInput, sortedDict | |
### Output and save segmentations only for one dataset only | |
def outputSegmOnly(opt): | |
### targetDataset - one dataset only, SVTP-645, CUTE80-288images | |
targetDataset = "CUTE80" # ['IIIT5k_3000', 'SVT', 'IC03_867', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] | |
segmRootDir = "/home/uclpc1/Documents/STR/datasets/segmentations/224X224/{}/".format(targetDataset) | |
if not os.path.exists(segmRootDir): | |
os.makedirs(segmRootDir) | |
opt.eval = True | |
### Only IIIT5k_3000 | |
if opt.fast_acc: | |
# # To easily compute the total accuracy of our paper. | |
eval_data_list = [targetDataset] | |
else: | |
# The evaluation datasets, dataset order is same with Table 1 in our paper. | |
eval_data_list = [targetDataset] | |
### Taken from LIME | |
segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, | |
max_dist=200, ratio=0.2, | |
random_seed=random.randint(0, 1000)) | |
for eval_data in eval_data_list: | |
eval_data_path = os.path.join(opt.eval_data, eval_data) | |
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt) | |
eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt) | |
evaluation_loader = torch.utils.data.DataLoader( | |
eval_data, batch_size=1, | |
shuffle=False, | |
num_workers=int(opt.workers), | |
collate_fn=AlignCollate_evaluation, pin_memory=True) | |
for i, (image_tensors, labels) in enumerate(evaluation_loader): | |
imgDataDict = {} | |
img_numpy = image_tensors.cpu().detach().numpy()[0] ### Need to set batch size to 1 only | |
if img_numpy.shape[0] == 1: | |
img_numpy = gray2rgb(img_numpy[0]) | |
# print("img_numpy shape: ", img_numpy.shape) # (224,224,3) | |
segmOutput = segmentation_fn(img_numpy) | |
imgDataDict['segdata'] = segmOutput | |
imgDataDict['label'] = labels[0] | |
outputPickleFile = segmRootDir + "{}.pkl".format(i) | |
with open(outputPickleFile, 'wb') as f: | |
pickle.dump(imgDataDict, f) | |
def acquireSelectivityHit(origImg, attributions, segmentations, model, converter, labels, scoring): | |
# print("segmentations unique len: ", np.unique(segmentations)) | |
aveSegmentations, sortedDict = averageSegmentsOut(attributions[0,0], segmentations) | |
sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])] | |
sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score | |
# print("sortedDict: ", sortedDict) # {0: -5.51e-06, 1: -1.469e-05, 2: -3.06e-05,...} | |
# print("aveSegmentations unique len: ", np.unique(aveSegmentations)) | |
# print("aveSegmentations device: ", aveSegmentations.device) # cuda:0 | |
# print("aveSegmentations shape: ", aveSegmentations.shape) # (224,224) | |
# print("aveSegmentations: ", aveSegmentations) | |
n_correct = [] | |
confidenceList = [] # First index is one feature removed, second index two features removed, and so on... | |
clonedImg = torch.clone(origImg) | |
gt = str(labels) | |
for totalSegToHide in range(0, len(sortedKeys)): | |
### Acquire LIME prediction result | |
currentSegmentToHide = sortedKeys[totalSegToHide] | |
clonedImg[0,0][segmentations == currentSegmentToHide] = 0.0 | |
pred, confScore = getPredAndConf(opt, model, scoring, clonedImg, converter, np.array([gt])) | |
# To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. | |
if opt.sensitive and opt.data_filtering_off: | |
pred = pred.lower() | |
gt = gt.lower() | |
alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz' | |
out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]' | |
pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred) | |
gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt) | |
if pred == gt: | |
n_correct.append(1) | |
else: | |
n_correct.append(0) | |
confScore = confScore[0][0]*100 | |
confidenceList.append(confScore) | |
return n_correct, confidenceList | |
### Once you have the selectivity_eval_results.pkl file, | |
def acquire_selectivity_auc(opt, pkl_filename=None): | |
if pkl_filename is None: | |
pkl_filename = "/home/goo/str/str_vit_dataexplain_lambda/metrics_sensitivity_eval_results_CUTE80.pkl" # VITSTR | |
accKeys = [] | |
with open(pkl_filename, 'rb') as f: | |
selectivity_data = pickle.load(f) | |
for resDictIdx, resDict in enumerate(selectivity_data): | |
keylistAcc = [] | |
keylistConf = [] | |
metricsKeys = resDict.keys() | |
for keyStr in resDict.keys(): | |
if "_acc" in keyStr: keylistAcc.append(keyStr) | |
if "_conf" in keyStr: keylistConf.append(keyStr) | |
# Need to check if network correctly predicted the image | |
for metrics_accStr in keylistAcc: | |
if 1 not in resDict[metrics_accStr]: print("resDictIdx") | |
## gtClassNum - set to gtClassNum=0 for standard implemention, or specific class idx for local explanation | |
def acquireAttribution(opt, super_model, input, segmTensor, gtClassNum, lowestAccKey, device): | |
channels = 1 | |
if opt.rgb: | |
channels = 3 | |
### Perform attribution | |
if "intgrad_" in lowestAccKey: | |
ig = IntegratedGradients(super_model) | |
attributions = ig.attribute(input, target=gtClassNum) | |
elif "gradshap_" in lowestAccKey: | |
gs = GradientShap(super_model) | |
baseline_dist = torch.zeros((1, channels, opt.imgH, opt.imgW)) | |
baseline_dist = baseline_dist.to(device) | |
attributions = gs.attribute(input, baselines=baseline_dist, target=gtClassNum) | |
elif "deeplift_" in lowestAccKey: | |
dl = DeepLift(super_model) | |
attributions = dl.attribute(input, target=gtClassNum) | |
elif "saliency_" in lowestAccKey: | |
saliency = Saliency(super_model) | |
attributions = saliency.attribute(input, target=gtClassNum) | |
elif "inpxgrad_" in lowestAccKey: | |
input_x_gradient = InputXGradient(super_model) | |
attributions = input_x_gradient.attribute(input, target=gtClassNum) | |
elif "guidedbp_" in lowestAccKey: | |
gbp = GuidedBackprop(super_model) | |
attributions = gbp.attribute(input, target=gtClassNum) | |
elif "deconv_" in lowestAccKey: | |
deconv = Deconvolution(super_model) | |
attributions = deconv.attribute(input, target=gtClassNum) | |
elif "featablt_" in lowestAccKey: | |
ablator = FeatureAblation(super_model) | |
attributions = ablator.attribute(input, target=gtClassNum, feature_mask=segmTensor) | |
elif "shapley_" in lowestAccKey: | |
svs = ShapleyValueSampling(super_model) | |
attributions = svs.attribute(input, target=gtClassNum, feature_mask=segmTensor) | |
elif "lime_" in lowestAccKey: | |
interpretable_model = SkLearnRidge(alpha=1, fit_intercept=True) ### This is the default used by LIME | |
lime = Lime(super_model, interpretable_model=interpretable_model) | |
attributions = lime.attribute(input, target=gtClassNum, feature_mask=segmTensor) | |
elif "kernelshap_" in lowestAccKey: | |
ks = KernelShap(super_model) | |
attributions = ks.attribute(input, target=gtClassNum, feature_mask=segmTensor) | |
else: | |
assert False | |
return attributions | |
### In addition to acquire_average_auc(), this function also returns the best selectivity_acc attr-based method | |
### pklFile - you need to pass pkl file here | |
def acquire_bestacc_attr(opt, pickleFile): | |
# pickleFile = "metrics_sensitivity_eval_results_IIIT5k_3000.pkl" | |
# pickleFile = "/home/goo/str/str_vit_dataexplain_lambda/shapley_singlechar_ave_matrn_SVT.pkl" | |
acquireSelectivity = True # If True, set to | |
acquireInfidelity = False | |
acquireSensitivity = False | |
with open(pickleFile, 'rb') as f: | |
data = pickle.load(f) | |
metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens" | |
selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle | |
for imgData in data: | |
if acquireSelectivity: | |
for keyStr in imgData.keys(): | |
if ("_acc" in keyStr or "_conf" in keyStr) and not ("_local_" in keyStr or "_global_local_" in keyStr): # Accept only selectivity | |
if keyStr not in metricDict: | |
metricDict[keyStr] = [] | |
dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0] | |
dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0 | |
denom = [1] * len(dataList) # Denominator to normalize AUC | |
auc_norm = np.trapz(dataList) / np.trapz(denom) | |
metricDict[keyStr].append(auc_norm) | |
elif acquireInfidelity: | |
pass # TODO | |
elif acquireSensitivity: | |
pass # TODO | |
lowestAccKey = "" | |
lowestAcc = 10000000 | |
for metricKey in metricDict: | |
if "_acc" in metricKey: # Used for selectivity accuracy only | |
statisticVal = statistics.mean(metricDict[metricKey]) | |
if statisticVal < lowestAcc: | |
lowestAcc = statisticVal | |
lowestAccKey = metricKey | |
# print("{}: {}".format(metricKey, statisticVal)) | |
assert lowestAccKey!="" | |
return lowestAccKey | |
def saveAttrData(filename, attribution, segmData, origImg): | |
pklData = {} | |
pklData['attribution'] = torch.clone(attribution).detach().cpu().numpy() | |
pklData['segmData'] = segmData | |
pklData['origImg'] = origImg | |
with open(filename, 'wb') as f: | |
pickle.dump(pklData, f) | |
### New code (8/3/2022) to acquire average selectivity, infidelity, etc. after running captum test | |
def acquire_average_auc(opt): | |
# pickleFile = "metrics_sensitivity_eval_results_IIIT5k_3000.pkl" | |
pickleFile = "/home/goo/str/str_vit_dataexplain_lambda/shapley_singlechar_ave_vitstr_IC03_860.pkl" | |
acquireSelectivity = True # If True, set to | |
acquireInfidelity = False | |
acquireSensitivity = False | |
with open(pickleFile, 'rb') as f: | |
data = pickle.load(f) | |
metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens" | |
selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle | |
for imgData in data: | |
if acquireSelectivity: | |
for keyStr in imgData.keys(): | |
if "_acc" in keyStr or "_conf" in keyStr: # Accept only selectivity | |
if keyStr not in metricDict: | |
metricDict[keyStr] = [] | |
dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0] | |
dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0 | |
denom = [1] * len(dataList) # Denominator to normalize AUC | |
auc_norm = np.trapz(dataList) / np.trapz(denom) | |
metricDict[keyStr].append(auc_norm) | |
elif acquireInfidelity: | |
pass # TODO | |
elif acquireSensitivity: | |
pass # TODO | |
for metricKey in metricDict: | |
print("{}: {}".format(metricKey, statistics.mean(metricDict[metricKey]))) | |
### Use this acquire list | |
def acquireListOfAveAUC(opt): | |
acquireSelectivity = True | |
acquireInfidelity = False | |
acquireSensitivity = False | |
totalChars = 10 | |
collectedMetricDict = {} | |
for charNum in range(0, totalChars): | |
pickleFile = f"/home/goo/str/str_vit_dataexplain_lambda/singlechar{charNum}_results_{totalChars}chardataset.pkl" | |
with open(pickleFile, 'rb') as f: | |
data = pickle.load(f) | |
metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens" | |
selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle | |
for imgData in data: | |
if acquireSelectivity: | |
for keyStr in imgData.keys(): | |
if "_acc" in keyStr or "_conf" in keyStr: # Accept only selectivity | |
if keyStr not in metricDict: | |
metricDict[keyStr] = [] | |
dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0] | |
dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0 | |
denom = [1] * len(dataList) # Denominator to normalize AUC | |
auc_norm = np.trapz(dataList) / np.trapz(denom) | |
metricDict[keyStr].append(auc_norm) | |
for metricKey in metricDict: | |
selec_auc_normalize = statistics.mean(metricDict[metricKey]) | |
if metricKey not in collectedMetricDict: | |
collectedMetricDict[metricKey] = [] | |
collectedMetricDict[metricKey].append(selec_auc_normalize) | |
for collectedMetricDictKey in collectedMetricDict: | |
print("{}: {}".format(collectedMetricDictKey, collectedMetricDict[collectedMetricDictKey])) | |
for charNum in range(0, totalChars): | |
selectivityAcrossCharsLs = [] | |
for collectedMetricDictKey in collectedMetricDict: | |
if "_acc" in collectedMetricDictKey: | |
selectivityAcrossCharsLs.append(collectedMetricDict[collectedMetricDictKey][charNum]) | |
print("accuracy -- {}: {}".format(charNum, statistics.mean(selectivityAcrossCharsLs))) | |
for charNum in range(0, totalChars): | |
selectivityAcrossCharsLs = [] | |
for collectedMetricDictKey in collectedMetricDict: | |
if "_conf" in collectedMetricDictKey: | |
selectivityAcrossCharsLs.append(collectedMetricDict[collectedMetricDictKey][charNum]) | |
print("confidence -- {}: {}".format(charNum, statistics.mean(selectivityAcrossCharsLs))) | |
if __name__ == '__main__': | |
# deleteInf() | |
opt = get_args(is_train=False) | |
""" vocab / character number configuration """ | |
if opt.sensitive: | |
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). | |
cudnn.benchmark = True | |
cudnn.deterministic = True | |
opt.num_gpu = torch.cuda.device_count() | |
main(opt) | |