Your Name commited on
Commit
22496cb
·
1 Parent(s): 929aa32
Files changed (1) hide show
  1. demo.py +122 -0
demo.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import glob
5
+ import tqdm
6
+ import torch, re
7
+ import PIL
8
+ import cv2
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+ from torchvision import transforms
12
+ from utils import Config, Logger, CharsetMapper
13
+
14
+ def get_model(config):
15
+ import importlib
16
+ names = config.model_name.split('.')
17
+ module_name, class_name = '.'.join(names[:-1]), names[-1]
18
+ cls = getattr(importlib.import_module(module_name), class_name)
19
+ model = cls(config)
20
+ logging.info(model)
21
+ model = model.eval()
22
+ return model
23
+
24
+ def preprocess(img, width, height):
25
+ img = cv2.resize(np.array(img), (width, height))
26
+ img = transforms.ToTensor()(img).unsqueeze(0)
27
+ mean = torch.tensor([0.485, 0.456, 0.406])
28
+ std = torch.tensor([0.229, 0.224, 0.225])
29
+ return (img-mean[...,None,None]) / std[...,None,None]
30
+
31
+ def postprocess(output, charset, model_eval):
32
+ def _get_output(last_output, model_eval):
33
+ if isinstance(last_output, (tuple, list)):
34
+ for res in last_output:
35
+ if res['name'] == model_eval: output = res
36
+ else: output = last_output
37
+ return output
38
+
39
+ def _decode(logit):
40
+ """ Greed decode """
41
+ out = F.softmax(logit, dim=2)
42
+ pt_text, pt_scores, pt_lengths = [], [], []
43
+ for o in out:
44
+ text = charset.get_text(o.argmax(dim=1), padding=False, trim=False)
45
+ text = text.split(charset.null_char)[0] # end at end-token
46
+ pt_text.append(text)
47
+ pt_scores.append(o.max(dim=1)[0])
48
+ pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token
49
+ return pt_text, pt_scores, pt_lengths
50
+
51
+ output = _get_output(output, model_eval)
52
+ logits, pt_lengths = output['logits'], output['pt_lengths']
53
+ pt_text, pt_scores, pt_lengths_ = _decode(logits)
54
+
55
+ return pt_text, pt_scores, pt_lengths_
56
+
57
+ def load(model, file, device=None, strict=True):
58
+ if device is None: device = 'cpu'
59
+ elif isinstance(device, int): device = torch.device('cuda', device)
60
+ assert os.path.isfile(file)
61
+ state = torch.load(file, map_location=device)
62
+ if set(state.keys()) == {'model', 'opt'}:
63
+ state = state['model']
64
+ model.load_state_dict(state, strict=strict)
65
+ return model
66
+
67
+
68
+ def main():
69
+ parser = argparse.ArgumentParser()
70
+ parser.add_argument('--config', type=str, default='configs/train_abinet.yaml',
71
+ help='path to config file')
72
+ parser.add_argument('--input', type=str, default='figs/test')
73
+ parser.add_argument('--cuda', type=int, default=-1)
74
+ parser.add_argument('--checkpoint', type=str, default='workdir/train-abinet/best-train-abinet.pth')
75
+ parser.add_argument('--model_eval', type=str, default='alignment',
76
+ choices=['alignment', 'vision', 'language'])
77
+ args = parser.parse_args()
78
+ config = Config(args.config)
79
+ if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
80
+ if args.model_eval is not None: config.model_eval = args.model_eval
81
+ config.global_phase = 'test'
82
+ config.model_vision_checkpoint, config.model_language_checkpoint = None, None
83
+ device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}'
84
+
85
+ Logger.init(config.global_workdir, config.global_name, config.global_phase)
86
+ Logger.enable_file()
87
+ logging.info(config)
88
+
89
+ logging.info('Construct model.')
90
+ model = get_model(config).to(device)
91
+ model = load(model, config.model_checkpoint, device=device)
92
+ charset = CharsetMapper(filename=config.dataset_charset_path,
93
+ max_length=config.dataset_max_length + 1)
94
+
95
+ if os.path.isdir(args.input):
96
+ paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)]
97
+ else:
98
+ paths = glob.glob(os.path.expanduser(args.input))
99
+ assert paths, "The input path(s) was not found"
100
+ paths = sorted(paths)
101
+
102
+
103
+ count = 0
104
+ checks = 0
105
+ print(tqdm.tqdm(paths))
106
+ for path in tqdm.tqdm(paths):
107
+ img = PIL.Image.open(path).convert('RGB')
108
+ img = preprocess(img, config.dataset_image_width, config.dataset_image_height)
109
+ img = img.to(device)
110
+ res = model(img)
111
+ pt_text, _, __ = postprocess(res, charset, config.model_eval)
112
+ a = re.findall(r'(\d{6}).png', path)[0]
113
+ # print(a)
114
+ # print(pt_text[0], "Lol")
115
+ # a = re.findall(r'base/(.*).pn', path)[0]
116
+ checks += 1
117
+ if a.lower() != pt_text[0].lower():
118
+ count += 1
119
+ print(f'label:{a.lower()} ||| guess:{pt_text[0]} ||| count_fails:{str(count)}/{str(checks)}')
120
+
121
+ if __name__ == '__main__':
122
+ main()