Spaces:
Running
Running
from functools import cached_property | |
from pathlib import Path | |
from typing import Any | |
import pandas as pd | |
import torch | |
import torch.nn.functional as F | |
from esm.tokenization.tokenizer_base import EsmTokenizerBase | |
from esm.utils.constants import esm3 as C | |
Sample = dict[str, Any] | |
class ResidueAnnotationsTokenizer(EsmTokenizerBase): | |
def __init__( | |
self, | |
csv_path: str | None = None, | |
max_annotations: int = 16, | |
): | |
if csv_path is None: | |
csv_path = str(C.data_root() / C.RESID_CSV) | |
self.csv_path = csv_path | |
self.max_annotations = max_annotations | |
def _description2label(self) -> dict[str, str]: | |
with Path(self.csv_path).open() as f: # type: ignore | |
df = pd.read_csv(f) | |
return dict(zip(df.label, df.label_clean)) | |
def _labels(self) -> list[str]: | |
with Path(self.csv_path).open() as f: # type: ignore | |
df = pd.read_csv(f) | |
labels = ( | |
df.groupby("label_clean")["count"] | |
.sum() | |
.sort_values(ascending=False, kind="stable") # type: ignore | |
.index.tolist() | |
) | |
assert isinstance(labels, list) | |
return labels # type: ignore | |
def _description2id(self, description: str) -> int | None: | |
label = self._description2label.get(description) | |
return self._label2id.get(label) # type: ignore | |
def _label2id(self) -> dict[str, int]: | |
offset = len(self.special_tokens) + 1 # +1 for "<none>" | |
return {label: offset + i for i, label in enumerate(self._labels)} | |
def special_tokens(self) -> list[str]: | |
"""List of special tokens which come before cluster toknes in vocab.""" | |
return ["<pad>", "<motif>", "<unk>"] | |
def vocab(self): | |
annotation_tokens = [f"<ra:{id}>" for _, id in self._label2id.items()] | |
return self.special_tokens + ["<none>"] + annotation_tokens | |
def vocab_to_index(self) -> dict[str, int]: | |
return {token: token_id for token_id, token in enumerate(self.vocab)} | |
def vocabulary(self) -> list[str]: | |
"""Full vocabulary.""" | |
return [*self.special_tokens, "<none>", *self._labels] | |
def get_special_tokens_mask(self, encoded: torch.Tensor) -> torch.Tensor: | |
"""Determines where in the sequence are special tokens.""" | |
return encoded[:, 0] < len(self.special_tokens) | |
def tokenize( | |
self, sample: Sample | None, sequence: str, fail_on_mismatch: bool = False | |
) -> list[str]: | |
""" | |
# interpro_site_starts | |
# interpro_site_ends # should always == interpro_site_starts. but I haven't checked overall. | |
# interpro_site_residues # the residue identity of the specfic residue that is annotated. good for a sanity check that parsing occurred correctly. | |
# interpro_site_descriptions | |
# ASSERT (i.e. drop if bad) | |
# interpro_site_residues matches the residue at that position | |
# all these lists ^ above are the same length | |
""" | |
seqlen = len(sequence) | |
assert seqlen >= 0 | |
# None mean sequence is *not annotated* - so use full <pad> | |
if sample is None: | |
return ["<pad>"] * seqlen | |
if any( | |
sample.get(field) is None | |
for field in [ | |
"interpro_site_descriptions", | |
"interpro_site_starts", | |
"interpro_site_ends", | |
"interpro_site_residues", | |
] | |
): | |
return ["<pad>"] * seqlen | |
num_annotations = len(sample["interpro_site_descriptions"]) | |
if any( | |
len(sample[field]) != num_annotations | |
for field in [ | |
"interpro_site_starts", | |
"interpro_site_ends", | |
"interpro_site_residues", | |
] | |
): | |
# mismatched length. | |
return ["<pad>"] * seqlen | |
positional_ids = [set() for _ in range(seqlen)] | |
for description, start, end, residues in zip( | |
sample["interpro_site_descriptions"], | |
sample["interpro_site_starts"], | |
sample["interpro_site_ends"], | |
sample["interpro_site_residues"], | |
): | |
try: | |
start = int(start) | |
end = int(end) | |
except (TypeError, ValueError): | |
continue | |
# Start / End are 1-indexed [inclusive, inclusive]. | |
if start <= 0 or end > seqlen or start > end: | |
print(f"invalid start/end: ({start}, {end}), len: {seqlen}") | |
continue | |
if len(residues) != (end - start) + 1: | |
print(f"bad reference residue: {residues}") | |
continue | |
token_id = self._description2id(description) | |
if token_id is None: | |
token_id = self.vocab_to_index["<unk>"] | |
for i, residue in zip(range(start - 1, end), residues): | |
# If there are any mismatching residues, skip the entire sample. | |
if sequence[i] != residue: | |
if fail_on_mismatch: | |
raise ValueError( | |
f"Residue mismatch at position {i} (1-indexed): {sequence[i]} != {residue}" | |
) | |
return ["<pad>"] * seqlen | |
positional_ids[i].add(token_id) | |
tokens = [] | |
for token_ids in positional_ids: | |
if token_ids: | |
token = "<ra:" + ",".join(str(token_id) for token_id in token_ids) + ">" | |
else: | |
token = "<none>" | |
tokens.append(token) | |
return tokens | |
def _token2ids(self, token: str) -> list[int]: | |
if token.startswith("<ra:") and token.endswith(">"): | |
return [int(token_id) for token_id in token[4:-1].split(",")] | |
else: | |
token_id = self.vocab_to_index[token] | |
return [token_id] | |
def encode( | |
self, tokens: list[str], add_special_tokens: bool = True | |
) -> torch.Tensor: | |
token_ids = torch.full( | |
size=(len(tokens), self.max_annotations), | |
dtype=torch.int64, | |
fill_value=self.vocab_to_index["<pad>"], | |
) | |
for i, token in enumerate(tokens): | |
ids = self._token2ids(token)[: self.max_annotations] | |
token_ids[i, : len(ids)] = torch.tensor(ids) | |
if add_special_tokens: | |
token_ids = F.pad( | |
token_ids, (0, 0, 1, 1), value=self.vocab_to_index["<pad>"] | |
) | |
return token_ids | |
def decode(self, encoded: torch.Tensor) -> list[str]: | |
raise NotImplementedError( | |
"Residue annotation decoding should be handled with util.decoding.decode_residue_annotations" | |
) | |
def mask_token(self) -> str: | |
return "<pad>" | |
def mask_token_id(self) -> int: | |
return self.vocab_to_index[self.mask_token] | |
def bos_token(self) -> str: | |
return "<pad>" | |
def bos_token_id(self) -> int: | |
return self.vocab_to_index[self.bos_token] | |
def eos_token(self) -> str: | |
return "<pad>" | |
def eos_token_id(self) -> int: | |
return self.vocab_to_index[self.eos_token] | |
def pad_token(self) -> str: | |
return "<pad>" | |
def pad_token_id(self) -> int: | |
return self.vocab_to_index[self.pad_token] | |