|
|
|
|
|
import logging |
|
import numpy as np |
|
import pickle |
|
from enum import Enum |
|
from typing import Optional |
|
import torch |
|
from torch import nn |
|
|
|
from detectron2.config import CfgNode |
|
from detectron2.utils.file_io import PathManager |
|
|
|
from .vertex_direct_embedder import VertexDirectEmbedder |
|
from .vertex_feature_embedder import VertexFeatureEmbedder |
|
|
|
|
|
class EmbedderType(Enum): |
|
""" |
|
Embedder type which defines how vertices are mapped into the embedding space: |
|
- "vertex_direct": direct vertex embedding |
|
- "vertex_feature": embedding vertex features |
|
""" |
|
|
|
VERTEX_DIRECT = "vertex_direct" |
|
VERTEX_FEATURE = "vertex_feature" |
|
|
|
|
|
def create_embedder(embedder_spec: CfgNode, embedder_dim: int) -> nn.Module: |
|
""" |
|
Create an embedder based on the provided configuration |
|
|
|
Args: |
|
embedder_spec (CfgNode): embedder configuration |
|
embedder_dim (int): embedding space dimensionality |
|
Return: |
|
An embedder instance for the specified configuration |
|
Raises ValueError, in case of unexpected embedder type |
|
""" |
|
embedder_type = EmbedderType(embedder_spec.TYPE) |
|
if embedder_type == EmbedderType.VERTEX_DIRECT: |
|
embedder = VertexDirectEmbedder( |
|
num_vertices=embedder_spec.NUM_VERTICES, |
|
embed_dim=embedder_dim, |
|
) |
|
if embedder_spec.INIT_FILE != "": |
|
embedder.load(embedder_spec.INIT_FILE) |
|
elif embedder_type == EmbedderType.VERTEX_FEATURE: |
|
embedder = VertexFeatureEmbedder( |
|
num_vertices=embedder_spec.NUM_VERTICES, |
|
feature_dim=embedder_spec.FEATURE_DIM, |
|
embed_dim=embedder_dim, |
|
train_features=embedder_spec.FEATURES_TRAINABLE, |
|
) |
|
if embedder_spec.INIT_FILE != "": |
|
embedder.load(embedder_spec.INIT_FILE) |
|
else: |
|
raise ValueError(f"Unexpected embedder type {embedder_type}") |
|
|
|
if not embedder_spec.IS_TRAINABLE: |
|
embedder.requires_grad_(False) |
|
|
|
return embedder |
|
|
|
|
|
class Embedder(nn.Module): |
|
""" |
|
Embedder module that serves as a container for embedders to use with different |
|
meshes. Extends Module to automatically save / load state dict. |
|
""" |
|
|
|
DEFAULT_MODEL_CHECKPOINT_PREFIX = "roi_heads.embedder." |
|
|
|
def __init__(self, cfg: CfgNode): |
|
""" |
|
Initialize mesh embedders. An embedder for mesh `i` is stored in a submodule |
|
"embedder_{i}". |
|
|
|
Args: |
|
cfg (CfgNode): configuration options |
|
""" |
|
super(Embedder, self).__init__() |
|
self.mesh_names = set() |
|
embedder_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE |
|
logger = logging.getLogger(__name__) |
|
for mesh_name, embedder_spec in cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.items(): |
|
logger.info(f"Adding embedder embedder_{mesh_name} with spec {embedder_spec}") |
|
self.add_module(f"embedder_{mesh_name}", create_embedder(embedder_spec, embedder_dim)) |
|
self.mesh_names.add(mesh_name) |
|
if cfg.MODEL.WEIGHTS != "": |
|
self.load_from_model_checkpoint(cfg.MODEL.WEIGHTS) |
|
|
|
def load_from_model_checkpoint(self, fpath: str, prefix: Optional[str] = None): |
|
if prefix is None: |
|
prefix = Embedder.DEFAULT_MODEL_CHECKPOINT_PREFIX |
|
state_dict = None |
|
if fpath.endswith(".pkl"): |
|
with PathManager.open(fpath, "rb") as hFile: |
|
state_dict = pickle.load(hFile, encoding="latin1") |
|
else: |
|
with PathManager.open(fpath, "rb") as hFile: |
|
state_dict = torch.load(hFile, map_location=torch.device("cpu")) |
|
if state_dict is not None and "model" in state_dict: |
|
state_dict_local = {} |
|
for key in state_dict["model"]: |
|
if key.startswith(prefix): |
|
v_key = state_dict["model"][key] |
|
if isinstance(v_key, np.ndarray): |
|
v_key = torch.from_numpy(v_key) |
|
state_dict_local[key[len(prefix) :]] = v_key |
|
|
|
self.load_state_dict(state_dict_local, strict=False) |
|
|
|
def forward(self, mesh_name: str) -> torch.Tensor: |
|
""" |
|
Produce vertex embeddings for the specific mesh; vertex embeddings are |
|
a tensor of shape [N, D] where: |
|
N = number of vertices |
|
D = number of dimensions in the embedding space |
|
Args: |
|
mesh_name (str): name of a mesh for which to obtain vertex embeddings |
|
Return: |
|
Vertex embeddings, a tensor of shape [N, D] |
|
""" |
|
return getattr(self, f"embedder_{mesh_name}")() |
|
|
|
def has_embeddings(self, mesh_name: str) -> bool: |
|
return hasattr(self, f"embedder_{mesh_name}") |
|
|