# ------------------------------------------------------------------------
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Text decoder."""

try:
    from flash_attn import flash_attn_func
    from flash_attn import flash_attn_with_kvcache
    from flash_attn.layers.rotary import apply_rotary_emb
except ImportError:
    flash_attn_func = None
    flash_attn_with_kvcache = None
    apply_rotary_emb = None

import torch
from torch import nn


class TransformerCache(nn.Module):
    """Transformer cache module."""

    def __init__(self, device=None, dtype=None):
        super(TransformerCache, self).__init__()
        self.device = device
        self.dtype = dtype
        self.start_pos = 0
        self.cache_dict = {}

    def init_seq(self, max_batch_size):
        seq_lens = torch.zeros(max_batch_size, dtype=torch.int32, device=self.device)
        self.cache_dict["seq_lens"] = seq_lens

    def init_rotary(self, seq_len, dim, theta=10000.0):
        grid = torch.arange(seq_len, dtype=torch.float32).unsqueeze_(-1)
        freq = torch.pow(theta, torch.arange(0, dim, 2)[: dim // 2].float().div_(dim))
        broadcast_freq = grid.mul(freq.reciprocal_().unsqueeze_(0))
        cache_cos = broadcast_freq.cos().view((-1, dim // 2))
        cache_sin = broadcast_freq.sin().view((-1, dim // 2))
        self.cache_dict["cos"] = cache_cos.to(self.device, self.dtype)
        self.cache_dict["sin"] = cache_sin.to(self.device, self.dtype)

    def init_kv(self, mixer, kv_size):
        cache_k = torch.zeros(*kv_size, dtype=self.dtype, device=self.device)
        cache_v = torch.zeros(*kv_size, dtype=self.dtype, device=self.device)
        self.cache_dict[f"{id(mixer)}_k"] = cache_k
        self.cache_dict[f"{id(mixer)}_v"] = cache_v

    def set_seq(self, start_pos=0, end_pos=None):
        self.start_pos = start_pos
        if "seq_lens" in self.cache_dict:
            self.cache_dict["seq_lens"].fill_(start_pos)
        if "cos" in self.cache_dict and end_pos is not None:
            self.cache_dict["seq_cos"] = self.cache_dict["cos"][self.start_pos : end_pos]
            self.cache_dict["seq_sin"] = self.cache_dict["sin"][self.start_pos : end_pos]

    def forward_rotary(self, q, k, inplace=False):
        cos = self.cache_dict.get("seq_cos", self.cache_dict.get("cos", None))
        sin = self.cache_dict.get("seq_sin", self.cache_dict.get("sin", None))
        if cos is None or sin is None:
            return q, k
        q = apply_rotary_emb(q, cos, sin, interleaved=True, inplace=inplace)
        k = apply_rotary_emb(k, cos, sin, interleaved=True, inplace=inplace)
        return q, k

    def forward_flash(self, mixer, q, k, v):
        cache_k = self.cache_dict.get(f"{id(mixer)}_k", None)
        cache_v = self.cache_dict.get(f"{id(mixer)}_v", None)
        flash_args = {"softmax_scale": mixer.scale, "causal": True}
        if cache_k is None or cache_v is None:
            return flash_attn_func(q, k, v, **flash_args)
        flash_args["cache_seqlens"] = self.cache_dict["seq_lens"][: q.shape[0]]
        return flash_attn_with_kvcache(q, cache_k, cache_v, k, v, **flash_args)


class Attention(nn.Module):
    """Self-Attention layer."""

    def __init__(self, dim, num_heads, bias=True):
        super(Attention, self).__init__()
        self.qkv = nn.Linear(dim, dim * 3, bias=bias)
        self.proj = nn.Linear(dim, dim, bias=bias)
        self.head_dim = dim // num_heads
        self.num_heads = num_heads
        self.scale = self.head_dim**-0.5
        self.cache = nn.Module()

    def forward(self, x):
        qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim)
        q, k, v = self.qkv(x).view(qkv_shape).unbind(dim=2)
        q, k = self.cache.forward_rotary(q, k, inplace=True)
        o = self.cache.forward_flash(self, q, k, v)
        return self.proj(o.flatten(2))


class MLP(nn.Module):
    """Two layers MLP."""

    def __init__(self, dim, mlp_dim, bias=True):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(dim, mlp_dim, bias=bias)
        self.fc2 = nn.Linear(mlp_dim, dim, bias=bias)
        self.activation = nn.GELU()

    def forward(self, x):
        return self.fc2(self.activation(self.fc1(x)))


class Block(nn.Module):
    """Transformer block."""

    def __init__(self, dim, num_heads, mlp_dim, bias=True):
        super(Block, self).__init__()
        self.attn = Attention(dim, num_heads, bias=bias)
        self.mlp = MLP(dim, mlp_dim, bias=bias)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        x = self.attn(self.norm1(x)).add_(x)
        return self.mlp(self.norm2(x)).add_(x)


class Transformer(nn.Module):
    """Causal transformer decoder."""

    def __init__(self, depth, dim, num_heads, mlp_dim, vocab_size):
        super(Transformer, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.vocab_size = vocab_size
        self.tok_embeddings = nn.Embedding(vocab_size, dim)
        self.blocks = nn.ModuleList(Block(dim, num_heads, mlp_dim) for _ in range(depth))
        self.norm = nn.LayerNorm(dim)
        self.text_proj = nn.Linear(dim, vocab_size, bias=False)

    def forward(self, prompts, tokens, start_pos=0):
        prompt_len = prompts.size(1)
        start_pos = start_pos + (prompt_len if start_pos > 0 else 0)
        end_pos = start_pos + tokens.size(1) + (0 if start_pos > 0 else prompt_len)
        self.cache.set_seq(start_pos, end_pos)
        x = self.tok_embeddings(tokens)
        x = x if start_pos > 0 else torch.cat([prompts, x], dim=1)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x[:, 0 if start_pos > 0 else prompt_len :])
        return self.text_proj(x).float()


class TextDecoder(nn.Module):
    """Module to decode texts."""

    def __init__(
        self,
        depth,
        embed_dim,
        num_heads,
        mlp_ratio,
        prompt_embed_dim,
        max_seq_len,
        vocab_size,
    ):
        super(TextDecoder, self).__init__()
        self.max_seq_len = max_seq_len
        self.max_text_len = self.max_seq_len - 1
        self.encoder = nn.Linear(prompt_embed_dim, embed_dim, bias=False)
        self.transformer = Transformer(
            depth=depth,
            dim=embed_dim,
            mlp_dim=embed_dim * mlp_ratio,
            num_heads=num_heads,
            vocab_size=vocab_size,
        )

    def reset_cache(self, max_batch_size=1, max_seq_len=None):
        device, dtype = self.encoder.weight.device, self.encoder.weight.dtype
        max_seq_len = self.max_seq_len if max_seq_len is None else max_seq_len
        num_heads, head_dim = self.transformer.num_heads, self.transformer.head_dim
        self.transformer.cache = TransformerCache(device=device, dtype=dtype)
        self.transformer.cache.init_seq(max_batch_size)
        self.transformer.cache.init_rotary(max_seq_len, head_dim, theta=10000.0)
        kv_cache_size = (max_batch_size, max_seq_len, num_heads, head_dim)
        for blk in self.transformer.blocks:
            blk.attn.__dict__["cache"] = self.transformer.cache
            self.transformer.cache.init_kv(blk.attn, kv_cache_size) if not self.training else None

    def get_prompts(self, prompt_tokens):
        return self.encoder(prompt_tokens)

    def get_outputs(self, inputs, start_pos=0):
        return {"text_pred": self.transformer(inputs["prompts"], inputs["tokens"], start_pos)}

    def forward(self, inputs, start_pos=0):
        return self.get_outputs(inputs, start_pos)