File size: 5,413 Bytes
938e515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from dataclasses import dataclass
from typing import Any, Optional
import torch
from detectron2.structures import BoxMode, Instances
from .utils import AnnotationsAccumulator
@dataclass
class PackedCseAnnotations:
x_gt: torch.Tensor
y_gt: torch.Tensor
coarse_segm_gt: Optional[torch.Tensor]
vertex_mesh_ids_gt: torch.Tensor
vertex_ids_gt: torch.Tensor
bbox_xywh_gt: torch.Tensor
bbox_xywh_est: torch.Tensor
point_bbox_with_dp_indices: torch.Tensor
point_bbox_indices: torch.Tensor
bbox_indices: torch.Tensor
class CseAnnotationsAccumulator(AnnotationsAccumulator):
"""
Accumulates annotations by batches that correspond to objects detected on
individual images. Can pack them together into single tensors.
"""
def __init__(self):
self.x_gt = []
self.y_gt = []
self.s_gt = []
self.vertex_mesh_ids_gt = []
self.vertex_ids_gt = []
self.bbox_xywh_gt = []
self.bbox_xywh_est = []
self.point_bbox_with_dp_indices = []
self.point_bbox_indices = []
self.bbox_indices = []
self.nxt_bbox_with_dp_index = 0
self.nxt_bbox_index = 0
def accumulate(self, instances_one_image: Instances):
"""
Accumulate instances data for one image
Args:
instances_one_image (Instances): instances data to accumulate
"""
boxes_xywh_est = BoxMode.convert(
instances_one_image.proposal_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS
)
boxes_xywh_gt = BoxMode.convert(
instances_one_image.gt_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS
)
n_matches = len(boxes_xywh_gt)
assert n_matches == len(
boxes_xywh_est
), f"Got {len(boxes_xywh_est)} proposal boxes and {len(boxes_xywh_gt)} GT boxes"
if not n_matches:
# no detection - GT matches
return
if (
not hasattr(instances_one_image, "gt_densepose")
or instances_one_image.gt_densepose is None
):
# no densepose GT for the detections, just increase the bbox index
self.nxt_bbox_index += n_matches
return
for box_xywh_est, box_xywh_gt, dp_gt in zip(
boxes_xywh_est, boxes_xywh_gt, instances_one_image.gt_densepose
):
if (dp_gt is not None) and (len(dp_gt.x) > 0):
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `float`.
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got `float`.
self._do_accumulate(box_xywh_gt, box_xywh_est, dp_gt)
self.nxt_bbox_index += 1
def _do_accumulate(self, box_xywh_gt: torch.Tensor, box_xywh_est: torch.Tensor, dp_gt: Any):
"""
Accumulate instances data for one image, given that the data is not empty
Args:
box_xywh_gt (tensor): GT bounding box
box_xywh_est (tensor): estimated bounding box
dp_gt: GT densepose data with the following attributes:
- x: normalized X coordinates
- y: normalized Y coordinates
- segm: tensor of size [S, S] with coarse segmentation
-
"""
self.x_gt.append(dp_gt.x)
self.y_gt.append(dp_gt.y)
if hasattr(dp_gt, "segm"):
self.s_gt.append(dp_gt.segm.unsqueeze(0))
self.vertex_ids_gt.append(dp_gt.vertex_ids)
self.vertex_mesh_ids_gt.append(torch.full_like(dp_gt.vertex_ids, dp_gt.mesh_id))
self.bbox_xywh_gt.append(box_xywh_gt.view(-1, 4))
self.bbox_xywh_est.append(box_xywh_est.view(-1, 4))
self.point_bbox_with_dp_indices.append(
torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_with_dp_index)
)
self.point_bbox_indices.append(torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_index))
self.bbox_indices.append(self.nxt_bbox_index)
self.nxt_bbox_with_dp_index += 1
def pack(self) -> Optional[PackedCseAnnotations]:
"""
Pack data into tensors
"""
if not len(self.x_gt):
# TODO:
# returning proper empty annotations would require
# creating empty tensors of appropriate shape and
# type on an appropriate device;
# we return None so far to indicate empty annotations
return None
return PackedCseAnnotations(
x_gt=torch.cat(self.x_gt, 0),
y_gt=torch.cat(self.y_gt, 0),
vertex_mesh_ids_gt=torch.cat(self.vertex_mesh_ids_gt, 0),
vertex_ids_gt=torch.cat(self.vertex_ids_gt, 0),
# ignore segmentation annotations, if not all the instances contain those
coarse_segm_gt=torch.cat(self.s_gt, 0)
if len(self.s_gt) == len(self.bbox_xywh_gt)
else None,
bbox_xywh_gt=torch.cat(self.bbox_xywh_gt, 0),
bbox_xywh_est=torch.cat(self.bbox_xywh_est, 0),
point_bbox_with_dp_indices=torch.cat(self.point_bbox_with_dp_indices, 0),
point_bbox_indices=torch.cat(self.point_bbox_indices, 0),
bbox_indices=torch.as_tensor(
self.bbox_indices, dtype=torch.long, device=self.x_gt[0].device
),
)
|