|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Optional |
|
import torch |
|
|
|
from detectron2.structures import BoxMode, Instances |
|
|
|
from .utils import AnnotationsAccumulator |
|
|
|
|
|
@dataclass |
|
class PackedCseAnnotations: |
|
x_gt: torch.Tensor |
|
y_gt: torch.Tensor |
|
coarse_segm_gt: Optional[torch.Tensor] |
|
vertex_mesh_ids_gt: torch.Tensor |
|
vertex_ids_gt: torch.Tensor |
|
bbox_xywh_gt: torch.Tensor |
|
bbox_xywh_est: torch.Tensor |
|
point_bbox_with_dp_indices: torch.Tensor |
|
point_bbox_indices: torch.Tensor |
|
bbox_indices: torch.Tensor |
|
|
|
|
|
class CseAnnotationsAccumulator(AnnotationsAccumulator): |
|
""" |
|
Accumulates annotations by batches that correspond to objects detected on |
|
individual images. Can pack them together into single tensors. |
|
""" |
|
|
|
def __init__(self): |
|
self.x_gt = [] |
|
self.y_gt = [] |
|
self.s_gt = [] |
|
self.vertex_mesh_ids_gt = [] |
|
self.vertex_ids_gt = [] |
|
self.bbox_xywh_gt = [] |
|
self.bbox_xywh_est = [] |
|
self.point_bbox_with_dp_indices = [] |
|
self.point_bbox_indices = [] |
|
self.bbox_indices = [] |
|
self.nxt_bbox_with_dp_index = 0 |
|
self.nxt_bbox_index = 0 |
|
|
|
def accumulate(self, instances_one_image: Instances): |
|
""" |
|
Accumulate instances data for one image |
|
|
|
Args: |
|
instances_one_image (Instances): instances data to accumulate |
|
""" |
|
boxes_xywh_est = BoxMode.convert( |
|
instances_one_image.proposal_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS |
|
) |
|
boxes_xywh_gt = BoxMode.convert( |
|
instances_one_image.gt_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS |
|
) |
|
n_matches = len(boxes_xywh_gt) |
|
assert n_matches == len( |
|
boxes_xywh_est |
|
), f"Got {len(boxes_xywh_est)} proposal boxes and {len(boxes_xywh_gt)} GT boxes" |
|
if not n_matches: |
|
|
|
return |
|
if ( |
|
not hasattr(instances_one_image, "gt_densepose") |
|
or instances_one_image.gt_densepose is None |
|
): |
|
|
|
self.nxt_bbox_index += n_matches |
|
return |
|
for box_xywh_est, box_xywh_gt, dp_gt in zip( |
|
boxes_xywh_est, boxes_xywh_gt, instances_one_image.gt_densepose |
|
): |
|
if (dp_gt is not None) and (len(dp_gt.x) > 0): |
|
|
|
|
|
self._do_accumulate(box_xywh_gt, box_xywh_est, dp_gt) |
|
self.nxt_bbox_index += 1 |
|
|
|
def _do_accumulate(self, box_xywh_gt: torch.Tensor, box_xywh_est: torch.Tensor, dp_gt: Any): |
|
""" |
|
Accumulate instances data for one image, given that the data is not empty |
|
|
|
Args: |
|
box_xywh_gt (tensor): GT bounding box |
|
box_xywh_est (tensor): estimated bounding box |
|
dp_gt: GT densepose data with the following attributes: |
|
- x: normalized X coordinates |
|
- y: normalized Y coordinates |
|
- segm: tensor of size [S, S] with coarse segmentation |
|
- |
|
""" |
|
self.x_gt.append(dp_gt.x) |
|
self.y_gt.append(dp_gt.y) |
|
if hasattr(dp_gt, "segm"): |
|
self.s_gt.append(dp_gt.segm.unsqueeze(0)) |
|
self.vertex_ids_gt.append(dp_gt.vertex_ids) |
|
self.vertex_mesh_ids_gt.append(torch.full_like(dp_gt.vertex_ids, dp_gt.mesh_id)) |
|
self.bbox_xywh_gt.append(box_xywh_gt.view(-1, 4)) |
|
self.bbox_xywh_est.append(box_xywh_est.view(-1, 4)) |
|
self.point_bbox_with_dp_indices.append( |
|
torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_with_dp_index) |
|
) |
|
self.point_bbox_indices.append(torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_index)) |
|
self.bbox_indices.append(self.nxt_bbox_index) |
|
self.nxt_bbox_with_dp_index += 1 |
|
|
|
def pack(self) -> Optional[PackedCseAnnotations]: |
|
""" |
|
Pack data into tensors |
|
""" |
|
if not len(self.x_gt): |
|
|
|
|
|
|
|
|
|
|
|
return None |
|
return PackedCseAnnotations( |
|
x_gt=torch.cat(self.x_gt, 0), |
|
y_gt=torch.cat(self.y_gt, 0), |
|
vertex_mesh_ids_gt=torch.cat(self.vertex_mesh_ids_gt, 0), |
|
vertex_ids_gt=torch.cat(self.vertex_ids_gt, 0), |
|
|
|
coarse_segm_gt=torch.cat(self.s_gt, 0) |
|
if len(self.s_gt) == len(self.bbox_xywh_gt) |
|
else None, |
|
bbox_xywh_gt=torch.cat(self.bbox_xywh_gt, 0), |
|
bbox_xywh_est=torch.cat(self.bbox_xywh_est, 0), |
|
point_bbox_with_dp_indices=torch.cat(self.point_bbox_with_dp_indices, 0), |
|
point_bbox_indices=torch.cat(self.point_bbox_indices, 0), |
|
bbox_indices=torch.as_tensor( |
|
self.bbox_indices, dtype=torch.long, device=self.x_gt[0].device |
|
), |
|
) |
|
|