strexp / attribution_ops.py
markytools's picture
added strexp
d61b9c7
import os
import pickle
from captum_improve_vitstr import rankedAttributionsBySegm
import matplotlib.pyplot as plt
from skimage.color import gray2rgb
from captum.attr._utils.visualization import visualize_image_attr
import torch
import numpy as np
def attr_one_dataset():
modelName = "vitstr"
datasetName = "IIIT5k_3000"
rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/"
attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/"
if not os.path.exists(attrOutputImgs):
os.makedirs(attrOutputImgs)
minNumber = 1000000
maxNumber = 0
# From a folder containing saved attribution pickle files, convert them into attribution images
for path, subdirs, files in os.walk(rootDir):
for name in files:
fullfilename = os.path.join(rootDir, name) # Value
# fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl
partfilename = fullfilename[fullfilename.rfind('/')+1:]
print("fullfilename: ", fullfilename)
imgNum = int(partfilename.split('_')[0])
attrImgName = partfilename.replace('.pkl', '.png')
minNumber = min(minNumber, imgNum)
maxNumber = max(maxNumber, imgNum)
with open(fullfilename, 'rb') as f:
pklData = pickle.load(f)
attributions = pklData['attribution']
segmDataNP = pklData['segmData']
origImgNP = pklData['origImg']
if np.isnan(attributions).any():
continue
attributions = torch.from_numpy(attributions)
rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP)
rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
rankedAttr = gray2rgb(rankedAttr)
mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
mplotfig.savefig(attrOutputImgs + attrImgName)
mplotfig.clear()
plt.close(mplotfig)
def attr_all_dataset():
modelName = "vitstr"
datasetNameList = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
for datasetName in datasetNameList:
rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/"
attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/"
if not os.path.exists(attrOutputImgs):
os.makedirs(attrOutputImgs)
minNumber = 1000000
maxNumber = 0
# From a folder containing saved attribution pickle files, convert them into attribution images
for path, subdirs, files in os.walk(rootDir):
for name in files:
fullfilename = os.path.join(rootDir, name) # Value
# fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl
partfilename = fullfilename[fullfilename.rfind('/')+1:]
imgNum = int(partfilename.split('_')[0])
attrImgName = partfilename.replace('.pkl', '.png')
minNumber = min(minNumber, imgNum)
maxNumber = max(maxNumber, imgNum)
with open(fullfilename, 'rb') as f:
pklData = pickle.load(f)
attributions = pklData['attribution']
segmDataNP = pklData['segmData']
origImgNP = pklData['origImg']
attributions = torch.from_numpy(attributions)
rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP)
rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
rankedAttr = gray2rgb(rankedAttr)
mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
mplotfig.savefig(attrOutputImgs + attrImgName)
mplotfig.clear()
plt.close(mplotfig)
if __name__ == '__main__':
attr_one_dataset()
# attr_all_dataset()