OrienterNet / maploc /models /orienternet.py
Paul-Edouard Sarlin
Code formatting (#47)
b0cf684 unverified
raw
history blame
7.75 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
import numpy as np
import torch
from torch.nn.functional import normalize
from . import get_model
from .base import BaseModel
from .bev_net import BEVNet
from .bev_projection import CartesianProjection, PolarProjectionDepth
from .map_encoder import MapEncoder
from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall
from .voting import (
TemplateSampler,
argmax_xyr,
conv2d_fft_batchwise,
expectation_xyr,
log_softmax_spatial,
mask_yaw_prior,
nll_loss_xyr,
nll_loss_xyr_smoothed,
)
class OrienterNet(BaseModel):
default_conf = {
"image_encoder": "???",
"map_encoder": "???",
"bev_net": "???",
"latent_dim": "???",
"matching_dim": "???",
"scale_range": [0, 9],
"num_scale_bins": "???",
"z_min": None,
"z_max": "???",
"x_max": "???",
"pixel_per_meter": "???",
"num_rotations": "???",
"add_temperature": False,
"normalize_features": False,
"padding_matching": "replicate",
"apply_map_prior": True,
"do_label_smoothing": False,
"sigma_xy": 1,
"sigma_r": 2,
# depcreated
"depth_parameterization": "scale",
"norm_depth_scores": False,
"normalize_scores_by_dim": False,
"normalize_scores_by_num_valid": True,
"prior_renorm": True,
"retrieval_dim": None,
}
def _init(self, conf):
assert not self.conf.norm_depth_scores
assert self.conf.depth_parameterization == "scale"
assert not self.conf.normalize_scores_by_dim
assert self.conf.normalize_scores_by_num_valid
assert self.conf.prior_renorm
Encoder = get_model(conf.image_encoder.get("name", "feature_extractor_v2"))
self.image_encoder = Encoder(conf.image_encoder.backbone)
self.map_encoder = MapEncoder(conf.map_encoder)
self.bev_net = None if conf.bev_net is None else BEVNet(conf.bev_net)
ppm = conf.pixel_per_meter
self.projection_polar = PolarProjectionDepth(
conf.z_max,
ppm,
conf.scale_range,
conf.z_min,
)
self.projection_bev = CartesianProjection(
conf.z_max, conf.x_max, ppm, conf.z_min
)
self.template_sampler = TemplateSampler(
self.projection_bev.grid_xz, ppm, conf.num_rotations
)
self.scale_classifier = torch.nn.Linear(conf.latent_dim, conf.num_scale_bins)
if conf.bev_net is None:
self.feature_projection = torch.nn.Linear(
conf.latent_dim, conf.matching_dim
)
if conf.add_temperature:
temperature = torch.nn.Parameter(torch.tensor(0.0))
self.register_parameter("temperature", temperature)
def exhaustive_voting(self, f_bev, f_map, valid_bev, confidence_bev=None):
if self.conf.normalize_features:
f_bev = normalize(f_bev, dim=1)
f_map = normalize(f_map, dim=1)
# Build the templates and exhaustively match against the map.
if confidence_bev is not None:
f_bev = f_bev * confidence_bev.unsqueeze(1)
f_bev = f_bev.masked_fill(~valid_bev.unsqueeze(1), 0.0)
templates = self.template_sampler(f_bev)
with torch.autocast("cuda", enabled=False):
scores = conv2d_fft_batchwise(
f_map.float(),
templates.float(),
padding_mode=self.conf.padding_matching,
)
if self.conf.add_temperature:
scores = scores * torch.exp(self.temperature)
# Reweight the different rotations based on the number of valid pixels in each
# template. Axis-aligned rotation have the maximum number of valid pixels.
valid_templates = self.template_sampler(valid_bev.float()[None]) > (1 - 1e-4)
num_valid = valid_templates.float().sum((-3, -2, -1))
scores = scores / num_valid[..., None, None]
return scores
def _forward(self, data):
pred = {}
pred_map = pred["map"] = self.map_encoder(data)
f_map = pred_map["map_features"][0]
# Extract image features.
level = 0
f_image = self.image_encoder(data)["feature_maps"][level]
camera = data["camera"].scale(1 / self.image_encoder.scales[level])
camera = camera.to(data["image"].device, non_blocking=True)
# Estimate the monocular priors.
pred["pixel_scales"] = scales = self.scale_classifier(f_image.moveaxis(1, -1))
f_polar = self.projection_polar(f_image, scales, camera)
# Map to the BEV.
with torch.autocast("cuda", enabled=False):
f_bev, valid_bev, _ = self.projection_bev(
f_polar.float(), None, camera.float()
)
pred_bev = {}
if self.conf.bev_net is None:
# channel last -> classifier -> channel first
f_bev = self.feature_projection(f_bev.moveaxis(1, -1)).moveaxis(-1, 1)
else:
pred_bev = pred["bev"] = self.bev_net({"input": f_bev})
f_bev = pred_bev["output"]
scores = self.exhaustive_voting(
f_bev, f_map, valid_bev, pred_bev.get("confidence")
)
scores = scores.moveaxis(1, -1) # B,H,W,N
if "log_prior" in pred_map and self.conf.apply_map_prior:
scores = scores + pred_map["log_prior"][0].unsqueeze(-1)
# pred["scores_unmasked"] = scores.clone()
if "map_mask" in data:
scores.masked_fill_(~data["map_mask"][..., None], -np.inf)
if "yaw_prior" in data:
mask_yaw_prior(scores, data["yaw_prior"], self.conf.num_rotations)
log_probs = log_softmax_spatial(scores)
with torch.no_grad():
uvr_max = argmax_xyr(scores).to(scores)
uvr_avg, _ = expectation_xyr(log_probs.exp())
return {
**pred,
"scores": scores,
"log_probs": log_probs,
"uvr_max": uvr_max,
"uv_max": uvr_max[..., :2],
"yaw_max": uvr_max[..., 2],
"uvr_expectation": uvr_avg,
"uv_expectation": uvr_avg[..., :2],
"yaw_expectation": uvr_avg[..., 2],
"features_image": f_image,
"features_bev": f_bev,
"valid_bev": valid_bev.squeeze(1),
}
def loss(self, pred, data):
xy_gt = data["uv"]
yaw_gt = data["roll_pitch_yaw"][..., -1]
if self.conf.do_label_smoothing:
nll = nll_loss_xyr_smoothed(
pred["log_probs"],
xy_gt,
yaw_gt,
self.conf.sigma_xy / self.conf.pixel_per_meter,
self.conf.sigma_r,
mask=data.get("map_mask"),
)
else:
nll = nll_loss_xyr(pred["log_probs"], xy_gt, yaw_gt)
loss = {"total": nll, "nll": nll}
if self.training and self.conf.add_temperature:
loss["temperature"] = self.temperature.expand(len(nll))
return loss
def metrics(self):
return {
"xy_max_error": Location2DError("uv_max", self.conf.pixel_per_meter),
"xy_expectation_error": Location2DError(
"uv_expectation", self.conf.pixel_per_meter
),
"yaw_max_error": AngleError("yaw_max"),
"xy_recall_2m": Location2DRecall(2.0, self.conf.pixel_per_meter, "uv_max"),
"xy_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"),
"yaw_recall_2°": AngleRecall(2.0, "yaw_max"),
"yaw_recall_5°": AngleRecall(5.0, "yaw_max"),
}