File size: 4,099 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import pickle
from captum_improve_vitstr import rankedAttributionsBySegm
import matplotlib.pyplot as plt
from skimage.color import gray2rgb
from captum.attr._utils.visualization import visualize_image_attr
import torch
import numpy as np

def attr_one_dataset():
    modelName = "vitstr"
    datasetName = "IIIT5k_3000"

    rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/"
    attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/"
    if not os.path.exists(attrOutputImgs):
        os.makedirs(attrOutputImgs)

    minNumber = 1000000
    maxNumber = 0
    # From a folder containing saved attribution pickle files, convert them into attribution images
    for path, subdirs, files in os.walk(rootDir):
        for name in files:
            fullfilename = os.path.join(rootDir, name) # Value
            # fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl
            partfilename = fullfilename[fullfilename.rfind('/')+1:]
            print("fullfilename: ", fullfilename)
            imgNum = int(partfilename.split('_')[0])
            attrImgName = partfilename.replace('.pkl', '.png')
            minNumber = min(minNumber, imgNum)
            maxNumber = max(maxNumber, imgNum)
            with open(fullfilename, 'rb') as f:
                pklData = pickle.load(f)
                attributions = pklData['attribution']
                segmDataNP = pklData['segmData']
                origImgNP = pklData['origImg']
            if np.isnan(attributions).any():
                continue
            attributions = torch.from_numpy(attributions)
            rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP)
            rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
            rankedAttr = gray2rgb(rankedAttr)
            mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
            mplotfig.savefig(attrOutputImgs + attrImgName)
            mplotfig.clear()
            plt.close(mplotfig)

def attr_all_dataset():
    modelName = "vitstr"

    datasetNameList = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']

    for datasetName in datasetNameList:
        rootDir = f"/data/goo/strattr/attributionData/{modelName}/{datasetName}/"
        attrOutputImgs = f"/data/goo/strattr/attributionDataImgs/{modelName}/{datasetName}/"
        if not os.path.exists(attrOutputImgs):
            os.makedirs(attrOutputImgs)

        minNumber = 1000000
        maxNumber = 0
        # From a folder containing saved attribution pickle files, convert them into attribution images
        for path, subdirs, files in os.walk(rootDir):
            for name in files:
                fullfilename = os.path.join(rootDir, name) # Value
                # fullfilename: /data/goo/strattr/attributionData/trba/CUTE80/66_featablt.pkl
                partfilename = fullfilename[fullfilename.rfind('/')+1:]
                imgNum = int(partfilename.split('_')[0])
                attrImgName = partfilename.replace('.pkl', '.png')
                minNumber = min(minNumber, imgNum)
                maxNumber = max(maxNumber, imgNum)
                with open(fullfilename, 'rb') as f:
                    pklData = pickle.load(f)
                    attributions = pklData['attribution']
                    segmDataNP = pklData['segmData']
                    origImgNP = pklData['origImg']
                attributions = torch.from_numpy(attributions)
                rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP)
                rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
                rankedAttr = gray2rgb(rankedAttr)
                mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
                mplotfig.savefig(attrOutputImgs + attrImgName)
                mplotfig.clear()
                plt.close(mplotfig)

if __name__ == '__main__':
    attr_one_dataset()
    # attr_all_dataset()