"""GeoCalib model definition.""" import logging from typing import Dict import torch from torch import nn from torch.nn import functional as F from geocalib.lm_optimizer import LMOptimizer from geocalib.modules import MSCAN, ConvModule, LightHamHead # mypy: ignore-errors logger = logging.getLogger(__name__) class LowLevelEncoder(nn.Module): """Very simple low-level encoder.""" def __init__(self): """Simple low-level encoder.""" super().__init__() self.in_channel = 3 self.feat_dim = 64 self.conv1 = ConvModule(self.in_channel, self.feat_dim, kernel_size=3, padding=1) self.conv2 = ConvModule(self.feat_dim, self.feat_dim, kernel_size=3, padding=1) def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Forward pass.""" x = data["image"] assert ( x.shape[-1] % 32 == 0 and x.shape[-2] % 32 == 0 ), "Image size must be multiple of 32 if not using single image input." c1 = self.conv1(x) c2 = self.conv2(c1) return {"features": c2} class UpDecoder(nn.Module): """Minimal implementation of UpDecoder.""" def __init__(self): """Up decoder.""" super().__init__() self.decoder = LightHamHead() self.linear_pred_up = nn.Conv2d(self.decoder.out_channels, 2, kernel_size=1) def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Forward pass.""" x, log_confidence = self.decoder(data["features"]) up = self.linear_pred_up(x) return {"up_field": F.normalize(up, dim=1), "up_confidence": torch.sigmoid(log_confidence)} class LatitudeDecoder(nn.Module): """Minimal implementation of LatitudeDecoder.""" def __init__(self): """Latitude decoder.""" super().__init__() self.decoder = LightHamHead() self.linear_pred_latitude = nn.Conv2d(self.decoder.out_channels, 1, kernel_size=1) def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Forward pass.""" x, log_confidence = self.decoder(data["features"]) eps = 1e-5 # avoid nan in backward of asin lat = torch.tanh(self.linear_pred_latitude(x)) lat = torch.asin(torch.clamp(lat, -1 + eps, 1 - eps)) return {"latitude_field": lat, "latitude_confidence": torch.sigmoid(log_confidence)} class PerspectiveDecoder(nn.Module): """Minimal implementation of PerspectiveDecoder.""" def __init__(self): """Perspective decoder wrapping up and latitude decoders.""" super().__init__() self.up_head = UpDecoder() self.latitude_head = LatitudeDecoder() def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Forward pass.""" return self.up_head(data) | self.latitude_head(data) class GeoCalib(nn.Module): """GeoCalib inference model.""" def __init__(self, **optimizer_options): """Initialize the GeoCalib inference model. Args: optimizer_options: Options for the lm optimizer. """ super().__init__() self.backbone = MSCAN() self.ll_enc = LowLevelEncoder() self.perspective_decoder = PerspectiveDecoder() self.optimizer = LMOptimizer({**optimizer_options}) def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Forward pass.""" features = {"hl": self.backbone(data)["features"], "ll": self.ll_enc(data)["features"]} out = self.perspective_decoder({"features": features}) out |= { k: data[k] for k in ["image", "scales", "prior_gravity", "prior_focal", "prior_k1"] if k in data } out |= self.optimizer(out) return out def flexible_load(self, state_dict: Dict[str, torch.Tensor]) -> None: """Load a checkpoint with flexible key names.""" dict_params = set(state_dict.keys()) model_params = set(map(lambda n: n[0], self.named_parameters())) if dict_params == model_params: # perfect fit logger.info("Loading all parameters of the checkpoint.") self.load_state_dict(state_dict, strict=True) return elif len(dict_params & model_params) == 0: # perfect mismatch strip_prefix = lambda x: ".".join(x.split(".")[:1] + x.split(".")[2:]) state_dict = {strip_prefix(n): p for n, p in state_dict.items()} dict_params = set(state_dict.keys()) if len(dict_params & model_params) == 0: raise ValueError( "Could not manage to load the checkpoint with" "parameters:" + "\n\t".join(sorted(dict_params)) ) common_params = dict_params & model_params left_params = dict_params - model_params left_params = [ p for p in left_params if "running" not in p and "num_batches_tracked" not in p ] logger.debug("Loading parameters:\n\t" + "\n\t".join(sorted(common_params))) if left_params: # ignore running stats of batchnorm logger.warning("Could not load parameters:\n\t" + "\n\t".join(sorted(left_params))) self.load_state_dict(state_dict, strict=False)