from typing import Optional, Tuple, MutableMapping
from typing import Union
import math
from contextlib import nullcontext

import torch
import torch as T
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.attention import SDPBackend

from einops import rearrange

from utils import si_module, default, exists, load_ckpt

CACHE_FILL_VALUE = -1

def get_cache_len(cache: Optional[Tensor]) -> int:
    """
    cache: (batch, seq_len, 2, kv_heads, head_dim)
    """
    if cache is None:
        return 0
    nonzeros = T.any(cache.flatten(2) != CACHE_FILL_VALUE, dim=-1)
    length = nonzeros.sum(dim=-1).int()
    assert T.all(length == length[0])
    return length[0]


def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(x, cos, sin, offset: int = 0):
    assert (
        cos.shape[1] >= offset + x.shape[1]
    ), f"Offset and/or input sequence is too large,\
        \n offset: {offset}, seq_len: {x.shape[1]}, max: {cos.shape[1]}"

    cos_out = cos[:, offset : offset + x.shape[1], :, :]
    sin_out = sin[:, offset : offset + x.shape[1], :, :]

    return (x * cos_out) + (rotate_half(x) * sin_out)


# Adapted from https://github.com/foundation-model-stack/foundation-model-stack
class ShapeRotator:
    def __init__(
        self,
        dim: int,
        end: int,
        theta: float = 10_000,
    ):
        super().__init__()
        self.dim = dim
        self.ratio = theta
        self.cached_freqs: MutableMapping[int, MutableMapping[int, torch.Tensor]] = {}
        self.max_seq_len_cached: MutableMapping[int, int] = {}
        self.ntk_scaling = False
        self.max_seq_len = end

    def compute_freqs_cis(self, device, max_seq_len=None):
        alpha = 1
        dev_idx = device.index
        max_seq_len = default(max_seq_len, self.max_seq_len)

        if dev_idx not in self.cached_freqs:
            self.cached_freqs[dev_idx] = {}
        if dev_idx not in self.max_seq_len_cached:
            self.max_seq_len_cached[dev_idx] = 0


        if self.max_seq_len_cached[dev_idx] > 0:
            return 1
        max_seq_len = max(max_seq_len, self.max_seq_len)

        if (
            1 in self.cached_freqs[dev_idx]
            and max_seq_len <= self.max_seq_len_cached[dev_idx]
        ):
            return 1

        ratio = self.ratio
        dim = self.dim

        freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2, device=device).float() / dim))

        t = torch.arange(max_seq_len, device=device, dtype=freqs.dtype)
        freqs = torch.einsum("i,j->ij", t, freqs)
        emb = torch.cat((freqs, freqs), dim=-1).to(device)

        cos_to_cache = emb.cos()[None, :, None, :]
        sin_to_cache = emb.sin()[None, :, None, :]

        self.max_seq_len_cached[dev_idx] = max_seq_len

        self.cached_freqs[dev_idx][alpha] = torch.stack(
            [
                cos_to_cache,
                sin_to_cache,
            ],
            dim=-1,
        )

        return alpha

    def rotate(
        self,
        q: Tensor,
        k: Tensor,
        offset: int = 0,
    ) -> Tuple[Tensor, Tensor]:
        """
        Args
        ----
        q : torch.Tensor
            Embedded query tensor, expected size is B x S x H x Eh
        k : torch.Tensor
            Embedded query tensor, expected size is B x S x H x Eh
        """
        assert len(q.size()) == 4
        assert len(k.size()) == 4

        seq_len = self.max_seq_len
        alpha = self.compute_freqs_cis(q.device, seq_len)
        freqs = self.cached_freqs[q.device.index][alpha]

        freqs = freqs.float()  # 1 L D/2 2 2
        q_out = apply_rotary_pos_emb(q, freqs[..., 0], freqs[..., 1], offset=offset).type_as(q)
        k_out = apply_rotary_pos_emb(k, freqs[..., 0], freqs[..., 1], offset=offset).type_as(k)

        return q_out.view_as(q), k_out.view_as(k)

class Linear(nn.Linear):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, bias=False)

class Norm(nn.Module):
    def __init__(self,
            dim: int,
            eps: float = 1e-5,) -> None:
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(T.ones((dim,)))

    def forward(self, input: Tensor) -> Tensor:
        return F.layer_norm(input, (self.weight.shape[0],), weight=self.weight, bias=None, eps=self.eps)
    

class FFNN(nn.Module):
    def __init__(self,
            dim: int,
            expand_dim: int = None,):
        super().__init__()
        expand_dim = default(expand_dim, 256 * ((int(2 * 4 * dim / 3) + 256 - 1) // 256))
        self.dim = dim
        self.expand_dim = expand_dim

        self.gateup_proj = Linear(dim, 2*expand_dim)
        self.down_proj = Linear(expand_dim, dim)

    def forward(self, x):
        gate, up = self.gateup_proj(x).chunk(2, dim=-1)
        return self.down_proj(up * F.silu(gate))

class GQA(nn.Module):
    def __init__(self,
            dim: int, 
            n_head: int, 
            shape_rotator: ShapeRotator,
            kv_heads: Optional[int] = None,
            eps: float = 1e-5,
            causal: bool = True,):
        super().__init__()
        self.n_heads = n_head
        self.kv_heads = default(kv_heads, n_head)
        self.head_dim = dim // n_head
        self.causal = causal

        self.proj_qkv = Linear(dim, self.head_dim*(n_head+2*self.kv_heads))

        self.norm_q = Norm(self.head_dim*n_head, eps=eps)
        self.norm_k = Norm(self.head_dim*self.kv_heads, eps=eps)

        self.attn_out = Linear(dim, dim)

        self.shape_rotator = shape_rotator

    def _sdpa(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        k = k.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
        v = v.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
        x = F.scaled_dot_product_attention(
            q.transpose(1, 2), 
            k.transpose(1, 2), 
            v.transpose(1, 2), 
            is_causal=False if (q.size(1) != k.size(1)) else self.causal,
        )
        x = x.transpose(1, 2).contiguous()
        return x

    def _attend(self, q: Tensor, k: Tensor, v: Tensor, kv_cache: Optional[Tensor] = None,):
        cache_len = get_cache_len(kv_cache)
        q, k = self.shape_rotator.rotate(q, k, offset=cache_len)
        if exists(kv_cache):
            k = T.cat([kv_cache[:, :cache_len, 0], k], dim=1)
            v = T.cat([kv_cache[:, :cache_len, 1], v], dim=1)
            kv_cache[:, :k.size(1), 0] = k
            kv_cache[:, :v.size(1), 1] = v
        x = self._sdpa(q, k, v)
        return self.attn_out(rearrange(x, 'b s h d -> b s (h d)'))

    def _project(self, x):
        full_q, full_k, full_v = self.proj_qkv(x).chunk(3, dim=-1)
        normed_full_q = self.norm_q(full_q).to(full_q.dtype)
        normed_full_k = self.norm_k(full_k).to(full_k.dtype)

        q = rearrange(normed_full_q, 'b s (h d) -> b s h d', h=self.n_heads)
        k = rearrange(normed_full_k, 'b s (h d) -> b s h d', h=self.kv_heads)
        v = rearrange(full_v, 'b s (h d) -> b s h d', h=self.kv_heads)
        return q, k, v

    def forward(self,
            x: Tensor,
            kv: Optional[Tensor] = None,):
        """
        x: (B, S, D)
        kv: (B, S, H, D)
        """
        q, k, v = self._project(x)
        return self._attend(q, k, v, kv_cache=kv)
        

class PreNormAttn(nn.Module):
    def __init__(self,
            dim: int, 
            n_head: int,
            shape_rotator: ShapeRotator,
            kv_heads: Optional[int] = None,
            eps: float = 1e-5,
            causal: bool = True,):
        super().__init__()
        self.attn_norm = Norm(dim, eps=eps)
        self.attn = GQA(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)

    def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
        """
        x: (B, S, D)
        kv: (B, S, H, D)
        """
        return x + self.attn(self.attn_norm(x), kv)

class PreNormFFNN(nn.Module):
    def __init__(self,
            dim: int, 
            ff_dim: int,
            eps: float = 1e-5,):
        super().__init__()
        self.ffnn_norm = Norm(dim, eps=eps)
        self.ffnn = FFNN(dim, ff_dim)

    def forward(self, x: Tensor) -> Tensor:
        return x + self.ffnn(self.ffnn_norm(x))

class Block(nn.Module):
    def __init__(self,
            dim: int,
            layer_id: int = 0,
            n_head: int = 16,
            kv_heads: Optional[int] = None,
            ff_dim: Optional[int] = None,
            eps: float = 1e-5,
            causal: bool = True,
            shape_rotator: ShapeRotator = None):
        super().__init__()
        self.attn = PreNormAttn(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)
        self.ffnn = PreNormFFNN(dim, ff_dim, eps=eps)
        self.dim = dim
        self.layer_id = layer_id
        self.head_dim = dim // n_head
        self.expand_dim = self.ffnn.ffnn.expand_dim

        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.dim)
        nn.init.trunc_normal_(self.ffnn.ffnn.gateup_proj.weight, std=std, a=-3 * std, b=3 * std)
        nn.init.trunc_normal_(self.attn.attn.proj_qkv.weight, std=std, a=-3 * std, b=3 * std)
        nn.init.trunc_normal_(self.attn.attn.attn_out.weight, std=std, a=-3 * std, b=3 * std)

        xstd = 1.0 / math.sqrt(self.expand_dim)
        nn.init.trunc_normal_(self.ffnn.ffnn.down_proj.weight, std=xstd, a=-3 * xstd, b=3 * xstd)

    def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
        """
        x: (B, S, D)
        kv: (B, S, H, D)
        """
        h = self.attn(x, kv)
        out = self.ffnn(h)
        return out



class GPTOutput(nn.Module):
    def __init__(self, dim, vocab_size):
        super().__init__()
        self.dim = dim
        self.norm = Norm(dim)
        self.output = Linear(dim, vocab_size)

        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.dim**2)
        nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std)

    def forward(self, x):
        return self.output(self.norm(x))
    
@si_module
class Stack(nn.Module):
    class Config:
        layers: int
        dim: int
        seq_len: int
        n_head: int = 32
        ff_dim: int = None
        kv_heads: int = None
        eps: float = 1e-5
        theta: Union[int, float] = 10_000
        causal: bool = True

        from_pretrained: Optional[Tuple[str, int]] = None

    def __init__(self, c: Config):
        super().__init__()

        from_pretrained = c.from_pretrained
        if exists(from_pretrained):
            checkpoint = load_ckpt(c.from_pretrained)

        self.shape_rotator = ShapeRotator(c.dim//c.n_head, c.seq_len, theta=c.theta)

        self.layers = nn.ModuleList([
            Block(
                dim=c.dim,
                layer_id=l,
                n_head=c.n_head,
                kv_heads=c.kv_heads,
                ff_dim=c.ff_dim,
                eps=c.eps,
                causal=c.causal,
                shape_rotator=self.shape_rotator,
            ) for l in range(c.layers)
        ])

        kv_heads = c.kv_heads or c.n_head
        head_dim = c.dim // c.n_head
        cache_shape = [c.layers, c.seq_len, 2, kv_heads, head_dim]
        self.cache_shape = cache_shape
        self.cache = [None] * c.layers

        if exists(from_pretrained):
            self.load_state_dict(checkpoint)

    def init_cache(self, bsize, device, dtype, length:int=None):
        if self.cache_shape is None:
            return
        cache_shape = self.cache_shape.copy()
        cache_shape[1] = length or cache_shape[1]
        self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)

    def deinit_cache(self):
        self.cache = [None] * len(self.cache)

    def forward(self, x: Tensor) -> Tensor:
        for l, layer in enumerate(self.layers):
            x = layer(x, kv=self.cache[l])
        return x