Spaces:
Sleeping
Sleeping
from typing import TYPE_CHECKING | |
from tensorboardX import SummaryWriter | |
if TYPE_CHECKING: | |
# TYPE_CHECKING is always False at runtime, but mypy will evaluate the contents of this block. | |
# So if you import this module within TYPE_CHECKING, you will get code hints and other benefits. | |
# Here is a good answer on stackoverflow: | |
# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports | |
from ding.framework import Parallel | |
class DistributedWriter(SummaryWriter): | |
""" | |
Overview: | |
A simple subclass of SummaryWriter that supports writing to one process in multi-process mode. | |
The best way is to use it in conjunction with the ``router`` to take advantage of the message \ | |
and event components of the router (see ``writer.plugin``). | |
Interfaces: | |
``get_instance``, ``plugin``, ``initialize``, ``__del__`` | |
""" | |
root = None | |
def __init__(self, *args, **kwargs): | |
""" | |
Overview: | |
Initialize the DistributedWriter object. | |
Arguments: | |
- args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ | |
SummaryWriter. | |
- kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ | |
SummaryWriter. | |
""" | |
self._default_writer_to_disk = kwargs.get("write_to_disk") if "write_to_disk" in kwargs else True | |
# We need to write data to files lazily, so we should not use file writer in __init__, | |
# On the contrary, we will initialize the file writer when the user calls the | |
# add_* function for the first time | |
kwargs["write_to_disk"] = False | |
super().__init__(*args, **kwargs) | |
self._in_parallel = False | |
self._router = None | |
self._is_writer = False | |
self._lazy_initialized = False | |
def get_instance(cls, *args, **kwargs) -> "DistributedWriter": | |
""" | |
Overview: | |
Get instance and set the root level instance on the first called. If args and kwargs is none, | |
this method will return root instance. | |
Arguments: | |
- args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ | |
SummaryWriter. | |
- kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ | |
SummaryWriter. | |
""" | |
if args or kwargs: | |
ins = cls(*args, **kwargs) | |
if cls.root is None: | |
cls.root = ins | |
return ins | |
else: | |
return cls.root | |
def plugin(self, router: "Parallel", is_writer: bool = False) -> "DistributedWriter": | |
""" | |
Overview: | |
Plugin ``router``, so when using this writer with active router, it will automatically send requests\ | |
to the main writer instead of writing it to the disk. So we can collect data from multiple processes\ | |
and write them into one file. | |
Arguments: | |
- router (:obj:`Parallel`): The router to be plugged in. | |
- is_writer (:obj:`bool`): Whether this writer is the main writer. | |
Examples: | |
>>> DistributedWriter().plugin(router, is_writer=True) | |
""" | |
if router.is_active: | |
self._in_parallel = True | |
self._router = router | |
self._is_writer = is_writer | |
if is_writer: | |
self.initialize() | |
self._lazy_initialized = True | |
router.on("distributed_writer", self._on_distributed_writer) | |
return self | |
def _on_distributed_writer(self, fn_name: str, *args, **kwargs): | |
""" | |
Overview: | |
This method is called when the router receives a request to write data. | |
Arguments: | |
- fn_name (:obj:`str`): The name of the function to be called. | |
- args (:obj:`Tuple`): The arguments passed to the function to be called. | |
- kwargs (:obj:`Dict`): The keyword arguments passed to the function to be called. | |
""" | |
if self._is_writer: | |
getattr(self, fn_name)(*args, **kwargs) | |
def initialize(self): | |
""" | |
Overview: | |
Initialize the file writer. | |
""" | |
self.close() | |
self._write_to_disk = self._default_writer_to_disk | |
self._get_file_writer() | |
self._lazy_initialized = True | |
def __del__(self): | |
""" | |
Overview: | |
Close the file writer. | |
""" | |
self.close() | |
def enable_parallel(fn_name, fn): | |
""" | |
Overview: | |
Decorator to enable parallel writing. | |
Arguments: | |
- fn_name (:obj:`str`): The name of the function to be called. | |
- fn (:obj:`Callable`): The function to be called. | |
""" | |
def _parallel_fn(self: DistributedWriter, *args, **kwargs): | |
if not self._lazy_initialized: | |
self.initialize() | |
if self._in_parallel and not self._is_writer: | |
self._router.emit("distributed_writer", fn_name, *args, **kwargs) | |
else: | |
fn(self, *args, **kwargs) | |
return _parallel_fn | |
ready_to_parallel_fns = [ | |
'add_audio', | |
'add_custom_scalars', | |
'add_custom_scalars_marginchart', | |
'add_custom_scalars_multilinechart', | |
'add_embedding', | |
'add_figure', | |
'add_graph', | |
'add_graph_deprecated', | |
'add_histogram', | |
'add_histogram_raw', | |
'add_hparams', | |
'add_image', | |
'add_image_with_boxes', | |
'add_images', | |
'add_mesh', | |
'add_onnx_graph', | |
'add_openvino_graph', | |
'add_pr_curve', | |
'add_pr_curve_raw', | |
'add_scalar', | |
'add_scalars', | |
'add_text', | |
'add_video', | |
] | |
for fn_name in ready_to_parallel_fns: | |
if hasattr(DistributedWriter, fn_name): | |
setattr(DistributedWriter, fn_name, enable_parallel(fn_name, getattr(DistributedWriter, fn_name))) | |
# Examples: | |
# In main, `distributed_writer.plugin(task.router, is_writer=True)`, | |
# In middleware, `distributed_writer.record()` | |
distributed_writer = DistributedWriter() | |