File size: 2,526 Bytes
9b19c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
from abc import ABC, abstractmethod
from typing import Literal, TypeAlias

from torch.utils.tensorboard import SummaryWriter

from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
from tianshou.utils.string import ToStringMixin

TLogger: TypeAlias = BaseLogger


class LoggerFactory(ToStringMixin, ABC):
    @abstractmethod
    def create_logger(
        self,
        log_dir: str,
        experiment_name: str,
        run_id: str | None,
        config_dict: dict,
    ) -> TLogger:
        """Creates the logger.

        :param log_dir: path to the directory in which log data is to be stored
        :param experiment_name: the name of the job, which may contain `os.path.sep`
        :param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger
        :param config_dict: a dictionary with data that is to be logged
        :return: the logger
        """


class LoggerFactoryDefault(LoggerFactory):
    def __init__(
        self,
        logger_type: Literal["tensorboard", "wandb", "pandas"] = "tensorboard",
        wandb_project: str | None = None,
    ):
        if logger_type == "wandb" and wandb_project is None:
            raise ValueError("Must provide 'wandb_project'")
        self.logger_type = logger_type
        self.wandb_project = wandb_project

    def create_logger(
        self,
        log_dir: str,
        experiment_name: str,
        run_id: str | None,
        config_dict: dict,
    ) -> TLogger:
        if self.logger_type in ["wandb", "tensorboard"]:
            writer = SummaryWriter(log_dir)
            writer.add_text(
                "args",
                str(
                    dict(
                        log_dir=log_dir,
                        logger_type=self.logger_type,
                        wandb_project=self.wandb_project,
                    ),
                ),
            )
        match self.logger_type:
            case "wandb":
                wandb_logger = WandbLogger(
                    save_interval=1,
                    name=experiment_name.replace(os.path.sep, "__"),
                    run_id=run_id,
                    config=config_dict,
                    project=self.wandb_project,
                )
                wandb_logger.load(writer)
                return wandb_logger
            case "tensorboard":
                return TensorboardLogger(writer)
            case _:
                raise ValueError(f"Unknown logger type '{self.logger_type}'")