"""Tokenizes annotations of protein function.""" import re import string from functools import cache, cached_property, partial from typing import Collection import numpy as np import pandas as pd import scipy.sparse as sp import torch import torch.nn.functional as F from esm.tokenization.tokenizer_base import EsmTokenizerBase from esm.utils.constants import esm3 as C from esm.utils.function import interpro, lsh, tfidf from esm.utils.misc import stack_variable_length_tensors from esm.utils.types import FunctionAnnotation class InterProQuantizedTokenizer(EsmTokenizerBase): """Tokenizer for functional annotations. This tokenizer converts InterPro and/or function keywords into a multi-token representation by hashing TF-IDF vector representations of the text associated with the fuction and then applying a locality sensitive hash (LSH). """ def __init__( self, depth: int = 8, lsh_bits_per_token: int = 8, lsh_path: str | None = None, keyword_vocabulary_path: str | None = None, keyword_idf_path: str | None = None, interpro_entry_path: str | None = None, interpro2keywords_path: str | None = None, ): """Constructs function tokenizer. Args: depth: number of tokens emitted in each position. lsh_bits_per_token: Number of LSH bits per token. Determines the vocabulary size. lsh_path: path to locality sensitive hash (LSH) hyperplanes. keyword_vocabulary_path: path to csv containing function keyword vocabulary. keyword_idf_path: path to IDF values for each keyword. interpro_entry_csv_path: path to list of InterPro entries in CSV format. interpro2keywords_path: path to CSV mapping InterPro IDs to function keywords. """ self.depth = depth default = lambda x, d: x if x is not None else C.data_root() / d self.keyword_vocabulary_path = default( keyword_vocabulary_path, C.KEYWORDS_VOCABULARY ) self.keyword_idf_path = default(keyword_idf_path, C.KEYWORDS_IDF) self._interpro2keywords_path = default( interpro2keywords_path, C.INTERPRO2KEYWORDS ) self.interpro_ = interpro.InterPro( entries_path=default(interpro_entry_path, C.INTERPRO_ENTRY) ) self.lsh_vocab_size = 1 << lsh_bits_per_token self._lsh = lsh.LSHTokenized( lsh_bits_per_token, len(self.keyword_vocabulary), self.depth, default(lsh_path, C.LSH_TABLE_PATHS["8bit"]), ) # This is the offset into the vocabulary where LSH tokens start. self._lsh_token_vocab_offset = len(self.special_tokens) + 1 # +1 for <none> @cached_property def interpro2keywords(self) -> dict[str, list[str]]: """Mapping from InterPro ID to function keywords.""" df = pd.read_csv(self._interpro2keywords_path) assert "interpro_id" in df.columns and "keywords" in df.columns, df.columns return dict(zip(df.interpro_id, df.keywords.str.split(","))) @cached_property def interpro_labels(self) -> list[str]: """The set of supported InterPro labels.""" return sorted(self.interpro2keywords.keys()) @cached_property def interpro_to_index(self) -> dict[str, int]: """Mapping from InterPro id to index.""" return {id: i for i, id in enumerate(self.interpro_labels)} @property def keyword_vocabulary(self) -> list[str]: """Set of supported keywords.""" return self._tfidf.vocabulary @property def keyword_to_index(self) -> dict[str, int]: """Mapping from keywords to index.""" return self._tfidf.vocab_to_index @cached_property def _tfidf(self) -> tfidf.TFIDFModel: """Creates TF-IDF model for encoding function keywords.""" return tfidf.TFIDFModel( vocabulary_path=self.keyword_vocabulary_path, idf_path=self.keyword_idf_path, ) @cached_property def special_tokens(self) -> list[str]: """List of special tokens which come before cluster tokens in vocab.""" return ["<pad>", "<motif>", "<unk>"] @cached_property def vocab(self) -> list[str]: """Vocabulary of function tokens.""" lsh_tokens = [f"<lsh:{i}>" for i in range(self.lsh_vocab_size)] return self.special_tokens + ["<none>"] + lsh_tokens @cached_property def vocab_to_index(self) -> dict[str, int]: return {token: token_id for token_id, token in enumerate(self.vocab)} def get_special_tokens_mask(self, encoded: torch.Tensor) -> torch.Tensor: """Determines where in the sequence are special tokens.""" where = encoded < len(self.special_tokens) assert torch.all(torch.all(where, dim=1) | torch.all(~where, dim=1)) return where[:, 0] def tokenize( self, annotations: list[FunctionAnnotation], seqlen: int, p_keyword_dropout: float = 0.0, ) -> list[str]: """Encodes range-annotations of protein function as tokens. Args: features: Annotated function ranges, either as InterPro ids or keywords. seqlen: length of sequence. p_keyword_dropout: Optional probability of dropping out keywords from the input annotations. Returns: Tokenized representation of function annotations as a list of string tokens of size seqlen. """ assert seqlen >= 0 if not annotations: return ["<pad>"] * seqlen # Expand the range annotations into positional annotaiton sets. positional_labels: list[set[str]] = [set() for _ in range(seqlen)] for annotation in annotations: assert 1 <= annotation.start <= annotation.end <= seqlen, ( f"Invalid annotation range [{annotation.start}, {annotation.end}] for " f"sequence length {seqlen}." ) for i in range(annotation.start - 1, annotation.end): positional_labels[i].add(annotation.label) if p_keyword_dropout > 0: keyword_mask = ( np.random.random(len(self._tfidf.vocabulary)) < p_keyword_dropout ) else: keyword_mask = None # Annotations tend to be repetitive over the length of the sequence - cache their # hashes to speed up tokenization. hash_fn = cache(partial(self._function_text_hash, keyword_mask=keyword_mask)) tokens: list[str] = [] for labels in positional_labels: if not labels: token = "<none>" else: lsh_hash = hash_fn(frozenset(labels)) if lsh_hash is not None: assert len(lsh_hash) == self.depth token = "<lsh:" + ",".join(map(str, lsh_hash)) + ">" else: token = "<unk>" tokens.append(token) return tokens def _function_text_hash( self, labels: Collection[str], keyword_mask: np.ndarray | None = None, ) -> np.ndarray | None: """Applies a locality sensitive hash (LSH) to function text. Args: labels: InterPro ids and/or keywords. keyword_mask: optional boolean array shaped (keyword_vocab_size,) indicating which keywords to drop before hashing. Returns: LSH shaped (depth,) or None if there is no text or keywords to hash. """ # Split labels into either InterPro ids or keywords. interpro_ids = [] keywords = [] for label in labels: match = re.match(r"IPR\d+", label) if match and match.group() in self.interpro_to_index: interpro_ids.append(match.group()) elif label in self._tfidf.vocab_to_index: keywords.append(label) else: raise ValueError(f"Unsupported: {label}") vec: sp.csr_matrix = self._tfidf.encode(keywords) # Perform an element-wise maximum over TF-IDF vectors from distinct tags to # avoid tags getting "washed out" by eg. 4 very similar tags. Keywords are # incorporated as another TF-IDF vector vec: sp.csr_matrix = self._tfidf.encode(keywords) for interpro_id in interpro_ids: interpro_keywords = self.interpro2keywords.get(interpro_id, []) vec_ = self._tfidf.encode(interpro_keywords) vec = vec.maximum(vec_) if keyword_mask is not None: vec.data *= 1 - np.take(keyword_mask, vec.indices) if vec.sum() == 0: return None return self._lsh(vec)[0, :] def encode( self, tokens: list[str], add_special_tokens: bool = True ) -> torch.Tensor: """Encodes string tokens as token-id tensor. Args: tokens: list of individual tokens. e.g. ["<none>", "<pq:1,2,3,4>"] add_special_tokens: whether to add a single pad token at the start and end of the sequence to act as <cls> and <eos> tokens. Returns: <int>[length, depth] function tokens. Length will be +2 of input tokens length when add_special_tokens is True. """ token_ids = torch.zeros(size=(len(tokens), self.depth), dtype=torch.int64) for i, token in enumerate(tokens): token_ids[i, :] = torch.tensor(self._token2ids(token)) if add_special_tokens: token_ids = F.pad( token_ids, (0, 0, 1, 1), value=self.vocab_to_index["<pad>"] ) return token_ids def lookup_annotation_name(self, annotation: FunctionAnnotation) -> str | None: return self.interpro_.lookup_name(annotation.label) def format_annotation(self, annotation: FunctionAnnotation) -> str: annotation_name = self.lookup_annotation_name(annotation) if annotation_name is not None: return f"{annotation_name} ({annotation.label})" else: return annotation.label def _token2ids(self, token: str) -> list[int]: """Converts token into token_id set of length depth.""" if re.match(r"<lsh:[\d+,]+>", token): lsh_ids = [int(lsh_id) for lsh_id in re.findall(r"\d+", token)] assert ( len(lsh_ids) == self.depth ), f"Expected token to have {self.depth} ids found {lsh_ids}" return [self._lsh_token_vocab_offset + lsh_id for lsh_id in lsh_ids] elif token == "<none>" or token in self.special_tokens: return [self.vocab_to_index[token]] * self.depth else: raise ValueError(f"Unknown token: {token}") def batch_encode( self, token_batch: list[list[str]], add_special_tokens: bool = True, ) -> torch.Tensor: """Encodes batch of function tokens. Args: token_batch: batch of function tokens. add_special_tokens: whether to add special tokens. Returns: <int>[batch_size, max_length, depth] batch of encoded tokens. """ encoded = [ self.encode(tokens, add_special_tokens=add_special_tokens) for tokens in token_batch ] return stack_variable_length_tensors( encoded, constant_value=self.vocab_to_index["<pad>"], ) def decode(self, encoded: torch.Tensor): raise NotImplementedError( "Function token decoding should be handled with " "util.decoding.decode_function_annotations" ) @property def mask_token(self) -> str: return "<pad>" @property def mask_token_id(self) -> int: return self.vocab_to_index[self.mask_token] @property def bos_token(self) -> str: return "<pad>" @property def bos_token_id(self) -> int: return self.vocab_to_index[self.bos_token] @property def eos_token(self) -> str: return "<pad>" @property def eos_token_id(self) -> int: return self.vocab_to_index[self.eos_token] @property def pad_token(self) -> str: return "<pad>" @property def pad_token_id(self) -> int: return self.vocab_to_index[self.pad_token] def _texts_to_keywords(texts: list[str]) -> list[str]: """Breaks InterPro/GO free-text description set into bag-of-n-grams for n={1,2}. Args: texts: collection of text descriptions, i.e. InterPro/GO names. Returns: Collection of terms/n-grams """ keywords = [] for text in texts: keywords.extend(_keywords_from_text(text)) return keywords def _keywords_from_text(text: str) -> list[str]: """Splits text into unigrams and bigrams.""" elements = text.split(", ") terms = [] for element in elements: element = _sanitize(element) words = element.split() # Add 1-mers terms.extend(words) # Add 2-mers for i in range(len(words) - 1): bigram = words[i] + " " + words[i + 1] terms.append(bigram) return [term for term in terms if len(term) > 1 and term not in _EXCLUDED_TERMS] def _sanitize(text: str) -> str: text = text.replace("-", " ") text = text.translate(str.maketrans("", "", string.punctuation)) text = text.lower() return text # These terms are omitted from textual representations since they are pervasive and # unspecific to particular protein function. _EXCLUDED_TERMS = { "binding domain", "biological_process", "biological process", "biologicalprocess", "c", "cellular_component", "cellular component", "cellularcomponent", "cellular_process", "cellularprocess", "cellular process", "cellularprocess", "like domain", "molecular function", "molecular_function", "molecularfunction", "n", }