|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Callable, Dict, List, Optional |
|
|
|
from detectron2.structures import Instances |
|
|
|
ModelOutput = Dict[str, Any] |
|
SampledData = Dict[str, Any] |
|
|
|
|
|
@dataclass |
|
class _Sampler: |
|
""" |
|
Sampler registry entry that contains: |
|
- src (str): source field to sample from (deleted after sampling) |
|
- dst (Optional[str]): destination field to sample to, if not None |
|
- func (Optional[Callable: Any -> Any]): function that performs sampling, |
|
if None, reference copy is performed |
|
""" |
|
|
|
src: str |
|
dst: Optional[str] |
|
func: Optional[Callable[[Any], Any]] |
|
|
|
|
|
class PredictionToGroundTruthSampler: |
|
""" |
|
Sampler implementation that converts predictions to GT using registered |
|
samplers for different fields of `Instances`. |
|
""" |
|
|
|
def __init__(self, dataset_name: str = ""): |
|
self.dataset_name = dataset_name |
|
self._samplers = {} |
|
self.register_sampler("pred_boxes", "gt_boxes", None) |
|
self.register_sampler("pred_classes", "gt_classes", None) |
|
|
|
self.register_sampler("scores") |
|
|
|
def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]: |
|
""" |
|
Transform model output into ground truth data through sampling |
|
|
|
Args: |
|
model_output (Dict[str, Any]): model output |
|
Returns: |
|
Dict[str, Any]: sampled data |
|
""" |
|
for model_output_i in model_output: |
|
instances: Instances = model_output_i["instances"] |
|
|
|
for _, sampler in self._samplers.items(): |
|
if not instances.has(sampler.src) or sampler.dst is None: |
|
continue |
|
if sampler.func is None: |
|
instances.set(sampler.dst, instances.get(sampler.src)) |
|
else: |
|
instances.set(sampler.dst, sampler.func(instances)) |
|
|
|
for _, sampler in self._samplers.items(): |
|
if sampler.src != sampler.dst and instances.has(sampler.src): |
|
instances.remove(sampler.src) |
|
model_output_i["dataset"] = self.dataset_name |
|
return model_output |
|
|
|
def register_sampler( |
|
self, |
|
prediction_attr: str, |
|
gt_attr: Optional[str] = None, |
|
func: Optional[Callable[[Any], Any]] = None, |
|
): |
|
""" |
|
Register sampler for a field |
|
|
|
Args: |
|
prediction_attr (str): field to replace with a sampled value |
|
gt_attr (Optional[str]): field to store the sampled value to, if not None |
|
func (Optional[Callable: Any -> Any]): sampler function |
|
""" |
|
self._samplers[(prediction_attr, gt_attr)] = _Sampler( |
|
src=prediction_attr, dst=gt_attr, func=func |
|
) |
|
|
|
def remove_sampler( |
|
self, |
|
prediction_attr: str, |
|
gt_attr: Optional[str] = None, |
|
): |
|
""" |
|
Remove sampler for a field |
|
|
|
Args: |
|
prediction_attr (str): field to replace with a sampled value |
|
gt_attr (Optional[str]): field to store the sampled value to, if not None |
|
""" |
|
assert (prediction_attr, gt_attr) in self._samplers |
|
del self._samplers[(prediction_attr, gt_attr)] |
|
|