from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax
from torch_scatter import scatter
from torch_sparse import SparseTensor
import loralib as lora
from esm.multihead_attention import MultiheadAttention
import math
from torch import _dynamo
_dynamo.config.suppress_errors = True
from ..module.utils import (
    CosineCutoff,
    act_class_mapping,
    get_template_fn,
    gelu
)


# original torchmd-net attention layer
class EquivariantMultiHeadAttention(MessagePassing):
    """Equivariant multi-head attention layer."""

    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            share_kv,
            edge_attr_channels,
            distance_influence,
            num_heads,
            activation,
            attn_activation,
            cutoff_lower,
            cutoff_upper,
            use_lora=None,
    ):
        super(EquivariantMultiHeadAttention, self).__init__(
            aggr="mean", node_dim=0)
        assert x_hidden_channels % num_heads == 0 \
            and vec_channels % num_heads == 0, (
                f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) "
                f"and vec_channels ({vec_channels}) "
                f"must be evenly divisible by the number of "
                f"attention heads ({num_heads})"
            )
        assert vec_hidden_channels == x_channels, (
            f"The number of hidden channels x_channels ({x_channels}) "
            f"and vec_hidden_channels ({vec_hidden_channels}) "
            f"must be equal"
        )

        self.distance_influence = distance_influence
        self.num_heads = num_heads
        self.x_channels = x_channels
        self.x_hidden_channels = x_hidden_channels
        self.x_head_dim = x_hidden_channels // num_heads
        self.vec_channels = vec_channels
        self.vec_hidden_channels = vec_hidden_channels
        # important, not vec_hidden_channels // num_heads
        self.vec_head_dim = vec_channels // num_heads
        self.share_kv = share_kv
        self.layernorm = nn.LayerNorm(x_channels)
        self.act = activation()
        self.attn_activation = act_class_mapping[attn_activation]()
        self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper)

        if use_lora is not None:
            self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora)
            self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None
            self.v_proj = lora.Linear(
                x_channels, x_hidden_channels + vec_channels * 2, r=use_lora)
            self.o_proj = lora.Linear(
                x_hidden_channels, x_channels * 2 + vec_channels, r=use_lora)
            self.vec_proj = lora.Linear(
                vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False, r=use_lora)
        else:
            self.q_proj = nn.Linear(x_channels, x_hidden_channels)
            self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None
            self.v_proj = nn.Linear(
                x_channels, x_hidden_channels + vec_channels * 2)
            self.o_proj = nn.Linear(
                x_hidden_channels, x_channels * 2 + vec_channels)
            self.vec_proj = nn.Linear(
            vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False)

        self.dk_proj = None
        if distance_influence in ["keys", "both"]:
            if use_lora is not None:
                self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora)
            else:
                self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels)

        self.dv_proj = None
        if distance_influence in ["values", "both"]:
            if use_lora is not None:
                self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2, r=use_lora)
            else:
                self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2)

        self.reset_parameters()

    def reset_parameters(self):
        self.layernorm.reset_parameters()
        nn.init.xavier_uniform_(self.q_proj.weight)
        self.q_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.k_proj.weight)
        self.k_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.v_proj.weight)
        self.v_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.vec_proj.weight)
        if self.dk_proj:
            nn.init.xavier_uniform_(self.dk_proj.weight)
            self.dk_proj.bias.data.fill_(0)
        if self.dv_proj:
            nn.init.xavier_uniform_(self.dv_proj.weight)
            self.dv_proj.bias.data.fill_(0)

    def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij, return_attn=False):
        x = self.layernorm(x)
        q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim)
        v = self.v_proj(x).reshape(-1, self.num_heads,
                                   self.x_head_dim + self.vec_head_dim * 2)
        if self.share_kv:
            k = v[:, :, :self.x_head_dim]
        else:
            k = self.k_proj(x).reshape(-1, self.num_heads, self.x_head_dim)

        vec1, vec2, vec3 = torch.split(self.vec_proj(vec),
                                       [self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1)
        vec = vec.reshape(-1, 3, self.num_heads, self.vec_head_dim)
        vec_dot = (vec1 * vec2).sum(dim=1)

        dk = (
            self.act(self.dk_proj(f_ij)).reshape(-1,
                                                 self.num_heads, self.x_head_dim)
            if self.dk_proj is not None
            else None
        )
        dv = (
            self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads,
                                                 self.x_head_dim + self.vec_head_dim * 2)
            if self.dv_proj is not None
            else None
        )

        # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor,
        # d_ij: Tensor)
        x, vec, attn = self.propagate(
            edge_index,
            q=q,
            k=k,
            v=v,
            vec=vec,
            dk=dk,
            dv=dv,
            r_ij=r_ij,
            d_ij=d_ij,
            size=None,
        )
        x = x.reshape(-1, self.x_hidden_channels)
        vec = vec.reshape(-1, 3, self.vec_channels)

        o1, o2, o3 = torch.split(self.o_proj(
            x), [self.vec_channels, self.x_channels, self.x_channels], dim=1)
        dx = vec_dot * o2 + o3
        dvec = vec3 * o1.unsqueeze(1) + vec
        if return_attn:
            return dx, dvec, torch.concat((edge_index.T, attn), dim=1)
        else:
            return dx, dvec, None

    def message(self, q_i, k_j, v_j, vec_j, dk, dv, r_ij, d_ij):
        # attention mechanism
        if dk is None:
            attn = (q_i * k_j).sum(dim=-1)
        else:  # TODO: consider add or multiply dk
            attn = (q_i * k_j * dk).sum(dim=-1)

        # attention activation function
        attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1)

        # value pathway
        if dv is not None:
            v_j = v_j * dv
        x, vec1, vec2 = torch.split(
            v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2)

        # update scalar features
        x = x * attn.unsqueeze(2)
        # update vector features
        vec = vec_j * vec1.unsqueeze(1) + vec2.unsqueeze(1) * \
            d_ij.unsqueeze(2).unsqueeze(3)
        return x, vec, attn

    def aggregate(
            self,
            features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            index: torch.Tensor,
            ptr: Optional[torch.Tensor],
            dim_size: Optional[int],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x, vec, attn = features
        x = scatter(x, index, dim=self.node_dim, dim_size=dim_size)
        vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size)
        return x, vec, attn

    def update(
            self, inputs: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return inputs

    def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor:
        pass

    def edge_update(self) -> Tensor:
        pass


# ESM multi-head attention layer, added LoRA
class ESMMultiheadAttention(MultiheadAttention):
    """Multi-headed attention.

    See "Attention Is All You Need" for more details.
    """

    def __init__(
        self,
        embed_dim,
        num_heads,
        kdim=None,
        vdim=None,
        dropout=0.0,
        bias=True,
        add_bias_kv: bool = False,
        add_zero_attn: bool = False,
        self_attention: bool = False,
        encoder_decoder_attention: bool = False,
        use_rotary_embeddings: bool = False,
    ):
        super().__init__(embed_dim, num_heads, kdim, vdim, dropout, bias, add_bias_kv, add_zero_attn, self_attention,
                         encoder_decoder_attention, use_rotary_embeddings)
        # change the projection to LoRA
        self.k_proj = lora.Linear(self.kdim, embed_dim, bias=bias, r=16)
        self.v_proj = lora.Linear(self.vdim, embed_dim, bias=bias, r=16)
        self.q_proj = lora.Linear(embed_dim, embed_dim, bias=bias, r=16)
        self.out_proj = lora.Linear(embed_dim, embed_dim, bias=bias, r=16)


# original torchmd-net attention layer, add pair-wise confidence of PAE
class EquivariantPAEMultiHeadAttention(EquivariantMultiHeadAttention):
    """Equivariant multi-head attention layer."""

    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            share_kv,
            edge_attr_channels,
            edge_attr_dist_channels,
            distance_influence,
            num_heads,
            activation,
            attn_activation,
            cutoff_lower,
            cutoff_upper,
            use_lora=None,
    ):
        super(EquivariantPAEMultiHeadAttention, self).__init__(
            x_channels=x_channels,
            x_hidden_channels=x_hidden_channels,
            vec_channels=vec_channels,
            vec_hidden_channels=vec_hidden_channels,
            share_kv=share_kv,
            edge_attr_channels=edge_attr_channels,
            distance_influence=distance_influence,
            num_heads=num_heads,
            activation=activation,
            attn_activation=attn_activation,
            cutoff_lower=cutoff_lower,
            cutoff_upper=cutoff_upper,
            use_lora=use_lora)
        # we cancel the cutoff function
        self.cutoff = None
        # we set separate projection for distance influence
        self.dk_dist_proj = None
        if distance_influence in ["keys", "both"]:
            if use_lora is not None:
                self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora)
            else:
                self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels)
        self.dv_dist_proj = None
        if distance_influence in ["values", "both"]:
            if use_lora is not None:
                self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora)
            else:
                self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2)
        if self.dk_dist_proj:
            nn.init.xavier_uniform_(self.dk_dist_proj.weight)
            self.dk_dist_proj.bias.data.fill_(0)
        if self.dv_dist_proj:
            nn.init.xavier_uniform_(self.dv_dist_proj.weight)
            self.dv_dist_proj.bias.data.fill_(0)

    def forward(self, x, vec, edge_index, w_ij, f_dist_ij, f_ij, d_ij, plddt, return_attn=False):
        # we replaced r_ij to w_ij as pair-wise confidence
        # we add plddt as position-wise confidence
        x = self.layernorm(x)
        q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim)
        v = self.v_proj(x).reshape(-1, self.num_heads,
                                   self.x_head_dim + self.vec_head_dim * 2)
        if self.share_kv:
            k = v[:, :, :self.x_head_dim]
        else:
            k = self.k_proj(x).reshape(-1, self.num_heads, self.x_head_dim)

        vec1, vec2, vec3 = torch.split(self.vec_proj(vec),
                                       [self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1)
        vec = vec.reshape(-1, 3, self.num_heads, self.vec_head_dim)
        vec_dot = (vec1 * vec2).sum(dim=1)

        dk = (
            self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim)
            if self.dk_proj is not None
            else None
        )
        dk_dist = (
            self.act(self.dk_dist_proj(f_dist_ij)).reshape(-1, self.num_heads, self.x_head_dim)
            if self.dk_dist_proj is not None
            else None
        )
        dv = (
            self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2)
            if self.dv_proj is not None
            else None
        )
        dv_dist = (
            self.act(self.dv_dist_proj(f_dist_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2)
            if self.dv_dist_proj is not None
            else None
        )

        # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor,
        # d_ij: Tensor)
        x, vec, attn = self.propagate(
            edge_index,
            q=q,
            k=k,
            v=v,
            vec=vec,
            dk=dk,
            dk_dist=dk_dist,
            dv=dv,
            dv_dist=dv_dist,
            d_ij=d_ij,
            w_ij=w_ij,
            size=None,
        )
        x = x.reshape(-1, self.x_hidden_channels)
        vec = vec.reshape(-1, 3, self.vec_channels)

        o1, o2, o3 = torch.split(self.o_proj(
            x), [self.vec_channels, self.x_channels, self.x_channels], dim=1)
        dx = vec_dot * o2 * plddt.unsqueeze(1) + o3
        dvec = vec3 * o1.unsqueeze(1) * plddt.unsqueeze(1).unsqueeze(2) + vec
        if return_attn:
            return dx, dvec, torch.concat((edge_index.T, attn), dim=1)
        else:
            return dx, dvec, None

    def message(self, q_i, k_j, v_j, vec_j, dk, dk_dist, dv, dv_dist, d_ij, w_ij):
        # attention mechanism
        attn = (q_i * k_j)
        if dk is not None:
            attn += dk
        if dk_dist is not None:
            attn += dk_dist * w_ij.unsqueeze(1).unsqueeze(2)
        attn = attn.sum(dim=-1)

        # attention activation function
        attn = self.attn_activation(attn)

        # value pathway, add dv, but apply w_ij to dv
        if dv is not None:
            v_j += dv
        if dv_dist is not None:
            v_j += dv_dist * w_ij.unsqueeze(1).unsqueeze(2)
        x, vec1, vec2 = torch.split(
            v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2)

        # update scalar features
        x = x * attn.unsqueeze(2)
        # update vector features
        vec = vec_j * vec1.unsqueeze(1) + vec2.unsqueeze(1) * \
            d_ij.unsqueeze(2).unsqueeze(3)
        return x, vec, attn


# original torchmd-net attention layer, add pair-wise confidence of PAE
class EquivariantWeightedPAEMultiHeadAttention(EquivariantMultiHeadAttention):
    """Equivariant multi-head attention layer."""

    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            share_kv,
            edge_attr_channels,
            edge_attr_dist_channels,
            distance_influence,
            num_heads,
            activation,
            attn_activation,
            cutoff_lower,
            cutoff_upper,
            use_lora=None,
    ):
        super(EquivariantWeightedPAEMultiHeadAttention, self).__init__(
            x_channels=x_channels,
            x_hidden_channels=x_hidden_channels,
            vec_channels=vec_channels,
            vec_hidden_channels=vec_hidden_channels,
            share_kv=share_kv,
            edge_attr_channels=edge_attr_channels,
            distance_influence=distance_influence,
            num_heads=num_heads,
            activation=activation,
            attn_activation=attn_activation,
            cutoff_lower=cutoff_lower,
            cutoff_upper=cutoff_upper,
            use_lora=use_lora)
        # we cancel the cutoff function
        self.cutoff = None
        # we set a separate weight for distance influence
        self.pae_weight = nn.Linear(1, 1, bias=True)
        self.pae_weight.weight.data.fill_(-0.5)
        self.pae_weight.bias.data.fill_(7.5)
        # we set separate projection for distance influence
        self.dk_dist_proj = None
        if distance_influence in ["keys", "both"]:
            if use_lora is not None:
                self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora)
            else:
                self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels)
        self.dv_dist_proj = None
        if distance_influence in ["values", "both"]:
            if use_lora is not None:
                self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora)
            else:
                self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2)
        if self.dk_dist_proj:
            nn.init.xavier_uniform_(self.dk_dist_proj.weight)
            self.dk_dist_proj.bias.data.fill_(0)
        if self.dv_dist_proj:
            nn.init.xavier_uniform_(self.dv_dist_proj.weight)
            self.dv_dist_proj.bias.data.fill_(0)

    def forward(self, x, vec, edge_index, w_ij, f_dist_ij, f_ij, d_ij, plddt, return_attn=False):
        # we replaced r_ij to w_ij as pair-wise confidence
        # we add plddt as position-wise confidence
        x = self.layernorm(x)
        q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim)
        v = self.v_proj(x).reshape(-1, self.num_heads,
                                   self.x_head_dim + self.vec_head_dim * 2)
        if self.share_kv:
            k = v[:, :, :self.x_head_dim]
        else:
            k = self.k_proj(x).reshape(-1, self.num_heads, self.x_head_dim)

        vec1, vec2, vec3 = torch.split(self.vec_proj(vec),
                                       [self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1)
        vec = vec.reshape(-1, 3, self.num_heads, self.vec_head_dim)
        vec_dot = (vec1 * vec2).sum(dim=1)

        dk = (
            self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim)
            if self.dk_proj is not None
            else None
        )
        dk_dist = (
            self.act(self.dk_dist_proj(f_dist_ij)).reshape(-1, self.num_heads, self.x_head_dim)
            if self.dk_dist_proj is not None
            else None
        )
        dv = (
            self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2)
            if self.dv_proj is not None
            else None
        )
        dv_dist = (
            self.act(self.dv_dist_proj(f_dist_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2)
            if self.dv_dist_proj is not None
            else None
        )

        # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor,
        # d_ij: Tensor)
        x, vec, attn = self.propagate(
            edge_index,
            q=q,
            k=k,
            v=v,
            vec=vec,
            dk=dk,
            dk_dist=dk_dist,
            dv=dv,
            dv_dist=dv_dist,
            d_ij=d_ij,
            w_ij=nn.functional.sigmoid(self.pae_weight(w_ij.unsqueeze(-1)).squeeze(-1)),
            size=None,
        )
        x = x.reshape(-1, self.x_hidden_channels)
        vec = vec.reshape(-1, 3, self.vec_channels)

        o1, o2, o3 = torch.split(self.o_proj(
            x), [self.vec_channels, self.x_channels, self.x_channels], dim=1)
        dx = vec_dot * o2 * plddt.unsqueeze(1) + o3
        dvec = vec3 * o1.unsqueeze(1) * plddt.unsqueeze(1).unsqueeze(2) + vec
        if return_attn:
            return dx, dvec, torch.concat((edge_index.T, attn), dim=1)
        else:
            return dx, dvec, None

    def message(self, q_i, k_j, v_j, vec_j, dk, dk_dist, dv, dv_dist, d_ij, w_ij):
        # attention mechanism
        attn = (q_i * k_j)
        if dk_dist is not None:
            if dk is not None:
                attn *= (dk + dk_dist * w_ij.unsqueeze(1).unsqueeze(2))
            else:
                attn *= dk_dist * w_ij
        else:
            if dk is not None:
                attn *= dk
        attn = attn.sum(dim=-1)

        # attention activation function
        attn = self.attn_activation(attn)

        # value pathway, add dv, but apply w_ij to dv
        if dv is not None:
            v_j += dv
        if dv_dist is not None:
            v_j += dv_dist * w_ij.unsqueeze(1).unsqueeze(2)
        x, vec1, vec2 = torch.split(
            v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2)

        # update scalar features
        x = x * attn.unsqueeze(2)
        # update vector features
        vec = vec_j * vec1.unsqueeze(1) + vec2.unsqueeze(1) * \
            d_ij.unsqueeze(2).unsqueeze(3)
        return x, vec, attn


class EquivariantPAEMultiHeadAttentionSoftMaxFullGraph(nn.Module):
    """Equivariant multi-head attention layer with softmax, apply attention on full graph by default"""
    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            share_kv,
            edge_attr_channels,
            edge_attr_dist_channels,
            distance_influence,
            num_heads,
            activation,
            attn_activation,
            cutoff_lower,
            cutoff_upper,
            use_lora=None,
    ):
        # same as EquivariantPAEMultiHeadAttentionSoftMax, but apply attention on full graph by default
        super(EquivariantPAEMultiHeadAttentionSoftMaxFullGraph, self).__init__()
        assert x_hidden_channels % num_heads == 0 \
            and vec_channels % num_heads == 0, (
                f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) "
                f"and vec_channels ({vec_channels}) "
                f"must be evenly divisible by the number of "
                f"attention heads ({num_heads})"
            )
        assert vec_hidden_channels == x_channels, (
            f"The number of hidden channels x_channels ({x_channels}) "
            f"and vec_hidden_channels ({vec_hidden_channels}) "
            f"must be equal"
        )

        self.distance_influence = distance_influence
        self.num_heads = num_heads
        self.x_channels = x_channels
        self.x_hidden_channels = x_hidden_channels
        self.x_head_dim = x_hidden_channels // num_heads
        self.vec_channels = vec_channels
        self.vec_hidden_channels = vec_hidden_channels
        # important, not vec_hidden_channels // num_heads
        self.vec_head_dim = vec_channels // num_heads
        self.share_kv = share_kv
        self.layernorm = nn.LayerNorm(x_channels)
        self.act = activation()
        self.cutoff = None
        self.scaling = self.x_head_dim**-0.5
        if use_lora is not None:
            self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora)
            self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None
            self.v_proj = lora.Linear(x_channels, x_hidden_channels + vec_channels * 2, r=use_lora)
            self.o_proj = lora.Linear(x_hidden_channels, x_channels * 2 + vec_channels, r=use_lora)
            self.vec_proj = lora.Linear(vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False, r=use_lora)
        else:
            self.q_proj = nn.Linear(x_channels, x_hidden_channels)
            self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None
            self.v_proj = nn.Linear(x_channels, x_hidden_channels + vec_channels * 2)
            self.o_proj = nn.Linear(x_hidden_channels, x_channels * 2 + vec_channels)
            self.vec_proj = nn.Linear(vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False)

        self.dk_proj = None
        self.dk_dist_proj = None
        self.dv_proj = None
        self.dv_dist_proj = None
        if distance_influence in ["keys", "both"]:
            if use_lora is not None:
                self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora)
                self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora)
            else:
                self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels)
                self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels)

        if distance_influence in ["values", "both"]:
            if use_lora is not None:
                self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2, r=use_lora)
                self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora)
            else:
                self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2)
                self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2)
        # set PAE weight as a learnable parameter, basiclly a sigmoid function
        self.pae_weight = nn.Linear(1, 1, bias=True)
        self.reset_parameters()

    def reset_parameters(self):
        self.layernorm.reset_parameters()
        nn.init.xavier_uniform_(self.q_proj.weight)
        self.q_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.k_proj.weight)
        self.k_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.v_proj.weight)
        self.v_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.vec_proj.weight)
        self.pae_weight.weight.data.fill_(-0.5)
        self.pae_weight.bias.data.fill_(7.5)
        if self.dk_proj:
            nn.init.xavier_uniform_(self.dk_proj.weight)
            self.dk_proj.bias.data.fill_(0)
        if self.dv_proj:
            nn.init.xavier_uniform_(self.dv_proj.weight)
            self.dv_proj.bias.data.fill_(0)
        if self.dk_dist_proj:
            nn.init.xavier_uniform_(self.dk_dist_proj.weight)
            self.dk_dist_proj.bias.data.fill_(0)
        if self.dv_dist_proj:
            nn.init.xavier_uniform_(self.dv_dist_proj.weight)
            self.dv_dist_proj.bias.data.fill_(0)

    def forward(self, x, vec, edge_index, w_ij, f_dist_ij, f_ij, d_ij, plddt, key_padding_mask, return_attn=False):
        # we replaced r_ij to w_ij as pair-wise confidence
        # we add plddt as position-wise confidence
        # edge_index is unused
        x = self.layernorm(x)
        q = self.q_proj(x) * self.scaling
        v = self.v_proj(x)
        # if self.share_kv:
        #     k = v[:, :, :self.x_head_dim]
        # else:
        k = self.k_proj(x)

        vec1, vec2, vec3 = torch.split(self.vec_proj(vec),
                                       [self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1)
        vec_dot = (vec1 * vec2).sum(dim=-2)

        dk = self.act(self.dk_proj(f_ij)) 
        dk_dist = self.act(self.dk_dist_proj(f_dist_ij)) 
        dv = self.act(self.dv_proj(f_ij))  
        dv_dist = self.act(self.dv_dist_proj(f_dist_ij)) 

        # full graph attention
        x, vec, attn = self.attention(
            q=q,
            k=k,
            v=v,
            vec=vec,
            dk=dk,
            dk_dist=dk_dist,
            dv=dv,
            dv_dist=dv_dist,
            d_ij=d_ij,
            w_ij=nn.functional.sigmoid(self.pae_weight(w_ij.unsqueeze(-1)).squeeze(-1)),
            key_padding_mask=key_padding_mask,
        )
        o1, o2, o3 = torch.split(self.o_proj(x), [self.vec_channels, self.x_channels, self.x_channels], dim=-1)
        dx = vec_dot * o2 * plddt.unsqueeze(-1) + o3
        dvec = vec3 * o1.unsqueeze(-2) * plddt.unsqueeze(-1).unsqueeze(-2) + vec
        # apply key_padding_mask to dx
        dx = dx.masked_fill(key_padding_mask.unsqueeze(-1), 0)
        dvec = dvec.masked_fill(key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0)
        if return_attn:
            return dx, dvec, attn
        else:
            return dx, dvec, None

    def attention(self, q, k, v, vec, dk, dk_dist, dv, dv_dist, d_ij, w_ij, key_padding_mask=None, need_head_weights=False):
        # note that q is of shape (bsz, tgt_len, num_heads * head_dim)
        # k, v is of shape (bsz, src_len, num_heads * head_dim)
        # vec is of shape (bsz, src_len, 3, num_heads * head_dim)
        # dk, dk_dist, dv, dv_dist is of shape (bsz, tgt_len, src_len, num_heads * head_dim)
        # d_ij is of shape (bsz, tgt_len, src_len, 3)
        # w_ij is of shape (bsz, tgt_len, src_len)
        # key_padding_mask is of shape (bsz, src_len)
        bsz, tgt_len, _ = q.size()
        src_len = k.size(1)
        # change q size to (bsz * num_heads, tgt_len, head_dim)
        # change k,v size to (bsz * num_heads, src_len, head_dim)
        q = q.transpose(0, 1).reshape(tgt_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous()
        k = k.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous()
        v = v.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).transpose(0, 1).contiguous()
        # change vec to (bsz * num_heads, src_len, 3, head_dim)
        vec = vec.permute(1, 2, 0, 3).reshape(src_len, 3, bsz * self.num_heads, self.vec_head_dim).permute(2, 0, 1, 3).contiguous()
        # dk size is (bsz, tgt_len, src_len, num_heads * head_dim)
        # if dk is not None:
        # change dk to (bsz * num_heads, tgt_len, src_len, head_dim)
        dk = dk.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous()
        # if dk_dist is not None:
        # change dk_dist to (bsz * num_heads, tgt_len, src_len, head_dim)
        dk_dist = dk_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous()
        # dv size is (bsz, tgt_len, src_len, num_heads * head_dim)
        # if dv is not None:
        # change dv to (bsz * num_heads, tgt_len, src_len, head_dim)
        dv = dv.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous()
        # if dv_dist is not None:
        # change dv_dist to (bsz * num_heads, tgt_len, src_len, head_dim)
        dv_dist = dv_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous()
        # if key_padding_mask is not None:
        # key_padding_mask should be (bsz, src_len)
        assert key_padding_mask.size(0) == bsz
        assert key_padding_mask.size(1) == src_len
        # attn_weights size is (bsz * num_heads, tgt_len, src_len, head_dim)
        attn_weights = torch.multiply(q[:, :, None, :], k[:, None, :, :])
        # w_ij is PAE confidence
        # w_ij size is (bsz, tgt_len, src_len)
        # change dimension of w_ij to (bsz * num_heads, tgt_len, src_len, head_dim)
        # if dk_dist is not None:
        assert w_ij is not None
        #     if dk is not None:
        attn_weights *= (dk + dk_dist * w_ij[:, :, :, None].repeat(self.num_heads, 1, 1, self.x_head_dim))
        # add dv and dv_dist
        v = v.unsqueeze(1) + dv + dv_dist * w_ij[:, :, :, None].repeat(self.num_heads, 1, 1, self.x_head_dim + 2 * self.vec_head_dim)
        #     else:
        #         attn_weights *= dk_dist * w_ij
        # else:
        #     if dk is not None:
        #         attn_weights *= dk
        # attn_weights size is (bsz * num_heads, tgt_len, src_len)
        attn_weights = attn_weights.sum(dim=-1)
        # apply key_padding_mask to attn_weights
        # if key_padding_mask is not None:
        # don't attend to padding symbols
        attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).contiguous()
        attn_weights = attn_weights.masked_fill(
            key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
        )
        attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len).contiguous()
        # apply softmax to attn_weights
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
        # x, vec1, vec2 are of shape (bsz * num_heads, src_len, head_dim)
        x, vec1, vec2 = torch.split(v, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=-1)
        # first get invariant feature outputs x, size is (bsz * num_heads, tgt_len, head_dim)
        x_out = torch.einsum('bts,btsh->bth', attn_weights, x)
        # next get equivariant feature outputs vec_out_1, size is (bsz * num_heads, tgt_len, 3, head_dim)
        vec_out_1 = torch.einsum('bsih,btsh->btih', vec, vec1)
        # next get equivariant feature outputs vec_out_2, size is (bsz * num_heads, tgt_len, src_len, 3, head_dim)
        vec_out_2 = torch.einsum('btsi,btsh->btih', d_ij, vec2)
        # adds up vec_out_1 and vec_out_2, get vec_out, size is (bsz * num_heads, tgt_len, 3, head_dim)
        vec_out = vec_out_1 + vec_out_2
        attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
        # if not need_head_weights:
        # average attention weights over heads
        attn_weights = attn_weights.mean(dim=0)
        # reshape x_out to (bsz, tgt_len, num_heads * head_dim)
        x_out = x_out.transpose(1, 0).reshape(tgt_len, bsz, self.num_heads * self.x_head_dim).transpose(1, 0).contiguous()
        # reshape vec_out to (bsz, tgt_len, 3, num_heads * head_dim)
        vec_out = vec_out.permute(1, 2, 0, 3).reshape(tgt_len, 3, bsz, self.num_heads * self.vec_head_dim).permute(2, 0, 1, 3).contiguous()
        return x_out, vec_out, attn_weights


class MultiHeadAttentionSoftMaxFullGraph(nn.Module):
    """
    Multi-head attention layer with softmax, apply attention on full graph by default
    No equivariant property, but can take structure information as input, just didn't use it
    """
    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            share_kv,
            edge_attr_channels,
            edge_attr_dist_channels,
            distance_influence,
            num_heads,
            activation,
            attn_activation,
            cutoff_lower,
            cutoff_upper,
            use_lora=None,
    ):
        # same as EquivariantPAEMultiHeadAttentionSoftMax, but apply attention on full graph by default
        super(MultiHeadAttentionSoftMaxFullGraph, self).__init__()
        assert x_hidden_channels % num_heads == 0 \
            and vec_channels % num_heads == 0, (
                f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) "
                f"and vec_channels ({vec_channels}) "
                f"must be evenly divisible by the number of "
                f"attention heads ({num_heads})"
            )
        assert vec_hidden_channels == x_channels, (
            f"The number of hidden channels x_channels ({x_channels}) "
            f"and vec_hidden_channels ({vec_hidden_channels}) "
            f"must be equal"
        )

        self.distance_influence = distance_influence
        self.num_heads = num_heads
        self.x_channels = x_channels
        self.x_hidden_channels = x_hidden_channels
        self.x_head_dim = x_hidden_channels // num_heads
        self.vec_channels = vec_channels
        self.vec_hidden_channels = vec_hidden_channels
        # important, not vec_hidden_channels // num_heads
        self.vec_head_dim = vec_channels // num_heads
        self.share_kv = share_kv
        self.layernorm = nn.LayerNorm(x_channels)
        self.act = activation()
        self.cutoff = None
        self.scaling = self.x_head_dim**-0.5
        if use_lora is not None:
            self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora)
            self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None
            self.v_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora)
            self.o_proj = lora.Linear(x_hidden_channels, x_channels, r=use_lora)
            # self.vec_proj = lora.Linear(vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False, r=use_lora)
        else:
            self.q_proj = nn.Linear(x_channels, x_hidden_channels)
            self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None
            self.v_proj = nn.Linear(x_channels, x_hidden_channels)
            self.o_proj = nn.Linear(x_hidden_channels, x_channels)
            # self.vec_proj = nn.Linear(vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False)

        self.dk_proj = None
        self.dk_dist_proj = None
        self.dv_proj = None
        self.dv_dist_proj = None
        if distance_influence in ["keys", "both"]:
            if use_lora is not None:
                self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora)
                # self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora)
            else:
                self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels)
                # self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels)

        if distance_influence in ["values", "both"]:
            if use_lora is not None:
                self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora)
                # self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora)
            else:
                self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels)
                # self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2)
        # set PAE weight as a learnable parameter, basiclly a sigmoid function
        # self.pae_weight = nn.Linear(1, 1, bias=True)
        self.reset_parameters()

    def reset_parameters(self):
        self.layernorm.reset_parameters()
        nn.init.xavier_uniform_(self.q_proj.weight)
        self.q_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.k_proj.weight)
        self.k_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.v_proj.weight)
        self.v_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)
        # nn.init.xavier_uniform_(self.vec_proj.weight)
        # self.pae_weight.weight.data.fill_(-0.5)
        # self.pae_weight.bias.data.fill_(7.5)
        if self.dk_proj:
            nn.init.xavier_uniform_(self.dk_proj.weight)
            self.dk_proj.bias.data.fill_(0)
        if self.dv_proj:
            nn.init.xavier_uniform_(self.dv_proj.weight)
            self.dv_proj.bias.data.fill_(0)

    def forward(self, x, vec, edge_index, w_ij, f_dist_ij, f_ij, d_ij, plddt, key_padding_mask, return_attn=False):
        # we replaced r_ij to w_ij as pair-wise confidence
        # we add plddt as position-wise confidence
        # edge_index is unused
        x = self.layernorm(x)
        q = self.q_proj(x) * self.scaling
        v = self.v_proj(x)
        # if self.share_kv:
        #     k = v[:, :, :self.x_head_dim]
        # else:
        k = self.k_proj(x)

        # vec1, vec2, vec3 = torch.split(self.vec_proj(vec),
        #                                [self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1)
        # vec_dot = (vec1 * vec2).sum(dim=-2)

        dk = self.act(self.dk_proj(f_ij)) 
        # dk_dist = self.act(self.dk_dist_proj(f_dist_ij)) 
        dv = self.act(self.dv_proj(f_ij))  
        # dv_dist = self.act(self.dv_dist_proj(f_dist_ij)) 

        # full graph attention
        x, vec, attn = self.attention(
            q=q,
            k=k,
            v=v,
            vec=vec,
            dk=dk,
            # dk_dist=dk_dist,
            dv=dv,
            # dv_dist=dv_dist,
            # d_ij=d_ij,
            # w_ij=nn.functional.sigmoid(self.pae_weight(w_ij.unsqueeze(-1)).squeeze(-1)),
            key_padding_mask=key_padding_mask,
        )
        # o1, o2, o3 = torch.split(self.o_proj(x), [self.vec_channels, self.x_channels, self.x_channels], dim=-1)
        # dx = vec_dot * o2 * plddt.unsqueeze(-1) + o3
        dx = self.o_proj(x)
        # dvec = vec3 * o1.unsqueeze(-2) * plddt.unsqueeze(-1).unsqueeze(-2) + vec
        # apply key_padding_mask to dx
        dx = dx.masked_fill(key_padding_mask.unsqueeze(-1), 0)
        # dvec = dvec.masked_fill(key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0)
        if return_attn:
            return dx, vec, attn
        else:
            return dx, vec, None

    def attention(self, q, k, v, vec, dk, dv, key_padding_mask=None, need_head_weights=False):
        # note that q is of shape (bsz, tgt_len, num_heads * head_dim)
        # k, v is of shape (bsz, src_len, num_heads * head_dim)
        # vec is of shape (bsz, src_len, 3, num_heads * head_dim)
        # dk, dk_dist, dv, dv_dist is of shape (bsz, tgt_len, src_len, num_heads * head_dim)
        # d_ij is of shape (bsz, tgt_len, src_len, 3)
        # w_ij is of shape (bsz, tgt_len, src_len)
        # key_padding_mask is of shape (bsz, src_len)
        bsz, tgt_len, _ = q.size()
        src_len = k.size(1)
        # change q size to (bsz * num_heads, tgt_len, head_dim)
        # change k,v size to (bsz * num_heads, src_len, head_dim)
        q = q.transpose(0, 1).reshape(tgt_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous()
        k = k.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous()
        v = v.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous()
        # change vec to (bsz * num_heads, src_len, 3, head_dim)
        # vec = vec.permute(1, 2, 0, 3).reshape(src_len, 3, bsz * self.num_heads, self.vec_head_dim).permute(2, 0, 1, 3).contiguous()
        # dk size is (bsz, tgt_len, src_len, num_heads * head_dim)
        # if dk is not None:
        # change dk to (bsz * num_heads, tgt_len, src_len, head_dim)
        dk = dk.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous()
        # if dk_dist is not None:
        # change dk_dist to (bsz * num_heads, tgt_len, src_len, head_dim)
        # dk_dist = dk_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous()
        # dv size is (bsz, tgt_len, src_len, num_heads * head_dim)
        # if dv is not None:
        # change dv to (bsz * num_heads, tgt_len, src_len, head_dim)
        dv = dv.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous()
        # if dv_dist is not None:
        # change dv_dist to (bsz * num_heads, tgt_len, src_len, head_dim)
        # dv_dist = dv_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous()
        # if key_padding_mask is not None:
        # key_padding_mask should be (bsz, src_len)
        assert key_padding_mask.size(0) == bsz
        assert key_padding_mask.size(1) == src_len
        # attn_weights size is (bsz * num_heads, tgt_len, src_len, head_dim)
        attn_weights = torch.multiply(q[:, :, None, :], k[:, None, :, :])
        # w_ij is PAE confidence
        # w_ij size is (bsz, tgt_len, src_len)
        # change dimension of w_ij to (bsz * num_heads, tgt_len, src_len, head_dim)
        # if dk_dist is not None:
        # assert w_ij is not None
        #     if dk is not None:
        attn_weights *= dk
        # add dv and dv_dist
        v = v.unsqueeze(1) + dv
        #     else:
        #         attn_weights *= dk_dist * w_ij
        # else:
        #     if dk is not None:
        #         attn_weights *= dk
        # attn_weights size is (bsz * num_heads, tgt_len, src_len)
        attn_weights = attn_weights.sum(dim=-1)
        # apply key_padding_mask to attn_weights
        # if key_padding_mask is not None:
        # don't attend to padding symbols
        attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).contiguous()
        attn_weights = attn_weights.masked_fill(
            key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
        )
        attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len).contiguous()
        # apply softmax to attn_weights
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
        # x, vec1, vec2 are of shape (bsz * num_heads, src_len, head_dim)
        # x, vec1, vec2 = torch.split(v, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=-1)
        # first get invariant feature outputs x, size is (bsz * num_heads, tgt_len, head_dim)
        x_out = torch.einsum('bts,btsh->bth', attn_weights, v)
        # next get equivariant feature outputs vec_out_1, size is (bsz * num_heads, tgt_len, 3, head_dim)
        # vec_out_1 = torch.einsum('bsih,btsh->btih', vec, vec1)
        # next get equivariant feature outputs vec_out_2, size is (bsz * num_heads, tgt_len, src_len, 3, head_dim)
        # vec_out_2 = torch.einsum('btsi,btsh->btih', d_ij, vec2)
        # adds up vec_out_1 and vec_out_2, get vec_out, size is (bsz * num_heads, tgt_len, 3, head_dim)
        # vec_out = vec_out_1 + vec_out_2
        attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
        # if not need_head_weights:
        # average attention weights over heads
        attn_weights = attn_weights.mean(dim=0)
        # reshape x_out to (bsz, tgt_len, num_heads * head_dim)
        x_out = x_out.transpose(1, 0).reshape(tgt_len, bsz, self.num_heads * self.x_head_dim).transpose(1, 0).contiguous()
        # reshape vec_out to (bsz, tgt_len, 3, num_heads * head_dim)
        # vec_out = vec_out.permute(1, 2, 0, 3).reshape(tgt_len, 3, bsz, self.num_heads * self.vec_head_dim).permute(2, 0, 1, 3).contiguous()
        return x_out, vec, attn_weights


class PAEMultiHeadAttentionSoftMaxStarGraph(nn.Module):
    """Equivariant multi-head attention layer with softmax, apply attention on full graph by default"""
    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            share_kv,
            edge_attr_channels,
            edge_attr_dist_channels,
            distance_influence,
            num_heads,
            activation,
            cutoff_lower,
            cutoff_upper,
            use_lora=None,
    ):
        # same as EquivariantPAEMultiHeadAttentionSoftMax, but apply attention on full graph by default
        super(PAEMultiHeadAttentionSoftMaxStarGraph, self).__init__()
        assert x_hidden_channels % num_heads == 0 \
            and vec_channels % num_heads == 0, (
                f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) "
                f"and vec_channels ({vec_channels}) "
                f"must be evenly divisible by the number of "
                f"attention heads ({num_heads})"
            )
        assert vec_hidden_channels == x_channels, (
            f"The number of hidden channels x_channels ({x_channels}) "
            f"and vec_hidden_channels ({vec_hidden_channels}) "
            f"must be equal"
        )

        self.distance_influence = distance_influence
        self.num_heads = num_heads
        self.x_channels = x_channels
        self.x_hidden_channels = x_hidden_channels
        self.x_head_dim = x_hidden_channels // num_heads
        self.vec_channels = vec_channels
        self.vec_hidden_channels = vec_hidden_channels
        # important, not vec_hidden_channels // num_heads
        self.vec_head_dim = vec_channels // num_heads
        self.share_kv = share_kv
        self.layernorm = nn.LayerNorm(x_channels)
        self.act = activation()
        self.cutoff = None
        self.scaling = self.x_head_dim**-0.5
        if use_lora is not None:
            self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora)
            self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None
            self.v_proj = lora.Linear(x_channels, x_hidden_channels + vec_channels * 2, r=use_lora)
        else:
            self.q_proj = nn.Linear(x_channels, x_hidden_channels)
            self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None
            self.v_proj = nn.Linear(x_channels, x_hidden_channels)

        self.dk_proj = None
        self.dk_dist_proj = None
        self.dv_proj = None
        self.dv_dist_proj = None
        if distance_influence in ["keys", "both"]:
            if use_lora is not None:
                self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora)
                self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora)
            else:
                self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels)
                self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels)

        if distance_influence in ["values", "both"]:
            if use_lora is not None:
                self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2, r=use_lora)
                self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora)
            else:
                self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2)
                self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2)
        # set PAE weight as a learnable parameter, basiclly a sigmoid function
        self.pae_weight = nn.Linear(1, 1, bias=True)
        self.reset_parameters()

    def reset_parameters(self):
        self.layernorm.reset_parameters()
        nn.init.xavier_uniform_(self.q_proj.weight)
        self.q_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.k_proj.weight)
        self.k_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.v_proj.weight)
        self.v_proj.bias.data.fill_(0)
        self.pae_weight.weight.data.fill_(-0.5)
        self.pae_weight.bias.data.fill_(7.5)
        if self.dk_proj:
            nn.init.xavier_uniform_(self.dk_proj.weight)
            self.dk_proj.bias.data.fill_(0)
        if self.dv_proj:
            nn.init.xavier_uniform_(self.dv_proj.weight)
            self.dv_proj.bias.data.fill_(0)
        if self.dk_dist_proj:
            nn.init.xavier_uniform_(self.dk_dist_proj.weight)
            self.dk_dist_proj.bias.data.fill_(0)
        if self.dv_dist_proj:
            nn.init.xavier_uniform_(self.dv_dist_proj.weight)
            self.dv_dist_proj.bias.data.fill_(0)

    def forward(self, x, x_center_index, w_ij, f_dist_ij, f_ij, key_padding_mask, return_attn=False):
        # we replaced r_ij to w_ij as pair-wise confidence
        # we add plddt as position-wise confidence
        # edge_index is unused
        x = self.layernorm(x)
        q = self.q_proj(x[x_center_index].unsqueeze(1)) * self.scaling
        v = self.v_proj(x)
        # if self.share_kv:
        #     k = v[:, :, :self.x_head_dim]
        # else:
        k = self.k_proj(x)

        dk = self.act(self.dk_proj(f_ij)) 
        dk_dist = self.act(self.dk_dist_proj(f_dist_ij)) 
        dv = self.act(self.dv_proj(f_ij)) 
        dv_dist = self.act(self.dv_dist_proj(f_dist_ij)) 

        # full graph attention
        x, attn = self.attention(
            q=q,
            k=k,
            v=v,
            dk=dk,
            dk_dist=dk_dist,
            dv=dv,
            dv_dist=dv_dist,
            w_ij=nn.functional.sigmoid(self.pae_weight(w_ij.unsqueeze(-1)).squeeze(-1)),
            key_padding_mask=key_padding_mask,
        )
        if return_attn:
            return x, attn
        else:
            return x, None

    def attention(self, q, k, v, dk, dk_dist, dv, dv_dist, w_ij, key_padding_mask=None, need_head_weights=False):
        # note that q is of shape (bsz, tgt_len, num_heads * head_dim)
        # k, v is of shape (bsz, src_len, num_heads * head_dim)
        # vec is of shape (bsz, src_len, 3, num_heads * head_dim)
        # dk, dk_dist, dv, dv_dist is of shape (bsz, tgt_len, src_len, num_heads * head_dim)
        # d_ij is of shape (bsz, tgt_len, src_len, 3)
        # w_ij is of shape (bsz, tgt_len, src_len)
        # key_padding_mask is of shape (bsz, src_len)
        bsz, tgt_len, _ = q.size()
        src_len = k.size(1)
        # change q size to (bsz * num_heads, tgt_len, head_dim)
        # change k,v size to (bsz * num_heads, src_len, head_dim)
        q = q.transpose(0, 1).reshape(tgt_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous()
        k = k.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous()
        v = v.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous()
        # dk size is (bsz, tgt_len, src_len, num_heads * head_dim)
        # if dk is not None:
        # change dk to (bsz * num_heads, tgt_len, src_len, head_dim)
        dk = dk.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous()
        # if dk_dist is not None:
        # change dk_dist to (bsz * num_heads, tgt_len, src_len, head_dim)
        dk_dist = dk_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous()
        # dv size is (bsz, tgt_len, src_len, num_heads * head_dim)
        # if dv is not None:
        # change dv to (bsz * num_heads, tgt_len, src_len, head_dim)
        dv = dv.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous()
        # if dv_dist is not None:
        # change dv_dist to (bsz * num_heads, tgt_len, src_len, head_dim)
        dv_dist = dv_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous()
        # if key_padding_mask is not None:
        # key_padding_mask should be (bsz, src_len)
        assert key_padding_mask.size(0) == bsz
        assert key_padding_mask.size(1) == src_len
        # attn_weights size is (bsz * num_heads, tgt_len, src_len, head_dim)
        attn_weights = torch.multiply(q[:, :, None, :], k[:, None, :, :])
        # w_ij is PAE confidence
        # w_ij size is (bsz, tgt_len, src_len)
        # change dimension of w_ij to (bsz * num_heads, tgt_len, src_len, head_dim)
        # if dk_dist is not None:
        assert w_ij is not None
        #     if dk is not None:
        attn_weights *= (dk + dk_dist * w_ij[:, :, :, None].repeat(self.num_heads, 1, 1, self.x_head_dim))
        # add dv and dv_dist
        v = v.unsqueeze(1) + dv + dv_dist * w_ij[:, :, :, None].repeat(self.num_heads, 1, 1, self.x_head_dim + 2 * self.vec_head_dim)
        #     else:
        #         attn_weights *= dk_dist * w_ij
        # else:
        #     if dk is not None:
        #         attn_weights *= dk
        # attn_weights size is (bsz * num_heads, tgt_len, src_len)
        attn_weights = attn_weights.sum(dim=-1)
        # apply key_padding_mask to attn_weights
        # if key_padding_mask is not None:
        # don't attend to padding symbols
        attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).contiguous()
        attn_weights = attn_weights.masked_fill(
            key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
        )
        attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len).contiguous()
        # apply softmax to attn_weights
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
        # first get invariant feature outputs x, size is (bsz * num_heads, tgt_len, head_dim)
        x_out = torch.einsum('bts,btsh->bth', attn_weights, v)
        attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
        # if not need_head_weights:
        # average attention weights over heads
        attn_weights = attn_weights.mean(dim=0)
        # reshape x_out to (bsz, tgt_len, num_heads * head_dim)
        x_out = x_out.transpose(1, 0).reshape(tgt_len, bsz, self.num_heads * self.x_head_dim).transpose(1, 0).contiguous()
        return x_out, attn_weights


class MultiHeadAttentionSoftMaxStarGraph(nn.Module):
    """Equivariant multi-head attention layer with softmax, apply attention on full graph by default"""
    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            share_kv,
            edge_attr_channels,
            edge_attr_dist_channels,
            distance_influence,
            num_heads,
            activation,
            cutoff_lower,
            cutoff_upper,
            use_lora=None,
    ):
        # same as EquivariantPAEMultiHeadAttentionSoftMax, but apply attention on full graph by default
        super(MultiHeadAttentionSoftMaxStarGraph, self).__init__()
        assert x_hidden_channels % num_heads == 0 \
            and vec_channels % num_heads == 0, (
                f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) "
                f"and vec_channels ({vec_channels}) "
                f"must be evenly divisible by the number of "
                f"attention heads ({num_heads})"
            )
        assert vec_hidden_channels == x_channels, (
            f"The number of hidden channels x_channels ({x_channels}) "
            f"and vec_hidden_channels ({vec_hidden_channels}) "
            f"must be equal"
        )

        self.distance_influence = distance_influence
        self.num_heads = num_heads
        self.x_channels = x_channels
        self.x_hidden_channels = x_hidden_channels
        self.x_head_dim = x_hidden_channels // num_heads
        self.vec_channels = vec_channels
        self.vec_hidden_channels = vec_hidden_channels
        # important, not vec_hidden_channels // num_heads
        self.vec_head_dim = vec_channels // num_heads
        self.share_kv = share_kv
        self.layernorm = nn.LayerNorm(x_channels)
        self.act = activation()
        self.cutoff = None
        self.scaling = self.x_head_dim**-0.5
        if use_lora is not None:
            self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora)
            self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None
            self.v_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora)
        else:
            self.q_proj = nn.Linear(x_channels, x_hidden_channels)
            self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None
            self.v_proj = nn.Linear(x_channels, x_hidden_channels)

        self.dk_proj = None
        # self.dk_dist_proj = None
        self.dv_proj = None
        # self.dv_dist_proj = None
        if distance_influence in ["keys", "both"]:
            if use_lora is not None:
                self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora)
                # self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora)
            else:
                self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels)
                # self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels)

        if distance_influence in ["values", "both"]:
            if use_lora is not None:
                self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora)
                # self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora)
            else:
                self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels)
                # self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2)
        # set PAE weight as a learnable parameter, basiclly a sigmoid function
        # self.pae_weight = nn.Linear(1, 1, bias=True)
        self.reset_parameters()

    def reset_parameters(self):
        self.layernorm.reset_parameters()
        nn.init.xavier_uniform_(self.q_proj.weight)
        self.q_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.k_proj.weight)
        self.k_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.v_proj.weight)
        self.v_proj.bias.data.fill_(0)
        # self.pae_weight.weight.data.fill_(-0.5)
        # self.pae_weight.bias.data.fill_(7.5)
        if self.dk_proj:
            nn.init.xavier_uniform_(self.dk_proj.weight)
            self.dk_proj.bias.data.fill_(0)
        if self.dv_proj:
            nn.init.xavier_uniform_(self.dv_proj.weight)
            self.dv_proj.bias.data.fill_(0)
        # if self.dk_dist_proj:
        #     nn.init.xavier_uniform_(self.dk_dist_proj.weight)
        #     self.dk_dist_proj.bias.data.fill_(0)
        # if self.dv_dist_proj:
        #     nn.init.xavier_uniform_(self.dv_dist_proj.weight)
        #     self.dv_dist_proj.bias.data.fill_(0)

    def forward(self, x, x_center_index, w_ij, f_dist_ij, f_ij, key_padding_mask, return_attn=False):
        # we replaced r_ij to w_ij as pair-wise confidence
        # we add plddt as position-wise confidence
        # edge_index is unused
        x = self.layernorm(x)
        q = self.q_proj(x[x_center_index].unsqueeze(1)) * self.scaling
        v = self.v_proj(x)
        # if self.share_kv:
        #     k = v[:, :, :self.x_head_dim]
        # else:
        k = self.k_proj(x)

        dk = self.act(self.dk_proj(f_ij)) 
        # dk_dist = self.act(self.dk_dist_proj(f_dist_ij)) 
        dv = self.act(self.dv_proj(f_ij)) 
        # dv_dist = self.act(self.dv_dist_proj(f_dist_ij)) 

        # full graph attention
        x, attn = self.attention(
            q=q,
            k=k,
            v=v,
            dk=dk,
            # dk_dist=dk_dist,
            dv=dv,
            # dv_dist=dv_dist,
            # w_ij=nn.functional.sigmoid(self.pae_weight(w_ij.unsqueeze(-1)).squeeze(-1)),
            key_padding_mask=key_padding_mask,
        )
        if return_attn:
            return x, attn
        else:
            return x, None

    def attention(self, q, k, v, dk, dv, key_padding_mask=None, need_head_weights=False):
        # note that q is of shape (bsz, tgt_len, num_heads * head_dim)
        # k, v is of shape (bsz, src_len, num_heads * head_dim)
        # vec is of shape (bsz, src_len, 3, num_heads * head_dim)
        # dk, dk_dist, dv, dv_dist is of shape (bsz, tgt_len, src_len, num_heads * head_dim)
        # d_ij is of shape (bsz, tgt_len, src_len, 3)
        # w_ij is of shape (bsz, tgt_len, src_len)
        # key_padding_mask is of shape (bsz, src_len)
        bsz, tgt_len, _ = q.size()
        src_len = k.size(1)
        # change q size to (bsz * num_heads, tgt_len, head_dim)
        # change k,v size to (bsz * num_heads, src_len, head_dim)
        q = q.transpose(0, 1).reshape(tgt_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous()
        k = k.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous()
        v = v.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous()
        # dk size is (bsz, tgt_len, src_len, num_heads * head_dim)
        # if dk is not None:
        # change dk to (bsz * num_heads, tgt_len, src_len, head_dim)
        dk = dk.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous()
        # if dk_dist is not None:
        # change dk_dist to (bsz * num_heads, tgt_len, src_len, head_dim)
        # dk_dist = dk_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous()
        # dv size is (bsz, tgt_len, src_len, num_heads * head_dim)
        # if dv is not None:
        # change dv to (bsz * num_heads, tgt_len, src_len, head_dim)
        dv = dv.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous()
        # if dv_dist is not None:
        # change dv_dist to (bsz * num_heads, tgt_len, src_len, head_dim)
        # dv_dist = dv_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous()
        # if key_padding_mask is not None:
        # key_padding_mask should be (bsz, src_len)
        assert key_padding_mask.size(0) == bsz
        assert key_padding_mask.size(1) == src_len
        # attn_weights size is (bsz * num_heads, tgt_len, src_len, head_dim)
        attn_weights = torch.multiply(q[:, :, None, :], k[:, None, :, :])
        # w_ij is PAE confidence
        # w_ij size is (bsz, tgt_len, src_len)
        # change dimension of w_ij to (bsz * num_heads, tgt_len, src_len, head_dim)
        # if dk_dist is not None:
        # assert w_ij is not None
        #     if dk is not None:
        attn_weights *= dk
        # add dv and dv_dist
        v = v.unsqueeze(1) + dv
        #     else:
        #         attn_weights *= dk_dist * w_ij
        # else:
        #     if dk is not None:
        #         attn_weights *= dk
        # attn_weights size is (bsz * num_heads, tgt_len, src_len)
        attn_weights = attn_weights.sum(dim=-1)
        # apply key_padding_mask to attn_weights
        # if key_padding_mask is not None:
        # don't attend to padding symbols
        attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).contiguous()
        attn_weights = attn_weights.masked_fill(
            key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
        )
        attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len).contiguous()
        # apply softmax to attn_weights
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
        # first get invariant feature outputs x, size is (bsz * num_heads, tgt_len, head_dim)
        x_out = torch.einsum('bts,btsh->bth', attn_weights, v)
        attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
        # if not need_head_weights:
        # average attention weights over heads
        attn_weights = attn_weights.mean(dim=0)
        # reshape x_out to (bsz, tgt_len, num_heads * head_dim)
        x_out = x_out.transpose(1, 0).reshape(tgt_len, bsz, self.num_heads * self.x_head_dim).transpose(1, 0).contiguous()
        return x_out, attn_weights


# original torchmd-net attention layer, let k, v share the same projection
class EquivariantProMultiHeadAttention(MessagePassing):
    """Equivariant multi-head attention layer."""

    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            edge_attr_channels,
            distance_influence,
            num_heads,
            activation,
            attn_activation,
            cutoff_lower,
            cutoff_upper,
    ):
        super(EquivariantMultiHeadAttention, self).__init__(
            aggr="mean", node_dim=0)
        assert x_hidden_channels % num_heads == 0 \
            and vec_channels % num_heads == 0, (
                f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) "
                f"and vec_channels ({vec_channels}) "
                f"must be evenly divisible by the number of "
                f"attention heads ({num_heads})"
            )
        assert vec_hidden_channels == x_channels, (
            f"The number of hidden channels x_channels ({x_channels}) "
            f"and vec_hidden_channels ({vec_hidden_channels}) "
            f"must be equal"
        )

        self.distance_influence = distance_influence
        self.num_heads = num_heads
        self.x_channels = x_channels
        self.x_hidden_channels = x_hidden_channels
        self.x_head_dim = x_hidden_channels // num_heads
        self.vec_channels = vec_channels
        self.vec_hidden_channels = vec_hidden_channels
        # important, not vec_hidden_channels // num_heads
        self.vec_head_dim = vec_channels // num_heads

        self.layernorm = nn.LayerNorm(x_channels)
        self.act = activation()
        self.attn_activation = act_class_mapping[attn_activation]()
        self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper)

        self.q_proj = nn.Linear(x_channels, x_hidden_channels)
        # self.k_proj = nn.Linear(x_channels, x_hidden_channels)
        self.kv_proj = nn.Linear(
            x_channels, x_hidden_channels + vec_channels * 2)
        self.o_proj = nn.Linear(
            x_hidden_channels, x_channels * 2 + vec_channels)

        self.vec_proj = nn.Linear(
            vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False)

        self.dk_proj = None
        if distance_influence in ["keys", "both"]:
            self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels)

        self.dv_proj = None
        if distance_influence in ["values", "both"]:
            self.dv_proj = nn.Linear(
                edge_attr_channels, x_hidden_channels + vec_channels * 2)

        self.reset_parameters()

    def reset_parameters(self):
        self.layernorm.reset_parameters()
        nn.init.xavier_uniform_(self.q_proj.weight)
        self.q_proj.bias.data.fill_(0)
        # nn.init.xavier_uniform_(self.k_proj.weight)
        # self.k_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.kv_proj.weight)
        self.kv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.vec_proj.weight)
        if self.dk_proj:
            nn.init.xavier_uniform_(self.dk_proj.weight)
            self.dk_proj.bias.data.fill_(0)
        if self.dv_proj:
            nn.init.xavier_uniform_(self.dv_proj.weight)
            self.dv_proj.bias.data.fill_(0)

    def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij, return_attn=False):
        x = self.layernorm(x)
        q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim)
        # k = self.k_proj(x).reshape(-1, self.num_heads, self.x_head_dim)
        v = self.kv_proj(x).reshape(-1, self.num_heads,
                                   self.x_head_dim + self.vec_head_dim * 2)
        k = v[:, :, :self.x_head_dim]

        vec1, vec2, vec3 = torch.split(self.vec_proj(vec),
                                       [self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1)
        vec = vec.reshape(-1, 3, self.num_heads, self.vec_head_dim)
        vec_dot = (vec1 * vec2).sum(dim=1)

        dk = (
            self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim)
            if self.dk_proj is not None
            else None
        )
        dv = (
            self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2)
            if self.dv_proj is not None
            else None
        )

        # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor,
        # d_ij: Tensor)
        x, vec, attn = self.propagate(
            edge_index,
            q=q,
            k=k,
            v=v,
            vec=vec,
            dk=dk,
            dv=dv,
            r_ij=r_ij,
            d_ij=d_ij,
            size=None,
        )
        x = x.reshape(-1, self.x_hidden_channels)
        vec = vec.reshape(-1, 3, self.vec_channels)

        o1, o2, o3 = torch.split(self.o_proj(
            x), [self.vec_channels, self.x_channels, self.x_channels], dim=1)
        dx = vec_dot * o2 + o3
        dvec = vec3 * o1.unsqueeze(1) + vec
        if return_attn:
            return dx, dvec, torch.concat((edge_index.T, attn), dim=1)
        else:
            return dx, dvec, None

    def message(self, q_i, k_j, v_j, vec_j, dk, dv, r_ij, d_ij):
        # attention mechanism
        if dk is None:
            attn = (q_i * k_j).sum(dim=-1)
        else:  # TODO: consider add or multiply dk
            attn = (q_i * k_j * dk).sum(dim=-1)

        # attention activation function
        attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1)

        # value pathway
        if dv is not None:
            v_j = v_j * dv
        x, vec1, vec2 = torch.split(
            v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2)

        # update scalar features
        x = x * attn.unsqueeze(2)
        # update vector features
        vec = vec_j * vec1.unsqueeze(1) + vec2.unsqueeze(1) * \
            d_ij.unsqueeze(2).unsqueeze(3)
        return x, vec, attn

    def aggregate(
            self,
            features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            index: torch.Tensor,
            ptr: Optional[torch.Tensor],
            dim_size: Optional[int],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x, vec, attn = features
        x = scatter(x, index, dim=self.node_dim, dim_size=dim_size)
        vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size)
        return x, vec, attn

    def update(
            self, inputs: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return inputs

    def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor:
        pass

    def edge_update(self) -> Tensor:
        pass


# softmax version of torchmd-net attention layer
class EquivariantMultiHeadAttentionSoftMax(EquivariantMultiHeadAttention):
    """Equivariant multi-head attention layer with softmax"""

    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            share_kv,
            edge_attr_channels,
            distance_influence,
            num_heads,
            activation,
            attn_activation,
            cutoff_lower,
            cutoff_upper,
            use_lora=None,
    ):
        super(EquivariantMultiHeadAttentionSoftMax, self).__init__(x_channels=x_channels,
                                                                   x_hidden_channels=x_hidden_channels,
                                                                   vec_channels=vec_channels,
                                                                   vec_hidden_channels=vec_hidden_channels,
                                                                   share_kv=share_kv,
                                                                   edge_attr_channels=edge_attr_channels,
                                                                   distance_influence=distance_influence,
                                                                   num_heads=num_heads,
                                                                   activation=activation,
                                                                   attn_activation=attn_activation,
                                                                   cutoff_lower=cutoff_lower,
                                                                   cutoff_upper=cutoff_upper,
                                                                   use_lora=use_lora)
        self.attn_activation = nn.LeakyReLU(0.2)

    def message(self, q_i, k_j, v_j, vec_j, dk, dv, r_ij, d_ij,
                index: Tensor,
                ptr: Optional[Tensor],
                size_i: Optional[int]):
        # attention mechanism
        if dk is None:
            attn = (q_i * k_j).sum(dim=-1)
        else:  # TODO: consider add or multiply dk
            attn = (q_i * k_j * dk).sum(dim=-1)

        # attention activation function
        attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1)
        attn = softmax(attn, index, ptr, size_i)
        # TODO: consider drop out attn or not.
        # attn = F.dropout(attn, p=self.dropout, training=self.training)
        # value pathway
        if dv is not None:
            v_j = v_j * dv
        x, vec1, vec2 = torch.split(
            v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2)

        # update scalar features
        x = x * attn.unsqueeze(2)
        # update vector features
        vec = (vec1.unsqueeze(1) * vec_j + vec2.unsqueeze(1) * d_ij.unsqueeze(2).unsqueeze(3)) \
            * attn.unsqueeze(1).unsqueeze(3)
        return x, vec, attn


# softmax version of torchmd-net attention layer, add pair-wise confidence of PAE
class EquivariantPAEMultiHeadAttentionSoftMax(EquivariantPAEMultiHeadAttention):
    """Equivariant multi-head attention layer with softmax"""

    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            share_kv,
            edge_attr_channels,
            edge_attr_dist_channels,
            distance_influence,
            num_heads,
            activation,
            attn_activation,
            cutoff_lower,
            cutoff_upper,
            use_lora=None,
    ):
        super(EquivariantPAEMultiHeadAttentionSoftMax, self).__init__(
            x_channels=x_channels,
            x_hidden_channels=x_hidden_channels,
            vec_channels=vec_channels,
            vec_hidden_channels=vec_hidden_channels,
            share_kv=share_kv,
            edge_attr_channels=edge_attr_channels,
            edge_attr_dist_channels=edge_attr_dist_channels,
            distance_influence=distance_influence,
            num_heads=num_heads,
            activation=activation,
            attn_activation=attn_activation,
            cutoff_lower=cutoff_lower,
            cutoff_upper=cutoff_upper,
            use_lora=use_lora)
        self.attn_activation = nn.LeakyReLU(0.2)

    def message(self, q_i, k_j, v_j, vec_j, dk, dk_dist, dv, dv_dist, d_ij, w_ij,
                index: Tensor,
                ptr: Optional[Tensor],
                size_i: Optional[int]):
        # attention mechanism
        attn = (q_i * k_j)
        if dk is not None:
            attn += dk
        if dk_dist is not None:
            attn += dk_dist * w_ij.unsqueeze(1).unsqueeze(2)
        attn = attn.sum(dim=-1)
        # attention activation function
        attn = self.attn_activation(attn)
        attn = softmax(attn, index, ptr, size_i)
        # TODO: consider drop out attn or not.
        # attn = F.dropout(attn, p=self.dropout, training=self.training)
        # value pathway
        if dv is not None:
            v_j += dv
        if dv_dist is not None:
            v_j += dv_dist * w_ij.unsqueeze(1).unsqueeze(2)
        x, vec1, vec2 = torch.split(
            v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2)

        # update scalar features
        x = x * attn.unsqueeze(2)
        # update vector features
        vec = (vec1.unsqueeze(1) * vec_j + vec2.unsqueeze(1) * d_ij.unsqueeze(2).unsqueeze(3)) \
            * attn.unsqueeze(1).unsqueeze(3)
        return x, vec, attn

# softmax version of torchmd-net attention layer, add pair-wise confidence of PAE
class EquivariantWeightedPAEMultiHeadAttentionSoftMax(EquivariantWeightedPAEMultiHeadAttention):
    """Equivariant multi-head attention layer with softmax"""

    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            share_kv,
            edge_attr_channels,
            edge_attr_dist_channels,
            distance_influence,
            num_heads,
            activation,
            attn_activation,
            cutoff_lower,
            cutoff_upper,
            use_lora=None,
    ):
        super(EquivariantWeightedPAEMultiHeadAttentionSoftMax, self).__init__(
            x_channels=x_channels,
            x_hidden_channels=x_hidden_channels,
            vec_channels=vec_channels,
            vec_hidden_channels=vec_hidden_channels,
            share_kv=share_kv,
            edge_attr_channels=edge_attr_channels,
            edge_attr_dist_channels=edge_attr_dist_channels,
            distance_influence=distance_influence,
            num_heads=num_heads,
            activation=activation,
            attn_activation=attn_activation,
            cutoff_lower=cutoff_lower,
            cutoff_upper=cutoff_upper,
            use_lora=use_lora)
        self.attn_activation = nn.LeakyReLU(0.2)

    def message(self, q_i, k_j, v_j, vec_j, dk, dk_dist, dv, dv_dist, d_ij, w_ij,
                index: Tensor,
                ptr: Optional[Tensor],
                size_i: Optional[int]):
        # attention mechanism
        attn = (q_i * k_j)
        if dk_dist is not None:
            if dk is not None:
                attn *= (dk + dk_dist * w_ij.unsqueeze(1).unsqueeze(2))
            else:
                attn *= dk_dist * w_ij
        else:
            if dk is not None:
                attn *= dk
        attn = attn.sum(dim=-1)
        # attention activation function
        attn = self.attn_activation(attn)
        attn = softmax(attn, index, ptr, size_i)
        # TODO: consider drop out attn or not.
        # attn = F.dropout(attn, p=self.dropout, training=self.training)
        # value pathway
        if dv is not None:
            v_j += dv
        if dv_dist is not None:
            v_j += dv_dist * w_ij.unsqueeze(1).unsqueeze(2)
        x, vec1, vec2 = torch.split(
            v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2)

        # update scalar features
        x = x * attn.unsqueeze(2)
        # update vector features
        vec = (vec1.unsqueeze(1) * vec_j + vec2.unsqueeze(1) * d_ij.unsqueeze(2).unsqueeze(3)) \
            * attn.unsqueeze(1).unsqueeze(3)
        return x, vec, attn


# MSA encoder adapted from gMVP
class MSAEncoder(nn.Module):
    def __init__(self, num_species, pairwise_type, weighting_schema):
        """[summary]

        Args:
            num_species (int): Number of species to use from MSA. [1,200] // 200 used in default gMVP
            pairwise_type ([str]): method for calculating pairwise coevolution. only "cov" supported
            weighting_schema ([str]): species weighting type; "spe" -> use dense layer to weight speices 
                                        "none" -> uniformly weight species 

        Raises:
            NotImplementedError: [description]
        """        
        super(MSAEncoder, self).__init__()
        self.num_species = num_species
        self.pairwise_type = pairwise_type
        self.weighting_schema = weighting_schema
        if self.weighting_schema == 'spe':
            self.W = nn.parameter.Parameter(
                torch.zeros((1, num_species)), 
                requires_grad=True)

        elif self.weighting_schema  == 'none':
            self.W = torch.tensor(1.0 / self.num_species).repeat(self.num_species)
        else:
            raise NotImplementedError
       
    def forward(self, x, edge_index):
        # x: L nodes x N num_species
        shape  = x.shape
        L, N = shape[0], shape[1]
        E = edge_index.shape[1]
        
        A = 21 # number of amino acids 
        x = x[:, :self.num_species]
        if self.weighting_schema == 'spe':
            sm = torch.nn.Softmax(dim=-1)
            W = sm(self.W)
        else:
            W = self.W
        x = nn.functional.one_hot(x.type(torch.int64), A).type(torch.float32) # L x N x A
        x1 = torch.matmul(W[:, None], x) # L x 1 x A

        if self.pairwise_type  == 'fre':
            x2 = torch.matmul(x[edge_index[0], :, :, None], x[edge_index[1], :, None, :]) # E x N x A x A
            x2 = x2.reshape((E, N, A * A)) # E x N x (A x A)
            x2 = (W[:, :, None] * x2).sum(dim=1) # E x (A x A)
        elif self.pairwise_type == 'cov':
            #numerical stability
            x2 = torch.matmul(x[edge_index[0], :, :, None], x[edge_index[1], :, None, :]) # E x N x A x A
            x2 = (W[:, :, None, None] * x2).sum(dim=1) # E x A x A
            x2_t = x1[edge_index[0], 0, :, None] * x1[edge_index[1], 0, None, :] # E x A x A
            x2 = (x2 - x2_t).reshape(E, A * A) # E x (A x A)
            x2 = x2.reshape(E, A * A) # E x (A x A)
            norm = torch.sqrt(torch.sum(torch.square(x2), dim=-1, keepdim=True) + 1e-12)
            x2 = torch.cat([x2, norm], dim=-1) # E x (A x A + 1)
        elif self.pairwise_type == 'cov_all':
            print('cov_all not implemented in EvolEncoder2')
            raise NotImplementedError
        elif self.pairwise_type == 'inv_cov':
            print('in_cov not implemented in EvolEncoder2')
            raise NotImplementedError
        elif self.pairwise_type == 'none':
            x2 = None
        else:
            raise NotImplementedError(
                f'pairwise_type {self.pairwise_type} not implemented')

        x1 = torch.squeeze(x1, dim=1) # L x A

        return x1, x2


# MSA encoder adapted from gMVP
class MSAEncoderFullGraph(nn.Module):
    def __init__(self, num_species, pairwise_type, weighting_schema):
        """[summary]

        Args:
            num_species (int): Number of species to use from MSA. [1,200] // 200 used in default gMVP
            pairwise_type ([str]): method for calculating pairwise coevolution. only "cov" supported
            weighting_schema ([str]): species weighting type; "spe" -> use dense layer to weight speices 
                                        "none" -> uniformly weight species 

        Raises:
            NotImplementedError: [description]
        """        
        super(MSAEncoderFullGraph, self).__init__()
        self.num_species = num_species
        self.pairwise_type = pairwise_type
        self.weighting_schema = weighting_schema
        if self.weighting_schema == 'spe':
            self.W = nn.parameter.Parameter(
                torch.zeros((num_species)), 
                requires_grad=True)

        elif self.weighting_schema  == 'none':
            self.W = torch.tensor(1.0 / self.num_species).repeat(self.num_species)
        else:
            raise NotImplementedError
       
    def forward(self, x):
        # x: B batch size x L lenth x N num_species
        shape  = x.shape
        B, L, N = shape[0], shape[1], shape[2]
        A = 21 # number of amino acids 
        x = x[:, :, :self.num_species]
        if self.weighting_schema == 'spe':
            W = torch.nn.functional.softmax(self.W, dim=-1)
        else:
            W = self.W
        x = nn.functional.one_hot(x.type(torch.int64), A).type(torch.float32) # B x L x N x A
        x1 = torch.einsum('blna,n->bla', x, W) # B x L x A

        if self.pairwise_type == 'cov':
            #numerical stability
            # x2 = torch.einsum('bLnA,blna,n->bLlAa', x, x, W) # B x L x L x A x A, check if ram supports this
            # x2_t = x1[:, :, None, :, None] * x1[:, None, :, None, :] # B x L x L x A x A
            # x2 = (x2 - x2_t).reshape(B, L, L, A * A) # B x L x L x (A x A)
            # complete that in one line to save memory
            x2 = (torch.einsum('bLnA,blna,n->bLlAa', x, x, W) - x1[:, :, None, :, None] * x1[:, None, :, None, :]).reshape(B, L, L, A * A)
            norm = torch.sqrt(torch.sum(torch.square(x2), dim=-1, keepdim=True) + 1e-12) # B x L x L x 1
            x2 = torch.cat([x2, norm], dim=-1) # B x L x L x (A x A + 1)
        elif self.pairwise_type == 'cov_all':
            print('cov_all not implemented in EvolEncoder2')
            raise NotImplementedError
        elif self.pairwise_type == 'inv_cov':
            print('in_cov not implemented in EvolEncoder2')
            raise NotImplementedError
        elif self.pairwise_type == 'none':
            x2 = None
        else:
            raise NotImplementedError(
                f'pairwise_type {self.pairwise_type} not implemented')
        return x1, x2


class NodeToEdgeAttr(nn.Module):
    def __init__(self, node_channel, hidden_channel, edge_attr_channel, use_lora=None, layer_norm=False):
        super().__init__()
        self.layer_norm = layer_norm
        if layer_norm:
            self.layernorm = nn.LayerNorm(node_channel)
        if use_lora is not None:
            self.proj = lora.Linear(node_channel, hidden_channel * 2, bias=True, r=use_lora)
            self.o_proj = lora.Linear(2 * hidden_channel, edge_attr_channel, r=use_lora)
        else:
            self.proj = nn.Linear(node_channel, hidden_channel * 2, bias=True)
            self.o_proj = nn.Linear(2 * hidden_channel, edge_attr_channel, bias=True)

        torch.nn.init.zeros_(self.proj.bias)
        torch.nn.init.zeros_(self.o_proj.bias)

    def forward(self, x, edge_index):
        """
        Inputs:
          x: N x sequence_state_dim

        Output:
          edge_attr: edge_index.shape[0] x pairwise_state_dim

        Intermediate state:
          B x L x L x 2*inner_dim
        """
        x = self.layernorm(x) if self.layer_norm else x
        q, k = self.proj(x).chunk(2, dim=-1)

        prod = q[edge_index[0], :] * k[edge_index[1], :]
        diff = q[edge_index[0], :] - k[edge_index[1], :]

        edge_attr = torch.cat([prod, diff], dim=-1)
        edge_attr = self.o_proj(edge_attr)

        return edge_attr


class MultiplicativeUpdate(MessagePassing):
    def __init__(self, vec_in_channel, hidden_channel, hidden_vec_channel, ee_channels=None, use_lora=None, layer_norm=True) -> None:
        super(MultiplicativeUpdate, self).__init__(aggr="mean")
        self.vec_in_channel = vec_in_channel
        self.hidden_channel = hidden_channel
        self.hidden_vec_channel = hidden_vec_channel

        if use_lora is not None:
            self.linear_a_p = lora.Linear(self.vec_in_channel, self.hidden_vec_channel, bias=False, r=use_lora)
            self.linear_b_p = lora.Linear(self.vec_in_channel, self.hidden_vec_channel, bias=False, r=use_lora)
            self.linear_g = lora.Linear(self.hidden_vec_channel, self.hidden_channel, r=use_lora)
        else:
            self.linear_a_p = nn.Linear(self.vec_in_channel, self.hidden_vec_channel, bias=False)
            self.linear_b_p = nn.Linear(self.vec_in_channel, self.hidden_vec_channel, bias=False)
            self.linear_g = nn.Linear(self.hidden_vec_channel, self.hidden_channel)
        if ee_channels is not None:
            if use_lora is not None:
                self.linear_ee = lora.Linear(ee_channels, self.hidden_channel, r=use_lora)
            else:
                self.linear_ee = nn.Linear(ee_channels, self.hidden_channel)
        else:
            self.linear_ee = None
        self.layer_norm = layer_norm
        if layer_norm:
            self.layer_norm_in = nn.LayerNorm(self.hidden_channel)
            self.layer_norm_out = nn.LayerNorm(self.hidden_channel)

        self.sigmoid = nn.Sigmoid()

    def forward(self,
                edge_attr: torch.Tensor,
                edge_vec: torch.Tensor,
                edge_edge_index: torch.Tensor,
                edge_edge_attr: torch.Tensor,
                ) -> torch.Tensor:
        """
        Args:
            edge_vec:
                [*, 3, in_channel] input tensor
            edge_attr:
                [*, hidden_channel] input mask
        Returns:
            [*, hidden_channel] output tensor
        """
        if self.layer_norm:
            x = self.layer_norm_in(edge_attr)
        x = self.propagate(edge_index=edge_edge_index, 
                           a=self.linear_a_p(edge_vec).reshape(edge_attr.shape[0], -1),
                           b=self.linear_b_p(edge_vec).reshape(edge_attr.shape[0], -1),
                           edge_attr=x,
                           ee_ij=edge_edge_attr, )
        if self.layer_norm:
            x = self.layer_norm_out(x)
        edge_attr = edge_attr + x
        return edge_attr

    def message(self, a_i: Tensor, b_j: Tensor, edge_attr_j: Tensor, ee_ij: Tensor,) -> Tensor:
        # a: [*, 3, hidden_channel]
        # b: [*, 3, hidden_channel]
        s = (a_i.reshape(-1, 3, self.hidden_vec_channel).permute(0, 2, 1) \
             * b_j.reshape(-1, 3, self.hidden_vec_channel).permute(0, 2, 1)).sum(dim=-1)
        if ee_ij is not None and self.linear_ee is not None:
            s = self.sigmoid(self.linear_ee(ee_ij) + self.linear_g(s))
        else:
            s = self.sigmoid(self.linear_g(s))
        return s * edge_attr_j


# let k v share the same weight
class EquivariantTriAngularMultiHeadAttention(MessagePassing):
    """Equivariant multi-head attention layer. Add Triangular update between edges."""

    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            edge_attr_channels,
            distance_influence,
            num_heads,
            activation,
            attn_activation,
            cutoff_lower,
            cutoff_upper,
            triangular_update=False,
            ee_channels=None,
    ):
        super(EquivariantTriAngularMultiHeadAttention, self).__init__(aggr="mean", node_dim=0)

        self.distance_influence = distance_influence
        self.num_heads = num_heads
        self.x_channels = x_channels
        self.x_hidden_channels = x_hidden_channels
        self.x_head_dim = x_hidden_channels // num_heads
        self.vec_channels = vec_channels
        self.vec_hidden_channels = vec_hidden_channels
        self.ee_channels = ee_channels
        # important, not vec_hidden_channels // num_heads

        self.layernorm_in = nn.LayerNorm(x_channels)
        self.layernorm_out = nn.LayerNorm(x_hidden_channels)

        self.act = activation()
        self.attn_activation = act_class_mapping[attn_activation]()

        self.q_proj = nn.Linear(x_channels, x_hidden_channels)
        self.kv_proj = nn.Linear(x_channels, x_hidden_channels)
        # self.v_proj = nn.Linear(x_channels, x_hidden_channels)
        self.o_proj = nn.Linear(x_hidden_channels, x_hidden_channels)
        self.out = nn.Linear(x_hidden_channels, x_channels)
        # add residue to x
        # self.residue_hidden = nn.Linear(x_channels, x_hidden_channels)

        self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels)
        self.triangular_update = triangular_update
        if self.triangular_update:
            self.edge_triangle_start_update = MultiplicativeUpdate(vec_in_channel=vec_channels,
                                                                hidden_channel=edge_attr_channels,
                                                                hidden_vec_channel=vec_hidden_channels,
                                                                ee_channels=ee_channels, )
            self.edge_triangle_end_update = MultiplicativeUpdate(vec_in_channel=vec_channels,
                                                                hidden_channel=edge_attr_channels,
                                                                hidden_vec_channel=vec_hidden_channels,
                                                                ee_channels=ee_channels, )
            self.node_to_edge_attr = NodeToEdgeAttr(node_channel=x_channels,
                                                    hidden_channel=x_hidden_channels,
                                                    edge_attr_channel=edge_attr_channels)

        self.reset_parameters()

    def reset_parameters(self):
        self.layernorm_in.reset_parameters()
        self.layernorm_out.reset_parameters()
        nn.init.xavier_uniform_(self.q_proj.weight)
        self.q_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.kv_proj.weight)
        self.kv_proj.bias.data.fill_(0)
        # nn.init.xavier_uniform_(self.v_proj.weight)
        # self.v_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)
        if self.dk_proj:
            nn.init.xavier_uniform_(self.dk_proj.weight)
            self.dk_proj.bias.data.fill_(0)

    def get_start_index(self, edge_index):
        edge_start_index = []
        start_node_count = edge_index[0].unique(return_counts=True)
        start_nodes = start_node_count[0][start_node_count[1] > 1]
        for i in start_nodes:
            node_start_index = torch.where(edge_index[0] == i)[0]
            candidates = torch.combinations(node_start_index, r=2).T
            edge_start_index.append(torch.cat([candidates, candidates.flip(0)], dim=1))
        edge_start_index = torch.concat(edge_start_index, dim=1)
        edge_start_index = edge_start_index[:, edge_start_index[0] != edge_start_index[1]]
        return edge_start_index

    def get_end_index(self, edge_index):
        edge_end_index = []
        end_node_count = edge_index[1].unique(return_counts=True)
        end_nodes = end_node_count[0][end_node_count[1] > 1]
        for i in end_nodes:
            node_end_index = torch.where(edge_index[1] == i)[0]
            candidates = torch.combinations(node_end_index, r=2).T
            edge_end_index.append(torch.cat([candidates, candidates.flip(0)], dim=1))
        edge_end_index = torch.concat(edge_end_index, dim=1)
        edge_end_index = edge_end_index[:, edge_end_index[0] != edge_end_index[1]]
        return edge_end_index

    def forward(self, x, coords, edge_index, edge_attr, edge_vec, return_attn=False):
        residue = x
        x = self.layernorm_in(x)
        q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim)
        k = self.kv_proj(x).reshape(-1, self.num_heads, self.x_head_dim)
        v = k

        # point ettr to edge_attr
        if self.triangular_update:
            edge_attr += self.node_to_edge_attr(x, edge_index)

            # Triangular edge update
            # TODO: Add drop out layers here
            edge_edge_index = self.get_start_index(edge_index)
            if self.ee_channels is not None:
                edge_edge_attr = coords[edge_index[1][edge_edge_index[0]], :, [0]] - coords[edge_index[1][edge_edge_index[1]], :, [0]]
                edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True)
            else:
                edge_edge_attr = None
            edge_attr = self.edge_triangle_start_update(
                edge_attr, edge_vec, 
                edge_edge_index,
                edge_edge_attr
            )
            edge_edge_index = self.get_end_index(edge_index)
            if self.ee_channels is not None:
                edge_edge_attr = coords[edge_index[0][edge_edge_index[0]], :, [0]] - coords[edge_index[0][edge_edge_index[1]], :, [0]]
                edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True)
            else:
                edge_edge_attr = None
            edge_attr = self.edge_triangle_end_update(
                edge_attr, edge_vec, 
                edge_edge_index,
                edge_edge_attr
            )
            del edge_edge_attr, edge_edge_index

        dk = (
            self.act(self.dk_proj(edge_attr)).reshape(-1, self.num_heads, self.x_head_dim) 
            if self.dk_proj is not None else None
        )

        # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor,
        # d_ij: Tensor)
        x, attn = self.propagate(
            edge_index,
            q=q,
            k=k,
            v=v,
            dk=dk,
            size=None,
        )
        x = x.reshape(-1, self.x_hidden_channels)
        x = residue + x
        x = self.layernorm_out(x)
        x = gelu(self.o_proj(x))
        x = self.out(x)
        del residue, q, k, v, dk
        if return_attn:
            return x, edge_attr, torch.concat((edge_index.T, attn), dim=1)
        else:
            return x, edge_attr, None

    def message(self, q_i, k_j, v_j, dk):
        # attention mechanism
        if dk is None:
            attn = (q_i * k_j).sum(dim=-1)
        else:  # TODO: consider add or multiply dk
            attn = (q_i * k_j * dk).sum(dim=-1)

        # attention activation function
        attn = self.attn_activation(attn)

        # update scalar features
        x = v_j * attn.unsqueeze(2)
        return x, attn

    def aggregate(
            self,
            features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            index: torch.Tensor,
            ptr: Optional[torch.Tensor],
            dim_size: Optional[int],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x, attn = features
        x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
        return x, attn

    def update(
            self, inputs: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return inputs

    def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor:
        pass

    def edge_update(self) -> Tensor:
        pass


# let k v share the same weight, dropout attention weights, with option LoRA
class EquivariantTriAngularDropMultiHeadAttention(MessagePassing):
    """Equivariant multi-head attention layer. Add Triangular update between edges."""

    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            edge_attr_channels,
            distance_influence,
            num_heads,
            activation,
            attn_activation,
            rbf_channels, 
            triangular_update=False,
            ee_channels=None,
            drop_out_rate=0.0,
            use_lora=None,
            layer_norm=True,
    ):
        super(EquivariantTriAngularDropMultiHeadAttention, self).__init__(aggr="mean", node_dim=0)

        self.distance_influence = distance_influence
        self.num_heads = num_heads
        self.x_channels = x_channels
        self.x_hidden_channels = x_hidden_channels
        self.x_head_dim = x_hidden_channels // num_heads
        self.vec_channels = vec_channels
        self.vec_hidden_channels = vec_hidden_channels
        self.ee_channels = ee_channels
        self.rbf_channels = rbf_channels
        self.layer_norm = layer_norm
        # important, not vec_hidden_channels // num_heads
        if layer_norm:
            self.layernorm_in = nn.LayerNorm(x_channels)
            self.layernorm_out = nn.LayerNorm(x_hidden_channels)

        self.act = activation()
        self.attn_activation = act_class_mapping[attn_activation]()

        if use_lora is not None:
            self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora)
            self.kv_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora)
            self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora)
            self.o_proj = lora.Linear(x_hidden_channels, x_hidden_channels, r=use_lora)
        else:
            self.q_proj = nn.Linear(x_channels, x_hidden_channels)
            self.kv_proj = nn.Linear(x_channels, x_hidden_channels)
            self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels)
            self.o_proj = nn.Linear(x_hidden_channels, x_hidden_channels)

        self.triangular_drop = nn.Dropout(drop_out_rate)
        self.rbf_drop = nn.Dropout(drop_out_rate)
        self.dense_drop = nn.Dropout(drop_out_rate)
        self.dropout = nn.Dropout(drop_out_rate)
        self.triangular_update = triangular_update
        if self.triangular_update:
            self.edge_triangle_end_update = MultiplicativeUpdate(vec_in_channel=vec_channels,
                                                                 hidden_channel=edge_attr_channels,
                                                                 hidden_vec_channel=vec_hidden_channels,
                                                                 ee_channels=ee_channels, 
                                                                 layer_norm=layer_norm,
                                                                 use_lora=use_lora)
            self.node_to_edge_attr = NodeToEdgeAttr(node_channel=x_channels,
                                                    hidden_channel=x_hidden_channels,
                                                    edge_attr_channel=edge_attr_channels,
                                                    use_lora=use_lora)
            self.triangle_update_dropout = nn.Dropout(0.5)
        self.reset_parameters()

    def reset_parameters(self):
        if self.layer_norm:
            self.layernorm_in.reset_parameters()
            self.layernorm_out.reset_parameters()
        nn.init.xavier_uniform_(self.q_proj.weight)
        self.q_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.kv_proj.weight)
        self.kv_proj.bias.data.fill_(0)
        # nn.init.xavier_uniform_(self.v_proj.weight)
        # self.v_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)
        if self.dk_proj:
            nn.init.xavier_uniform_(self.dk_proj.weight)
            self.dk_proj.bias.data.fill_(0)

    def get_start_index(self, edge_index):
        edge_start_index = []
        start_node_count = edge_index[0].unique(return_counts=True)
        start_nodes = start_node_count[0][start_node_count[1] > 1]
        for i in start_nodes:
            node_start_index = torch.where(edge_index[0] == i)[0]
            candidates = torch.combinations(node_start_index, r=2).T
            edge_start_index.append(torch.cat([candidates, candidates.flip(0)], dim=1))
        edge_start_index = torch.concat(edge_start_index, dim=1)
        edge_start_index = edge_start_index[:, edge_start_index[0] != edge_start_index[1]]
        return edge_start_index

    def get_end_index(self, edge_index):
        edge_end_index = []
        end_node_count = edge_index[1].unique(return_counts=True)
        end_nodes = end_node_count[0][end_node_count[1] > 1]
        for i in end_nodes:
            node_end_index = torch.where(edge_index[1] == i)[0]
            candidates = torch.combinations(node_end_index, r=2).T
            edge_end_index.append(torch.cat([candidates, candidates.flip(0)], dim=1))
        edge_end_index = torch.concat(edge_end_index, dim=1)
        edge_end_index = edge_end_index[:, edge_end_index[0] != edge_end_index[1]]
        return edge_end_index

    def forward(self, x, coords, edge_index, edge_attr, edge_vec, return_attn=False):
        residue = x
        if self.layer_norm:
            x = self.layernorm_in(x)
        q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim)
        k = self.kv_proj(x).reshape(-1, self.num_heads, self.x_head_dim)
        v = k

        # point ettr to edge_attr
        if self.triangular_update:
            edge_attr += self.node_to_edge_attr(x, edge_index)
            # Triangular edge update
            # TODO: Add drop out layers here
            edge_edge_index = self.get_end_index(edge_index)
            edge_edge_index = edge_edge_index[:, self.triangular_drop(
                torch.ones(edge_edge_index.shape[1], device=edge_edge_index.device)
                ).to(torch.bool)]
            if self.ee_channels is not None:
                edge_edge_attr = coords[edge_index[0][edge_edge_index[0]], :, [0]] - coords[edge_index[0][edge_edge_index[1]], :, [0]]
                edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True)
            else:
                edge_edge_attr = None
            edge_attr = self.edge_triangle_end_update(
                edge_attr, edge_vec, 
                edge_edge_index,
                edge_edge_attr
            )
            del edge_edge_attr, edge_edge_index

        # drop rbfs
        edge_attr = torch.cat((edge_attr[:, :-self.rbf_channels],
                              self.rbf_drop(edge_attr[:, -self.rbf_channels:])), 
                              dim=-1)

        dk = (
            self.act(self.dk_proj(edge_attr)).reshape(-1, self.num_heads, self.x_head_dim) 
            if self.dk_proj is not None else None
        )

        # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor,
        # d_ij: Tensor)
        x, attn = self.propagate(
            edge_index,
            q=q,
            k=k,
            v=v,
            dk=dk,
            size=None,
        )
        x = x.reshape(-1, self.x_hidden_channels)
        if self.layer_norm:
            x = self.layernorm_out(x)
        x = self.dense_drop(x)
        x = residue + gelu(x)
        x = self.o_proj(x)
        x = self.dropout(x)
        del residue, q, k, v, dk
        if return_attn:
            return x, edge_attr, torch.concat((edge_index.T, attn), dim=1)
        else:
            return x, edge_attr, None

    def message(self, q_i, k_j, v_j, dk):
        # attention mechanism
        if dk is None:
            attn = (q_i * k_j).sum(dim=-1)
        else:  # TODO: consider add or multiply dk
            attn = (q_i * k_j * dk).sum(dim=-1)

        # attention activation function
        attn = self.attn_activation(attn)

        # update scalar features
        x = v_j * attn.unsqueeze(2)
        return x, attn

    def aggregate(
            self,
            features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            index: torch.Tensor,
            ptr: Optional[torch.Tensor],
            dim_size: Optional[int],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x, attn = features
        x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
        return x, attn

    def update(
            self, inputs: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return inputs

    def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor:
        pass

    def edge_update(self) -> Tensor:
        pass


# let k v share the same weight
class EquivariantTriAngularStarMultiHeadAttention(MessagePassing):
    """
    Equivariant multi-head attention layer. Add Triangular update between edges. Only update the center node.
    """

    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            edge_attr_channels,
            distance_influence,
            num_heads,
            activation,
            attn_activation,
            cutoff_lower,
            cutoff_upper,
            triangular_update=False,
            ee_channels=None,
    ):
        super(EquivariantTriAngularStarMultiHeadAttention, self).__init__(aggr="mean", node_dim=0)

        self.distance_influence = distance_influence
        self.num_heads = num_heads
        self.x_channels = x_channels
        self.x_hidden_channels = x_hidden_channels
        self.x_head_dim = x_hidden_channels // num_heads
        self.vec_channels = vec_channels
        self.vec_hidden_channels = vec_hidden_channels
        self.ee_channels = ee_channels
        # important, not vec_hidden_channels // num_heads

        # self.layernorm_in = nn.LayerNorm(x_channels)
        self.layernorm_out = nn.LayerNorm(x_hidden_channels)

        self.act = activation()
        self.attn_activation = act_class_mapping[attn_activation]()

        self.q_proj = nn.Linear(x_channels, x_hidden_channels)
        self.kv_proj = nn.Linear(x_channels, x_hidden_channels)
        # self.v_proj = nn.Linear(x_channels, x_hidden_channels)
        # self.o_proj = nn.Linear(x_hidden_channels, x_hidden_channels)
        # self.out = nn.Linear(x_hidden_channels, x_channels)
        # add residue to x
        # self.residue_hidden = nn.Linear(x_channels, x_hidden_channels)
        self.gru = nn.GRUCell(x_channels, x_channels)
        self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels)
        self.triangular_update = triangular_update
        if self.triangular_update:
            # self.edge_triangle_start_update = MultiplicativeUpdate(vec_in_channel=vec_channels,
            #                                                     hidden_channel=edge_attr_channels,
            #                                                     hidden_vec_channel=vec_hidden_channels,
            #                                                     ee_channels=ee_channels, )
            self.edge_triangle_end_update = MultiplicativeUpdate(vec_in_channel=vec_channels,
                                                                hidden_channel=edge_attr_channels,
                                                                hidden_vec_channel=vec_hidden_channels,
                                                                ee_channels=ee_channels, )
            self.node_to_edge_attr = NodeToEdgeAttr(node_channel=x_channels,
                                                    hidden_channel=x_hidden_channels,
                                                    edge_attr_channel=edge_attr_channels)

        self.reset_parameters()

    def reset_parameters(self):
        # self.layernorm_in.reset_parameters()
        self.layernorm_out.reset_parameters()
        nn.init.xavier_uniform_(self.q_proj.weight)
        self.q_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.kv_proj.weight)
        self.kv_proj.bias.data.fill_(0)
        # nn.init.xavier_uniform_(self.v_proj.weight)
        # self.v_proj.bias.data.fill_(0)
        # nn.init.xavier_uniform_(self.o_proj.weight)
        # self.o_proj.bias.data.fill_(0)
        if self.dk_proj:
            nn.init.xavier_uniform_(self.dk_proj.weight)
            self.dk_proj.bias.data.fill_(0)

    def get_start_index(self, edge_index):
        edge_start_index = []
        start_node_count = edge_index[0].unique(return_counts=True)
        start_nodes = start_node_count[0][start_node_count[1] > 1]
        for i in start_nodes:
            node_start_index = torch.where(edge_index[0] == i)[0]
            candidates = torch.combinations(node_start_index, r=2).T
            edge_start_index.append(torch.cat([candidates, candidates.flip(0)], dim=1))
        edge_start_index = torch.concat(edge_start_index, dim=1)
        edge_start_index = edge_start_index[:, edge_start_index[0] != edge_start_index[1]]
        return edge_start_index

    def get_end_index(self, edge_index):
        edge_end_index = []
        end_node_count = edge_index[1].unique(return_counts=True)
        end_nodes = end_node_count[0][end_node_count[1] > 1]
        for i in end_nodes:
            node_end_index = torch.where(edge_index[1] == i)[0]
            candidates = torch.combinations(node_end_index, r=2).T
            edge_end_index.append(torch.cat([candidates, candidates.flip(0)], dim=1))
        edge_end_index = torch.concat(edge_end_index, dim=1)
        edge_end_index = edge_end_index[:, edge_end_index[0] != edge_end_index[1]]
        return edge_end_index

    def forward(self, x, coords, edge_index, edge_attr, edge_vec, return_attn=False):
        # perform topK pooling
        end_node_count = edge_index[1].unique(return_counts=True)
        center_nodes = end_node_count[0][end_node_count[1] > 1]
        other_nodes = end_node_count[0][end_node_count[1] <= 1]
        residue = x[center_nodes] # batch_size * x_channels
        # filter edge_index and edge_attr to from context to center only
        edge_attr = edge_attr[torch.isin(edge_index[1], center_nodes), :]
        edge_vec = edge_vec[torch.isin(edge_index[1], center_nodes), :]
        edge_index = edge_index[:, torch.isin(edge_index[1], center_nodes)]
        # x itself is q, k and v
        q = self.q_proj(residue).reshape(-1, self.num_heads, self.x_head_dim)
        kv = self.kv_proj(x[other_nodes]).reshape(-1, self.num_heads, self.x_head_dim)
        qkv = torch.zeros(x.shape[0], self.num_heads, self.x_head_dim).to(x.device, non_blocking=True)
        qkv[center_nodes] = q
        qkv[other_nodes] = kv
        # point ettr to edge_attr
        if self.triangular_update:
            edge_attr += self.node_to_edge_attr(x, edge_index)
            # Triangular edge update
            # TODO: Add drop out layers here
            edge_edge_index = self.get_end_index(edge_index)
            if self.ee_channels is not None:
                edge_edge_attr = coords[edge_index[0][edge_edge_index[0]], :, [0]] - coords[edge_index[0][edge_edge_index[1]], :, [0]]
                edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True)
            else:
                edge_edge_attr = None
            edge_attr = self.edge_triangle_end_update(
                edge_attr, edge_vec, 
                edge_edge_index,
                edge_edge_attr
            )
            del edge_edge_attr, edge_edge_index

        dk = (
            self.act(self.dk_proj(edge_attr)).reshape(-1, self.num_heads, self.x_head_dim) 
            if self.dk_proj is not None else None
        ) # TODO: check self.act

        # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor,
        # d_ij: Tensor)
        x, attn = self.propagate(
            edge_index,
            q=qkv,
            k=qkv,
            v=qkv,
            dk=dk,
            size=None,
        )
        x = x.reshape(-1, self.x_hidden_channels)
        # only get the center nodes
        x = x[center_nodes]
        x = self.layernorm_out(x)
        x = self.gru(residue, x)
        del residue, dk
        if return_attn:
            return x, edge_attr, torch.concat((edge_index.T, attn), dim=1)
        else:
            return x, edge_attr, None

    def message(self, q_i, k_j, v_j, dk):
        # attention mechanism
        if dk is None:
            attn = (q_i * k_j).sum(dim=-1)
        else:  # TODO: consider add or multiply dk
            attn = (q_i * k_j + dk).sum(dim=-1)

        # attention activation function
        attn = self.attn_activation(attn) / self.x_head_dim

        # update scalar features
        x = v_j * attn.unsqueeze(2)
        return x, attn

    def aggregate(
            self,
            features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            index: torch.Tensor,
            ptr: Optional[torch.Tensor],
            dim_size: Optional[int],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x, attn = features
        x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
        return x, attn

    def update(
            self, inputs: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return inputs

    def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor:
        pass

    def edge_update(self) -> Tensor:
        pass


# let k v share the same weight, dropout attention weights, with option LoRA
class EquivariantTriAngularStarDropMultiHeadAttention(MessagePassing):
    """
    Equivariant multi-head attention layer. Add Triangular update between edges. Only update the center node.
    """

    def __init__(
            self,
            x_channels,
            x_hidden_channels,
            vec_channels,
            vec_hidden_channels,
            edge_attr_channels,
            distance_influence,
            num_heads,
            activation,
            attn_activation,
            rbf_channels,
            triangular_update=False,
            ee_channels=None,
            drop_out_rate=0.0,
            use_lora=None,
    ):
        super(EquivariantTriAngularStarDropMultiHeadAttention, self).__init__(aggr="mean", node_dim=0)

        self.distance_influence = distance_influence
        self.num_heads = num_heads
        self.x_channels = x_channels
        self.x_hidden_channels = x_hidden_channels
        self.x_head_dim = x_hidden_channels // num_heads
        self.vec_channels = vec_channels
        self.vec_hidden_channels = vec_hidden_channels
        self.ee_channels = ee_channels
        self.rbf_channels = rbf_channels
        # important, not vec_hidden_channels // num_heads

        # self.layernorm_in = nn.LayerNorm(x_channels)
        self.layernorm_out = nn.LayerNorm(x_hidden_channels)

        self.act = activation()
        self.attn_activation = act_class_mapping[attn_activation]()

        if use_lora is not None:
            self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora)
            self.kv_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora)
            self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora)
        else:
            self.q_proj = nn.Linear(x_channels, x_hidden_channels)
            self.kv_proj = nn.Linear(x_channels, x_hidden_channels)
            self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels)
        # self.v_proj = nn.Linear(x_channels, x_hidden_channels)
        # self.o_proj = nn.Linear(x_hidden_channels, x_hidden_channels)
        # self.out = nn.Linear(x_hidden_channels, x_channels)
        # add residue to x
        # self.residue_hidden = nn.Linear(x_channels, x_hidden_channels)
        self.gru = nn.GRUCell(x_channels, x_channels)
        
        self.triangular_drop = nn.Dropout(drop_out_rate)
        self.rbf_drop = nn.Dropout(drop_out_rate)
        self.dense_drop = nn.Dropout(drop_out_rate)
        self.dropout = nn.Dropout(drop_out_rate)
        self.triangular_update = triangular_update
        if self.triangular_update:
            self.edge_triangle_end_update = MultiplicativeUpdate(vec_in_channel=vec_channels,
                                                                 hidden_channel=edge_attr_channels,
                                                                 hidden_vec_channel=vec_hidden_channels,
                                                                 ee_channels=ee_channels, 
                                                                 use_lora=use_lora)
            self.node_to_edge_attr = NodeToEdgeAttr(node_channel=x_channels,
                                                    hidden_channel=x_hidden_channels,
                                                    edge_attr_channel=edge_attr_channels,
                                                    use_lora=use_lora)
            self.triangle_update_dropout = nn.Dropout(0.5)

        self.reset_parameters()

    def reset_parameters(self):
        # self.layernorm_in.reset_parameters()
        self.layernorm_out.reset_parameters()
        nn.init.xavier_uniform_(self.q_proj.weight)
        self.q_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.kv_proj.weight)
        self.kv_proj.bias.data.fill_(0)
        # nn.init.xavier_uniform_(self.v_proj.weight)
        # self.v_proj.bias.data.fill_(0)
        # nn.init.xavier_uniform_(self.o_proj.weight)
        # self.o_proj.bias.data.fill_(0)
        if self.dk_proj:
            nn.init.xavier_uniform_(self.dk_proj.weight)
            self.dk_proj.bias.data.fill_(0)

    def get_start_index(self, edge_index):
        edge_start_index = []
        start_node_count = edge_index[0].unique(return_counts=True)
        start_nodes = start_node_count[0][start_node_count[1] > 1]
        for i in start_nodes:
            node_start_index = torch.where(edge_index[0] == i)[0]
            candidates = torch.combinations(node_start_index, r=2).T
            edge_start_index.append(torch.cat([candidates, candidates.flip(0)], dim=1))
        edge_start_index = torch.concat(edge_start_index, dim=1)
        edge_start_index = edge_start_index[:, edge_start_index[0] != edge_start_index[1]]
        return edge_start_index

    def get_end_index(self, edge_index):
        edge_end_index = []
        end_node_count = edge_index[1].unique(return_counts=True)
        end_nodes = end_node_count[0][end_node_count[1] > 1]
        for i in end_nodes:
            node_end_index = torch.where(edge_index[1] == i)[0]
            candidates = torch.combinations(node_end_index, r=2).T
            edge_end_index.append(torch.cat([candidates, candidates.flip(0)], dim=1))
        edge_end_index = torch.concat(edge_end_index, dim=1)
        edge_end_index = edge_end_index[:, edge_end_index[0] != edge_end_index[1]]
        return edge_end_index

    def forward(self, x, coords, edge_index, edge_attr, edge_vec, return_attn=False):
        # perform topK pooling
        end_node_count = edge_index[1].unique(return_counts=True)
        center_nodes = end_node_count[0][end_node_count[1] > 1]
        other_nodes = end_node_count[0][end_node_count[1] <= 1]
        residue = x[center_nodes] # batch_size * x_channels
        # filter edge_index and edge_attr to from context to center only
        edge_attr = edge_attr[torch.isin(edge_index[1], center_nodes), :]
        edge_vec = edge_vec[torch.isin(edge_index[1], center_nodes), :]
        edge_index = edge_index[:, torch.isin(edge_index[1], center_nodes)]
        # x itself is q, k and v
        q = self.q_proj(residue).reshape(-1, self.num_heads, self.x_head_dim)
        kv = self.kv_proj(x[other_nodes]).reshape(-1, self.num_heads, self.x_head_dim)
        qkv = torch.zeros(x.shape[0], self.num_heads, self.x_head_dim).to(x.device, non_blocking=True)
        qkv[center_nodes] = q
        qkv[other_nodes] = kv
        # point ettr to edge_attr
        if self.triangular_update:
            edge_attr += self.node_to_edge_attr(x, edge_index)
            # Triangular edge update
            # TODO: Add drop out layers here
            edge_edge_index = self.get_end_index(edge_index)
            edge_edge_index = edge_edge_index[:, self.triangular_drop(
                torch.ones(edge_edge_index.shape[1], device=edge_edge_index.device)
                ).to(torch.bool)]
            if self.ee_channels is not None:
                edge_edge_attr = coords[edge_index[0][edge_edge_index[0]], :, [0]] - coords[edge_index[0][edge_edge_index[1]], :, [0]]
                edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True)
            else:
                edge_edge_attr = None
            edge_attr = self.edge_triangle_end_update(
                edge_attr, edge_vec, 
                edge_edge_index,
                edge_edge_attr
            )
            del edge_edge_attr, edge_edge_index
        
        # drop rbfs
        edge_attr = torch.cat((edge_attr[:, :-self.rbf_channels],
                              self.rbf_drop(edge_attr[:, -self.rbf_channels:])), 
                              dim=-1)

        dk = (
            self.act(self.dk_proj(edge_attr)).reshape(-1, self.num_heads, self.x_head_dim) 
            if self.dk_proj is not None else None
        ) # TODO: check self.act

        # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor,
        # d_ij: Tensor)
        x, attn = self.propagate(
            edge_index,
            q=qkv,
            k=qkv,
            v=qkv,
            dk=dk,
            size=None,
        )
        x = x.reshape(-1, self.x_hidden_channels)
        # only get the center nodes
        x = x[center_nodes]
        x = self.layernorm_out(x)
        x = self.dense_drop(x)
        x = self.gru(residue, x)
        x = self.dropout(x)
        del residue, dk
        if return_attn:
            return x, edge_attr, torch.concat((edge_index.T, attn), dim=1)
        else:
            return x, edge_attr, None

    def message(self, q_i, k_j, v_j, dk):
        # attention mechanism
        if dk is None:
            attn = (q_i * k_j).sum(dim=-1)
        else:  # TODO: consider add or multiply dk
            attn = (q_i * k_j + dk).sum(dim=-1)

        # attention activation function
        attn = self.attn_activation(attn) / self.x_head_dim

        # update scalar features
        x = v_j * attn.unsqueeze(2)
        return x, attn

    def aggregate(
            self,
            features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            index: torch.Tensor,
            ptr: Optional[torch.Tensor],
            dim_size: Optional[int],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x, attn = features
        x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
        return x, attn

    def update(
            self, inputs: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return inputs

    def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor:
        pass

    def edge_update(self) -> Tensor:
        pass


# Transform sequence, structure, and relative position into a pair feature
class PairFeatureNet(nn.Module):

    def __init__(self, c_s, c_p, relpos_k=32, template_type="exp-normal-smearing-distance"):
        super(PairFeatureNet, self).__init__()

        self.c_s = c_s
        self.c_p = c_p

        self.linear_s_p_i = nn.Linear(c_s, c_p)
        self.linear_s_p_j = nn.Linear(c_s, c_p)

        self.relpos_k = relpos_k
        self.n_bin = 2 * relpos_k + 1
        self.linear_relpos = nn.Linear(self.n_bin, c_p)

        # TODO: implement structure to pairwise feature function
        self.template_fn, c_template = get_template_fn(template_type)
        self.linear_template = nn.Linear(c_template, c_p)

    def relpos(self, r):
        # AlphaFold 2 Algorithm 4 & 5
        # Based on OpenFold utils/tensor_utils.py
        # Input: [b, n_res]

        # [b, n_res, n_res]
        d = r[:, :, None] - r[:, None, :]

        # [n_bin]
        v = torch.arange(-self.relpos_k, self.relpos_k + 1).to(r.device, non_blocking=True)

        # [1, 1, 1, n_bin]
        v_reshaped = v.view(*((1,) * len(d.shape) + (len(v),)))

        # [b, n_res, n_res]
        b = torch.argmin(torch.abs(d[:, :, :, None] - v_reshaped), dim=-1)

        # [b, n_res, n_res, n_bin]
        oh = nn.functional.one_hot(b, num_classes=len(v)).float()

        # [b, n_res, n_res, c_p]
        p = self.linear_relpos(oh)

        return p

    def template(self, t):
        return self.linear_template(self.template_fn(t))

    def forward(self, s, t, r, mask):
        # Input: [b, n_res, c_s]
        p_mask = mask.unsqueeze(1) * mask.unsqueeze(2)
        # [b, n_res, c_p]
        p_i = self.linear_s_p_i(s)
        p_j = self.linear_s_p_j(s)

        # [b, n_res, n_res, c_p]
        p = p_i[:, :, None, :] + p_j[:, None, :, :]

        # [b, n_res, n_res, c_p]
        p += self.relpos(r)  # upper bond is 64 AA
        p += self.template(t)  # upper bond is 100 A

        # [b, n_res, n_res, c_p]
        p *= p_mask.unsqueeze(-1)

        return p


# AF2's TriangularSelfAttentionBlock, but I removed the pairwise attention because of memory issues.
# In genie they are doing the same.
class TriangularSelfAttentionBlock(nn.Module):
    def __init__(
        self,
        sequence_state_dim,
        pairwise_state_dim,
        sequence_head_width,
        pairwise_head_width,
        dropout=0,
        **__kwargs,
    ):
        super().__init__()
        from openfold.model.triangular_multiplicative_update import (
            TriangleMultiplicationIncoming,
            TriangleMultiplicationOutgoing,
        )
        from esm.esmfold.v1.misc import (
            Attention,
            Dropout,
            PairToSequence,
            ResidueMLP,
            SequenceToPair,
        )
        assert sequence_state_dim % sequence_head_width == 0
        assert pairwise_state_dim % pairwise_head_width == 0
        sequence_num_heads = sequence_state_dim // sequence_head_width
        pairwise_num_heads = pairwise_state_dim // pairwise_head_width
        assert sequence_state_dim == sequence_num_heads * sequence_head_width
        assert pairwise_state_dim == pairwise_num_heads * pairwise_head_width
        assert pairwise_state_dim % 2 == 0

        self.sequence_state_dim = sequence_state_dim
        self.pairwise_state_dim = pairwise_state_dim

        self.layernorm_1 = nn.LayerNorm(sequence_state_dim)

        self.sequence_to_pair = SequenceToPair(
            sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim
        )
        self.pair_to_sequence = PairToSequence(
            pairwise_state_dim, sequence_num_heads)

        self.seq_attention = Attention(
            sequence_state_dim, sequence_num_heads, sequence_head_width, gated=True
        )
        self.tri_mul_out = TriangleMultiplicationOutgoing(
            pairwise_state_dim,
            pairwise_state_dim,
        )
        self.tri_mul_in = TriangleMultiplicationIncoming(
            pairwise_state_dim,
            pairwise_state_dim,
        )

        self.mlp_seq = ResidueMLP(
            sequence_state_dim, 4 * sequence_state_dim, dropout=dropout)
        self.mlp_pair = ResidueMLP(
            pairwise_state_dim, 4 * pairwise_state_dim, dropout=dropout)

        assert dropout < 0.4
        self.drop = nn.Dropout(dropout)
        self.row_drop = Dropout(dropout * 2, 2)
        self.col_drop = Dropout(dropout * 2, 1)

        torch.nn.init.zeros_(self.tri_mul_in.linear_z.weight)
        torch.nn.init.zeros_(self.tri_mul_in.linear_z.bias)
        torch.nn.init.zeros_(self.tri_mul_out.linear_z.weight)
        torch.nn.init.zeros_(self.tri_mul_out.linear_z.bias)

        torch.nn.init.zeros_(self.sequence_to_pair.o_proj.weight)
        torch.nn.init.zeros_(self.sequence_to_pair.o_proj.bias)
        torch.nn.init.zeros_(self.pair_to_sequence.linear.weight)
        torch.nn.init.zeros_(self.seq_attention.o_proj.weight)
        torch.nn.init.zeros_(self.seq_attention.o_proj.bias)
        torch.nn.init.zeros_(self.mlp_seq.mlp[-2].weight)
        torch.nn.init.zeros_(self.mlp_seq.mlp[-2].bias)
        torch.nn.init.zeros_(self.mlp_pair.mlp[-2].weight)
        torch.nn.init.zeros_(self.mlp_pair.mlp[-2].bias)

    def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
        """
        Inputs:
          sequence_state: B x L x sequence_state_dim
          pairwise_state: B x L x L x pairwise_state_dim
          mask: B x L boolean tensor of valid positions

        Output:
          sequence_state: B x L x sequence_state_dim
          pairwise_state: B x L x L x pairwise_state_dim
        """
        assert len(sequence_state.shape) == 3
        assert len(pairwise_state.shape) == 4
        if mask is not None:
            assert len(mask.shape) == 2

        batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
        pairwise_state_dim = pairwise_state.shape[3]
        assert sequence_state_dim == self.sequence_state_dim
        assert pairwise_state_dim == self.pairwise_state_dim
        assert batch_dim == pairwise_state.shape[0]
        assert seq_dim == pairwise_state.shape[1]
        assert seq_dim == pairwise_state.shape[2]

        # Update sequence state
        bias = self.pair_to_sequence(pairwise_state)

        # Self attention with bias + mlp.
        y = self.layernorm_1(sequence_state)
        y, _ = self.seq_attention(y, mask=mask, bias=bias)
        sequence_state = sequence_state + self.drop(y)
        sequence_state = self.mlp_seq(sequence_state)

        # Update pairwise state
        pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)

        # Axial attention with triangular bias.
        tri_mask = mask.unsqueeze(
            2) * mask.unsqueeze(1) if mask is not None else None
        pairwise_state = pairwise_state + self.row_drop(
            self.tri_mul_out(pairwise_state, mask=tri_mask)
        )
        pairwise_state = pairwise_state + self.col_drop(
            self.tri_mul_in(pairwise_state, mask=tri_mask)
        )

        # MLP over pairs.
        pairwise_state = self.mlp_pair(pairwise_state)

        return sequence_state, pairwise_state


# A Self-Attention Pooling Block
class SeqPairAttentionOutput(nn.Module):
    def __init__(self, seq_state_dim, pairwise_state_dim, num_heads, output_dim, dropout):
        super(SeqPairAttentionOutput, self).__init__()
        from esm.esmfold.v1.misc import (
            Attention,
            PairToSequence,
            ResidueMLP,
        )
        self.seq_state_dim = seq_state_dim
        self.pairwise_state_dim = pairwise_state_dim
        self.output_dim = output_dim
        seq_head_width = seq_state_dim // num_heads

        self.layernorm = nn.LayerNorm(seq_state_dim)
        self.seq_attention = Attention(
            seq_state_dim, num_heads, seq_head_width, gated=True
        )
        self.pair_to_sequence = PairToSequence(pairwise_state_dim, num_heads)
        self.mlp_seq = ResidueMLP(
            seq_state_dim, 4 * seq_state_dim, dropout=dropout)
        self.drop = nn.Dropout(dropout)

    def forward(self, sequence_state, pairwise_state, mask=None):
        # Update sequence state
        bias = self.pair_to_sequence(pairwise_state)

        # Self attention with bias + mlp.
        y = self.layernorm(sequence_state)
        y, _ = self.seq_attention(y, mask=mask, bias=bias)
        sequence_state = sequence_state + self.drop(y)
        sequence_state = self.mlp_seq(sequence_state)

        return sequence_state