jbilcke-hf's picture
jbilcke-hf HF Staff
upgrading finetrainers (and losing my extra code + improvements)
80ebcb3
raw
history blame
3.67 kB
import logging
import os
from typing import TYPE_CHECKING, Union
from .constants import FINETRAINERS_LOG_LEVEL
if TYPE_CHECKING:
from .parallel import ParallelBackendType
class FinetrainersLoggerAdapter(logging.LoggerAdapter):
def __init__(self, logger: logging.Logger, parallel_backend: "ParallelBackendType" = None) -> None:
super().__init__(logger, {})
self.parallel_backend = parallel_backend
self._log_freq = {}
self._log_freq_counter = {}
def log(
self,
level,
msg,
*args,
main_process_only: bool = False,
local_main_process_only: bool = True,
in_order: bool = False,
**kwargs,
):
# set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
kwargs.setdefault("stacklevel", 2)
if not self.isEnabledFor(level):
return
if self.parallel_backend is None:
if int(os.environ.get("RANK", 0)) == 0:
msg, kwargs = self.process(msg, kwargs)
self.logger.log(level, msg, *args, **kwargs)
return
if (main_process_only or local_main_process_only) and in_order:
raise ValueError(
"Cannot set `main_process_only` or `local_main_process_only` to True while `in_order` is True."
)
if (main_process_only and self.parallel_backend.is_main_process) or (
local_main_process_only and self.parallel_backend.is_local_main_process
):
msg, kwargs = self.process(msg, kwargs)
self.logger.log(level, msg, *args, **kwargs)
return
if in_order:
for i in range(self.parallel_backend.world_size):
if self.rank == i:
msg, kwargs = self.process(msg, kwargs)
self.logger.log(level, msg, *args, **kwargs)
self.parallel_backend.wait_for_everyone()
return
if not main_process_only and not local_main_process_only:
msg, kwargs = self.process(msg, kwargs)
self.logger.log(level, msg, *args, **kwargs)
return
def log_freq(
self,
level: str,
name: str,
msg: str,
frequency: int,
*,
main_process_only: bool = False,
local_main_process_only: bool = True,
in_order: bool = False,
**kwargs,
) -> None:
if frequency <= 0:
return
if name not in self._log_freq_counter:
self._log_freq[name] = frequency
self._log_freq_counter[name] = 0
if self._log_freq_counter[name] % self._log_freq[name] == 0:
self.log(
level,
msg,
main_process_only=main_process_only,
local_main_process_only=local_main_process_only,
in_order=in_order,
**kwargs,
)
self._log_freq_counter[name] += 1
def get_logger() -> Union[logging.Logger, FinetrainersLoggerAdapter]:
global _logger
return _logger
def _set_parallel_backend(parallel_backend: "ParallelBackendType") -> FinetrainersLoggerAdapter:
_logger.parallel_backend = parallel_backend
_logger = logging.getLogger("finetrainers")
_logger.setLevel(FINETRAINERS_LOG_LEVEL)
_console_handler = logging.StreamHandler()
_console_handler.setLevel(FINETRAINERS_LOG_LEVEL)
_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
_console_handler.setFormatter(_formatter)
_logger.addHandler(_console_handler)
_logger = FinetrainersLoggerAdapter(_logger)