veichta's picture
Upload folder using huggingface_hub
205a7af verified
import logging
from siclib.models import get_model
from siclib.models.base_model import BaseModel
logger = logging.getLogger(__name__)
# flake8: noqa
# mypy: ignore-errors
class GeoCalib(BaseModel):
default_conf = {
"backbone": {"name": "encoders.mscan"},
"ll_enc": {"name": "encoders.low_level_encoder"},
"perspective_decoder": {"name": "decoders.perspective_decoder"},
"optimizer": {"name": "optimization.lm_optimizer"},
}
required_data_keys = ["image"]
def _init(self, conf):
logger.debug(f"Initializing GeoCalib with {conf}")
self.backbone = get_model(conf.backbone["name"])(conf.backbone)
self.ll_enc = get_model(conf.ll_enc["name"])(conf.ll_enc) if conf.ll_enc else None
self.perspective_decoder = get_model(conf.perspective_decoder["name"])(
conf.perspective_decoder
)
self.optimizer = (
get_model(conf.optimizer["name"])(conf.optimizer) if conf.optimizer else None
)
def _forward(self, data):
backbone_out = self.backbone(data)
features = {"hl": backbone_out["features"], "padding": backbone_out.get("padding", None)}
if self.ll_enc is not None:
features["ll"] = self.ll_enc(data)["features"] # low level features
out = self.perspective_decoder({"features": features})
out |= {
k: data[k]
for k in ["image", "scales", "prior_gravity", "prior_focal", "prior_k1"]
if k in data
}
if self.optimizer is not None:
out |= self.optimizer(out)
return out
def loss(self, pred, data):
losses, metrics = self.perspective_decoder.loss(pred, data)
total = losses["perspective_total"]
if self.optimizer is not None:
opt_losses, param_metrics = self.optimizer.loss(pred, data)
losses |= opt_losses
metrics |= param_metrics
total = total + opt_losses["param_total"]
losses["total"] = total
return losses, metrics