|
import logging |
|
from copy import deepcopy |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision |
|
from torch.nn import Identity |
|
|
|
from siclib.geometry.camera import SimpleRadial |
|
from siclib.geometry.gravity import Gravity |
|
from siclib.models.base_model import BaseModel |
|
from siclib.models.utils.metrics import dist_error, pitch_error, roll_error, vfov_error |
|
from siclib.models.utils.modules import _DenseBlock, _Transition |
|
from siclib.utils.conversions import deg2rad, pitch2rho, rho2pitch |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
def get_centers_and_edges(min: float, max: float, num_bins: int) -> Tuple[np.ndarray, torch.Tensor]: |
|
centers = torch.linspace(min, max + ((max - min) / (num_bins - 1)), num_bins + 1).float() |
|
edges = centers.detach() - ((centers.detach()[1] - centers[0]) / 2.0) |
|
return centers, edges |
|
|
|
|
|
class DeepCalib(BaseModel): |
|
default_conf = { |
|
"name": "densenet", |
|
"model": "densenet161", |
|
"loss": "NLL", |
|
"num_bins": 256, |
|
"freeze_batch_normalization": False, |
|
"model": "densenet161", |
|
"pretrained": True, |
|
"heads": ["roll", "rho", "vfov", "k1_hat"], |
|
"flip": [], |
|
"rpf_scales": [1, 1, 1], |
|
"bounds": { |
|
"roll": [-45, 45], |
|
"rho": [-1, 1], |
|
"vfov": [20, 105], |
|
"k1_hat": [-0.7, 0.7], |
|
}, |
|
"use_softamax": False, |
|
} |
|
|
|
mean = [0.485, 0.456, 0.406] |
|
std = [0.229, 0.224, 0.225] |
|
|
|
strict_conf = False |
|
|
|
required_data_keys = ["image", "image_size"] |
|
|
|
def _init(self, conf): |
|
self.is_classification = True if self.conf.loss in ["NLL"] else False |
|
|
|
self.num_bins = conf.num_bins |
|
|
|
self.roll_centers, self.roll_edges = get_centers_and_edges( |
|
deg2rad(conf.bounds.roll[0]), deg2rad(conf.bounds.roll[1]), self.num_bins |
|
) |
|
|
|
self.rho_centers, self.rho_edges = get_centers_and_edges( |
|
conf.bounds.rho[0], conf.bounds.rho[1], self.num_bins |
|
) |
|
|
|
self.fov_centers, self.fov_edges = get_centers_and_edges( |
|
deg2rad(conf.bounds.vfov[0]), deg2rad(conf.bounds.vfov[1]), self.num_bins |
|
) |
|
|
|
self.k1_hat_centers, self.k1_hat_edges = get_centers_and_edges( |
|
conf.bounds.k1_hat[0], conf.bounds.k1_hat[1], self.num_bins |
|
) |
|
|
|
Model = getattr(torchvision.models, conf.model) |
|
weights = "DEFAULT" if self.conf.pretrained else None |
|
self.model = Model(weights=weights) |
|
|
|
layers = [] |
|
|
|
|
|
num_features = self.model.classifier.in_features |
|
head_layers = 3 |
|
layers.append(_Transition(num_features, num_features // 2)) |
|
num_features = num_features // 2 |
|
growth_rate = 32 |
|
layers.append( |
|
_DenseBlock( |
|
num_layers=head_layers, |
|
num_input_features=num_features, |
|
growth_rate=growth_rate, |
|
bn_size=4, |
|
drop_rate=0, |
|
) |
|
) |
|
layers.append(nn.BatchNorm2d(num_features + head_layers * growth_rate)) |
|
layers.append(nn.ReLU()) |
|
layers.append(nn.AdaptiveAvgPool2d((1, 1))) |
|
layers.append(nn.Flatten()) |
|
layers.append(nn.Linear(num_features + head_layers * growth_rate, 512)) |
|
layers.append(nn.ReLU()) |
|
self.model.classifier = Identity() |
|
self.model.features.norm5 = Identity() |
|
|
|
if self.is_classification: |
|
layers.append(nn.Linear(512, self.num_bins)) |
|
layers.append(nn.LogSoftmax(dim=1)) |
|
else: |
|
layers.append(nn.Linear(512, 1)) |
|
layers.append(nn.Tanh()) |
|
|
|
self.roll_head = nn.Sequential(*deepcopy(layers)) |
|
self.rho_head = nn.Sequential(*deepcopy(layers)) |
|
self.vfov_head = nn.Sequential(*deepcopy(layers)) |
|
self.k1_hat_head = nn.Sequential(*deepcopy(layers)) |
|
|
|
def bins_to_val(self, centers, pred): |
|
if centers.device != pred.device: |
|
centers = centers.to(pred.device) |
|
|
|
if not self.conf.use_softamax: |
|
return centers[pred.argmax(1)] |
|
|
|
beta = 1e-0 |
|
pred_softmax = F.softmax(pred / beta, dim=1) |
|
weighted_centers = centers[:-1].unsqueeze(0) * pred_softmax |
|
val = weighted_centers.sum(dim=1) |
|
return val |
|
|
|
def _forward(self, data): |
|
image = data["image"] |
|
mean, std = image.new_tensor(self.mean), image.new_tensor(self.std) |
|
image = (image - mean[:, None, None]) / std[:, None, None] |
|
shared_features = self.model.features(image) |
|
pred = {} |
|
|
|
if "roll" in self.conf.heads: |
|
pred["roll"] = self.roll_head(shared_features) |
|
if "rho" in self.conf.heads: |
|
pred["rho"] = self.rho_head(shared_features) |
|
if "vfov" in self.conf.heads: |
|
pred["vfov"] = self.vfov_head(shared_features) |
|
if "vfov" in self.conf.flip: |
|
pred["vfov"] = pred["vfov"] * -1 |
|
if "k1_hat" in self.conf.heads: |
|
pred["k1_hat"] = self.k1_hat_head(shared_features) |
|
|
|
size = data["image_size"] |
|
w, h = size[:, 0], size[:, 1] |
|
|
|
if self.is_classification: |
|
parameters = { |
|
"roll": self.bins_to_val(self.roll_centers, pred["roll"]), |
|
"rho": self.bins_to_val(self.rho_centers, pred["rho"]), |
|
"vfov": self.bins_to_val(self.fov_centers, pred["vfov"]), |
|
"k1_hat": self.bins_to_val(self.k1_hat_centers, pred["k1_hat"]), |
|
"width": w, |
|
"height": h, |
|
} |
|
|
|
for k in self.conf.flip: |
|
parameters[k] = parameters[k] * -1 |
|
|
|
for i, k in enumerate(["roll", "rho", "vfov"]): |
|
parameters[k] = parameters[k] * self.conf.rpf_scales[i] |
|
|
|
camera = SimpleRadial.from_dict(parameters) |
|
|
|
roll, pitch = parameters["roll"], rho2pitch(parameters["rho"], camera.f[..., 1], h) |
|
gravity = Gravity.from_rp(roll, pitch) |
|
|
|
else: |
|
if "roll" in self.conf.heads: |
|
pred["roll"] = pred["roll"] * deg2rad(45) |
|
if "vfov" in self.conf.heads: |
|
pred["vfov"] = (pred["vfov"] + 1) * deg2rad((105 - 20) / 2 + 20) |
|
|
|
camera = SimpleRadial.from_dict(pred | {"width": w, "height": h}) |
|
gravity = Gravity.from_rp(pred["roll"], pred["pitch"]) |
|
|
|
return pred | {"camera": camera, "gravity": gravity} |
|
|
|
def loss(self, pred, data): |
|
loss = {"total": 0} |
|
if self.conf.loss == "Huber": |
|
loss_fn = nn.HuberLoss(reduction="none") |
|
elif self.conf.loss == "L1": |
|
loss_fn = nn.L1Loss(reduction="none") |
|
elif self.conf.loss == "L2": |
|
loss_fn = nn.MSELoss(reduction="none") |
|
elif self.conf.loss == "NLL": |
|
loss_fn = nn.NLLLoss(reduction="none") |
|
|
|
gt_cam = data["camera"] |
|
|
|
if "roll" in self.conf.heads: |
|
|
|
gt_roll = data["gravity"].roll.float() |
|
pred_roll = pred["roll"].float() |
|
|
|
if gt_roll.device != self.roll_edges.device: |
|
self.roll_edges = self.roll_edges.to(gt_roll.device) |
|
self.roll_centers = self.roll_centers.to(gt_roll.device) |
|
|
|
if self.is_classification: |
|
gt_roll = ( |
|
torch.bucketize(gt_roll.contiguous(), self.roll_edges) - 1 |
|
) |
|
|
|
assert (gt_roll >= 0).all(), gt_roll |
|
assert (gt_roll < self.num_bins).all(), gt_roll |
|
else: |
|
assert pred_roll.dim() == gt_roll.dim() |
|
|
|
loss_roll = loss_fn(pred_roll, gt_roll) |
|
loss["roll"] = loss_roll |
|
loss["total"] += loss_roll |
|
|
|
if "rho" in self.conf.heads: |
|
gt_rho = pitch2rho(data["gravity"].pitch, gt_cam.f[..., 1], gt_cam.size[..., 1]).float() |
|
pred_rho = pred["rho"].float() |
|
|
|
if gt_rho.device != self.rho_edges.device: |
|
self.rho_edges = self.rho_edges.to(gt_rho.device) |
|
self.rho_centers = self.rho_centers.to(gt_rho.device) |
|
|
|
if self.is_classification: |
|
gt_rho = torch.bucketize(gt_rho.contiguous(), self.rho_edges) - 1 |
|
|
|
assert (gt_rho >= 0).all(), gt_rho |
|
assert (gt_rho < self.num_bins).all(), gt_rho |
|
else: |
|
assert pred_rho.dim() == gt_rho.dim() |
|
|
|
|
|
loss_rho = loss_fn(pred_rho, gt_rho) |
|
loss["rho"] = loss_rho |
|
loss["total"] += loss_rho |
|
|
|
if "vfov" in self.conf.heads: |
|
gt_vfov = gt_cam.vfov.float() |
|
pred_vfov = pred["vfov"].float() |
|
|
|
if gt_vfov.device != self.fov_edges.device: |
|
self.fov_edges = self.fov_edges.to(gt_vfov.device) |
|
self.fov_centers = self.fov_centers.to(gt_vfov.device) |
|
|
|
if self.is_classification: |
|
gt_vfov = torch.bucketize(gt_vfov.contiguous(), self.fov_edges) - 1 |
|
|
|
assert (gt_vfov >= 0).all(), gt_vfov |
|
assert (gt_vfov < self.num_bins).all(), gt_vfov |
|
else: |
|
min_vfov = deg2rad(self.conf.bounds.vfov[0]) |
|
max_vfov = deg2rad(self.conf.bounds.vfov[1]) |
|
gt_vfov = (2 * (gt_vfov - min_vfov) / (max_vfov - min_vfov)) - 1 |
|
assert pred_vfov.dim() == gt_vfov.dim() |
|
|
|
loss_vfov = loss_fn(pred_vfov, gt_vfov) |
|
loss["vfov"] = loss_vfov |
|
loss["total"] += loss_vfov |
|
|
|
if "k1_hat" in self.conf.heads: |
|
gt_k1_hat = data["camera"].k1_hat.float() |
|
pred_k1_hat = pred["k1_hat"].float() |
|
|
|
if gt_k1_hat.device != self.k1_hat_edges.device: |
|
self.k1_hat_edges = self.k1_hat_edges.to(gt_k1_hat.device) |
|
self.k1_hat_centers = self.k1_hat_centers.to(gt_k1_hat.device) |
|
|
|
if self.is_classification: |
|
gt_k1_hat = torch.bucketize(gt_k1_hat.contiguous(), self.k1_hat_edges) - 1 |
|
|
|
assert (gt_k1_hat >= 0).all(), gt_k1_hat |
|
assert (gt_k1_hat < self.num_bins).all(), gt_k1_hat |
|
else: |
|
assert pred_k1_hat.dim() == gt_k1_hat.dim() |
|
|
|
loss_k1_hat = loss_fn(pred_k1_hat, gt_k1_hat) |
|
loss["k1_hat"] = loss_k1_hat |
|
loss["total"] += loss_k1_hat |
|
|
|
return loss, self.metrics(pred, data) |
|
|
|
def metrics(self, pred, data): |
|
pred_cam, gt_cam = pred["camera"], data["camera"] |
|
pred_gravity, gt_gravity = pred["gravity"], data["gravity"] |
|
|
|
return { |
|
"roll_error": roll_error(pred_gravity, gt_gravity), |
|
"pitch_error": pitch_error(pred_gravity, gt_gravity), |
|
"vfov_error": vfov_error(pred_cam, gt_cam), |
|
"k1_error": dist_error(pred_cam, gt_cam), |
|
} |
|
|