|
|
|
|
|
from typing import List
|
|
import torch
|
|
|
|
from detectron2.config import CfgNode
|
|
from detectron2.structures import Instances
|
|
from detectron2.structures.boxes import matched_pairwise_iou
|
|
|
|
|
|
class DensePoseDataFilter:
|
|
def __init__(self, cfg: CfgNode):
|
|
self.iou_threshold = cfg.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD
|
|
self.keep_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
|
|
|
|
@torch.no_grad()
|
|
def __call__(self, features: List[torch.Tensor], proposals_with_targets: List[Instances]):
|
|
"""
|
|
Filters proposals with targets to keep only the ones relevant for
|
|
DensePose training
|
|
|
|
Args:
|
|
features (list[Tensor]): input data as a list of features,
|
|
each feature is a tensor. Axis 0 represents the number of
|
|
images `N` in the input data; axes 1-3 are channels,
|
|
height, and width, which may vary between features
|
|
(e.g., if a feature pyramid is used).
|
|
proposals_with_targets (list[Instances]): length `N` list of
|
|
`Instances`. The i-th `Instances` contains instances
|
|
(proposals, GT) for the i-th input image,
|
|
Returns:
|
|
list[Tensor]: filtered features
|
|
list[Instances]: filtered proposals
|
|
"""
|
|
proposals_filtered = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, proposals_per_image in enumerate(proposals_with_targets):
|
|
if not proposals_per_image.has("gt_densepose") and (
|
|
not proposals_per_image.has("gt_masks") or not self.keep_masks
|
|
):
|
|
|
|
continue
|
|
gt_boxes = proposals_per_image.gt_boxes
|
|
est_boxes = proposals_per_image.proposal_boxes
|
|
|
|
iou = matched_pairwise_iou(gt_boxes, est_boxes)
|
|
iou_select = iou > self.iou_threshold
|
|
proposals_per_image = proposals_per_image[iou_select]
|
|
|
|
N_gt_boxes = len(proposals_per_image.gt_boxes)
|
|
assert N_gt_boxes == len(proposals_per_image.proposal_boxes), (
|
|
f"The number of GT boxes {N_gt_boxes} is different from the "
|
|
f"number of proposal boxes {len(proposals_per_image.proposal_boxes)}"
|
|
)
|
|
|
|
if self.keep_masks:
|
|
gt_masks = (
|
|
proposals_per_image.gt_masks
|
|
if hasattr(proposals_per_image, "gt_masks")
|
|
else [None] * N_gt_boxes
|
|
)
|
|
else:
|
|
gt_masks = [None] * N_gt_boxes
|
|
gt_densepose = (
|
|
proposals_per_image.gt_densepose
|
|
if hasattr(proposals_per_image, "gt_densepose")
|
|
else [None] * N_gt_boxes
|
|
)
|
|
assert len(gt_masks) == N_gt_boxes
|
|
assert len(gt_densepose) == N_gt_boxes
|
|
selected_indices = [
|
|
i
|
|
for i, (dp_target, mask_target) in enumerate(zip(gt_densepose, gt_masks))
|
|
if (dp_target is not None) or (mask_target is not None)
|
|
]
|
|
|
|
|
|
|
|
if len(selected_indices) != N_gt_boxes:
|
|
proposals_per_image = proposals_per_image[selected_indices]
|
|
assert len(proposals_per_image.gt_boxes) == len(proposals_per_image.proposal_boxes)
|
|
proposals_filtered.append(proposals_per_image)
|
|
|
|
|
|
return features, proposals_filtered
|
|
|