Spaces:
Build error
Build error
Commit
·
6ccedcd
1
Parent(s):
dffa77d
updated app
Browse files
app.py
CHANGED
@@ -80,6 +80,90 @@ elif modelName=="parseq":
|
|
80 |
opt.scorer = "mean"
|
81 |
opt.blackbg = True
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
# x = st.slider('Select a value')
|
84 |
# st.write(x, 'squared is', x * x)
|
85 |
|
@@ -99,6 +183,7 @@ if uploaded_file is not None:
|
|
99 |
# To read file as bytes:
|
100 |
bytes_data = uploaded_file.getvalue()
|
101 |
pilImg = Image.open(uploaded_file)
|
|
|
102 |
|
103 |
orig_img_tensors = transforms.ToTensor()(pilImg).unsqueeze(0)
|
104 |
img1 = orig_img_tensors.to(device)
|
|
|
80 |
opt.scorer = "mean"
|
81 |
opt.blackbg = True
|
82 |
|
83 |
+
segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4,
|
84 |
+
max_dist=200, ratio=0.2,
|
85 |
+
random_seed=random.randint(0, 1000))
|
86 |
+
|
87 |
+
if modelName=="vitstr":
|
88 |
+
if opt.Transformer:
|
89 |
+
converter = TokenLabelConverter(opt)
|
90 |
+
elif 'CTC' in opt.Prediction:
|
91 |
+
converter = CTCLabelConverter(opt.character)
|
92 |
+
else:
|
93 |
+
converter = AttnLabelConverter(opt.character)
|
94 |
+
opt.num_class = len(converter.character)
|
95 |
+
if opt.rgb:
|
96 |
+
opt.input_channel = 3
|
97 |
+
model_obj = Model(opt)
|
98 |
+
|
99 |
+
model = torch.nn.DataParallel(model_obj).to(device)
|
100 |
+
modelCopy = copy.deepcopy(model)
|
101 |
+
|
102 |
+
""" evaluation """
|
103 |
+
scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True)
|
104 |
+
super_pixel_model_singlechar = torch.nn.Sequential(
|
105 |
+
# super_pixler,
|
106 |
+
# numpy2torch_converter,
|
107 |
+
modelCopy,
|
108 |
+
scoring_singlechar
|
109 |
+
).to(device)
|
110 |
+
modelCopy.eval()
|
111 |
+
scoring_singlechar.eval()
|
112 |
+
super_pixel_model_singlechar.eval()
|
113 |
+
|
114 |
+
# Single Char Attribution Averaging
|
115 |
+
# enableSingleCharAttrAve - set to True
|
116 |
+
scoring = STRScore(opt=opt, converter=converter, device=device)
|
117 |
+
super_pixel_model = torch.nn.Sequential(
|
118 |
+
# super_pixler,
|
119 |
+
# numpy2torch_converter,
|
120 |
+
model,
|
121 |
+
scoring
|
122 |
+
).to(device)
|
123 |
+
model.eval()
|
124 |
+
scoring.eval()
|
125 |
+
super_pixel_model.eval()
|
126 |
+
|
127 |
+
elif modelName=="parseq":
|
128 |
+
model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True)
|
129 |
+
# checkpoint = torch.hub.load_state_dict_from_url('https://github.com/baudm/parseq/releases/download/v1.0.0/parseq-bb5792a6.pt', map_location="cpu")
|
130 |
+
# # state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()}
|
131 |
+
# model.load_state_dict(checkpoint)
|
132 |
+
model = model.to(device)
|
133 |
+
model_obj = model
|
134 |
+
converter = TokenLabelConverter(opt)
|
135 |
+
modelCopy = copy.deepcopy(model)
|
136 |
+
|
137 |
+
""" evaluation """
|
138 |
+
scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True, model=modelCopy)
|
139 |
+
super_pixel_model_singlechar = torch.nn.Sequential(
|
140 |
+
# super_pixler,
|
141 |
+
# numpy2torch_converter,
|
142 |
+
modelCopy,
|
143 |
+
scoring_singlechar
|
144 |
+
).to(device)
|
145 |
+
modelCopy.eval()
|
146 |
+
scoring_singlechar.eval()
|
147 |
+
super_pixel_model_singlechar.eval()
|
148 |
+
|
149 |
+
# Single Char Attribution Averaging
|
150 |
+
# enableSingleCharAttrAve - set to True
|
151 |
+
scoring = STRScore(opt=opt, converter=converter, device=device, model=model)
|
152 |
+
super_pixel_model = torch.nn.Sequential(
|
153 |
+
# super_pixler,
|
154 |
+
# numpy2torch_converter,
|
155 |
+
model,
|
156 |
+
scoring
|
157 |
+
).to(device)
|
158 |
+
model.eval()
|
159 |
+
scoring.eval()
|
160 |
+
super_pixel_model.eval()
|
161 |
+
|
162 |
+
if opt.blackbg:
|
163 |
+
shapImgLs = np.zeros(shape=(1, 1, 224, 224)).astype(np.float32)
|
164 |
+
trainList = np.array(shapImgLs)
|
165 |
+
background = torch.from_numpy(trainList).to(device)
|
166 |
+
|
167 |
# x = st.slider('Select a value')
|
168 |
# st.write(x, 'squared is', x * x)
|
169 |
|
|
|
183 |
# To read file as bytes:
|
184 |
bytes_data = uploaded_file.getvalue()
|
185 |
pilImg = Image.open(uploaded_file)
|
186 |
+
pilImg = pilImg.resize((opt.imgW, opt.imgH))
|
187 |
|
188 |
orig_img_tensors = transforms.ToTensor()(pilImg).unsqueeze(0)
|
189 |
img1 = orig_img_tensors.to(device)
|