Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import pickle | |
from functools import lru_cache | |
from typing import Dict, Optional, Tuple | |
import torch | |
from detectron2.utils.file_io import PathManager | |
from densepose.data.meshes.catalog import MeshCatalog, MeshInfo | |
def _maybe_copy_to_device( | |
attribute: Optional[torch.Tensor], device: torch.device | |
) -> Optional[torch.Tensor]: | |
if attribute is None: | |
return None | |
return attribute.to(device) | |
class Mesh: | |
def __init__( | |
self, | |
vertices: Optional[torch.Tensor] = None, | |
faces: Optional[torch.Tensor] = None, | |
geodists: Optional[torch.Tensor] = None, | |
symmetry: Optional[Dict[str, torch.Tensor]] = None, | |
texcoords: Optional[torch.Tensor] = None, | |
mesh_info: Optional[MeshInfo] = None, | |
device: Optional[torch.device] = None, | |
): | |
""" | |
Args: | |
vertices (tensor [N, 3] of float32): vertex coordinates in 3D | |
faces (tensor [M, 3] of long): triangular face represented as 3 | |
vertex indices | |
geodists (tensor [N, N] of float32): geodesic distances from | |
vertex `i` to vertex `j` (optional, default: None) | |
symmetry (dict: str -> tensor): various mesh symmetry data: | |
- "vertex_transforms": vertex mapping under horizontal flip, | |
tensor of size [N] of type long; vertex `i` is mapped to | |
vertex `tensor[i]` (optional, default: None) | |
texcoords (tensor [N, 2] of float32): texture coordinates, i.e. global | |
and normalized mesh UVs (optional, default: None) | |
mesh_info (MeshInfo type): necessary to load the attributes on-the-go, | |
can be used instead of passing all the variables one by one | |
device (torch.device): device of the Mesh. If not provided, will use | |
the device of the vertices | |
""" | |
self._vertices = vertices | |
self._faces = faces | |
self._geodists = geodists | |
self._symmetry = symmetry | |
self._texcoords = texcoords | |
self.mesh_info = mesh_info | |
self.device = device | |
assert self._vertices is not None or self.mesh_info is not None | |
all_fields = [self._vertices, self._faces, self._geodists, self._texcoords] | |
if self.device is None: | |
for field in all_fields: | |
if field is not None: | |
self.device = field.device | |
break | |
if self.device is None and symmetry is not None: | |
for key in symmetry: | |
self.device = symmetry[key].device | |
break | |
self.device = torch.device("cpu") if self.device is None else self.device | |
assert all([var.device == self.device for var in all_fields if var is not None]) | |
if symmetry: | |
assert all(symmetry[key].device == self.device for key in symmetry) | |
if texcoords and vertices: | |
assert len(vertices) == len(texcoords) | |
def to(self, device: torch.device): | |
device_symmetry = self._symmetry | |
if device_symmetry: | |
device_symmetry = {key: value.to(device) for key, value in device_symmetry.items()} | |
return Mesh( | |
_maybe_copy_to_device(self._vertices, device), | |
_maybe_copy_to_device(self._faces, device), | |
_maybe_copy_to_device(self._geodists, device), | |
device_symmetry, | |
_maybe_copy_to_device(self._texcoords, device), | |
self.mesh_info, | |
device, | |
) | |
def vertices(self): | |
if self._vertices is None and self.mesh_info is not None: | |
self._vertices = load_mesh_data(self.mesh_info.data, "vertices", self.device) | |
return self._vertices | |
def faces(self): | |
if self._faces is None and self.mesh_info is not None: | |
self._faces = load_mesh_data(self.mesh_info.data, "faces", self.device) | |
return self._faces | |
def geodists(self): | |
if self._geodists is None and self.mesh_info is not None: | |
self._geodists = load_mesh_auxiliary_data(self.mesh_info.geodists, self.device) | |
return self._geodists | |
def symmetry(self): | |
if self._symmetry is None and self.mesh_info is not None: | |
self._symmetry = load_mesh_symmetry(self.mesh_info.symmetry, self.device) | |
return self._symmetry | |
def texcoords(self): | |
if self._texcoords is None and self.mesh_info is not None: | |
self._texcoords = load_mesh_auxiliary_data(self.mesh_info.texcoords, self.device) | |
return self._texcoords | |
def get_geodists(self): | |
if self.geodists is None: | |
self.geodists = self._compute_geodists() | |
return self.geodists | |
def _compute_geodists(self): | |
# TODO: compute using Laplace-Beltrami | |
geodists = None | |
return geodists | |
def load_mesh_data( | |
mesh_fpath: str, field: str, device: Optional[torch.device] = None | |
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: | |
with PathManager.open(mesh_fpath, "rb") as hFile: | |
# pyre-fixme[7]: Expected `Tuple[Optional[Tensor], Optional[Tensor]]` but | |
# got `Tensor`. | |
return torch.as_tensor(pickle.load(hFile)[field], dtype=torch.float).to(device) | |
return None | |
def load_mesh_auxiliary_data( | |
fpath: str, device: Optional[torch.device] = None | |
) -> Optional[torch.Tensor]: | |
fpath_local = PathManager.get_local_path(fpath) | |
with PathManager.open(fpath_local, "rb") as hFile: | |
return torch.as_tensor(pickle.load(hFile), dtype=torch.float).to(device) | |
return None | |
def load_mesh_symmetry( | |
symmetry_fpath: str, device: Optional[torch.device] = None | |
) -> Optional[Dict[str, torch.Tensor]]: | |
with PathManager.open(symmetry_fpath, "rb") as hFile: | |
symmetry_loaded = pickle.load(hFile) | |
symmetry = { | |
"vertex_transforms": torch.as_tensor( | |
symmetry_loaded["vertex_transforms"], dtype=torch.long | |
).to(device), | |
} | |
return symmetry | |
return None | |
def create_mesh(mesh_name: str, device: Optional[torch.device] = None) -> Mesh: | |
return Mesh(mesh_info=MeshCatalog[mesh_name], device=device) | |