Spaces:
Running
Running
File size: 4,896 Bytes
9665c2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
from pathlib import Path
import pytorch_lightning as pl
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from torchmetrics import MeanMetric, MetricCollection
from . import logger
from .models import get_model
class AverageKeyMeter(MeanMetric):
def __init__(self, key, *args, **kwargs):
self.key = key
super().__init__(*args, **kwargs)
def update(self, dict):
value = dict[self.key]
value = value[torch.isfinite(value)]
return super().update(value)
class GenericModule(pl.LightningModule):
def __init__(self, cfg):
super().__init__()
name = cfg.model.get("name")
name = "orienternet" if name in ("localizer_bev_depth", None) else name
self.model = get_model(name)(cfg.model)
self.cfg = cfg
self.save_hyperparameters(cfg)
self.metrics_val = MetricCollection(self.model.metrics(), prefix="val/")
self.losses_val = None # we do not know the loss keys in advance
def forward(self, batch):
return self.model(batch)
def training_step(self, batch):
pred = self(batch)
losses = self.model.loss(pred, batch)
self.log_dict(
{f"loss/{k}/train": v.mean() for k, v in losses.items()},
prog_bar=True,
rank_zero_only=True,
)
return losses["total"].mean()
def validation_step(self, batch, batch_idx):
pred = self(batch)
losses = self.model.loss(pred, batch)
if self.losses_val is None:
self.losses_val = MetricCollection(
{k: AverageKeyMeter(k).to(self.device) for k in losses},
prefix="loss/",
postfix="/val",
)
self.metrics_val(pred, batch)
self.log_dict(self.metrics_val, sync_dist=True)
self.losses_val.update(losses)
self.log_dict(self.losses_val, sync_dist=True)
def validation_epoch_start(self, batch):
self.losses_val = None
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.training.lr)
ret = {"optimizer": optimizer}
cfg_scheduler = self.cfg.training.get("lr_scheduler")
if cfg_scheduler is not None:
scheduler = getattr(torch.optim.lr_scheduler, cfg_scheduler.name)(
optimizer=optimizer, **cfg_scheduler.get("args", {})
)
ret["lr_scheduler"] = {
"scheduler": scheduler,
"interval": "epoch",
"frequency": 1,
"monitor": "loss/total/val",
"strict": True,
"name": "learning_rate",
}
return ret
@classmethod
def load_from_checkpoint(
cls,
checkpoint_path,
map_location=None,
hparams_file=None,
strict=True,
cfg=None,
find_best=False,
):
assert hparams_file is None, "hparams are not supported."
checkpoint = torch.load(
checkpoint_path, map_location=map_location or (lambda storage, loc: storage)
)
if find_best:
best_score, best_name = None, None
modes = {"min": torch.lt, "max": torch.gt}
for key, state in checkpoint["callbacks"].items():
if not key.startswith("ModelCheckpoint"):
continue
mode = eval(key.replace("ModelCheckpoint", ""))["mode"]
if best_score is None or modes[mode](
state["best_model_score"], best_score
):
best_score = state["best_model_score"]
best_name = Path(state["best_model_path"]).name
logger.info("Loading best checkpoint %s", best_name)
if best_name != checkpoint_path:
return cls.load_from_checkpoint(
Path(checkpoint_path).parent / best_name,
map_location,
hparams_file,
strict,
cfg,
find_best=False,
)
logger.info(
"Using checkpoint %s from epoch %d and step %d.",
checkpoint_path.name,
checkpoint["epoch"],
checkpoint["global_step"],
)
cfg_ckpt = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
if list(cfg_ckpt.keys()) == ["cfg"]: # backward compatibility
cfg_ckpt = cfg_ckpt["cfg"]
cfg_ckpt = OmegaConf.create(cfg_ckpt)
if cfg is None:
cfg = {}
if not isinstance(cfg, DictConfig):
cfg = OmegaConf.create(cfg)
with open_dict(cfg_ckpt):
cfg = OmegaConf.merge(cfg_ckpt, cfg)
return pl.core.saving._load_state(cls, checkpoint, strict=strict, cfg=cfg)
|