from functools import cached_property
from pathlib import Path
from typing import Any

import pandas as pd
import torch
import torch.nn.functional as F

from esm.tokenization.tokenizer_base import EsmTokenizerBase
from esm.utils.constants import esm3 as C

Sample = dict[str, Any]


class ResidueAnnotationsTokenizer(EsmTokenizerBase):
    def __init__(
        self,
        csv_path: str | None = None,
        max_annotations: int = 16,
    ):
        if csv_path is None:
            csv_path = str(C.data_root() / C.RESID_CSV)
        self.csv_path = csv_path
        self.max_annotations = max_annotations

    @cached_property
    def _description2label(self) -> dict[str, str]:
        with Path(self.csv_path).open() as f:  # type: ignore
            df = pd.read_csv(f)
        return dict(zip(df.label, df.label_clean))

    @cached_property
    def _labels(self) -> list[str]:
        with Path(self.csv_path).open() as f:  # type: ignore
            df = pd.read_csv(f)
        labels = (
            df.groupby("label_clean")["count"]
            .sum()
            .sort_values(ascending=False, kind="stable")  # type: ignore
            .index.tolist()
        )
        assert isinstance(labels, list)
        return labels  # type: ignore

    def _description2id(self, description: str) -> int | None:
        label = self._description2label.get(description)
        return self._label2id.get(label)  # type: ignore

    @cached_property
    def _label2id(self) -> dict[str, int]:
        offset = len(self.special_tokens) + 1  # +1 for "<none>"
        return {label: offset + i for i, label in enumerate(self._labels)}

    @cached_property
    def special_tokens(self) -> list[str]:
        """List of special tokens which come before cluster toknes in vocab."""
        return ["<pad>", "<motif>", "<unk>"]

    @cached_property
    def vocab(self):
        annotation_tokens = [f"<ra:{id}>" for _, id in self._label2id.items()]
        return self.special_tokens + ["<none>"] + annotation_tokens

    @cached_property
    def vocab_to_index(self) -> dict[str, int]:
        return {token: token_id for token_id, token in enumerate(self.vocab)}

    @cached_property
    def vocabulary(self) -> list[str]:
        """Full vocabulary."""
        return [*self.special_tokens, "<none>", *self._labels]

    def get_special_tokens_mask(self, encoded: torch.Tensor) -> torch.Tensor:
        """Determines where in the sequence are special tokens."""
        return encoded[:, 0] < len(self.special_tokens)

    def tokenize(
        self, sample: Sample | None, sequence: str, fail_on_mismatch: bool = False
    ) -> list[str]:
        """
        # interpro_site_starts
        # interpro_site_ends  # should always == interpro_site_starts.  but I haven't checked overall.
        # interpro_site_residues  # the residue identity of the specfic residue that is annotated.  good for a sanity check that parsing occurred correctly.
        # interpro_site_descriptions
        # ASSERT (i.e. drop if bad)
        # interpro_site_residues matches the residue at that position
        # all these lists ^ above are the same length
        """
        seqlen = len(sequence)
        assert seqlen >= 0
        # None mean sequence is *not annotated* - so use full <pad>
        if sample is None:
            return ["<pad>"] * seqlen

        if any(
            sample.get(field) is None
            for field in [
                "interpro_site_descriptions",
                "interpro_site_starts",
                "interpro_site_ends",
                "interpro_site_residues",
            ]
        ):
            return ["<pad>"] * seqlen

        num_annotations = len(sample["interpro_site_descriptions"])
        if any(
            len(sample[field]) != num_annotations
            for field in [
                "interpro_site_starts",
                "interpro_site_ends",
                "interpro_site_residues",
            ]
        ):
            # mismatched length.
            return ["<pad>"] * seqlen

        positional_ids = [set() for _ in range(seqlen)]
        for description, start, end, residues in zip(
            sample["interpro_site_descriptions"],
            sample["interpro_site_starts"],
            sample["interpro_site_ends"],
            sample["interpro_site_residues"],
        ):
            try:
                start = int(start)
                end = int(end)
            except (TypeError, ValueError):
                continue

            # Start / End are 1-indexed [inclusive, inclusive].
            if start <= 0 or end > seqlen or start > end:
                print(f"invalid start/end: ({start}, {end}), len: {seqlen}")
                continue

            if len(residues) != (end - start) + 1:
                print(f"bad reference residue: {residues}")
                continue

            token_id = self._description2id(description)
            if token_id is None:
                token_id = self.vocab_to_index["<unk>"]

            for i, residue in zip(range(start - 1, end), residues):
                # If there are any mismatching residues, skip the entire sample.
                if sequence[i] != residue:
                    if fail_on_mismatch:
                        raise ValueError(
                            f"Residue mismatch at position {i} (1-indexed): {sequence[i]} != {residue}"
                        )
                    return ["<pad>"] * seqlen

                positional_ids[i].add(token_id)

        tokens = []
        for token_ids in positional_ids:
            if token_ids:
                token = "<ra:" + ",".join(str(token_id) for token_id in token_ids) + ">"
            else:
                token = "<none>"
            tokens.append(token)
        return tokens

    def _token2ids(self, token: str) -> list[int]:
        if token.startswith("<ra:") and token.endswith(">"):
            return [int(token_id) for token_id in token[4:-1].split(",")]
        else:
            token_id = self.vocab_to_index[token]
            return [token_id]

    def encode(
        self, tokens: list[str], add_special_tokens: bool = True
    ) -> torch.Tensor:
        token_ids = torch.full(
            size=(len(tokens), self.max_annotations),
            dtype=torch.int64,
            fill_value=self.vocab_to_index["<pad>"],
        )
        for i, token in enumerate(tokens):
            ids = self._token2ids(token)[: self.max_annotations]
            token_ids[i, : len(ids)] = torch.tensor(ids)

        if add_special_tokens:
            token_ids = F.pad(
                token_ids, (0, 0, 1, 1), value=self.vocab_to_index["<pad>"]
            )
        return token_ids

    def decode(self, encoded: torch.Tensor) -> list[str]:
        raise NotImplementedError(
            "Residue annotation decoding should be handled with util.decoding.decode_residue_annotations"
        )

    @property
    def mask_token(self) -> str:
        return "<pad>"

    @property
    def mask_token_id(self) -> int:
        return self.vocab_to_index[self.mask_token]

    @property
    def bos_token(self) -> str:
        return "<pad>"

    @property
    def bos_token_id(self) -> int:
        return self.vocab_to_index[self.bos_token]

    @property
    def eos_token(self) -> str:
        return "<pad>"

    @property
    def eos_token_id(self) -> int:
        return self.vocab_to_index[self.eos_token]

    @property
    def pad_token(self) -> str:
        return "<pad>"

    @property
    def pad_token_id(self) -> int:
        return self.vocab_to_index[self.pad_token]