import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F


class MultiHeadAttention(Module):
    r"""A class that implements a Multi-head Attention mechanism.
    Multi-head attention allows the model to focus on different positions,
    capturing various aspects of the input.

    Args:
        query_dim (int): The dimensionality of the query.
        key_dim (int): The dimensionality of the key.
        num_units (int): The total number of dimensions of the output.
        num_heads (int): The number of parallel attention layers (multi-heads).

    Inputs: query, and key
        - **query**: Tensor of shape [N, T_q, query_dim]
        - **key**: Tensor of shape [N, T_k, key_dim]

    Outputs:
        - An output tensor of shape [N, T_q, num_units]
    """

    def __init__(
        self,
        query_dim: int,
        key_dim: int,
        num_units: int,
        num_heads: int,
    ):
        super().__init__()
        self.num_units = num_units
        self.num_heads = num_heads
        self.key_dim = key_dim

        self.W_query = nn.Linear(
            in_features=query_dim,
            out_features=num_units,
            bias=False,
        )
        self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
        self.W_value = nn.Linear(
            in_features=key_dim, out_features=num_units, bias=False,
        )

    def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
        r"""Performs the forward pass over input tensors.

        Args:
            query (torch.Tensor): The input tensor containing query vectors.
                It is expected to have the dimensions [N, T_q, query_dim]
                where N is the batch size, T_q is the sequence length of queries,
                and query_dim is the dimensionality of a single query vector.

            key (torch.Tensor): The input tensor containing key vectors.
                It is expected to have the dimensions [N, T_k, key_dim]
                where N is the batch size, T_k is the sequence length of keys,
                and key_dim is the dimensionality of a single key vector.

        Returns:
            torch.Tensor: The output tensor of shape [N, T_q, num_units] which
                represents the results of the multi-head attention mechanism applied
                on the provided queries and keys.
        """
        querys = self.W_query(query)  # [N, T_q, num_units]
        keys = self.W_key(key)  # [N, T_k, num_units]
        values = self.W_value(key)
        split_size = self.num_units // self.num_heads

        querys = torch.stack(
            torch.split(querys, split_size, dim=2), dim=0,
        )  # [h, N, T_q, num_units/h]
        keys = torch.stack(
            torch.split(keys, split_size, dim=2), dim=0,
        )  # [h, N, T_k, num_units/h]
        values = torch.stack(
            torch.split(values, split_size, dim=2), dim=0,
        )  # [h, N, T_k, num_units/h]
        # score = softmax(QK^T / (d_k ** 0.5))

        scores = torch.matmul(querys, keys.transpose(2, 3))  # [h, N, T_q, T_k]
        scores = scores / (self.key_dim**0.5)
        scores = F.softmax(scores, dim=3)
        # out = score * V
        out = torch.matmul(scores, values)  # [h, N, T_q, num_units/h]
        return torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(
            0,
        )  # [N, T_q, num_units]