markytools commited on
Commit
6ccedcd
·
1 Parent(s): dffa77d

updated app

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py CHANGED
@@ -80,6 +80,90 @@ elif modelName=="parseq":
80
  opt.scorer = "mean"
81
  opt.blackbg = True
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # x = st.slider('Select a value')
84
  # st.write(x, 'squared is', x * x)
85
 
@@ -99,6 +183,7 @@ if uploaded_file is not None:
99
  # To read file as bytes:
100
  bytes_data = uploaded_file.getvalue()
101
  pilImg = Image.open(uploaded_file)
 
102
 
103
  orig_img_tensors = transforms.ToTensor()(pilImg).unsqueeze(0)
104
  img1 = orig_img_tensors.to(device)
 
80
  opt.scorer = "mean"
81
  opt.blackbg = True
82
 
83
+ segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4,
84
+ max_dist=200, ratio=0.2,
85
+ random_seed=random.randint(0, 1000))
86
+
87
+ if modelName=="vitstr":
88
+ if opt.Transformer:
89
+ converter = TokenLabelConverter(opt)
90
+ elif 'CTC' in opt.Prediction:
91
+ converter = CTCLabelConverter(opt.character)
92
+ else:
93
+ converter = AttnLabelConverter(opt.character)
94
+ opt.num_class = len(converter.character)
95
+ if opt.rgb:
96
+ opt.input_channel = 3
97
+ model_obj = Model(opt)
98
+
99
+ model = torch.nn.DataParallel(model_obj).to(device)
100
+ modelCopy = copy.deepcopy(model)
101
+
102
+ """ evaluation """
103
+ scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True)
104
+ super_pixel_model_singlechar = torch.nn.Sequential(
105
+ # super_pixler,
106
+ # numpy2torch_converter,
107
+ modelCopy,
108
+ scoring_singlechar
109
+ ).to(device)
110
+ modelCopy.eval()
111
+ scoring_singlechar.eval()
112
+ super_pixel_model_singlechar.eval()
113
+
114
+ # Single Char Attribution Averaging
115
+ # enableSingleCharAttrAve - set to True
116
+ scoring = STRScore(opt=opt, converter=converter, device=device)
117
+ super_pixel_model = torch.nn.Sequential(
118
+ # super_pixler,
119
+ # numpy2torch_converter,
120
+ model,
121
+ scoring
122
+ ).to(device)
123
+ model.eval()
124
+ scoring.eval()
125
+ super_pixel_model.eval()
126
+
127
+ elif modelName=="parseq":
128
+ model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True)
129
+ # checkpoint = torch.hub.load_state_dict_from_url('https://github.com/baudm/parseq/releases/download/v1.0.0/parseq-bb5792a6.pt', map_location="cpu")
130
+ # # state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()}
131
+ # model.load_state_dict(checkpoint)
132
+ model = model.to(device)
133
+ model_obj = model
134
+ converter = TokenLabelConverter(opt)
135
+ modelCopy = copy.deepcopy(model)
136
+
137
+ """ evaluation """
138
+ scoring_singlechar = STRScore(opt=opt, converter=converter, device=device, enableSingleCharAttrAve=True, model=modelCopy)
139
+ super_pixel_model_singlechar = torch.nn.Sequential(
140
+ # super_pixler,
141
+ # numpy2torch_converter,
142
+ modelCopy,
143
+ scoring_singlechar
144
+ ).to(device)
145
+ modelCopy.eval()
146
+ scoring_singlechar.eval()
147
+ super_pixel_model_singlechar.eval()
148
+
149
+ # Single Char Attribution Averaging
150
+ # enableSingleCharAttrAve - set to True
151
+ scoring = STRScore(opt=opt, converter=converter, device=device, model=model)
152
+ super_pixel_model = torch.nn.Sequential(
153
+ # super_pixler,
154
+ # numpy2torch_converter,
155
+ model,
156
+ scoring
157
+ ).to(device)
158
+ model.eval()
159
+ scoring.eval()
160
+ super_pixel_model.eval()
161
+
162
+ if opt.blackbg:
163
+ shapImgLs = np.zeros(shape=(1, 1, 224, 224)).astype(np.float32)
164
+ trainList = np.array(shapImgLs)
165
+ background = torch.from_numpy(trainList).to(device)
166
+
167
  # x = st.slider('Select a value')
168
  # st.write(x, 'squared is', x * x)
169
 
 
183
  # To read file as bytes:
184
  bytes_data = uploaded_file.getvalue()
185
  pilImg = Image.open(uploaded_file)
186
+ pilImg = pilImg.resize((opt.imgW, opt.imgH))
187
 
188
  orig_img_tensors = transforms.ToTensor()(pilImg).unsqueeze(0)
189
  img1 = orig_img_tensors.to(device)