Spaces:
Build error
Build error
File size: 4,099 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import os
import pickle
from captum_improve_vitstr import rankedAttributionsBySegm
import matplotlib.pyplot as plt
from skimage.color import gray2rgb
from captum.attr._utils.visualization import visualize_image_attr
import torch
import numpy as np
def attr_one_dataset():
modelName = "vitstr"
datasetName = "IIIT5k_3000"
rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/"
attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/"
if not os.path.exists(attrOutputImgs):
os.makedirs(attrOutputImgs)
minNumber = 1000000
maxNumber = 0
# From a folder containing saved attribution pickle files, convert them into attribution images
for path, subdirs, files in os.walk(rootDir):
for name in files:
fullfilename = os.path.join(rootDir, name) # Value
# fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl
partfilename = fullfilename[fullfilename.rfind('/')+1:]
print("fullfilename: ", fullfilename)
imgNum = int(partfilename.split('_')[0])
attrImgName = partfilename.replace('.pkl', '.png')
minNumber = min(minNumber, imgNum)
maxNumber = max(maxNumber, imgNum)
with open(fullfilename, 'rb') as f:
pklData = pickle.load(f)
attributions = pklData['attribution']
segmDataNP = pklData['segmData']
origImgNP = pklData['origImg']
if np.isnan(attributions).any():
continue
attributions = torch.from_numpy(attributions)
rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP)
rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
rankedAttr = gray2rgb(rankedAttr)
mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
mplotfig.savefig(attrOutputImgs + attrImgName)
mplotfig.clear()
plt.close(mplotfig)
def attr_all_dataset():
modelName = "vitstr"
datasetNameList = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
for datasetName in datasetNameList:
rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/"
attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/"
if not os.path.exists(attrOutputImgs):
os.makedirs(attrOutputImgs)
minNumber = 1000000
maxNumber = 0
# From a folder containing saved attribution pickle files, convert them into attribution images
for path, subdirs, files in os.walk(rootDir):
for name in files:
fullfilename = os.path.join(rootDir, name) # Value
# fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl
partfilename = fullfilename[fullfilename.rfind('/')+1:]
imgNum = int(partfilename.split('_')[0])
attrImgName = partfilename.replace('.pkl', '.png')
minNumber = min(minNumber, imgNum)
maxNumber = max(maxNumber, imgNum)
with open(fullfilename, 'rb') as f:
pklData = pickle.load(f)
attributions = pklData['attribution']
segmDataNP = pklData['segmData']
origImgNP = pklData['origImg']
attributions = torch.from_numpy(attributions)
rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP)
rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
rankedAttr = gray2rgb(rankedAttr)
mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
mplotfig.savefig(attrOutputImgs + attrImgName)
mplotfig.clear()
plt.close(mplotfig)
if __name__ == '__main__':
attr_one_dataset()
# attr_all_dataset()
|