File size: 3,861 Bytes
205a7af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
"""This module implements the writer class for logging to tensorboard or wandb."""
import logging
import os
from typing import Any, Dict, Optional
from omegaconf import DictConfig
from torch import nn
from torch.utils.tensorboard import SummaryWriter as TFSummaryWriter
from siclib import __module_name__
logger = logging.getLogger(__name__)
try:
import wandb
except ImportError:
logger.debug("Could not import wandb.")
wandb = None
# mypy: ignore-errors
def dot_conf(conf: DictConfig) -> Dict[str, Any]:
"""Recursively convert a DictConfig to a flat dict with keys joined by dots."""
d = {}
for k, v in conf.items():
if isinstance(v, DictConfig):
d |= {f"{k}.{k2}": v2 for k2, v2 in dot_conf(v).items()}
else:
d[k] = v
return d
class SummaryWriter:
"""Writer class for logging to tensorboard or wandb."""
def __init__(self, conf: DictConfig, args: DictConfig, log_dir: str):
"""Initialize the writer."""
self.conf = conf
if not conf.train.writer:
self.use_wandb = False
self.use_tensorboard = False
return
self.use_wandb = "wandb" in conf.train.writer
self.use_tensorboard = "tensorboard" in conf.train.writer
if self.use_wandb and not wandb:
raise ImportError("wandb not installed.")
if self.use_tensorboard:
self.writer = TFSummaryWriter(log_dir=log_dir)
if self.use_wandb:
os.environ["WANDB__SERVICE_WAIT"] = "300"
wandb.init(project=__module_name__, name=args.experiment, config=dot_conf(conf))
if conf.train.writer and not self.use_wandb and not self.use_tensorboard:
raise NotImplementedError(f"Writer {conf.train.writer} not implemented")
def add_scalar(self, tag: str, value: float, step: Optional[int] = None):
"""Log a scalar value to tensorboard or wandb."""
if self.use_wandb:
step = 1 if step == 0 else step
wandb.log({tag: value}, step=step)
if self.use_tensorboard:
self.writer.add_scalar(tag, value, step)
def add_figure(self, tag: str, figure, step: Optional[int] = None):
"""Log a figure to tensorboard or wandb."""
if self.use_wandb:
step = 1 if step == 0 else step
wandb.log({tag: figure}, step=step)
if self.use_tensorboard:
self.writer.add_figure(tag, figure, step)
def add_histogram(self, tag: str, values, step: Optional[int] = None):
"""Log a histogram to tensorboard or wandb."""
if self.use_tensorboard:
self.writer.add_histogram(tag, values, step)
def add_text(self, tag: str, text: str, step: Optional[int] = None):
"""Log text to tensorboard or wandb."""
if self.use_tensorboard:
self.writer.add_text(tag, text, step)
def add_pr_curve(self, tag: str, values, step: Optional[int] = None):
"""Log a precision-recall curve to tensorboard or wandb."""
if self.use_wandb:
step = 1 if step == 0 else step
# @TODO: check if this works
# wandb.log({"pr": wandb.plots.precision_recall(y_test, y_probas, labels)})
wandb.log({tag: wandb.plots.precision_recall(values)}, step=step)
if self.use_tensorboard:
self.writer.add_pr_curve(tag, values, step)
def watch(self, model: nn.Module, log_freq: int = 1000):
"""Watch a model for gradient updates."""
if self.use_wandb:
wandb.watch(
model,
log="gradients",
log_freq=log_freq,
)
def close(self):
"""Close the writer."""
if self.use_wandb:
wandb.finish()
if self.use_tensorboard:
self.writer.close()
|