Your Name commited on
Commit
74a6211
·
1 Parent(s): 5c5ae78
Files changed (12) hide show
  1. LICENSE +25 -0
  2. README.md +142 -13
  3. app.py +122 -0
  4. callbacks.py +360 -0
  5. commands.sh +3 -0
  6. config.yaml +71 -0
  7. dataset.py +278 -0
  8. losses.py +72 -0
  9. main.py +246 -0
  10. requirements.txt +9 -0
  11. transforms.py +329 -0
  12. utils.py +304 -0
LICENSE ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ABINet for non-commercial purposes
2
+
3
+ Copyright (c) 2021, USTC
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -1,13 +1,142 @@
1
- ---
2
- title: Dvatch Captcha Sneedium
3
- emoji: 🌍
4
- colorFrom: green
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 3.3.1
8
- app_file: app.py
9
- pinned: false
10
- license: openrail
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition
2
+
3
+ The official code of [ABINet](https://arxiv.org/pdf/2103.06495.pdf) (CVPR 2021, Oral).
4
+
5
+ ABINet uses a vision model and an explicit language model to recognize text in the wild, which are trained in end-to-end way. The language model (BCN) achieves bidirectional language representation in simulating cloze test, additionally utilizing iterative correction strategy.
6
+
7
+ ![framework](./figs/framework.png)
8
+
9
+ ## Runtime Environment
10
+
11
+ - We provide a pre-built docker image using the Dockerfile from `docker/Dockerfile`
12
+
13
+ - Running in Docker
14
+ ```
15
+ $ [email protected]:FangShancheng/ABINet.git
16
+ $ docker run --gpus all --rm -ti --ipc=host -v $(pwd)/ABINet:/app fangshancheng/fastai:torch1.1 /bin/bash
17
+ ```
18
+ - (Untested) Or using the dependencies
19
+ ```
20
+ pip install -r requirements.txt
21
+ ```
22
+
23
+ ## Datasets
24
+
25
+ - Training datasets
26
+
27
+ 1. [MJSynth](http://www.robots.ox.ac.uk/~vgg/data/text/) (MJ):
28
+ - Use `tools/create_lmdb_dataset.py` to convert images into LMDB dataset
29
+ - [LMDB dataset BaiduNetdisk(passwd:n23k)](https://pan.baidu.com/s/1mgnTiyoR8f6Cm655rFI4HQ)
30
+ 2. [SynthText](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) (ST):
31
+ - Use `tools/crop_by_word_bb.py` to crop images from original [SynthText](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) dataset, and convert images into LMDB dataset by `tools/create_lmdb_dataset.py`
32
+ - [LMDB dataset BaiduNetdisk(passwd:n23k)](https://pan.baidu.com/s/1mgnTiyoR8f6Cm655rFI4HQ)
33
+ 3. [WikiText103](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip), which is only used for pre-trainig language models:
34
+ - Use `notebooks/prepare_wikitext103.ipynb` to convert text into CSV format.
35
+ - [CSV dataset BaiduNetdisk(passwd:dk01)](https://pan.baidu.com/s/1yabtnPYDKqhBb_Ie9PGFXA)
36
+
37
+ - Evaluation datasets, LMDB datasets can be downloaded from [BaiduNetdisk(passwd:1dbv)](https://pan.baidu.com/s/1RUg3Akwp7n8kZYJ55rU5LQ), [GoogleDrive](https://drive.google.com/file/d/1dTI0ipu14Q1uuK4s4z32DqbqF3dJPdkk/view?usp=sharing).
38
+ 1. ICDAR 2013 (IC13)
39
+ 2. ICDAR 2015 (IC15)
40
+ 3. IIIT5K Words (IIIT)
41
+ 4. Street View Text (SVT)
42
+ 5. Street View Text-Perspective (SVTP)
43
+ 6. CUTE80 (CUTE)
44
+
45
+
46
+ - The structure of `data` directory is
47
+ ```
48
+ data
49
+ ├── charset_36.txt
50
+ ├── evaluation
51
+ │   ├── CUTE80
52
+ │   ├── IC13_857
53
+ │   ├── IC15_1811
54
+ │   ├── IIIT5k_3000
55
+ │   ├── SVT
56
+ │   └── SVTP
57
+ ├── training
58
+ │   ├── MJ
59
+ │   │   ├── MJ_test
60
+ │   │   ├── MJ_train
61
+ │   │   └── MJ_valid
62
+ │   └── ST
63
+ ├── WikiText-103.csv
64
+ └── WikiText-103_eval_d1.csv
65
+ ```
66
+
67
+ ### Pretrained Models
68
+
69
+ Get the pretrained models from [BaiduNetdisk(passwd:kwck)](https://pan.baidu.com/s/1b3vyvPwvh_75FkPlp87czQ), [GoogleDrive](https://drive.google.com/file/d/1mYM_26qHUom_5NU7iutHneB_KHlLjL5y/view?usp=sharing). Performances of the pretrained models are summaried as follows:
70
+
71
+ |Model|IC13|SVT|IIIT|IC15|SVTP|CUTE|AVG|
72
+ |-|-|-|-|-|-|-|-|
73
+ |ABINet-SV|97.1|92.7|95.2|84.0|86.7|88.5|91.4|
74
+ |ABINet-LV|97.0|93.4|96.4|85.9|89.5|89.2|92.7|
75
+
76
+ ## Training
77
+
78
+ 1. Pre-train vision model
79
+ ```
80
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config=configs/pretrain_vision_model.yaml
81
+ ```
82
+ 2. Pre-train language model
83
+ ```
84
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config=configs/pretrain_language_model.yaml
85
+ ```
86
+ 3. Train ABINet
87
+ ```
88
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config=configs/train_abinet.yaml
89
+ ```
90
+ Note:
91
+ - You can set the `checkpoint` path for vision and language models separately for specific pretrained model, or set to `None` to train from scratch
92
+
93
+
94
+ ## Evaluation
95
+
96
+ ```
97
+ CUDA_VISIBLE_DEVICES=0 python main.py --config=configs/train_abinet.yaml --phase test --image_only
98
+ ```
99
+ Additional flags:
100
+ - `--checkpoint /path/to/checkpoint` set the path of evaluation model
101
+ - `--test_root /path/to/dataset` set the path of evaluation dataset
102
+ - `--model_eval [alignment|vision]` which sub-model to evaluate
103
+ - `--image_only` disable dumping visualization of attention masks
104
+
105
+ ## Web Demo
106
+
107
+ Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/tomofi/ABINet-OCR)
108
+
109
+ ## Run Demo
110
+
111
+ ```
112
+ python demo.py --config=configs/train_abinet.yaml --input=figs/test
113
+ ```
114
+ Additional flags:
115
+ - `--config /path/to/config` set the path of configuration file
116
+ - `--input /path/to/image-directory` set the path of image directory or wildcard path, e.g, `--input='figs/test/*.png'`
117
+ - `--checkpoint /path/to/checkpoint` set the path of trained model
118
+ - `--cuda [-1|0|1|2|3...]` set the cuda id, by default -1 is set and stands for cpu
119
+ - `--model_eval [alignment|vision]` which sub-model to use
120
+ - `--image_only` disable dumping visualization of attention masks
121
+
122
+ ## Visualization
123
+ Successful and failure cases on low-quality images:
124
+
125
+ ![cases](./figs/cases.png)
126
+
127
+ ## Citation
128
+ If you find our method useful for your reserach, please cite
129
+ ```bash
130
+ @article{fang2021read,
131
+ title={Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition},
132
+ author={Fang, Shancheng and Xie, Hongtao and Wang, Yuxin and Mao, Zhendong and Zhang, Yongdong},
133
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
134
+ year={2021}
135
+ }
136
+ ```
137
+
138
+ ## License
139
+
140
+ This project is only free for academic research purposes, licensed under the 2-clause BSD License - see the LICENSE file for details.
141
+
142
+ Feel free to contact [email protected] if you have any questions.
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import glob
5
+ import tqdm
6
+ import torch, re
7
+ import PIL
8
+ import cv2
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+ from torchvision import transforms
12
+ from utils import Config, Logger, CharsetMapper
13
+
14
+ def get_model(config):
15
+ import importlib
16
+ names = config.model_name.split('.')
17
+ module_name, class_name = '.'.join(names[:-1]), names[-1]
18
+ cls = getattr(importlib.import_module(module_name), class_name)
19
+ model = cls(config)
20
+ logging.info(model)
21
+ model = model.eval()
22
+ return model
23
+
24
+ def preprocess(img, width, height):
25
+ img = cv2.resize(np.array(img), (width, height))
26
+ img = transforms.ToTensor()(img).unsqueeze(0)
27
+ mean = torch.tensor([0.485, 0.456, 0.406])
28
+ std = torch.tensor([0.229, 0.224, 0.225])
29
+ return (img-mean[...,None,None]) / std[...,None,None]
30
+
31
+ def postprocess(output, charset, model_eval):
32
+ def _get_output(last_output, model_eval):
33
+ if isinstance(last_output, (tuple, list)):
34
+ for res in last_output:
35
+ if res['name'] == model_eval: output = res
36
+ else: output = last_output
37
+ return output
38
+
39
+ def _decode(logit):
40
+ """ Greed decode """
41
+ out = F.softmax(logit, dim=2)
42
+ pt_text, pt_scores, pt_lengths = [], [], []
43
+ for o in out:
44
+ text = charset.get_text(o.argmax(dim=1), padding=False, trim=False)
45
+ text = text.split(charset.null_char)[0] # end at end-token
46
+ pt_text.append(text)
47
+ pt_scores.append(o.max(dim=1)[0])
48
+ pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token
49
+ return pt_text, pt_scores, pt_lengths
50
+
51
+ output = _get_output(output, model_eval)
52
+ logits, pt_lengths = output['logits'], output['pt_lengths']
53
+ pt_text, pt_scores, pt_lengths_ = _decode(logits)
54
+
55
+ return pt_text, pt_scores, pt_lengths_
56
+
57
+ def load(model, file, device=None, strict=True):
58
+ if device is None: device = 'cpu'
59
+ elif isinstance(device, int): device = torch.device('cuda', device)
60
+ assert os.path.isfile(file)
61
+ state = torch.load(file, map_location=device)
62
+ if set(state.keys()) == {'model', 'opt'}:
63
+ state = state['model']
64
+ model.load_state_dict(state, strict=strict)
65
+ return model
66
+
67
+
68
+ def main():
69
+ parser = argparse.ArgumentParser()
70
+ parser.add_argument('--config', type=str, default='configs/train_abinet.yaml',
71
+ help='path to config file')
72
+ parser.add_argument('--input', type=str, default='figs/test')
73
+ parser.add_argument('--cuda', type=int, default=-1)
74
+ parser.add_argument('--checkpoint', type=str, default='workdir/train-abinet/best-train-abinet.pth')
75
+ parser.add_argument('--model_eval', type=str, default='alignment',
76
+ choices=['alignment', 'vision', 'language'])
77
+ args = parser.parse_args()
78
+ config = Config(args.config)
79
+ if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
80
+ if args.model_eval is not None: config.model_eval = args.model_eval
81
+ config.global_phase = 'test'
82
+ config.model_vision_checkpoint, config.model_language_checkpoint = None, None
83
+ device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}'
84
+
85
+ Logger.init(config.global_workdir, config.global_name, config.global_phase)
86
+ Logger.enable_file()
87
+ logging.info(config)
88
+
89
+ logging.info('Construct model.')
90
+ model = get_model(config).to(device)
91
+ model = load(model, config.model_checkpoint, device=device)
92
+ charset = CharsetMapper(filename=config.dataset_charset_path,
93
+ max_length=config.dataset_max_length + 1)
94
+
95
+ if os.path.isdir(args.input):
96
+ paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)]
97
+ else:
98
+ paths = glob.glob(os.path.expanduser(args.input))
99
+ assert paths, "The input path(s) was not found"
100
+ paths = sorted(paths)
101
+
102
+
103
+ count = 0
104
+ checks = 0
105
+ print(tqdm.tqdm(paths))
106
+ for path in tqdm.tqdm(paths):
107
+ img = PIL.Image.open(path).convert('RGB')
108
+ img = preprocess(img, config.dataset_image_width, config.dataset_image_height)
109
+ img = img.to(device)
110
+ res = model(img)
111
+ pt_text, _, __ = postprocess(res, charset, config.model_eval)
112
+ a = re.findall(r'(\d{6}).png', path)[0]
113
+ # print(a)
114
+ # print(pt_text[0], "Lol")
115
+ # a = re.findall(r'base/(.*).pn', path)[0]
116
+ checks += 1
117
+ if a.lower() != pt_text[0].lower():
118
+ count += 1
119
+ print(f'label:{a.lower()} ||| guess:{pt_text[0]} ||| count_fails:{str(count)}/{str(checks)}')
120
+
121
+ if __name__ == '__main__':
122
+ main()
callbacks.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import shutil
3
+ import time
4
+
5
+ import editdistance as ed
6
+ import torchvision.utils as vutils
7
+ from fastai.callbacks.tensorboard import (LearnerTensorboardWriter,
8
+ SummaryWriter, TBWriteRequest,
9
+ asyncTBWriter)
10
+ from fastai.vision import *
11
+ from torch.nn.parallel import DistributedDataParallel
12
+ from torchvision import transforms
13
+
14
+ import dataset
15
+ from utils import CharsetMapper, Timer, blend_mask
16
+
17
+
18
+ class IterationCallback(LearnerTensorboardWriter):
19
+ "A `TrackerCallback` that monitor in each iteration."
20
+ def __init__(self, learn:Learner, name:str='model', checpoint_keep_num=5,
21
+ show_iters:int=50, eval_iters:int=1000, save_iters:int=20000,
22
+ start_iters:int=0, stats_iters=20000):
23
+ #if self.learn.rank is not None: time.sleep(self.learn.rank) # keep all event files
24
+ super().__init__(learn, base_dir='.', name=learn.path, loss_iters=show_iters,
25
+ stats_iters=stats_iters, hist_iters=stats_iters)
26
+ self.name, self.bestname = Path(name).name, f'best-{Path(name).name}'
27
+ self.show_iters = show_iters
28
+ self.eval_iters = eval_iters
29
+ self.save_iters = save_iters
30
+ self.start_iters = start_iters
31
+ self.checpoint_keep_num = checpoint_keep_num
32
+ self.metrics_root = 'metrics/' # rewrite
33
+ self.timer = Timer()
34
+ self.host = self.learn.rank is None or self.learn.rank == 0
35
+
36
+ def _write_metrics(self, iteration:int, names:List[str], last_metrics:MetricsList)->None:
37
+ "Writes training metrics to Tensorboard."
38
+ for i, name in enumerate(names):
39
+ if last_metrics is None or len(last_metrics) < i+1: return
40
+ scalar_value = last_metrics[i]
41
+ self._write_scalar(name=name, scalar_value=scalar_value, iteration=iteration)
42
+
43
+ def _write_sub_loss(self, iteration:int, last_losses:dict)->None:
44
+ "Writes sub loss to Tensorboard."
45
+ for name, loss in last_losses.items():
46
+ scalar_value = to_np(loss)
47
+ tag = self.metrics_root + name
48
+ self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
49
+
50
+ def _save(self, name):
51
+ if isinstance(self.learn.model, DistributedDataParallel):
52
+ tmp = self.learn.model
53
+ self.learn.model = self.learn.model.module
54
+ self.learn.save(name)
55
+ self.learn.model = tmp
56
+ else: self.learn.save(name)
57
+
58
+ def _validate(self, dl=None, callbacks=None, metrics=None, keeped_items=False):
59
+ "Validate on `dl` with potential `callbacks` and `metrics`."
60
+ dl = ifnone(dl, self.learn.data.valid_dl)
61
+ metrics = ifnone(metrics, self.learn.metrics)
62
+ cb_handler = CallbackHandler(ifnone(callbacks, []), metrics)
63
+ cb_handler.on_train_begin(1, None, metrics); cb_handler.on_epoch_begin()
64
+ if keeped_items: cb_handler.state_dict.update(dict(keeped_items=[]))
65
+ val_metrics = validate(self.learn.model, dl, self.loss_func, cb_handler)
66
+ cb_handler.on_epoch_end(val_metrics)
67
+ if keeped_items: return cb_handler.state_dict['keeped_items']
68
+ else: return cb_handler.state_dict['last_metrics']
69
+
70
+ def jump_to_epoch_iter(self, epoch:int, iteration:int)->None:
71
+ try:
72
+ self.learn.load(f'{self.name}_{epoch}_{iteration}', purge=False)
73
+ logging.info(f'Loaded {self.name}_{epoch}_{iteration}')
74
+ except: logging.info(f'Model {self.name}_{epoch}_{iteration} not found.')
75
+
76
+ def on_train_begin(self, n_epochs, **kwargs):
77
+ # TODO: can not write graph here
78
+ # super().on_train_begin(**kwargs)
79
+ self.best = -float('inf')
80
+ self.timer.tic()
81
+ if self.host:
82
+ checkpoint_path = self.learn.path/'checkpoint.yaml'
83
+ if checkpoint_path.exists():
84
+ os.remove(checkpoint_path)
85
+ open(checkpoint_path, 'w').close()
86
+ return {'skip_validate': True, 'iteration':self.start_iters} # disable default validate
87
+
88
+ def on_batch_begin(self, **kwargs:Any)->None:
89
+ self.timer.toc_data()
90
+ super().on_batch_begin(**kwargs)
91
+
92
+ def on_batch_end(self, iteration, epoch, last_loss, smooth_loss, train, **kwargs):
93
+ super().on_batch_end(last_loss, iteration, train, **kwargs)
94
+ if iteration == 0: return
95
+
96
+ if iteration % self.loss_iters == 0:
97
+ last_losses = self.learn.loss_func.last_losses
98
+ self._write_sub_loss(iteration=iteration, last_losses=last_losses)
99
+ self.tbwriter.add_scalar(tag=self.metrics_root + 'lr',
100
+ scalar_value=self.opt.lr, global_step=iteration)
101
+
102
+ if iteration % self.show_iters == 0:
103
+ log_str = f'epoch {epoch} iter {iteration}: loss = {last_loss:6.4f}, ' \
104
+ f'smooth loss = {smooth_loss:6.4f}'
105
+ logging.info(log_str)
106
+ # log_str = f'data time = {self.timer.data_diff:.4f}s, runing time = {self.timer.running_diff:.4f}s'
107
+ # logging.info(log_str)
108
+
109
+ if iteration % self.eval_iters == 0:
110
+ # TODO: or remove time to on_epoch_end
111
+ # 1. Record time
112
+ log_str = f'average data time = {self.timer.average_data_time():.4f}s, ' \
113
+ f'average running time = {self.timer.average_running_time():.4f}s'
114
+ logging.info(log_str)
115
+
116
+ # 2. Call validate
117
+ last_metrics = self._validate()
118
+ self.learn.model.train()
119
+ log_str = f'epoch {epoch} iter {iteration}: eval loss = {last_metrics[0]:6.4f}, ' \
120
+ f'ccr = {last_metrics[1]:6.4f}, cwr = {last_metrics[2]:6.4f}, ' \
121
+ f'ted = {last_metrics[3]:6.4f}, ned = {last_metrics[4]:6.4f}, ' \
122
+ f'ted/w = {last_metrics[5]:6.4f}, '
123
+ logging.info(log_str)
124
+ names = ['eval_loss', 'ccr', 'cwr', 'ted', 'ned', 'ted/w']
125
+ self._write_metrics(iteration, names, last_metrics)
126
+
127
+ # 3. Save best model
128
+ current = last_metrics[2]
129
+ if current is not None and current > self.best:
130
+ logging.info(f'Better model found at epoch {epoch}, '\
131
+ f'iter {iteration} with accuracy value: {current:6.4f}.')
132
+ self.best = current
133
+ self._save(f'{self.bestname}')
134
+
135
+ if iteration % self.save_iters == 0 and self.host:
136
+ logging.info(f'Save model {self.name}_{epoch}_{iteration}')
137
+ filename = f'{self.name}_{epoch}_{iteration}'
138
+ self._save(filename)
139
+
140
+ checkpoint_path = self.learn.path/'checkpoint.yaml'
141
+ if not checkpoint_path.exists():
142
+ open(checkpoint_path, 'w').close()
143
+ with open(checkpoint_path, 'r') as file:
144
+ checkpoints = yaml.load(file, Loader=yaml.FullLoader) or dict()
145
+ checkpoints['all_checkpoints'] = (
146
+ checkpoints.get('all_checkpoints') or list())
147
+ checkpoints['all_checkpoints'].insert(0, filename)
148
+ if len(checkpoints['all_checkpoints']) > self.checpoint_keep_num:
149
+ removed_checkpoint = checkpoints['all_checkpoints'].pop()
150
+ removed_checkpoint = self.learn.path/self.learn.model_dir/f'{removed_checkpoint}.pth'
151
+ os.remove(removed_checkpoint)
152
+ checkpoints['current_checkpoint'] = filename
153
+ with open(checkpoint_path, 'w') as file:
154
+ yaml.dump(checkpoints, file)
155
+
156
+
157
+ self.timer.toc_running()
158
+
159
+ def on_train_end(self, **kwargs):
160
+ #self.learn.load(f'{self.bestname}', purge=False)
161
+ pass
162
+
163
+ def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs)->None:
164
+ self._write_embedding(iteration=iteration)
165
+
166
+
167
+ class TextAccuracy(Callback):
168
+ _names = ['ccr', 'cwr', 'ted', 'ned', 'ted/w']
169
+ def __init__(self, charset_path, max_length, case_sensitive, model_eval):
170
+ self.charset_path = charset_path
171
+ self.max_length = max_length
172
+ self.case_sensitive = case_sensitive
173
+ self.charset = CharsetMapper(charset_path, self.max_length)
174
+ self.names = self._names
175
+
176
+ self.model_eval = model_eval or 'alignment'
177
+ assert self.model_eval in ['vision', 'language', 'alignment']
178
+
179
+ def on_epoch_begin(self, **kwargs):
180
+ self.total_num_char = 0.
181
+ self.total_num_word = 0.
182
+ self.correct_num_char = 0.
183
+ self.correct_num_word = 0.
184
+ self.total_ed = 0.
185
+ self.total_ned = 0.
186
+
187
+ def _get_output(self, last_output):
188
+ if isinstance(last_output, (tuple, list)):
189
+ for res in last_output:
190
+ if res['name'] == self.model_eval: output = res
191
+ else: output = last_output
192
+ return output
193
+
194
+ def _update_output(self, last_output, items):
195
+ if isinstance(last_output, (tuple, list)):
196
+ for res in last_output:
197
+ if res['name'] == self.model_eval: res.update(items)
198
+ else: last_output.update(items)
199
+ return last_output
200
+
201
+ def on_batch_end(self, last_output, last_target, **kwargs):
202
+ output = self._get_output(last_output)
203
+ logits, pt_lengths = output['logits'], output['pt_lengths']
204
+ pt_text, pt_scores, pt_lengths_ = self.decode(logits)
205
+ assert (pt_lengths == pt_lengths_).all(), f'{pt_lengths} != {pt_lengths_} for {pt_text}'
206
+ last_output = self._update_output(last_output, {'pt_text':pt_text, 'pt_scores':pt_scores})
207
+
208
+ pt_text = [self.charset.trim(t) for t in pt_text]
209
+ label = last_target[0]
210
+ if label.dim() == 3: label = label.argmax(dim=-1) # one-hot label
211
+ gt_text = [self.charset.get_text(l, trim=True) for l in label]
212
+
213
+ for i in range(len(gt_text)):
214
+ if not self.case_sensitive:
215
+ gt_text[i], pt_text[i] = gt_text[i].lower(), pt_text[i].lower()
216
+ distance = ed.eval(gt_text[i], pt_text[i])
217
+ self.total_ed += distance
218
+ self.total_ned += float(distance) / max(len(gt_text[i]), 1)
219
+
220
+ if gt_text[i] == pt_text[i]:
221
+ self.correct_num_word += 1
222
+ self.total_num_word += 1
223
+
224
+ for j in range(min(len(gt_text[i]), len(pt_text[i]))):
225
+ if gt_text[i][j] == pt_text[i][j]:
226
+ self.correct_num_char += 1
227
+ self.total_num_char += len(gt_text[i])
228
+
229
+ return {'last_output': last_output}
230
+
231
+ def on_epoch_end(self, last_metrics, **kwargs):
232
+ mets = [self.correct_num_char / self.total_num_char,
233
+ self.correct_num_word / self.total_num_word,
234
+ self.total_ed,
235
+ self.total_ned,
236
+ self.total_ed / self.total_num_word]
237
+ return add_metrics(last_metrics, mets)
238
+
239
+ def decode(self, logit):
240
+ """ Greed decode """
241
+ # TODO: test running time and decode on GPU
242
+ out = F.softmax(logit, dim=2)
243
+ pt_text, pt_scores, pt_lengths = [], [], []
244
+ for o in out:
245
+ text = self.charset.get_text(o.argmax(dim=1), padding=False, trim=False)
246
+ text = text.split(self.charset.null_char)[0] # end at end-token
247
+ pt_text.append(text)
248
+ pt_scores.append(o.max(dim=1)[0])
249
+ pt_lengths.append(min(len(text) + 1, self.max_length)) # one for end-token
250
+ pt_scores = torch.stack(pt_scores)
251
+ pt_lengths = pt_scores.new_tensor(pt_lengths, dtype=torch.long)
252
+ return pt_text, pt_scores, pt_lengths
253
+
254
+
255
+ class TopKTextAccuracy(TextAccuracy):
256
+ _names = ['ccr', 'cwr']
257
+ def __init__(self, k, charset_path, max_length, case_sensitive, model_eval):
258
+ self.k = k
259
+ self.charset_path = charset_path
260
+ self.max_length = max_length
261
+ self.case_sensitive = case_sensitive
262
+ self.charset = CharsetMapper(charset_path, self.max_length)
263
+ self.names = self._names
264
+
265
+ def on_epoch_begin(self, **kwargs):
266
+ self.total_num_char = 0.
267
+ self.total_num_word = 0.
268
+ self.correct_num_char = 0.
269
+ self.correct_num_word = 0.
270
+
271
+ def on_batch_end(self, last_output, last_target, **kwargs):
272
+ logits, pt_lengths = last_output['logits'], last_output['pt_lengths']
273
+ gt_labels, gt_lengths = last_target[:]
274
+
275
+ for logit, pt_length, label, length in zip(logits, pt_lengths, gt_labels, gt_lengths):
276
+ word_flag = True
277
+ for i in range(length):
278
+ char_logit = logit[i].topk(self.k)[1]
279
+ char_label = label[i].argmax(-1)
280
+ if char_label in char_logit: self.correct_num_char += 1
281
+ else: word_flag = False
282
+ self.total_num_char += 1
283
+ if pt_length == length and word_flag:
284
+ self.correct_num_word += 1
285
+ self.total_num_word += 1
286
+
287
+ def on_epoch_end(self, last_metrics, **kwargs):
288
+ mets = [self.correct_num_char / self.total_num_char,
289
+ self.correct_num_word / self.total_num_word,
290
+ 0., 0., 0.]
291
+ return add_metrics(last_metrics, mets)
292
+
293
+
294
+ class DumpPrediction(LearnerCallback):
295
+
296
+ def __init__(self, learn, dataset, charset_path, model_eval, image_only=False, debug=False):
297
+ super().__init__(learn=learn)
298
+ self.debug = debug
299
+ self.model_eval = model_eval or 'alignment'
300
+ self.image_only = image_only
301
+ assert self.model_eval in ['vision', 'language', 'alignment']
302
+
303
+ self.dataset, self.root = dataset, Path(self.learn.path)/f'{dataset}-{self.model_eval}'
304
+ self.attn_root = self.root/'attn'
305
+ self.charset = CharsetMapper(charset_path)
306
+ if self.root.exists(): shutil.rmtree(self.root)
307
+ self.root.mkdir(), self.attn_root.mkdir()
308
+
309
+ self.pil = transforms.ToPILImage()
310
+ self.tensor = transforms.ToTensor()
311
+ size = self.learn.data.img_h, self.learn.data.img_w
312
+ self.resize = transforms.Resize(size=size, interpolation=0)
313
+ self.c = 0
314
+
315
+ def on_batch_end(self, last_input, last_output, last_target, **kwargs):
316
+ if isinstance(last_output, (tuple, list)):
317
+ for res in last_output:
318
+ if res['name'] == self.model_eval: pt_text = res['pt_text']
319
+ if res['name'] == 'vision': attn_scores = res['attn_scores'].detach().cpu()
320
+ if res['name'] == self.model_eval: logits = res['logits']
321
+ else:
322
+ pt_text = last_output['pt_text']
323
+ attn_scores = last_output['attn_scores'].detach().cpu()
324
+ logits = last_output['logits']
325
+
326
+ images = last_input[0] if isinstance(last_input, (tuple, list)) else last_input
327
+ images = images.detach().cpu()
328
+ pt_text = [self.charset.trim(t) for t in pt_text]
329
+ gt_label = last_target[0]
330
+ if gt_label.dim() == 3: gt_label = gt_label.argmax(dim=-1) # one-hot label
331
+ gt_text = [self.charset.get_text(l, trim=True) for l in gt_label]
332
+
333
+ prediction, false_prediction = [], []
334
+ for gt, pt, image, attn, logit in zip(gt_text, pt_text, images, attn_scores, logits):
335
+ prediction.append(f'{gt}\t{pt}\n')
336
+ if gt != pt:
337
+ if self.debug:
338
+ scores = torch.softmax(logit, dim=-1)[:max(len(pt), len(gt)) + 1]
339
+ logging.info(f'{self.c} gt {gt}, pt {pt}, logit {logit.shape}, scores {scores.topk(5, dim=-1)}')
340
+ false_prediction.append(f'{gt}\t{pt}\n')
341
+
342
+ image = self.learn.data.denorm(image)
343
+ if not self.image_only:
344
+ image_np = np.array(self.pil(image))
345
+ attn_pil = [self.pil(a) for a in attn[:, None, :, :]]
346
+ attn = [self.tensor(self.resize(a)).repeat(3, 1, 1) for a in attn_pil]
347
+ attn_sum = np.array([np.array(a) for a in attn_pil[:len(pt)]]).sum(axis=0)
348
+ blended_sum = self.tensor(blend_mask(image_np, attn_sum))
349
+ blended = [self.tensor(blend_mask(image_np, np.array(a))) for a in attn_pil]
350
+ save_image = torch.stack([image] + attn + [blended_sum] + blended)
351
+ save_image = save_image.view(2, -1, *save_image.shape[1:])
352
+ save_image = save_image.permute(1, 0, 2, 3, 4).flatten(0, 1)
353
+ vutils.save_image(save_image, self.attn_root/f'{self.c}_{gt}_{pt}.jpg',
354
+ nrow=2, normalize=True, scale_each=True)
355
+ else:
356
+ self.pil(image).save(self.attn_root/f'{self.c}_{gt}_{pt}.jpg')
357
+ self.c += 1
358
+
359
+ with open(self.root/f'{self.model_eval}.txt', 'a') as f: f.writelines(prediction)
360
+ with open(self.root/f'{self.model_eval}-false.txt', 'a') as f: f.writelines(false_prediction)
commands.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ python tools/create_lmdb_dataset.py
2
+ python demo.py --config=configs/train_abinet.yaml --input=base/
3
+ python main.py --config=configs/train_abinet.yaml
config.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ name: train-abinet
3
+ phase: train
4
+ stage: train-super
5
+ workdir: workdir
6
+ seed: ~
7
+
8
+ dataset:
9
+ train: {
10
+ roots: [
11
+ 'output_tbell_dataset2/',
12
+ # 'data/training/MJ/MJ_train/',
13
+ # 'data/training/MJ/MJ_test/',
14
+ # 'data/training/MJ/MJ_valid/',
15
+ # 'data/training/ST'
16
+ ],
17
+ batch_size: 20
18
+ }
19
+ test: {
20
+ roots: [
21
+ 'output_tbell_dataset2/'
22
+ ],
23
+ batch_size: 2
24
+ }
25
+ data_aug: True
26
+ multiscales: False
27
+ num_workers: 1
28
+
29
+ training:
30
+ epochs: 100000
31
+ show_iters: 50
32
+ eval_iters: 600
33
+ # save_iters: 3000
34
+
35
+ optimizer:
36
+ type: AdamW
37
+ true_wd: False
38
+ wd: 0.0
39
+ bn_wd: False
40
+ clip_grad: 20
41
+ lr: 0.0001
42
+ args: {
43
+ betas: !!python/tuple [0.9, 0.999], # for default Adam
44
+ }
45
+ scheduler: {
46
+ periods: [6, 4],
47
+ gamma: 0.1,
48
+ }
49
+
50
+ model:
51
+ name: 'modules.model_abinet_iter.ABINetIterModel'
52
+ iter_size: 3
53
+ ensemble: ''
54
+ use_vision: True
55
+ vision: {
56
+ checkpoint: workdir/pretrain-vision-model/best-pretrain-vision-model.pth,
57
+ loss_weight: 1.,
58
+ attention: 'position',
59
+ backbone: 'transformer',
60
+ backbone_ln: 3,
61
+ }
62
+ language: {
63
+ checkpoint: workdir/pretrain-language-model/pretrain-language-model.pth,
64
+ num_layers: 4,
65
+ loss_weight: 1.,
66
+ detach: True,
67
+ use_self_attn: False
68
+ }
69
+ alignment: {
70
+ loss_weight: 1.,
71
+ }
dataset.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+
4
+ import cv2
5
+ import lmdb
6
+ import six
7
+ from fastai.vision import *
8
+ from torchvision import transforms
9
+
10
+ from transforms import CVColorJitter, CVDeterioration, CVGeometry
11
+ from utils import CharsetMapper, onehot
12
+
13
+
14
+ class ImageDataset(Dataset):
15
+ "`ImageDataset` read data from LMDB database."
16
+
17
+ def __init__(self,
18
+ path:PathOrStr,
19
+ is_training:bool=True,
20
+ img_h:int=32,
21
+ img_w:int=100,
22
+ max_length:int=25,
23
+ check_length:bool=True,
24
+ case_sensitive:bool=False,
25
+ charset_path:str='data/charset_36.txt',
26
+ convert_mode:str='RGB',
27
+ data_aug:bool=True,
28
+ deteriorate_ratio:float=0.,
29
+ multiscales:bool=True,
30
+ one_hot_y:bool=True,
31
+ return_idx:bool=False,
32
+ return_raw:bool=False,
33
+ **kwargs):
34
+ self.path, self.name = Path(path), Path(path).name
35
+ assert self.path.is_dir() and self.path.exists(), f"{path} is not a valid directory."
36
+ self.convert_mode, self.check_length = convert_mode, check_length
37
+ self.img_h, self.img_w = img_h, img_w
38
+ self.max_length, self.one_hot_y = max_length, one_hot_y
39
+ self.return_idx, self.return_raw = return_idx, return_raw
40
+ self.case_sensitive, self.is_training = case_sensitive, is_training
41
+ self.data_aug, self.multiscales = data_aug, multiscales
42
+ self.charset = CharsetMapper(charset_path, max_length=max_length+1)
43
+ self.c = self.charset.num_classes
44
+
45
+ self.env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False)
46
+ assert self.env, f'Cannot open LMDB dataset from {path}.'
47
+ with self.env.begin(write=False) as txn:
48
+ self.length = int(txn.get('num-samples'.encode()))
49
+
50
+ if self.is_training and self.data_aug:
51
+ self.augment_tfs = transforms.Compose([
52
+ CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5),
53
+ CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
54
+ CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25)
55
+ ])
56
+ self.totensor = transforms.ToTensor()
57
+
58
+ def __len__(self): return self.length
59
+
60
+ def _next_image(self, index):
61
+ next_index = random.randint(0, len(self) - 1)
62
+ return self.get(next_index)
63
+
64
+ def _check_image(self, x, pixels=6):
65
+ if x.size[0] <= pixels or x.size[1] <= pixels: return False
66
+ else: return True
67
+
68
+ def resize_multiscales(self, img, borderType=cv2.BORDER_CONSTANT):
69
+ def _resize_ratio(img, ratio, fix_h=True):
70
+ if ratio * self.img_w < self.img_h:
71
+ if fix_h: trg_h = self.img_h
72
+ else: trg_h = int(ratio * self.img_w)
73
+ trg_w = self.img_w
74
+ else: trg_h, trg_w = self.img_h, int(self.img_h / ratio)
75
+ img = cv2.resize(img, (trg_w, trg_h))
76
+ pad_h, pad_w = (self.img_h - trg_h) / 2, (self.img_w - trg_w) / 2
77
+ top, bottom = math.ceil(pad_h), math.floor(pad_h)
78
+ left, right = math.ceil(pad_w), math.floor(pad_w)
79
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType)
80
+ return img
81
+
82
+ if self.is_training:
83
+ if random.random() < 0.5:
84
+ base, maxh, maxw = self.img_h, self.img_h, self.img_w
85
+ h, w = random.randint(base, maxh), random.randint(base, maxw)
86
+ return _resize_ratio(img, h/w)
87
+ else: return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio
88
+ else: return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio
89
+
90
+ def resize(self, img):
91
+ if self.multiscales: return self.resize_multiscales(img, cv2.BORDER_REPLICATE)
92
+ else: return cv2.resize(img, (self.img_w, self.img_h))
93
+
94
+ def get(self, idx):
95
+ with self.env.begin(write=False) as txn:
96
+ image_key, label_key = f'image-{idx+1:09d}', f'label-{idx+1:09d}'
97
+ try:
98
+ label = str(txn.get(label_key.encode()), 'utf-8') # label
99
+ label = re.sub('[^0-9a-zA-Z]+', '', label)
100
+ if self.check_length and self.max_length > 0:
101
+ if len(label) > self.max_length or len(label) <= 0:
102
+ #logging.info(f'Long or short text image is found: {self.name}, {idx}, {label}, {len(label)}')
103
+ return self._next_image(idx)
104
+ label = label[:self.max_length]
105
+
106
+ imgbuf = txn.get(image_key.encode()) # image
107
+ buf = six.BytesIO()
108
+ buf.write(imgbuf)
109
+ buf.seek(0)
110
+ with warnings.catch_warnings():
111
+ warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin
112
+ image = PIL.Image.open(buf).convert(self.convert_mode)
113
+ if self.is_training and not self._check_image(image):
114
+ #logging.info(f'Invalid image is found: {self.name}, {idx}, {label}, {len(label)}')
115
+ return self._next_image(idx)
116
+ except:
117
+ import traceback
118
+ traceback.print_exc()
119
+ logging.info(f'Corrupted image is found: {self.name}, {idx}, {label}, {len(label)}')
120
+ return self._next_image(idx)
121
+ return image, label, idx
122
+
123
+ def _process_training(self, image):
124
+ if self.data_aug: image = self.augment_tfs(image)
125
+ image = self.resize(np.array(image))
126
+ return image
127
+
128
+ def _process_test(self, image):
129
+ return self.resize(np.array(image)) # TODO:move is_training to here
130
+
131
+ def __getitem__(self, idx):
132
+ image, text, idx_new = self.get(idx)
133
+ if not self.is_training: assert idx == idx_new, f'idx {idx} != idx_new {idx_new} during testing.'
134
+
135
+ if self.is_training: image = self._process_training(image)
136
+ else: image = self._process_test(image)
137
+ if self.return_raw: return image, text
138
+ image = self.totensor(image)
139
+
140
+ length = tensor(len(text) + 1).to(dtype=torch.long) # one for end token
141
+ label = self.charset.get_labels(text, case_sensitive=self.case_sensitive)
142
+ label = tensor(label).to(dtype=torch.long)
143
+ if self.one_hot_y: label = onehot(label, self.charset.num_classes)
144
+
145
+ if self.return_idx: y = [label, length, idx_new]
146
+ else: y = [label, length]
147
+ return image, y
148
+
149
+
150
+ class TextDataset(Dataset):
151
+ def __init__(self,
152
+ path:PathOrStr,
153
+ delimiter:str='\t',
154
+ max_length:int=25,
155
+ charset_path:str='data/charset_36.txt',
156
+ case_sensitive=False,
157
+ one_hot_x=True,
158
+ one_hot_y=True,
159
+ is_training=True,
160
+ smooth_label=False,
161
+ smooth_factor=0.2,
162
+ use_sm=False,
163
+ **kwargs):
164
+ self.path = Path(path)
165
+ self.case_sensitive, self.use_sm = case_sensitive, use_sm
166
+ self.smooth_factor, self.smooth_label = smooth_factor, smooth_label
167
+ self.charset = CharsetMapper(charset_path, max_length=max_length+1)
168
+ self.one_hot_x, self.one_hot_y, self.is_training = one_hot_x, one_hot_y, is_training
169
+ if self.is_training and self.use_sm: self.sm = SpellingMutation(charset=self.charset)
170
+
171
+ dtype = {'inp': str, 'gt': str}
172
+ self.df = pd.read_csv(self.path, dtype=dtype, delimiter=delimiter, na_filter=False)
173
+ self.inp_col, self.gt_col = 0, 1
174
+
175
+ def __len__(self): return len(self.df)
176
+
177
+ def __getitem__(self, idx):
178
+ text_x = self.df.iloc[idx, self.inp_col]
179
+ text_x = re.sub('[^0-9a-zA-Z]+', '', text_x)
180
+ if not self.case_sensitive: text_x = text_x.lower()
181
+ if self.is_training and self.use_sm: text_x = self.sm(text_x)
182
+
183
+ length_x = tensor(len(text_x) + 1).to(dtype=torch.long) # one for end token
184
+ label_x = self.charset.get_labels(text_x, case_sensitive=self.case_sensitive)
185
+ label_x = tensor(label_x)
186
+ if self.one_hot_x:
187
+ label_x = onehot(label_x, self.charset.num_classes)
188
+ if self.is_training and self.smooth_label:
189
+ label_x = torch.stack([self.prob_smooth_label(l) for l in label_x])
190
+ x = [label_x, length_x]
191
+
192
+ text_y = self.df.iloc[idx, self.gt_col]
193
+ text_y = re.sub('[^0-9a-zA-Z]+', '', text_y)
194
+ if not self.case_sensitive: text_y = text_y.lower()
195
+ length_y = tensor(len(text_y) + 1).to(dtype=torch.long) # one for end token
196
+ label_y = self.charset.get_labels(text_y, case_sensitive=self.case_sensitive)
197
+ label_y = tensor(label_y)
198
+ if self.one_hot_y: label_y = onehot(label_y, self.charset.num_classes)
199
+ y = [label_y, length_y]
200
+
201
+ return x, y
202
+
203
+ def prob_smooth_label(self, one_hot):
204
+ one_hot = one_hot.float()
205
+ delta = torch.rand([]) * self.smooth_factor
206
+ num_classes = len(one_hot)
207
+ noise = torch.rand(num_classes)
208
+ noise = noise / noise.sum() * delta
209
+ one_hot = one_hot * (1 - delta) + noise
210
+ return one_hot
211
+
212
+
213
+ class SpellingMutation(object):
214
+ def __init__(self, pn0=0.7, pn1=0.85, pn2=0.95, pt0=0.7, pt1=0.85, charset=None):
215
+ """
216
+ Args:
217
+ pn0: the prob of not modifying characters is (pn0)
218
+ pn1: the prob of modifying one characters is (pn1 - pn0)
219
+ pn2: the prob of modifying two characters is (pn2 - pn1),
220
+ and three (1 - pn2)
221
+ pt0: the prob of replacing operation is pt0.
222
+ pt1: the prob of inserting operation is (pt1 - pt0),
223
+ and deleting operation is (1 - pt1)
224
+ """
225
+ super().__init__()
226
+ self.pn0, self.pn1, self.pn2 = pn0, pn1, pn2
227
+ self.pt0, self.pt1 = pt0, pt1
228
+ self.charset = charset
229
+ logging.info(f'the probs: pn0={self.pn0}, pn1={self.pn1} ' +
230
+ f'pn2={self.pn2}, pt0={self.pt0}, pt1={self.pt1}')
231
+
232
+ def is_digit(self, text, ratio=0.5):
233
+ length = max(len(text), 1)
234
+ digit_num = sum([t in self.charset.digits for t in text])
235
+ if digit_num / length < ratio: return False
236
+ return True
237
+
238
+ def is_unk_char(self, char):
239
+ # return char == self.charset.unk_char
240
+ return (char not in self.charset.digits) and (char not in self.charset.alphabets)
241
+
242
+ def get_num_to_modify(self, length):
243
+ prob = random.random()
244
+ if prob < self.pn0: num_to_modify = 0
245
+ elif prob < self.pn1: num_to_modify = 1
246
+ elif prob < self.pn2: num_to_modify = 2
247
+ else: num_to_modify = 3
248
+
249
+ if length <= 1: num_to_modify = 0
250
+ elif length >= 2 and length <= 4: num_to_modify = min(num_to_modify, 1)
251
+ else: num_to_modify = min(num_to_modify, length // 2) # smaller than length // 2
252
+ return num_to_modify
253
+
254
+ def __call__(self, text, debug=False):
255
+ if self.is_digit(text): return text
256
+ length = len(text)
257
+ num_to_modify = self.get_num_to_modify(length)
258
+ if num_to_modify <= 0: return text
259
+
260
+ chars = []
261
+ index = np.arange(0, length)
262
+ random.shuffle(index)
263
+ index = index[: num_to_modify]
264
+ if debug: self.index = index
265
+ for i, t in enumerate(text):
266
+ if i not in index: chars.append(t)
267
+ elif self.is_unk_char(t): chars.append(t)
268
+ else:
269
+ prob = random.random()
270
+ if prob < self.pt0: # replace
271
+ chars.append(random.choice(self.charset.alphabets))
272
+ elif prob < self.pt1: # insert
273
+ chars.append(random.choice(self.charset.alphabets))
274
+ chars.append(t)
275
+ else: # delete
276
+ continue
277
+ new_text = ''.join(chars[: self.charset.max_length-1])
278
+ return new_text if len(new_text) >= 1 else text
losses.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.vision import *
2
+
3
+ from modules.model import Model
4
+
5
+
6
+ class MultiLosses(nn.Module):
7
+ def __init__(self, one_hot=True):
8
+ super().__init__()
9
+ self.ce = SoftCrossEntropyLoss() if one_hot else torch.nn.CrossEntropyLoss()
10
+ self.bce = torch.nn.BCELoss()
11
+
12
+ @property
13
+ def last_losses(self):
14
+ return self.losses
15
+
16
+ def _flatten(self, sources, lengths):
17
+ return torch.cat([t[:l] for t, l in zip(sources, lengths)])
18
+
19
+ def _merge_list(self, all_res):
20
+ if not isinstance(all_res, (list, tuple)):
21
+ return all_res
22
+ def merge(items):
23
+ if isinstance(items[0], torch.Tensor): return torch.cat(items, dim=0)
24
+ else: return items[0]
25
+ res = dict()
26
+ for key in all_res[0].keys():
27
+ items = [r[key] for r in all_res]
28
+ res[key] = merge(items)
29
+ return res
30
+
31
+ def _ce_loss(self, output, gt_labels, gt_lengths, idx=None, record=True):
32
+ loss_name = output.get('name')
33
+ pt_logits, weight = output['logits'], output['loss_weight']
34
+
35
+ assert pt_logits.shape[0] % gt_labels.shape[0] == 0
36
+ iter_size = pt_logits.shape[0] // gt_labels.shape[0]
37
+ if iter_size > 1:
38
+ gt_labels = gt_labels.repeat(3, 1, 1)
39
+ gt_lengths = gt_lengths.repeat(3)
40
+ flat_gt_labels = self._flatten(gt_labels, gt_lengths)
41
+ flat_pt_logits = self._flatten(pt_logits, gt_lengths)
42
+
43
+ nll = output.get('nll')
44
+ if nll is not None:
45
+ loss = self.ce(flat_pt_logits, flat_gt_labels, softmax=False) * weight
46
+ else:
47
+ loss = self.ce(flat_pt_logits, flat_gt_labels) * weight
48
+ if record and loss_name is not None: self.losses[f'{loss_name}_loss'] = loss
49
+
50
+ return loss
51
+
52
+ def forward(self, outputs, *args):
53
+ self.losses = {}
54
+ if isinstance(outputs, (tuple, list)):
55
+ outputs = [self._merge_list(o) for o in outputs]
56
+ return sum([self._ce_loss(o, *args) for o in outputs if o['loss_weight'] > 0.])
57
+ else:
58
+ return self._ce_loss(outputs, *args, record=False)
59
+
60
+
61
+ class SoftCrossEntropyLoss(nn.Module):
62
+ def __init__(self, reduction="mean"):
63
+ super().__init__()
64
+ self.reduction = reduction
65
+
66
+ def forward(self, input, target, softmax=True):
67
+ if softmax: log_prob = F.log_softmax(input, dim=-1)
68
+ else: log_prob = torch.log(input)
69
+ loss = -(target * log_prob).sum(dim=-1)
70
+ if self.reduction == "mean": return loss.mean()
71
+ elif self.reduction == "sum": return loss.sum()
72
+ else: return loss
main.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import random
5
+
6
+ import torch
7
+ from fastai.callbacks.general_sched import GeneralScheduler, TrainingPhase
8
+ from fastai.distributed import *
9
+ from fastai.vision import *
10
+ from torch.backends import cudnn
11
+
12
+ from callbacks import DumpPrediction, IterationCallback, TextAccuracy, TopKTextAccuracy
13
+ from dataset import ImageDataset, TextDataset
14
+ from losses import MultiLosses
15
+ from utils import Config, Logger, MyDataParallel, MyConcatDataset
16
+
17
+
18
+ def _set_random_seed(seed):
19
+ if seed is not None:
20
+ random.seed(seed)
21
+ torch.manual_seed(seed)
22
+ cudnn.deterministic = True
23
+ logging.warning('You have chosen to seed training. '
24
+ 'This will slow down your training!')
25
+
26
+ def _get_training_phases(config, n):
27
+ lr = np.array(config.optimizer_lr)
28
+ periods = config.optimizer_scheduler_periods
29
+ sigma = [config.optimizer_scheduler_gamma ** i for i in range(len(periods))]
30
+ phases = [TrainingPhase(n * periods[i]).schedule_hp('lr', lr * sigma[i])
31
+ for i in range(len(periods))]
32
+ return phases
33
+
34
+ def _get_dataset(ds_type, paths, is_training, config, **kwargs):
35
+ kwargs.update({
36
+ 'img_h': config.dataset_image_height,
37
+ 'img_w': config.dataset_image_width,
38
+ 'max_length': config.dataset_max_length,
39
+ 'case_sensitive': config.dataset_case_sensitive,
40
+ 'charset_path': config.dataset_charset_path,
41
+ 'data_aug': config.dataset_data_aug,
42
+ 'deteriorate_ratio': config.dataset_deteriorate_ratio,
43
+ 'is_training': is_training,
44
+ 'multiscales': config.dataset_multiscales,
45
+ 'one_hot_y': config.dataset_one_hot_y,
46
+ })
47
+ datasets = [ds_type(p, **kwargs) for p in paths]
48
+ if len(datasets) > 1: return MyConcatDataset(datasets)
49
+ else: return datasets[0]
50
+
51
+
52
+ def _get_language_databaunch(config):
53
+ kwargs = {
54
+ 'max_length': config.dataset_max_length,
55
+ 'case_sensitive': config.dataset_case_sensitive,
56
+ 'charset_path': config.dataset_charset_path,
57
+ 'smooth_label': config.dataset_smooth_label,
58
+ 'smooth_factor': config.dataset_smooth_factor,
59
+ 'one_hot_y': config.dataset_one_hot_y,
60
+ 'use_sm': config.dataset_use_sm,
61
+ }
62
+ train_ds = TextDataset(config.dataset_train_roots[0], is_training=True, **kwargs)
63
+ valid_ds = TextDataset(config.dataset_test_roots[0], is_training=False, **kwargs)
64
+ data = DataBunch.create(
65
+ path=train_ds.path,
66
+ train_ds=train_ds,
67
+ valid_ds=valid_ds,
68
+ bs=config.dataset_train_batch_size,
69
+ val_bs=config.dataset_test_batch_size,
70
+ num_workers=config.dataset_num_workers,
71
+ pin_memory=config.dataset_pin_memory)
72
+ logging.info(f'{len(data.train_ds)} training items found.')
73
+ if not data.empty_val:
74
+ logging.info(f'{len(data.valid_ds)} valid items found.')
75
+ return data
76
+
77
+ def _get_databaunch(config):
78
+ # An awkward way to reduce loadding data time during test
79
+ if config.global_phase == 'test': config.dataset_train_roots = config.dataset_test_roots
80
+ train_ds = _get_dataset(ImageDataset, config.dataset_train_roots, True, config)
81
+ valid_ds = _get_dataset(ImageDataset, config.dataset_test_roots, False, config)
82
+ data = ImageDataBunch.create(
83
+ train_ds=train_ds,
84
+ valid_ds=valid_ds,
85
+ bs=config.dataset_train_batch_size,
86
+ val_bs=config.dataset_test_batch_size,
87
+ num_workers=config.dataset_num_workers,
88
+ pin_memory=config.dataset_pin_memory).normalize(imagenet_stats)
89
+ ar_tfm = lambda x: ((x[0], x[1]), x[1]) # auto-regression only for dtd
90
+ data.add_tfm(ar_tfm)
91
+
92
+ logging.info(f'{len(data.train_ds)} training items found.')
93
+ if not data.empty_val:
94
+ logging.info(f'{len(data.valid_ds)} valid items found.')
95
+
96
+ return data
97
+
98
+ def _get_model(config):
99
+ import importlib
100
+ names = config.model_name.split('.')
101
+ module_name, class_name = '.'.join(names[:-1]), names[-1]
102
+ cls = getattr(importlib.import_module(module_name), class_name)
103
+ model = cls(config)
104
+ logging.info(model)
105
+ return model
106
+
107
+
108
+ def _get_learner(config, data, model, local_rank=None):
109
+ strict = ifnone(config.model_strict, True)
110
+ if config.global_stage == 'pretrain-language':
111
+ metrics = [TopKTextAccuracy(
112
+ k=ifnone(config.model_k, 5),
113
+ charset_path=config.dataset_charset_path,
114
+ max_length=config.dataset_max_length + 1,
115
+ case_sensitive=config.dataset_eval_case_sensisitves,
116
+ model_eval=config.model_eval)]
117
+ else:
118
+ metrics = [TextAccuracy(
119
+ charset_path=config.dataset_charset_path,
120
+ max_length=config.dataset_max_length + 1,
121
+ case_sensitive=config.dataset_eval_case_sensisitves,
122
+ model_eval=config.model_eval)]
123
+ opt_type = getattr(torch.optim, config.optimizer_type)
124
+ learner = Learner(data, model, silent=True, model_dir='.',
125
+ true_wd=config.optimizer_true_wd,
126
+ wd=config.optimizer_wd,
127
+ bn_wd=config.optimizer_bn_wd,
128
+ path=config.global_workdir,
129
+ metrics=metrics,
130
+ opt_func=partial(opt_type, **config.optimizer_args or dict()),
131
+ loss_func=MultiLosses(one_hot=config.dataset_one_hot_y))
132
+ learner.split(lambda m: children(m))
133
+
134
+ if config.global_phase == 'train':
135
+ num_replicas = 1 if local_rank is None else torch.distributed.get_world_size()
136
+ phases = _get_training_phases(config, len(learner.data.train_dl)//num_replicas)
137
+ learner.callback_fns += [
138
+ partial(GeneralScheduler, phases=phases),
139
+ partial(GradientClipping, clip=config.optimizer_clip_grad),
140
+ partial(IterationCallback, name=config.global_name,
141
+ show_iters=config.training_show_iters,
142
+ eval_iters=config.training_eval_iters,
143
+ save_iters=config.training_save_iters,
144
+ start_iters=config.training_start_iters,
145
+ stats_iters=config.training_stats_iters)]
146
+ else:
147
+ learner.callbacks += [
148
+ DumpPrediction(learn=learner,
149
+ dataset='-'.join([Path(p).name for p in config.dataset_test_roots]),charset_path=config.dataset_charset_path,
150
+ model_eval=config.model_eval,
151
+ debug=config.global_debug,
152
+ image_only=config.global_image_only)]
153
+
154
+ learner.rank = local_rank
155
+ if local_rank is not None:
156
+ logging.info(f'Set model to distributed with rank {local_rank}.')
157
+ learner.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(learner.model)
158
+ learner.model.to(local_rank)
159
+ learner = learner.to_distributed(local_rank)
160
+
161
+ if torch.cuda.device_count() > 1 and local_rank is None:
162
+ logging.info(f'Use {torch.cuda.device_count()} GPUs.')
163
+ learner.model = MyDataParallel(learner.model)
164
+
165
+ if config.model_checkpoint:
166
+ if Path(config.model_checkpoint).exists():
167
+ with open(config.model_checkpoint, 'rb') as f:
168
+ buffer = io.BytesIO(f.read())
169
+ learner.load(buffer, strict=strict)
170
+ else:
171
+ from distutils.dir_util import copy_tree
172
+ src = Path('/data/fangsc/model')/config.global_name
173
+ trg = Path('/output')/config.global_name
174
+ if src.exists(): copy_tree(str(src), str(trg))
175
+ learner.load(config.model_checkpoint, strict=strict)
176
+ logging.info(f'Read model from {config.model_checkpoint}')
177
+ elif config.global_phase == 'test':
178
+ learner.load(f'best-{config.global_name}', strict=strict)
179
+ logging.info(f'Read model from best-{config.global_name}')
180
+
181
+ if learner.opt_func.func.__name__ == 'Adadelta': # fastai bug, fix after 1.0.60
182
+ learner.fit(epochs=0, lr=config.optimizer_lr)
183
+ learner.opt.mom = 0.
184
+
185
+ return learner
186
+
187
+ def main():
188
+ parser = argparse.ArgumentParser()
189
+ parser.add_argument('--config', type=str, required=True,
190
+ help='path to config file')
191
+ parser.add_argument('--phase', type=str, default=None, choices=['train', 'test'])
192
+ parser.add_argument('--name', type=str, default=None)
193
+ parser.add_argument('--checkpoint', type=str, default=None)
194
+ parser.add_argument('--test_root', type=str, default=None)
195
+ parser.add_argument("--local_rank", type=int, default=None)
196
+ parser.add_argument('--debug', action='store_true', default=None)
197
+ parser.add_argument('--image_only', action='store_true', default=None)
198
+ parser.add_argument('--model_strict', action='store_false', default=None)
199
+ parser.add_argument('--model_eval', type=str, default=None,
200
+ choices=['alignment', 'vision', 'language'])
201
+ args = parser.parse_args()
202
+ config = Config(args.config)
203
+ if args.name is not None: config.global_name = args.name
204
+ if args.phase is not None: config.global_phase = args.phase
205
+ if args.test_root is not None: config.dataset_test_roots = [args.test_root]
206
+ if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
207
+ if args.debug is not None: config.global_debug = args.debug
208
+ if args.image_only is not None: config.global_image_only = args.image_only
209
+ if args.model_eval is not None: config.model_eval = args.model_eval
210
+ if args.model_strict is not None: config.model_strict = args.model_strict
211
+
212
+ Logger.init(config.global_workdir, config.global_name, config.global_phase)
213
+ Logger.enable_file()
214
+ _set_random_seed(config.global_seed)
215
+ logging.info(config)
216
+
217
+ if args.local_rank is not None:
218
+ logging.info(f'Init distribution training at device {args.local_rank}.')
219
+ torch.cuda.set_device(args.local_rank)
220
+ torch.distributed.init_process_group(backend='nccl', init_method='env://')
221
+
222
+ logging.info('Construct dataset.')
223
+ if config.global_stage == 'pretrain-language': data = _get_language_databaunch(config)
224
+ else: data = _get_databaunch(config)
225
+
226
+ logging.info('Construct model.')
227
+ model = _get_model(config)
228
+
229
+ logging.info('Construct learner.')
230
+ learner = _get_learner(config, data, model, args.local_rank)
231
+
232
+ if config.global_phase == 'train':
233
+ logging.info('Start training.')
234
+ learner.fit(epochs=config.training_epochs,
235
+ lr=config.optimizer_lr)
236
+ else:
237
+ logging.info('Start validate')
238
+ last_metrics = learner.validate()
239
+ log_str = f'eval loss = {last_metrics[0]:6.3f}, ' \
240
+ f'ccr = {last_metrics[1]:6.3f}, cwr = {last_metrics[2]:6.3f}, ' \
241
+ f'ted = {last_metrics[3]:6.3f}, ned = {last_metrics[4]:6.0f}, ' \
242
+ f'ted/w = {last_metrics[5]:6.3f}, '
243
+ logging.info(log_str)
244
+
245
+ if __name__ == '__main__':
246
+ main()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ fastai==1.0.60
4
+ LMDB
5
+ Pillow
6
+ opencv-python
7
+ tensorboardX
8
+ PyYAML
9
+ gdown
transforms.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numbers
3
+ import random
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ from torchvision.transforms import Compose
10
+
11
+
12
+ def sample_asym(magnitude, size=None):
13
+ return np.random.beta(1, 4, size) * magnitude
14
+
15
+ def sample_sym(magnitude, size=None):
16
+ return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
17
+
18
+ def sample_uniform(low, high, size=None):
19
+ return np.random.uniform(low, high, size=size)
20
+
21
+ def get_interpolation(type='random'):
22
+ if type == 'random':
23
+ choice = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA]
24
+ interpolation = choice[random.randint(0, len(choice)-1)]
25
+ elif type == 'nearest': interpolation = cv2.INTER_NEAREST
26
+ elif type == 'linear': interpolation = cv2.INTER_LINEAR
27
+ elif type == 'cubic': interpolation = cv2.INTER_CUBIC
28
+ elif type == 'area': interpolation = cv2.INTER_AREA
29
+ else: raise TypeError('Interpolation types only nearest, linear, cubic, area are supported!')
30
+ return interpolation
31
+
32
+ class CVRandomRotation(object):
33
+ def __init__(self, degrees=15):
34
+ assert isinstance(degrees, numbers.Number), "degree should be a single number."
35
+ assert degrees >= 0, "degree must be positive."
36
+ self.degrees = degrees
37
+
38
+ @staticmethod
39
+ def get_params(degrees):
40
+ return sample_sym(degrees)
41
+
42
+ def __call__(self, img):
43
+ angle = self.get_params(self.degrees)
44
+ src_h, src_w = img.shape[:2]
45
+ M = cv2.getRotationMatrix2D(center=(src_w/2, src_h/2), angle=angle, scale=1.0)
46
+ abs_cos, abs_sin = abs(M[0,0]), abs(M[0,1])
47
+ dst_w = int(src_h * abs_sin + src_w * abs_cos)
48
+ dst_h = int(src_h * abs_cos + src_w * abs_sin)
49
+ M[0, 2] += (dst_w - src_w)/2
50
+ M[1, 2] += (dst_h - src_h)/2
51
+
52
+ flags = get_interpolation()
53
+ return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE)
54
+
55
+ class CVRandomAffine(object):
56
+ def __init__(self, degrees, translate=None, scale=None, shear=None):
57
+ assert isinstance(degrees, numbers.Number), "degree should be a single number."
58
+ assert degrees >= 0, "degree must be positive."
59
+ self.degrees = degrees
60
+
61
+ if translate is not None:
62
+ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
63
+ "translate should be a list or tuple and it must be of length 2."
64
+ for t in translate:
65
+ if not (0.0 <= t <= 1.0):
66
+ raise ValueError("translation values should be between 0 and 1")
67
+ self.translate = translate
68
+
69
+ if scale is not None:
70
+ assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
71
+ "scale should be a list or tuple and it must be of length 2."
72
+ for s in scale:
73
+ if s <= 0:
74
+ raise ValueError("scale values should be positive")
75
+ self.scale = scale
76
+
77
+ if shear is not None:
78
+ if isinstance(shear, numbers.Number):
79
+ if shear < 0:
80
+ raise ValueError("If shear is a single number, it must be positive.")
81
+ self.shear = [shear]
82
+ else:
83
+ assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
84
+ "shear should be a list or tuple and it must be of length 2."
85
+ self.shear = shear
86
+ else:
87
+ self.shear = shear
88
+
89
+ def _get_inverse_affine_matrix(self, center, angle, translate, scale, shear):
90
+ # https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717
91
+ from numpy import sin, cos, tan
92
+
93
+ if isinstance(shear, numbers.Number):
94
+ shear = [shear, 0]
95
+
96
+ if not isinstance(shear, (tuple, list)) and len(shear) == 2:
97
+ raise ValueError(
98
+ "Shear should be a single value or a tuple/list containing " +
99
+ "two values. Got {}".format(shear))
100
+
101
+ rot = math.radians(angle)
102
+ sx, sy = [math.radians(s) for s in shear]
103
+
104
+ cx, cy = center
105
+ tx, ty = translate
106
+
107
+ # RSS without scaling
108
+ a = cos(rot - sy) / cos(sy)
109
+ b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
110
+ c = sin(rot - sy) / cos(sy)
111
+ d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
112
+
113
+ # Inverted rotation matrix with scale and shear
114
+ # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
115
+ M = [d, -b, 0,
116
+ -c, a, 0]
117
+ M = [x / scale for x in M]
118
+
119
+ # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
120
+ M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
121
+ M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
122
+
123
+ # Apply center translation: C * RSS^-1 * C^-1 * T^-1
124
+ M[2] += cx
125
+ M[5] += cy
126
+ return M
127
+
128
+ @staticmethod
129
+ def get_params(degrees, translate, scale_ranges, shears, height):
130
+ angle = sample_sym(degrees)
131
+ if translate is not None:
132
+ max_dx = translate[0] * height
133
+ max_dy = translate[1] * height
134
+ translations = (np.round(sample_sym(max_dx)), np.round(sample_sym(max_dy)))
135
+ else:
136
+ translations = (0, 0)
137
+
138
+ if scale_ranges is not None:
139
+ scale = sample_uniform(scale_ranges[0], scale_ranges[1])
140
+ else:
141
+ scale = 1.0
142
+
143
+ if shears is not None:
144
+ if len(shears) == 1:
145
+ shear = [sample_sym(shears[0]), 0.]
146
+ elif len(shears) == 2:
147
+ shear = [sample_sym(shears[0]), sample_sym(shears[1])]
148
+ else:
149
+ shear = 0.0
150
+
151
+ return angle, translations, scale, shear
152
+
153
+
154
+ def __call__(self, img):
155
+ src_h, src_w = img.shape[:2]
156
+ angle, translate, scale, shear = self.get_params(
157
+ self.degrees, self.translate, self.scale, self.shear, src_h)
158
+
159
+ M = self._get_inverse_affine_matrix((src_w/2, src_h/2), angle, (0, 0), scale, shear)
160
+ M = np.array(M).reshape(2,3)
161
+
162
+ startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1), (0, src_h - 1)]
163
+ project = lambda x, y, a, b, c: int(a*x + b*y + c)
164
+ endpoints = [(project(x, y, *M[0]), project(x, y, *M[1])) for x, y in startpoints]
165
+
166
+ rect = cv2.minAreaRect(np.array(endpoints))
167
+ bbox = cv2.boxPoints(rect).astype(dtype=np.int)
168
+ max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
169
+ min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
170
+
171
+ dst_w = int(max_x - min_x)
172
+ dst_h = int(max_y - min_y)
173
+ M[0, 2] += (dst_w - src_w) / 2
174
+ M[1, 2] += (dst_h - src_h) / 2
175
+
176
+ # add translate
177
+ dst_w += int(abs(translate[0]))
178
+ dst_h += int(abs(translate[1]))
179
+ if translate[0] < 0: M[0, 2] += abs(translate[0])
180
+ if translate[1] < 0: M[1, 2] += abs(translate[1])
181
+
182
+ flags = get_interpolation()
183
+ return cv2.warpAffine(img, M, (dst_w , dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE)
184
+
185
+ class CVRandomPerspective(object):
186
+ def __init__(self, distortion=0.5):
187
+ self.distortion = distortion
188
+
189
+ def get_params(self, width, height, distortion):
190
+ offset_h = sample_asym(distortion * height / 2, size=4).astype(dtype=np.int)
191
+ offset_w = sample_asym(distortion * width / 2, size=4).astype(dtype=np.int)
192
+ topleft = ( offset_w[0], offset_h[0])
193
+ topright = (width - 1 - offset_w[1], offset_h[1])
194
+ botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
195
+ botleft = ( offset_w[3], height - 1 - offset_h[3])
196
+
197
+ startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
198
+ endpoints = [topleft, topright, botright, botleft]
199
+ return np.array(startpoints, dtype=np.float32), np.array(endpoints, dtype=np.float32)
200
+
201
+ def __call__(self, img):
202
+ height, width = img.shape[:2]
203
+ startpoints, endpoints = self.get_params(width, height, self.distortion)
204
+ M = cv2.getPerspectiveTransform(startpoints, endpoints)
205
+
206
+ # TODO: more robust way to crop image
207
+ rect = cv2.minAreaRect(endpoints)
208
+ bbox = cv2.boxPoints(rect).astype(dtype=np.int)
209
+ max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
210
+ min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
211
+ min_x, min_y = max(min_x, 0), max(min_y, 0)
212
+
213
+ flags = get_interpolation()
214
+ img = cv2.warpPerspective(img, M, (max_x, max_y), flags=flags, borderMode=cv2.BORDER_REPLICATE)
215
+ img = img[min_y:, min_x:]
216
+ return img
217
+
218
+ class CVRescale(object):
219
+
220
+ def __init__(self, factor=4, base_size=(128, 512)):
221
+ """ Define image scales using gaussian pyramid and rescale image to target scale.
222
+
223
+ Args:
224
+ factor: the decayed factor from base size, factor=4 keeps target scale by default.
225
+ base_size: base size the build the bottom layer of pyramid
226
+ """
227
+ if isinstance(factor, numbers.Number):
228
+ self.factor = round(sample_uniform(0, factor))
229
+ elif isinstance(factor, (tuple, list)) and len(factor) == 2:
230
+ self.factor = round(sample_uniform(factor[0], factor[1]))
231
+ else:
232
+ raise Exception('factor must be number or list with length 2')
233
+ # assert factor is valid
234
+ self.base_h, self.base_w = base_size[:2]
235
+
236
+ def __call__(self, img):
237
+ if self.factor == 0: return img
238
+ src_h, src_w = img.shape[:2]
239
+ cur_w, cur_h = self.base_w, self.base_h
240
+ scale_img = cv2.resize(img, (cur_w, cur_h), interpolation=get_interpolation())
241
+ for _ in range(self.factor):
242
+ scale_img = cv2.pyrDown(scale_img)
243
+ scale_img = cv2.resize(scale_img, (src_w, src_h), interpolation=get_interpolation())
244
+ return scale_img
245
+
246
+ class CVGaussianNoise(object):
247
+ def __init__(self, mean=0, var=20):
248
+ self.mean = mean
249
+ if isinstance(var, numbers.Number):
250
+ self.var = max(int(sample_asym(var)), 1)
251
+ elif isinstance(var, (tuple, list)) and len(var) == 2:
252
+ self.var = int(sample_uniform(var[0], var[1]))
253
+ else:
254
+ raise Exception('degree must be number or list with length 2')
255
+
256
+ def __call__(self, img):
257
+ noise = np.random.normal(self.mean, self.var**0.5, img.shape)
258
+ img = np.clip(img + noise, 0, 255).astype(np.uint8)
259
+ return img
260
+
261
+ class CVMotionBlur(object):
262
+ def __init__(self, degrees=12, angle=90):
263
+ if isinstance(degrees, numbers.Number):
264
+ self.degree = max(int(sample_asym(degrees)), 1)
265
+ elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
266
+ self.degree = int(sample_uniform(degrees[0], degrees[1]))
267
+ else:
268
+ raise Exception('degree must be number or list with length 2')
269
+ self.angle = sample_uniform(-angle, angle)
270
+
271
+ def __call__(self, img):
272
+ M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2), self.angle, 1)
273
+ motion_blur_kernel = np.zeros((self.degree, self.degree))
274
+ motion_blur_kernel[self.degree // 2, :] = 1
275
+ motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (self.degree, self.degree))
276
+ motion_blur_kernel = motion_blur_kernel / self.degree
277
+ img = cv2.filter2D(img, -1, motion_blur_kernel)
278
+ img = np.clip(img, 0, 255).astype(np.uint8)
279
+ return img
280
+
281
+ class CVGeometry(object):
282
+ def __init__(self, degrees=15, translate=(0.3, 0.3), scale=(0.5, 2.),
283
+ shear=(45, 15), distortion=0.5, p=0.5):
284
+ self.p = p
285
+ type_p = random.random()
286
+ if type_p < 0.33:
287
+ self.transforms = CVRandomRotation(degrees=degrees)
288
+ elif type_p < 0.66:
289
+ self.transforms = CVRandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear)
290
+ else:
291
+ self.transforms = CVRandomPerspective(distortion=distortion)
292
+
293
+ def __call__(self, img):
294
+ if random.random() < self.p:
295
+ img = np.array(img)
296
+ return Image.fromarray(self.transforms(img))
297
+ else: return img
298
+
299
+ class CVDeterioration(object):
300
+ def __init__(self, var, degrees, factor, p=0.5):
301
+ self.p = p
302
+ transforms = []
303
+ if var is not None:
304
+ transforms.append(CVGaussianNoise(var=var))
305
+ if degrees is not None:
306
+ transforms.append(CVMotionBlur(degrees=degrees))
307
+ if factor is not None:
308
+ transforms.append(CVRescale(factor=factor))
309
+
310
+ random.shuffle(transforms)
311
+ transforms = Compose(transforms)
312
+ self.transforms = transforms
313
+
314
+ def __call__(self, img):
315
+ if random.random() < self.p:
316
+ img = np.array(img)
317
+ return Image.fromarray(self.transforms(img))
318
+ else: return img
319
+
320
+
321
+ class CVColorJitter(object):
322
+ def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.5):
323
+ self.p = p
324
+ self.transforms = transforms.ColorJitter(brightness=brightness, contrast=contrast,
325
+ saturation=saturation, hue=hue)
326
+
327
+ def __call__(self, img):
328
+ if random.random() < self.p: return self.transforms(img)
329
+ else: return img
utils.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import yaml
9
+ from matplotlib import colors
10
+ from matplotlib import pyplot as plt
11
+ from torch import Tensor, nn
12
+ from torch.utils.data import ConcatDataset
13
+
14
+ class CharsetMapper(object):
15
+ """A simple class to map ids into strings.
16
+
17
+ It works only when the character set is 1:1 mapping between individual
18
+ characters and individual ids.
19
+ """
20
+
21
+ def __init__(self,
22
+ filename='',
23
+ max_length=30,
24
+ null_char=u'\u2591'):
25
+ """Creates a lookup table.
26
+
27
+ Args:
28
+ filename: Path to charset file which maps characters to ids.
29
+ max_sequence_length: The max length of ids and string.
30
+ null_char: A unicode character used to replace '<null>' character.
31
+ the default value is a light shade block '░'.
32
+ """
33
+ self.null_char = null_char
34
+ self.max_length = max_length
35
+
36
+ self.label_to_char = self._read_charset(filename)
37
+ self.char_to_label = dict(map(reversed, self.label_to_char.items()))
38
+ self.num_classes = len(self.label_to_char)
39
+
40
+ def _read_charset(self, filename):
41
+ """Reads a charset definition from a tab separated text file.
42
+
43
+ Args:
44
+ filename: a path to the charset file.
45
+
46
+ Returns:
47
+ a dictionary with keys equal to character codes and values - unicode
48
+ characters.
49
+ """
50
+ import re
51
+ pattern = re.compile(r'(\d+)\t(.+)')
52
+ charset = {}
53
+ self.null_label = 0
54
+ charset[self.null_label] = self.null_char
55
+ with open(filename, 'r') as f:
56
+ for i, line in enumerate(f):
57
+ m = pattern.match(line)
58
+ assert m, f'Incorrect charset file. line #{i}: {line}'
59
+ label = int(m.group(1)) + 1
60
+ char = m.group(2)
61
+ charset[label] = char
62
+ return charset
63
+
64
+ def trim(self, text):
65
+ assert isinstance(text, str)
66
+ return text.replace(self.null_char, '')
67
+
68
+ def get_text(self, labels, length=None, padding=True, trim=False):
69
+ """ Returns a string corresponding to a sequence of character ids.
70
+ """
71
+ length = length if length else self.max_length
72
+ labels = [l.item() if isinstance(l, Tensor) else int(l) for l in labels]
73
+ if padding:
74
+ labels = labels + [self.null_label] * (length-len(labels))
75
+ text = ''.join([self.label_to_char[label] for label in labels])
76
+ if trim: text = self.trim(text)
77
+ return text
78
+
79
+ def get_labels(self, text, length=None, padding=True, case_sensitive=False):
80
+ """ Returns the labels of the corresponding text.
81
+ """
82
+ length = length if length else self.max_length
83
+ if padding:
84
+ text = text + self.null_char * (length - len(text))
85
+ if not case_sensitive:
86
+ text = text.lower()
87
+ labels = [self.char_to_label[char] for char in text]
88
+ return labels
89
+
90
+ def pad_labels(self, labels, length=None):
91
+ length = length if length else self.max_length
92
+
93
+ return labels + [self.null_label] * (length - len(labels))
94
+
95
+ @property
96
+ def digits(self):
97
+ return '0123456789'
98
+
99
+ @property
100
+ def digit_labels(self):
101
+ return self.get_labels(self.digits, padding=False)
102
+
103
+ @property
104
+ def alphabets(self):
105
+ all_chars = list(self.char_to_label.keys())
106
+ valid_chars = []
107
+ for c in all_chars:
108
+ if c in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ':
109
+ valid_chars.append(c)
110
+ return ''.join(valid_chars)
111
+
112
+ @property
113
+ def alphabet_labels(self):
114
+ return self.get_labels(self.alphabets, padding=False)
115
+
116
+
117
+ class Timer(object):
118
+ """A simple timer."""
119
+ def __init__(self):
120
+ self.data_time = 0.
121
+ self.data_diff = 0.
122
+ self.data_total_time = 0.
123
+ self.data_call = 0
124
+ self.running_time = 0.
125
+ self.running_diff = 0.
126
+ self.running_total_time = 0.
127
+ self.running_call = 0
128
+
129
+ def tic(self):
130
+ self.start_time = time.time()
131
+ self.running_time = self.start_time
132
+
133
+ def toc_data(self):
134
+ self.data_time = time.time()
135
+ self.data_diff = self.data_time - self.running_time
136
+ self.data_total_time += self.data_diff
137
+ self.data_call += 1
138
+
139
+ def toc_running(self):
140
+ self.running_time = time.time()
141
+ self.running_diff = self.running_time - self.data_time
142
+ self.running_total_time += self.running_diff
143
+ self.running_call += 1
144
+
145
+ def total_time(self):
146
+ return self.data_total_time + self.running_total_time
147
+
148
+ def average_time(self):
149
+ return self.average_data_time() + self.average_running_time()
150
+
151
+ def average_data_time(self):
152
+ return self.data_total_time / (self.data_call or 1)
153
+
154
+ def average_running_time(self):
155
+ return self.running_total_time / (self.running_call or 1)
156
+
157
+
158
+ class Logger(object):
159
+ _handle = None
160
+ _root = None
161
+
162
+ @staticmethod
163
+ def init(output_dir, name, phase):
164
+ format = '[%(asctime)s %(filename)s:%(lineno)d %(levelname)s {}] ' \
165
+ '%(message)s'.format(name)
166
+ logging.basicConfig(level=logging.INFO, format=format)
167
+
168
+ try: os.makedirs(output_dir)
169
+ except: pass
170
+ config_path = os.path.join(output_dir, f'{phase}.txt')
171
+ Logger._handle = logging.FileHandler(config_path)
172
+ Logger._root = logging.getLogger()
173
+
174
+ @staticmethod
175
+ def enable_file():
176
+ if Logger._handle is None or Logger._root is None:
177
+ raise Exception('Invoke Logger.init() first!')
178
+ Logger._root.addHandler(Logger._handle)
179
+
180
+ @staticmethod
181
+ def disable_file():
182
+ if Logger._handle is None or Logger._root is None:
183
+ raise Exception('Invoke Logger.init() first!')
184
+ Logger._root.removeHandler(Logger._handle)
185
+
186
+
187
+ class Config(object):
188
+
189
+ def __init__(self, config_path, host=True):
190
+ def __dict2attr(d, prefix=''):
191
+ for k, v in d.items():
192
+ if isinstance(v, dict):
193
+ __dict2attr(v, f'{prefix}{k}_')
194
+ else:
195
+ if k == 'phase':
196
+ assert v in ['train', 'test']
197
+ if k == 'stage':
198
+ assert v in ['pretrain-vision', 'pretrain-language',
199
+ 'train-semi-super', 'train-super']
200
+ self.__setattr__(f'{prefix}{k}', v)
201
+
202
+ assert os.path.exists(config_path), '%s does not exists!' % config_path
203
+ with open(config_path) as file:
204
+ config_dict = yaml.load(file, Loader=yaml.FullLoader)
205
+ with open('configs/template.yaml') as file:
206
+ default_config_dict = yaml.load(file, Loader=yaml.FullLoader)
207
+ __dict2attr(default_config_dict)
208
+ __dict2attr(config_dict)
209
+ self.global_workdir = os.path.join(self.global_workdir, self.global_name)
210
+
211
+ def __getattr__(self, item):
212
+ attr = self.__dict__.get(item)
213
+ if attr is None:
214
+ attr = dict()
215
+ prefix = f'{item}_'
216
+ for k, v in self.__dict__.items():
217
+ if k.startswith(prefix):
218
+ n = k.replace(prefix, '')
219
+ attr[n] = v
220
+ return attr if len(attr) > 0 else None
221
+ else:
222
+ return attr
223
+
224
+ def __repr__(self):
225
+ str = 'ModelConfig(\n'
226
+ for i, (k, v) in enumerate(sorted(vars(self).items())):
227
+ str += f'\t({i}): {k} = {v}\n'
228
+ str += ')'
229
+ return str
230
+
231
+ def blend_mask(image, mask, alpha=0.5, cmap='jet', color='b', color_alpha=1.0):
232
+ # normalize mask
233
+ mask = (mask-mask.min()) / (mask.max() - mask.min() + np.finfo(float).eps)
234
+ if mask.shape != image.shape:
235
+ mask = cv2.resize(mask,(image.shape[1], image.shape[0]))
236
+ # get color map
237
+ color_map = plt.get_cmap(cmap)
238
+ mask = color_map(mask)[:,:,:3]
239
+ # convert float to uint8
240
+ mask = (mask * 255).astype(dtype=np.uint8)
241
+
242
+ # set the basic color
243
+ basic_color = np.array(colors.to_rgb(color)) * 255
244
+ basic_color = np.tile(basic_color, [image.shape[0], image.shape[1], 1])
245
+ basic_color = basic_color.astype(dtype=np.uint8)
246
+ # blend with basic color
247
+ blended_img = cv2.addWeighted(image, color_alpha, basic_color, 1-color_alpha, 0)
248
+ # blend with mask
249
+ blended_img = cv2.addWeighted(blended_img, alpha, mask, 1-alpha, 0)
250
+
251
+ return blended_img
252
+
253
+ def onehot(label, depth, device=None):
254
+ """
255
+ Args:
256
+ label: shape (n1, n2, ..., )
257
+ depth: a scalar
258
+
259
+ Returns:
260
+ onehot: (n1, n2, ..., depth)
261
+ """
262
+ if not isinstance(label, torch.Tensor):
263
+ label = torch.tensor(label, device=device)
264
+ onehot = torch.zeros(label.size() + torch.Size([depth]), device=device)
265
+ onehot = onehot.scatter_(-1, label.unsqueeze(-1), 1)
266
+
267
+ return onehot
268
+
269
+ class MyDataParallel(nn.DataParallel):
270
+
271
+ def gather(self, outputs, target_device):
272
+ r"""
273
+ Gathers tensors from different GPUs on a specified device
274
+ (-1 means the CPU).
275
+ """
276
+ def gather_map(outputs):
277
+ out = outputs[0]
278
+ if isinstance(out, (str, int, float)):
279
+ return out
280
+ if isinstance(out, list) and isinstance(out[0], str):
281
+ return [o for out in outputs for o in out]
282
+ if isinstance(out, torch.Tensor):
283
+ return torch.nn.parallel._functions.Gather.apply(target_device, self.dim, *outputs)
284
+ if out is None:
285
+ return None
286
+ if isinstance(out, dict):
287
+ if not all((len(out) == len(d) for d in outputs)):
288
+ raise ValueError('All dicts must have the same number of keys')
289
+ return type(out)(((k, gather_map([d[k] for d in outputs]))
290
+ for k in out))
291
+ return type(out)(map(gather_map, zip(*outputs)))
292
+
293
+ # Recursive function calls like this create reference cycles.
294
+ # Setting the function to None clears the refcycle.
295
+ try:
296
+ res = gather_map(outputs)
297
+ finally:
298
+ gather_map = None
299
+ return res
300
+
301
+
302
+ class MyConcatDataset(ConcatDataset):
303
+ def __getattr__(self, k):
304
+ return getattr(self.datasets[0], k)