|
|
|
|
|
import random |
|
from typing import Optional, Tuple |
|
import torch |
|
|
|
from densepose.converters import ToChartResultConverterWithConfidences |
|
|
|
from .densepose_base import DensePoseBaseSampler |
|
|
|
|
|
class DensePoseConfidenceBasedSampler(DensePoseBaseSampler): |
|
""" |
|
Samples DensePose data from DensePose predictions. |
|
Samples for each class are drawn using confidence value estimates. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
confidence_channel: str, |
|
count_per_class: int = 8, |
|
search_count_multiplier: Optional[float] = None, |
|
search_proportion: Optional[float] = None, |
|
): |
|
""" |
|
Constructor |
|
|
|
Args: |
|
confidence_channel (str): confidence channel to use for sampling; |
|
possible values: |
|
"sigma_2": confidences for UV values |
|
"fine_segm_confidence": confidences for fine segmentation |
|
"coarse_segm_confidence": confidences for coarse segmentation |
|
(default: "sigma_2") |
|
count_per_class (int): the sampler produces at most `count_per_class` |
|
samples for each category (default: 8) |
|
search_count_multiplier (float or None): if not None, the total number |
|
of the most confident estimates of a given class to consider is |
|
defined as `min(search_count_multiplier * count_per_class, N)`, |
|
where `N` is the total number of estimates of the class; cannot be |
|
specified together with `search_proportion` (default: None) |
|
search_proportion (float or None): if not None, the total number of the |
|
of the most confident estimates of a given class to consider is |
|
defined as `min(max(search_proportion * N, count_per_class), N)`, |
|
where `N` is the total number of estimates of the class; cannot be |
|
specified together with `search_count_multiplier` (default: None) |
|
""" |
|
super().__init__(count_per_class) |
|
self.confidence_channel = confidence_channel |
|
self.search_count_multiplier = search_count_multiplier |
|
self.search_proportion = search_proportion |
|
assert (search_count_multiplier is None) or (search_proportion is None), ( |
|
f"Cannot specify both search_count_multiplier (={search_count_multiplier})" |
|
f"and search_proportion (={search_proportion})" |
|
) |
|
|
|
def _produce_index_sample(self, values: torch.Tensor, count: int): |
|
""" |
|
Produce a sample of indices to select data based on confidences |
|
|
|
Args: |
|
values (torch.Tensor): an array of size [n, k] that contains |
|
estimated values (U, V, confidences); |
|
n: number of channels (U, V, confidences) |
|
k: number of points labeled with part_id |
|
count (int): number of samples to produce, should be positive and <= k |
|
|
|
Return: |
|
list(int): indices of values (along axis 1) selected as a sample |
|
""" |
|
k = values.shape[1] |
|
if k == count: |
|
index_sample = list(range(k)) |
|
else: |
|
|
|
|
|
|
|
_, sorted_confidence_indices = torch.sort(values[2]) |
|
if self.search_count_multiplier is not None: |
|
search_count = min(int(count * self.search_count_multiplier), k) |
|
elif self.search_proportion is not None: |
|
search_count = min(max(int(k * self.search_proportion), count), k) |
|
else: |
|
search_count = min(count, k) |
|
sample_from_top = random.sample(range(search_count), count) |
|
index_sample = sorted_confidence_indices[:search_count][sample_from_top] |
|
return index_sample |
|
|
|
def _produce_labels_and_results(self, instance) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Method to get labels and DensePose results from an instance, with confidences |
|
|
|
Args: |
|
instance (Instances): an instance of `DensePoseChartPredictorOutputWithConfidences` |
|
|
|
Return: |
|
labels (torch.Tensor): shape [H, W], DensePose segmentation labels |
|
dp_result (torch.Tensor): shape [3, H, W], DensePose results u and v |
|
stacked with the confidence channel |
|
""" |
|
converter = ToChartResultConverterWithConfidences |
|
chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes) |
|
labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu() |
|
dp_result = torch.cat( |
|
(dp_result, getattr(chart_result, self.confidence_channel)[None].cpu()) |
|
) |
|
|
|
return labels, dp_result |
|
|