|
"""Base class for trainable models.""" |
|
|
|
import logging |
|
import re |
|
from abc import ABCMeta, abstractmethod |
|
from copy import copy |
|
|
|
import omegaconf |
|
import torch |
|
from omegaconf import OmegaConf |
|
from torch import nn |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
try: |
|
import wandb |
|
except ImportError: |
|
logger.debug("Could not import wandb.") |
|
wandb = None |
|
|
|
|
|
|
|
|
|
|
|
class MetaModel(ABCMeta): |
|
def __prepare__(name, bases, **kwds): |
|
total_conf = OmegaConf.create() |
|
for base in bases: |
|
for key in ("base_default_conf", "default_conf"): |
|
update = getattr(base, key, {}) |
|
if isinstance(update, dict): |
|
update = OmegaConf.create(update) |
|
total_conf = OmegaConf.merge(total_conf, update) |
|
return dict(base_default_conf=total_conf) |
|
|
|
|
|
class BaseModel(nn.Module, metaclass=MetaModel): |
|
""" |
|
What the child model is expect to declare: |
|
default_conf: dictionary of the default configuration of the model. |
|
It recursively updates the default_conf of all parent classes, and |
|
it is updated by the user-provided configuration passed to __init__. |
|
Configurations can be nested. |
|
|
|
required_data_keys: list of expected keys in the input data dictionary. |
|
|
|
strict_conf (optional): boolean. If false, BaseModel does not raise |
|
an error when the user provides an unknown configuration entry. |
|
|
|
_init(self, conf): initialization method, where conf is the final |
|
configuration object (also accessible with `self.conf`). Accessing |
|
unknown configuration entries will raise an error. |
|
|
|
_forward(self, data): method that returns a dictionary of batched |
|
prediction tensors based on a dictionary of batched input data tensors. |
|
|
|
loss(self, pred, data): method that returns a dictionary of losses, |
|
computed from model predictions and input data. Each loss is a batch |
|
of scalars, i.e. a torch.Tensor of shape (B,). |
|
The total loss to be optimized has the key `'total'`. |
|
|
|
metrics(self, pred, data): method that returns a dictionary of metrics, |
|
each as a batch of scalars. |
|
""" |
|
|
|
default_conf = { |
|
"name": None, |
|
"trainable": True, |
|
"freeze_batch_normalization": False, |
|
"timeit": False, |
|
"watch": False, |
|
} |
|
required_data_keys = [] |
|
strict_conf = False |
|
|
|
def __init__(self, conf): |
|
"""Perform some logic and call the _init method of the child model.""" |
|
super().__init__() |
|
default_conf = OmegaConf.merge(self.base_default_conf, OmegaConf.create(self.default_conf)) |
|
if self.strict_conf: |
|
OmegaConf.set_struct(default_conf, True) |
|
|
|
|
|
if "pad" in conf and "pad" not in default_conf: |
|
with omegaconf.read_write(conf): |
|
with omegaconf.open_dict(conf): |
|
conf["interpolation"] = {"pad": conf.pop("pad")} |
|
|
|
if isinstance(conf, dict): |
|
conf = OmegaConf.create(conf) |
|
self.conf = conf = OmegaConf.merge(default_conf, conf) |
|
OmegaConf.set_readonly(conf, True) |
|
OmegaConf.set_struct(conf, True) |
|
self.required_data_keys = copy(self.required_data_keys) |
|
self._init(conf) |
|
|
|
|
|
if "weights" in conf and conf.weights is not None: |
|
logger.info(f"Loading checkpoint {conf.weights}") |
|
ckpt = torch.load(str(conf.weights), map_location="cpu", weights_only=False) |
|
weights_key = "model" if "model" in ckpt else "state_dict" |
|
self.flexible_load(ckpt[weights_key]) |
|
|
|
if not conf.trainable: |
|
for p in self.parameters(): |
|
p.requires_grad = False |
|
|
|
if conf.watch: |
|
try: |
|
wandb.watch(self, log="all", log_graph=True, log_freq=10) |
|
logger.info(f"Watching {self.__class__.__name__}.") |
|
except ValueError: |
|
logger.warning(f"Could not watch {self.__class__.__name__}.") |
|
|
|
n_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
logger.info(f"Creating model {self.__class__.__name__} ({n_trainable/1e6:.2f} Mio)") |
|
|
|
def flexible_load(self, state_dict): |
|
"""TODO: fix a probable nasty bug, and move to BaseModel.""" |
|
|
|
for key in list(state_dict.keys()): |
|
if "gravity" in key: |
|
new_key = key.replace("gravity", "up") |
|
state_dict[new_key] = state_dict.pop(key) |
|
|
|
|
|
|
|
for key in list(state_dict.keys()): |
|
if "linear_pred_latitude" in key or "linear_pred_up" in key: |
|
continue |
|
|
|
if "_head" in key and "_head.decoder" not in key: |
|
|
|
pattern = r"_head\.\d+" |
|
if re.search(pattern, key): |
|
continue |
|
new_key = key.replace("_head.", "_head.decoder.") |
|
state_dict[new_key] = state_dict.pop(key) |
|
|
|
|
|
dict_params = set(state_dict.keys()) |
|
model_params = set(map(lambda n: n[0], self.named_parameters())) |
|
|
|
if dict_params == model_params: |
|
logger.info("Loading all parameters of the checkpoint.") |
|
self.load_state_dict(state_dict, strict=True) |
|
return |
|
elif len(dict_params & model_params) == 0: |
|
strip_prefix = lambda x: ".".join(x.split(".")[:1] + x.split(".")[2:]) |
|
state_dict = {strip_prefix(n): p for n, p in state_dict.items()} |
|
dict_params = set(state_dict.keys()) |
|
if len(dict_params & model_params) == 0: |
|
raise ValueError( |
|
"Could not manage to load the checkpoint with" |
|
"parameters:" + "\n\t".join(sorted(dict_params)) |
|
) |
|
common_params = dict_params & model_params |
|
left_params = dict_params - model_params |
|
left_params = [ |
|
p for p in left_params if "running" not in p and "num_batches_tracked" not in p |
|
] |
|
logger.debug("Loading parameters:\n\t" + "\n\t".join(sorted(common_params))) |
|
if left_params: |
|
|
|
logger.warning("Could not load parameters:\n\t" + "\n\t".join(sorted(left_params))) |
|
self.load_state_dict(state_dict, strict=False) |
|
|
|
def train(self, mode=True): |
|
super().train(mode) |
|
|
|
def freeze_bn(module): |
|
if isinstance(module, nn.modules.batchnorm._BatchNorm): |
|
module.eval() |
|
|
|
if self.conf.freeze_batch_normalization: |
|
self.apply(freeze_bn) |
|
|
|
return self |
|
|
|
def forward(self, data): |
|
"""Check the data and call the _forward method of the child model.""" |
|
|
|
def recursive_key_check(expected, given): |
|
for key in expected: |
|
assert key in given, f"Missing key {key} in data: {list(given.keys())}" |
|
if isinstance(expected, dict): |
|
recursive_key_check(expected[key], given[key]) |
|
|
|
recursive_key_check(self.required_data_keys, data) |
|
return self._forward(data) |
|
|
|
@abstractmethod |
|
def _init(self, conf): |
|
"""To be implemented by the child class.""" |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def _forward(self, data): |
|
"""To be implemented by the child class.""" |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def loss(self, pred, data): |
|
"""To be implemented by the child class.""" |
|
raise NotImplementedError |
|
|