|
from collections.abc import Callable |
|
from typing import Any |
|
|
|
import numpy as np |
|
from matplotlib.figure import Figure |
|
from tensorboard.backend.event_processing import event_accumulator |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from tianshou.utils.logger.base import ( |
|
VALID_LOG_VALS, |
|
VALID_LOG_VALS_TYPE, |
|
BaseLogger, |
|
TRestoredData, |
|
) |
|
|
|
|
|
class TensorboardLogger(BaseLogger): |
|
"""A logger that relies on tensorboard SummaryWriter by default to visualize and log statistics. |
|
|
|
:param SummaryWriter writer: the writer to log data. |
|
:param train_interval: the log interval in log_train_data(). Default to 1000. |
|
:param test_interval: the log interval in log_test_data(). Default to 1. |
|
:param update_interval: the log interval in log_update_data(). Default to 1000. |
|
:param info_interval: the log interval in log_info_data(). Default to 1. |
|
:param save_interval: the save interval in save_data(). Default to 1 (save at |
|
the end of each epoch). |
|
:param write_flush: whether to flush tensorboard result after each |
|
add_scalar operation. Default to True. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
writer: SummaryWriter, |
|
train_interval: int = 1000, |
|
test_interval: int = 1, |
|
update_interval: int = 1000, |
|
info_interval: int = 1, |
|
save_interval: int = 1, |
|
write_flush: bool = True, |
|
) -> None: |
|
super().__init__(train_interval, test_interval, update_interval, info_interval) |
|
self.save_interval = save_interval |
|
self.write_flush = write_flush |
|
self.last_save_step = -1 |
|
self.writer = writer |
|
|
|
def prepare_dict_for_logging( |
|
self, |
|
input_dict: dict[str, Any], |
|
parent_key: str = "", |
|
delimiter: str = "/", |
|
exclude_arrays: bool = True, |
|
) -> dict[str, VALID_LOG_VALS_TYPE]: |
|
"""Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys. |
|
|
|
Filtering is performed with respect to valid logging data types. |
|
|
|
:param input_dict: The nested dictionary to be flattened and filtered. |
|
:param parent_key: The parent key used as a prefix before the input_dict keys. |
|
:param delimiter: The delimiter used to separate the keys. |
|
:param exclude_arrays: Whether to exclude numpy arrays from the output. |
|
:return: A flattened dictionary where the keys are compressed and values are filtered. |
|
""" |
|
result = {} |
|
|
|
def add_to_result( |
|
cur_dict: dict, |
|
prefix: str = "", |
|
) -> None: |
|
for key, value in cur_dict.items(): |
|
if exclude_arrays and isinstance(value, np.ndarray): |
|
continue |
|
|
|
new_key = prefix + delimiter + key |
|
new_key = new_key.lstrip(delimiter) |
|
|
|
if isinstance(value, dict): |
|
add_to_result( |
|
value, |
|
new_key, |
|
) |
|
elif isinstance(value, VALID_LOG_VALS): |
|
result[new_key] = value |
|
|
|
add_to_result(input_dict, prefix=parent_key) |
|
return result |
|
|
|
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None: |
|
scope, step_name = step_type.split("/") |
|
self.writer.add_scalar(step_type, step, global_step=step) |
|
for k, v in data.items(): |
|
scope_key = f"{scope}/{k}" |
|
if isinstance(v, np.ndarray): |
|
self.writer.add_histogram(scope_key, v, global_step=step, bins="auto") |
|
elif isinstance(v, Figure): |
|
self.writer.add_figure(scope_key, v, global_step=step) |
|
else: |
|
self.writer.add_scalar(scope_key, v, global_step=step) |
|
if self.write_flush: |
|
self.writer.flush() |
|
|
|
def save_data( |
|
self, |
|
epoch: int, |
|
env_step: int, |
|
gradient_step: int, |
|
save_checkpoint_fn: Callable[[int, int, int], str] | None = None, |
|
) -> None: |
|
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: |
|
self.last_save_step = epoch |
|
save_checkpoint_fn(epoch, env_step, gradient_step) |
|
self.write("save/epoch", epoch, {"save/epoch": epoch}) |
|
self.write("save/env_step", env_step, {"save/env_step": env_step}) |
|
self.write( |
|
"save/gradient_step", |
|
gradient_step, |
|
{"save/gradient_step": gradient_step}, |
|
) |
|
|
|
def restore_data(self) -> tuple[int, int, int]: |
|
ea = event_accumulator.EventAccumulator(self.writer.log_dir) |
|
ea.Reload() |
|
|
|
try: |
|
epoch = ea.scalars.Items("save/epoch")[-1].step |
|
self.last_save_step = self.last_log_test_step = epoch |
|
gradient_step = ea.scalars.Items("save/gradient_step")[-1].step |
|
self.last_log_update_step = gradient_step |
|
except KeyError: |
|
epoch, gradient_step = 0, 0 |
|
try: |
|
env_step = ea.scalars.Items("save/env_step")[-1].step |
|
self.last_log_train_step = env_step |
|
except KeyError: |
|
env_step = 0 |
|
|
|
return epoch, env_step, gradient_step |
|
|
|
def restore_logged_data( |
|
self, |
|
log_path: str, |
|
) -> TRestoredData: |
|
"""Restores the logged data from the tensorboard log directory. |
|
|
|
The result is a nested dictionary where the keys are the tensorboard keys |
|
and the values are the corresponding numpy arrays. The keys in each level |
|
form a nested structure, where the hierarchy is represented by the slashes |
|
in the tensorboard key-strings. |
|
""" |
|
ea = event_accumulator.EventAccumulator(log_path) |
|
ea.Reload() |
|
|
|
def add_value_to_innermost_nested_dict( |
|
data_dict: dict[str, Any], |
|
key_string: str, |
|
value: Any, |
|
) -> None: |
|
"""A particular logic, walking through the keys in the |
|
`key_string` and adding the value to the `data_dict` in a nested manner, |
|
creating nested dictionaries on the fly if necessary, or updating existing ones. |
|
The value is added only to the innermost-nested dictionary. |
|
|
|
|
|
Example: |
|
------- |
|
>>> data_dict = {} |
|
>>> add_value_to_innermost_nested_dict(data_dict, "a/b/c", 1) |
|
>>> data_dict |
|
{"a": {"b": {"c": 1}}} |
|
""" |
|
keys = key_string.split("/") |
|
|
|
cur_nested_dict = data_dict |
|
|
|
|
|
for k in keys[:-1]: |
|
cur_nested_dict = cur_nested_dict.setdefault(k, {}) |
|
|
|
|
|
|
|
cur_nested_dict[keys[-1]] = value |
|
|
|
restored_data: dict[str, np.ndarray | dict] = {} |
|
for key_string in ea.scalars.Keys(): |
|
add_value_to_innermost_nested_dict( |
|
restored_data, |
|
key_string, |
|
np.array([s.value for s in ea.scalars.Items(key_string)]), |
|
) |
|
|
|
return restored_data |
|
|