import os import logging import torch.distributed as dist class RankFilter(logging.Filter): def __init__(self, rank): super().__init__() self.rank = rank def filter(self, record): return dist.get_rank() == self.rank def create_logger(log_path): # Create log path if os.path.isdir(os.path.dirname(log_path)): os.makedirs(os.path.dirname(log_path), exist_ok=True) # Create logger object logger = logging.getLogger() logger.setLevel(logging.INFO) # Create file handler and set the formatter fh = logging.FileHandler(log_path) formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') fh.setFormatter(formatter) # Add the file handler to the logger logger.addHandler(fh) # Add a stream handler to print to console sh = logging.StreamHandler() sh.setLevel(logging.INFO) # Set logging level for stream handler sh.setFormatter(formatter) logger.addHandler(sh) return logger