anyantudre's picture
moved from training repo to inference
caa56d6
import os
import logging
import torch.distributed as dist
class RankFilter(logging.Filter):
def __init__(self, rank):
super().__init__()
self.rank = rank
def filter(self, record):
return dist.get_rank() == self.rank
def create_logger(log_path):
# Create log path
if os.path.isdir(os.path.dirname(log_path)):
os.makedirs(os.path.dirname(log_path), exist_ok=True)
# Create logger object
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Create file handler and set the formatter
fh = logging.FileHandler(log_path)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
# Add the file handler to the logger
logger.addHandler(fh)
# Add a stream handler to print to console
sh = logging.StreamHandler()
sh.setLevel(logging.INFO) # Set logging level for stream handler
sh.setFormatter(formatter)
logger.addHandler(sh)
return logger