File size: 3,861 Bytes
205a7af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""This module implements the writer class for logging to tensorboard or wandb."""

import logging
import os
from typing import Any, Dict, Optional

from omegaconf import DictConfig
from torch import nn
from torch.utils.tensorboard import SummaryWriter as TFSummaryWriter

from siclib import __module_name__

logger = logging.getLogger(__name__)

try:
    import wandb
except ImportError:
    logger.debug("Could not import wandb.")
    wandb = None

# mypy: ignore-errors


def dot_conf(conf: DictConfig) -> Dict[str, Any]:
    """Recursively convert a DictConfig to a flat dict with keys joined by dots."""
    d = {}
    for k, v in conf.items():
        if isinstance(v, DictConfig):
            d |= {f"{k}.{k2}": v2 for k2, v2 in dot_conf(v).items()}
        else:
            d[k] = v
    return d


class SummaryWriter:
    """Writer class for logging to tensorboard or wandb."""

    def __init__(self, conf: DictConfig, args: DictConfig, log_dir: str):
        """Initialize the writer."""
        self.conf = conf

        if not conf.train.writer:
            self.use_wandb = False
            self.use_tensorboard = False
            return

        self.use_wandb = "wandb" in conf.train.writer
        self.use_tensorboard = "tensorboard" in conf.train.writer

        if self.use_wandb and not wandb:
            raise ImportError("wandb not installed.")

        if self.use_tensorboard:
            self.writer = TFSummaryWriter(log_dir=log_dir)

        if self.use_wandb:
            os.environ["WANDB__SERVICE_WAIT"] = "300"
            wandb.init(project=__module_name__, name=args.experiment, config=dot_conf(conf))

        if conf.train.writer and not self.use_wandb and not self.use_tensorboard:
            raise NotImplementedError(f"Writer {conf.train.writer} not implemented")

    def add_scalar(self, tag: str, value: float, step: Optional[int] = None):
        """Log a scalar value to tensorboard or wandb."""
        if self.use_wandb:
            step = 1 if step == 0 else step
            wandb.log({tag: value}, step=step)

        if self.use_tensorboard:
            self.writer.add_scalar(tag, value, step)

    def add_figure(self, tag: str, figure, step: Optional[int] = None):
        """Log a figure to tensorboard or wandb."""
        if self.use_wandb:
            step = 1 if step == 0 else step
            wandb.log({tag: figure}, step=step)
        if self.use_tensorboard:
            self.writer.add_figure(tag, figure, step)

    def add_histogram(self, tag: str, values, step: Optional[int] = None):
        """Log a histogram to tensorboard or wandb."""
        if self.use_tensorboard:
            self.writer.add_histogram(tag, values, step)

    def add_text(self, tag: str, text: str, step: Optional[int] = None):
        """Log text to tensorboard or wandb."""
        if self.use_tensorboard:
            self.writer.add_text(tag, text, step)

    def add_pr_curve(self, tag: str, values, step: Optional[int] = None):
        """Log a precision-recall curve to tensorboard or wandb."""
        if self.use_wandb:
            step = 1 if step == 0 else step
            # @TODO: check if this works
            # wandb.log({"pr": wandb.plots.precision_recall(y_test, y_probas, labels)})
            wandb.log({tag: wandb.plots.precision_recall(values)}, step=step)

        if self.use_tensorboard:
            self.writer.add_pr_curve(tag, values, step)

    def watch(self, model: nn.Module, log_freq: int = 1000):
        """Watch a model for gradient updates."""
        if self.use_wandb:
            wandb.watch(
                model,
                log="gradients",
                log_freq=log_freq,
            )

    def close(self):
        """Close the writer."""
        if self.use_wandb:
            wandb.finish()

        if self.use_tensorboard:
            self.writer.close()