|
|
|
|
|
from typing import Any, Dict, List |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from detectron2.config import CfgNode |
|
from detectron2.structures import Instances |
|
|
|
from densepose.data.meshes.catalog import MeshCatalog |
|
from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix |
|
|
|
from .embed_utils import PackedCseAnnotations |
|
from .utils import BilinearInterpolationHelper |
|
|
|
|
|
class EmbeddingLoss: |
|
""" |
|
Computes losses for estimated embeddings given annotated vertices. |
|
Instances in a minibatch that correspond to the same mesh are grouped |
|
together. For each group, loss is computed as cross-entropy for |
|
unnormalized scores given ground truth mesh vertex ids. |
|
Scores are based on squared distances between estimated vertex embeddings |
|
and mesh vertex embeddings. |
|
""" |
|
|
|
def __init__(self, cfg: CfgNode): |
|
""" |
|
Initialize embedding loss from config |
|
""" |
|
self.embdist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA |
|
|
|
def __call__( |
|
self, |
|
proposals_with_gt: List[Instances], |
|
densepose_predictor_outputs: Any, |
|
packed_annotations: PackedCseAnnotations, |
|
interpolator: BilinearInterpolationHelper, |
|
embedder: nn.Module, |
|
) -> Dict[int, torch.Tensor]: |
|
""" |
|
Produces losses for estimated embeddings given annotated vertices. |
|
Embeddings for all the vertices of a mesh are computed by the embedder. |
|
Embeddings for observed pixels are estimated by a predictor. |
|
Losses are computed as cross-entropy for squared distances between |
|
observed vertex embeddings and all mesh vertex embeddings given |
|
ground truth vertex IDs. |
|
|
|
Args: |
|
proposals_with_gt (list of Instances): detections with associated |
|
ground truth data; each item corresponds to instances detected |
|
on 1 image; the number of items corresponds to the number of |
|
images in a batch |
|
densepose_predictor_outputs: an object of a dataclass that contains predictor |
|
outputs with estimated values; assumed to have the following attributes: |
|
* embedding - embedding estimates, tensor of shape [N, D, S, S], where |
|
N = number of instances (= sum N_i, where N_i is the number of |
|
instances on image i) |
|
D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE) |
|
S = output size (width and height) |
|
packed_annotations (PackedCseAnnotations): contains various data useful |
|
for loss computation, each data is packed into a single tensor |
|
interpolator (BilinearInterpolationHelper): bilinear interpolation helper |
|
embedder (nn.Module): module that computes vertex embeddings for different meshes |
|
Return: |
|
dict(int -> tensor): losses for different mesh IDs |
|
""" |
|
losses = {} |
|
for mesh_id_tensor in packed_annotations.vertex_mesh_ids_gt.unique(): |
|
mesh_id = mesh_id_tensor.item() |
|
mesh_name = MeshCatalog.get_mesh_name(mesh_id) |
|
|
|
|
|
j_valid = interpolator.j_valid * ( |
|
packed_annotations.vertex_mesh_ids_gt == mesh_id |
|
) |
|
if not torch.any(j_valid): |
|
continue |
|
|
|
|
|
vertex_embeddings_i = normalize_embeddings( |
|
interpolator.extract_at_points( |
|
densepose_predictor_outputs.embedding, |
|
slice_fine_segm=slice(None), |
|
w_ylo_xlo=interpolator.w_ylo_xlo[:, None], |
|
w_ylo_xhi=interpolator.w_ylo_xhi[:, None], |
|
w_yhi_xlo=interpolator.w_yhi_xlo[:, None], |
|
w_yhi_xhi=interpolator.w_yhi_xhi[:, None], |
|
)[j_valid, :] |
|
) |
|
|
|
|
|
vertex_indices_i = packed_annotations.vertex_ids_gt[j_valid] |
|
|
|
|
|
mesh_vertex_embeddings = embedder(mesh_name) |
|
|
|
|
|
scores = squared_euclidean_distance_matrix( |
|
vertex_embeddings_i, mesh_vertex_embeddings |
|
) / (-self.embdist_gauss_sigma) |
|
losses[mesh_name] = F.cross_entropy(scores, vertex_indices_i, ignore_index=-1) |
|
|
|
for mesh_name in embedder.mesh_names: |
|
if mesh_name not in losses: |
|
losses[mesh_name] = self.fake_value( |
|
densepose_predictor_outputs, embedder, mesh_name |
|
) |
|
return losses |
|
|
|
def fake_values(self, densepose_predictor_outputs: Any, embedder: nn.Module): |
|
losses = {} |
|
for mesh_name in embedder.mesh_names: |
|
losses[mesh_name] = self.fake_value(densepose_predictor_outputs, embedder, mesh_name) |
|
return losses |
|
|
|
def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module, mesh_name: str): |
|
return densepose_predictor_outputs.embedding.sum() * 0 + embedder(mesh_name).sum() * 0 |
|
|