Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. | |
from typing import Any, Dict, List, Tuple | |
import torch | |
from torch.nn import functional as F | |
from detectron2.config import CfgNode | |
from detectron2.structures import Instances | |
from densepose.converters.base import IntTupleBox | |
from densepose.data.utils import get_class_to_mesh_name_mapping | |
from densepose.modeling.cse.utils import squared_euclidean_distance_matrix | |
from densepose.structures import DensePoseDataRelative | |
from .densepose_base import DensePoseBaseSampler | |
class DensePoseCSEBaseSampler(DensePoseBaseSampler): | |
""" | |
Base DensePose sampler to produce DensePose data from DensePose predictions. | |
Samples for each class are drawn according to some distribution over all pixels estimated | |
to belong to that class. | |
""" | |
def __init__( | |
self, | |
cfg: CfgNode, | |
use_gt_categories: bool, | |
embedder: torch.nn.Module, | |
count_per_class: int = 8, | |
): | |
""" | |
Constructor | |
Args: | |
cfg (CfgNode): the config of the model | |
embedder (torch.nn.Module): necessary to compute mesh vertex embeddings | |
count_per_class (int): the sampler produces at most `count_per_class` | |
samples for each category | |
""" | |
super().__init__(count_per_class) | |
self.embedder = embedder | |
self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg) | |
self.use_gt_categories = use_gt_categories | |
def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]: | |
""" | |
Sample DensPoseDataRelative from estimation results | |
""" | |
if self.use_gt_categories: | |
instance_class = instance.dataset_classes.tolist()[0] | |
else: | |
instance_class = instance.pred_classes.tolist()[0] | |
mesh_name = self.class_to_mesh_name[instance_class] | |
annotation = { | |
DensePoseDataRelative.X_KEY: [], | |
DensePoseDataRelative.Y_KEY: [], | |
DensePoseDataRelative.VERTEX_IDS_KEY: [], | |
DensePoseDataRelative.MESH_NAME_KEY: mesh_name, | |
} | |
mask, embeddings, other_values = self._produce_mask_and_results(instance, bbox_xywh) | |
indices = torch.nonzero(mask, as_tuple=True) | |
selected_embeddings = embeddings.permute(1, 2, 0)[indices].cpu() | |
values = other_values[:, indices[0], indices[1]] | |
k = values.shape[1] | |
count = min(self.count_per_class, k) | |
if count <= 0: | |
return annotation | |
index_sample = self._produce_index_sample(values, count) | |
closest_vertices = squared_euclidean_distance_matrix( | |
selected_embeddings[index_sample], self.embedder(mesh_name) | |
) | |
closest_vertices = torch.argmin(closest_vertices, dim=1) | |
sampled_y = indices[0][index_sample] + 0.5 | |
sampled_x = indices[1][index_sample] + 0.5 | |
# prepare / normalize data | |
_, _, w, h = bbox_xywh | |
x = (sampled_x / w * 256.0).cpu().tolist() | |
y = (sampled_y / h * 256.0).cpu().tolist() | |
# extend annotations | |
annotation[DensePoseDataRelative.X_KEY].extend(x) | |
annotation[DensePoseDataRelative.Y_KEY].extend(y) | |
annotation[DensePoseDataRelative.VERTEX_IDS_KEY].extend(closest_vertices.cpu().tolist()) | |
return annotation | |
def _produce_mask_and_results( | |
self, instance: Instances, bbox_xywh: IntTupleBox | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Method to get labels and DensePose results from an instance | |
Args: | |
instance (Instances): an instance of `DensePoseEmbeddingPredictorOutput` | |
bbox_xywh (IntTupleBox): the corresponding bounding box | |
Return: | |
mask (torch.Tensor): shape [H, W], DensePose segmentation mask | |
embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W], | |
DensePose CSE Embeddings | |
other_values (Tuple[torch.Tensor]): a tensor of shape [0, H, W], | |
for potential other values | |
""" | |
densepose_output = instance.pred_densepose | |
S = densepose_output.coarse_segm | |
E = densepose_output.embedding | |
_, _, w, h = bbox_xywh | |
embeddings = F.interpolate(E, size=(h, w), mode="bilinear")[0] | |
coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0] | |
mask = coarse_segm_resized.argmax(0) > 0 | |
other_values = torch.empty((0, h, w), device=E.device) | |
return mask, embeddings, other_values | |
def _resample_mask(self, output: Any) -> torch.Tensor: | |
""" | |
Convert DensePose predictor output to segmentation annotation - tensors of size | |
(256, 256) and type `int64`. | |
Args: | |
output: DensePose predictor output with the following attributes: | |
- coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse | |
segmentation scores | |
Return: | |
Tensor of size (S, S) and type `int64` with coarse segmentation annotations, | |
where S = DensePoseDataRelative.MASK_SIZE | |
""" | |
sz = DensePoseDataRelative.MASK_SIZE | |
mask = ( | |
F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False) | |
.argmax(dim=1) | |
.long() | |
.squeeze() | |
.cpu() | |
) | |
return mask | |