Spaces:
Running
Running
File size: 2,149 Bytes
499e141 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from dataclasses import dataclass
from typing import Callable
import numpy as np
from benchmark.reprojection import reprojection_error
from benchmark.utils import VARIANTS_ANGLE_SIN, quat_angle_error
@dataclass
class Inputs:
q_gt: np.array
t_gt: np.array
q_est: np.array
t_est: np.array
confidence: float
K: np.array
W: int
H: int
def __post_init__(self):
assert self.q_gt.shape == (4,), 'invalid gt quaternion shape'
assert self.t_gt.shape == (3,), 'invalid gt translation shape'
assert self.q_est.shape == (4,), 'invalid estimated quaternion shape'
assert self.t_est.shape == (3,), 'invalid estimated translation shape'
assert self.confidence >= 0, 'confidence must be non negative'
assert self.K.shape == (3, 3), 'invalid K shape'
assert self.W > 0, 'invalid image width'
assert self.H > 0, 'invalid image height'
class MyDict(dict):
def register(self, fn) -> Callable:
"""Registers a function within dict(fn_name -> fn_ref).
This is used to evaluate all registered metrics in MetricManager.__call__()"""
self[fn.__name__] = fn
return fn
class MetricManager:
_metrics = MyDict()
def __call__(self, inputs: Inputs, results: dict) -> None:
for metric, metric_fn in self._metrics.items():
results[metric].append(metric_fn(inputs))
@staticmethod
@_metrics.register
def trans_err(inputs: Inputs) -> np.float64:
return np.linalg.norm(inputs.t_est - inputs.t_gt)
@staticmethod
@_metrics.register
def rot_err(inputs: Inputs, variant: str = VARIANTS_ANGLE_SIN) -> np.float64:
return quat_angle_error(label=inputs.q_est, pred=inputs.q_gt, variant=variant)[0, 0]
@staticmethod
@_metrics.register
def reproj_err(inputs: Inputs) -> float:
return reprojection_error(
q_est=inputs.q_est, t_est=inputs.t_est, q_gt=inputs.q_gt, t_gt=inputs.t_gt, K=inputs.K,
W=inputs.W, H=inputs.H)
@staticmethod
@_metrics.register
def confidence(inputs: Inputs) -> float:
return inputs.confidence
|