|
from datetime import datetime |
|
from io import TextIOBase |
|
import logging |
|
from logging import FileHandler, Formatter, Handler, StreamHandler |
|
from pathlib import Path |
|
import sys |
|
import time |
|
from typing import Optional |
|
|
|
import colorama |
|
|
|
from .env_vars import dispatcher_env_vars, trial_env_vars |
|
|
|
|
|
handlers = {} |
|
|
|
log_format = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' |
|
time_format = '%Y-%m-%d %H:%M:%S' |
|
formatter = Formatter(log_format, time_format) |
|
|
|
|
|
def init_logger() -> None: |
|
""" |
|
This function will (and should only) get invoked on the first time of importing nni (no matter which submodule). |
|
It will try to detect the running environment and setup logger accordingly. |
|
|
|
The detection should work in most cases but for `nnictl` and `nni.experiment`. |
|
They will be identified as "standalone" mode and must configure the logger by themselves. |
|
""" |
|
colorama.init() |
|
|
|
if dispatcher_env_vars.SDK_PROCESS == 'dispatcher': |
|
_init_logger_dispatcher() |
|
return |
|
|
|
trial_platform = trial_env_vars.NNI_PLATFORM |
|
|
|
if trial_platform == 'unittest': |
|
return |
|
|
|
if trial_platform and not trial_env_vars.REUSE_MODE: |
|
_init_logger_trial() |
|
return |
|
|
|
_init_logger_standalone() |
|
|
|
logging.getLogger('filelock').setLevel(logging.WARNING) |
|
|
|
_exp_log_initialized = False |
|
|
|
def init_logger_experiment() -> None: |
|
""" |
|
Initialize logger for `nni.experiment.Experiment`. |
|
|
|
This function will get invoked after `init_logger()`. |
|
""" |
|
global _exp_log_initialized |
|
if not _exp_log_initialized: |
|
_exp_log_initialized = True |
|
colorful_formatter = Formatter(log_format, time_format) |
|
colorful_formatter.format = _colorful_format |
|
handlers['_default_'].setFormatter(colorful_formatter) |
|
|
|
def start_experiment_log(experiment_id: str, log_directory: Path, debug: bool) -> None: |
|
log_path = _prepare_log_dir(log_directory) / 'dispatcher.log' |
|
log_level = logging.DEBUG if debug else logging.WARNING |
|
_register_handler(FileHandler(log_path), log_level, experiment_id) |
|
|
|
def stop_experiment_log(experiment_id: str) -> None: |
|
if experiment_id in handlers: |
|
logging.getLogger().removeHandler(handlers.pop(experiment_id)) |
|
|
|
|
|
def _init_logger_dispatcher() -> None: |
|
log_level_map = { |
|
'fatal': logging.CRITICAL, |
|
'error': logging.ERROR, |
|
'warning': logging.WARNING, |
|
'info': logging.WARNING, |
|
'debug': logging.DEBUG, |
|
'trace': 0 |
|
} |
|
|
|
log_path = _prepare_log_dir(dispatcher_env_vars.NNI_LOG_DIRECTORY) / 'dispatcher.log' |
|
log_level = log_level_map.get(dispatcher_env_vars.NNI_LOG_LEVEL, logging.WARNING) |
|
_register_handler(FileHandler(log_path), log_level) |
|
|
|
|
|
def _init_logger_trial() -> None: |
|
log_path = _prepare_log_dir(trial_env_vars.NNI_OUTPUT_DIR) / 'trial.log' |
|
log_file = open(log_path, 'a') |
|
_register_handler(StreamHandler(log_file), logging.WARNING) |
|
|
|
if trial_env_vars.NNI_PLATFORM == 'local': |
|
sys.stdout = _LogFileWrapper(log_file) |
|
|
|
|
|
def _init_logger_standalone() -> None: |
|
_register_handler(StreamHandler(sys.stdout), logging.WARNING) |
|
|
|
|
|
def _prepare_log_dir(path: Optional[str]) -> Path: |
|
if path is None: |
|
return Path() |
|
ret = Path(path) |
|
ret.mkdir(parents=True, exist_ok=True) |
|
return ret |
|
|
|
def _register_handler(handler: Handler, level: int, tag: str = '_default_') -> None: |
|
assert tag not in handlers |
|
handlers[tag] = handler |
|
handler.setFormatter(formatter) |
|
logger = logging.getLogger() |
|
logger.addHandler(handler) |
|
logger.setLevel(level) |
|
|
|
def _colorful_format(record): |
|
time = formatter.formatTime(record, time_format) |
|
if not record.name.startswith('nni.'): |
|
return '[{}] ({}) {}'.format(time, record.name, record.msg % record.args) |
|
if record.levelno >= logging.ERROR: |
|
color = colorama.Fore.RED |
|
elif record.levelno >= logging.WARNING: |
|
color = colorama.Fore.YELLOW |
|
elif record.levelno >= logging.WARNING: |
|
color = colorama.Fore.GREEN |
|
else: |
|
color = colorama.Fore.BLUE |
|
msg = color + (record.msg % record.args) + colorama.Style.RESET_ALL |
|
if record.levelno < logging.WARNING: |
|
return '[{}] {}:{} {}'.format(time, record.threadName, record.name, msg) |
|
else: |
|
return '[{}] {}'.format(time, msg) |
|
|
|
class _LogFileWrapper(TextIOBase): |
|
|
|
|
|
def __init__(self, log_file: TextIOBase): |
|
self.file: TextIOBase = log_file |
|
self.line_buffer: Optional[str] = None |
|
self.line_start_time: Optional[datetime] = None |
|
|
|
def write(self, s: str) -> int: |
|
cur_time = datetime.now() |
|
if self.line_buffer and (cur_time - self.line_start_time).total_seconds() > 0.1: |
|
self.flush() |
|
|
|
if self.line_buffer: |
|
self.line_buffer += s |
|
else: |
|
self.line_buffer = s |
|
self.line_start_time = cur_time |
|
|
|
if '\n' not in s: |
|
return len(s) |
|
|
|
time_str = cur_time.strftime(time_format) |
|
lines = self.line_buffer.split('\n') |
|
for line in lines[:-1]: |
|
self.file.write(f'[{time_str}] PRINT {line}\n') |
|
self.file.flush() |
|
|
|
self.line_buffer = lines[-1] |
|
self.line_start_time = cur_time |
|
return len(s) |
|
|
|
def flush(self) -> None: |
|
if self.line_buffer: |
|
time_str = self.line_start_time.strftime(time_format) |
|
self.file.write(f'[{time_str}] PRINT {self.line_buffer}\n') |
|
self.file.flush() |
|
self.line_buffer = None |
|
|