Spaces:
Build error
Build error
import os | |
import pickle | |
from captum_improve_vitstr import rankedAttributionsBySegm | |
import matplotlib.pyplot as plt | |
from skimage.color import gray2rgb | |
from captum.attr._utils.visualization import visualize_image_attr | |
import torch | |
import numpy as np | |
def attr_one_dataset(): | |
modelName = "vitstr" | |
datasetName = "IIIT5k_3000" | |
rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/" | |
attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/" | |
if not os.path.exists(attrOutputImgs): | |
os.makedirs(attrOutputImgs) | |
minNumber = 1000000 | |
maxNumber = 0 | |
# From a folder containing saved attribution pickle files, convert them into attribution images | |
for path, subdirs, files in os.walk(rootDir): | |
for name in files: | |
fullfilename = os.path.join(rootDir, name) # Value | |
# fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl | |
partfilename = fullfilename[fullfilename.rfind('/')+1:] | |
print("fullfilename: ", fullfilename) | |
imgNum = int(partfilename.split('_')[0]) | |
attrImgName = partfilename.replace('.pkl', '.png') | |
minNumber = min(minNumber, imgNum) | |
maxNumber = max(maxNumber, imgNum) | |
with open(fullfilename, 'rb') as f: | |
pklData = pickle.load(f) | |
attributions = pklData['attribution'] | |
segmDataNP = pklData['segmData'] | |
origImgNP = pklData['origImg'] | |
if np.isnan(attributions).any(): | |
continue | |
attributions = torch.from_numpy(attributions) | |
rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) | |
rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] | |
rankedAttr = gray2rgb(rankedAttr) | |
mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') | |
mplotfig.savefig(attrOutputImgs + attrImgName) | |
mplotfig.clear() | |
plt.close(mplotfig) | |
def attr_all_dataset(): | |
modelName = "vitstr" | |
datasetNameList = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] | |
for datasetName in datasetNameList: | |
rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/" | |
attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/" | |
if not os.path.exists(attrOutputImgs): | |
os.makedirs(attrOutputImgs) | |
minNumber = 1000000 | |
maxNumber = 0 | |
# From a folder containing saved attribution pickle files, convert them into attribution images | |
for path, subdirs, files in os.walk(rootDir): | |
for name in files: | |
fullfilename = os.path.join(rootDir, name) # Value | |
# fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl | |
partfilename = fullfilename[fullfilename.rfind('/')+1:] | |
imgNum = int(partfilename.split('_')[0]) | |
attrImgName = partfilename.replace('.pkl', '.png') | |
minNumber = min(minNumber, imgNum) | |
maxNumber = max(maxNumber, imgNum) | |
with open(fullfilename, 'rb') as f: | |
pklData = pickle.load(f) | |
attributions = pklData['attribution'] | |
segmDataNP = pklData['segmData'] | |
origImgNP = pklData['origImg'] | |
attributions = torch.from_numpy(attributions) | |
rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP) | |
rankedAttr = rankedAttr.detach().cpu().numpy()[0][0] | |
rankedAttr = gray2rgb(rankedAttr) | |
mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn') | |
mplotfig.savefig(attrOutputImgs + attrImgName) | |
mplotfig.clear() | |
plt.close(mplotfig) | |
if __name__ == '__main__': | |
attr_one_dataset() | |
# attr_all_dataset() | |