"""Utilities for converting Graphein Networks to Geometric Deep Learning formats.
"""
# %%
# Graphein
# Author: Kexin Huang, Arian Jamasb <arian@jamasb.io>
# License: MIT
# Project Website: https://github.com/a-r-j/graphein
# Code Repository: https://github.com/a-r-j/graphein
from __future__ import annotations

from typing import List, Optional

import networkx as nx
import numpy as np
import torch

try:
    from graphein.utils.dependencies import import_message
except ImportError:
    raise Exception('You need to install graphein from source in addition to DSSP to use this model please refer to https://github.com/a-r-j/graphein and https://ssbio.readthedocs.io/en/latest/instructions/dssp.html')

try:
    import torch_geometric
    from torch_geometric.data import Data
except ImportError:
    import_message(
        submodule="graphein.ml.conversion",
        package="torch_geometric",
        pip_install=True,
        conda_channel="rusty1s",
    )

try:
    import dgl
except ImportError:
    import_message(
        submodule="graphein.ml.conversion",
        package="dgl",
        pip_install=True,
        conda_channel="dglteam",
    )

try:
    import jax.numpy as jnp
except ImportError:
    import_message(
        submodule="graphein.ml.conversion",
        package="jax",
        pip_install=True,
        conda_channel="conda-forge",
    )
try:
    import jraph
except ImportError:
    import_message(
        submodule="graphein.ml.conversion",
        package="jraph",
        pip_install=True,
        conda_channel="conda-forge",
    )


SUPPORTED_FORMATS = ["nx", "pyg", "dgl", "jraph"]
"""Supported conversion formats.

``"nx"``: NetworkX graph

``"pyg"``: PyTorch Geometric Data object

``"dgl"``: DGL graph

``"Jraph"``: Jraph GraphsTuple
"""

SUPPORTED_VERBOSITY = ["gnn", "default", "all_info"]
"""Supported verbosity levels for preserving graph features in conversion."""


class GraphFormatConvertor:
    """
    Provides conversion utilities between NetworkX Graphs and geometric deep learning library destination formats.
    Currently, we provide support for converstion from ``nx.Graph`` to ``dgl.DGLGraph`` and ``pytorch_geometric.Data``. Supported conversion
    formats can be retrieved from :const:`~graphein.ml.conversion.SUPPORTED_FORMATS`.

    :param src_format: The type of graph you'd like to convert from. Supported formats are available in :const:`~graphein.ml.conversion.SUPPORTED_FORMATS`
    :type src_format: Literal["nx", "pyg", "dgl", "jraph"]
    :param dst_format: The type of graph format you'd like to convert to. Supported formats are available in:
        ``graphein.ml.conversion.SUPPORTED_FORMATS``
    :type dst_format:  Literal["nx", "pyg", "dgl", "jraph"]
    :param verbose: Select from ``"gnn"``, ``"default"``, ``"all_info"`` to determine how much information is preserved (features)
        as some are unsupported by various downstream frameworks
    :type verbose: graphein.ml.conversion.SUPPORTED_VERBOSITY
    :param columns: List of columns in the node features to retain
    :type columns: List[str], optional
    """

    def __init__(
        self,
        src_format: str,
        dst_format: str,
        verbose: SUPPORTED_VERBOSITY = "gnn",
        columns: Optional[List[str]] = None,
    ):
        if (src_format not in SUPPORTED_FORMATS) or (
            dst_format not in SUPPORTED_FORMATS
        ):
            raise ValueError(
                "Please specify from supported format, "
                + "/".join(SUPPORTED_FORMATS)
            )
        self.src_format = src_format
        self.dst_format = dst_format

        # supported_verbose_format = ["gnn", "default", "all_info"]
        if (columns is None) and (verbose not in SUPPORTED_VERBOSITY):
            raise ValueError(
                "Please specify the supported verbose mode ("
                + "/".join(SUPPORTED_VERBOSITY)
                + ") or specify column names!"
            )

        if columns is None:
            if verbose == "gnn":
                columns = [
                    "edge_index",
                    "coords",
                    "dist_mat",
                    "name",
                    "node_id",
                ]
            elif verbose == "default":
                columns = [
                    "b_factor",
                    "chain_id",
                    "coords",
                    "dist_mat",
                    "edge_index",
                    "kind",
                    "name",
                    "node_id",
                    "residue_name",
                ]
            elif verbose == "all_info":
                columns = [
                    "atom_type",
                    "b_factor",
                    "chain_id",
                    "chain_ids",
                    "config",
                    "coords",
                    "dist_mat",
                    "edge_index",
                    "element_symbol",
                    "kind",
                    "name",
                    "node_id",
                    "node_type",
                    "pdb_df",
                    "raw_pdb_df",
                    "residue_name",
                    "residue_number",
                    "rgroup_df",
                    "sequence_A",
                    "sequence_B",
                ]
        self.columns = columns

        self.type2form = {
            "atom_type": "str",
            "b_factor": "float",
            "chain_id": "str",
            "coords": "np.array",
            "dist_mat": "np.array",
            "element_symbol": "str",
            "node_id": "str",
            "residue_name": "str",
            "residue_number": "int",
            "edge_index": "torch.tensor",
            "kind": "str",
        }

    def convert_nx_to_dgl(self, G: nx.Graph) -> dgl.DGLGraph:
        """
        Converts ``NetworkX`` graph to ``DGL``

        :param G: ``nx.Graph`` to convert to ``DGLGraph``
        :type G: nx.Graph
        :return: ``DGLGraph`` object version of input ``NetworkX`` graph
        :rtype: dgl.DGLGraph
        """
        g = dgl.DGLGraph()
        node_id = list(G.nodes())
        G = nx.convert_node_labels_to_integers(G)

        ## add node level feat

        node_dict = {}
        for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
            for key, value in feat_dict.items():
                if str(key) in self.columns:
                    node_dict[str(key)] = (
                        [value] if i == 0 else node_dict[str(key)] + [value]
                    )

        string_dict = {}
        node_dict_transformed = {}
        for i, j in node_dict.items():
            if i == "coords":
                node_dict_transformed[i] = torch.Tensor(np.asarray(j)).type(
                    "torch.FloatTensor"
                )
            elif i == "dist_mat":
                node_dict_transformed[i] = torch.Tensor(
                    np.asarray(j[0].values)
                ).type("torch.FloatTensor")
            elif self.type2form[i] == "str":
                string_dict[i] = j
            elif self.type2form[i] in ["float", "int"]:
                node_dict_transformed[i] = torch.Tensor(np.array(j))
        g.add_nodes(
            len(node_id),
            node_dict_transformed,
        )

        edge_dict = {}
        edge_index = torch.LongTensor(list(G.edges)).t().contiguous()

        # add edge level features
        for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
            for key, value in feat_dict.items():
                if str(key) in self.columns:
                    edge_dict[str(key)] = (
                        list(value)
                        if i == 0
                        else edge_dict[str(key)] + list(value)
                    )

        edge_transform_dict = {}
        for i, j in node_dict.items():
            if self.type2form[i] == "str":
                string_dict[i] = j
            elif self.type2form[i] in ["float", "int"]:
                edge_transform_dict[i] = torch.Tensor(np.array(j))
        g.add_edges(edge_index[0], edge_index[1], edge_transform_dict)

        # add graph level features
        graph_dict = {
            str(feat_name): [G.graph[feat_name]]
            for feat_name in G.graph
            if str(feat_name) in self.columns
        }

        return g

    def convert_nx_to_pyg(self, G: nx.Graph) -> Data:
        """
        Converts ``NetworkX`` graph to ``pytorch_geometric.data.Data`` object. Requires ``PyTorch Geometric`` (https://pytorch-geometric.readthedocs.io/en/latest/) to be installed.

        :param G: ``nx.Graph`` to convert to PyTorch Geometric ``Data`` object
        :type G: nx.Graph
        :return: ``Data`` object containing networkx graph data
        :rtype: pytorch_geometric.data.Data
        """

        # Initialise dict used to construct Data object & Assign node ids as a feature
        data = {"node_id": list(G.nodes())}
        G = nx.convert_node_labels_to_integers(G)

        # Construct Edge Index
        edge_index = torch.LongTensor(list(G.edges)).t().contiguous()

        # Add node features
        for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
            for key, value in feat_dict.items():
                if str(key) in self.columns:
                    data[str(key)] = (
                        [value] if i == 0 else data[str(key)] + [value]
                    )

        # Add edge features
        for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
            for key, value in feat_dict.items():
                if str(key) in self.columns:
                    data[str(key)] = (
                        list(value) if i == 0 else data[str(key)] + list(value)
                    )

        # Add graph-level features
        for feat_name in G.graph:
            if str(feat_name) in self.columns:
                data[str(feat_name)] = [G.graph[feat_name]]

        if "edge_index" in self.columns:
            data["edge_index"] = edge_index.view(2, -1)

        data = Data.from_dict(data)
        data.num_nodes = G.number_of_nodes()
        return data

    @staticmethod
    def convert_nx_to_nx(G: nx.Graph) -> nx.Graph:
        """
        Converts NetworkX graph (``nx.Graph``) to NetworkX graph (``nx.Graph``) object. Redundant - returns itself.

        :param G: NetworkX Graph
        :type G: nx.Graph
        :return: NetworkX Graph
        :rtype: nx.Graph
        """
        return G

    @staticmethod
    def convert_dgl_to_nx(G: dgl.DGLGraph) -> nx.Graph:
        """
        Converts a DGL Graph (``dgl.DGLGraph``) to a NetworkX (``nx.Graph``) object. Preserves node and edge attributes.

        :param G: ``dgl.DGLGraph`` to convert to ``NetworkX`` graph.
        :type G: dgl.DGLGraph
        :return: NetworkX graph object.
        :rtype: nx.Graph
        """
        node_attrs = G.node_attr_schemes().keys()
        edge_attrs = G.edge_attr_schemes().keys()
        return dgl.to_networkx(G, node_attrs, edge_attrs)

    @staticmethod
    def convert_pyg_to_nx(G: Data) -> nx.Graph:
        """Converts PyTorch Geometric ``Data`` object to NetworkX graph (``nx.Graph``).

        :param G: Pytorch Geometric Data.
        :type G: torch_geometric.data.Data
        :returns: NetworkX graph.
        :rtype: nx.Graph
        """
        return torch_geometric.utils.to_networkx(G)

    def convert_nx_to_jraph(self, G: nx.Graph) -> jraph.GraphsTuple:
        """Converts NetworkX graph (``nx.Graph``) to Jraph GraphsTuple graph. Requires ``jax`` and ``Jraph``.

        :param G: Networkx graph to convert.
        :type G: nx.Graph
        :return: Jraph GraphsTuple graph.
        :rtype: jraph.GraphsTuple
        """
        G = nx.convert_node_labels_to_integers(G)

        n_node = len(G)
        n_edge = G.number_of_edges()
        edge_list = list(G.edges())
        senders, receivers = zip(*edge_list)
        senders, receivers = jnp.array(senders), jnp.array(receivers)

        # Add node features
        node_features = {}
        for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
            for key, value in feat_dict.items():
                if str(key) in self.columns:
                    # node_features[str(key)] = (
                    #    [value]
                    #    if i == 0
                    #    else node_features[str(key)] + [value]
                    # )
                    feat = (
                        [value]
                        if i == 0
                        else node_features[str(key)] + [value]
                    )
                    try:
                        feat = torch.tensor(feat)
                        node_features[str(key)] = feat
                    except TypeError:
                        node_features[str(key)] = feat

        # Add edge features
        edge_features = {}
        for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
            for key, value in feat_dict.items():
                if str(key) in self.columns:
                    edge_features[str(key)] = (
                        list(value)
                        if i == 0
                        else edge_features[str(key)] + list(value)
                    )

        # Add graph features
        global_context = {
            str(feat_name): [G.graph[feat_name]]
            for feat_name in G.graph
            if str(feat_name) in self.columns
        }

        return jraph.GraphsTuple(
            nodes=node_features,
            senders=senders,
            receivers=receivers,
            edges=edge_features,
            n_node=n_node,
            n_edge=n_edge,
            globals=global_context,
        )

    def __call__(self, G: nx.Graph):
        nx_g = eval("self.convert_" + self.src_format + "_to_nx(G)")
        dst_g = eval("self.convert_nx_to_" + self.dst_format + "(nx_g)")
        return dst_g


# def convert_nx_to_pyg_data(G: nx.Graph) -> Data:
#     # Initialise dict used to construct Data object
#     data = {"node_id": list(G.nodes())}

#     G = nx.convert_node_labels_to_integers(G)

#     # Construct Edge Index
#     edge_index = torch.LongTensor(list(G.edges)).t().contiguous()

#     # Add node features
#     for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
#         for key, value in feat_dict.items():
#             data[str(key)] = [value] if i == 0 else data[str(key)] + [value]

#     # Add edge features
#     for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
#         for key, value in feat_dict.items():
#             data[str(key)] = (
#                 list(value) if i == 0 else data[str(key)] + list(value)
#             )

#     # Add graph-level features
#     for feat_name in G.graph:
#         data[str(feat_name)] = [G.graph[feat_name]]

#     data["edge_index"] = edge_index.view(2, -1)
#     data = Data.from_dict(data)
#     data.num_nodes = G.number_of_nodes()

#     return data
def convert_nx_to_pyg_data(G: nx.Graph) -> Data:
    # Initialise dict used to construct Data object
    data = {"node_id": list(G.nodes())}

    G = nx.convert_node_labels_to_integers(G)

    # Construct Edge Index
    edge_index = torch.LongTensor(list(G.edges)).t().contiguous()

    # Add node features
    for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
        for key, value in feat_dict.items():
            data[str(key)] = [value] if i == 0 else data[str(key)] + [value]


    # Add edge features
    for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
        for key, value in feat_dict.items():
            if key == 'distance':
                data[str(key)] = (
                    [value] if i == 0 else data[str(key)] + [value]
                )
            else:
                data[str(key)] = (
                    [list(value)] if i == 0 else data[str(key)] + [list(value)]
                )

    # Add graph-level features
    for feat_name in G.graph:
        data[str(feat_name)] = [G.graph[feat_name]]

    data["edge_index"] = edge_index.view(2, -1)
    data = Data.from_dict(data)
    data.num_nodes = G.number_of_nodes()

    return data