File size: 3,369 Bytes
938e515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
# Copyright (c) Facebook, Inc. and its affiliates.
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from detectron2.structures import Instances
ModelOutput = Dict[str, Any]
SampledData = Dict[str, Any]
@dataclass
class _Sampler:
"""
Sampler registry entry that contains:
- src (str): source field to sample from (deleted after sampling)
- dst (Optional[str]): destination field to sample to, if not None
- func (Optional[Callable: Any -> Any]): function that performs sampling,
if None, reference copy is performed
"""
src: str
dst: Optional[str]
func: Optional[Callable[[Any], Any]]
class PredictionToGroundTruthSampler:
"""
Sampler implementation that converts predictions to GT using registered
samplers for different fields of `Instances`.
"""
def __init__(self, dataset_name: str = ""):
self.dataset_name = dataset_name
self._samplers = {}
self.register_sampler("pred_boxes", "gt_boxes", None)
self.register_sampler("pred_classes", "gt_classes", None)
# delete scores
self.register_sampler("scores")
def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]:
"""
Transform model output into ground truth data through sampling
Args:
model_output (Dict[str, Any]): model output
Returns:
Dict[str, Any]: sampled data
"""
for model_output_i in model_output:
instances: Instances = model_output_i["instances"]
# transform data in each field
for _, sampler in self._samplers.items():
if not instances.has(sampler.src) or sampler.dst is None:
continue
if sampler.func is None:
instances.set(sampler.dst, instances.get(sampler.src))
else:
instances.set(sampler.dst, sampler.func(instances))
# delete model output data that was transformed
for _, sampler in self._samplers.items():
if sampler.src != sampler.dst and instances.has(sampler.src):
instances.remove(sampler.src)
model_output_i["dataset"] = self.dataset_name
return model_output
def register_sampler(
self,
prediction_attr: str,
gt_attr: Optional[str] = None,
func: Optional[Callable[[Any], Any]] = None,
):
"""
Register sampler for a field
Args:
prediction_attr (str): field to replace with a sampled value
gt_attr (Optional[str]): field to store the sampled value to, if not None
func (Optional[Callable: Any -> Any]): sampler function
"""
self._samplers[(prediction_attr, gt_attr)] = _Sampler(
src=prediction_attr, dst=gt_attr, func=func
)
def remove_sampler(
self,
prediction_attr: str,
gt_attr: Optional[str] = None,
):
"""
Remove sampler for a field
Args:
prediction_attr (str): field to replace with a sampled value
gt_attr (Optional[str]): field to store the sampled value to, if not None
"""
assert (prediction_attr, gt_attr) in self._samplers
del self._samplers[(prediction_attr, gt_attr)]
|