# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

# pyre-unsafe

import logging
import numpy as np
import pickle
from enum import Enum
from typing import Optional
import torch
from torch import nn

from detectron2.config import CfgNode
from detectron2.utils.file_io import PathManager

from .vertex_direct_embedder import VertexDirectEmbedder
from .vertex_feature_embedder import VertexFeatureEmbedder


class EmbedderType(Enum):
    """
    Embedder type which defines how vertices are mapped into the embedding space:
     - "vertex_direct": direct vertex embedding
     - "vertex_feature": embedding vertex features
    """

    VERTEX_DIRECT = "vertex_direct"
    VERTEX_FEATURE = "vertex_feature"


def create_embedder(embedder_spec: CfgNode, embedder_dim: int) -> nn.Module:
    """
    Create an embedder based on the provided configuration

    Args:
        embedder_spec (CfgNode): embedder configuration
        embedder_dim (int): embedding space dimensionality
    Return:
        An embedder instance for the specified configuration
        Raises ValueError, in case of unexpected  embedder type
    """
    embedder_type = EmbedderType(embedder_spec.TYPE)
    if embedder_type == EmbedderType.VERTEX_DIRECT:
        embedder = VertexDirectEmbedder(
            num_vertices=embedder_spec.NUM_VERTICES,
            embed_dim=embedder_dim,
        )
        if embedder_spec.INIT_FILE != "":
            embedder.load(embedder_spec.INIT_FILE)
    elif embedder_type == EmbedderType.VERTEX_FEATURE:
        embedder = VertexFeatureEmbedder(
            num_vertices=embedder_spec.NUM_VERTICES,
            feature_dim=embedder_spec.FEATURE_DIM,
            embed_dim=embedder_dim,
            train_features=embedder_spec.FEATURES_TRAINABLE,
        )
        if embedder_spec.INIT_FILE != "":
            embedder.load(embedder_spec.INIT_FILE)
    else:
        raise ValueError(f"Unexpected embedder type {embedder_type}")

    if not embedder_spec.IS_TRAINABLE:
        embedder.requires_grad_(False)

    return embedder


class Embedder(nn.Module):
    """
    Embedder module that serves as a container for embedders to use with different
    meshes. Extends Module to automatically save / load state dict.
    """

    DEFAULT_MODEL_CHECKPOINT_PREFIX = "roi_heads.embedder."

    def __init__(self, cfg: CfgNode):
        """
        Initialize mesh embedders. An embedder for mesh `i` is stored in a submodule
        "embedder_{i}".

        Args:
            cfg (CfgNode): configuration options
        """
        super(Embedder, self).__init__()
        self.mesh_names = set()
        embedder_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
        logger = logging.getLogger(__name__)
        for mesh_name, embedder_spec in cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.items():
            logger.info(f"Adding embedder embedder_{mesh_name} with spec {embedder_spec}")
            self.add_module(f"embedder_{mesh_name}", create_embedder(embedder_spec, embedder_dim))
            self.mesh_names.add(mesh_name)
        if cfg.MODEL.WEIGHTS != "":
            self.load_from_model_checkpoint(cfg.MODEL.WEIGHTS)

    def load_from_model_checkpoint(self, fpath: str, prefix: Optional[str] = None):
        if prefix is None:
            prefix = Embedder.DEFAULT_MODEL_CHECKPOINT_PREFIX
        state_dict = None
        if fpath.endswith(".pkl"):
            with PathManager.open(fpath, "rb") as hFile:
                state_dict = pickle.load(hFile, encoding="latin1")
        else:
            with PathManager.open(fpath, "rb") as hFile:
                state_dict = torch.load(hFile, map_location=torch.device("cpu"))
        if state_dict is not None and "model" in state_dict:
            state_dict_local = {}
            for key in state_dict["model"]:
                if key.startswith(prefix):
                    v_key = state_dict["model"][key]
                    if isinstance(v_key, np.ndarray):
                        v_key = torch.from_numpy(v_key)
                    state_dict_local[key[len(prefix) :]] = v_key
            # non-strict loading to finetune on different meshes
            self.load_state_dict(state_dict_local, strict=False)

    def forward(self, mesh_name: str) -> torch.Tensor:
        """
        Produce vertex embeddings for the specific mesh; vertex embeddings are
        a tensor of shape [N, D] where:
            N = number of vertices
            D = number of dimensions in the embedding space
        Args:
            mesh_name (str): name of a mesh for which to obtain vertex embeddings
        Return:
            Vertex embeddings, a tensor of shape [N, D]
        """
        return getattr(self, f"embedder_{mesh_name}")()

    def has_embeddings(self, mesh_name: str) -> bool:
        return hasattr(self, f"embedder_{mesh_name}")