from dataclasses import dataclass, field
from typing import Literal

import torch
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast

IM_START_TOKEN = "<|im_start|>"
IM_END_TOKEN = "<|im_end|>"
SEMANTIC_TOKEN = "<|semantic|>"
MEL_TOKEN = "<|mel|>"
PHONEME_START_TOKEN = "<|phoneme_start|>"
PHONEME_END_TOKEN = "<|phoneme_end|>"
ALL_SPECIAL_TOKENS = [
    IM_START_TOKEN,
    IM_END_TOKEN,
    SEMANTIC_TOKEN,
    MEL_TOKEN,
    PHONEME_START_TOKEN,
    PHONEME_END_TOKEN,
]

CODEBOOK_PAD_TOKEN_ID = 0


class FishTokenizerConfig(PretrainedConfig):
    share_codebook_embeddings: bool = True
    codebook_size: int = 1024
    num_codebooks: int = 8


class FishTokenizerFast(PreTrainedTokenizerFast):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True)
        self.codebook_size = kwargs.pop("codebook_size", 1024)
        self.num_codebooks = kwargs.pop("num_codebooks", 8)


AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast)


@dataclass(kw_only=True)
class BasePart:
    pass


@dataclass(kw_only=True)
class VQPart(BasePart):
    codes: torch.Tensor


@dataclass(kw_only=True)
class TextPart(BasePart):
    text: str


@dataclass(kw_only=True)
class MelPart(BasePart):
    mels: torch.Tensor


@dataclass(kw_only=True)
class EncodedMessage:
    tokens: torch.Tensor
    labels: torch.Tensor
    vq_parts: list[torch.Tensor]
    mel_parts: list[torch.Tensor]
    vq_require_losses: torch.Tensor | None = None


@dataclass(kw_only=True)
class Message:
    role: Literal["system", "user", "assistant"]
    parts: list[VQPart | TextPart | MelPart] = field(default_factory=list)
    add_im_start: bool = True
    add_im_end: bool = True
    cal_loss: bool = False

    # By default, ignore the loss of the auto-generated im_start token
    ignore_im_start_loss: bool = True

    def encode(
        self: "Message",
        tokenizer: AutoTokenizer,
    ) -> EncodedMessage:
        all_tokens = []
        all_labels = []

        # Multi-modal tokens
        vq_parts = []
        mel_parts = []

        semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
            [SEMANTIC_TOKEN, MEL_TOKEN]
        )

        parts = self.parts.copy()
        if self.add_im_start:
            parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n"))

        if self.add_im_end:
            parts.append(TextPart(text="<|im_end|>"))

        for part in parts:
            if isinstance(part, TextPart):
                tokens = tokenizer.encode(
                    part.text,
                    add_special_tokens=False,
                    truncation=False,
                    return_tensors="pt",
                ).int()[0]
            elif isinstance(part, VQPart):
                tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id
                codes = part.codes.clone() + 1

                if getattr(tokenizer, "share_codebook_embeddings", True) is False:
                    for i in range(len(codes)):
                        codes[i] += tokenizer.codebook_size * i

                vq_parts.append(codes)
            elif isinstance(part, MelPart):
                tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id
                mel_parts.append(part.mels)
            else:
                raise ValueError(f"Unsupported part type: {type(part)}")

            all_tokens.append(tokens)
            if self.cal_loss:
                all_labels.append(tokens.clone())
            else:
                all_labels.append(torch.full_like(tokens, -100))

        tokens = torch.cat(all_tokens, dim=0)
        labels = torch.cat(all_labels, dim=0)
        assert tokens.shape == labels.shape

        if self.ignore_im_start_loss and self.add_im_start:
            labels[: len(all_tokens[0])] = -100

        return EncodedMessage(
            tokens=tokens,
            labels=labels,
            vq_parts=vq_parts,
            mel_parts=mel_parts,
        )


@dataclass
class Conversation:
    messages: list[Message]

    def encode(
        self: "Conversation",
        tokenizer: AutoTokenizer,
        add_shift: bool = True,
    ) -> EncodedMessage:
        # Build the input_ids and labels
        tokens = []
        labels = []
        vq_parts = []
        mel_parts = []
        vq_require_losses = []

        for message in self.messages:
            encoded = message.encode(
                tokenizer,
            )
            tokens.append(encoded.tokens)
            labels.append(encoded.labels)
            vq_parts.extend(encoded.vq_parts)
            mel_parts.extend(encoded.mel_parts)
            vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))

        tokens = torch.cat(tokens, dim=0)
        labels = torch.cat(labels, dim=0)
        vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)

        if add_shift:
            tokens = tokens[:-1]
            labels = labels[1:]

        assert tokens.dtype in [
            torch.int,
            torch.long,
        ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"

        return EncodedMessage(
            tokens=tokens,
            labels=labels,
            vq_parts=vq_parts,
            mel_parts=mel_parts,
            vq_require_losses=vq_require_losses,
        )

    def encode_for_inference(
        self: "Conversation",
        tokenizer: AutoTokenizer,
        num_codebooks: int,
    ) -> EncodedMessage:
        encoded = self.encode(tokenizer, add_shift=False)
        tokens = encoded.tokens
        values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
        values[0] = tokens

        if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
            return values

        semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
            [SEMANTIC_TOKEN, MEL_TOKEN]
        )
        vq_parts = encoded.vq_parts
        vq_parts = torch.cat(vq_parts, dim=1)
        values[1:, tokens == semantic_id] = vq_parts
        return values

    def visualize(self: "Conversation", tokenizer: AutoTokenizer):
        encoded = self.encode(tokenizer, add_shift=False)

        print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="")
        print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="")

        for tok, lab in zip(encoded.tokens, encoded.labels):
            val = tokenizer.decode(tok, skip_special_tokens=False)
            if val == "\n":
                val = "\\n\n"

            if lab == -100:
                print_in_green(val)
            else:
                print_in_blue(val)

        print()


if __name__ == "__main__":
    message0 = Message(
        role="user",
        parts=[
            TextPart(text="Hello, how are you?"),
            VQPart(codes=torch.zeros((4, 10))),
        ],
        cal_loss=False,
    )

    message1 = Message(
        role="assistant",
        parts=[TextPart(text="I'm fine, thank you.")],
        cal_loss=True,
    )
    conversation = Conversation([message0, message1])
    tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
    conversation.visualize(tokenizer)

    encoded = conversation.encode(tokenizer)
    print(encoded)
    print(tokenizer.batch_decode(encoded.tokens))