|
|
|
|
|
import random
|
|
from typing import Tuple
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from detectron2.config import CfgNode
|
|
|
|
from densepose.structures.mesh import create_mesh
|
|
|
|
from .utils import sample_random_indices
|
|
|
|
|
|
class ShapeToShapeCycleLoss(nn.Module):
|
|
"""
|
|
Cycle Loss for Shapes.
|
|
Inspired by:
|
|
"Mapping in a Cycle: Sinkhorn Regularized Unsupervised Learning for Point Cloud Shapes".
|
|
"""
|
|
|
|
def __init__(self, cfg: CfgNode):
|
|
super().__init__()
|
|
self.shape_names = list(cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.keys())
|
|
self.all_shape_pairs = [
|
|
(x, y) for i, x in enumerate(self.shape_names) for y in self.shape_names[i + 1 :]
|
|
]
|
|
random.shuffle(self.all_shape_pairs)
|
|
self.cur_pos = 0
|
|
self.norm_p = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P
|
|
self.temperature = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE
|
|
self.max_num_vertices = (
|
|
cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES
|
|
)
|
|
|
|
def _sample_random_pair(self) -> Tuple[str, str]:
|
|
"""
|
|
Produce a random pair of different mesh names
|
|
|
|
Return:
|
|
tuple(str, str): a pair of different mesh names
|
|
"""
|
|
if self.cur_pos >= len(self.all_shape_pairs):
|
|
random.shuffle(self.all_shape_pairs)
|
|
self.cur_pos = 0
|
|
shape_pair = self.all_shape_pairs[self.cur_pos]
|
|
self.cur_pos += 1
|
|
return shape_pair
|
|
|
|
def forward(self, embedder: nn.Module):
|
|
"""
|
|
Do a forward pass with a random pair (src, dst) pair of shapes
|
|
Args:
|
|
embedder (nn.Module): module that computes vertex embeddings for different meshes
|
|
"""
|
|
src_mesh_name, dst_mesh_name = self._sample_random_pair()
|
|
return self._forward_one_pair(embedder, src_mesh_name, dst_mesh_name)
|
|
|
|
def fake_value(self, embedder: nn.Module):
|
|
losses = []
|
|
for mesh_name in embedder.mesh_names:
|
|
losses.append(embedder(mesh_name).sum() * 0)
|
|
return torch.mean(torch.stack(losses))
|
|
|
|
def _get_embeddings_and_geodists_for_mesh(
|
|
self, embedder: nn.Module, mesh_name: str
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Produces embeddings and geodesic distance tensors for a given mesh. May subsample
|
|
the mesh, if it contains too many vertices (controlled by
|
|
SHAPE_CYCLE_LOSS_MAX_NUM_VERTICES parameter).
|
|
Args:
|
|
embedder (nn.Module): module that computes embeddings for mesh vertices
|
|
mesh_name (str): mesh name
|
|
Return:
|
|
embeddings (torch.Tensor of size [N, D]): embeddings for selected mesh
|
|
vertices (N = number of selected vertices, D = embedding space dim)
|
|
geodists (torch.Tensor of size [N, N]): geodesic distances for the selected
|
|
mesh vertices (N = number of selected vertices)
|
|
"""
|
|
embeddings = embedder(mesh_name)
|
|
indices = sample_random_indices(
|
|
embeddings.shape[0], self.max_num_vertices, embeddings.device
|
|
)
|
|
mesh = create_mesh(mesh_name, embeddings.device)
|
|
geodists = mesh.geodists
|
|
if indices is not None:
|
|
embeddings = embeddings[indices]
|
|
geodists = geodists[torch.meshgrid(indices, indices)]
|
|
return embeddings, geodists
|
|
|
|
def _forward_one_pair(
|
|
self, embedder: nn.Module, mesh_name_1: str, mesh_name_2: str
|
|
) -> torch.Tensor:
|
|
"""
|
|
Do a forward pass with a selected pair of meshes
|
|
Args:
|
|
embedder (nn.Module): module that computes vertex embeddings for different meshes
|
|
mesh_name_1 (str): first mesh name
|
|
mesh_name_2 (str): second mesh name
|
|
Return:
|
|
Tensor containing the loss value
|
|
"""
|
|
embeddings_1, geodists_1 = self._get_embeddings_and_geodists_for_mesh(embedder, mesh_name_1)
|
|
embeddings_2, geodists_2 = self._get_embeddings_and_geodists_for_mesh(embedder, mesh_name_2)
|
|
sim_matrix_12 = embeddings_1.mm(embeddings_2.T)
|
|
|
|
c_12 = F.softmax(sim_matrix_12 / self.temperature, dim=1)
|
|
c_21 = F.softmax(sim_matrix_12.T / self.temperature, dim=1)
|
|
c_11 = c_12.mm(c_21)
|
|
c_22 = c_21.mm(c_12)
|
|
|
|
loss_cycle_11 = torch.norm(geodists_1 * c_11, p=self.norm_p)
|
|
loss_cycle_22 = torch.norm(geodists_2 * c_22, p=self.norm_p)
|
|
|
|
return loss_cycle_11 + loss_cycle_22
|
|
|