import argparse, os, sys, datetime, glob, importlib from omegaconf import OmegaConf import numpy as np from PIL import Image import torch import torchvision from torch.utils.data import DataLoader, Dataset from dataloader import CellLoader from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor import pytorch_lightning as pl from pytorch_lightning import seed_everything from pytorch_lightning.trainer import Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_only def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def get_parser(**parser_kwargs): def str2bool(v): if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") parser = argparse.ArgumentParser(**parser_kwargs) parser.add_argument( "-n", "--name", type=str, const=True, default="", nargs="?", help="postfix for logdir", ) parser.add_argument( "-r", "--resume", type=str, const=True, default="", nargs="?", help="resume from logdir or checkpoint in logdir", ) parser.add_argument( "-b", "--base", nargs="*", metavar="base_config.yaml", help="paths to base configs. Loaded from left-to-right. " "Parameters can be overwritten or added with command-line options of the form `--key value`.", default=list(), ) parser.add_argument( "-t", "--train", type=str2bool, const=True, default=False, nargs="?", help="train", ) parser.add_argument( "--no-test", type=str2bool, const=True, default=False, nargs="?", help="disable test", ) parser.add_argument( "-p", "--project", help="name of new or path to existing project" ) parser.add_argument( "-d", "--debug", type=str2bool, nargs="?", const=True, default=False, help="enable post-mortem debugging", ) parser.add_argument( "-s", "--seed", type=int, default=42, help="seed for seed_everything", ) parser.add_argument( "-f", "--postfix", type=str, default="", help="post-postfix for default name", ) return parser def nondefault_trainer_args(opt): parser = argparse.ArgumentParser() parser = Trainer.add_argparse_args(parser) args = parser.parse_args([]) return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) def instantiate_from_config(config): if not "target" in config: raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) class WrappedDataset(Dataset): """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" def __init__(self, dataset): self.data = dataset def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] class DataModuleFromConfig(pl.LightningDataModule): def __init__( self, data_csv, dataset, crop_size=256, resize=600, batch_size=1, sequence_mode="latent", vocab="bert", text_seq_len=0, num_workers=1, threshold=False, train=True, validation=True, test=None, wrap=False, **kwargs, ): super().__init__() self.data_csv = data_csv self.dataset = dataset self.image_folders = [] self.crop_size = crop_size self.resize = resize self.batch_size = batch_size self.sequence_mode = sequence_mode self.threshold = threshold self.text_seq_len = int(text_seq_len) self.vocab = vocab self.dataset_configs = dict() self.num_workers = num_workers if num_workers is not None else batch_size * 2 if train is not None: self.dataset_configs["train"] = train self.train_dataloader = self._train_dataloader if validation is not None: self.dataset_configs["validation"] = validation self.val_dataloader = self._val_dataloader if test is not None: self.dataset_configs["test"] = test self.test_dataloader = self._test_dataloader self.wrap = wrap def prepare_data(self): pass def setup(self, stage=None): # called on every GPU self.cell_dataset_train = CellLoader( data_csv=self.data_csv, dataset=self.dataset, crop_size=self.crop_size, split_key="train", crop_method="random", sequence_mode=None, vocab=self.vocab, text_seq_len=self.text_seq_len, threshold=self.threshold, ) self.cell_dataset_val = CellLoader( data_csv=self.data_csv, dataset=self.dataset, crop_size=self.crop_size, split_key="val", crop_method="center", sequence_mode=None, vocab=self.vocab, text_seq_len=self.text_seq_len, threshold=self.threshold, ) def _train_dataloader(self): return DataLoader( self.cell_dataset_train, num_workers=self.num_workers, pin_memory=True, shuffle=True, batch_size=self.batch_size, ) def _val_dataloader(self): return DataLoader( self.cell_dataset_val, num_workers=self.num_workers, pin_memory=True, batch_size=self.batch_size, ) # def _test_dataloader(self): # return DataLoader(self.datasets["test"], batch_size=self.batch_size, # num_workers=self.num_workers) class SetupCallback(Callback): def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): super().__init__() self.resume = resume self.now = now self.logdir = logdir self.ckptdir = ckptdir self.cfgdir = cfgdir self.config = config self.lightning_config = lightning_config def on_fit_start(self, trainer, pl_module): if trainer.global_rank == 0: # Create logdirs and save configs os.makedirs(self.logdir, exist_ok=True) os.makedirs(self.ckptdir, exist_ok=True) os.makedirs(self.cfgdir, exist_ok=True) print("Project config") print(OmegaConf.to_yaml(self.config)) OmegaConf.save( self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)), ) print("Lightning config") print(OmegaConf.to_yaml(self.lightning_config)) OmegaConf.save( OmegaConf.create({"lightning": self.lightning_config}), os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)), ) else: # ModelCheckpoint callback created log directory --- remove it if not self.resume and os.path.exists(self.logdir): dst, name = os.path.split(self.logdir) dst = os.path.join(dst, "child_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) try: os.rename(self.logdir, dst) except FileNotFoundError: pass class ImageLogger(Callback): def __init__( self, batch_frequency, max_images, clamp=True, increase_log_steps=True ): super().__init__() self.batch_freq = batch_frequency self.max_images = max_images self.logger_log_images = { pl.loggers.WandbLogger: self._wandb, # pl.loggers.TestTubeLogger: self._testtube, pl.loggers.TensorBoardLogger: self._testtube, } self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] if not increase_log_steps: self.log_steps = [self.batch_freq] self.clamp = clamp @rank_zero_only def _wandb(self, pl_module, images, batch_idx, split): raise ValueError("No way wandb") grids = dict() for k in images: grid = torchvision.utils.make_grid(images[k]) grids[f"{split}/{k}"] = wandb.Image(grid) pl_module.logger.experiment.log(grids) @rank_zero_only def _testtube(self, pl_module, images, batch_idx, split): for k in images: images[k] -= torch.min(images[k]) images[k] /= torch.max(images[k]) grid = torchvision.utils.make_grid(images[k]) # grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w tag = f"{split}/{k}" pl_module.logger.experiment.add_image( tag, grid, global_step=pl_module.global_step ) @rank_zero_only def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): root = os.path.join(save_dir, "images", split) for k in images: images[k] -= torch.min(images[k]) images[k] /= torch.max(images[k]) grid = torchvision.utils.make_grid(images[k], nrow=4) # grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) grid = grid.numpy() grid = (grid * 255).astype(np.uint8) filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( k, global_step, current_epoch, batch_idx ) path = os.path.join(root, filename) os.makedirs(os.path.split(path)[0], exist_ok=True) Image.fromarray(grid).save(path) def log_img(self, pl_module, batch, batch_idx, split="train"): if ( self.check_frequency(batch_idx) and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0 and callable(pl_module.log_images) and self.max_images > 0 ): logger = type(pl_module.logger) is_train = pl_module.training if is_train: pl_module.eval() with torch.no_grad(): images = pl_module.log_images(batch, split=split) for k in images: N = min(images[k].shape[0], self.max_images) images[k] = images[k][:N] if isinstance(images[k], torch.Tensor): images[k] = images[k].detach().cpu() if self.clamp: images[k] = torch.clamp(images[k], -1.0, 1.0) self.log_local( pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx, ) logger_log_images = self.logger_log_images.get( logger, lambda *args, **kwargs: None ) logger_log_images(pl_module, images, pl_module.global_step, split) if is_train: pl_module.train() def check_frequency(self, batch_idx): if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps): try: self.log_steps.pop(0) except IndexError: pass return True return False # def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): # def on_train_batch_end(self, *args, **kwargs): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.log_img(pl_module, batch, batch_idx, split="train") def on_validation_batch_end( self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx ): self.log_img(pl_module, batch, batch_idx, split="val") if __name__ == "__main__": # custom parser to specify config files, train, test and debug mode, # postfix, resume. # `--key value` arguments are interpreted as arguments to the trainer. # `nested.key=value` arguments are interpreted as config parameters. # configs are merged from left-to-right followed by command line parameters. # model: # base_learning_rate: float # target: path to lightning module # params: # key: value # data: # target: main.DataModuleFromConfig # params: # batch_size: int # wrap: bool # train: # target: path to train dataset # params: # key: value # validation: # target: path to validation dataset # params: # key: value # test: # target: path to test dataset # params: # key: value # lightning: (optional, has sane defaults and can be specified on cmdline) # trainer: # additional arguments to trainer # logger: # logger to instantiate # modelcheckpoint: # modelcheckpoint to instantiate # callbacks: # callback1: # target: importpath # params: # key: value now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # add cwd for convenience and to make classes in this file available when # running as `python main.py` # (in particular `main.DataModuleFromConfig`) sys.path.append(os.getcwd()) parser = get_parser() parser = Trainer.add_argparse_args(parser) opt, unknown = parser.parse_known_args() if opt.name and opt.resume: raise ValueError( "-n/--name and -r/--resume cannot be specified both." "If you want to resume training in a new log folder, " "use -n/--name in combination with --resume_from_checkpoint" ) if opt.resume: if not os.path.exists(opt.resume): raise ValueError("Cannot find {}".format(opt.resume)) if os.path.isfile(opt.resume): paths = opt.resume.split("/") idx = len(paths) - paths[::-1].index("logs") + 1 logdir = "/".join(paths[:idx]) ckpt = opt.resume else: assert os.path.isdir(opt.resume), opt.resume logdir = opt.resume.rstrip("/") ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") opt.resume_from_checkpoint = ckpt base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) opt.base = base_configs + opt.base _tmp = logdir.split("/") nowname = _tmp[_tmp.index("logs") + 1] else: if opt.name: name = "_" + opt.name elif opt.base: cfg_fname = os.path.split(opt.base[0])[-1] cfg_name = os.path.splitext(cfg_fname)[0] name = "_" + cfg_name else: name = "" nowname = now + name + opt.postfix logdir = os.path.join("logs", nowname) ckptdir = os.path.join(logdir, "checkpoints") cfgdir = os.path.join(logdir, "configs") seed_everything(opt.seed) try: # init and save configs configs = [OmegaConf.load(cfg) for cfg in opt.base] cli = OmegaConf.from_dotlist(unknown) config = OmegaConf.merge(*configs, cli) lightning_config = config.pop("lightning", OmegaConf.create()) # merge trainer cli with config trainer_config = lightning_config.get("trainer", OmegaConf.create()) # default to ddp trainer_config["distributed_backend"] = "ddp" trainer_config["replace_sampler_ddp"] = False trainer_config["strategy"] = "ddp" trainer_config["persistent_workers"] = True for k in nondefault_trainer_args(opt): trainer_config[k] = getattr(opt, k) if not "gpus" in trainer_config: del trainer_config["distributed_backend"] cpu = True else: gpuinfo = trainer_config["gpus"] print(f"Running on GPUs {gpuinfo}") cpu = False trainer_opt = argparse.Namespace(**trainer_config) lightning_config.trainer = trainer_config # model model = instantiate_from_config(config.model) # trainer and callbacks trainer_kwargs = dict() # default logger configs # NOTE wandb < 0.10.0 interferes with shutdown # wandb >= 0.10.0 seems to fix it but still interferes with pudb # debugging (wrongly sized pudb ui) # thus prefer testtube for now default_logger_cfgs = { "wandb": { "target": "pytorch_lightning.loggers.WandbLogger", "params": { "name": nowname, "save_dir": logdir, "offline": opt.debug, "id": nowname, }, }, "testtube": { # "target": "pytorch_lightning.loggers.TestTubeLogger", "target": "pytorch_lightning.loggers.TensorBoardLogger", "params": { "name": "testtube", "save_dir": logdir, }, }, } default_logger_cfg = default_logger_cfgs["testtube"] try: logger_cfg = lightning_config.logger except: logger_cfg = OmegaConf.create() logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # specify which metric is used to determine best models default_modelckpt_cfg = { "checkpoint_callback": { "target": "pytorch_lightning.callbacks.ModelCheckpoint", "params": { "dirpath": ckptdir, "filename": "{epoch:06}", "verbose": True, "save_last": True, }, } } if hasattr(model, "monitor"): print(f"Monitoring {model.monitor} as checkpoint metric.") default_modelckpt_cfg["checkpoint_callback"]["params"][ "monitor" ] = model.monitor default_modelckpt_cfg["checkpoint_callback"]["params"]["save_top_k"] = 3 try: modelckpt_cfg = lightning_config.modelcheckpoint except: modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) # trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) # loaded_model_callbacks = instantiate_from_config(modelckpt_cfg) # add callback which sets up log directory default_callbacks_cfg = { "setup_callback": { "target": "celle_taming_main.SetupCallback", "params": { "resume": opt.resume, "now": now, "logdir": logdir, "ckptdir": ckptdir, "cfgdir": cfgdir, "config": config, "lightning_config": lightning_config, }, }, "image_logger": { "target": "celle_taming_main.ImageLogger", "params": { "batch_frequency": 2000, "max_images": 10, "clamp": True, "increase_log_steps": False, }, }, "learning_rate_logger": { "target": "celle_taming_main.LearningRateMonitor", "params": { "logging_interval": "step", # "log_momentum": True }, }, } try: callbacks_cfg = lightning_config.callbacks except: callbacks_cfg = OmegaConf.create() callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) callbacks_cfg = OmegaConf.merge(modelckpt_cfg, callbacks_cfg) trainer_kwargs["callbacks"] = [ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg ] # loaded_callbacks = [ # instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg # ] # trainer_kwargs["callbacks"] = loaded_callbacks.append(loaded_model_callbacks) trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) # data data = instantiate_from_config(config.data) # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # calling these ourselves should not be necessary but it is. # lightning still takes care of proper multiprocessing though data.prepare_data() data.setup() # configure learning rate bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate if not cpu: ngpu = len(lightning_config.trainer.gpus.strip(",").split(",")) else: ngpu = 1 try: accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches except: accumulate_grad_batches = 1 print(f"accumulate_grad_batches = {accumulate_grad_batches}") lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr print( "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr ) ) # allow checkpointing via USR1 def melk(*args, **kwargs): # run all checkpoint hooks if trainer.global_rank == 0: print("Summoning checkpoint.") ckpt_path = os.path.join(ckptdir, "last.ckpt") trainer.save_checkpoint(ckpt_path) def divein(*args, **kwargs): if trainer.global_rank == 0: import pudb pudb.set_trace() import signal signal.signal(signal.SIGUSR1, melk) signal.signal(signal.SIGUSR2, divein) # model = torch.compile(model) # run if opt.train: try: torch.compile(trainer.fit(model, data)) except Exception: melk() raise if not opt.no_test and not trainer.interrupted: trainer.test(model, data) except Exception: if opt.debug and trainer.global_rank == 0: try: import pudb as debugger except ImportError: import pdb as debugger debugger.post_mortem() raise finally: # move newly created debug project to debug_runs if opt.debug and not opt.resume and trainer.global_rank == 0: dst, name = os.path.split(logdir) dst = os.path.join(dst, "debug_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) os.rename(logdir, dst)