File size: 6,967 Bytes
985cc7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import argparse
import glob
import json
import os
import os.path as osp
import shutil
import subprocess

import mmcv
import torch

# build schedule look-up table to automatically find the final model
RESULTS_LUT = ['mIoU', 'mAcc', 'aAcc']


def process_checkpoint(in_file, out_file):
    checkpoint = torch.load(in_file, map_location='cpu')
    # remove optimizer for smaller file size
    if 'optimizer' in checkpoint:
        del checkpoint['optimizer']
    # if it is necessary to remove some sensitive data in checkpoint['meta'],
    # add the code here.
    torch.save(checkpoint, out_file)
    sha = subprocess.check_output(['sha256sum', out_file]).decode()
    final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
    subprocess.Popen(['mv', out_file, final_file])
    return final_file


def get_final_iter(config):
    iter_num = config.split('_')[-2]
    assert iter_num.endswith('k')
    return int(iter_num[:-1]) * 1000


def get_final_results(log_json_path, iter_num):
    result_dict = dict()
    with open(log_json_path, 'r') as f:
        for line in f.readlines():
            log_line = json.loads(line)
            if 'mode' not in log_line.keys():
                continue

            if log_line['mode'] == 'train' and log_line['iter'] == iter_num:
                result_dict['memory'] = log_line['memory']

            if log_line['iter'] == iter_num:
                result_dict.update({
                    key: log_line[key]
                    for key in RESULTS_LUT if key in log_line
                })
                return result_dict


def parse_args():
    parser = argparse.ArgumentParser(description='Gather benchmarked models')
    parser.add_argument(
        'root',
        type=str,
        help='root path of benchmarked models to be gathered')
    parser.add_argument(
        'config',
        type=str,
        help='root path of benchmarked configs to be gathered')
    parser.add_argument(
        'out_dir',
        type=str,
        help='output path of gathered models to be stored')
    parser.add_argument('out_file', type=str, help='the output json file name')
    parser.add_argument(
        '--filter', type=str, nargs='+', default=[], help='config filter')
    parser.add_argument(
        '--all', action='store_true', help='whether include .py and .log')

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    models_root = args.root
    models_out = args.out_dir
    config_name = args.config
    mmcv.mkdir_or_exist(models_out)

    # find all models in the root directory to be gathered
    raw_configs = list(mmcv.scandir(config_name, '.py', recursive=True))

    # filter configs that is not trained in the experiments dir
    used_configs = []
    for raw_config in raw_configs:
        work_dir = osp.splitext(osp.basename(raw_config))[0]
        if osp.exists(osp.join(models_root, work_dir)):
            used_configs.append((work_dir, raw_config))
    print(f'Find {len(used_configs)} models to be gathered')

    # find final_ckpt and log file for trained each config
    # and parse the best performance
    model_infos = []
    for used_config, raw_config in used_configs:
        bypass = True
        for p in args.filter:
            if p in used_config:
                bypass = False
                break
        if bypass:
            continue
        exp_dir = osp.join(models_root, used_config)
        # check whether the exps is finished
        final_iter = get_final_iter(used_config)
        final_model = 'iter_{}.pth'.format(final_iter)
        model_path = osp.join(exp_dir, final_model)

        # skip if the model is still training
        if not osp.exists(model_path):
            print(f'{used_config} train not finished yet')
            continue

        # get logs
        log_json_paths = glob.glob(osp.join(exp_dir, '*.log.json'))
        log_json_path = log_json_paths[0]
        model_performance = None
        for idx, _log_json_path in enumerate(log_json_paths):
            model_performance = get_final_results(_log_json_path, final_iter)
            if model_performance is not None:
                log_json_path = _log_json_path
                break

        if model_performance is None:
            print(f'{used_config} model_performance is None')
            continue

        model_time = osp.split(log_json_path)[-1].split('.')[0]
        model_infos.append(
            dict(
                config=used_config,
                raw_config=raw_config,
                results=model_performance,
                iters=final_iter,
                model_time=model_time,
                log_json_path=osp.split(log_json_path)[-1]))

    # publish model for each checkpoint
    publish_model_infos = []
    for model in model_infos:
        model_publish_dir = osp.join(models_out,
                                     model['raw_config'].rstrip('.py'))
        model_name = osp.split(model['config'])[-1].split('.')[0]

        publish_model_path = osp.join(model_publish_dir,
                                      model_name + '_' + model['model_time'])
        trained_model_path = osp.join(models_root, model['config'],
                                      'iter_{}.pth'.format(model['iters']))
        if osp.exists(model_publish_dir):
            for file in os.listdir(model_publish_dir):
                if file.endswith('.pth'):
                    print(f'model {file} found')
                    model['model_path'] = osp.abspath(
                        osp.join(model_publish_dir, file))
                    break
            if 'model_path' not in model:
                print(f'dir {model_publish_dir} exists, no model found')

        else:
            mmcv.mkdir_or_exist(model_publish_dir)

            # convert model
            final_model_path = process_checkpoint(trained_model_path,
                                                  publish_model_path)
            model['model_path'] = final_model_path

        new_json_path = f'{model_name}-{model["log_json_path"]}'
        # copy log
        shutil.copy(
            osp.join(models_root, model['config'], model['log_json_path']),
            osp.join(model_publish_dir, new_json_path))
        if args.all:
            new_txt_path = new_json_path.rstrip('.json')
            shutil.copy(
                osp.join(models_root, model['config'],
                         model['log_json_path'].rstrip('.json')),
                osp.join(model_publish_dir, new_txt_path))

        if args.all:
            # copy config to guarantee reproducibility
            raw_config = osp.join(config_name, model['raw_config'])
            mmcv.Config.fromfile(raw_config).dump(
                osp.join(model_publish_dir, osp.basename(raw_config)))

        publish_model_infos.append(model)

    models = dict(models=publish_model_infos)
    mmcv.dump(models, osp.join(models_out, args.out_file))


if __name__ == '__main__':
    main()