StableVITON / densepose /structures /chart_confidence.py
rlawjdghek's picture
det2 (#6)
1527335 verified
raw
history blame
3.78 kB
# Copyright (c) Facebook, Inc. and its affiliates.
from dataclasses import make_dataclass
from functools import lru_cache
from typing import Any, Optional
import torch
@lru_cache(maxsize=None)
def decorate_predictor_output_class_with_confidences(BasePredictorOutput: type) -> type:
"""
Create a new output class from an existing one by adding new attributes
related to confidence estimation:
- sigma_1 (tensor)
- sigma_2 (tensor)
- kappa_u (tensor)
- kappa_v (tensor)
- fine_segm_confidence (tensor)
- coarse_segm_confidence (tensor)
Details on confidence estimation parameters can be found in:
N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning
Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019
A. Sanakoyeu et al., Transferring Dense Pose to Proximal Animal Classes, CVPR 2020
The new class inherits the provided `BasePredictorOutput` class,
it's name is composed of the name of the provided class and
"WithConfidences" suffix.
Args:
BasePredictorOutput (type): output type to which confidence data
is to be added, assumed to be a dataclass
Return:
New dataclass derived from the provided one that has attributes
for confidence estimation
"""
PredictorOutput = make_dataclass(
BasePredictorOutput.__name__ + "WithConfidences",
fields=[
("sigma_1", Optional[torch.Tensor], None),
("sigma_2", Optional[torch.Tensor], None),
("kappa_u", Optional[torch.Tensor], None),
("kappa_v", Optional[torch.Tensor], None),
("fine_segm_confidence", Optional[torch.Tensor], None),
("coarse_segm_confidence", Optional[torch.Tensor], None),
],
bases=(BasePredictorOutput,),
)
# add possibility to index PredictorOutput
def slice_if_not_none(data, item):
if data is None:
return None
if isinstance(item, int):
return data[item].unsqueeze(0)
return data[item]
def PredictorOutput_getitem(self, item):
PredictorOutput = type(self)
base_predictor_output_sliced = super(PredictorOutput, self).__getitem__(item)
return PredictorOutput(
**base_predictor_output_sliced.__dict__,
coarse_segm_confidence=slice_if_not_none(self.coarse_segm_confidence, item),
fine_segm_confidence=slice_if_not_none(self.fine_segm_confidence, item),
sigma_1=slice_if_not_none(self.sigma_1, item),
sigma_2=slice_if_not_none(self.sigma_2, item),
kappa_u=slice_if_not_none(self.kappa_u, item),
kappa_v=slice_if_not_none(self.kappa_v, item),
)
PredictorOutput.__getitem__ = PredictorOutput_getitem
def PredictorOutput_to(self, device: torch.device):
"""
Transfers all tensors to the given device
"""
PredictorOutput = type(self)
base_predictor_output_to = super(PredictorOutput, self).to(device) # pyre-ignore[16]
def to_device_if_tensor(var: Any):
if isinstance(var, torch.Tensor):
return var.to(device)
return var
return PredictorOutput(
**base_predictor_output_to.__dict__,
sigma_1=to_device_if_tensor(self.sigma_1),
sigma_2=to_device_if_tensor(self.sigma_2),
kappa_u=to_device_if_tensor(self.kappa_u),
kappa_v=to_device_if_tensor(self.kappa_v),
fine_segm_confidence=to_device_if_tensor(self.fine_segm_confidence),
coarse_segm_confidence=to_device_if_tensor(self.coarse_segm_confidence),
)
PredictorOutput.to = PredictorOutput_to
return PredictorOutput