markytools commited on
Commit
3431661
·
1 Parent(s): 21bfe64

updated app

Browse files
Files changed (2) hide show
  1. app(orig).py +36 -0
  2. app.py +17 -160
app(orig).py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+
4
+ # x = st.slider('Select a value')
5
+ # st.write(x, 'squared is', x * x)
6
+
7
+ image = Image.open('demo_image/demo_ballys.jpg') #Brand logo image (optional)
8
+ #Create two columns with different width
9
+ col1, col2 = st.columns( [0.8, 0.2])
10
+ with col1: # To display the header text using css style
11
+ st.markdown(""" <style> .font {
12
+ font-size:35px ; font-family: 'Cooper Black'; color: #FF9633;}
13
+ </style> """, unsafe_allow_html=True)
14
+ st.markdown('<p class="font">Upload your photo here...</p>', unsafe_allow_html=True)
15
+ with col2: # To display brand logo
16
+ st.image(image, width=150)
17
+
18
+ uploaded_file = st.file_uploader("Choose a file", type=["png", "jpg"])
19
+ if uploaded_file is not None:
20
+ # To read file as bytes:
21
+ bytes_data = uploaded_file.getvalue()
22
+ pillowImg = Image.open(uploaded_file)
23
+ # print("pillowImg shape: ", )
24
+ st.write(pillowImg.size)
25
+
26
+ # # To convert to a string based IO:
27
+ # stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
28
+ # st.write(stringio)
29
+ #
30
+ # # To read file as string:
31
+ # string_data = stringio.read()
32
+ # st.write(string_data)
33
+ #
34
+ # # Can be used wherever a "file-like" object is accepted:
35
+ # dataframe = pd.read_csv(uploaded_file)
36
+ # st.write(dataframe)
app.py CHANGED
@@ -1,84 +1,5 @@
1
  import streamlit as st
2
  from PIL import Image
3
- import settings
4
- import captum
5
- import numpy as np
6
- import torch
7
- import torch.nn.functional as F
8
- import torch.backends.cudnn as cudnn
9
- from utils import get_args
10
- from utils import CTCLabelConverter, AttnLabelConverter, Averager, TokenLabelConverter
11
- import string
12
- import time
13
- import sys
14
- from dataset import hierarchical_dataset, AlignCollate
15
- import validators
16
- from model import Model, STRScore
17
- from PIL import Image
18
- from lime.wrappers.scikit_image import SegmentationAlgorithm
19
- from captum._utils.models.linear_model import SkLearnLinearModel, SkLearnRidge
20
- import random
21
- import os
22
- from skimage.color import gray2rgb
23
- import pickle
24
- from train_shap_corr import getPredAndConf
25
- import re
26
- from captum_test import acquire_average_auc, saveAttrData
27
- import copy
28
- from skimage.color import gray2rgb
29
- from matplotlib import pyplot as plt
30
- from torchvision import transforms
31
-
32
- device = torch.device('cpu')
33
- opt = get_args(is_train=False)
34
-
35
- """ vocab / character number configuration """
36
- if opt.sensitive:
37
- opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
38
-
39
- cudnn.benchmark = True
40
- cudnn.deterministic = True
41
- # opt.num_gpu = torch.cuda.device_count()
42
-
43
- # combineBestDataXAI(opt)
44
- # acquire_average_auc(opt)
45
- # acquireSingleCharAttrAve(opt)
46
- modelName = "parseq"
47
- opt.modelName = modelName
48
- # opt.eval_data = "datasets/data_lmdb_release/evaluation"
49
-
50
- if modelName=="vitstr":
51
- opt.benchmark_all_eval = True
52
- opt.Transformation = "None"
53
- opt.FeatureExtraction = "None"
54
- opt.SequenceModeling = "None"
55
- opt.Prediction = "None"
56
- opt.Transformer = True
57
- opt.sensitive = True
58
- opt.imgH = 224
59
- opt.imgW = 224
60
- opt.data_filtering_off = True
61
- opt.TransformerModel= "vitstr_base_patch16_224"
62
- opt.saved_model = "pretrained/vitstr_base_patch16_224_aug.pth"
63
- opt.batch_size = 1
64
- opt.workers = 0
65
- opt.scorer = "mean"
66
- opt.blackbg = True
67
- elif modelName=="parseq":
68
- opt.benchmark_all_eval = True
69
- opt.Transformation = "None"
70
- opt.FeatureExtraction = "None"
71
- opt.SequenceModeling = "None"
72
- opt.Prediction = "None"
73
- opt.Transformer = True
74
- opt.sensitive = True
75
- opt.imgH = 32
76
- opt.imgW = 128
77
- opt.data_filtering_off = True
78
- opt.batch_size = 1
79
- opt.workers = 0
80
- opt.scorer = "mean"
81
- opt.blackbg = True
82
 
83
  # x = st.slider('Select a value')
84
  # st.write(x, 'squared is', x * x)
@@ -96,84 +17,20 @@ with col2: # To display brand logo
96
 
97
  uploaded_file = st.file_uploader("Choose a file", type=["png", "jpg"])
98
  if uploaded_file is not None:
99
- # To read file as bytes:
100
- bytes_data = uploaded_file.getvalue()
101
- pilImg = Image.open(uploaded_file)
102
-
103
- orig_img_tensors = transforms.ToTensor()(pilImg).unsqueeze(0)
104
- img1 = orig_img_tensors.to(device)
105
- # image_tensors = ((torch.clone(orig_img_tensors) + 1.0) / 2.0) * 255.0
106
- image_tensors = torch.mean(orig_img_tensors, dim=1).unsqueeze(0).unsqueeze(0)
107
- imgDataDict = {}
108
- img_numpy = image_tensors.cpu().detach().numpy()[0] ### Need to set batch size to 1 only
109
- if img_numpy.shape[0] == 1:
110
- img_numpy = gray2rgb(img_numpy[0])
111
- # print("img_numpy shape: ", img_numpy.shape) # (1, 32, 128, 3)
112
- segmOutput = segmentation_fn(img_numpy[0])
113
-
114
- results_dict = {}
115
- aveAttr = []
116
- aveAttr_charContrib = []
117
- target = converter.encode([labels])
118
-
119
- # labels: RONALDO
120
- segmDataNP = segmOutput
121
- img1.requires_grad = True
122
- bgImg = torch.zeros(img1.shape).to(device)
123
-
124
- # preds = model(img1, seqlen=converter.batch_max_length)
125
- input = img1
126
- origImgNP = torch.clone(orig_img_tensors).detach().cpu().numpy()[0][0] # (1, 1, 224, 224)
127
- origImgNP = gray2rgb(origImgNP)
128
- charOffset = 0
129
- img1 = transforms.Normalize(0.5, 0.5)(img1) # Between -1 to 1
130
- target = converter.encode([labels])
131
-
132
- ### Local explanations only
133
- collectedAttributions = []
134
- for charIdx in range(0, len(labels)):
135
- scoring_singlechar.setSingleCharOutput(charIdx + charOffset)
136
- gtClassNum = target[0][charIdx + charOffset]
137
-
138
- gs = GradientShap(super_pixel_model_singlechar)
139
- baseline_dist = torch.zeros((1, 3, opt.imgH, opt.imgW))
140
- baseline_dist = baseline_dist.to(device)
141
- attributions = gs.attribute(input, baselines=baseline_dist, target=0)
142
- collectedAttributions.append(attributions)
143
- aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0)
144
- # if not torch.isnan(aveAttributions).any():
145
- # rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP)
146
- # rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
147
- # rankedAttr = gray2rgb(rankedAttr)
148
- # mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
149
- # mplotfig.savefig(outputDir + '{}_shapley_l.png'.format(nameNoExt))
150
- # mplotfig.clear()
151
- # plt.close(mplotfig)
152
-
153
- ### Local Sampling
154
- gs = GradientShap(super_pixel_model)
155
- baseline_dist = torch.zeros((1, 3, opt.imgH, opt.imgW))
156
- baseline_dist = baseline_dist.to(device)
157
- attributions = gs.attribute(input, baselines=baseline_dist, target=0)
158
- # if not torch.isnan(attributions).any():
159
- # collectedAttributions.append(attributions)
160
- # rankedAttr = rankedAttributionsBySegm(attributions, segmDataNP)
161
- # rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
162
- # rankedAttr = gray2rgb(rankedAttr)
163
- # mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
164
- # mplotfig.savefig(outputDir + '{}_shapley.png'.format(nameNoExt))
165
- # mplotfig.clear()
166
- # plt.close(mplotfig)
167
-
168
- ### Global + Local context
169
- aveAttributions = torch.mean(torch.cat(collectedAttributions,dim=0), dim=0).unsqueeze(0)
170
- if not torch.isnan(aveAttributions).any():
171
- rankedAttr = rankedAttributionsBySegm(aveAttributions, segmDataNP)
172
- rankedAttr = rankedAttr.detach().cpu().numpy()[0][0]
173
- rankedAttr = gray2rgb(rankedAttr)
174
- mplotfig, _ = visualize_image_attr(rankedAttr, origImgNP, method='blended_heat_map', cmap='RdYlGn')
175
- fig = mplotfig.figure(figsize=(8,8))
176
- st.pyplot(fig)
177
- # mplotfig.savefig(outputDir + '{}_shapley_gl.png'.format(nameNoExt))
178
- # mplotfig.clear()
179
- # plt.close(mplotfig)
 
1
  import streamlit as st
2
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # x = st.slider('Select a value')
5
  # st.write(x, 'squared is', x * x)
 
17
 
18
  uploaded_file = st.file_uploader("Choose a file", type=["png", "jpg"])
19
  if uploaded_file is not None:
20
+ # To read file as bytes:
21
+ bytes_data = uploaded_file.getvalue()
22
+ pillowImg = Image.open(uploaded_file)
23
+ # print("pillowImg shape: ", )
24
+ st.write(pillowImg.size)
25
+
26
+ # # To convert to a string based IO:
27
+ # stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
28
+ # st.write(stringio)
29
+ #
30
+ # # To read file as string:
31
+ # string_data = stringio.read()
32
+ # st.write(string_data)
33
+ #
34
+ # # Can be used wherever a "file-like" object is accepted:
35
+ # dataframe = pd.read_csv(uploaded_file)
36
+ # st.write(dataframe)