import sys import argparse import os import time import logging from datetime import datetime def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', required=True, help='path to config file') parser.add_argument('--gpu', default='0', help='GPU(s) to be used') parser.add_argument('--resume', default=None, help='path to the weights to be resumed') parser.add_argument( '--resume_weights_only', action='store_true', help='specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only' ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument('--train', action='store_true') group.add_argument('--validate', action='store_true') group.add_argument('--test', action='store_true') group.add_argument('--predict', action='store_true') # group.add_argument('--export', action='store_true') # TODO: a separate export action parser.add_argument('--exp_dir', default='./exp') parser.add_argument('--runs_dir', default='./runs') parser.add_argument('--verbose', action='store_true', help='if true, set logging level to DEBUG') args, extras = parser.parse_known_args() # set CUDA_VISIBLE_DEVICES then import pytorch-lightning os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu n_gpus = len(args.gpu.split(',')) import datasets import systems import pytorch_lightning as pl from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger from utils.callbacks import CodeSnapshotCallback, ConfigSnapshotCallback, CustomProgressBar from utils.misc import load_config # parse YAML config to OmegaConf config = load_config(args.config, cli_args=extras) config.cmd_args = vars(args) config.trial_name = config.get('trial_name') or (config.tag + datetime.now().strftime('@%Y%m%d-%H%M%S')) config.exp_dir = config.get('exp_dir') or os.path.join(args.exp_dir, config.name) config.save_dir = config.get('save_dir') or os.path.join(config.exp_dir, config.trial_name, 'save') config.ckpt_dir = config.get('ckpt_dir') or os.path.join(config.exp_dir, config.trial_name, 'ckpt') config.code_dir = config.get('code_dir') or os.path.join(config.exp_dir, config.trial_name, 'code') config.config_dir = config.get('config_dir') or os.path.join(config.exp_dir, config.trial_name, 'config') logger = logging.getLogger('pytorch_lightning') if args.verbose: logger.setLevel(logging.DEBUG) if 'seed' not in config: config.seed = int(time.time() * 1000) % 1000 pl.seed_everything(config.seed) dm = datasets.make(config.dataset.name, config.dataset) system = systems.make(config.system.name, config, load_from_checkpoint=None if not args.resume_weights_only else args.resume) callbacks = [] if args.train: callbacks += [ ModelCheckpoint( dirpath=config.ckpt_dir, **config.checkpoint ), LearningRateMonitor(logging_interval='step'), CodeSnapshotCallback( config.code_dir, use_version=False ), ConfigSnapshotCallback( config, config.config_dir, use_version=False ), CustomProgressBar(refresh_rate=1), ] loggers = [] if args.train: loggers += [ TensorBoardLogger(args.runs_dir, name=config.name, version=config.trial_name), CSVLogger(config.exp_dir, name=config.trial_name, version='csv_logs') ] if sys.platform == 'win32': # does not support multi-gpu on windows strategy = 'dp' assert n_gpus == 1 else: strategy = 'ddp_find_unused_parameters_false' trainer = Trainer( devices=n_gpus, accelerator='gpu', callbacks=callbacks, logger=loggers, strategy=strategy, **config.trainer ) if args.train: if args.resume and not args.resume_weights_only: # FIXME: different behavior in pytorch-lighting>1.9 ? trainer.fit(system, datamodule=dm, ckpt_path=args.resume) else: trainer.fit(system, datamodule=dm) trainer.test(system, datamodule=dm) elif args.validate: trainer.validate(system, datamodule=dm, ckpt_path=args.resume) elif args.test: trainer.test(system, datamodule=dm, ckpt_path=args.resume) elif args.predict: trainer.predict(system, datamodule=dm, ckpt_path=args.resume) if __name__ == '__main__': main()