Spaces:
Build error
Build error
Commit
·
918d78a
1
Parent(s):
b4759d0
updated app
Browse files
app.py
CHANGED
@@ -58,7 +58,6 @@ from captum.attr._utils.visualization import visualize_image_attr
|
|
58 |
device = torch.device('cpu')
|
59 |
opt = get_args(is_train=False)
|
60 |
|
61 |
-
""" vocab / character number configuration """
|
62 |
if opt.sensitive:
|
63 |
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
|
64 |
|
@@ -125,7 +124,6 @@ if modelName=="vitstr":
|
|
125 |
model = torch.nn.DataParallel(model_obj).to(device)
|
126 |
modelCopy = copy.deepcopy(model)
|
127 |
|
128 |
-
""" evaluation """
|
129 |
scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True)
|
130 |
super_pixel_model_singlechar = torch.nn.Sequential(
|
131 |
# super_pixler,
|
@@ -193,7 +191,25 @@ if opt.blackbg:
|
|
193 |
# x = st.slider('Select a value')
|
194 |
# st.write(x, 'squared is', x * x)
|
195 |
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
image = Image.open('demo_image/demo_ballys.jpg') #Brand logo image (optional)
|
199 |
image2 = Image.open('demo_image/demo_ronaldo.jpg') #Brand logo image (optional)
|
|
|
58 |
device = torch.device('cpu')
|
59 |
opt = get_args(is_train=False)
|
60 |
|
|
|
61 |
if opt.sensitive:
|
62 |
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
|
63 |
|
|
|
124 |
model = torch.nn.DataParallel(model_obj).to(device)
|
125 |
modelCopy = copy.deepcopy(model)
|
126 |
|
|
|
127 |
scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True)
|
128 |
super_pixel_model_singlechar = torch.nn.Sequential(
|
129 |
# super_pixler,
|
|
|
191 |
# x = st.slider('Select a value')
|
192 |
# st.write(x, 'squared is', x * x)
|
193 |
|
194 |
+
### Acquire pixelwise attributions and replace them with ranked numbers averaged
|
195 |
+
### across segmentation with the largest contribution having the largest number
|
196 |
+
### and the smallest set to 1, which is the minimum number.
|
197 |
+
### attr - original attribution
|
198 |
+
### segm - image segmentations
|
199 |
+
def rankedAttributionsBySegm(attr, segm):
|
200 |
+
aveSegmentations, sortedDict = averageSegmentsOut(attr[0,0], segm)
|
201 |
+
totalSegm = len(sortedDict.keys()) # total segmentations
|
202 |
+
sortedKeys = [k for k, v in sorted(sortedDict.items(), key=lambda item: item[1])]
|
203 |
+
sortedKeys = sortedKeys[::-1] ### A list that should contain largest to smallest score
|
204 |
+
currentRank = totalSegm
|
205 |
+
rankedSegmImg = torch.clone(attr)
|
206 |
+
for totalSegToHide in range(0, len(sortedKeys)):
|
207 |
+
currentSegmentToHide = sortedKeys[totalSegToHide]
|
208 |
+
rankedSegmImg[0,0][segm == currentSegmentToHide] = currentRank
|
209 |
+
currentRank -= 1
|
210 |
+
return rankedSegmImg
|
211 |
+
|
212 |
+
labels = st.text_input('You need to put the text of the image here (e.g. BALLYS)')
|
213 |
|
214 |
image = Image.open('demo_image/demo_ballys.jpg') #Brand logo image (optional)
|
215 |
image2 = Image.open('demo_image/demo_ronaldo.jpg') #Brand logo image (optional)
|