GeoCalib / siclib /models /decoders /perspective_decoder.py
veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
1.87 kB
"""Perspective fields decoder heads.
Adapted from https://github.com/jinlinyi/PerspectiveFields
"""
import logging
from siclib.models import get_model
from siclib.models.base_model import BaseModel
logger = logging.getLogger(__name__)
# flake8: noqa
# mypy: ignore-errors
class PerspectiveDecoder(BaseModel):
default_conf = {
"up_decoder": {"name": "decoders.up_decoder"},
"latitude_decoder": {"name": "decoders.latitude_decoder"},
}
required_data_keys = ["features"]
def _init(self, conf):
logger.debug(f"Initializing PerspectiveDecoder with config: {conf}")
self.use_up = conf.up_decoder is not None
self.use_latitude = conf.latitude_decoder is not None
if self.use_up:
self.up_head = get_model(conf.up_decoder.name)(conf.up_decoder)
if self.use_latitude:
self.latitude_head = get_model(conf.latitude_decoder.name)(conf.latitude_decoder)
def _forward(self, data):
out_up = self.up_head(data) if self.use_up else {}
out_lat = self.latitude_head(data) if self.use_latitude else {}
return out_up | out_lat
def loss(self, pred, data):
ref = data["up_field"] if self.use_up else data["latitude_field"]
total = ref.new_zeros(ref.shape[0])
losses, metrics = {}, {}
if self.use_up:
up_losses, up_metrics = self.up_head.loss(pred, data)
losses |= up_losses
metrics |= up_metrics
total = total + losses.get("up_total", 0)
if self.use_latitude:
latitude_losses, latitude_metrics = self.latitude_head.loss(pred, data)
losses |= latitude_losses
metrics |= latitude_metrics
total = total + losses.get("latitude_total", 0)
losses["perspective_total"] = total
return losses, metrics