Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. | |
import logging | |
from typing import List, Optional, Sequence, Tuple | |
import torch | |
from detectron2.layers.nms import batched_nms | |
from detectron2.structures.instances import Instances | |
from densepose.converters import ToChartResultConverterWithConfidences | |
from densepose.structures import ( | |
DensePoseChartResultWithConfidences, | |
DensePoseEmbeddingPredictorOutput, | |
) | |
from densepose.vis.bounding_box import BoundingBoxVisualizer, ScoredBoundingBoxVisualizer | |
from densepose.vis.densepose_outputs_vertex import DensePoseOutputsVertexVisualizer | |
from densepose.vis.densepose_results import DensePoseResultsVisualizer | |
from .base import CompoundVisualizer | |
Scores = Sequence[float] | |
DensePoseChartResultsWithConfidences = List[DensePoseChartResultWithConfidences] | |
def extract_scores_from_instances(instances: Instances, select=None): | |
if instances.has("scores"): | |
return instances.scores if select is None else instances.scores[select] | |
return None | |
def extract_boxes_xywh_from_instances(instances: Instances, select=None): | |
if instances.has("pred_boxes"): | |
boxes_xywh = instances.pred_boxes.tensor.clone() | |
boxes_xywh[:, 2] -= boxes_xywh[:, 0] | |
boxes_xywh[:, 3] -= boxes_xywh[:, 1] | |
return boxes_xywh if select is None else boxes_xywh[select] | |
return None | |
def create_extractor(visualizer: object): | |
""" | |
Create an extractor for the provided visualizer | |
""" | |
if isinstance(visualizer, CompoundVisualizer): | |
extractors = [create_extractor(v) for v in visualizer.visualizers] | |
return CompoundExtractor(extractors) | |
elif isinstance(visualizer, DensePoseResultsVisualizer): | |
return DensePoseResultExtractor() | |
elif isinstance(visualizer, ScoredBoundingBoxVisualizer): | |
return CompoundExtractor([extract_boxes_xywh_from_instances, extract_scores_from_instances]) | |
elif isinstance(visualizer, BoundingBoxVisualizer): | |
return extract_boxes_xywh_from_instances | |
elif isinstance(visualizer, DensePoseOutputsVertexVisualizer): | |
return DensePoseOutputsExtractor() | |
else: | |
logger = logging.getLogger(__name__) | |
logger.error(f"Could not create extractor for {visualizer}") | |
return None | |
class BoundingBoxExtractor: | |
""" | |
Extracts bounding boxes from instances | |
""" | |
def __call__(self, instances: Instances): | |
boxes_xywh = extract_boxes_xywh_from_instances(instances) | |
return boxes_xywh | |
class ScoredBoundingBoxExtractor: | |
""" | |
Extracts bounding boxes from instances | |
""" | |
def __call__(self, instances: Instances, select=None): | |
scores = extract_scores_from_instances(instances) | |
boxes_xywh = extract_boxes_xywh_from_instances(instances) | |
if (scores is None) or (boxes_xywh is None): | |
return (boxes_xywh, scores) | |
if select is not None: | |
scores = scores[select] | |
boxes_xywh = boxes_xywh[select] | |
return (boxes_xywh, scores) | |
class DensePoseResultExtractor: | |
""" | |
Extracts DensePose chart result with confidences from instances | |
""" | |
def __call__( | |
self, instances: Instances, select=None | |
) -> Tuple[Optional[DensePoseChartResultsWithConfidences], Optional[torch.Tensor]]: | |
if instances.has("pred_densepose") and instances.has("pred_boxes"): | |
dpout = instances.pred_densepose | |
boxes_xyxy = instances.pred_boxes | |
boxes_xywh = extract_boxes_xywh_from_instances(instances) | |
if select is not None: | |
dpout = dpout[select] | |
boxes_xyxy = boxes_xyxy[select] | |
converter = ToChartResultConverterWithConfidences() | |
results = [converter.convert(dpout[i], boxes_xyxy[[i]]) for i in range(len(dpout))] | |
return results, boxes_xywh | |
else: | |
return None, None | |
class DensePoseOutputsExtractor: | |
""" | |
Extracts DensePose result from instances | |
""" | |
def __call__( | |
self, | |
instances: Instances, | |
select=None, | |
) -> Tuple[ | |
Optional[DensePoseEmbeddingPredictorOutput], Optional[torch.Tensor], Optional[List[int]] | |
]: | |
if not (instances.has("pred_densepose") and instances.has("pred_boxes")): | |
return None, None, None | |
dpout = instances.pred_densepose | |
boxes_xyxy = instances.pred_boxes | |
boxes_xywh = extract_boxes_xywh_from_instances(instances) | |
if instances.has("pred_classes"): | |
classes = instances.pred_classes.tolist() | |
else: | |
classes = None | |
if select is not None: | |
dpout = dpout[select] | |
boxes_xyxy = boxes_xyxy[select] | |
if classes is not None: | |
classes = classes[select] | |
return dpout, boxes_xywh, classes | |
class CompoundExtractor: | |
""" | |
Extracts data for CompoundVisualizer | |
""" | |
def __init__(self, extractors): | |
self.extractors = extractors | |
def __call__(self, instances: Instances, select=None): | |
datas = [] | |
for extractor in self.extractors: | |
data = extractor(instances, select) | |
datas.append(data) | |
return datas | |
class NmsFilteredExtractor: | |
""" | |
Extracts data in the format accepted by NmsFilteredVisualizer | |
""" | |
def __init__(self, extractor, iou_threshold): | |
self.extractor = extractor | |
self.iou_threshold = iou_threshold | |
def __call__(self, instances: Instances, select=None): | |
scores = extract_scores_from_instances(instances) | |
boxes_xywh = extract_boxes_xywh_from_instances(instances) | |
if boxes_xywh is None: | |
return None | |
select_local_idx = batched_nms( | |
boxes_xywh, | |
scores, | |
torch.zeros(len(scores), dtype=torch.int32), | |
iou_threshold=self.iou_threshold, | |
).squeeze() | |
select_local = torch.zeros(len(boxes_xywh), dtype=torch.bool, device=boxes_xywh.device) | |
select_local[select_local_idx] = True | |
select = select_local if select is None else (select & select_local) | |
return self.extractor(instances, select=select) | |
class ScoreThresholdedExtractor: | |
""" | |
Extracts data in the format accepted by ScoreThresholdedVisualizer | |
""" | |
def __init__(self, extractor, min_score): | |
self.extractor = extractor | |
self.min_score = min_score | |
def __call__(self, instances: Instances, select=None): | |
scores = extract_scores_from_instances(instances) | |
if scores is None: | |
return None | |
select_local = scores > self.min_score | |
select = select_local if select is None else (select & select_local) | |
data = self.extractor(instances, select=select) | |
return data | |