Spaces:
Runtime error
Runtime error
Your Name
commited on
Commit
·
74a6211
1
Parent(s):
5c5ae78
- LICENSE +25 -0
- README.md +142 -13
- app.py +122 -0
- callbacks.py +360 -0
- commands.sh +3 -0
- config.yaml +71 -0
- dataset.py +278 -0
- losses.py +72 -0
- main.py +246 -0
- requirements.txt +9 -0
- transforms.py +329 -0
- utils.py +304 -0
LICENSE
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ABINet for non-commercial purposes
|
2 |
+
|
3 |
+
Copyright (c) 2021, USTC
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without
|
7 |
+
modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
10 |
+
list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
this list of conditions and the following disclaimer in the documentation
|
14 |
+
and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
17 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
18 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
19 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
20 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
21 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
22 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
23 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
24 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
README.md
CHANGED
@@ -1,13 +1,142 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition
|
2 |
+
|
3 |
+
The official code of [ABINet](https://arxiv.org/pdf/2103.06495.pdf) (CVPR 2021, Oral).
|
4 |
+
|
5 |
+
ABINet uses a vision model and an explicit language model to recognize text in the wild, which are trained in end-to-end way. The language model (BCN) achieves bidirectional language representation in simulating cloze test, additionally utilizing iterative correction strategy.
|
6 |
+
|
7 |
+

|
8 |
+
|
9 |
+
## Runtime Environment
|
10 |
+
|
11 |
+
- We provide a pre-built docker image using the Dockerfile from `docker/Dockerfile`
|
12 |
+
|
13 |
+
- Running in Docker
|
14 |
+
```
|
15 |
+
$ [email protected]:FangShancheng/ABINet.git
|
16 |
+
$ docker run --gpus all --rm -ti --ipc=host -v $(pwd)/ABINet:/app fangshancheng/fastai:torch1.1 /bin/bash
|
17 |
+
```
|
18 |
+
- (Untested) Or using the dependencies
|
19 |
+
```
|
20 |
+
pip install -r requirements.txt
|
21 |
+
```
|
22 |
+
|
23 |
+
## Datasets
|
24 |
+
|
25 |
+
- Training datasets
|
26 |
+
|
27 |
+
1. [MJSynth](http://www.robots.ox.ac.uk/~vgg/data/text/) (MJ):
|
28 |
+
- Use `tools/create_lmdb_dataset.py` to convert images into LMDB dataset
|
29 |
+
- [LMDB dataset BaiduNetdisk(passwd:n23k)](https://pan.baidu.com/s/1mgnTiyoR8f6Cm655rFI4HQ)
|
30 |
+
2. [SynthText](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) (ST):
|
31 |
+
- Use `tools/crop_by_word_bb.py` to crop images from original [SynthText](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) dataset, and convert images into LMDB dataset by `tools/create_lmdb_dataset.py`
|
32 |
+
- [LMDB dataset BaiduNetdisk(passwd:n23k)](https://pan.baidu.com/s/1mgnTiyoR8f6Cm655rFI4HQ)
|
33 |
+
3. [WikiText103](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip), which is only used for pre-trainig language models:
|
34 |
+
- Use `notebooks/prepare_wikitext103.ipynb` to convert text into CSV format.
|
35 |
+
- [CSV dataset BaiduNetdisk(passwd:dk01)](https://pan.baidu.com/s/1yabtnPYDKqhBb_Ie9PGFXA)
|
36 |
+
|
37 |
+
- Evaluation datasets, LMDB datasets can be downloaded from [BaiduNetdisk(passwd:1dbv)](https://pan.baidu.com/s/1RUg3Akwp7n8kZYJ55rU5LQ), [GoogleDrive](https://drive.google.com/file/d/1dTI0ipu14Q1uuK4s4z32DqbqF3dJPdkk/view?usp=sharing).
|
38 |
+
1. ICDAR 2013 (IC13)
|
39 |
+
2. ICDAR 2015 (IC15)
|
40 |
+
3. IIIT5K Words (IIIT)
|
41 |
+
4. Street View Text (SVT)
|
42 |
+
5. Street View Text-Perspective (SVTP)
|
43 |
+
6. CUTE80 (CUTE)
|
44 |
+
|
45 |
+
|
46 |
+
- The structure of `data` directory is
|
47 |
+
```
|
48 |
+
data
|
49 |
+
├── charset_36.txt
|
50 |
+
├── evaluation
|
51 |
+
│ ├── CUTE80
|
52 |
+
│ ├── IC13_857
|
53 |
+
│ ├── IC15_1811
|
54 |
+
│ ├── IIIT5k_3000
|
55 |
+
│ ├── SVT
|
56 |
+
│ └── SVTP
|
57 |
+
├── training
|
58 |
+
│ ├── MJ
|
59 |
+
│ │ ├── MJ_test
|
60 |
+
│ │ ├── MJ_train
|
61 |
+
│ │ └── MJ_valid
|
62 |
+
│ └── ST
|
63 |
+
├── WikiText-103.csv
|
64 |
+
└── WikiText-103_eval_d1.csv
|
65 |
+
```
|
66 |
+
|
67 |
+
### Pretrained Models
|
68 |
+
|
69 |
+
Get the pretrained models from [BaiduNetdisk(passwd:kwck)](https://pan.baidu.com/s/1b3vyvPwvh_75FkPlp87czQ), [GoogleDrive](https://drive.google.com/file/d/1mYM_26qHUom_5NU7iutHneB_KHlLjL5y/view?usp=sharing). Performances of the pretrained models are summaried as follows:
|
70 |
+
|
71 |
+
|Model|IC13|SVT|IIIT|IC15|SVTP|CUTE|AVG|
|
72 |
+
|-|-|-|-|-|-|-|-|
|
73 |
+
|ABINet-SV|97.1|92.7|95.2|84.0|86.7|88.5|91.4|
|
74 |
+
|ABINet-LV|97.0|93.4|96.4|85.9|89.5|89.2|92.7|
|
75 |
+
|
76 |
+
## Training
|
77 |
+
|
78 |
+
1. Pre-train vision model
|
79 |
+
```
|
80 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config=configs/pretrain_vision_model.yaml
|
81 |
+
```
|
82 |
+
2. Pre-train language model
|
83 |
+
```
|
84 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config=configs/pretrain_language_model.yaml
|
85 |
+
```
|
86 |
+
3. Train ABINet
|
87 |
+
```
|
88 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config=configs/train_abinet.yaml
|
89 |
+
```
|
90 |
+
Note:
|
91 |
+
- You can set the `checkpoint` path for vision and language models separately for specific pretrained model, or set to `None` to train from scratch
|
92 |
+
|
93 |
+
|
94 |
+
## Evaluation
|
95 |
+
|
96 |
+
```
|
97 |
+
CUDA_VISIBLE_DEVICES=0 python main.py --config=configs/train_abinet.yaml --phase test --image_only
|
98 |
+
```
|
99 |
+
Additional flags:
|
100 |
+
- `--checkpoint /path/to/checkpoint` set the path of evaluation model
|
101 |
+
- `--test_root /path/to/dataset` set the path of evaluation dataset
|
102 |
+
- `--model_eval [alignment|vision]` which sub-model to evaluate
|
103 |
+
- `--image_only` disable dumping visualization of attention masks
|
104 |
+
|
105 |
+
## Web Demo
|
106 |
+
|
107 |
+
Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [](https://huggingface.co/spaces/tomofi/ABINet-OCR)
|
108 |
+
|
109 |
+
## Run Demo
|
110 |
+
|
111 |
+
```
|
112 |
+
python demo.py --config=configs/train_abinet.yaml --input=figs/test
|
113 |
+
```
|
114 |
+
Additional flags:
|
115 |
+
- `--config /path/to/config` set the path of configuration file
|
116 |
+
- `--input /path/to/image-directory` set the path of image directory or wildcard path, e.g, `--input='figs/test/*.png'`
|
117 |
+
- `--checkpoint /path/to/checkpoint` set the path of trained model
|
118 |
+
- `--cuda [-1|0|1|2|3...]` set the cuda id, by default -1 is set and stands for cpu
|
119 |
+
- `--model_eval [alignment|vision]` which sub-model to use
|
120 |
+
- `--image_only` disable dumping visualization of attention masks
|
121 |
+
|
122 |
+
## Visualization
|
123 |
+
Successful and failure cases on low-quality images:
|
124 |
+
|
125 |
+

|
126 |
+
|
127 |
+
## Citation
|
128 |
+
If you find our method useful for your reserach, please cite
|
129 |
+
```bash
|
130 |
+
@article{fang2021read,
|
131 |
+
title={Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition},
|
132 |
+
author={Fang, Shancheng and Xie, Hongtao and Wang, Yuxin and Mao, Zhendong and Zhang, Yongdong},
|
133 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
134 |
+
year={2021}
|
135 |
+
}
|
136 |
+
```
|
137 |
+
|
138 |
+
## License
|
139 |
+
|
140 |
+
This project is only free for academic research purposes, licensed under the 2-clause BSD License - see the LICENSE file for details.
|
141 |
+
|
142 |
+
Feel free to contact [email protected] if you have any questions.
|
app.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import glob
|
5 |
+
import tqdm
|
6 |
+
import torch, re
|
7 |
+
import PIL
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torchvision import transforms
|
12 |
+
from utils import Config, Logger, CharsetMapper
|
13 |
+
|
14 |
+
def get_model(config):
|
15 |
+
import importlib
|
16 |
+
names = config.model_name.split('.')
|
17 |
+
module_name, class_name = '.'.join(names[:-1]), names[-1]
|
18 |
+
cls = getattr(importlib.import_module(module_name), class_name)
|
19 |
+
model = cls(config)
|
20 |
+
logging.info(model)
|
21 |
+
model = model.eval()
|
22 |
+
return model
|
23 |
+
|
24 |
+
def preprocess(img, width, height):
|
25 |
+
img = cv2.resize(np.array(img), (width, height))
|
26 |
+
img = transforms.ToTensor()(img).unsqueeze(0)
|
27 |
+
mean = torch.tensor([0.485, 0.456, 0.406])
|
28 |
+
std = torch.tensor([0.229, 0.224, 0.225])
|
29 |
+
return (img-mean[...,None,None]) / std[...,None,None]
|
30 |
+
|
31 |
+
def postprocess(output, charset, model_eval):
|
32 |
+
def _get_output(last_output, model_eval):
|
33 |
+
if isinstance(last_output, (tuple, list)):
|
34 |
+
for res in last_output:
|
35 |
+
if res['name'] == model_eval: output = res
|
36 |
+
else: output = last_output
|
37 |
+
return output
|
38 |
+
|
39 |
+
def _decode(logit):
|
40 |
+
""" Greed decode """
|
41 |
+
out = F.softmax(logit, dim=2)
|
42 |
+
pt_text, pt_scores, pt_lengths = [], [], []
|
43 |
+
for o in out:
|
44 |
+
text = charset.get_text(o.argmax(dim=1), padding=False, trim=False)
|
45 |
+
text = text.split(charset.null_char)[0] # end at end-token
|
46 |
+
pt_text.append(text)
|
47 |
+
pt_scores.append(o.max(dim=1)[0])
|
48 |
+
pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token
|
49 |
+
return pt_text, pt_scores, pt_lengths
|
50 |
+
|
51 |
+
output = _get_output(output, model_eval)
|
52 |
+
logits, pt_lengths = output['logits'], output['pt_lengths']
|
53 |
+
pt_text, pt_scores, pt_lengths_ = _decode(logits)
|
54 |
+
|
55 |
+
return pt_text, pt_scores, pt_lengths_
|
56 |
+
|
57 |
+
def load(model, file, device=None, strict=True):
|
58 |
+
if device is None: device = 'cpu'
|
59 |
+
elif isinstance(device, int): device = torch.device('cuda', device)
|
60 |
+
assert os.path.isfile(file)
|
61 |
+
state = torch.load(file, map_location=device)
|
62 |
+
if set(state.keys()) == {'model', 'opt'}:
|
63 |
+
state = state['model']
|
64 |
+
model.load_state_dict(state, strict=strict)
|
65 |
+
return model
|
66 |
+
|
67 |
+
|
68 |
+
def main():
|
69 |
+
parser = argparse.ArgumentParser()
|
70 |
+
parser.add_argument('--config', type=str, default='configs/train_abinet.yaml',
|
71 |
+
help='path to config file')
|
72 |
+
parser.add_argument('--input', type=str, default='figs/test')
|
73 |
+
parser.add_argument('--cuda', type=int, default=-1)
|
74 |
+
parser.add_argument('--checkpoint', type=str, default='workdir/train-abinet/best-train-abinet.pth')
|
75 |
+
parser.add_argument('--model_eval', type=str, default='alignment',
|
76 |
+
choices=['alignment', 'vision', 'language'])
|
77 |
+
args = parser.parse_args()
|
78 |
+
config = Config(args.config)
|
79 |
+
if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
|
80 |
+
if args.model_eval is not None: config.model_eval = args.model_eval
|
81 |
+
config.global_phase = 'test'
|
82 |
+
config.model_vision_checkpoint, config.model_language_checkpoint = None, None
|
83 |
+
device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}'
|
84 |
+
|
85 |
+
Logger.init(config.global_workdir, config.global_name, config.global_phase)
|
86 |
+
Logger.enable_file()
|
87 |
+
logging.info(config)
|
88 |
+
|
89 |
+
logging.info('Construct model.')
|
90 |
+
model = get_model(config).to(device)
|
91 |
+
model = load(model, config.model_checkpoint, device=device)
|
92 |
+
charset = CharsetMapper(filename=config.dataset_charset_path,
|
93 |
+
max_length=config.dataset_max_length + 1)
|
94 |
+
|
95 |
+
if os.path.isdir(args.input):
|
96 |
+
paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)]
|
97 |
+
else:
|
98 |
+
paths = glob.glob(os.path.expanduser(args.input))
|
99 |
+
assert paths, "The input path(s) was not found"
|
100 |
+
paths = sorted(paths)
|
101 |
+
|
102 |
+
|
103 |
+
count = 0
|
104 |
+
checks = 0
|
105 |
+
print(tqdm.tqdm(paths))
|
106 |
+
for path in tqdm.tqdm(paths):
|
107 |
+
img = PIL.Image.open(path).convert('RGB')
|
108 |
+
img = preprocess(img, config.dataset_image_width, config.dataset_image_height)
|
109 |
+
img = img.to(device)
|
110 |
+
res = model(img)
|
111 |
+
pt_text, _, __ = postprocess(res, charset, config.model_eval)
|
112 |
+
a = re.findall(r'(\d{6}).png', path)[0]
|
113 |
+
# print(a)
|
114 |
+
# print(pt_text[0], "Lol")
|
115 |
+
# a = re.findall(r'base/(.*).pn', path)[0]
|
116 |
+
checks += 1
|
117 |
+
if a.lower() != pt_text[0].lower():
|
118 |
+
count += 1
|
119 |
+
print(f'label:{a.lower()} ||| guess:{pt_text[0]} ||| count_fails:{str(count)}/{str(checks)}')
|
120 |
+
|
121 |
+
if __name__ == '__main__':
|
122 |
+
main()
|
callbacks.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import shutil
|
3 |
+
import time
|
4 |
+
|
5 |
+
import editdistance as ed
|
6 |
+
import torchvision.utils as vutils
|
7 |
+
from fastai.callbacks.tensorboard import (LearnerTensorboardWriter,
|
8 |
+
SummaryWriter, TBWriteRequest,
|
9 |
+
asyncTBWriter)
|
10 |
+
from fastai.vision import *
|
11 |
+
from torch.nn.parallel import DistributedDataParallel
|
12 |
+
from torchvision import transforms
|
13 |
+
|
14 |
+
import dataset
|
15 |
+
from utils import CharsetMapper, Timer, blend_mask
|
16 |
+
|
17 |
+
|
18 |
+
class IterationCallback(LearnerTensorboardWriter):
|
19 |
+
"A `TrackerCallback` that monitor in each iteration."
|
20 |
+
def __init__(self, learn:Learner, name:str='model', checpoint_keep_num=5,
|
21 |
+
show_iters:int=50, eval_iters:int=1000, save_iters:int=20000,
|
22 |
+
start_iters:int=0, stats_iters=20000):
|
23 |
+
#if self.learn.rank is not None: time.sleep(self.learn.rank) # keep all event files
|
24 |
+
super().__init__(learn, base_dir='.', name=learn.path, loss_iters=show_iters,
|
25 |
+
stats_iters=stats_iters, hist_iters=stats_iters)
|
26 |
+
self.name, self.bestname = Path(name).name, f'best-{Path(name).name}'
|
27 |
+
self.show_iters = show_iters
|
28 |
+
self.eval_iters = eval_iters
|
29 |
+
self.save_iters = save_iters
|
30 |
+
self.start_iters = start_iters
|
31 |
+
self.checpoint_keep_num = checpoint_keep_num
|
32 |
+
self.metrics_root = 'metrics/' # rewrite
|
33 |
+
self.timer = Timer()
|
34 |
+
self.host = self.learn.rank is None or self.learn.rank == 0
|
35 |
+
|
36 |
+
def _write_metrics(self, iteration:int, names:List[str], last_metrics:MetricsList)->None:
|
37 |
+
"Writes training metrics to Tensorboard."
|
38 |
+
for i, name in enumerate(names):
|
39 |
+
if last_metrics is None or len(last_metrics) < i+1: return
|
40 |
+
scalar_value = last_metrics[i]
|
41 |
+
self._write_scalar(name=name, scalar_value=scalar_value, iteration=iteration)
|
42 |
+
|
43 |
+
def _write_sub_loss(self, iteration:int, last_losses:dict)->None:
|
44 |
+
"Writes sub loss to Tensorboard."
|
45 |
+
for name, loss in last_losses.items():
|
46 |
+
scalar_value = to_np(loss)
|
47 |
+
tag = self.metrics_root + name
|
48 |
+
self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
|
49 |
+
|
50 |
+
def _save(self, name):
|
51 |
+
if isinstance(self.learn.model, DistributedDataParallel):
|
52 |
+
tmp = self.learn.model
|
53 |
+
self.learn.model = self.learn.model.module
|
54 |
+
self.learn.save(name)
|
55 |
+
self.learn.model = tmp
|
56 |
+
else: self.learn.save(name)
|
57 |
+
|
58 |
+
def _validate(self, dl=None, callbacks=None, metrics=None, keeped_items=False):
|
59 |
+
"Validate on `dl` with potential `callbacks` and `metrics`."
|
60 |
+
dl = ifnone(dl, self.learn.data.valid_dl)
|
61 |
+
metrics = ifnone(metrics, self.learn.metrics)
|
62 |
+
cb_handler = CallbackHandler(ifnone(callbacks, []), metrics)
|
63 |
+
cb_handler.on_train_begin(1, None, metrics); cb_handler.on_epoch_begin()
|
64 |
+
if keeped_items: cb_handler.state_dict.update(dict(keeped_items=[]))
|
65 |
+
val_metrics = validate(self.learn.model, dl, self.loss_func, cb_handler)
|
66 |
+
cb_handler.on_epoch_end(val_metrics)
|
67 |
+
if keeped_items: return cb_handler.state_dict['keeped_items']
|
68 |
+
else: return cb_handler.state_dict['last_metrics']
|
69 |
+
|
70 |
+
def jump_to_epoch_iter(self, epoch:int, iteration:int)->None:
|
71 |
+
try:
|
72 |
+
self.learn.load(f'{self.name}_{epoch}_{iteration}', purge=False)
|
73 |
+
logging.info(f'Loaded {self.name}_{epoch}_{iteration}')
|
74 |
+
except: logging.info(f'Model {self.name}_{epoch}_{iteration} not found.')
|
75 |
+
|
76 |
+
def on_train_begin(self, n_epochs, **kwargs):
|
77 |
+
# TODO: can not write graph here
|
78 |
+
# super().on_train_begin(**kwargs)
|
79 |
+
self.best = -float('inf')
|
80 |
+
self.timer.tic()
|
81 |
+
if self.host:
|
82 |
+
checkpoint_path = self.learn.path/'checkpoint.yaml'
|
83 |
+
if checkpoint_path.exists():
|
84 |
+
os.remove(checkpoint_path)
|
85 |
+
open(checkpoint_path, 'w').close()
|
86 |
+
return {'skip_validate': True, 'iteration':self.start_iters} # disable default validate
|
87 |
+
|
88 |
+
def on_batch_begin(self, **kwargs:Any)->None:
|
89 |
+
self.timer.toc_data()
|
90 |
+
super().on_batch_begin(**kwargs)
|
91 |
+
|
92 |
+
def on_batch_end(self, iteration, epoch, last_loss, smooth_loss, train, **kwargs):
|
93 |
+
super().on_batch_end(last_loss, iteration, train, **kwargs)
|
94 |
+
if iteration == 0: return
|
95 |
+
|
96 |
+
if iteration % self.loss_iters == 0:
|
97 |
+
last_losses = self.learn.loss_func.last_losses
|
98 |
+
self._write_sub_loss(iteration=iteration, last_losses=last_losses)
|
99 |
+
self.tbwriter.add_scalar(tag=self.metrics_root + 'lr',
|
100 |
+
scalar_value=self.opt.lr, global_step=iteration)
|
101 |
+
|
102 |
+
if iteration % self.show_iters == 0:
|
103 |
+
log_str = f'epoch {epoch} iter {iteration}: loss = {last_loss:6.4f}, ' \
|
104 |
+
f'smooth loss = {smooth_loss:6.4f}'
|
105 |
+
logging.info(log_str)
|
106 |
+
# log_str = f'data time = {self.timer.data_diff:.4f}s, runing time = {self.timer.running_diff:.4f}s'
|
107 |
+
# logging.info(log_str)
|
108 |
+
|
109 |
+
if iteration % self.eval_iters == 0:
|
110 |
+
# TODO: or remove time to on_epoch_end
|
111 |
+
# 1. Record time
|
112 |
+
log_str = f'average data time = {self.timer.average_data_time():.4f}s, ' \
|
113 |
+
f'average running time = {self.timer.average_running_time():.4f}s'
|
114 |
+
logging.info(log_str)
|
115 |
+
|
116 |
+
# 2. Call validate
|
117 |
+
last_metrics = self._validate()
|
118 |
+
self.learn.model.train()
|
119 |
+
log_str = f'epoch {epoch} iter {iteration}: eval loss = {last_metrics[0]:6.4f}, ' \
|
120 |
+
f'ccr = {last_metrics[1]:6.4f}, cwr = {last_metrics[2]:6.4f}, ' \
|
121 |
+
f'ted = {last_metrics[3]:6.4f}, ned = {last_metrics[4]:6.4f}, ' \
|
122 |
+
f'ted/w = {last_metrics[5]:6.4f}, '
|
123 |
+
logging.info(log_str)
|
124 |
+
names = ['eval_loss', 'ccr', 'cwr', 'ted', 'ned', 'ted/w']
|
125 |
+
self._write_metrics(iteration, names, last_metrics)
|
126 |
+
|
127 |
+
# 3. Save best model
|
128 |
+
current = last_metrics[2]
|
129 |
+
if current is not None and current > self.best:
|
130 |
+
logging.info(f'Better model found at epoch {epoch}, '\
|
131 |
+
f'iter {iteration} with accuracy value: {current:6.4f}.')
|
132 |
+
self.best = current
|
133 |
+
self._save(f'{self.bestname}')
|
134 |
+
|
135 |
+
if iteration % self.save_iters == 0 and self.host:
|
136 |
+
logging.info(f'Save model {self.name}_{epoch}_{iteration}')
|
137 |
+
filename = f'{self.name}_{epoch}_{iteration}'
|
138 |
+
self._save(filename)
|
139 |
+
|
140 |
+
checkpoint_path = self.learn.path/'checkpoint.yaml'
|
141 |
+
if not checkpoint_path.exists():
|
142 |
+
open(checkpoint_path, 'w').close()
|
143 |
+
with open(checkpoint_path, 'r') as file:
|
144 |
+
checkpoints = yaml.load(file, Loader=yaml.FullLoader) or dict()
|
145 |
+
checkpoints['all_checkpoints'] = (
|
146 |
+
checkpoints.get('all_checkpoints') or list())
|
147 |
+
checkpoints['all_checkpoints'].insert(0, filename)
|
148 |
+
if len(checkpoints['all_checkpoints']) > self.checpoint_keep_num:
|
149 |
+
removed_checkpoint = checkpoints['all_checkpoints'].pop()
|
150 |
+
removed_checkpoint = self.learn.path/self.learn.model_dir/f'{removed_checkpoint}.pth'
|
151 |
+
os.remove(removed_checkpoint)
|
152 |
+
checkpoints['current_checkpoint'] = filename
|
153 |
+
with open(checkpoint_path, 'w') as file:
|
154 |
+
yaml.dump(checkpoints, file)
|
155 |
+
|
156 |
+
|
157 |
+
self.timer.toc_running()
|
158 |
+
|
159 |
+
def on_train_end(self, **kwargs):
|
160 |
+
#self.learn.load(f'{self.bestname}', purge=False)
|
161 |
+
pass
|
162 |
+
|
163 |
+
def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs)->None:
|
164 |
+
self._write_embedding(iteration=iteration)
|
165 |
+
|
166 |
+
|
167 |
+
class TextAccuracy(Callback):
|
168 |
+
_names = ['ccr', 'cwr', 'ted', 'ned', 'ted/w']
|
169 |
+
def __init__(self, charset_path, max_length, case_sensitive, model_eval):
|
170 |
+
self.charset_path = charset_path
|
171 |
+
self.max_length = max_length
|
172 |
+
self.case_sensitive = case_sensitive
|
173 |
+
self.charset = CharsetMapper(charset_path, self.max_length)
|
174 |
+
self.names = self._names
|
175 |
+
|
176 |
+
self.model_eval = model_eval or 'alignment'
|
177 |
+
assert self.model_eval in ['vision', 'language', 'alignment']
|
178 |
+
|
179 |
+
def on_epoch_begin(self, **kwargs):
|
180 |
+
self.total_num_char = 0.
|
181 |
+
self.total_num_word = 0.
|
182 |
+
self.correct_num_char = 0.
|
183 |
+
self.correct_num_word = 0.
|
184 |
+
self.total_ed = 0.
|
185 |
+
self.total_ned = 0.
|
186 |
+
|
187 |
+
def _get_output(self, last_output):
|
188 |
+
if isinstance(last_output, (tuple, list)):
|
189 |
+
for res in last_output:
|
190 |
+
if res['name'] == self.model_eval: output = res
|
191 |
+
else: output = last_output
|
192 |
+
return output
|
193 |
+
|
194 |
+
def _update_output(self, last_output, items):
|
195 |
+
if isinstance(last_output, (tuple, list)):
|
196 |
+
for res in last_output:
|
197 |
+
if res['name'] == self.model_eval: res.update(items)
|
198 |
+
else: last_output.update(items)
|
199 |
+
return last_output
|
200 |
+
|
201 |
+
def on_batch_end(self, last_output, last_target, **kwargs):
|
202 |
+
output = self._get_output(last_output)
|
203 |
+
logits, pt_lengths = output['logits'], output['pt_lengths']
|
204 |
+
pt_text, pt_scores, pt_lengths_ = self.decode(logits)
|
205 |
+
assert (pt_lengths == pt_lengths_).all(), f'{pt_lengths} != {pt_lengths_} for {pt_text}'
|
206 |
+
last_output = self._update_output(last_output, {'pt_text':pt_text, 'pt_scores':pt_scores})
|
207 |
+
|
208 |
+
pt_text = [self.charset.trim(t) for t in pt_text]
|
209 |
+
label = last_target[0]
|
210 |
+
if label.dim() == 3: label = label.argmax(dim=-1) # one-hot label
|
211 |
+
gt_text = [self.charset.get_text(l, trim=True) for l in label]
|
212 |
+
|
213 |
+
for i in range(len(gt_text)):
|
214 |
+
if not self.case_sensitive:
|
215 |
+
gt_text[i], pt_text[i] = gt_text[i].lower(), pt_text[i].lower()
|
216 |
+
distance = ed.eval(gt_text[i], pt_text[i])
|
217 |
+
self.total_ed += distance
|
218 |
+
self.total_ned += float(distance) / max(len(gt_text[i]), 1)
|
219 |
+
|
220 |
+
if gt_text[i] == pt_text[i]:
|
221 |
+
self.correct_num_word += 1
|
222 |
+
self.total_num_word += 1
|
223 |
+
|
224 |
+
for j in range(min(len(gt_text[i]), len(pt_text[i]))):
|
225 |
+
if gt_text[i][j] == pt_text[i][j]:
|
226 |
+
self.correct_num_char += 1
|
227 |
+
self.total_num_char += len(gt_text[i])
|
228 |
+
|
229 |
+
return {'last_output': last_output}
|
230 |
+
|
231 |
+
def on_epoch_end(self, last_metrics, **kwargs):
|
232 |
+
mets = [self.correct_num_char / self.total_num_char,
|
233 |
+
self.correct_num_word / self.total_num_word,
|
234 |
+
self.total_ed,
|
235 |
+
self.total_ned,
|
236 |
+
self.total_ed / self.total_num_word]
|
237 |
+
return add_metrics(last_metrics, mets)
|
238 |
+
|
239 |
+
def decode(self, logit):
|
240 |
+
""" Greed decode """
|
241 |
+
# TODO: test running time and decode on GPU
|
242 |
+
out = F.softmax(logit, dim=2)
|
243 |
+
pt_text, pt_scores, pt_lengths = [], [], []
|
244 |
+
for o in out:
|
245 |
+
text = self.charset.get_text(o.argmax(dim=1), padding=False, trim=False)
|
246 |
+
text = text.split(self.charset.null_char)[0] # end at end-token
|
247 |
+
pt_text.append(text)
|
248 |
+
pt_scores.append(o.max(dim=1)[0])
|
249 |
+
pt_lengths.append(min(len(text) + 1, self.max_length)) # one for end-token
|
250 |
+
pt_scores = torch.stack(pt_scores)
|
251 |
+
pt_lengths = pt_scores.new_tensor(pt_lengths, dtype=torch.long)
|
252 |
+
return pt_text, pt_scores, pt_lengths
|
253 |
+
|
254 |
+
|
255 |
+
class TopKTextAccuracy(TextAccuracy):
|
256 |
+
_names = ['ccr', 'cwr']
|
257 |
+
def __init__(self, k, charset_path, max_length, case_sensitive, model_eval):
|
258 |
+
self.k = k
|
259 |
+
self.charset_path = charset_path
|
260 |
+
self.max_length = max_length
|
261 |
+
self.case_sensitive = case_sensitive
|
262 |
+
self.charset = CharsetMapper(charset_path, self.max_length)
|
263 |
+
self.names = self._names
|
264 |
+
|
265 |
+
def on_epoch_begin(self, **kwargs):
|
266 |
+
self.total_num_char = 0.
|
267 |
+
self.total_num_word = 0.
|
268 |
+
self.correct_num_char = 0.
|
269 |
+
self.correct_num_word = 0.
|
270 |
+
|
271 |
+
def on_batch_end(self, last_output, last_target, **kwargs):
|
272 |
+
logits, pt_lengths = last_output['logits'], last_output['pt_lengths']
|
273 |
+
gt_labels, gt_lengths = last_target[:]
|
274 |
+
|
275 |
+
for logit, pt_length, label, length in zip(logits, pt_lengths, gt_labels, gt_lengths):
|
276 |
+
word_flag = True
|
277 |
+
for i in range(length):
|
278 |
+
char_logit = logit[i].topk(self.k)[1]
|
279 |
+
char_label = label[i].argmax(-1)
|
280 |
+
if char_label in char_logit: self.correct_num_char += 1
|
281 |
+
else: word_flag = False
|
282 |
+
self.total_num_char += 1
|
283 |
+
if pt_length == length and word_flag:
|
284 |
+
self.correct_num_word += 1
|
285 |
+
self.total_num_word += 1
|
286 |
+
|
287 |
+
def on_epoch_end(self, last_metrics, **kwargs):
|
288 |
+
mets = [self.correct_num_char / self.total_num_char,
|
289 |
+
self.correct_num_word / self.total_num_word,
|
290 |
+
0., 0., 0.]
|
291 |
+
return add_metrics(last_metrics, mets)
|
292 |
+
|
293 |
+
|
294 |
+
class DumpPrediction(LearnerCallback):
|
295 |
+
|
296 |
+
def __init__(self, learn, dataset, charset_path, model_eval, image_only=False, debug=False):
|
297 |
+
super().__init__(learn=learn)
|
298 |
+
self.debug = debug
|
299 |
+
self.model_eval = model_eval or 'alignment'
|
300 |
+
self.image_only = image_only
|
301 |
+
assert self.model_eval in ['vision', 'language', 'alignment']
|
302 |
+
|
303 |
+
self.dataset, self.root = dataset, Path(self.learn.path)/f'{dataset}-{self.model_eval}'
|
304 |
+
self.attn_root = self.root/'attn'
|
305 |
+
self.charset = CharsetMapper(charset_path)
|
306 |
+
if self.root.exists(): shutil.rmtree(self.root)
|
307 |
+
self.root.mkdir(), self.attn_root.mkdir()
|
308 |
+
|
309 |
+
self.pil = transforms.ToPILImage()
|
310 |
+
self.tensor = transforms.ToTensor()
|
311 |
+
size = self.learn.data.img_h, self.learn.data.img_w
|
312 |
+
self.resize = transforms.Resize(size=size, interpolation=0)
|
313 |
+
self.c = 0
|
314 |
+
|
315 |
+
def on_batch_end(self, last_input, last_output, last_target, **kwargs):
|
316 |
+
if isinstance(last_output, (tuple, list)):
|
317 |
+
for res in last_output:
|
318 |
+
if res['name'] == self.model_eval: pt_text = res['pt_text']
|
319 |
+
if res['name'] == 'vision': attn_scores = res['attn_scores'].detach().cpu()
|
320 |
+
if res['name'] == self.model_eval: logits = res['logits']
|
321 |
+
else:
|
322 |
+
pt_text = last_output['pt_text']
|
323 |
+
attn_scores = last_output['attn_scores'].detach().cpu()
|
324 |
+
logits = last_output['logits']
|
325 |
+
|
326 |
+
images = last_input[0] if isinstance(last_input, (tuple, list)) else last_input
|
327 |
+
images = images.detach().cpu()
|
328 |
+
pt_text = [self.charset.trim(t) for t in pt_text]
|
329 |
+
gt_label = last_target[0]
|
330 |
+
if gt_label.dim() == 3: gt_label = gt_label.argmax(dim=-1) # one-hot label
|
331 |
+
gt_text = [self.charset.get_text(l, trim=True) for l in gt_label]
|
332 |
+
|
333 |
+
prediction, false_prediction = [], []
|
334 |
+
for gt, pt, image, attn, logit in zip(gt_text, pt_text, images, attn_scores, logits):
|
335 |
+
prediction.append(f'{gt}\t{pt}\n')
|
336 |
+
if gt != pt:
|
337 |
+
if self.debug:
|
338 |
+
scores = torch.softmax(logit, dim=-1)[:max(len(pt), len(gt)) + 1]
|
339 |
+
logging.info(f'{self.c} gt {gt}, pt {pt}, logit {logit.shape}, scores {scores.topk(5, dim=-1)}')
|
340 |
+
false_prediction.append(f'{gt}\t{pt}\n')
|
341 |
+
|
342 |
+
image = self.learn.data.denorm(image)
|
343 |
+
if not self.image_only:
|
344 |
+
image_np = np.array(self.pil(image))
|
345 |
+
attn_pil = [self.pil(a) for a in attn[:, None, :, :]]
|
346 |
+
attn = [self.tensor(self.resize(a)).repeat(3, 1, 1) for a in attn_pil]
|
347 |
+
attn_sum = np.array([np.array(a) for a in attn_pil[:len(pt)]]).sum(axis=0)
|
348 |
+
blended_sum = self.tensor(blend_mask(image_np, attn_sum))
|
349 |
+
blended = [self.tensor(blend_mask(image_np, np.array(a))) for a in attn_pil]
|
350 |
+
save_image = torch.stack([image] + attn + [blended_sum] + blended)
|
351 |
+
save_image = save_image.view(2, -1, *save_image.shape[1:])
|
352 |
+
save_image = save_image.permute(1, 0, 2, 3, 4).flatten(0, 1)
|
353 |
+
vutils.save_image(save_image, self.attn_root/f'{self.c}_{gt}_{pt}.jpg',
|
354 |
+
nrow=2, normalize=True, scale_each=True)
|
355 |
+
else:
|
356 |
+
self.pil(image).save(self.attn_root/f'{self.c}_{gt}_{pt}.jpg')
|
357 |
+
self.c += 1
|
358 |
+
|
359 |
+
with open(self.root/f'{self.model_eval}.txt', 'a') as f: f.writelines(prediction)
|
360 |
+
with open(self.root/f'{self.model_eval}-false.txt', 'a') as f: f.writelines(false_prediction)
|
commands.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
python tools/create_lmdb_dataset.py
|
2 |
+
python demo.py --config=configs/train_abinet.yaml --input=base/
|
3 |
+
python main.py --config=configs/train_abinet.yaml
|
config.yaml
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
global:
|
2 |
+
name: train-abinet
|
3 |
+
phase: train
|
4 |
+
stage: train-super
|
5 |
+
workdir: workdir
|
6 |
+
seed: ~
|
7 |
+
|
8 |
+
dataset:
|
9 |
+
train: {
|
10 |
+
roots: [
|
11 |
+
'output_tbell_dataset2/',
|
12 |
+
# 'data/training/MJ/MJ_train/',
|
13 |
+
# 'data/training/MJ/MJ_test/',
|
14 |
+
# 'data/training/MJ/MJ_valid/',
|
15 |
+
# 'data/training/ST'
|
16 |
+
],
|
17 |
+
batch_size: 20
|
18 |
+
}
|
19 |
+
test: {
|
20 |
+
roots: [
|
21 |
+
'output_tbell_dataset2/'
|
22 |
+
],
|
23 |
+
batch_size: 2
|
24 |
+
}
|
25 |
+
data_aug: True
|
26 |
+
multiscales: False
|
27 |
+
num_workers: 1
|
28 |
+
|
29 |
+
training:
|
30 |
+
epochs: 100000
|
31 |
+
show_iters: 50
|
32 |
+
eval_iters: 600
|
33 |
+
# save_iters: 3000
|
34 |
+
|
35 |
+
optimizer:
|
36 |
+
type: AdamW
|
37 |
+
true_wd: False
|
38 |
+
wd: 0.0
|
39 |
+
bn_wd: False
|
40 |
+
clip_grad: 20
|
41 |
+
lr: 0.0001
|
42 |
+
args: {
|
43 |
+
betas: !!python/tuple [0.9, 0.999], # for default Adam
|
44 |
+
}
|
45 |
+
scheduler: {
|
46 |
+
periods: [6, 4],
|
47 |
+
gamma: 0.1,
|
48 |
+
}
|
49 |
+
|
50 |
+
model:
|
51 |
+
name: 'modules.model_abinet_iter.ABINetIterModel'
|
52 |
+
iter_size: 3
|
53 |
+
ensemble: ''
|
54 |
+
use_vision: True
|
55 |
+
vision: {
|
56 |
+
checkpoint: workdir/pretrain-vision-model/best-pretrain-vision-model.pth,
|
57 |
+
loss_weight: 1.,
|
58 |
+
attention: 'position',
|
59 |
+
backbone: 'transformer',
|
60 |
+
backbone_ln: 3,
|
61 |
+
}
|
62 |
+
language: {
|
63 |
+
checkpoint: workdir/pretrain-language-model/pretrain-language-model.pth,
|
64 |
+
num_layers: 4,
|
65 |
+
loss_weight: 1.,
|
66 |
+
detach: True,
|
67 |
+
use_self_attn: False
|
68 |
+
}
|
69 |
+
alignment: {
|
70 |
+
loss_weight: 1.,
|
71 |
+
}
|
dataset.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import re
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import lmdb
|
6 |
+
import six
|
7 |
+
from fastai.vision import *
|
8 |
+
from torchvision import transforms
|
9 |
+
|
10 |
+
from transforms import CVColorJitter, CVDeterioration, CVGeometry
|
11 |
+
from utils import CharsetMapper, onehot
|
12 |
+
|
13 |
+
|
14 |
+
class ImageDataset(Dataset):
|
15 |
+
"`ImageDataset` read data from LMDB database."
|
16 |
+
|
17 |
+
def __init__(self,
|
18 |
+
path:PathOrStr,
|
19 |
+
is_training:bool=True,
|
20 |
+
img_h:int=32,
|
21 |
+
img_w:int=100,
|
22 |
+
max_length:int=25,
|
23 |
+
check_length:bool=True,
|
24 |
+
case_sensitive:bool=False,
|
25 |
+
charset_path:str='data/charset_36.txt',
|
26 |
+
convert_mode:str='RGB',
|
27 |
+
data_aug:bool=True,
|
28 |
+
deteriorate_ratio:float=0.,
|
29 |
+
multiscales:bool=True,
|
30 |
+
one_hot_y:bool=True,
|
31 |
+
return_idx:bool=False,
|
32 |
+
return_raw:bool=False,
|
33 |
+
**kwargs):
|
34 |
+
self.path, self.name = Path(path), Path(path).name
|
35 |
+
assert self.path.is_dir() and self.path.exists(), f"{path} is not a valid directory."
|
36 |
+
self.convert_mode, self.check_length = convert_mode, check_length
|
37 |
+
self.img_h, self.img_w = img_h, img_w
|
38 |
+
self.max_length, self.one_hot_y = max_length, one_hot_y
|
39 |
+
self.return_idx, self.return_raw = return_idx, return_raw
|
40 |
+
self.case_sensitive, self.is_training = case_sensitive, is_training
|
41 |
+
self.data_aug, self.multiscales = data_aug, multiscales
|
42 |
+
self.charset = CharsetMapper(charset_path, max_length=max_length+1)
|
43 |
+
self.c = self.charset.num_classes
|
44 |
+
|
45 |
+
self.env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False)
|
46 |
+
assert self.env, f'Cannot open LMDB dataset from {path}.'
|
47 |
+
with self.env.begin(write=False) as txn:
|
48 |
+
self.length = int(txn.get('num-samples'.encode()))
|
49 |
+
|
50 |
+
if self.is_training and self.data_aug:
|
51 |
+
self.augment_tfs = transforms.Compose([
|
52 |
+
CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5),
|
53 |
+
CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
|
54 |
+
CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25)
|
55 |
+
])
|
56 |
+
self.totensor = transforms.ToTensor()
|
57 |
+
|
58 |
+
def __len__(self): return self.length
|
59 |
+
|
60 |
+
def _next_image(self, index):
|
61 |
+
next_index = random.randint(0, len(self) - 1)
|
62 |
+
return self.get(next_index)
|
63 |
+
|
64 |
+
def _check_image(self, x, pixels=6):
|
65 |
+
if x.size[0] <= pixels or x.size[1] <= pixels: return False
|
66 |
+
else: return True
|
67 |
+
|
68 |
+
def resize_multiscales(self, img, borderType=cv2.BORDER_CONSTANT):
|
69 |
+
def _resize_ratio(img, ratio, fix_h=True):
|
70 |
+
if ratio * self.img_w < self.img_h:
|
71 |
+
if fix_h: trg_h = self.img_h
|
72 |
+
else: trg_h = int(ratio * self.img_w)
|
73 |
+
trg_w = self.img_w
|
74 |
+
else: trg_h, trg_w = self.img_h, int(self.img_h / ratio)
|
75 |
+
img = cv2.resize(img, (trg_w, trg_h))
|
76 |
+
pad_h, pad_w = (self.img_h - trg_h) / 2, (self.img_w - trg_w) / 2
|
77 |
+
top, bottom = math.ceil(pad_h), math.floor(pad_h)
|
78 |
+
left, right = math.ceil(pad_w), math.floor(pad_w)
|
79 |
+
img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType)
|
80 |
+
return img
|
81 |
+
|
82 |
+
if self.is_training:
|
83 |
+
if random.random() < 0.5:
|
84 |
+
base, maxh, maxw = self.img_h, self.img_h, self.img_w
|
85 |
+
h, w = random.randint(base, maxh), random.randint(base, maxw)
|
86 |
+
return _resize_ratio(img, h/w)
|
87 |
+
else: return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio
|
88 |
+
else: return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio
|
89 |
+
|
90 |
+
def resize(self, img):
|
91 |
+
if self.multiscales: return self.resize_multiscales(img, cv2.BORDER_REPLICATE)
|
92 |
+
else: return cv2.resize(img, (self.img_w, self.img_h))
|
93 |
+
|
94 |
+
def get(self, idx):
|
95 |
+
with self.env.begin(write=False) as txn:
|
96 |
+
image_key, label_key = f'image-{idx+1:09d}', f'label-{idx+1:09d}'
|
97 |
+
try:
|
98 |
+
label = str(txn.get(label_key.encode()), 'utf-8') # label
|
99 |
+
label = re.sub('[^0-9a-zA-Z]+', '', label)
|
100 |
+
if self.check_length and self.max_length > 0:
|
101 |
+
if len(label) > self.max_length or len(label) <= 0:
|
102 |
+
#logging.info(f'Long or short text image is found: {self.name}, {idx}, {label}, {len(label)}')
|
103 |
+
return self._next_image(idx)
|
104 |
+
label = label[:self.max_length]
|
105 |
+
|
106 |
+
imgbuf = txn.get(image_key.encode()) # image
|
107 |
+
buf = six.BytesIO()
|
108 |
+
buf.write(imgbuf)
|
109 |
+
buf.seek(0)
|
110 |
+
with warnings.catch_warnings():
|
111 |
+
warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin
|
112 |
+
image = PIL.Image.open(buf).convert(self.convert_mode)
|
113 |
+
if self.is_training and not self._check_image(image):
|
114 |
+
#logging.info(f'Invalid image is found: {self.name}, {idx}, {label}, {len(label)}')
|
115 |
+
return self._next_image(idx)
|
116 |
+
except:
|
117 |
+
import traceback
|
118 |
+
traceback.print_exc()
|
119 |
+
logging.info(f'Corrupted image is found: {self.name}, {idx}, {label}, {len(label)}')
|
120 |
+
return self._next_image(idx)
|
121 |
+
return image, label, idx
|
122 |
+
|
123 |
+
def _process_training(self, image):
|
124 |
+
if self.data_aug: image = self.augment_tfs(image)
|
125 |
+
image = self.resize(np.array(image))
|
126 |
+
return image
|
127 |
+
|
128 |
+
def _process_test(self, image):
|
129 |
+
return self.resize(np.array(image)) # TODO:move is_training to here
|
130 |
+
|
131 |
+
def __getitem__(self, idx):
|
132 |
+
image, text, idx_new = self.get(idx)
|
133 |
+
if not self.is_training: assert idx == idx_new, f'idx {idx} != idx_new {idx_new} during testing.'
|
134 |
+
|
135 |
+
if self.is_training: image = self._process_training(image)
|
136 |
+
else: image = self._process_test(image)
|
137 |
+
if self.return_raw: return image, text
|
138 |
+
image = self.totensor(image)
|
139 |
+
|
140 |
+
length = tensor(len(text) + 1).to(dtype=torch.long) # one for end token
|
141 |
+
label = self.charset.get_labels(text, case_sensitive=self.case_sensitive)
|
142 |
+
label = tensor(label).to(dtype=torch.long)
|
143 |
+
if self.one_hot_y: label = onehot(label, self.charset.num_classes)
|
144 |
+
|
145 |
+
if self.return_idx: y = [label, length, idx_new]
|
146 |
+
else: y = [label, length]
|
147 |
+
return image, y
|
148 |
+
|
149 |
+
|
150 |
+
class TextDataset(Dataset):
|
151 |
+
def __init__(self,
|
152 |
+
path:PathOrStr,
|
153 |
+
delimiter:str='\t',
|
154 |
+
max_length:int=25,
|
155 |
+
charset_path:str='data/charset_36.txt',
|
156 |
+
case_sensitive=False,
|
157 |
+
one_hot_x=True,
|
158 |
+
one_hot_y=True,
|
159 |
+
is_training=True,
|
160 |
+
smooth_label=False,
|
161 |
+
smooth_factor=0.2,
|
162 |
+
use_sm=False,
|
163 |
+
**kwargs):
|
164 |
+
self.path = Path(path)
|
165 |
+
self.case_sensitive, self.use_sm = case_sensitive, use_sm
|
166 |
+
self.smooth_factor, self.smooth_label = smooth_factor, smooth_label
|
167 |
+
self.charset = CharsetMapper(charset_path, max_length=max_length+1)
|
168 |
+
self.one_hot_x, self.one_hot_y, self.is_training = one_hot_x, one_hot_y, is_training
|
169 |
+
if self.is_training and self.use_sm: self.sm = SpellingMutation(charset=self.charset)
|
170 |
+
|
171 |
+
dtype = {'inp': str, 'gt': str}
|
172 |
+
self.df = pd.read_csv(self.path, dtype=dtype, delimiter=delimiter, na_filter=False)
|
173 |
+
self.inp_col, self.gt_col = 0, 1
|
174 |
+
|
175 |
+
def __len__(self): return len(self.df)
|
176 |
+
|
177 |
+
def __getitem__(self, idx):
|
178 |
+
text_x = self.df.iloc[idx, self.inp_col]
|
179 |
+
text_x = re.sub('[^0-9a-zA-Z]+', '', text_x)
|
180 |
+
if not self.case_sensitive: text_x = text_x.lower()
|
181 |
+
if self.is_training and self.use_sm: text_x = self.sm(text_x)
|
182 |
+
|
183 |
+
length_x = tensor(len(text_x) + 1).to(dtype=torch.long) # one for end token
|
184 |
+
label_x = self.charset.get_labels(text_x, case_sensitive=self.case_sensitive)
|
185 |
+
label_x = tensor(label_x)
|
186 |
+
if self.one_hot_x:
|
187 |
+
label_x = onehot(label_x, self.charset.num_classes)
|
188 |
+
if self.is_training and self.smooth_label:
|
189 |
+
label_x = torch.stack([self.prob_smooth_label(l) for l in label_x])
|
190 |
+
x = [label_x, length_x]
|
191 |
+
|
192 |
+
text_y = self.df.iloc[idx, self.gt_col]
|
193 |
+
text_y = re.sub('[^0-9a-zA-Z]+', '', text_y)
|
194 |
+
if not self.case_sensitive: text_y = text_y.lower()
|
195 |
+
length_y = tensor(len(text_y) + 1).to(dtype=torch.long) # one for end token
|
196 |
+
label_y = self.charset.get_labels(text_y, case_sensitive=self.case_sensitive)
|
197 |
+
label_y = tensor(label_y)
|
198 |
+
if self.one_hot_y: label_y = onehot(label_y, self.charset.num_classes)
|
199 |
+
y = [label_y, length_y]
|
200 |
+
|
201 |
+
return x, y
|
202 |
+
|
203 |
+
def prob_smooth_label(self, one_hot):
|
204 |
+
one_hot = one_hot.float()
|
205 |
+
delta = torch.rand([]) * self.smooth_factor
|
206 |
+
num_classes = len(one_hot)
|
207 |
+
noise = torch.rand(num_classes)
|
208 |
+
noise = noise / noise.sum() * delta
|
209 |
+
one_hot = one_hot * (1 - delta) + noise
|
210 |
+
return one_hot
|
211 |
+
|
212 |
+
|
213 |
+
class SpellingMutation(object):
|
214 |
+
def __init__(self, pn0=0.7, pn1=0.85, pn2=0.95, pt0=0.7, pt1=0.85, charset=None):
|
215 |
+
"""
|
216 |
+
Args:
|
217 |
+
pn0: the prob of not modifying characters is (pn0)
|
218 |
+
pn1: the prob of modifying one characters is (pn1 - pn0)
|
219 |
+
pn2: the prob of modifying two characters is (pn2 - pn1),
|
220 |
+
and three (1 - pn2)
|
221 |
+
pt0: the prob of replacing operation is pt0.
|
222 |
+
pt1: the prob of inserting operation is (pt1 - pt0),
|
223 |
+
and deleting operation is (1 - pt1)
|
224 |
+
"""
|
225 |
+
super().__init__()
|
226 |
+
self.pn0, self.pn1, self.pn2 = pn0, pn1, pn2
|
227 |
+
self.pt0, self.pt1 = pt0, pt1
|
228 |
+
self.charset = charset
|
229 |
+
logging.info(f'the probs: pn0={self.pn0}, pn1={self.pn1} ' +
|
230 |
+
f'pn2={self.pn2}, pt0={self.pt0}, pt1={self.pt1}')
|
231 |
+
|
232 |
+
def is_digit(self, text, ratio=0.5):
|
233 |
+
length = max(len(text), 1)
|
234 |
+
digit_num = sum([t in self.charset.digits for t in text])
|
235 |
+
if digit_num / length < ratio: return False
|
236 |
+
return True
|
237 |
+
|
238 |
+
def is_unk_char(self, char):
|
239 |
+
# return char == self.charset.unk_char
|
240 |
+
return (char not in self.charset.digits) and (char not in self.charset.alphabets)
|
241 |
+
|
242 |
+
def get_num_to_modify(self, length):
|
243 |
+
prob = random.random()
|
244 |
+
if prob < self.pn0: num_to_modify = 0
|
245 |
+
elif prob < self.pn1: num_to_modify = 1
|
246 |
+
elif prob < self.pn2: num_to_modify = 2
|
247 |
+
else: num_to_modify = 3
|
248 |
+
|
249 |
+
if length <= 1: num_to_modify = 0
|
250 |
+
elif length >= 2 and length <= 4: num_to_modify = min(num_to_modify, 1)
|
251 |
+
else: num_to_modify = min(num_to_modify, length // 2) # smaller than length // 2
|
252 |
+
return num_to_modify
|
253 |
+
|
254 |
+
def __call__(self, text, debug=False):
|
255 |
+
if self.is_digit(text): return text
|
256 |
+
length = len(text)
|
257 |
+
num_to_modify = self.get_num_to_modify(length)
|
258 |
+
if num_to_modify <= 0: return text
|
259 |
+
|
260 |
+
chars = []
|
261 |
+
index = np.arange(0, length)
|
262 |
+
random.shuffle(index)
|
263 |
+
index = index[: num_to_modify]
|
264 |
+
if debug: self.index = index
|
265 |
+
for i, t in enumerate(text):
|
266 |
+
if i not in index: chars.append(t)
|
267 |
+
elif self.is_unk_char(t): chars.append(t)
|
268 |
+
else:
|
269 |
+
prob = random.random()
|
270 |
+
if prob < self.pt0: # replace
|
271 |
+
chars.append(random.choice(self.charset.alphabets))
|
272 |
+
elif prob < self.pt1: # insert
|
273 |
+
chars.append(random.choice(self.charset.alphabets))
|
274 |
+
chars.append(t)
|
275 |
+
else: # delete
|
276 |
+
continue
|
277 |
+
new_text = ''.join(chars[: self.charset.max_length-1])
|
278 |
+
return new_text if len(new_text) >= 1 else text
|
losses.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.vision import *
|
2 |
+
|
3 |
+
from modules.model import Model
|
4 |
+
|
5 |
+
|
6 |
+
class MultiLosses(nn.Module):
|
7 |
+
def __init__(self, one_hot=True):
|
8 |
+
super().__init__()
|
9 |
+
self.ce = SoftCrossEntropyLoss() if one_hot else torch.nn.CrossEntropyLoss()
|
10 |
+
self.bce = torch.nn.BCELoss()
|
11 |
+
|
12 |
+
@property
|
13 |
+
def last_losses(self):
|
14 |
+
return self.losses
|
15 |
+
|
16 |
+
def _flatten(self, sources, lengths):
|
17 |
+
return torch.cat([t[:l] for t, l in zip(sources, lengths)])
|
18 |
+
|
19 |
+
def _merge_list(self, all_res):
|
20 |
+
if not isinstance(all_res, (list, tuple)):
|
21 |
+
return all_res
|
22 |
+
def merge(items):
|
23 |
+
if isinstance(items[0], torch.Tensor): return torch.cat(items, dim=0)
|
24 |
+
else: return items[0]
|
25 |
+
res = dict()
|
26 |
+
for key in all_res[0].keys():
|
27 |
+
items = [r[key] for r in all_res]
|
28 |
+
res[key] = merge(items)
|
29 |
+
return res
|
30 |
+
|
31 |
+
def _ce_loss(self, output, gt_labels, gt_lengths, idx=None, record=True):
|
32 |
+
loss_name = output.get('name')
|
33 |
+
pt_logits, weight = output['logits'], output['loss_weight']
|
34 |
+
|
35 |
+
assert pt_logits.shape[0] % gt_labels.shape[0] == 0
|
36 |
+
iter_size = pt_logits.shape[0] // gt_labels.shape[0]
|
37 |
+
if iter_size > 1:
|
38 |
+
gt_labels = gt_labels.repeat(3, 1, 1)
|
39 |
+
gt_lengths = gt_lengths.repeat(3)
|
40 |
+
flat_gt_labels = self._flatten(gt_labels, gt_lengths)
|
41 |
+
flat_pt_logits = self._flatten(pt_logits, gt_lengths)
|
42 |
+
|
43 |
+
nll = output.get('nll')
|
44 |
+
if nll is not None:
|
45 |
+
loss = self.ce(flat_pt_logits, flat_gt_labels, softmax=False) * weight
|
46 |
+
else:
|
47 |
+
loss = self.ce(flat_pt_logits, flat_gt_labels) * weight
|
48 |
+
if record and loss_name is not None: self.losses[f'{loss_name}_loss'] = loss
|
49 |
+
|
50 |
+
return loss
|
51 |
+
|
52 |
+
def forward(self, outputs, *args):
|
53 |
+
self.losses = {}
|
54 |
+
if isinstance(outputs, (tuple, list)):
|
55 |
+
outputs = [self._merge_list(o) for o in outputs]
|
56 |
+
return sum([self._ce_loss(o, *args) for o in outputs if o['loss_weight'] > 0.])
|
57 |
+
else:
|
58 |
+
return self._ce_loss(outputs, *args, record=False)
|
59 |
+
|
60 |
+
|
61 |
+
class SoftCrossEntropyLoss(nn.Module):
|
62 |
+
def __init__(self, reduction="mean"):
|
63 |
+
super().__init__()
|
64 |
+
self.reduction = reduction
|
65 |
+
|
66 |
+
def forward(self, input, target, softmax=True):
|
67 |
+
if softmax: log_prob = F.log_softmax(input, dim=-1)
|
68 |
+
else: log_prob = torch.log(input)
|
69 |
+
loss = -(target * log_prob).sum(dim=-1)
|
70 |
+
if self.reduction == "mean": return loss.mean()
|
71 |
+
elif self.reduction == "sum": return loss.sum()
|
72 |
+
else: return loss
|
main.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from fastai.callbacks.general_sched import GeneralScheduler, TrainingPhase
|
8 |
+
from fastai.distributed import *
|
9 |
+
from fastai.vision import *
|
10 |
+
from torch.backends import cudnn
|
11 |
+
|
12 |
+
from callbacks import DumpPrediction, IterationCallback, TextAccuracy, TopKTextAccuracy
|
13 |
+
from dataset import ImageDataset, TextDataset
|
14 |
+
from losses import MultiLosses
|
15 |
+
from utils import Config, Logger, MyDataParallel, MyConcatDataset
|
16 |
+
|
17 |
+
|
18 |
+
def _set_random_seed(seed):
|
19 |
+
if seed is not None:
|
20 |
+
random.seed(seed)
|
21 |
+
torch.manual_seed(seed)
|
22 |
+
cudnn.deterministic = True
|
23 |
+
logging.warning('You have chosen to seed training. '
|
24 |
+
'This will slow down your training!')
|
25 |
+
|
26 |
+
def _get_training_phases(config, n):
|
27 |
+
lr = np.array(config.optimizer_lr)
|
28 |
+
periods = config.optimizer_scheduler_periods
|
29 |
+
sigma = [config.optimizer_scheduler_gamma ** i for i in range(len(periods))]
|
30 |
+
phases = [TrainingPhase(n * periods[i]).schedule_hp('lr', lr * sigma[i])
|
31 |
+
for i in range(len(periods))]
|
32 |
+
return phases
|
33 |
+
|
34 |
+
def _get_dataset(ds_type, paths, is_training, config, **kwargs):
|
35 |
+
kwargs.update({
|
36 |
+
'img_h': config.dataset_image_height,
|
37 |
+
'img_w': config.dataset_image_width,
|
38 |
+
'max_length': config.dataset_max_length,
|
39 |
+
'case_sensitive': config.dataset_case_sensitive,
|
40 |
+
'charset_path': config.dataset_charset_path,
|
41 |
+
'data_aug': config.dataset_data_aug,
|
42 |
+
'deteriorate_ratio': config.dataset_deteriorate_ratio,
|
43 |
+
'is_training': is_training,
|
44 |
+
'multiscales': config.dataset_multiscales,
|
45 |
+
'one_hot_y': config.dataset_one_hot_y,
|
46 |
+
})
|
47 |
+
datasets = [ds_type(p, **kwargs) for p in paths]
|
48 |
+
if len(datasets) > 1: return MyConcatDataset(datasets)
|
49 |
+
else: return datasets[0]
|
50 |
+
|
51 |
+
|
52 |
+
def _get_language_databaunch(config):
|
53 |
+
kwargs = {
|
54 |
+
'max_length': config.dataset_max_length,
|
55 |
+
'case_sensitive': config.dataset_case_sensitive,
|
56 |
+
'charset_path': config.dataset_charset_path,
|
57 |
+
'smooth_label': config.dataset_smooth_label,
|
58 |
+
'smooth_factor': config.dataset_smooth_factor,
|
59 |
+
'one_hot_y': config.dataset_one_hot_y,
|
60 |
+
'use_sm': config.dataset_use_sm,
|
61 |
+
}
|
62 |
+
train_ds = TextDataset(config.dataset_train_roots[0], is_training=True, **kwargs)
|
63 |
+
valid_ds = TextDataset(config.dataset_test_roots[0], is_training=False, **kwargs)
|
64 |
+
data = DataBunch.create(
|
65 |
+
path=train_ds.path,
|
66 |
+
train_ds=train_ds,
|
67 |
+
valid_ds=valid_ds,
|
68 |
+
bs=config.dataset_train_batch_size,
|
69 |
+
val_bs=config.dataset_test_batch_size,
|
70 |
+
num_workers=config.dataset_num_workers,
|
71 |
+
pin_memory=config.dataset_pin_memory)
|
72 |
+
logging.info(f'{len(data.train_ds)} training items found.')
|
73 |
+
if not data.empty_val:
|
74 |
+
logging.info(f'{len(data.valid_ds)} valid items found.')
|
75 |
+
return data
|
76 |
+
|
77 |
+
def _get_databaunch(config):
|
78 |
+
# An awkward way to reduce loadding data time during test
|
79 |
+
if config.global_phase == 'test': config.dataset_train_roots = config.dataset_test_roots
|
80 |
+
train_ds = _get_dataset(ImageDataset, config.dataset_train_roots, True, config)
|
81 |
+
valid_ds = _get_dataset(ImageDataset, config.dataset_test_roots, False, config)
|
82 |
+
data = ImageDataBunch.create(
|
83 |
+
train_ds=train_ds,
|
84 |
+
valid_ds=valid_ds,
|
85 |
+
bs=config.dataset_train_batch_size,
|
86 |
+
val_bs=config.dataset_test_batch_size,
|
87 |
+
num_workers=config.dataset_num_workers,
|
88 |
+
pin_memory=config.dataset_pin_memory).normalize(imagenet_stats)
|
89 |
+
ar_tfm = lambda x: ((x[0], x[1]), x[1]) # auto-regression only for dtd
|
90 |
+
data.add_tfm(ar_tfm)
|
91 |
+
|
92 |
+
logging.info(f'{len(data.train_ds)} training items found.')
|
93 |
+
if not data.empty_val:
|
94 |
+
logging.info(f'{len(data.valid_ds)} valid items found.')
|
95 |
+
|
96 |
+
return data
|
97 |
+
|
98 |
+
def _get_model(config):
|
99 |
+
import importlib
|
100 |
+
names = config.model_name.split('.')
|
101 |
+
module_name, class_name = '.'.join(names[:-1]), names[-1]
|
102 |
+
cls = getattr(importlib.import_module(module_name), class_name)
|
103 |
+
model = cls(config)
|
104 |
+
logging.info(model)
|
105 |
+
return model
|
106 |
+
|
107 |
+
|
108 |
+
def _get_learner(config, data, model, local_rank=None):
|
109 |
+
strict = ifnone(config.model_strict, True)
|
110 |
+
if config.global_stage == 'pretrain-language':
|
111 |
+
metrics = [TopKTextAccuracy(
|
112 |
+
k=ifnone(config.model_k, 5),
|
113 |
+
charset_path=config.dataset_charset_path,
|
114 |
+
max_length=config.dataset_max_length + 1,
|
115 |
+
case_sensitive=config.dataset_eval_case_sensisitves,
|
116 |
+
model_eval=config.model_eval)]
|
117 |
+
else:
|
118 |
+
metrics = [TextAccuracy(
|
119 |
+
charset_path=config.dataset_charset_path,
|
120 |
+
max_length=config.dataset_max_length + 1,
|
121 |
+
case_sensitive=config.dataset_eval_case_sensisitves,
|
122 |
+
model_eval=config.model_eval)]
|
123 |
+
opt_type = getattr(torch.optim, config.optimizer_type)
|
124 |
+
learner = Learner(data, model, silent=True, model_dir='.',
|
125 |
+
true_wd=config.optimizer_true_wd,
|
126 |
+
wd=config.optimizer_wd,
|
127 |
+
bn_wd=config.optimizer_bn_wd,
|
128 |
+
path=config.global_workdir,
|
129 |
+
metrics=metrics,
|
130 |
+
opt_func=partial(opt_type, **config.optimizer_args or dict()),
|
131 |
+
loss_func=MultiLosses(one_hot=config.dataset_one_hot_y))
|
132 |
+
learner.split(lambda m: children(m))
|
133 |
+
|
134 |
+
if config.global_phase == 'train':
|
135 |
+
num_replicas = 1 if local_rank is None else torch.distributed.get_world_size()
|
136 |
+
phases = _get_training_phases(config, len(learner.data.train_dl)//num_replicas)
|
137 |
+
learner.callback_fns += [
|
138 |
+
partial(GeneralScheduler, phases=phases),
|
139 |
+
partial(GradientClipping, clip=config.optimizer_clip_grad),
|
140 |
+
partial(IterationCallback, name=config.global_name,
|
141 |
+
show_iters=config.training_show_iters,
|
142 |
+
eval_iters=config.training_eval_iters,
|
143 |
+
save_iters=config.training_save_iters,
|
144 |
+
start_iters=config.training_start_iters,
|
145 |
+
stats_iters=config.training_stats_iters)]
|
146 |
+
else:
|
147 |
+
learner.callbacks += [
|
148 |
+
DumpPrediction(learn=learner,
|
149 |
+
dataset='-'.join([Path(p).name for p in config.dataset_test_roots]),charset_path=config.dataset_charset_path,
|
150 |
+
model_eval=config.model_eval,
|
151 |
+
debug=config.global_debug,
|
152 |
+
image_only=config.global_image_only)]
|
153 |
+
|
154 |
+
learner.rank = local_rank
|
155 |
+
if local_rank is not None:
|
156 |
+
logging.info(f'Set model to distributed with rank {local_rank}.')
|
157 |
+
learner.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(learner.model)
|
158 |
+
learner.model.to(local_rank)
|
159 |
+
learner = learner.to_distributed(local_rank)
|
160 |
+
|
161 |
+
if torch.cuda.device_count() > 1 and local_rank is None:
|
162 |
+
logging.info(f'Use {torch.cuda.device_count()} GPUs.')
|
163 |
+
learner.model = MyDataParallel(learner.model)
|
164 |
+
|
165 |
+
if config.model_checkpoint:
|
166 |
+
if Path(config.model_checkpoint).exists():
|
167 |
+
with open(config.model_checkpoint, 'rb') as f:
|
168 |
+
buffer = io.BytesIO(f.read())
|
169 |
+
learner.load(buffer, strict=strict)
|
170 |
+
else:
|
171 |
+
from distutils.dir_util import copy_tree
|
172 |
+
src = Path('/data/fangsc/model')/config.global_name
|
173 |
+
trg = Path('/output')/config.global_name
|
174 |
+
if src.exists(): copy_tree(str(src), str(trg))
|
175 |
+
learner.load(config.model_checkpoint, strict=strict)
|
176 |
+
logging.info(f'Read model from {config.model_checkpoint}')
|
177 |
+
elif config.global_phase == 'test':
|
178 |
+
learner.load(f'best-{config.global_name}', strict=strict)
|
179 |
+
logging.info(f'Read model from best-{config.global_name}')
|
180 |
+
|
181 |
+
if learner.opt_func.func.__name__ == 'Adadelta': # fastai bug, fix after 1.0.60
|
182 |
+
learner.fit(epochs=0, lr=config.optimizer_lr)
|
183 |
+
learner.opt.mom = 0.
|
184 |
+
|
185 |
+
return learner
|
186 |
+
|
187 |
+
def main():
|
188 |
+
parser = argparse.ArgumentParser()
|
189 |
+
parser.add_argument('--config', type=str, required=True,
|
190 |
+
help='path to config file')
|
191 |
+
parser.add_argument('--phase', type=str, default=None, choices=['train', 'test'])
|
192 |
+
parser.add_argument('--name', type=str, default=None)
|
193 |
+
parser.add_argument('--checkpoint', type=str, default=None)
|
194 |
+
parser.add_argument('--test_root', type=str, default=None)
|
195 |
+
parser.add_argument("--local_rank", type=int, default=None)
|
196 |
+
parser.add_argument('--debug', action='store_true', default=None)
|
197 |
+
parser.add_argument('--image_only', action='store_true', default=None)
|
198 |
+
parser.add_argument('--model_strict', action='store_false', default=None)
|
199 |
+
parser.add_argument('--model_eval', type=str, default=None,
|
200 |
+
choices=['alignment', 'vision', 'language'])
|
201 |
+
args = parser.parse_args()
|
202 |
+
config = Config(args.config)
|
203 |
+
if args.name is not None: config.global_name = args.name
|
204 |
+
if args.phase is not None: config.global_phase = args.phase
|
205 |
+
if args.test_root is not None: config.dataset_test_roots = [args.test_root]
|
206 |
+
if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
|
207 |
+
if args.debug is not None: config.global_debug = args.debug
|
208 |
+
if args.image_only is not None: config.global_image_only = args.image_only
|
209 |
+
if args.model_eval is not None: config.model_eval = args.model_eval
|
210 |
+
if args.model_strict is not None: config.model_strict = args.model_strict
|
211 |
+
|
212 |
+
Logger.init(config.global_workdir, config.global_name, config.global_phase)
|
213 |
+
Logger.enable_file()
|
214 |
+
_set_random_seed(config.global_seed)
|
215 |
+
logging.info(config)
|
216 |
+
|
217 |
+
if args.local_rank is not None:
|
218 |
+
logging.info(f'Init distribution training at device {args.local_rank}.')
|
219 |
+
torch.cuda.set_device(args.local_rank)
|
220 |
+
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
221 |
+
|
222 |
+
logging.info('Construct dataset.')
|
223 |
+
if config.global_stage == 'pretrain-language': data = _get_language_databaunch(config)
|
224 |
+
else: data = _get_databaunch(config)
|
225 |
+
|
226 |
+
logging.info('Construct model.')
|
227 |
+
model = _get_model(config)
|
228 |
+
|
229 |
+
logging.info('Construct learner.')
|
230 |
+
learner = _get_learner(config, data, model, args.local_rank)
|
231 |
+
|
232 |
+
if config.global_phase == 'train':
|
233 |
+
logging.info('Start training.')
|
234 |
+
learner.fit(epochs=config.training_epochs,
|
235 |
+
lr=config.optimizer_lr)
|
236 |
+
else:
|
237 |
+
logging.info('Start validate')
|
238 |
+
last_metrics = learner.validate()
|
239 |
+
log_str = f'eval loss = {last_metrics[0]:6.3f}, ' \
|
240 |
+
f'ccr = {last_metrics[1]:6.3f}, cwr = {last_metrics[2]:6.3f}, ' \
|
241 |
+
f'ted = {last_metrics[3]:6.3f}, ned = {last_metrics[4]:6.0f}, ' \
|
242 |
+
f'ted/w = {last_metrics[5]:6.3f}, '
|
243 |
+
logging.info(log_str)
|
244 |
+
|
245 |
+
if __name__ == '__main__':
|
246 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
fastai==1.0.60
|
4 |
+
LMDB
|
5 |
+
Pillow
|
6 |
+
opencv-python
|
7 |
+
tensorboardX
|
8 |
+
PyYAML
|
9 |
+
gdown
|
transforms.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numbers
|
3 |
+
import random
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from torchvision import transforms
|
9 |
+
from torchvision.transforms import Compose
|
10 |
+
|
11 |
+
|
12 |
+
def sample_asym(magnitude, size=None):
|
13 |
+
return np.random.beta(1, 4, size) * magnitude
|
14 |
+
|
15 |
+
def sample_sym(magnitude, size=None):
|
16 |
+
return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
|
17 |
+
|
18 |
+
def sample_uniform(low, high, size=None):
|
19 |
+
return np.random.uniform(low, high, size=size)
|
20 |
+
|
21 |
+
def get_interpolation(type='random'):
|
22 |
+
if type == 'random':
|
23 |
+
choice = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA]
|
24 |
+
interpolation = choice[random.randint(0, len(choice)-1)]
|
25 |
+
elif type == 'nearest': interpolation = cv2.INTER_NEAREST
|
26 |
+
elif type == 'linear': interpolation = cv2.INTER_LINEAR
|
27 |
+
elif type == 'cubic': interpolation = cv2.INTER_CUBIC
|
28 |
+
elif type == 'area': interpolation = cv2.INTER_AREA
|
29 |
+
else: raise TypeError('Interpolation types only nearest, linear, cubic, area are supported!')
|
30 |
+
return interpolation
|
31 |
+
|
32 |
+
class CVRandomRotation(object):
|
33 |
+
def __init__(self, degrees=15):
|
34 |
+
assert isinstance(degrees, numbers.Number), "degree should be a single number."
|
35 |
+
assert degrees >= 0, "degree must be positive."
|
36 |
+
self.degrees = degrees
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def get_params(degrees):
|
40 |
+
return sample_sym(degrees)
|
41 |
+
|
42 |
+
def __call__(self, img):
|
43 |
+
angle = self.get_params(self.degrees)
|
44 |
+
src_h, src_w = img.shape[:2]
|
45 |
+
M = cv2.getRotationMatrix2D(center=(src_w/2, src_h/2), angle=angle, scale=1.0)
|
46 |
+
abs_cos, abs_sin = abs(M[0,0]), abs(M[0,1])
|
47 |
+
dst_w = int(src_h * abs_sin + src_w * abs_cos)
|
48 |
+
dst_h = int(src_h * abs_cos + src_w * abs_sin)
|
49 |
+
M[0, 2] += (dst_w - src_w)/2
|
50 |
+
M[1, 2] += (dst_h - src_h)/2
|
51 |
+
|
52 |
+
flags = get_interpolation()
|
53 |
+
return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE)
|
54 |
+
|
55 |
+
class CVRandomAffine(object):
|
56 |
+
def __init__(self, degrees, translate=None, scale=None, shear=None):
|
57 |
+
assert isinstance(degrees, numbers.Number), "degree should be a single number."
|
58 |
+
assert degrees >= 0, "degree must be positive."
|
59 |
+
self.degrees = degrees
|
60 |
+
|
61 |
+
if translate is not None:
|
62 |
+
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
|
63 |
+
"translate should be a list or tuple and it must be of length 2."
|
64 |
+
for t in translate:
|
65 |
+
if not (0.0 <= t <= 1.0):
|
66 |
+
raise ValueError("translation values should be between 0 and 1")
|
67 |
+
self.translate = translate
|
68 |
+
|
69 |
+
if scale is not None:
|
70 |
+
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
|
71 |
+
"scale should be a list or tuple and it must be of length 2."
|
72 |
+
for s in scale:
|
73 |
+
if s <= 0:
|
74 |
+
raise ValueError("scale values should be positive")
|
75 |
+
self.scale = scale
|
76 |
+
|
77 |
+
if shear is not None:
|
78 |
+
if isinstance(shear, numbers.Number):
|
79 |
+
if shear < 0:
|
80 |
+
raise ValueError("If shear is a single number, it must be positive.")
|
81 |
+
self.shear = [shear]
|
82 |
+
else:
|
83 |
+
assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
|
84 |
+
"shear should be a list or tuple and it must be of length 2."
|
85 |
+
self.shear = shear
|
86 |
+
else:
|
87 |
+
self.shear = shear
|
88 |
+
|
89 |
+
def _get_inverse_affine_matrix(self, center, angle, translate, scale, shear):
|
90 |
+
# https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717
|
91 |
+
from numpy import sin, cos, tan
|
92 |
+
|
93 |
+
if isinstance(shear, numbers.Number):
|
94 |
+
shear = [shear, 0]
|
95 |
+
|
96 |
+
if not isinstance(shear, (tuple, list)) and len(shear) == 2:
|
97 |
+
raise ValueError(
|
98 |
+
"Shear should be a single value or a tuple/list containing " +
|
99 |
+
"two values. Got {}".format(shear))
|
100 |
+
|
101 |
+
rot = math.radians(angle)
|
102 |
+
sx, sy = [math.radians(s) for s in shear]
|
103 |
+
|
104 |
+
cx, cy = center
|
105 |
+
tx, ty = translate
|
106 |
+
|
107 |
+
# RSS without scaling
|
108 |
+
a = cos(rot - sy) / cos(sy)
|
109 |
+
b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
|
110 |
+
c = sin(rot - sy) / cos(sy)
|
111 |
+
d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
|
112 |
+
|
113 |
+
# Inverted rotation matrix with scale and shear
|
114 |
+
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
|
115 |
+
M = [d, -b, 0,
|
116 |
+
-c, a, 0]
|
117 |
+
M = [x / scale for x in M]
|
118 |
+
|
119 |
+
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
|
120 |
+
M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
|
121 |
+
M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
|
122 |
+
|
123 |
+
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
|
124 |
+
M[2] += cx
|
125 |
+
M[5] += cy
|
126 |
+
return M
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def get_params(degrees, translate, scale_ranges, shears, height):
|
130 |
+
angle = sample_sym(degrees)
|
131 |
+
if translate is not None:
|
132 |
+
max_dx = translate[0] * height
|
133 |
+
max_dy = translate[1] * height
|
134 |
+
translations = (np.round(sample_sym(max_dx)), np.round(sample_sym(max_dy)))
|
135 |
+
else:
|
136 |
+
translations = (0, 0)
|
137 |
+
|
138 |
+
if scale_ranges is not None:
|
139 |
+
scale = sample_uniform(scale_ranges[0], scale_ranges[1])
|
140 |
+
else:
|
141 |
+
scale = 1.0
|
142 |
+
|
143 |
+
if shears is not None:
|
144 |
+
if len(shears) == 1:
|
145 |
+
shear = [sample_sym(shears[0]), 0.]
|
146 |
+
elif len(shears) == 2:
|
147 |
+
shear = [sample_sym(shears[0]), sample_sym(shears[1])]
|
148 |
+
else:
|
149 |
+
shear = 0.0
|
150 |
+
|
151 |
+
return angle, translations, scale, shear
|
152 |
+
|
153 |
+
|
154 |
+
def __call__(self, img):
|
155 |
+
src_h, src_w = img.shape[:2]
|
156 |
+
angle, translate, scale, shear = self.get_params(
|
157 |
+
self.degrees, self.translate, self.scale, self.shear, src_h)
|
158 |
+
|
159 |
+
M = self._get_inverse_affine_matrix((src_w/2, src_h/2), angle, (0, 0), scale, shear)
|
160 |
+
M = np.array(M).reshape(2,3)
|
161 |
+
|
162 |
+
startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1), (0, src_h - 1)]
|
163 |
+
project = lambda x, y, a, b, c: int(a*x + b*y + c)
|
164 |
+
endpoints = [(project(x, y, *M[0]), project(x, y, *M[1])) for x, y in startpoints]
|
165 |
+
|
166 |
+
rect = cv2.minAreaRect(np.array(endpoints))
|
167 |
+
bbox = cv2.boxPoints(rect).astype(dtype=np.int)
|
168 |
+
max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
|
169 |
+
min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
|
170 |
+
|
171 |
+
dst_w = int(max_x - min_x)
|
172 |
+
dst_h = int(max_y - min_y)
|
173 |
+
M[0, 2] += (dst_w - src_w) / 2
|
174 |
+
M[1, 2] += (dst_h - src_h) / 2
|
175 |
+
|
176 |
+
# add translate
|
177 |
+
dst_w += int(abs(translate[0]))
|
178 |
+
dst_h += int(abs(translate[1]))
|
179 |
+
if translate[0] < 0: M[0, 2] += abs(translate[0])
|
180 |
+
if translate[1] < 0: M[1, 2] += abs(translate[1])
|
181 |
+
|
182 |
+
flags = get_interpolation()
|
183 |
+
return cv2.warpAffine(img, M, (dst_w , dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE)
|
184 |
+
|
185 |
+
class CVRandomPerspective(object):
|
186 |
+
def __init__(self, distortion=0.5):
|
187 |
+
self.distortion = distortion
|
188 |
+
|
189 |
+
def get_params(self, width, height, distortion):
|
190 |
+
offset_h = sample_asym(distortion * height / 2, size=4).astype(dtype=np.int)
|
191 |
+
offset_w = sample_asym(distortion * width / 2, size=4).astype(dtype=np.int)
|
192 |
+
topleft = ( offset_w[0], offset_h[0])
|
193 |
+
topright = (width - 1 - offset_w[1], offset_h[1])
|
194 |
+
botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
|
195 |
+
botleft = ( offset_w[3], height - 1 - offset_h[3])
|
196 |
+
|
197 |
+
startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
|
198 |
+
endpoints = [topleft, topright, botright, botleft]
|
199 |
+
return np.array(startpoints, dtype=np.float32), np.array(endpoints, dtype=np.float32)
|
200 |
+
|
201 |
+
def __call__(self, img):
|
202 |
+
height, width = img.shape[:2]
|
203 |
+
startpoints, endpoints = self.get_params(width, height, self.distortion)
|
204 |
+
M = cv2.getPerspectiveTransform(startpoints, endpoints)
|
205 |
+
|
206 |
+
# TODO: more robust way to crop image
|
207 |
+
rect = cv2.minAreaRect(endpoints)
|
208 |
+
bbox = cv2.boxPoints(rect).astype(dtype=np.int)
|
209 |
+
max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
|
210 |
+
min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
|
211 |
+
min_x, min_y = max(min_x, 0), max(min_y, 0)
|
212 |
+
|
213 |
+
flags = get_interpolation()
|
214 |
+
img = cv2.warpPerspective(img, M, (max_x, max_y), flags=flags, borderMode=cv2.BORDER_REPLICATE)
|
215 |
+
img = img[min_y:, min_x:]
|
216 |
+
return img
|
217 |
+
|
218 |
+
class CVRescale(object):
|
219 |
+
|
220 |
+
def __init__(self, factor=4, base_size=(128, 512)):
|
221 |
+
""" Define image scales using gaussian pyramid and rescale image to target scale.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
factor: the decayed factor from base size, factor=4 keeps target scale by default.
|
225 |
+
base_size: base size the build the bottom layer of pyramid
|
226 |
+
"""
|
227 |
+
if isinstance(factor, numbers.Number):
|
228 |
+
self.factor = round(sample_uniform(0, factor))
|
229 |
+
elif isinstance(factor, (tuple, list)) and len(factor) == 2:
|
230 |
+
self.factor = round(sample_uniform(factor[0], factor[1]))
|
231 |
+
else:
|
232 |
+
raise Exception('factor must be number or list with length 2')
|
233 |
+
# assert factor is valid
|
234 |
+
self.base_h, self.base_w = base_size[:2]
|
235 |
+
|
236 |
+
def __call__(self, img):
|
237 |
+
if self.factor == 0: return img
|
238 |
+
src_h, src_w = img.shape[:2]
|
239 |
+
cur_w, cur_h = self.base_w, self.base_h
|
240 |
+
scale_img = cv2.resize(img, (cur_w, cur_h), interpolation=get_interpolation())
|
241 |
+
for _ in range(self.factor):
|
242 |
+
scale_img = cv2.pyrDown(scale_img)
|
243 |
+
scale_img = cv2.resize(scale_img, (src_w, src_h), interpolation=get_interpolation())
|
244 |
+
return scale_img
|
245 |
+
|
246 |
+
class CVGaussianNoise(object):
|
247 |
+
def __init__(self, mean=0, var=20):
|
248 |
+
self.mean = mean
|
249 |
+
if isinstance(var, numbers.Number):
|
250 |
+
self.var = max(int(sample_asym(var)), 1)
|
251 |
+
elif isinstance(var, (tuple, list)) and len(var) == 2:
|
252 |
+
self.var = int(sample_uniform(var[0], var[1]))
|
253 |
+
else:
|
254 |
+
raise Exception('degree must be number or list with length 2')
|
255 |
+
|
256 |
+
def __call__(self, img):
|
257 |
+
noise = np.random.normal(self.mean, self.var**0.5, img.shape)
|
258 |
+
img = np.clip(img + noise, 0, 255).astype(np.uint8)
|
259 |
+
return img
|
260 |
+
|
261 |
+
class CVMotionBlur(object):
|
262 |
+
def __init__(self, degrees=12, angle=90):
|
263 |
+
if isinstance(degrees, numbers.Number):
|
264 |
+
self.degree = max(int(sample_asym(degrees)), 1)
|
265 |
+
elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
|
266 |
+
self.degree = int(sample_uniform(degrees[0], degrees[1]))
|
267 |
+
else:
|
268 |
+
raise Exception('degree must be number or list with length 2')
|
269 |
+
self.angle = sample_uniform(-angle, angle)
|
270 |
+
|
271 |
+
def __call__(self, img):
|
272 |
+
M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2), self.angle, 1)
|
273 |
+
motion_blur_kernel = np.zeros((self.degree, self.degree))
|
274 |
+
motion_blur_kernel[self.degree // 2, :] = 1
|
275 |
+
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (self.degree, self.degree))
|
276 |
+
motion_blur_kernel = motion_blur_kernel / self.degree
|
277 |
+
img = cv2.filter2D(img, -1, motion_blur_kernel)
|
278 |
+
img = np.clip(img, 0, 255).astype(np.uint8)
|
279 |
+
return img
|
280 |
+
|
281 |
+
class CVGeometry(object):
|
282 |
+
def __init__(self, degrees=15, translate=(0.3, 0.3), scale=(0.5, 2.),
|
283 |
+
shear=(45, 15), distortion=0.5, p=0.5):
|
284 |
+
self.p = p
|
285 |
+
type_p = random.random()
|
286 |
+
if type_p < 0.33:
|
287 |
+
self.transforms = CVRandomRotation(degrees=degrees)
|
288 |
+
elif type_p < 0.66:
|
289 |
+
self.transforms = CVRandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear)
|
290 |
+
else:
|
291 |
+
self.transforms = CVRandomPerspective(distortion=distortion)
|
292 |
+
|
293 |
+
def __call__(self, img):
|
294 |
+
if random.random() < self.p:
|
295 |
+
img = np.array(img)
|
296 |
+
return Image.fromarray(self.transforms(img))
|
297 |
+
else: return img
|
298 |
+
|
299 |
+
class CVDeterioration(object):
|
300 |
+
def __init__(self, var, degrees, factor, p=0.5):
|
301 |
+
self.p = p
|
302 |
+
transforms = []
|
303 |
+
if var is not None:
|
304 |
+
transforms.append(CVGaussianNoise(var=var))
|
305 |
+
if degrees is not None:
|
306 |
+
transforms.append(CVMotionBlur(degrees=degrees))
|
307 |
+
if factor is not None:
|
308 |
+
transforms.append(CVRescale(factor=factor))
|
309 |
+
|
310 |
+
random.shuffle(transforms)
|
311 |
+
transforms = Compose(transforms)
|
312 |
+
self.transforms = transforms
|
313 |
+
|
314 |
+
def __call__(self, img):
|
315 |
+
if random.random() < self.p:
|
316 |
+
img = np.array(img)
|
317 |
+
return Image.fromarray(self.transforms(img))
|
318 |
+
else: return img
|
319 |
+
|
320 |
+
|
321 |
+
class CVColorJitter(object):
|
322 |
+
def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.5):
|
323 |
+
self.p = p
|
324 |
+
self.transforms = transforms.ColorJitter(brightness=brightness, contrast=contrast,
|
325 |
+
saturation=saturation, hue=hue)
|
326 |
+
|
327 |
+
def __call__(self, img):
|
328 |
+
if random.random() < self.p: return self.transforms(img)
|
329 |
+
else: return img
|
utils.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import yaml
|
9 |
+
from matplotlib import colors
|
10 |
+
from matplotlib import pyplot as plt
|
11 |
+
from torch import Tensor, nn
|
12 |
+
from torch.utils.data import ConcatDataset
|
13 |
+
|
14 |
+
class CharsetMapper(object):
|
15 |
+
"""A simple class to map ids into strings.
|
16 |
+
|
17 |
+
It works only when the character set is 1:1 mapping between individual
|
18 |
+
characters and individual ids.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self,
|
22 |
+
filename='',
|
23 |
+
max_length=30,
|
24 |
+
null_char=u'\u2591'):
|
25 |
+
"""Creates a lookup table.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
filename: Path to charset file which maps characters to ids.
|
29 |
+
max_sequence_length: The max length of ids and string.
|
30 |
+
null_char: A unicode character used to replace '<null>' character.
|
31 |
+
the default value is a light shade block '░'.
|
32 |
+
"""
|
33 |
+
self.null_char = null_char
|
34 |
+
self.max_length = max_length
|
35 |
+
|
36 |
+
self.label_to_char = self._read_charset(filename)
|
37 |
+
self.char_to_label = dict(map(reversed, self.label_to_char.items()))
|
38 |
+
self.num_classes = len(self.label_to_char)
|
39 |
+
|
40 |
+
def _read_charset(self, filename):
|
41 |
+
"""Reads a charset definition from a tab separated text file.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
filename: a path to the charset file.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
a dictionary with keys equal to character codes and values - unicode
|
48 |
+
characters.
|
49 |
+
"""
|
50 |
+
import re
|
51 |
+
pattern = re.compile(r'(\d+)\t(.+)')
|
52 |
+
charset = {}
|
53 |
+
self.null_label = 0
|
54 |
+
charset[self.null_label] = self.null_char
|
55 |
+
with open(filename, 'r') as f:
|
56 |
+
for i, line in enumerate(f):
|
57 |
+
m = pattern.match(line)
|
58 |
+
assert m, f'Incorrect charset file. line #{i}: {line}'
|
59 |
+
label = int(m.group(1)) + 1
|
60 |
+
char = m.group(2)
|
61 |
+
charset[label] = char
|
62 |
+
return charset
|
63 |
+
|
64 |
+
def trim(self, text):
|
65 |
+
assert isinstance(text, str)
|
66 |
+
return text.replace(self.null_char, '')
|
67 |
+
|
68 |
+
def get_text(self, labels, length=None, padding=True, trim=False):
|
69 |
+
""" Returns a string corresponding to a sequence of character ids.
|
70 |
+
"""
|
71 |
+
length = length if length else self.max_length
|
72 |
+
labels = [l.item() if isinstance(l, Tensor) else int(l) for l in labels]
|
73 |
+
if padding:
|
74 |
+
labels = labels + [self.null_label] * (length-len(labels))
|
75 |
+
text = ''.join([self.label_to_char[label] for label in labels])
|
76 |
+
if trim: text = self.trim(text)
|
77 |
+
return text
|
78 |
+
|
79 |
+
def get_labels(self, text, length=None, padding=True, case_sensitive=False):
|
80 |
+
""" Returns the labels of the corresponding text.
|
81 |
+
"""
|
82 |
+
length = length if length else self.max_length
|
83 |
+
if padding:
|
84 |
+
text = text + self.null_char * (length - len(text))
|
85 |
+
if not case_sensitive:
|
86 |
+
text = text.lower()
|
87 |
+
labels = [self.char_to_label[char] for char in text]
|
88 |
+
return labels
|
89 |
+
|
90 |
+
def pad_labels(self, labels, length=None):
|
91 |
+
length = length if length else self.max_length
|
92 |
+
|
93 |
+
return labels + [self.null_label] * (length - len(labels))
|
94 |
+
|
95 |
+
@property
|
96 |
+
def digits(self):
|
97 |
+
return '0123456789'
|
98 |
+
|
99 |
+
@property
|
100 |
+
def digit_labels(self):
|
101 |
+
return self.get_labels(self.digits, padding=False)
|
102 |
+
|
103 |
+
@property
|
104 |
+
def alphabets(self):
|
105 |
+
all_chars = list(self.char_to_label.keys())
|
106 |
+
valid_chars = []
|
107 |
+
for c in all_chars:
|
108 |
+
if c in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ':
|
109 |
+
valid_chars.append(c)
|
110 |
+
return ''.join(valid_chars)
|
111 |
+
|
112 |
+
@property
|
113 |
+
def alphabet_labels(self):
|
114 |
+
return self.get_labels(self.alphabets, padding=False)
|
115 |
+
|
116 |
+
|
117 |
+
class Timer(object):
|
118 |
+
"""A simple timer."""
|
119 |
+
def __init__(self):
|
120 |
+
self.data_time = 0.
|
121 |
+
self.data_diff = 0.
|
122 |
+
self.data_total_time = 0.
|
123 |
+
self.data_call = 0
|
124 |
+
self.running_time = 0.
|
125 |
+
self.running_diff = 0.
|
126 |
+
self.running_total_time = 0.
|
127 |
+
self.running_call = 0
|
128 |
+
|
129 |
+
def tic(self):
|
130 |
+
self.start_time = time.time()
|
131 |
+
self.running_time = self.start_time
|
132 |
+
|
133 |
+
def toc_data(self):
|
134 |
+
self.data_time = time.time()
|
135 |
+
self.data_diff = self.data_time - self.running_time
|
136 |
+
self.data_total_time += self.data_diff
|
137 |
+
self.data_call += 1
|
138 |
+
|
139 |
+
def toc_running(self):
|
140 |
+
self.running_time = time.time()
|
141 |
+
self.running_diff = self.running_time - self.data_time
|
142 |
+
self.running_total_time += self.running_diff
|
143 |
+
self.running_call += 1
|
144 |
+
|
145 |
+
def total_time(self):
|
146 |
+
return self.data_total_time + self.running_total_time
|
147 |
+
|
148 |
+
def average_time(self):
|
149 |
+
return self.average_data_time() + self.average_running_time()
|
150 |
+
|
151 |
+
def average_data_time(self):
|
152 |
+
return self.data_total_time / (self.data_call or 1)
|
153 |
+
|
154 |
+
def average_running_time(self):
|
155 |
+
return self.running_total_time / (self.running_call or 1)
|
156 |
+
|
157 |
+
|
158 |
+
class Logger(object):
|
159 |
+
_handle = None
|
160 |
+
_root = None
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
def init(output_dir, name, phase):
|
164 |
+
format = '[%(asctime)s %(filename)s:%(lineno)d %(levelname)s {}] ' \
|
165 |
+
'%(message)s'.format(name)
|
166 |
+
logging.basicConfig(level=logging.INFO, format=format)
|
167 |
+
|
168 |
+
try: os.makedirs(output_dir)
|
169 |
+
except: pass
|
170 |
+
config_path = os.path.join(output_dir, f'{phase}.txt')
|
171 |
+
Logger._handle = logging.FileHandler(config_path)
|
172 |
+
Logger._root = logging.getLogger()
|
173 |
+
|
174 |
+
@staticmethod
|
175 |
+
def enable_file():
|
176 |
+
if Logger._handle is None or Logger._root is None:
|
177 |
+
raise Exception('Invoke Logger.init() first!')
|
178 |
+
Logger._root.addHandler(Logger._handle)
|
179 |
+
|
180 |
+
@staticmethod
|
181 |
+
def disable_file():
|
182 |
+
if Logger._handle is None or Logger._root is None:
|
183 |
+
raise Exception('Invoke Logger.init() first!')
|
184 |
+
Logger._root.removeHandler(Logger._handle)
|
185 |
+
|
186 |
+
|
187 |
+
class Config(object):
|
188 |
+
|
189 |
+
def __init__(self, config_path, host=True):
|
190 |
+
def __dict2attr(d, prefix=''):
|
191 |
+
for k, v in d.items():
|
192 |
+
if isinstance(v, dict):
|
193 |
+
__dict2attr(v, f'{prefix}{k}_')
|
194 |
+
else:
|
195 |
+
if k == 'phase':
|
196 |
+
assert v in ['train', 'test']
|
197 |
+
if k == 'stage':
|
198 |
+
assert v in ['pretrain-vision', 'pretrain-language',
|
199 |
+
'train-semi-super', 'train-super']
|
200 |
+
self.__setattr__(f'{prefix}{k}', v)
|
201 |
+
|
202 |
+
assert os.path.exists(config_path), '%s does not exists!' % config_path
|
203 |
+
with open(config_path) as file:
|
204 |
+
config_dict = yaml.load(file, Loader=yaml.FullLoader)
|
205 |
+
with open('configs/template.yaml') as file:
|
206 |
+
default_config_dict = yaml.load(file, Loader=yaml.FullLoader)
|
207 |
+
__dict2attr(default_config_dict)
|
208 |
+
__dict2attr(config_dict)
|
209 |
+
self.global_workdir = os.path.join(self.global_workdir, self.global_name)
|
210 |
+
|
211 |
+
def __getattr__(self, item):
|
212 |
+
attr = self.__dict__.get(item)
|
213 |
+
if attr is None:
|
214 |
+
attr = dict()
|
215 |
+
prefix = f'{item}_'
|
216 |
+
for k, v in self.__dict__.items():
|
217 |
+
if k.startswith(prefix):
|
218 |
+
n = k.replace(prefix, '')
|
219 |
+
attr[n] = v
|
220 |
+
return attr if len(attr) > 0 else None
|
221 |
+
else:
|
222 |
+
return attr
|
223 |
+
|
224 |
+
def __repr__(self):
|
225 |
+
str = 'ModelConfig(\n'
|
226 |
+
for i, (k, v) in enumerate(sorted(vars(self).items())):
|
227 |
+
str += f'\t({i}): {k} = {v}\n'
|
228 |
+
str += ')'
|
229 |
+
return str
|
230 |
+
|
231 |
+
def blend_mask(image, mask, alpha=0.5, cmap='jet', color='b', color_alpha=1.0):
|
232 |
+
# normalize mask
|
233 |
+
mask = (mask-mask.min()) / (mask.max() - mask.min() + np.finfo(float).eps)
|
234 |
+
if mask.shape != image.shape:
|
235 |
+
mask = cv2.resize(mask,(image.shape[1], image.shape[0]))
|
236 |
+
# get color map
|
237 |
+
color_map = plt.get_cmap(cmap)
|
238 |
+
mask = color_map(mask)[:,:,:3]
|
239 |
+
# convert float to uint8
|
240 |
+
mask = (mask * 255).astype(dtype=np.uint8)
|
241 |
+
|
242 |
+
# set the basic color
|
243 |
+
basic_color = np.array(colors.to_rgb(color)) * 255
|
244 |
+
basic_color = np.tile(basic_color, [image.shape[0], image.shape[1], 1])
|
245 |
+
basic_color = basic_color.astype(dtype=np.uint8)
|
246 |
+
# blend with basic color
|
247 |
+
blended_img = cv2.addWeighted(image, color_alpha, basic_color, 1-color_alpha, 0)
|
248 |
+
# blend with mask
|
249 |
+
blended_img = cv2.addWeighted(blended_img, alpha, mask, 1-alpha, 0)
|
250 |
+
|
251 |
+
return blended_img
|
252 |
+
|
253 |
+
def onehot(label, depth, device=None):
|
254 |
+
"""
|
255 |
+
Args:
|
256 |
+
label: shape (n1, n2, ..., )
|
257 |
+
depth: a scalar
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
onehot: (n1, n2, ..., depth)
|
261 |
+
"""
|
262 |
+
if not isinstance(label, torch.Tensor):
|
263 |
+
label = torch.tensor(label, device=device)
|
264 |
+
onehot = torch.zeros(label.size() + torch.Size([depth]), device=device)
|
265 |
+
onehot = onehot.scatter_(-1, label.unsqueeze(-1), 1)
|
266 |
+
|
267 |
+
return onehot
|
268 |
+
|
269 |
+
class MyDataParallel(nn.DataParallel):
|
270 |
+
|
271 |
+
def gather(self, outputs, target_device):
|
272 |
+
r"""
|
273 |
+
Gathers tensors from different GPUs on a specified device
|
274 |
+
(-1 means the CPU).
|
275 |
+
"""
|
276 |
+
def gather_map(outputs):
|
277 |
+
out = outputs[0]
|
278 |
+
if isinstance(out, (str, int, float)):
|
279 |
+
return out
|
280 |
+
if isinstance(out, list) and isinstance(out[0], str):
|
281 |
+
return [o for out in outputs for o in out]
|
282 |
+
if isinstance(out, torch.Tensor):
|
283 |
+
return torch.nn.parallel._functions.Gather.apply(target_device, self.dim, *outputs)
|
284 |
+
if out is None:
|
285 |
+
return None
|
286 |
+
if isinstance(out, dict):
|
287 |
+
if not all((len(out) == len(d) for d in outputs)):
|
288 |
+
raise ValueError('All dicts must have the same number of keys')
|
289 |
+
return type(out)(((k, gather_map([d[k] for d in outputs]))
|
290 |
+
for k in out))
|
291 |
+
return type(out)(map(gather_map, zip(*outputs)))
|
292 |
+
|
293 |
+
# Recursive function calls like this create reference cycles.
|
294 |
+
# Setting the function to None clears the refcycle.
|
295 |
+
try:
|
296 |
+
res = gather_map(outputs)
|
297 |
+
finally:
|
298 |
+
gather_map = None
|
299 |
+
return res
|
300 |
+
|
301 |
+
|
302 |
+
class MyConcatDataset(ConcatDataset):
|
303 |
+
def __getattr__(self, k):
|
304 |
+
return getattr(self.datasets[0], k)
|