File size: 8,028 Bytes
205a7af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""Base class for trainable models."""

import logging
import re
from abc import ABCMeta, abstractmethod
from copy import copy

import omegaconf
import torch
from omegaconf import OmegaConf
from torch import nn

logger = logging.getLogger(__name__)

try:
    import wandb
except ImportError:
    logger.debug("Could not import wandb.")
    wandb = None

# flake8: noqa
# mypy: ignore-errors


class MetaModel(ABCMeta):
    def __prepare__(name, bases, **kwds):
        total_conf = OmegaConf.create()
        for base in bases:
            for key in ("base_default_conf", "default_conf"):
                update = getattr(base, key, {})
                if isinstance(update, dict):
                    update = OmegaConf.create(update)
                total_conf = OmegaConf.merge(total_conf, update)
        return dict(base_default_conf=total_conf)


class BaseModel(nn.Module, metaclass=MetaModel):
    """
    What the child model is expect to declare:
        default_conf: dictionary of the default configuration of the model.
        It recursively updates the default_conf of all parent classes, and
        it is updated by the user-provided configuration passed to __init__.
        Configurations can be nested.

        required_data_keys: list of expected keys in the input data dictionary.

        strict_conf (optional): boolean. If false, BaseModel does not raise
        an error when the user provides an unknown configuration entry.

        _init(self, conf): initialization method, where conf is the final
        configuration object (also accessible with `self.conf`). Accessing
        unknown configuration entries will raise an error.

        _forward(self, data): method that returns a dictionary of batched
        prediction tensors based on a dictionary of batched input data tensors.

        loss(self, pred, data): method that returns a dictionary of losses,
        computed from model predictions and input data. Each loss is a batch
        of scalars, i.e. a torch.Tensor of shape (B,).
        The total loss to be optimized has the key `'total'`.

        metrics(self, pred, data): method that returns a dictionary of metrics,
        each as a batch of scalars.
    """

    default_conf = {
        "name": None,
        "trainable": True,  # if false: do not optimize this model parameters
        "freeze_batch_normalization": False,  # use test-time statistics
        "timeit": False,  # time forward pass
        "watch": False,  # log weights and gradients to wandb
    }
    required_data_keys = []
    strict_conf = False

    def __init__(self, conf):
        """Perform some logic and call the _init method of the child model."""
        super().__init__()
        default_conf = OmegaConf.merge(self.base_default_conf, OmegaConf.create(self.default_conf))
        if self.strict_conf:
            OmegaConf.set_struct(default_conf, True)

        # fixme: backward compatibility
        if "pad" in conf and "pad" not in default_conf:  # backward compat.
            with omegaconf.read_write(conf):
                with omegaconf.open_dict(conf):
                    conf["interpolation"] = {"pad": conf.pop("pad")}

        if isinstance(conf, dict):
            conf = OmegaConf.create(conf)
        self.conf = conf = OmegaConf.merge(default_conf, conf)
        OmegaConf.set_readonly(conf, True)
        OmegaConf.set_struct(conf, True)
        self.required_data_keys = copy(self.required_data_keys)
        self._init(conf)

        # load pretrained weights
        if "weights" in conf and conf.weights is not None:
            logger.info(f"Loading checkpoint {conf.weights}")
            ckpt = torch.load(str(conf.weights), map_location="cpu", weights_only=False)
            weights_key = "model" if "model" in ckpt else "state_dict"
            self.flexible_load(ckpt[weights_key])

        if not conf.trainable:
            for p in self.parameters():
                p.requires_grad = False

        if conf.watch:
            try:
                wandb.watch(self, log="all", log_graph=True, log_freq=10)
                logger.info(f"Watching {self.__class__.__name__}.")
            except ValueError:
                logger.warning(f"Could not watch {self.__class__.__name__}.")

        n_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        logger.info(f"Creating model {self.__class__.__name__} ({n_trainable/1e6:.2f} Mio)")

    def flexible_load(self, state_dict):
        """TODO: fix a probable nasty bug, and move to BaseModel."""
        # replace *gravity* with *up*
        for key in list(state_dict.keys()):
            if "gravity" in key:
                new_key = key.replace("gravity", "up")
                state_dict[new_key] = state_dict.pop(key)
                # print(f"Renaming {key} to {new_key}")

        # replace *_head.* with *_head.decoder.* for original paramnet checkpoints
        for key in list(state_dict.keys()):
            if "linear_pred_latitude" in key or "linear_pred_up" in key:
                continue

            if "_head" in key and "_head.decoder" not in key:
                # check if _head.{num} in key
                pattern = r"_head\.\d+"
                if re.search(pattern, key):
                    continue
                new_key = key.replace("_head.", "_head.decoder.")
                state_dict[new_key] = state_dict.pop(key)
                # print(f"Renaming {key} to {new_key}")

        dict_params = set(state_dict.keys())
        model_params = set(map(lambda n: n[0], self.named_parameters()))

        if dict_params == model_params:  # perfect fit
            logger.info("Loading all parameters of the checkpoint.")
            self.load_state_dict(state_dict, strict=True)
            return
        elif len(dict_params & model_params) == 0:  # perfect mismatch
            strip_prefix = lambda x: ".".join(x.split(".")[:1] + x.split(".")[2:])
            state_dict = {strip_prefix(n): p for n, p in state_dict.items()}
            dict_params = set(state_dict.keys())
            if len(dict_params & model_params) == 0:
                raise ValueError(
                    "Could not manage to load the checkpoint with"
                    "parameters:" + "\n\t".join(sorted(dict_params))
                )
        common_params = dict_params & model_params
        left_params = dict_params - model_params
        left_params = [
            p for p in left_params if "running" not in p and "num_batches_tracked" not in p
        ]
        logger.debug("Loading parameters:\n\t" + "\n\t".join(sorted(common_params)))
        if left_params:
            # ignore running stats of batchnorm
            logger.warning("Could not load parameters:\n\t" + "\n\t".join(sorted(left_params)))
        self.load_state_dict(state_dict, strict=False)

    def train(self, mode=True):
        super().train(mode)

        def freeze_bn(module):
            if isinstance(module, nn.modules.batchnorm._BatchNorm):
                module.eval()

        if self.conf.freeze_batch_normalization:
            self.apply(freeze_bn)

        return self

    def forward(self, data):
        """Check the data and call the _forward method of the child model."""

        def recursive_key_check(expected, given):
            for key in expected:
                assert key in given, f"Missing key {key} in data: {list(given.keys())}"
                if isinstance(expected, dict):
                    recursive_key_check(expected[key], given[key])

        recursive_key_check(self.required_data_keys, data)
        return self._forward(data)

    @abstractmethod
    def _init(self, conf):
        """To be implemented by the child class."""
        raise NotImplementedError

    @abstractmethod
    def _forward(self, data):
        """To be implemented by the child class."""
        raise NotImplementedError

    @abstractmethod
    def loss(self, pred, data):
        """To be implemented by the child class."""
        raise NotImplementedError