Kano001's picture
Upload 5252 files
c61ccee verified
raw
history blame
1.63 kB
import itertools
import logging
from torch.hub import _Faketqdm, tqdm
# Disable progress bar by default, not in dynamo config because otherwise get a circular import
disable_progress = True
# Return all loggers that torchdynamo/torchinductor is responsible for
def get_loggers():
return [
logging.getLogger("torch.fx.experimental.symbolic_shapes"),
logging.getLogger("torch._dynamo"),
logging.getLogger("torch._inductor"),
]
# Creates a logging function that logs a message with a step # prepended.
# get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
# so that step numbers are initialized properly. e.g.:
# @functools.lru_cache(None)
# def _step_logger():
# return get_step_logger(logging.getLogger(...))
# def fn():
# _step_logger()(logging.INFO, "msg")
_step_counter = itertools.count(1)
# Update num_steps if more phases are added: Dynamo, AOT, Backend
# This is very inductor centric
# _inductor.utils.has_triton() gives a circular import error here
if not disable_progress:
try:
import triton # noqa: F401
num_steps = 3
except ImportError:
num_steps = 2
pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0)
def get_step_logger(logger):
if not disable_progress:
pbar.update(1)
if not isinstance(pbar, _Faketqdm):
pbar.set_postfix_str(f"{logger.name}")
step = next(_step_counter)
def log(level, msg, **kwargs):
logger.log(level, "Step %s: %s", step, msg, **kwargs)
return log