markytools commited on
Commit
918d78a
·
1 Parent(s): b4759d0

updated app

Browse files
Files changed (1) hide show
  1. app.py +19 -3
app.py CHANGED
@@ -58,7 +58,6 @@ from captum.attr._utils.visualization import visualize_image_attr
58
  device = torch.device('cpu')
59
  opt = get_args(is_train=False)
60
 
61
- """ vocab / character number configuration """
62
  if opt.sensitive:
63
  opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
64
 
@@ -125,7 +124,6 @@ if modelName=="vitstr":
125
  model = torch.nn.DataParallel(model_obj).to(device)
126
  modelCopy = copy.deepcopy(model)
127
 
128
- """ evaluation """
129
  scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True)
130
  super_pixel_model_singlechar = torch.nn.Sequential(
131
  # super_pixler,
@@ -193,7 +191,25 @@ if opt.blackbg:
193
  # x = st.slider('Select a value')
194
  # st.write(x, 'squared is', x * x)
195
 
196
- labels = st.text_input('You need to put the text of the image here...', 'BALLYS')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  image = Image.open('demo_image/demo_ballys.jpg') #Brand logo image (optional)
199
  image2 = Image.open('demo_image/demo_ronaldo.jpg') #Brand logo image (optional)
 
58
  device = torch.device('cpu')
59
  opt = get_args(is_train=False)
60
 
 
61
  if opt.sensitive:
62
  opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
63
 
 
124
  model = torch.nn.DataParallel(model_obj).to(device)
125
  modelCopy = copy.deepcopy(model)
126
 
 
127
  scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True)
128
  super_pixel_model_singlechar = torch.nn.Sequential(
129
  # super_pixler,
 
191
  # x = st.slider('Select a value')
192
  # st.write(x, 'squared is', x * x)
193
 
194
+ ### Acquire pixelwise attributions and replace them with ranked numbers averaged
195
+ ### across segmentation with the largest contribution having the largest number
196
+ ### and the smallest set to 1, which is the minimum number.
197
+ ### attr - original attribution
198
+ ### segm - image segmentations
199
+ def rankedAttributionsBySegm(attr, segm):
200
+ aveSegmentations, sortedDict = averageSegmentsOut(attr[0,0], segm)
201
+ totalSegm = len(sortedDict.keys()) # total segmentations
202
+ sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])]
203
+ sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score
204
+ currentRank = totalSegm
205
+ rankedSegmImg = torch.clone(attr)
206
+ for totalSegToHide in range(0, len(sortedKeys)):
207
+ currentSegmentToHide = sortedKeys[totalSegToHide]
208
+ rankedSegmImg[0,0][segm == currentSegmentToHide] = currentRank
209
+ currentRank -= 1
210
+ return rankedSegmImg
211
+
212
+ labels = st.text_input('You need to put the text of the image here (e.g. BALLYS)')
213
 
214
  image = Image.open('demo_image/demo_ballys.jpg') #Brand logo image (optional)
215
  image2 = Image.open('demo_image/demo_ronaldo.jpg') #Brand logo image (optional)