|
|
|
import functools |
|
import hashlib |
|
import itertools |
|
import json |
|
import logging |
|
import os |
|
import os.path |
|
import re |
|
import tempfile |
|
from dataclasses import dataclass, field |
|
from importlib import __import__ |
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union |
|
from weakref import WeakSet |
|
|
|
import torch._logging.structured |
|
from torch.utils._traceback import CapturedTraceback |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trace_log = logging.getLogger("torch.__trace") |
|
|
|
DEFAULT_LOG_LEVEL = logging.WARNING |
|
LOG_ENV_VAR = "TORCH_LOGS" |
|
LOG_OUT_ENV_VAR = "TORCH_LOGS_OUT" |
|
LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT" |
|
TRACE_ENV_VAR = "TORCH_TRACE" |
|
|
|
|
|
@dataclass |
|
class LogRegistry: |
|
|
|
|
|
|
|
|
|
log_alias_to_log_qnames: Dict[str, List[str]] = field(default_factory=dict) |
|
|
|
|
|
|
|
|
|
|
|
artifact_log_qnames: Set[str] = field(default_factory=set) |
|
|
|
|
|
|
|
|
|
|
|
child_log_qnames: Set[str] = field(default_factory=set) |
|
|
|
|
|
|
|
artifact_names: Set[str] = field(default_factory=set) |
|
|
|
|
|
visible_artifacts: Set[str] = field(default_factory=set) |
|
|
|
|
|
artifact_descriptions: Dict[str, str] = field(default_factory=dict) |
|
|
|
|
|
|
|
|
|
off_by_default_artifact_names: Set[str] = field(default_factory=set) |
|
|
|
|
|
artifact_log_formatters: Dict[str, logging.Formatter] = field(default_factory=dict) |
|
|
|
def is_artifact(self, name): |
|
return name in self.artifact_names |
|
|
|
def is_log(self, alias): |
|
return alias in self.log_alias_to_log_qnames |
|
|
|
|
|
def register_log(self, alias, log_qnames: Union[str, List[str]]): |
|
if isinstance(log_qnames, str): |
|
log_qnames = [log_qnames] |
|
self.log_alias_to_log_qnames[alias] = log_qnames |
|
|
|
|
|
def register_artifact_name( |
|
self, name, description, visible, off_by_default, log_format |
|
): |
|
self.artifact_names.add(name) |
|
if visible: |
|
self.visible_artifacts.add(name) |
|
self.artifact_descriptions[name] = description |
|
|
|
|
|
|
|
if off_by_default: |
|
self.off_by_default_artifact_names.add(name) |
|
|
|
if log_format is not None: |
|
self.artifact_log_formatters[name] = logging.Formatter(log_format) |
|
|
|
|
|
|
|
|
|
def register_artifact_log(self, artifact_log_qname): |
|
self.artifact_log_qnames.add(artifact_log_qname) |
|
|
|
def register_child_log(self, log_qname): |
|
self.child_log_qnames.add(log_qname) |
|
|
|
|
|
def get_log_qnames(self) -> Set[str]: |
|
return { |
|
qname |
|
for qnames in self.log_alias_to_log_qnames.values() |
|
for qname in qnames |
|
} |
|
|
|
def get_artifact_log_qnames(self): |
|
return set(self.artifact_log_qnames) |
|
|
|
def get_child_log_qnames(self): |
|
return set(self.child_log_qnames) |
|
|
|
def is_off_by_default(self, artifact_qname): |
|
return artifact_qname in self.off_by_default_artifact_names |
|
|
|
|
|
@dataclass |
|
class LogState: |
|
|
|
log_qname_to_level: Dict[str, str] = field(default_factory=dict) |
|
|
|
|
|
artifact_names: Set[str] = field(default_factory=set) |
|
|
|
def enable_artifact(self, artifact_name): |
|
self.artifact_names.add(artifact_name) |
|
|
|
def is_artifact_enabled(self, name): |
|
return name in self.artifact_names |
|
|
|
def enable_log(self, log_qnames, log_level): |
|
if isinstance(log_qnames, str): |
|
log_qnames = [log_qnames] |
|
for log_qname in log_qnames: |
|
self.log_qname_to_level[log_qname] = log_level |
|
|
|
def get_log_level_pairs(self): |
|
"""Returns all qualified module names for which the user requested |
|
explicit logging settings. |
|
|
|
.. warning: |
|
|
|
This function used to return all loggers, regardless of whether |
|
or not the user specified them or not; it now only returns logs |
|
which were explicitly mentioned by the user (and torch, which |
|
always is implicitly requested when we initialize our logging |
|
subsystem.) |
|
""" |
|
return self.log_qname_to_level.items() |
|
|
|
def clear(self): |
|
self.log_qname_to_level.clear() |
|
self.artifact_names.clear() |
|
|
|
|
|
log_registry = LogRegistry() |
|
log_state = LogState() |
|
|
|
|
|
DEFAULT_LOGGING = { |
|
"dynamo": logging.DEBUG, |
|
"aot": logging.DEBUG, |
|
"inductor": logging.DEBUG, |
|
"ddp_graphs": True, |
|
"graph_breaks": True, |
|
"guards": True, |
|
"recompiles": True, |
|
"dynamic": logging.INFO, |
|
} |
|
|
|
|
|
def set_logs( |
|
*, |
|
all: Optional[int] = None, |
|
dynamo: Optional[int] = None, |
|
aot: Optional[int] = None, |
|
autograd: Optional[int] = None, |
|
dynamic: Optional[int] = None, |
|
inductor: Optional[int] = None, |
|
distributed: Optional[int] = None, |
|
dist_c10d: Optional[int] = None, |
|
dist_ddp: Optional[int] = None, |
|
dist_fsdp: Optional[int] = None, |
|
onnx: Optional[int] = None, |
|
bytecode: bool = False, |
|
aot_graphs: bool = False, |
|
aot_joint_graph: bool = False, |
|
ddp_graphs: bool = False, |
|
graph: bool = False, |
|
graph_code: bool = False, |
|
graph_breaks: bool = False, |
|
graph_sizes: bool = False, |
|
guards: bool = False, |
|
recompiles: bool = False, |
|
recompiles_verbose: bool = False, |
|
trace_source: bool = False, |
|
trace_call: bool = False, |
|
trace_bytecode: bool = False, |
|
output_code: bool = False, |
|
kernel_code: bool = False, |
|
schedule: bool = False, |
|
perf_hints: bool = False, |
|
post_grad_graphs: bool = False, |
|
onnx_diagnostics: bool = False, |
|
fusion: bool = False, |
|
overlap: bool = False, |
|
export: Optional[int] = None, |
|
modules: Optional[Dict[str, Union[int, bool]]] = None, |
|
cudagraphs: bool = False, |
|
sym_node: bool = False, |
|
compiled_autograd_verbose: bool = False, |
|
): |
|
""" |
|
Sets the log level for individual components and toggles individual log |
|
artifact types. |
|
|
|
.. warning:: This feature is a prototype and may have compatibility |
|
breaking changes in the future. |
|
|
|
.. note:: The ``TORCH_LOGS`` environment variable has complete precedence |
|
over this function, so if it was set, this function does nothing. |
|
|
|
A component is a set of related features in PyTorch. All of the log |
|
messages emitted from a given component have their own log levels. If the |
|
log level of a particular message has priority greater than or equal to its |
|
component's log level setting, it is emitted. Otherwise, it is suppressed. |
|
This allows you to, for instance, silence large groups of log messages that |
|
are not relevant to you and increase verbosity of logs for components that |
|
are relevant. The expected log level values, ordered from highest to lowest |
|
priority, are: |
|
|
|
* ``logging.CRITICAL`` |
|
* ``logging.ERROR`` |
|
* ``logging.WARNING`` |
|
* ``logging.INFO`` |
|
* ``logging.DEBUG`` |
|
* ``logging.NOTSET`` |
|
|
|
See documentation for the Python ``logging`` module for more information on |
|
log levels: `<https://docs.python.org/3/library/logging.html#logging-levels>`_ |
|
|
|
An artifact is a particular type of log message. Each artifact is assigned |
|
to a parent component. A component can emit many different kinds of |
|
artifacts. In general, an artifact is emitted if either its corresponding |
|
setting in the argument list below is turned on or if its parent component |
|
is set to a log level less than or equal to the log level of the artifact. |
|
|
|
Keyword args: |
|
all (:class:`Optional[int]`): |
|
The default log level for all components. Default: ``logging.WARN`` |
|
|
|
dynamo (:class:`Optional[int]`): |
|
The log level for the TorchDynamo component. Default: ``logging.WARN`` |
|
|
|
aot (:class:`Optional[int]`): |
|
The log level for the AOTAutograd component. Default: ``logging.WARN`` |
|
|
|
autograd (:class:`Optional[int]`): |
|
The log level for autograd. Default: ``logging.WARN`` |
|
|
|
inductor (:class:`Optional[int]`): |
|
The log level for the TorchInductor component. Default: ``logging.WARN`` |
|
|
|
dynamic (:class:`Optional[int]`): |
|
The log level for dynamic shapes. Default: ``logging.WARN`` |
|
|
|
distributed (:class:`Optional[int]`): |
|
Whether to log c10d communication operations and other debug info from PyTorch Distributed components. |
|
Default: ``logging.WARN`` |
|
|
|
dist_c10d (:class:`Optional[int]`): |
|
Whether to log c10d communication operations related debug info in PyTorch Distributed components. |
|
Default: ``logging.WARN`` |
|
|
|
dist_ddp (:class:`Optional[int]`): |
|
Whether to log debug info related to ``DistributedDataParallel``(DDP) from PyTorch Distributed components. |
|
Default: ``logging.WARN`` |
|
|
|
dist_fsdp (:class:`Optional[int]`): |
|
Whether to log debug info related to ``FullyShardedDataParallel``(FSDP) in PyTorch Distributed components. |
|
Default: ``logging.WARN`` |
|
|
|
onnx (:class:`Optional[int]`): |
|
The log level for the ONNX exporter component. Default: ``logging.WARN`` |
|
|
|
bytecode (:class:`bool`): |
|
Whether to emit the original and generated bytecode from TorchDynamo. |
|
Default: ``False`` |
|
|
|
aot_graphs (:class:`bool`): |
|
Whether to emit the graphs generated by AOTAutograd. Default: ``False`` |
|
|
|
aot_joint_graph (:class:`bool`): |
|
Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False`` |
|
|
|
inductor (:class:`Optional[int]`): |
|
Whether to log information from inductor cudagraphs. Default: ``logging.WARN`` |
|
|
|
ddp_graphs (:class:`bool`): |
|
Whether to emit graphs generated by DDPOptimizer. Default: ``False`` |
|
|
|
graph (:class:`bool`): |
|
Whether to emit the graph captured by TorchDynamo in tabular format. |
|
Default: ``False`` |
|
|
|
graph_code (:class:`bool`): |
|
Whether to emit the python source of the graph captured by TorchDynamo. |
|
Default: ``False`` |
|
|
|
graph_breaks (:class:`bool`): |
|
Whether to emit the graph breaks encountered by TorchDynamo. |
|
Default: ``False`` |
|
|
|
graph_sizes (:class:`bool`): |
|
Whether to emit tensor sizes of the graph captured by TorchDynamo. |
|
Default: ``False`` |
|
|
|
guards (:class:`bool`): |
|
Whether to emit the guards generated by TorchDynamo for each compiled |
|
function. Default: ``False`` |
|
|
|
recompiles (:class:`bool`): |
|
Whether to emit a guard failure reason and message every time |
|
TorchDynamo recompiles a function. Default: ``False`` |
|
|
|
recompiles_verbose (:class:`bool`): |
|
Whether to emit all guard failure reasons when TorchDynamo recompiles |
|
a function, even those that are not actually run. Default: ``False`` |
|
|
|
trace_source (:class:`bool`): |
|
Whether to emit when TorchDynamo begins tracing a new line. Default: ``False`` |
|
|
|
trace_call (:class:`bool`): |
|
Whether to emit detailed line location when TorchDynamo creates an FX node |
|
corresponding to function call. Python 3.11+ only. Default: ``False`` |
|
|
|
trace_bytecode (:class:`bool`): |
|
Whether to emit bytecode instructions and traced stack state as TorchDynamo |
|
traces bytecode. Default: ``False`` |
|
|
|
output_code (:class:`bool`): |
|
Whether to emit the TorchInductor output code on a per-graph basis. Default: ``False`` |
|
|
|
kernel_code (:class:`bool`): |
|
Whether to emit the TorchInductor output code on a per-kernel bases. Default: ``False`` |
|
|
|
schedule (:class:`bool`): |
|
Whether to emit the TorchInductor schedule. Default: ``False`` |
|
|
|
perf_hints (:class:`bool`): |
|
Whether to emit the TorchInductor perf hints. Default: ``False`` |
|
|
|
post_grad_graphs (:class:`bool`): |
|
Whether to emit the graphs generated by after post grad passes. Default: ``False`` |
|
|
|
onnx_diagnostics (:class:`bool`): |
|
Whether to emit the ONNX exporter diagnostics in logging. Default: ``False`` |
|
|
|
fusion (:class:`bool`): |
|
Whether to emit detailed Inductor fusion decisions. Default: ``False`` |
|
|
|
overlap (:class:`bool`): |
|
Whether to emit detailed Inductor compute/comm overlap decisions. Default: ``False`` |
|
|
|
sym_node (:class:`bool`): |
|
Whether to emit debug info for various SymNode opterations. Default: ``False`` |
|
|
|
export (:class:`Optional[int]`): |
|
The log level for export. Default: ``logging.WARN`` |
|
|
|
modules (dict): |
|
This argument provides an alternate way to specify the above log |
|
component and artifact settings, in the format of a keyword args |
|
dictionary given as a single argument. There are two cases |
|
where this is useful (1) if a new log component or artifact has |
|
been registered but a keyword argument for it has not been added |
|
to this function and (2) if the log level for an unregistered module |
|
needs to be set. This can be done by providing the fully-qualified module |
|
name as the key, with the log level as the value. Default: ``None`` |
|
|
|
|
|
Example:: |
|
|
|
>>> # xdoctest: +SKIP |
|
>>> import logging |
|
|
|
# The following changes the "dynamo" component to emit DEBUG-level |
|
# logs, and to emit "graph_code" artifacts. |
|
|
|
>>> torch._logging.set_logs(dynamo=logging.DEBUG, graph_code=True) |
|
|
|
# The following enables the logs for a different module |
|
|
|
>>> torch._logging.set_logs(modules={"unregistered.module.name": logging.DEBUG}) |
|
""" |
|
|
|
if LOG_ENV_VAR in os.environ: |
|
log.warning( |
|
"Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs" |
|
) |
|
return |
|
|
|
log_state.clear() |
|
|
|
modules = modules or {} |
|
|
|
def _set_logs(**kwargs): |
|
for alias, val in itertools.chain(kwargs.items(), modules.items()): |
|
if val is None: |
|
continue |
|
|
|
if log_registry.is_artifact(alias): |
|
if not isinstance(val, bool): |
|
raise ValueError( |
|
f"Expected bool to enable artifact {alias}, received {val}" |
|
) |
|
|
|
if val: |
|
log_state.enable_artifact(alias) |
|
elif log_registry.is_log(alias) or alias in log_registry.child_log_qnames: |
|
if val not in logging._levelToName: |
|
raise ValueError( |
|
f"Unrecognized log level for log {alias}: {val}, valid level values " |
|
f"are: {','.join([str(k) for k in logging._levelToName.keys()])}" |
|
) |
|
|
|
log_state.enable_log( |
|
log_registry.log_alias_to_log_qnames.get(alias, alias), val |
|
) |
|
else: |
|
raise ValueError( |
|
f"Unrecognized log or artifact name passed to set_logs: {alias}" |
|
) |
|
|
|
_init_logs() |
|
|
|
_set_logs( |
|
torch=all, |
|
dynamo=dynamo, |
|
aot=aot, |
|
autograd=autograd, |
|
inductor=inductor, |
|
dynamic=dynamic, |
|
bytecode=bytecode, |
|
aot_graphs=aot_graphs, |
|
aot_joint_graph=aot_joint_graph, |
|
ddp_graphs=ddp_graphs, |
|
distributed=distributed, |
|
dist_c10d=dist_c10d, |
|
dist_ddp=dist_ddp, |
|
dist_fsdp=dist_fsdp, |
|
graph=graph, |
|
graph_code=graph_code, |
|
graph_breaks=graph_breaks, |
|
graph_sizes=graph_sizes, |
|
guards=guards, |
|
recompiles=recompiles, |
|
recompiles_verbose=recompiles_verbose, |
|
trace_source=trace_source, |
|
trace_call=trace_call, |
|
trace_bytecode=trace_bytecode, |
|
output_code=output_code, |
|
kernel_code=kernel_code, |
|
schedule=schedule, |
|
perf_hints=perf_hints, |
|
post_grad_graphs=post_grad_graphs, |
|
onnx=onnx, |
|
onnx_diagnostics=onnx_diagnostics, |
|
fusion=fusion, |
|
overlap=overlap, |
|
sym_node=sym_node, |
|
export=export, |
|
cudagraphs=cudagraphs, |
|
compiled_autograd_verbose=compiled_autograd_verbose, |
|
) |
|
|
|
|
|
def get_loggers(): |
|
""" |
|
Returns: a list of all registered loggers |
|
""" |
|
return [logging.getLogger(qname) for qname in log_registry.get_log_qnames()] |
|
|
|
|
|
def register_log(setting_name, log_name): |
|
""" |
|
Enables a log to be controlled by the env var and user API with the setting_name |
|
Args: |
|
setting_name: the shorthand name used in the env var and user API |
|
log_name: the log name that the setting_name is associated with |
|
""" |
|
log_registry.register_log(setting_name, log_name) |
|
|
|
|
|
def register_artifact( |
|
setting_name, description, visible=False, off_by_default=False, log_format=None |
|
): |
|
""" |
|
Enables an artifact to be controlled by the env var and user API with name |
|
Args: |
|
setting_name: the shorthand name used in the env var and user API |
|
description: A description of what this outputs |
|
visible: Whether it gets suggested to users by default |
|
off_by_default: whether this artifact should be logged when the ancestor loggers |
|
are enabled at level DEBUG |
|
""" |
|
log_registry.register_artifact_name( |
|
setting_name, description, visible, off_by_default, log_format |
|
) |
|
|
|
|
|
def getArtifactLogger(module_qname, artifact_name): |
|
if artifact_name not in log_registry.artifact_names: |
|
raise ValueError( |
|
f"Artifact name: {repr(artifact_name)} not registered," |
|
f"please call register_artifact({repr(artifact_name)}) in torch._logging.registrations." |
|
) |
|
qname = module_qname + f".__{artifact_name}" |
|
log = logging.getLogger(qname) |
|
log.artifact_name = artifact_name |
|
log_registry.register_artifact_log(qname) |
|
configure_artifact_log(log) |
|
return log |
|
|
|
|
|
INCR_VERBOSITY_CHAR = "+" |
|
DECR_VERBOSITY_CHAR = "-" |
|
VERBOSITY_REGEX = ( |
|
"(" |
|
+ "|".join([re.escape(INCR_VERBOSITY_CHAR), re.escape(DECR_VERBOSITY_CHAR)]) |
|
+ "?)" |
|
) |
|
|
|
|
|
def configure_artifact_log(log): |
|
|
|
|
|
|
|
if log_registry.is_off_by_default(log.artifact_name): |
|
log.propagate = False |
|
|
|
|
|
if log_state.is_artifact_enabled(log.artifact_name): |
|
log.setLevel(logging.DEBUG) |
|
log.propagate = True |
|
|
|
|
|
|
|
def _gen_settings_regex(): |
|
return re.compile(r"((\+|-)?[\w\.]+,\s*)*(\+|-)?[\w\.]+?") |
|
|
|
|
|
def _validate_settings(settings): |
|
return re.fullmatch(_gen_settings_regex(), settings) is not None |
|
|
|
|
|
def help_message(verbose=False): |
|
def pad_to(s, length=30): |
|
assert len(s) <= length |
|
return s + " " * (length - len(s)) |
|
|
|
if verbose: |
|
printed_artifacts = log_registry.artifact_names |
|
else: |
|
printed_artifacts = log_registry.visible_artifacts |
|
|
|
if verbose: |
|
heading = "All registered names" |
|
else: |
|
heading = "Visible registered names (use TORCH_LOGS='+help' for full list)" |
|
lines = ( |
|
["all"] |
|
+ sorted(log_registry.log_alias_to_log_qnames.keys()) |
|
+ sorted( |
|
[ |
|
f"{pad_to(name)}\t{log_registry.artifact_descriptions[name]}" |
|
for name in printed_artifacts |
|
] |
|
) |
|
) |
|
setting_info = " " + "\n ".join(lines) |
|
examples = """ |
|
Examples: |
|
TORCH_LOGS="+dynamo,aot" will set the log level of TorchDynamo to |
|
logging.DEBUG and AOT to logging.INFO |
|
|
|
TORCH_LOGS="-dynamo,+inductor" will set the log level of TorchDynamo to |
|
logging.ERROR and TorchInductor to logging.DEBUG |
|
|
|
TORCH_LOGS="aot_graphs" will enable the aot_graphs artifact |
|
|
|
TORCH_LOGS="+dynamo,schedule" will enable set the log level of TorchDynamo |
|
to logging.DEBUG and enable the schedule artifact |
|
|
|
TORCH_LOGS="+some.random.module,schedule" will set the log level of |
|
some.random.module to logging.DEBUG and enable the schedule artifact |
|
|
|
TORCH_LOGS_FORMAT="%(levelname)s: %(message)s" or any provided format |
|
string will set the output format |
|
Valid keys are "levelname", "message", "pathname", "levelno", "lineno", |
|
"filename" and "name". |
|
|
|
TORCH_LOGS_OUT=/tmp/output.txt will output the logs to /tmp/output.txt as |
|
well. This is useful when the output is long. |
|
""" |
|
msg = f""" |
|
TORCH_LOGS Info |
|
{examples} |
|
|
|
{heading} |
|
{setting_info} |
|
""" |
|
return msg |
|
|
|
|
|
def _invalid_settings_err_msg(settings, verbose=False): |
|
valid_settings = ", ".join( |
|
["all"] |
|
+ list(log_registry.log_alias_to_log_qnames.keys()) |
|
+ list(log_registry.artifact_names) |
|
) |
|
msg = f""" |
|
Invalid log settings: {settings}, must be a comma separated list of fully |
|
qualified module names, registered log names or registered artifact names. |
|
For more info on various settings, try TORCH_LOGS="help" |
|
Valid settings: |
|
{valid_settings} |
|
""" |
|
return msg |
|
|
|
|
|
@functools.lru_cache |
|
def _parse_log_settings(settings): |
|
if settings == "": |
|
return dict() |
|
|
|
if settings == "help": |
|
raise ValueError(help_message(verbose=False)) |
|
elif settings == "+help": |
|
raise ValueError(help_message(verbose=True)) |
|
if not _validate_settings(settings): |
|
raise ValueError(_invalid_settings_err_msg(settings)) |
|
|
|
settings = re.sub(r"\s+", "", settings) |
|
log_names = settings.split(",") |
|
|
|
def get_name_level_pair(name): |
|
clean_name = name.replace(INCR_VERBOSITY_CHAR, "") |
|
clean_name = clean_name.replace(DECR_VERBOSITY_CHAR, "") |
|
|
|
if name[0] == INCR_VERBOSITY_CHAR: |
|
level = logging.DEBUG |
|
elif name[0] == DECR_VERBOSITY_CHAR: |
|
level = logging.ERROR |
|
else: |
|
level = logging.INFO |
|
|
|
return clean_name, level |
|
|
|
log_state = LogState() |
|
|
|
for name in log_names: |
|
name, level = get_name_level_pair(name) |
|
|
|
if name == "all": |
|
name = "torch" |
|
|
|
if log_registry.is_log(name): |
|
assert level is not None |
|
log_qnames = log_registry.log_alias_to_log_qnames[name] |
|
log_state.enable_log(log_qnames, level) |
|
elif log_registry.is_artifact(name): |
|
log_state.enable_artifact(name) |
|
elif _is_valid_module(name): |
|
if not _has_registered_parent(name): |
|
log_registry.register_log(name, name) |
|
else: |
|
log_registry.register_child_log(name) |
|
log_state.enable_log(name, level) |
|
else: |
|
raise ValueError(_invalid_settings_err_msg(settings)) |
|
|
|
return log_state |
|
|
|
|
|
def _is_valid_module(qname): |
|
try: |
|
__import__(qname) |
|
return True |
|
except ImportError: |
|
return False |
|
|
|
|
|
def _update_log_state_from_env(): |
|
global log_state |
|
log_setting = os.environ.get(LOG_ENV_VAR, None) |
|
if log_setting is not None: |
|
log_state = _parse_log_settings(log_setting) |
|
|
|
|
|
def _has_registered_parent(log_qname): |
|
cur_log = logging.getLogger(log_qname) |
|
|
|
registered_log_qnames = log_registry.get_log_qnames() |
|
|
|
while cur_log.parent: |
|
if cur_log.name in registered_log_qnames: |
|
return True |
|
cur_log = cur_log.parent |
|
|
|
return False |
|
|
|
|
|
|
|
class TorchLogsFormatter(logging.Formatter): |
|
def __init__(self, *, trace: bool = False): |
|
super().__init__() |
|
self._is_trace = trace |
|
|
|
def format(self, record): |
|
artifact_name = getattr(logging.getLogger(record.name), "artifact_name", None) |
|
if artifact_name is not None: |
|
artifact_formatter = log_registry.artifact_log_formatters.get( |
|
artifact_name, None |
|
) |
|
if artifact_formatter is not None: |
|
return artifact_formatter.format(record) |
|
|
|
record.message = record.getMessage() |
|
record.asctime = self.formatTime(record, "%m%d %H:%M:%S") |
|
|
|
|
|
s = record.message |
|
if record.exc_info: |
|
|
|
|
|
if not record.exc_text: |
|
record.exc_text = self.formatException(record.exc_info) |
|
if record.exc_text: |
|
if s[-1:] != "\n": |
|
s = s + "\n" |
|
s = s + record.exc_text |
|
if record.stack_info: |
|
if s[-1:] != "\n": |
|
s = s + "\n" |
|
s = s + self.formatStack(record.stack_info) |
|
|
|
record.rankprefix = "" |
|
if not self._is_trace and dist.is_available() and dist.is_initialized(): |
|
record.rankprefix = f"[rank{dist.get_rank()}]:" |
|
|
|
record.traceid = "" |
|
if ( |
|
not self._is_trace |
|
and (trace_id := torch._guards.CompileContext.current_trace_id()) |
|
is not None |
|
): |
|
record.traceid = f" [{trace_id}]" |
|
|
|
glog_level_to_abbr = { |
|
"DEBUG": "V", |
|
"INFO": "I", |
|
"WARNING": "W", |
|
"ERROR": "E", |
|
"CRITICAL": "C", |
|
} |
|
|
|
shortlevel = glog_level_to_abbr.get(record.levelname, record.levelname) |
|
|
|
record.artifactprefix = "" |
|
if artifact_name is not None: |
|
record.artifactprefix = f" [__{artifact_name}]" |
|
|
|
prefix = ( |
|
f"{record.rankprefix}{shortlevel}{record.asctime}.{int(record.msecs*1000):06d} {record.thread} " |
|
f"{os.path.relpath(record.pathname, os.path.dirname(os.path.dirname(torch.__file__)))}:" |
|
f"{record.lineno}]{record.traceid}{record.artifactprefix}" |
|
) |
|
if self._is_trace: |
|
assert s == "" |
|
try: |
|
r = f"{prefix} {json.dumps(record.metadata)}" |
|
except TypeError: |
|
log.warning("failing metadata: %r", record.metadata) |
|
raise |
|
if record.payload is not None: |
|
r += "".join(f"\n\t{l}" for l in record.payload.split("\n")) |
|
return r |
|
else: |
|
lines = s.split("\n") |
|
return "\n".join(f"{prefix} {l}" for l in lines) |
|
|
|
|
|
def _default_formatter(): |
|
fmt = os.environ.get(LOG_FORMAT_ENV_VAR, None) |
|
if fmt is None: |
|
return TorchLogsFormatter() |
|
else: |
|
if fmt in ("short", "basic"): |
|
fmt = logging.BASIC_FORMAT |
|
return logging.Formatter(fmt) |
|
|
|
|
|
DEFAULT_FORMATTER = _default_formatter() |
|
|
|
|
|
def _setup_handlers(create_handler_fn, log): |
|
debug_handler = _track_handler(create_handler_fn()) |
|
debug_handler.setFormatter(DEFAULT_FORMATTER) |
|
debug_handler.setLevel(logging.DEBUG) |
|
log.addHandler(debug_handler) |
|
|
|
|
|
handlers = WeakSet() |
|
|
|
|
|
|
|
|
|
def _track_handler(handler): |
|
handlers.add(handler) |
|
return handler |
|
|
|
|
|
def _is_torch_handler(handler): |
|
return handler in handlers |
|
|
|
|
|
|
|
def _clear_handlers(log): |
|
to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)] |
|
for handler in to_remove: |
|
log.removeHandler(handler) |
|
|
|
|
|
def _reset_logs(): |
|
|
|
for log_qname in log_registry.get_log_qnames(): |
|
log = logging.getLogger(log_qname) |
|
log.setLevel(logging.WARNING) |
|
log.propagate = False |
|
_clear_handlers(log) |
|
|
|
|
|
for artifact_log_qname in itertools.chain( |
|
log_registry.get_artifact_log_qnames(), log_registry.get_child_log_qnames() |
|
): |
|
log = logging.getLogger(artifact_log_qname) |
|
log.setLevel(logging.NOTSET) |
|
log.propagate = True |
|
|
|
trace_log.propagate = False |
|
_clear_handlers(trace_log) |
|
|
|
|
|
def _get_log_state(): |
|
return log_state |
|
|
|
|
|
def _set_log_state(state): |
|
global log_state |
|
log_state = state |
|
|
|
|
|
def _init_logs(log_file_name=None): |
|
_reset_logs() |
|
_update_log_state_from_env() |
|
|
|
out = os.environ.get(LOG_OUT_ENV_VAR, None) |
|
if out is not None: |
|
log_file_name = out |
|
|
|
|
|
|
|
for log_qname in log_registry.get_log_qnames(): |
|
|
|
|
|
if log_qname == "torch": |
|
continue |
|
log = logging.getLogger(log_qname) |
|
log.setLevel(logging.NOTSET) |
|
|
|
|
|
|
|
for log_qname, level in log_state.get_log_level_pairs(): |
|
log = logging.getLogger(log_qname) |
|
log.setLevel(level) |
|
|
|
|
|
for log_qname in log_registry.get_log_qnames(): |
|
log = logging.getLogger(log_qname) |
|
_setup_handlers( |
|
logging.StreamHandler, |
|
log, |
|
) |
|
|
|
if log_file_name is not None: |
|
_setup_handlers( |
|
lambda: logging.FileHandler(log_file_name), |
|
log, |
|
) |
|
|
|
|
|
|
|
for artifact_log_qname in log_registry.get_artifact_log_qnames(): |
|
log = logging.getLogger(artifact_log_qname) |
|
configure_artifact_log(log) |
|
|
|
|
|
|
|
trace_dir_name = os.environ.get(TRACE_ENV_VAR, None) |
|
|
|
|
|
|
|
|
|
|
|
handler = LazyTraceHandler(trace_dir_name) |
|
|
|
|
|
|
|
trace_log.setLevel(logging.DEBUG) |
|
trace_log_handler = _track_handler(handler) |
|
trace_log_handler.setFormatter(TorchLogsFormatter(trace=True)) |
|
trace_log.addHandler(trace_log_handler) |
|
|
|
|
|
class LazyTraceHandler(logging.StreamHandler): |
|
"""Like FileHandler, but the file is allocated lazily only upon the first log message""" |
|
|
|
def __init__(self, root_dir: Optional[str]): |
|
|
|
|
|
self.root_dir = root_dir |
|
logging.Handler.__init__(self) |
|
self.stream = None |
|
self._builtin_open = open |
|
|
|
|
|
def close(self): |
|
self.acquire() |
|
try: |
|
try: |
|
if self.stream: |
|
try: |
|
self.flush() |
|
finally: |
|
stream = self.stream |
|
self.stream = None |
|
if hasattr(stream, "close"): |
|
stream.close() |
|
finally: |
|
|
|
|
|
|
|
|
|
logging.StreamHandler.close(self) |
|
finally: |
|
self.release() |
|
|
|
def emit(self, record): |
|
if self.stream is None: |
|
ok = False |
|
if self.root_dir is None: |
|
TRACE_LOG_DIR = "/logs" |
|
open_func = self._builtin_open |
|
|
|
import torch.version as torch_version |
|
|
|
if hasattr(torch_version, "git_version"): |
|
log.info("LazyTraceHandler: disabled because not fbcode") |
|
elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"): |
|
log.info( |
|
"LazyTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False" |
|
) |
|
elif not os.path.exists(TRACE_LOG_DIR): |
|
log.info( |
|
"LazyTraceHandler: disabled because %s does not exist", |
|
TRACE_LOG_DIR, |
|
) |
|
elif not os.access(TRACE_LOG_DIR, os.W_OK): |
|
log.info( |
|
"LazyTraceHandler: disabled because %s is not writeable", |
|
TRACE_LOG_DIR, |
|
) |
|
else: |
|
self.root_dir = TRACE_LOG_DIR |
|
|
|
if self.root_dir is not None: |
|
os.makedirs(self.root_dir, exist_ok=True) |
|
ranksuffix = "" |
|
if dist.is_available() and dist.is_initialized(): |
|
ranksuffix = f"rank_{dist.get_rank()}_" |
|
self.stream = tempfile.NamedTemporaryFile( |
|
mode="w+", |
|
suffix=".log", |
|
prefix=f"dedicated_log_torch_trace_{ranksuffix}", |
|
dir=self.root_dir, |
|
delete=False, |
|
) |
|
log.info("LazyTraceHandler: logging to %s", self.stream.name) |
|
else: |
|
|
|
trace_log.removeHandler(self) |
|
return |
|
if self.stream: |
|
super().emit(record) |
|
|
|
|
|
@functools.lru_cache(None) |
|
def warning_once(logger_obj, *args, **kwargs): |
|
""" |
|
This function is similar to `logger.warning()`, but will emit the warning with the same message only once |
|
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache. |
|
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to |
|
another type of cache that includes the caller frame information in the hashing function. |
|
""" |
|
logger_obj.warning(*args, **kwargs) |
|
|
|
|
|
class LazyString: |
|
def __init__(self, func, *args, **kwargs): |
|
self.func = func |
|
self.args = args |
|
self.kwargs = kwargs |
|
|
|
def __str__(self): |
|
return self.func(*self.args, **self.kwargs) |
|
|
|
|
|
def trace_structured( |
|
name: str, |
|
|
|
|
|
metadata_fn: Callable[[], Union[Dict[str, Any], Tuple[str, int]]] = dict, |
|
*, |
|
payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None, |
|
suppress_context: bool = False, |
|
): |
|
""" |
|
metadata is an arbitrary JSON compatible struct, but it's expected to not be |
|
too long (e.g., less than 1MB) |
|
|
|
payload is an arbitrary string, which can be arbitrarily long (but expected to have |
|
newlines so no lines are too long) |
|
""" |
|
assert "name" not in ["rank", "frame_id", "frame_compile_id", "attempt"] |
|
assert callable( |
|
metadata_fn |
|
), f"metadata_fn should be callable, but got {type(metadata_fn)}" |
|
assert callable( |
|
payload_fn |
|
), f"payload_fn should be callable, but got {type(payload_fn)}" |
|
|
|
|
|
if trace_log.handlers: |
|
record: Dict[str, object] = {} |
|
record[name] = metadata_fn() |
|
if not suppress_context: |
|
|
|
|
|
|
|
if dist.is_available() and dist.is_initialized(): |
|
record["rank"] = dist.get_rank() |
|
if ( |
|
trace_id := torch._guards.CompileContext.current_trace_id() |
|
) is not None: |
|
record["frame_id"] = trace_id.compile_id.frame_id |
|
record["frame_compile_id"] = trace_id.compile_id.frame_compile_id |
|
record["attempt"] = trace_id.attempt |
|
else: |
|
|
|
|
|
record["stack"] = torch._logging.structured.from_traceback( |
|
CapturedTraceback.extract(skip=1).summary() |
|
) |
|
payload = payload_fn() |
|
if payload is not None: |
|
if not isinstance(payload, str): |
|
if isinstance(payload, list): |
|
|
|
payload = "[\n" + ",\n".join(json.dumps(i) for i in payload) + "\n]" |
|
else: |
|
|
|
payload = json.dumps(payload, indent=0) |
|
h = hashlib.md5() |
|
h.update(payload.encode("utf-8")) |
|
record["has_payload"] = h.hexdigest() |
|
trace_log.debug( |
|
"", extra={"metadata": record, "payload": payload}, stacklevel=2 |
|
) |
|
|
|
|
|
import torch._guards |
|
import torch._utils_internal |
|
import torch.distributed as dist |
|
|