File size: 6,768 Bytes
938e515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
from typing import List, Optional, Sequence, Tuple
import torch
from detectron2.layers.nms import batched_nms
from detectron2.structures.instances import Instances
from densepose.converters import ToChartResultConverterWithConfidences
from densepose.structures import (
DensePoseChartResultWithConfidences,
DensePoseEmbeddingPredictorOutput,
)
from densepose.vis.bounding_box import BoundingBoxVisualizer, ScoredBoundingBoxVisualizer
from densepose.vis.densepose_outputs_vertex import DensePoseOutputsVertexVisualizer
from densepose.vis.densepose_results import DensePoseResultsVisualizer
from .base import CompoundVisualizer
Scores = Sequence[float]
DensePoseChartResultsWithConfidences = List[DensePoseChartResultWithConfidences]
def extract_scores_from_instances(instances: Instances, select=None):
if instances.has("scores"):
return instances.scores if select is None else instances.scores[select]
return None
def extract_boxes_xywh_from_instances(instances: Instances, select=None):
if instances.has("pred_boxes"):
boxes_xywh = instances.pred_boxes.tensor.clone()
boxes_xywh[:, 2] -= boxes_xywh[:, 0]
boxes_xywh[:, 3] -= boxes_xywh[:, 1]
return boxes_xywh if select is None else boxes_xywh[select]
return None
def create_extractor(visualizer: object):
"""
Create an extractor for the provided visualizer
"""
if isinstance(visualizer, CompoundVisualizer):
extractors = [create_extractor(v) for v in visualizer.visualizers]
return CompoundExtractor(extractors)
elif isinstance(visualizer, DensePoseResultsVisualizer):
return DensePoseResultExtractor()
elif isinstance(visualizer, ScoredBoundingBoxVisualizer):
return CompoundExtractor([extract_boxes_xywh_from_instances, extract_scores_from_instances])
elif isinstance(visualizer, BoundingBoxVisualizer):
return extract_boxes_xywh_from_instances
elif isinstance(visualizer, DensePoseOutputsVertexVisualizer):
return DensePoseOutputsExtractor()
else:
logger = logging.getLogger(__name__)
logger.error(f"Could not create extractor for {visualizer}")
return None
class BoundingBoxExtractor:
"""
Extracts bounding boxes from instances
"""
def __call__(self, instances: Instances):
boxes_xywh = extract_boxes_xywh_from_instances(instances)
return boxes_xywh
class ScoredBoundingBoxExtractor:
"""
Extracts bounding boxes from instances
"""
def __call__(self, instances: Instances, select=None):
scores = extract_scores_from_instances(instances)
boxes_xywh = extract_boxes_xywh_from_instances(instances)
if (scores is None) or (boxes_xywh is None):
return (boxes_xywh, scores)
if select is not None:
scores = scores[select]
boxes_xywh = boxes_xywh[select]
return (boxes_xywh, scores)
class DensePoseResultExtractor:
"""
Extracts DensePose chart result with confidences from instances
"""
def __call__(
self, instances: Instances, select=None
) -> Tuple[Optional[DensePoseChartResultsWithConfidences], Optional[torch.Tensor]]:
if instances.has("pred_densepose") and instances.has("pred_boxes"):
dpout = instances.pred_densepose
boxes_xyxy = instances.pred_boxes
boxes_xywh = extract_boxes_xywh_from_instances(instances)
if select is not None:
dpout = dpout[select]
boxes_xyxy = boxes_xyxy[select]
converter = ToChartResultConverterWithConfidences()
results = [converter.convert(dpout[i], boxes_xyxy[[i]]) for i in range(len(dpout))]
return results, boxes_xywh
else:
return None, None
class DensePoseOutputsExtractor:
"""
Extracts DensePose result from instances
"""
def __call__(
self,
instances: Instances,
select=None,
) -> Tuple[
Optional[DensePoseEmbeddingPredictorOutput], Optional[torch.Tensor], Optional[List[int]]
]:
if not (instances.has("pred_densepose") and instances.has("pred_boxes")):
return None, None, None
dpout = instances.pred_densepose
boxes_xyxy = instances.pred_boxes
boxes_xywh = extract_boxes_xywh_from_instances(instances)
if instances.has("pred_classes"):
classes = instances.pred_classes.tolist()
else:
classes = None
if select is not None:
dpout = dpout[select]
boxes_xyxy = boxes_xyxy[select]
if classes is not None:
classes = classes[select]
return dpout, boxes_xywh, classes
class CompoundExtractor:
"""
Extracts data for CompoundVisualizer
"""
def __init__(self, extractors):
self.extractors = extractors
def __call__(self, instances: Instances, select=None):
datas = []
for extractor in self.extractors:
data = extractor(instances, select)
datas.append(data)
return datas
class NmsFilteredExtractor:
"""
Extracts data in the format accepted by NmsFilteredVisualizer
"""
def __init__(self, extractor, iou_threshold):
self.extractor = extractor
self.iou_threshold = iou_threshold
def __call__(self, instances: Instances, select=None):
scores = extract_scores_from_instances(instances)
boxes_xywh = extract_boxes_xywh_from_instances(instances)
if boxes_xywh is None:
return None
select_local_idx = batched_nms(
boxes_xywh,
scores,
torch.zeros(len(scores), dtype=torch.int32),
iou_threshold=self.iou_threshold,
).squeeze()
select_local = torch.zeros(len(boxes_xywh), dtype=torch.bool, device=boxes_xywh.device)
select_local[select_local_idx] = True
select = select_local if select is None else (select & select_local)
return self.extractor(instances, select=select)
class ScoreThresholdedExtractor:
"""
Extracts data in the format accepted by ScoreThresholdedVisualizer
"""
def __init__(self, extractor, min_score):
self.extractor = extractor
self.min_score = min_score
def __call__(self, instances: Instances, select=None):
scores = extract_scores_from_instances(instances)
if scores is None:
return None
select_local = scores > self.min_score
select = select_local if select is None else (select & select_local)
data = self.extractor(instances, select=select)
return data
|