"""Tokenizes annotations of protein function."""

import re
import string
from functools import cache, cached_property, partial
from typing import Collection

import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch
import torch.nn.functional as F

from esm.tokenization.tokenizer_base import EsmTokenizerBase
from esm.utils.constants import esm3 as C
from esm.utils.function import interpro, lsh, tfidf
from esm.utils.misc import stack_variable_length_tensors
from esm.utils.types import FunctionAnnotation


class InterProQuantizedTokenizer(EsmTokenizerBase):
    """Tokenizer for functional annotations.

    This tokenizer converts InterPro and/or function keywords into a multi-token
    representation by hashing TF-IDF vector representations of the text associated with
    the fuction and then applying a locality sensitive hash (LSH).
    """

    def __init__(
        self,
        depth: int = 8,
        lsh_bits_per_token: int = 8,
        lsh_path: str | None = None,
        keyword_vocabulary_path: str | None = None,
        keyword_idf_path: str | None = None,
        interpro_entry_path: str | None = None,
        interpro2keywords_path: str | None = None,
    ):
        """Constructs function tokenizer.

        Args:
            depth: number of tokens emitted in each position.
            lsh_bits_per_token: Number of LSH bits per token. Determines the vocabulary
                  size.
            lsh_path: path to locality sensitive hash (LSH) hyperplanes.
            keyword_vocabulary_path: path to csv containing function keyword vocabulary.
            keyword_idf_path: path to IDF values for each keyword.
            interpro_entry_csv_path: path to list of InterPro entries in CSV format.
            interpro2keywords_path: path to CSV mapping InterPro IDs to function keywords.
        """
        self.depth = depth
        default = lambda x, d: x if x is not None else C.data_root() / d

        self.keyword_vocabulary_path = default(
            keyword_vocabulary_path, C.KEYWORDS_VOCABULARY
        )
        self.keyword_idf_path = default(keyword_idf_path, C.KEYWORDS_IDF)

        self._interpro2keywords_path = default(
            interpro2keywords_path, C.INTERPRO2KEYWORDS
        )
        self.interpro_ = interpro.InterPro(
            entries_path=default(interpro_entry_path, C.INTERPRO_ENTRY)
        )

        self.lsh_vocab_size = 1 << lsh_bits_per_token
        self._lsh = lsh.LSHTokenized(
            lsh_bits_per_token,
            len(self.keyword_vocabulary),
            self.depth,
            default(lsh_path, C.LSH_TABLE_PATHS["8bit"]),
        )

        # This is the offset into the vocabulary where LSH tokens start.
        self._lsh_token_vocab_offset = len(self.special_tokens) + 1  # +1 for <none>

    @cached_property
    def interpro2keywords(self) -> dict[str, list[str]]:
        """Mapping from InterPro ID to function keywords."""
        df = pd.read_csv(self._interpro2keywords_path)
        assert "interpro_id" in df.columns and "keywords" in df.columns, df.columns
        return dict(zip(df.interpro_id, df.keywords.str.split(",")))

    @cached_property
    def interpro_labels(self) -> list[str]:
        """The set of supported InterPro labels."""
        return sorted(self.interpro2keywords.keys())

    @cached_property
    def interpro_to_index(self) -> dict[str, int]:
        """Mapping from InterPro id to index."""
        return {id: i for i, id in enumerate(self.interpro_labels)}

    @property
    def keyword_vocabulary(self) -> list[str]:
        """Set of supported keywords."""
        return self._tfidf.vocabulary

    @property
    def keyword_to_index(self) -> dict[str, int]:
        """Mapping from keywords to index."""
        return self._tfidf.vocab_to_index

    @cached_property
    def _tfidf(self) -> tfidf.TFIDFModel:
        """Creates TF-IDF model for encoding function keywords."""
        return tfidf.TFIDFModel(
            vocabulary_path=self.keyword_vocabulary_path,
            idf_path=self.keyword_idf_path,
        )

    @cached_property
    def special_tokens(self) -> list[str]:
        """List of special tokens which come before cluster tokens in vocab."""
        return ["<pad>", "<motif>", "<unk>"]

    @cached_property
    def vocab(self) -> list[str]:
        """Vocabulary of function tokens."""
        lsh_tokens = [f"<lsh:{i}>" for i in range(self.lsh_vocab_size)]
        return self.special_tokens + ["<none>"] + lsh_tokens

    @cached_property
    def vocab_to_index(self) -> dict[str, int]:
        return {token: token_id for token_id, token in enumerate(self.vocab)}

    def get_special_tokens_mask(self, encoded: torch.Tensor) -> torch.Tensor:
        """Determines where in the sequence are special tokens."""
        where = encoded < len(self.special_tokens)
        assert torch.all(torch.all(where, dim=1) | torch.all(~where, dim=1))
        return where[:, 0]

    def tokenize(
        self,
        annotations: list[FunctionAnnotation],
        seqlen: int,
        p_keyword_dropout: float = 0.0,
    ) -> list[str]:
        """Encodes range-annotations of protein function as tokens.

        Args:
            features: Annotated function ranges, either as InterPro ids or keywords.
            seqlen: length of sequence.
            p_keyword_dropout: Optional probability of dropping out keywords from the
              input annotations.
        Returns:
            Tokenized representation of function annotations as a list of string tokens
            of size seqlen.
        """
        assert seqlen >= 0

        if not annotations:
            return ["<pad>"] * seqlen

        # Expand the range annotations into positional annotaiton sets.
        positional_labels: list[set[str]] = [set() for _ in range(seqlen)]
        for annotation in annotations:
            assert 1 <= annotation.start <= annotation.end <= seqlen, (
                f"Invalid annotation range [{annotation.start}, {annotation.end}] for "
                f"sequence length {seqlen}."
            )
            for i in range(annotation.start - 1, annotation.end):
                positional_labels[i].add(annotation.label)

        if p_keyword_dropout > 0:
            keyword_mask = (
                np.random.random(len(self._tfidf.vocabulary)) < p_keyword_dropout
            )
        else:
            keyword_mask = None

        # Annotations tend to be repetitive over the length of the sequence - cache their
        # hashes to speed up tokenization.
        hash_fn = cache(partial(self._function_text_hash, keyword_mask=keyword_mask))

        tokens: list[str] = []
        for labels in positional_labels:
            if not labels:
                token = "<none>"
            else:
                lsh_hash = hash_fn(frozenset(labels))
                if lsh_hash is not None:
                    assert len(lsh_hash) == self.depth
                    token = "<lsh:" + ",".join(map(str, lsh_hash)) + ">"
                else:
                    token = "<unk>"

            tokens.append(token)

        return tokens

    def _function_text_hash(
        self,
        labels: Collection[str],
        keyword_mask: np.ndarray | None = None,
    ) -> np.ndarray | None:
        """Applies a locality sensitive hash (LSH) to function text.

        Args:
            labels: InterPro ids and/or keywords.
            keyword_mask: optional boolean array shaped (keyword_vocab_size,) indicating
                which keywords to drop before hashing.
        Returns:
            LSH shaped (depth,) or None if there is no text or keywords to hash.
        """
        # Split labels into either InterPro ids or keywords.
        interpro_ids = []
        keywords = []
        for label in labels:
            match = re.match(r"IPR\d+", label)
            if match and match.group() in self.interpro_to_index:
                interpro_ids.append(match.group())
            elif label in self._tfidf.vocab_to_index:
                keywords.append(label)
            else:
                raise ValueError(f"Unsupported: {label}")

        vec: sp.csr_matrix = self._tfidf.encode(keywords)

        # Perform an element-wise maximum over TF-IDF vectors from distinct tags to
        # avoid tags getting "washed out" by eg. 4 very similar tags. Keywords are
        # incorporated as another TF-IDF vector
        vec: sp.csr_matrix = self._tfidf.encode(keywords)
        for interpro_id in interpro_ids:
            interpro_keywords = self.interpro2keywords.get(interpro_id, [])
            vec_ = self._tfidf.encode(interpro_keywords)
            vec = vec.maximum(vec_)

        if keyword_mask is not None:
            vec.data *= 1 - np.take(keyword_mask, vec.indices)

        if vec.sum() == 0:
            return None

        return self._lsh(vec)[0, :]

    def encode(
        self, tokens: list[str], add_special_tokens: bool = True
    ) -> torch.Tensor:
        """Encodes string tokens as token-id tensor.

        Args:
            tokens: list of individual tokens. e.g. ["<none>", "<pq:1,2,3,4>"]
            add_special_tokens: whether to add a single pad token at the start and end
              of the sequence to act as <cls> and <eos> tokens.
        Returns:
            <int>[length, depth] function tokens. Length will be +2 of input tokens
            length when add_special_tokens is True.
        """
        token_ids = torch.zeros(size=(len(tokens), self.depth), dtype=torch.int64)
        for i, token in enumerate(tokens):
            token_ids[i, :] = torch.tensor(self._token2ids(token))
        if add_special_tokens:
            token_ids = F.pad(
                token_ids, (0, 0, 1, 1), value=self.vocab_to_index["<pad>"]
            )
        return token_ids

    def lookup_annotation_name(self, annotation: FunctionAnnotation) -> str | None:
        return self.interpro_.lookup_name(annotation.label)

    def format_annotation(self, annotation: FunctionAnnotation) -> str:
        annotation_name = self.lookup_annotation_name(annotation)
        if annotation_name is not None:
            return f"{annotation_name} ({annotation.label})"
        else:
            return annotation.label

    def _token2ids(self, token: str) -> list[int]:
        """Converts token into token_id set of length depth."""
        if re.match(r"<lsh:[\d+,]+>", token):
            lsh_ids = [int(lsh_id) for lsh_id in re.findall(r"\d+", token)]
            assert (
                len(lsh_ids) == self.depth
            ), f"Expected token to have {self.depth} ids found {lsh_ids}"
            return [self._lsh_token_vocab_offset + lsh_id for lsh_id in lsh_ids]
        elif token == "<none>" or token in self.special_tokens:
            return [self.vocab_to_index[token]] * self.depth
        else:
            raise ValueError(f"Unknown token: {token}")

    def batch_encode(
        self,
        token_batch: list[list[str]],
        add_special_tokens: bool = True,
    ) -> torch.Tensor:
        """Encodes batch of function tokens.

        Args:
            token_batch: batch of function tokens.
            add_special_tokens: whether to add special tokens.
        Returns:
            <int>[batch_size, max_length, depth] batch of encoded tokens.
        """
        encoded = [
            self.encode(tokens, add_special_tokens=add_special_tokens)
            for tokens in token_batch
        ]
        return stack_variable_length_tensors(
            encoded,
            constant_value=self.vocab_to_index["<pad>"],
        )

    def decode(self, encoded: torch.Tensor):
        raise NotImplementedError(
            "Function token decoding should be handled with "
            "util.decoding.decode_function_annotations"
        )

    @property
    def mask_token(self) -> str:
        return "<pad>"

    @property
    def mask_token_id(self) -> int:
        return self.vocab_to_index[self.mask_token]

    @property
    def bos_token(self) -> str:
        return "<pad>"

    @property
    def bos_token_id(self) -> int:
        return self.vocab_to_index[self.bos_token]

    @property
    def eos_token(self) -> str:
        return "<pad>"

    @property
    def eos_token_id(self) -> int:
        return self.vocab_to_index[self.eos_token]

    @property
    def pad_token(self) -> str:
        return "<pad>"

    @property
    def pad_token_id(self) -> int:
        return self.vocab_to_index[self.pad_token]


def _texts_to_keywords(texts: list[str]) -> list[str]:
    """Breaks InterPro/GO free-text description set into bag-of-n-grams for n={1,2}.

    Args:
        texts: collection of text descriptions, i.e. InterPro/GO names.
    Returns:
        Collection of terms/n-grams
    """
    keywords = []
    for text in texts:
        keywords.extend(_keywords_from_text(text))
    return keywords


def _keywords_from_text(text: str) -> list[str]:
    """Splits text into unigrams and bigrams."""
    elements = text.split(", ")

    terms = []
    for element in elements:
        element = _sanitize(element)
        words = element.split()

        # Add 1-mers
        terms.extend(words)

        # Add 2-mers
        for i in range(len(words) - 1):
            bigram = words[i] + " " + words[i + 1]
            terms.append(bigram)

    return [term for term in terms if len(term) > 1 and term not in _EXCLUDED_TERMS]


def _sanitize(text: str) -> str:
    text = text.replace("-", " ")
    text = text.translate(str.maketrans("", "", string.punctuation))
    text = text.lower()
    return text


# These terms are omitted from textual representations since they are pervasive and
# unspecific to particular protein function.
_EXCLUDED_TERMS = {
    "binding domain",
    "biological_process",
    "biological process",
    "biologicalprocess",
    "c",
    "cellular_component",
    "cellular component",
    "cellularcomponent",
    "cellular_process",
    "cellularprocess",
    "cellular process",
    "cellularprocess",
    "like domain",
    "molecular function",
    "molecular_function",
    "molecularfunction",
    "n",
}