Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import argparse | |
import json | |
import os.path as osp | |
import re | |
from pathlib import Path | |
from unittest.mock import MagicMock | |
import matplotlib.pyplot as plt | |
import rich | |
import torch.nn as nn | |
from mmengine.config import Config, DictAction | |
from mmengine.hooks import Hook | |
from mmengine.model import BaseModel | |
from mmengine.registry import init_default_scope | |
from mmengine.runner import Runner | |
from mmengine.visualization import Visualizer | |
from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn | |
from mmocr.registry import DATASETS | |
class SimpleModel(BaseModel): | |
"""simple model that do nothing in train_step.""" | |
def __init__(self): | |
super(SimpleModel, self).__init__() | |
self.data_preprocessor = nn.Identity() | |
self.conv = nn.Conv2d(1, 1, 1) | |
def forward(self, inputs, data_samples, mode='tensor'): | |
pass | |
def train_step(self, data, optim_wrapper): | |
pass | |
class ParamRecordHook(Hook): | |
def __init__(self, by_epoch): | |
super().__init__() | |
self.by_epoch = by_epoch | |
self.lr_list = [] | |
self.momentum_list = [] | |
self.wd_list = [] | |
self.task_id = 0 | |
self.progress = Progress(BarColumn(), MofNCompleteColumn(), | |
TextColumn('{task.description}')) | |
def before_train(self, runner): | |
if self.by_epoch: | |
total = runner.train_loop.max_epochs | |
self.task_id = self.progress.add_task( | |
'epochs', start=True, total=total) | |
else: | |
total = runner.train_loop.max_iters | |
self.task_id = self.progress.add_task( | |
'iters', start=True, total=total) | |
self.progress.start() | |
def after_train_epoch(self, runner): | |
if self.by_epoch: | |
self.progress.update(self.task_id, advance=1) | |
def after_train_iter(self, runner, batch_idx, data_batch, outputs): | |
if not self.by_epoch: | |
self.progress.update(self.task_id, advance=1) | |
self.lr_list.append(runner.optim_wrapper.get_lr()['lr'][0]) | |
self.momentum_list.append( | |
runner.optim_wrapper.get_momentum()['momentum'][0]) | |
self.wd_list.append( | |
runner.optim_wrapper.param_groups[0]['weight_decay']) | |
def after_train(self, runner): | |
self.progress.stop() | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description='Visualize a Dataset Pipeline') | |
parser.add_argument('config', help='config file path') | |
parser.add_argument( | |
'-p', | |
'--parameter', | |
type=str, | |
default='lr', | |
choices=['lr', 'momentum', 'wd'], | |
help='The parameter to visualize its change curve, choose from' | |
'"lr", "wd" and "momentum". Defaults to "lr".') | |
parser.add_argument( | |
'-d', | |
'--dataset-size', | |
type=int, | |
help='The size of the dataset. If specify, `build_dataset` will ' | |
'be skipped and use this size as the dataset size.') | |
parser.add_argument( | |
'-n', | |
'--ngpus', | |
type=int, | |
default=1, | |
help='The number of GPUs used in training.') | |
parser.add_argument( | |
'-s', | |
'--save-path', | |
type=Path, | |
help='The learning rate curve plot save path') | |
parser.add_argument( | |
'--log-level', | |
default='WARNING', | |
help='The log level of the handler and logger. Defaults to ' | |
'WARNING.') | |
parser.add_argument('--title', type=str, help='title of figure') | |
parser.add_argument( | |
'--style', type=str, default='whitegrid', help='style of plt') | |
parser.add_argument('--not-show', default=False, action='store_true') | |
parser.add_argument( | |
'--window-size', | |
default='12*7', | |
help='Size of the window to display images, in format of "$W*$H".') | |
parser.add_argument( | |
'--cfg-options', | |
nargs='+', | |
action=DictAction, | |
help='override some settings in the used config, the key-value pair ' | |
'in xxx=yyy format will be merged into config file. If the value to ' | |
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' | |
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' | |
'Note that the quotation marks are necessary and that no white space ' | |
'is allowed.') | |
args = parser.parse_args() | |
if args.window_size != '': | |
assert re.match(r'\d+\*\d+', args.window_size), \ | |
"'window-size' must be in format 'W*H'." | |
return args | |
def plot_curve(lr_list, args, param_name, iters_per_epoch, by_epoch=True): | |
"""Plot learning rate vs iter graph.""" | |
try: | |
import seaborn as sns | |
sns.set_style(args.style) | |
except ImportError: | |
pass | |
wind_w, wind_h = args.window_size.split('*') | |
wind_w, wind_h = int(wind_w), int(wind_h) | |
plt.figure(figsize=(wind_w, wind_h)) | |
ax: plt.Axes = plt.subplot() | |
ax.plot(lr_list, linewidth=1) | |
if by_epoch: | |
ax.xaxis.tick_top() | |
ax.set_xlabel('Iters') | |
ax.xaxis.set_label_position('top') | |
sec_ax = ax.secondary_xaxis( | |
'bottom', | |
functions=(lambda x: x / iters_per_epoch, | |
lambda y: y * iters_per_epoch)) | |
sec_ax.set_xlabel('Epochs') | |
else: | |
plt.xlabel('Iters') | |
plt.ylabel(param_name) | |
if args.title is None: | |
plt.title(f'{osp.basename(args.config)} {param_name} curve') | |
else: | |
plt.title(args.title) | |
def simulate_train(data_loader, cfg, by_epoch): | |
model = SimpleModel() | |
param_record_hook = ParamRecordHook(by_epoch=by_epoch) | |
default_hooks = dict( | |
param_scheduler=cfg.default_hooks['param_scheduler'], | |
runtime_info=None, | |
timer=None, | |
logger=None, | |
checkpoint=None, | |
sampler_seed=None, | |
param_record=param_record_hook) | |
runner = Runner( | |
model=model, | |
work_dir=cfg.work_dir, | |
train_dataloader=data_loader, | |
train_cfg=cfg.train_cfg, | |
log_level=cfg.log_level, | |
optim_wrapper=cfg.optim_wrapper, | |
param_scheduler=cfg.param_scheduler, | |
default_scope=cfg.default_scope, | |
default_hooks=default_hooks, | |
visualizer=MagicMock(spec=Visualizer), | |
custom_hooks=cfg.get('custom_hooks', None)) | |
runner.train() | |
param_dict = dict( | |
lr=param_record_hook.lr_list, | |
momentum=param_record_hook.momentum_list, | |
wd=param_record_hook.wd_list) | |
return param_dict | |
def build_dataset(cfg): | |
return DATASETS.build(cfg) | |
def main(): | |
args = parse_args() | |
cfg = Config.fromfile(args.config) | |
init_default_scope(cfg.get('default_scope', 'mmocr')) | |
if args.cfg_options is not None: | |
cfg.merge_from_dict(args.cfg_options) | |
if cfg.get('work_dir', None) is None: | |
# use config filename as default work_dir if cfg.work_dir is None | |
cfg.work_dir = osp.join('./work_dirs', | |
osp.splitext(osp.basename(args.config))[0]) | |
cfg.log_level = args.log_level | |
# make sure save_root exists | |
if args.save_path and not args.save_path.parent.exists(): | |
raise FileNotFoundError( | |
f'The save path is {args.save_path}, and directory ' | |
f"'{args.save_path.parent}' do not exist.") | |
# init logger | |
print('Param_scheduler :') | |
rich.print_json(json.dumps(cfg.param_scheduler)) | |
# prepare data loader | |
batch_size = cfg.train_dataloader.batch_size * args.ngpus | |
if 'by_epoch' in cfg.train_cfg: | |
by_epoch = cfg.train_cfg.get('by_epoch') | |
elif 'type' in cfg.train_cfg: | |
by_epoch = cfg.train_cfg.get('type') == 'EpochBasedTrainLoop' | |
else: | |
raise ValueError('please set `train_cfg`.') | |
if args.dataset_size is None and by_epoch: | |
dataset_size = len(build_dataset(cfg.train_dataloader.dataset)) | |
else: | |
dataset_size = args.dataset_size or batch_size | |
class FakeDataloader(list): | |
dataset = MagicMock(metainfo=None) | |
data_loader = FakeDataloader(range(dataset_size // batch_size)) | |
dataset_info = ( | |
f'\nDataset infos:' | |
f'\n - Dataset size: {dataset_size}' | |
f'\n - Batch size per GPU: {cfg.train_dataloader.batch_size}' | |
f'\n - Number of GPUs: {args.ngpus}' | |
f'\n - Total batch size: {batch_size}') | |
if by_epoch: | |
dataset_info += f'\n - Iterations per epoch: {len(data_loader)}' | |
rich.print(dataset_info + '\n') | |
# simulation training process | |
param_dict = simulate_train(data_loader, cfg, by_epoch) | |
param_list = param_dict[args.parameter] | |
if args.parameter == 'lr': | |
param_name = 'Learning Rate' | |
elif args.parameter == 'momentum': | |
param_name = 'Momentum' | |
else: | |
param_name = 'Weight Decay' | |
plot_curve(param_list, args, param_name, len(data_loader), by_epoch) | |
if args.save_path: | |
plt.savefig(args.save_path) | |
print(f'\nThe {param_name} graph is saved at {args.save_path}') | |
if not args.not_show: | |
plt.show() | |
if __name__ == '__main__': | |
main() | |