Spaces:
Build error
Build error
import pickle | |
import copy | |
import numpy as np | |
import statistics | |
import sys | |
import os | |
from captum.attr._utils.visualization import visualize_image_attr | |
import matplotlib.pyplot as plt | |
### New code (8/3/2022) to acquire average selectivity, infidelity, etc. after running captum test | |
def acquire_average_auc(): | |
# pickleFile = "metrics_sensitivity_eval_results_IIIT5k_3000.pkl" | |
pickleFile = "shapley_singlechar_ave_vitstr_IC15_1811.pkl" | |
acquireSelectivity = True # If True, set to | |
acquireInfidelity = False | |
acquireSensitivity = False | |
with open(pickleFile, 'rb') as f: | |
data = pickle.load(f) | |
metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens" | |
selectivity_acc_auc_normalized = [] # Normalized because it is divided by the full rectangle | |
for imgData in data: | |
if acquireSelectivity: | |
for keyStr in imgData.keys(): | |
if "_acc" in keyStr or "_conf" in keyStr: # Accept only selectivity | |
if keyStr not in metricDict: | |
metricDict[keyStr] = [] | |
dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0] | |
dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0 | |
denom = [1] * len(dataList) # Denominator to normalize AUC | |
auc_norm = np.trapz(dataList) / np.trapz(denom) | |
if not np.isnan(auc_norm).any(): | |
metricDict[keyStr].append(auc_norm) | |
elif acquireInfidelity: | |
pass # TODO | |
elif acquireSensitivity: | |
pass # TODO | |
for metricKey in metricDict: | |
print("{}: {}".format(metricKey, statistics.mean(metricDict[metricKey]))) | |
### | |
def sumOfAllAttributions(): | |
modelName = "trba" | |
datasetName = "IC15_1811" # IIIT5k_3000, IC03_867, IC13_857, IC15_1811 | |
mainRootDir = "/data/goo/strattr/" | |
rootDir = f"{mainRootDir}attributionData/{modelName}/{datasetName}/" | |
numpyOutputDir = mainRootDir | |
if modelName=="vitstr": | |
shape = [1, 1, 224, 224] | |
elif modelName =="parseq": | |
shape = [1, 3, 32, 128] | |
elif modelName =="trba": | |
shape = [1, 1, 32, 100] | |
# pickleFile = f"shapley_singlechar_ave_{modelName}_{datasetName}.pkl" | |
# acquireSelectivity = True | |
# with open(pickleFile, 'rb') as f: | |
# data = pickle.load(f) | |
# metricDict = {} # Keys: "saliency_acc", "saliency_conf", "saliency_infid", "saliency_sens" | |
# | |
# for imgData in data: | |
# if acquireSelectivity: | |
# for keyStr in imgData.keys(): | |
# print("keyStr: ", keyStr) | |
# if "_acc" in keyStr or "_conf" in keyStr: # Accept only selectivity | |
# if keyStr not in metricDict: | |
# metricDict[keyStr] = [] | |
# dataList = copy.deepcopy(imgData[keyStr]) # list of 0,1 [1,1,1,0,0,0,0] | |
# dataList.insert(0, 1) # Insert 1 at beginning to avoid np.trapz([1]) = 0.0 | |
# denom = [1] * len(dataList) # Denominator to normalize AUC | |
# auc_norm = np.trapz(dataList) / np.trapz(denom) | |
totalImgCount = 0 | |
# From a folder containing saved attribution pickle files, convert them into attribution images | |
for path, subdirs, files in os.walk(rootDir): | |
for name in files: | |
fullfilename = os.path.join(rootDir, name) # Value | |
# fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl | |
if "_gl." not in fullfilename.split('/')[-1]: # Accept only global+local | |
continue | |
totalImgCount += 1 | |
shape[0] = totalImgCount | |
main_np = np.memmap(numpyOutputDir+f"aveattr_{modelName}_{datasetName}.dat", dtype='float32', mode='w+', shape=tuple(shape)) | |
attrIdx = 0 | |
# From a folder containing saved attribution pickle files, convert them into attribution images | |
leftGreaterRightAcc = 0.0 | |
for path, subdirs, files in os.walk(rootDir): | |
for name in files: | |
fullfilename = os.path.join(rootDir, name) # Value | |
# fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl | |
if "_gl." not in fullfilename.split('/')[-1]: # Accept only global+local | |
continue | |
print("fullfilename: ", fullfilename) | |
# imgNum = int(partfilename.split('_')[0]) | |
# attrImgName = partfilename.replace('.pkl', '.png') | |
# minNumber = min(minNumber, imgNum) | |
# maxNumber = max(maxNumber, imgNum) | |
with open(fullfilename, 'rb') as f: | |
pklData = pickle.load(f) | |
attributions = pklData['attribution'] | |
segmDataNP = pklData['segmData'] | |
origImgNP = pklData['origImg'] | |
if np.isnan(attributions).any(): | |
continue | |
# attributions[0] = (attributions[0] - attributions[0].min()) / (attributions[0].max() - attributions[0].min()) | |
main_np[attrIdx] = attributions[0] | |
sumLeft = np.sum(attributions[0,:,:,0:attributions.shape[3]//2]) | |
sumRight = np.sum(attributions[0,:,:,attributions.shape[3]//2:]) | |
if sumLeft > sumRight: | |
leftGreaterRightAcc += 1.0 | |
attrIdx += 1 | |
print("leftGreaterRightAcc: ", leftGreaterRightAcc/attrIdx) | |
main_np.flush() | |
meanAveAttr = np.transpose(np.mean(main_np, axis=0), (1,2,0)) | |
print("meanAveAttr shape: ", meanAveAttr.shape) # (1, 3, 32, 128) | |
meanAveAttr = 2*((meanAveAttr - meanAveAttr.min()) / (meanAveAttr.max() - meanAveAttr.min())) - 1.0 | |
mplotfig, _ = visualize_image_attr(meanAveAttr, cmap='RdYlGn') # input should be in (H,W,C) | |
mplotfig.savefig(numpyOutputDir+f"aveattr_{modelName}_{datasetName}.png") | |
mplotfig.clear() | |
plt.close(mplotfig) | |
if __name__ == '__main__': | |
# acquire_average_auc() | |
sumOfAllAttributions() | |