markytools commited on
Commit
dffa77d
·
1 Parent(s): c60f05d

updated app

Browse files
Files changed (1) hide show
  1. app.py +160 -17
app.py CHANGED
@@ -1,5 +1,84 @@
1
  import streamlit as st
2
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # x = st.slider('Select a value')
5
  # st.write(x, 'squared is', x * x)
@@ -17,20 +96,84 @@ with col2: # To display brand logo
17
 
18
  uploaded_file = st.file_uploader("Choose a file", type=["png", "jpg"])
19
  if uploaded_file is not None:
20
- # To read file as bytes:
21
- bytes_data = uploaded_file.getvalue()
22
- pillowImg = Image.open(uploaded_file)
23
- # print("pillowImg shape: ", )
24
- st.write(pillowImg.size)
25
-
26
- # # To convert to a string based IO:
27
- # stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
28
- # st.write(stringio)
29
- #
30
- # # To read file as string:
31
- # string_data = stringio.read()
32
- # st.write(string_data)
33
- #
34
- # # Can be used wherever a "file-like" object is accepted:
35
- # dataframe = pd.read_csv(uploaded_file)
36
- # st.write(dataframe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from PIL import Image
3
+ import settings
4
+ import captum
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.backends.cudnn as cudnn
9
+ from utils import get_args
10
+ from utils import CTCLabelConverter, AttnLabelConverter, Averager, TokenLabelConverter
11
+ import string
12
+ import time
13
+ import sys
14
+ from dataset import hierarchical_dataset, AlignCollate
15
+ import validators
16
+ from model import Model, STRScore
17
+ from PIL import Image
18
+ from lime.wrappers.scikit_image import SegmentationAlgorithm
19
+ from captum._utils.models.linear_model import SkLearnLinearModel, SkLearnRidge
20
+ import random
21
+ import os
22
+ from skimage.color import gray2rgb
23
+ import pickle
24
+ from train_shap_corr import getPredAndConf
25
+ import re
26
+ from captum_test import acquire_average_auc, saveAttrData
27
+ import copy
28
+ from skimage.color import gray2rgb
29
+ from matplotlib import pyplot as plt
30
+ from torchvision import transforms
31
+
32
+ device = torch.device('cpu')
33
+ opt = get_args(is_train=False)
34
+
35
+ """ vocab / character number configuration """
36
+ if opt.sensitive:
37
+ opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
38
+
39
+ cudnn.benchmark = True
40
+ cudnn.deterministic = True
41
+ # opt.num_gpu = torch.cuda.device_count()
42
+
43
+ # combineBestDataXAI(opt)
44
+ # acquire_average_auc(opt)
45
+ # acquireSingleCharAttrAve(opt)
46
+ modelName = "parseq"
47
+ opt.modelName = modelName
48
+ # opt.eval_data = "datasets/data_lmdb_release/evaluation"
49
+
50
+ if modelName=="vitstr":
51
+ opt.benchmark_all_eval = True
52
+ opt.Transformation = "None"
53
+ opt.FeatureExtraction = "None"
54
+ opt.SequenceModeling = "None"
55
+ opt.Prediction = "None"
56
+ opt.Transformer = True
57
+ opt.sensitive = True
58
+ opt.imgH = 224
59
+ opt.imgW = 224
60
+ opt.data_filtering_off = True
61
+ opt.TransformerModel= "vitstr_base_patch16_224"
62
+ opt.saved_model = "pretrained/vitstr_base_patch16_224_aug.pth"
63
+ opt.batch_size = 1
64
+ opt.workers = 0
65
+ opt.scorer = "mean"
66
+ opt.blackbg = True
67
+ elif modelName=="parseq":
68
+ opt.benchmark_all_eval = True
69
+ opt.Transformation = "None"
70
+ opt.FeatureExtraction = "None"
71
+ opt.SequenceModeling = "None"
72
+ opt.Prediction = "None"
73
+ opt.Transformer = True
74
+ opt.sensitive = True
75
+ opt.imgH = 32
76
+ opt.imgW = 128
77
+ opt.data_filtering_off = True
78
+ opt.batch_size = 1
79
+ opt.workers = 0
80
+ opt.scorer = "mean"
81
+ opt.blackbg = True
82
 
83
  # x = st.slider('Select a value')
84
  # st.write(x, 'squared is', x * x)
 
96
 
97
  uploaded_file = st.file_uploader("Choose a file", type=["png", "jpg"])
98
  if uploaded_file is not None:
99
+ # To read file as bytes:
100
+ bytes_data = uploaded_file.getvalue()
101
+ pilImg = Image.open(uploaded_file)
102
+
103
+ orig_img_tensors = transforms.ToTensor()(pilImg).unsqueeze(0)
104
+ img1 = orig_img_tensors.to(device)
105
+ # image_tensors = ((torch.clone(orig_img_tensors) + 1.0) / 2.0) * 255.0
106
+ image_tensors = torch.mean(orig_img_tensors, dim=1).unsqueeze(0).unsqueeze(0)
107
+ imgDataDict = {}
108
+ img_numpy = image_tensors.cpu().detach().numpy()[0] ### Need to set batch size to 1 only
109
+ if img_numpy.shape[0] == 1:
110
+ img_numpy = gray2rgb(img_numpy[0])
111
+ # print("img_numpy shape: ", img_numpy.shape) # (1, 32, 128, 3)
112
+ segmOutput = segmentation_fn(img_numpy[0])
113
+
114
+ results_dict = {}
115
+ aveAttr = []
116
+ aveAttr_charContrib = []
117
+ target = converter.encode([labels])
118
+
119
+ # labels: RONALDO
120
+ segmDataNP = segmOutput
121
+ img1.requires_grad = True
122
+ bgImg = torch.zeros(img1.shape).to(device)
123
+
124
+ # preds = model(img1, seqlen=converter.batch_max_length)
125
+ input = img1
126
+ origImgNP = torch.clone(orig_img_tensors).detach().cpu().numpy()[0][0] # (1, 1, 224, 224)
127
+ origImgNP = gray2rgb(origImgNP)
128
+ charOffset = 0
129
+ img1 = transforms.Normalize(0.5, 0.5)(img1) # Between -1 to 1
130
+ target = converter.encode([labels])
131
+
132
+ ### Local explanations only
133
+ collectedAttributions = []
134
+ for charIdx in range(0, len(labels)):
135
+ scoring_singlechar.setSingleCharOutput(charIdx + charOffset)
136
+ gtClassNum = target[0][charIdx + charOffset]
137
+
138
+ gs = GradientShap(super_pixel_model_singlechar)
139
+ baseline_dist = torch.zeros((1, 3, opt.imgH, opt.imgW))
140
+ baseline_dist = baseline_dist.to(device)
141
+ attributions = gs.attribute(input, baselines=baseline_dist, target=0)
142
+ collectedAttributions.append(attributions)
143
+ aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0)
144
+ # if not torch.isnan(aveAttributions).any():
145
+ # rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP)
146
+ # rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
147
+ # rankedAttr = gray2rgb(rankedAttr)
148
+ # mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
149
+ # mplotfig.savefig(outputDir + '{}_shapley_l.png'.format(nameNoExt))
150
+ # mplotfig.clear()
151
+ # plt.close(mplotfig)
152
+
153
+ ### Local Sampling
154
+ gs = GradientShap(super_pixel_model)
155
+ baseline_dist = torch.zeros((1, 3, opt.imgH, opt.imgW))
156
+ baseline_dist = baseline_dist.to(device)
157
+ attributions = gs.attribute(input, baselines=baseline_dist, target=0)
158
+ # if not torch.isnan(attributions).any():
159
+ # collectedAttributions.append(attributions)
160
+ # rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP)
161
+ # rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
162
+ # rankedAttr = gray2rgb(rankedAttr)
163
+ # mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
164
+ # mplotfig.savefig(outputDir + '{}_shapley.png'.format(nameNoExt))
165
+ # mplotfig.clear()
166
+ # plt.close(mplotfig)
167
+
168
+ ### Global + Local context
169
+ aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0)
170
+ if not torch.isnan(aveAttributions).any():
171
+ rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP)
172
+ rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
173
+ rankedAttr = gray2rgb(rankedAttr)
174
+ mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
175
+ fig = mplotfig.figure(figsize=(8,8))
176
+ st.pyplot(fig)
177
+ # mplotfig.savefig(outputDir + '{}_shapley_gl.png'.format(nameNoExt))
178
+ # mplotfig.clear()
179
+ # plt.close(mplotfig)