import pathlib from enum import Enum from typing import Any, Dict, List, Optional, Union from .logging import get_logger logger = get_logger() class BaseTracker: r"""Base class for loggers. Does nothing by default, so it is useful when you want to disable logging.""" def log(self, metrics: Dict[str, Any], step: int) -> None: pass def finish(self) -> None: pass class WandbTracker(BaseTracker): r"""Logger implementation for Weights & Biases.""" def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None) -> None: import wandb self.wandb = wandb # WandB does not create a directory if it does not exist and instead starts using the system temp directory. pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) self.run = wandb.init(project=experiment_name, dir=log_dir, config=config) logger.info("WandB logging enabled") def log(self, metrics: Dict[str, Any], step: int) -> None: self.run.log(metrics, step=step) def finish(self) -> None: self.run.finish() class SequentialTracker(BaseTracker): r"""Sequential tracker that logs to multiple trackers in sequence.""" def __init__(self, trackers: List[BaseTracker]) -> None: self.trackers = trackers def log(self, metrics: Dict[str, Any], step: int) -> None: for tracker in self.trackers: tracker.log(metrics, step) def finish(self) -> None: for tracker in self.trackers: tracker.finish() class Trackers(str, Enum): r"""Enum for supported trackers.""" NONE = "none" WANDB = "wandb" _SUPPORTED_TRACKERS = [tracker.value for tracker in Trackers.__members__.values()] def initialize_trackers( trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str ) -> Union[BaseTracker, SequentialTracker]: r"""Initialize loggers based on the provided configuration.""" logger.info(f"Initializing trackers: {trackers}. Logging to {log_dir=}") if len(trackers) == 0: return BaseTracker() if any(tracker_name not in _SUPPORTED_TRACKERS for tracker_name in set(trackers)): raise ValueError(f"Unsupported tracker(s) provided. Supported trackers: {_SUPPORTED_TRACKERS}") tracker_instances = [] for tracker_name in set(trackers): if tracker_name == Trackers.NONE: tracker = BaseTracker() elif tracker_name == Trackers.WANDB: tracker = WandbTracker(experiment_name, log_dir, config) tracker_instances.append(tracker) tracker = SequentialTracker(tracker_instances) return tracker TrackerType = Union[BaseTracker, SequentialTracker, WandbTracker]