import logging import os from typing import TYPE_CHECKING, Union from .constants import FINETRAINERS_LOG_LEVEL if TYPE_CHECKING: from .parallel import ParallelBackendType class FinetrainersLoggerAdapter(logging.LoggerAdapter): def __init__(self, logger: logging.Logger, parallel_backend: "ParallelBackendType" = None) -> None: super().__init__(logger, {}) self.parallel_backend = parallel_backend self._log_freq = {} self._log_freq_counter = {} def log( self, level, msg, *args, main_process_only: bool = False, local_main_process_only: bool = True, in_order: bool = False, **kwargs, ): # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice kwargs.setdefault("stacklevel", 2) if not self.isEnabledFor(level): return if self.parallel_backend is None: if int(os.environ.get("RANK", 0)) == 0: msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs) return if (main_process_only or local_main_process_only) and in_order: raise ValueError( "Cannot set `main_process_only` or `local_main_process_only` to True while `in_order` is True." ) if (main_process_only and self.parallel_backend.is_main_process) or ( local_main_process_only and self.parallel_backend.is_local_main_process ): msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs) return if in_order: for i in range(self.parallel_backend.world_size): if self.rank == i: msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs) self.parallel_backend.wait_for_everyone() return if not main_process_only and not local_main_process_only: msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs) return def log_freq( self, level: str, name: str, msg: str, frequency: int, *, main_process_only: bool = False, local_main_process_only: bool = True, in_order: bool = False, **kwargs, ) -> None: if frequency <= 0: return if name not in self._log_freq_counter: self._log_freq[name] = frequency self._log_freq_counter[name] = 0 if self._log_freq_counter[name] % self._log_freq[name] == 0: self.log( level, msg, main_process_only=main_process_only, local_main_process_only=local_main_process_only, in_order=in_order, **kwargs, ) self._log_freq_counter[name] += 1 def get_logger() -> Union[logging.Logger, FinetrainersLoggerAdapter]: global _logger return _logger def _set_parallel_backend(parallel_backend: "ParallelBackendType") -> FinetrainersLoggerAdapter: _logger.parallel_backend = parallel_backend _logger = logging.getLogger("finetrainers") _logger.setLevel(FINETRAINERS_LOG_LEVEL) _console_handler = logging.StreamHandler() _console_handler.setLevel(FINETRAINERS_LOG_LEVEL) _formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") _console_handler.setFormatter(_formatter) _logger.addHandler(_console_handler) _logger = FinetrainersLoggerAdapter(_logger)