Spaces:
Runtime error
Runtime error
File size: 6,369 Bytes
938e515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import pickle
from functools import lru_cache
from typing import Dict, Optional, Tuple
import torch
from detectron2.utils.file_io import PathManager
from densepose.data.meshes.catalog import MeshCatalog, MeshInfo
def _maybe_copy_to_device(
attribute: Optional[torch.Tensor], device: torch.device
) -> Optional[torch.Tensor]:
if attribute is None:
return None
return attribute.to(device)
class Mesh:
def __init__(
self,
vertices: Optional[torch.Tensor] = None,
faces: Optional[torch.Tensor] = None,
geodists: Optional[torch.Tensor] = None,
symmetry: Optional[Dict[str, torch.Tensor]] = None,
texcoords: Optional[torch.Tensor] = None,
mesh_info: Optional[MeshInfo] = None,
device: Optional[torch.device] = None,
):
"""
Args:
vertices (tensor [N, 3] of float32): vertex coordinates in 3D
faces (tensor [M, 3] of long): triangular face represented as 3
vertex indices
geodists (tensor [N, N] of float32): geodesic distances from
vertex `i` to vertex `j` (optional, default: None)
symmetry (dict: str -> tensor): various mesh symmetry data:
- "vertex_transforms": vertex mapping under horizontal flip,
tensor of size [N] of type long; vertex `i` is mapped to
vertex `tensor[i]` (optional, default: None)
texcoords (tensor [N, 2] of float32): texture coordinates, i.e. global
and normalized mesh UVs (optional, default: None)
mesh_info (MeshInfo type): necessary to load the attributes on-the-go,
can be used instead of passing all the variables one by one
device (torch.device): device of the Mesh. If not provided, will use
the device of the vertices
"""
self._vertices = vertices
self._faces = faces
self._geodists = geodists
self._symmetry = symmetry
self._texcoords = texcoords
self.mesh_info = mesh_info
self.device = device
assert self._vertices is not None or self.mesh_info is not None
all_fields = [self._vertices, self._faces, self._geodists, self._texcoords]
if self.device is None:
for field in all_fields:
if field is not None:
self.device = field.device
break
if self.device is None and symmetry is not None:
for key in symmetry:
self.device = symmetry[key].device
break
self.device = torch.device("cpu") if self.device is None else self.device
assert all([var.device == self.device for var in all_fields if var is not None])
if symmetry:
assert all(symmetry[key].device == self.device for key in symmetry)
if texcoords and vertices:
assert len(vertices) == len(texcoords)
def to(self, device: torch.device):
device_symmetry = self._symmetry
if device_symmetry:
device_symmetry = {key: value.to(device) for key, value in device_symmetry.items()}
return Mesh(
_maybe_copy_to_device(self._vertices, device),
_maybe_copy_to_device(self._faces, device),
_maybe_copy_to_device(self._geodists, device),
device_symmetry,
_maybe_copy_to_device(self._texcoords, device),
self.mesh_info,
device,
)
@property
def vertices(self):
if self._vertices is None and self.mesh_info is not None:
self._vertices = load_mesh_data(self.mesh_info.data, "vertices", self.device)
return self._vertices
@property
def faces(self):
if self._faces is None and self.mesh_info is not None:
self._faces = load_mesh_data(self.mesh_info.data, "faces", self.device)
return self._faces
@property
def geodists(self):
if self._geodists is None and self.mesh_info is not None:
self._geodists = load_mesh_auxiliary_data(self.mesh_info.geodists, self.device)
return self._geodists
@property
def symmetry(self):
if self._symmetry is None and self.mesh_info is not None:
self._symmetry = load_mesh_symmetry(self.mesh_info.symmetry, self.device)
return self._symmetry
@property
def texcoords(self):
if self._texcoords is None and self.mesh_info is not None:
self._texcoords = load_mesh_auxiliary_data(self.mesh_info.texcoords, self.device)
return self._texcoords
def get_geodists(self):
if self.geodists is None:
self.geodists = self._compute_geodists()
return self.geodists
def _compute_geodists(self):
# TODO: compute using Laplace-Beltrami
geodists = None
return geodists
def load_mesh_data(
mesh_fpath: str, field: str, device: Optional[torch.device] = None
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
with PathManager.open(mesh_fpath, "rb") as hFile:
# pyre-fixme[7]: Expected `Tuple[Optional[Tensor], Optional[Tensor]]` but
# got `Tensor`.
return torch.as_tensor(pickle.load(hFile)[field], dtype=torch.float).to(device)
return None
def load_mesh_auxiliary_data(
fpath: str, device: Optional[torch.device] = None
) -> Optional[torch.Tensor]:
fpath_local = PathManager.get_local_path(fpath)
with PathManager.open(fpath_local, "rb") as hFile:
return torch.as_tensor(pickle.load(hFile), dtype=torch.float).to(device)
return None
@lru_cache()
def load_mesh_symmetry(
symmetry_fpath: str, device: Optional[torch.device] = None
) -> Optional[Dict[str, torch.Tensor]]:
with PathManager.open(symmetry_fpath, "rb") as hFile:
symmetry_loaded = pickle.load(hFile)
symmetry = {
"vertex_transforms": torch.as_tensor(
symmetry_loaded["vertex_transforms"], dtype=torch.long
).to(device),
}
return symmetry
return None
@lru_cache()
def create_mesh(mesh_name: str, device: Optional[torch.device] = None) -> Mesh:
return Mesh(mesh_info=MeshCatalog[mesh_name], device=device)
|