Spaces:
Sleeping
Sleeping
import json | |
from ditk import logging | |
import os | |
from typing import Optional, Tuple, Union, Dict, Any | |
import ditk.logging | |
import numpy as np | |
import yaml | |
from hbutils.system import touch | |
from tabulate import tabulate | |
from .log_writer_helper import DistributedWriter | |
def build_logger( | |
path: str, | |
name: Optional[str] = None, | |
need_tb: bool = True, | |
need_text: bool = True, | |
text_level: Union[int, str] = logging.INFO | |
) -> Tuple[Optional[logging.Logger], Optional['SummaryWriter']]: # noqa | |
""" | |
Overview: | |
Build text logger and tensorboard logger. | |
Arguments: | |
- path (:obj:`str`): Logger(``Textlogger`` & ``SummaryWriter``)'s saved dir | |
- name (:obj:`str`): The logger file name | |
- need_tb (:obj:`bool`): Whether ``SummaryWriter`` instance would be created and returned | |
- need_text (:obj:`bool`): Whether ``loggingLogger`` instance would be created and returned | |
- text_level (:obj:`int`` or :obj:`str`): Logging level of ``logging.Logger``, default set to ``logging.INFO`` | |
Returns: | |
- logger (:obj:`Optional[logging.Logger]`): Logger that displays terminal output | |
- tb_logger (:obj:`Optional['SummaryWriter']`): Saves output to tfboard, only return when ``need_tb``. | |
""" | |
if name is None: | |
name = 'default' | |
logger = LoggerFactory.create_logger(path, name=name, level=text_level) if need_text else None | |
tb_name = name + '_tb_logger' | |
tb_logger = TBLoggerFactory.create_logger(os.path.join(path, tb_name)) if need_tb else None | |
return logger, tb_logger | |
class TBLoggerFactory(object): | |
""" | |
Overview: | |
TBLoggerFactory is a factory class for ``SummaryWriter``. | |
Interfaces: | |
``create_logger`` | |
Properties: | |
- ``tb_loggers`` (:obj:`Dict[str, SummaryWriter]`): A dict that stores ``SummaryWriter`` instances. | |
""" | |
tb_loggers = {} | |
def create_logger(cls: type, logdir: str) -> DistributedWriter: | |
if logdir in cls.tb_loggers: | |
return cls.tb_loggers[logdir] | |
tb_logger = DistributedWriter(logdir) | |
cls.tb_loggers[logdir] = tb_logger | |
return tb_logger | |
class LoggerFactory(object): | |
""" | |
Overview: | |
LoggerFactory is a factory class for ``logging.Logger``. | |
Interfaces: | |
``create_logger``, ``get_tabulate_vars``, ``get_tabulate_vars_hor`` | |
""" | |
def create_logger(cls, path: str, name: str = 'default', level: Union[int, str] = logging.INFO) -> logging.Logger: | |
""" | |
Overview: | |
Create logger using logging | |
Arguments: | |
- name (:obj:`str`): Logger's name | |
- path (:obj:`str`): Logger's save dir | |
- level (:obj:`int` or :obj:`str`): Used to set the level. Reference: ``Logger.setLevel`` method. | |
Returns: | |
- (:obj:`logging.Logger`): new logging logger | |
""" | |
ditk.logging.try_init_root(level) | |
logger_name = f'{name}_logger' | |
logger_file_path = os.path.join(path, f'{logger_name}.txt') | |
touch(logger_file_path) | |
logger = ditk.logging.getLogger(logger_name, level, [logger_file_path]) | |
logger.get_tabulate_vars = LoggerFactory.get_tabulate_vars | |
logger.get_tabulate_vars_hor = LoggerFactory.get_tabulate_vars_hor | |
return logger | |
def get_tabulate_vars(variables: Dict[str, Any]) -> str: | |
""" | |
Overview: | |
Get the text description in tabular form of all vars | |
Arguments: | |
- variables (:obj:`List[str]`): Names of the vars to query. | |
Returns: | |
- string (:obj:`str`): Text description in tabular form of all vars | |
""" | |
headers = ["Name", "Value"] | |
data = [] | |
for k, v in variables.items(): | |
data.append([k, "{:.6f}".format(v)]) | |
s = "\n" + tabulate(data, headers=headers, tablefmt='grid') | |
return s | |
def get_tabulate_vars_hor(variables: Dict[str, Any]) -> str: | |
""" | |
Overview: | |
Get the text description in tabular form of all vars | |
Arguments: | |
- variables (:obj:`List[str]`): Names of the vars to query. | |
""" | |
column_to_divide = 5 # which includes the header "Name & Value" | |
datak = [] | |
datav = [] | |
divide_count = 0 | |
for k, v in variables.items(): | |
if divide_count == 0 or divide_count >= (column_to_divide - 1): | |
datak.append("Name") | |
datav.append("Value") | |
if divide_count >= (column_to_divide - 1): | |
divide_count = 0 | |
divide_count += 1 | |
datak.append(k) | |
if not isinstance(v, str) and np.isscalar(v): | |
datav.append("{:.6f}".format(v)) | |
else: | |
datav.append(v) | |
s = "\n" | |
row_number = len(datak) // column_to_divide + 1 | |
for row_id in range(row_number): | |
item_start = row_id * column_to_divide | |
item_end = (row_id + 1) * column_to_divide | |
if (row_id + 1) * column_to_divide > len(datak): | |
item_end = len(datak) | |
data = [datak[item_start:item_end], datav[item_start:item_end]] | |
s = s + tabulate(data, tablefmt='grid') + "\n" | |
return s | |
def pretty_print(result: dict, direct_print: bool = True) -> str: | |
""" | |
Overview: | |
Print a dict ``result`` in a pretty way | |
Arguments: | |
- result (:obj:`dict`): The result to print | |
- direct_print (:obj:`bool`): Whether to print directly | |
Returns: | |
- string (:obj:`str`): The pretty-printed result in str format | |
""" | |
result = result.copy() | |
out = {} | |
for k, v in result.items(): | |
if v is not None: | |
out[k] = v | |
cleaned = json.dumps(out) | |
string = yaml.safe_dump(json.loads(cleaned), default_flow_style=False) | |
if direct_print: | |
print(string) | |
return string | |