from functools import cached_property from pathlib import Path from typing import Any import pandas as pd import torch import torch.nn.functional as F from esm.tokenization.tokenizer_base import EsmTokenizerBase from esm.utils.constants import esm3 as C Sample = dict[str, Any] class ResidueAnnotationsTokenizer(EsmTokenizerBase): def __init__( self, csv_path: str | None = None, max_annotations: int = 16, ): if csv_path is None: csv_path = str(C.data_root() / C.RESID_CSV) self.csv_path = csv_path self.max_annotations = max_annotations @cached_property def _description2label(self) -> dict[str, str]: with Path(self.csv_path).open() as f: # type: ignore df = pd.read_csv(f) return dict(zip(df.label, df.label_clean)) @cached_property def _labels(self) -> list[str]: with Path(self.csv_path).open() as f: # type: ignore df = pd.read_csv(f) labels = ( df.groupby("label_clean")["count"] .sum() .sort_values(ascending=False, kind="stable") # type: ignore .index.tolist() ) assert isinstance(labels, list) return labels # type: ignore def _description2id(self, description: str) -> int | None: label = self._description2label.get(description) return self._label2id.get(label) # type: ignore @cached_property def _label2id(self) -> dict[str, int]: offset = len(self.special_tokens) + 1 # +1 for "" return {label: offset + i for i, label in enumerate(self._labels)} @cached_property def special_tokens(self) -> list[str]: """List of special tokens which come before cluster toknes in vocab.""" return ["", "", ""] @cached_property def vocab(self): annotation_tokens = [f"" for _, id in self._label2id.items()] return self.special_tokens + [""] + annotation_tokens @cached_property def vocab_to_index(self) -> dict[str, int]: return {token: token_id for token_id, token in enumerate(self.vocab)} @cached_property def vocabulary(self) -> list[str]: """Full vocabulary.""" return [*self.special_tokens, "", *self._labels] def get_special_tokens_mask(self, encoded: torch.Tensor) -> torch.Tensor: """Determines where in the sequence are special tokens.""" return encoded[:, 0] < len(self.special_tokens) def tokenize( self, sample: Sample | None, sequence: str, fail_on_mismatch: bool = False ) -> list[str]: """ # interpro_site_starts # interpro_site_ends # should always == interpro_site_starts. but I haven't checked overall. # interpro_site_residues # the residue identity of the specfic residue that is annotated. good for a sanity check that parsing occurred correctly. # interpro_site_descriptions # ASSERT (i.e. drop if bad) # interpro_site_residues matches the residue at that position # all these lists ^ above are the same length """ seqlen = len(sequence) assert seqlen >= 0 # None mean sequence is *not annotated* - so use full if sample is None: return [""] * seqlen if any( sample.get(field) is None for field in [ "interpro_site_descriptions", "interpro_site_starts", "interpro_site_ends", "interpro_site_residues", ] ): return [""] * seqlen num_annotations = len(sample["interpro_site_descriptions"]) if any( len(sample[field]) != num_annotations for field in [ "interpro_site_starts", "interpro_site_ends", "interpro_site_residues", ] ): # mismatched length. return [""] * seqlen positional_ids = [set() for _ in range(seqlen)] for description, start, end, residues in zip( sample["interpro_site_descriptions"], sample["interpro_site_starts"], sample["interpro_site_ends"], sample["interpro_site_residues"], ): try: start = int(start) end = int(end) except (TypeError, ValueError): continue # Start / End are 1-indexed [inclusive, inclusive]. if start <= 0 or end > seqlen or start > end: print(f"invalid start/end: ({start}, {end}), len: {seqlen}") continue if len(residues) != (end - start) + 1: print(f"bad reference residue: {residues}") continue token_id = self._description2id(description) if token_id is None: token_id = self.vocab_to_index[""] for i, residue in zip(range(start - 1, end), residues): # If there are any mismatching residues, skip the entire sample. if sequence[i] != residue: if fail_on_mismatch: raise ValueError( f"Residue mismatch at position {i} (1-indexed): {sequence[i]} != {residue}" ) return [""] * seqlen positional_ids[i].add(token_id) tokens = [] for token_ids in positional_ids: if token_ids: token = "" else: token = "" tokens.append(token) return tokens def _token2ids(self, token: str) -> list[int]: if token.startswith(""): return [int(token_id) for token_id in token[4:-1].split(",")] else: token_id = self.vocab_to_index[token] return [token_id] def encode( self, tokens: list[str], add_special_tokens: bool = True ) -> torch.Tensor: token_ids = torch.full( size=(len(tokens), self.max_annotations), dtype=torch.int64, fill_value=self.vocab_to_index[""], ) for i, token in enumerate(tokens): ids = self._token2ids(token)[: self.max_annotations] token_ids[i, : len(ids)] = torch.tensor(ids) if add_special_tokens: token_ids = F.pad( token_ids, (0, 0, 1, 1), value=self.vocab_to_index[""] ) return token_ids def decode(self, encoded: torch.Tensor) -> list[str]: raise NotImplementedError( "Residue annotation decoding should be handled with util.decoding.decode_residue_annotations" ) @property def mask_token(self) -> str: return "" @property def mask_token_id(self) -> int: return self.vocab_to_index[self.mask_token] @property def bos_token(self) -> str: return "" @property def bos_token_id(self) -> int: return self.vocab_to_index[self.bos_token] @property def eos_token(self) -> str: return "" @property def eos_token_id(self) -> int: return self.vocab_to_index[self.eos_token] @property def pad_token(self) -> str: return "" @property def pad_token_id(self) -> int: return self.vocab_to_index[self.pad_token]