Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/torch
/distributed
/checkpoint
/logger.py
# mypy: allow-untyped-defs | |
import functools | |
import time | |
from typing import Any, Callable, Dict, List, TypeVar | |
from typing_extensions import ParamSpec | |
import torch.distributed.c10d_logger as c10d_logger | |
from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME | |
__all__: List[str] = [] | |
global _dcp_logger | |
_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME) | |
_T = TypeVar("_T") | |
_P = ParamSpec("_P") | |
def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]: | |
""" | |
Extracts log data from dcp method args | |
""" | |
msg_dict = {} | |
# checkpoint ID can be passed in through the serializer or through the checkpoint id directly | |
storage_writer = kwargs.get("storage_writer", None) | |
storage_reader = kwargs.get("storage_reader", None) | |
checkpoint_id = kwargs.get("checkpoint_id", None) | |
if not checkpoint_id and (serializer := storage_writer or storage_reader): | |
checkpoint_id = getattr(serializer, "checkpoint_id", None) | |
msg_dict["checkpoint_id"] = ( | |
str(checkpoint_id) if checkpoint_id is not None else checkpoint_id | |
) | |
return msg_dict | |
def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]: | |
msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs) | |
msg_dict.update(c10d_logger._get_msg_dict(func_name, **msg_dict)) | |
return msg_dict | |
def _dcp_method_logger( | |
log_exceptions: bool = False, **wrapper_kwargs: Any | |
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # pyre-ignore | |
"""This method decorator logs the start, end, and exception of wrapped events.""" | |
def decorator(func: Callable[_P, _T]): | |
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: | |
msg_dict = _get_msg_dict( | |
func.__name__, *args, **{**wrapper_kwargs, **kwargs} | |
) | |
# log start event | |
msg_dict["event"] = "start" | |
t0 = time.time_ns() | |
msg_dict["time"] = t0 | |
_dcp_logger.debug(msg_dict) | |
# exceptions | |
try: | |
result = func(*args, **kwargs) | |
except Exception as error: | |
if log_exceptions: | |
msg_dict["event"] = "exception" | |
msg_dict["error"] = f"{error}" | |
msg_dict["time"] = time.time_ns() | |
_dcp_logger.error(msg_dict) | |
raise | |
# end event | |
msg_dict["event"] = "end" | |
t1 = time.time_ns() | |
msg_dict["time"] = time.time_ns() | |
msg_dict["times_spent"] = t1 - t0 | |
_dcp_logger.debug(msg_dict) | |
return result | |
return wrapper | |
return decorator | |