StableVITON / densepose /vis /extractor.py
rlawjdghek's picture
det2 (#6)
1527335 verified
raw
history blame
6.77 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
from typing import List, Optional, Sequence, Tuple
import torch
from detectron2.layers.nms import batched_nms
from detectron2.structures.instances import Instances
from densepose.converters import ToChartResultConverterWithConfidences
from densepose.structures import (
DensePoseChartResultWithConfidences,
DensePoseEmbeddingPredictorOutput,
)
from densepose.vis.bounding_box import BoundingBoxVisualizer, ScoredBoundingBoxVisualizer
from densepose.vis.densepose_outputs_vertex import DensePoseOutputsVertexVisualizer
from densepose.vis.densepose_results import DensePoseResultsVisualizer
from .base import CompoundVisualizer
Scores = Sequence[float]
DensePoseChartResultsWithConfidences = List[DensePoseChartResultWithConfidences]
def extract_scores_from_instances(instances: Instances, select=None):
if instances.has("scores"):
return instances.scores if select is None else instances.scores[select]
return None
def extract_boxes_xywh_from_instances(instances: Instances, select=None):
if instances.has("pred_boxes"):
boxes_xywh = instances.pred_boxes.tensor.clone()
boxes_xywh[:, 2] -= boxes_xywh[:, 0]
boxes_xywh[:, 3] -= boxes_xywh[:, 1]
return boxes_xywh if select is None else boxes_xywh[select]
return None
def create_extractor(visualizer: object):
"""
Create an extractor for the provided visualizer
"""
if isinstance(visualizer, CompoundVisualizer):
extractors = [create_extractor(v) for v in visualizer.visualizers]
return CompoundExtractor(extractors)
elif isinstance(visualizer, DensePoseResultsVisualizer):
return DensePoseResultExtractor()
elif isinstance(visualizer, ScoredBoundingBoxVisualizer):
return CompoundExtractor([extract_boxes_xywh_from_instances, extract_scores_from_instances])
elif isinstance(visualizer, BoundingBoxVisualizer):
return extract_boxes_xywh_from_instances
elif isinstance(visualizer, DensePoseOutputsVertexVisualizer):
return DensePoseOutputsExtractor()
else:
logger = logging.getLogger(__name__)
logger.error(f"Could not create extractor for {visualizer}")
return None
class BoundingBoxExtractor:
"""
Extracts bounding boxes from instances
"""
def __call__(self, instances: Instances):
boxes_xywh = extract_boxes_xywh_from_instances(instances)
return boxes_xywh
class ScoredBoundingBoxExtractor:
"""
Extracts bounding boxes from instances
"""
def __call__(self, instances: Instances, select=None):
scores = extract_scores_from_instances(instances)
boxes_xywh = extract_boxes_xywh_from_instances(instances)
if (scores is None) or (boxes_xywh is None):
return (boxes_xywh, scores)
if select is not None:
scores = scores[select]
boxes_xywh = boxes_xywh[select]
return (boxes_xywh, scores)
class DensePoseResultExtractor:
"""
Extracts DensePose chart result with confidences from instances
"""
def __call__(
self, instances: Instances, select=None
) -> Tuple[Optional[DensePoseChartResultsWithConfidences], Optional[torch.Tensor]]:
if instances.has("pred_densepose") and instances.has("pred_boxes"):
dpout = instances.pred_densepose
boxes_xyxy = instances.pred_boxes
boxes_xywh = extract_boxes_xywh_from_instances(instances)
if select is not None:
dpout = dpout[select]
boxes_xyxy = boxes_xyxy[select]
converter = ToChartResultConverterWithConfidences()
results = [converter.convert(dpout[i], boxes_xyxy[[i]]) for i in range(len(dpout))]
return results, boxes_xywh
else:
return None, None
class DensePoseOutputsExtractor:
"""
Extracts DensePose result from instances
"""
def __call__(
self,
instances: Instances,
select=None,
) -> Tuple[
Optional[DensePoseEmbeddingPredictorOutput], Optional[torch.Tensor], Optional[List[int]]
]:
if not (instances.has("pred_densepose") and instances.has("pred_boxes")):
return None, None, None
dpout = instances.pred_densepose
boxes_xyxy = instances.pred_boxes
boxes_xywh = extract_boxes_xywh_from_instances(instances)
if instances.has("pred_classes"):
classes = instances.pred_classes.tolist()
else:
classes = None
if select is not None:
dpout = dpout[select]
boxes_xyxy = boxes_xyxy[select]
if classes is not None:
classes = classes[select]
return dpout, boxes_xywh, classes
class CompoundExtractor:
"""
Extracts data for CompoundVisualizer
"""
def __init__(self, extractors):
self.extractors = extractors
def __call__(self, instances: Instances, select=None):
datas = []
for extractor in self.extractors:
data = extractor(instances, select)
datas.append(data)
return datas
class NmsFilteredExtractor:
"""
Extracts data in the format accepted by NmsFilteredVisualizer
"""
def __init__(self, extractor, iou_threshold):
self.extractor = extractor
self.iou_threshold = iou_threshold
def __call__(self, instances: Instances, select=None):
scores = extract_scores_from_instances(instances)
boxes_xywh = extract_boxes_xywh_from_instances(instances)
if boxes_xywh is None:
return None
select_local_idx = batched_nms(
boxes_xywh,
scores,
torch.zeros(len(scores), dtype=torch.int32),
iou_threshold=self.iou_threshold,
).squeeze()
select_local = torch.zeros(len(boxes_xywh), dtype=torch.bool, device=boxes_xywh.device)
select_local[select_local_idx] = True
select = select_local if select is None else (select & select_local)
return self.extractor(instances, select=select)
class ScoreThresholdedExtractor:
"""
Extracts data in the format accepted by ScoreThresholdedVisualizer
"""
def __init__(self, extractor, min_score):
self.extractor = extractor
self.min_score = min_score
def __call__(self, instances: Instances, select=None):
scores = extract_scores_from_instances(instances)
if scores is None:
return None
select_local = scores > self.min_score
select = select_local if select is None else (select & select_local)
data = self.extractor(instances, select=select)
return data