import torch
import torch.nn as nn
import logging

from modules.Utilities import util
from modules.Attention import AttentionMethods
from modules.Device import Device
from modules.cond import cast


def Normalize(
    in_channels: int, dtype: torch.dtype = None, device: torch.device = None
) -> torch.nn.GroupNorm:
    """#### Normalize the input channels.

    #### Args:
        - `in_channels` (int): The input channels.
        - `dtype` (torch.dtype, optional): The data type. Defaults to `None`.
        - `device` (torch.device, optional): The device. Defaults to `None`.

    #### Returns:
        - `torch.nn.GroupNorm`: The normalized input channels
    """
    return torch.nn.GroupNorm(
        num_groups=32,
        num_channels=in_channels,
        eps=1e-6,
        affine=True,
        dtype=dtype,
        device=device,
    )


if Device.xformers_enabled():
    logging.info("Using xformers cross attention")
    optimized_attention = AttentionMethods.attention_xformers
else:
    logging.info("Using pytorch cross attention")
    optimized_attention = AttentionMethods.attention_pytorch

optimized_attention_masked = optimized_attention


def optimized_attention_for_device() -> AttentionMethods.attention_pytorch:
    """#### Get the optimized attention for a device.

    #### Returns:
        - `function`: The optimized attention function.
    """
    return AttentionMethods.attention_pytorch


class CrossAttention(nn.Module):
    """#### Cross attention module, which applies attention across the query and context.

    #### Args:
        - `query_dim` (int): The query dimension.
        - `context_dim` (int, optional): The context dimension. Defaults to `None`.
        - `heads` (int, optional): The number of heads. Defaults to `8`.
        - `dim_head` (int, optional): The head dimension. Defaults to `64`.
        - `dropout` (float, optional): The dropout rate. Defaults to `0.0`.
        - `dtype` (torch.dtype, optional): The data type. Defaults to `None`.
        - `device` (torch.device, optional): The device. Defaults to `None`.
        - `operations` (cast.disable_weight_init, optional): The operations. Defaults to `cast.disable_weight_init`.
    """

    def __init__(
        self,
        query_dim: int,
        context_dim: int = None,
        heads: int = 8,
        dim_head: int = 64,
        dropout: float = 0.0,
        dtype: torch.dtype = None,
        device: torch.device = None,
        operations: cast.disable_weight_init = cast.disable_weight_init,
    ):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = util.default(context_dim, query_dim)

        self.heads = heads
        self.dim_head = dim_head

        self.to_q = operations.Linear(
            query_dim, inner_dim, bias=False, dtype=dtype, device=device
        )
        self.to_k = operations.Linear(
            context_dim, inner_dim, bias=False, dtype=dtype, device=device
        )
        self.to_v = operations.Linear(
            context_dim, inner_dim, bias=False, dtype=dtype, device=device
        )

        self.to_out = nn.Sequential(
            operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
            nn.Dropout(dropout),
        )

    def forward(
        self,
        x: torch.Tensor,
        context: torch.Tensor = None,
        value: torch.Tensor = None,
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        """#### Forward pass of the cross attention module.

        #### Args:
            - `x` (torch.Tensor): The input tensor.
            - `context` (torch.Tensor, optional): The context tensor. Defaults to `None`.
            - `value` (torch.Tensor, optional): The value tensor. Defaults to `None`.
            - `mask` (torch.Tensor, optional): The mask tensor. Defaults to `None`.

        #### Returns:
            - `torch.Tensor`: The output tensor.
        """
        q = self.to_q(x)
        context = util.default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        out = optimized_attention(q, k, v, self.heads)
        return self.to_out(out)


class AttnBlock(nn.Module):
    """#### Attention block, which applies attention to the input tensor.

    #### Args:
        - `in_channels` (int): The input channels.
    """

    def __init__(self, in_channels: int):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = cast.disable_weight_init.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.k = cast.disable_weight_init.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.v = cast.disable_weight_init.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.proj_out = cast.disable_weight_init.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )

        if Device.xformers_enabled_vae():
            logging.info("Using xformers attention in VAE")
            self.optimized_attention = AttentionMethods.xformers_attention
        else:
            logging.info("Using pytorch attention in VAE")
            self.optimized_attention = AttentionMethods.pytorch_attention

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """#### Forward pass of the attention block.

        #### Args:
            - `x` (torch.Tensor): The input tensor.

        #### Returns:
            - `torch.Tensor`: The output tensor.
        """
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        h_ = self.optimized_attention(q, k, v)

        h_ = self.proj_out(h_)

        return x + h_


def make_attn(in_channels: int, attn_type: str = "vanilla") -> AttnBlock:
    """#### Make an attention block.

    #### Args:
        - `in_channels` (int): The input channels.
        - `attn_type` (str, optional): The attention type. Defaults to "vanilla".

    #### Returns:
        - `AttnBlock`: A class instance of the attention block.
    """
    return AttnBlock(in_channels)