M3Site / esm /utils /function /encode_decode.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
import re
from typing import Sequence
import torch
from esm.models.function_decoder import (
FunctionTokenDecoder,
_merge_annotations,
)
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.tokenization.residue_tokenizer import (
ResidueAnnotationsTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.types import FunctionAnnotation
def encode_function_annotations(
sequence: str,
function_annotations: Sequence[FunctionAnnotation],
function_tokens_tokenizer: InterProQuantizedTokenizer,
residue_annotations_tokenizer: ResidueAnnotationsTokenizer,
add_special_tokens: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
assert isinstance(
residue_annotations_tokenizer, ResidueAnnotationsTokenizer
), "residue_annotations_tokenizer must be of type ResidueAnnotationsTokenizer"
# Split the user's annotations by type
ft_annotations: list[FunctionAnnotation] = []
ra_annotations: list[FunctionAnnotation] = []
for fa in function_annotations:
assert (
1 <= fa.start <= fa.end <= len(sequence)
), f"Invalid (start, end) in function annotation {fa}. Indices 1-indexed and [inclusive, inclusive]"
supported_label = False
# Is it an InterPro label?
if match := re.match(r"IPR\d+", fa.label):
if match.group() in function_tokens_tokenizer.interpro_to_index:
ft_annotations.append(fa)
supported_label = True
# Is it a function keyword?
if fa.label in function_tokens_tokenizer._tfidf.vocab_to_index:
ft_annotations.append(fa)
supported_label = True
# Is it a residue annotation?
if fa.label in residue_annotations_tokenizer._labels:
ra_annotations.append(fa)
supported_label = True
if not supported_label:
raise ValueError(f"Unknown label in FunctionAnnotation: {fa.label}")
# Convert function token FunctionAnnotations -> Tensor
function_tokens = function_tokens_tokenizer.tokenize(
annotations=ft_annotations,
seqlen=len(sequence),
)
function_token_ids = function_tokens_tokenizer.encode(
function_tokens, add_special_tokens=add_special_tokens
)
# Convert residue annotation FunctionAnnotations -> Tensor
if ra_annotations:
descriptions, starts, ends = zip(
*[(anot.label, anot.start, anot.end) for anot in ra_annotations]
)
else:
descriptions = starts = ends = None
ra_tokens = residue_annotations_tokenizer.tokenize(
{
"interpro_site_descriptions": descriptions,
"interpro_site_starts": starts,
"interpro_site_ends": ends,
},
sequence=sequence,
fail_on_mismatch=True,
)
residue_annotation_ids = residue_annotations_tokenizer.encode(
ra_tokens, add_special_tokens=add_special_tokens
)
return function_token_ids, residue_annotation_ids
def decode_function_tokens(
function_token_ids: torch.Tensor,
function_token_decoder: FunctionTokenDecoder,
function_tokens_tokenizer: InterProQuantizedTokenizer,
decoder_annotation_threshold: float = 0.1,
annotation_min_length: int | None = 5,
annotation_gap_merge_max: int | None = 3,
) -> list[FunctionAnnotation]:
"""Decodes model prediction logits into function predictions.
Merges function token and residue annotation predictions into a single
set of FunctionAnnotation predictions.
Args:
function_token_ids: Tensor <float>[length, depth] of
function token ids.
residue_annotation_logits: Tensor <float>[length, RA-vocab] of residue
annotation binary classification logits.
function_tokens_tokenizer: InterPro annotation tokenizer.
residue_annotation_threshold: tokenizer of residue annotations.
residue_annotation_threshold: predicted probability threshold for emitting
a predicted residue annotation.
Returns:
Predicted function annotations merged from both predictions.
"""
assert (
function_token_ids.ndim == 2
), "function_token_ids must be of shape (length, depth)"
annotations: list[FunctionAnnotation] = []
# Function Annotations from predicted function tokens.
decoded = function_token_decoder.decode(
function_token_ids,
tokenizer=function_tokens_tokenizer,
annotation_threshold=decoder_annotation_threshold,
annotation_min_length=annotation_min_length,
annotation_gap_merge_max=annotation_gap_merge_max,
)
# Convert predicted InterPro annotation to FunctionAnnotation.
annotations.extend(decoded["function_keywords"])
for annotation in decoded["interpro_annotations"]:
annotation: FunctionAnnotation
label = function_tokens_tokenizer.format_annotation(annotation)
annotations.append(
FunctionAnnotation(label=label, start=annotation.start, end=annotation.end)
)
return annotations
def decode_residue_annotation_tokens(
residue_annotations_token_ids: torch.Tensor,
residue_annotations_tokenizer: ResidueAnnotationsTokenizer,
annotation_min_length: int | None = 5,
annotation_gap_merge_max: int | None = 3,
) -> list[FunctionAnnotation]:
"""Decodes residue annotation tokens into FunctionAnnotations.
Args:
tokens: Tensor <int>[length, MAX_RESIDUE_ANNOTATIONS] of residue annotation tokens.
residue_annotations_tokenizer: Tokenizer of residue annotations.
threshold: predicted probability threshold for emitting a predicted residue
annotation.
Returns:
Predicted residue annotations.
"""
assert (
residue_annotations_token_ids.ndim == 2
), "logits must be of shape (length, MAX_RESIDUE_ANNOTATIONS)"
annotations: list[FunctionAnnotation] = []
for depth in range(0, C.MAX_RESIDUE_ANNOTATIONS):
token_ids = residue_annotations_token_ids[:, depth]
for loc, vocab_index in torch.nonzero(token_ids).cpu().numpy():
label = residue_annotations_tokenizer.vocabulary[vocab_index]
if label not in [*residue_annotations_tokenizer.special_tokens, "<none>"]:
annotation = FunctionAnnotation(label=label, start=loc, end=loc)
annotations.append(annotation)
annotations = _merge_annotations(
annotations,
merge_gap_max=annotation_gap_merge_max,
)
# Drop very small annotations.
if annotation_min_length is not None:
annotations = [
annotation
for annotation in annotations
if annotation.end - annotation.start + 1 >= annotation_min_length
]
return annotations