#!/usr/bin/env python3 # -*- coding:utf-8 -*- # Copyright (c) Megvii Inc. All rights reserved. import inspect import os import sys from loguru import logger import torch def get_caller_name(depth=0): """ Args: depth (int): Depth of caller conext, use 0 for caller depth. Default value: 0. Returns: str: module name of the caller """ # the following logic is a little bit faster than inspect.stack() logic frame = inspect.currentframe().f_back for _ in range(depth): frame = frame.f_back return frame.f_globals["__name__"] class StreamToLoguru: """ stream object that redirects writes to a logger instance. """ def __init__(self, level="INFO", caller_names=("apex", "pycocotools")): """ Args: level(str): log level string of loguru. Default value: "INFO". caller_names(tuple): caller names of redirected module. Default value: (apex, pycocotools). """ self.level = level self.linebuf = "" self.caller_names = caller_names def write(self, buf): full_name = get_caller_name(depth=1) module_name = full_name.rsplit(".", maxsplit=-1)[0] if module_name in self.caller_names: for line in buf.rstrip().splitlines(): # use caller level log logger.opt(depth=2).log(self.level, line.rstrip()) else: sys.__stdout__.write(buf) def flush(self): pass def redirect_sys_output(log_level="INFO"): redirect_logger = StreamToLoguru(log_level) sys.stderr = redirect_logger sys.stdout = redirect_logger def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"): """setup logger for training and testing. Args: save_dir(str): location to save log file distributed_rank(int): device rank when multi-gpu environment filename (string): log save name. mode(str): log file write mode, `append` or `override`. default is `a`. Return: logger instance. """ loguru_format = ( "{time:YYYY-MM-DD HH:mm:ss} | " "{level: <8} | " "{name}:{line} - {message}" ) logger.remove() save_file = os.path.join(save_dir, filename) if mode == "o" and os.path.exists(save_file): os.remove(save_file) # only keep logger in rank0 process if distributed_rank == 0: logger.add( sys.stderr, format=loguru_format, level="INFO", enqueue=True, ) logger.add(save_file) # redirect stdout/stderr to loguru redirect_sys_output("INFO") class WandbLogger(object): """ Log training runs, datasets, models, and predictions to Weights & Biases. This logger sends information to W&B at wandb.ai. By default, this information includes hyperparameters, system configuration and metrics, model metrics, and basic data metrics and analyses. For more information, please refer to: https://docs.wandb.ai/guides/track """ def __init__(self, project=None, name=None, id=None, entity=None, save_dir=None, config=None, **kwargs): """ Args: project (str): wandb project name. name (str): wandb run name. id (str): wandb run id. entity (str): wandb entity name. save_dir (str): save directory. config (dict): config dict. **kwargs: other kwargs. """ try: import wandb self.wandb = wandb except ModuleNotFoundError: raise ModuleNotFoundError( "wandb is not installed." "Please install wandb using pip install wandb" ) self.project = project self.name = name self.id = id self.save_dir = save_dir self.config = config self.kwargs = kwargs self.entity = entity self._run = None self._wandb_init = dict( project=self.project, name=self.name, id=self.id, entity=self.entity, dir=self.save_dir, resume="allow" ) self._wandb_init.update(**kwargs) _ = self.run if self.config: self.run.config.update(self.config) self.run.define_metric("epoch") self.run.define_metric("val/", step_metric="epoch") @property def run(self): if self._run is None: if self.wandb.run is not None: logger.info( "There is a wandb run already in progress " "and newly created instances of `WandbLogger` will reuse" " this run. If this is not desired, call `wandb.finish()`" "before instantiating `WandbLogger`." ) self._run = self.wandb.run else: self._run = self.wandb.init(**self._wandb_init) return self._run def log_metrics(self, metrics, step=None): """ Args: metrics (dict): metrics dict. step (int): step number. """ for k, v in metrics.items(): if isinstance(v, torch.Tensor): metrics[k] = v.item() if step is not None: self.run.log(metrics, step=step) else: self.run.log(metrics) def save_checkpoint(self, save_dir, model_name, is_best): """ Args: save_dir (str): save directory. model_name (str): model name. is_best (bool): whether the model is the best model. """ filename = os.path.join(save_dir, model_name + "_ckpt.pth") artifact = self.wandb.Artifact( name=f"model-{self.run.id}", type="model" ) artifact.add_file(filename, name="model_ckpt.pth") aliases = ["latest"] if is_best: aliases.append("best") self.run.log_artifact(artifact, aliases=aliases) def finish(self): self.run.finish()