M3Site / esm /models /function_decoder.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
"""Function Token Decoder."""
from collections import defaultdict
from dataclasses import dataclass, field
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.misc import merge_ranges
from esm.utils.types import FunctionAnnotation
@dataclass(frozen=True)
class FunctionTokenDecoderConfig:
"""Configures function token decoder."""
# Embedding dimension of decoder.
d_model: int = 1024
# Number of attention heads of decoder.
n_heads: int = 8
# Number of layers of decoder.
n_layers: int = 3
# Number of integer values that function tokens may assume.
function_token_vocab_size: int = 260
# Number of function tokens at each position.
function_token_depth: int = 8
# Number of InterPro labels that can be decoded.
num_interpro_classes: int = 29026
# Number of function keywords that can be decoded.
keyword_vocabulary_size: int = 58641
# List of supported InterPro ids.
interpro_entry_list: str = field(
default_factory=lambda: str(C.data_root() / C.INTERPRO_ENTRY)
)
# Path to keywords vocabulary.
keyword_vocabulary_path: str = field(
default_factory=lambda: str(C.data_root() / C.KEYWORDS_VOCABULARY)
)
# Whether to unpack LSH bits into single-bit tokens.
unpack_lsh_bits: bool = True
# The number of special tokens in the function tokenizer vocabulary which come
# before the LSH tokens.
num_special_tokens: int = 4
# The number of bits per LSH token in the function tokenizer.
bits_per_token: int = 8
class FunctionTokenDecoder(nn.Module):
def __init__(self, config: FunctionTokenDecoderConfig | None = None):
"""Constructs function token decoder."""
super().__init__()
if config is None:
config = FunctionTokenDecoderConfig()
self.config = config
# Get the supported set of InterPro ids.
df = pd.read_csv(config.interpro_entry_list, sep="\t")
self.interpro_ids = sorted(df.ENTRY_AC)
self.interpro2index = {
interpro_id: i for i, interpro_id in enumerate(self.interpro_ids)
}
assert len(self.interpro_ids) == config.num_interpro_classes
with open(config.keyword_vocabulary_path, "r") as f:
self.keywords_vocabulary: list[str] = list(f.read().strip().split("\n"))
assert len(self.keywords_vocabulary) == config.keyword_vocabulary_size
if config.unpack_lsh_bits:
vocab_size = 2 * config.function_token_depth * config.bits_per_token
else:
# Function-token id's re-use the same token ids at each position along the depth
# dimension, despite distinct meanings. The decoder should take this into
# account so create distinct embeddings for tokens at each position.
vocab_size = (
self.config.function_token_depth * self.config.function_token_vocab_size
)
self.embedding = nn.Embedding(
# Function-token id's re-use the same token ids at each position along the
# depth dimension, despite distinct meanings. The decoder should take this
# into account so create distinct embeddings for tokens at each position.
num_embeddings=(vocab_size),
embedding_dim=config.d_model,
)
self.decoder = TransformerStack(
d_model=config.d_model,
n_heads=config.n_heads,
v_heads=None,
n_layers=config.n_layers,
n_layers_geom=0,
scale_residue=False,
bias=True,
qk_layernorm=False,
ffn_type="gelu",
expansion_ratio=4,
)
self.heads = nn.ModuleDict(
{
# Binary classification head predicting which keywords are present.
"keyword_logits": RegressionHead(
d_model=config.d_model,
output_dim=config.keyword_vocabulary_size,
hidden_dim=4 * config.d_model,
),
# Regresses the TF-IDF value of each present keyword.
"keyword_tfidf": RegressionHead(
d_model=config.d_model,
output_dim=config.keyword_vocabulary_size,
hidden_dim=4 * config.d_model,
),
# Predicts which InterPro annotations are present.
"interpro_logits": RegressionHead(
d_model=config.d_model,
output_dim=config.num_interpro_classes,
hidden_dim=4 * config.d_model,
),
}
)
def forward(self, token_ids: torch.Tensor) -> dict[str, torch.Tensor]:
"""Forward pass through function token decoder.
Args:
token_ids: <int>[batch_size, function_token_depth] batch of function tokens
ids to decode.
Returns:
interpro_logits: binary classification logits tensor of shape
<float>[batch_size, num_interpro_classes]
"""
assert token_ids.ndim == 2
assert token_ids.shape[1] == self.config.function_token_depth
batch_size, depth = token_ids.shape
if self.config.unpack_lsh_bits:
# Shift values into [0, 2^bits/token)
lsh_bits = token_ids - self.config.num_special_tokens
# extract each bit. (hob stands for highest-order bit)
bits = torch.concat(
[
torch.bitwise_and(lsh_bits, 1 << hob).gt(0).to(torch.int32)
for hob in range(self.config.bits_per_token)
],
dim=1,
)
assert bits.shape == (batch_size, depth * self.config.bits_per_token)
# Shift each bit into individual vocabulary ranges, so they get distinct
# embeddings.
vocab_offsets = 2 * torch.arange(
depth * self.config.bits_per_token, device=token_ids.device
)
inputs = vocab_offsets[None, :] + bits
# zero-out special tokens, i.e. non LSH tokens.
where_special = token_ids < self.config.num_special_tokens
inputs = torch.where(where_special.any(dim=1, keepdim=True), 0, inputs)
else:
# Apply depth-position offset to use distinct vocabs. See __init__ for
# explaination.
vocab_offsets = self.config.function_token_vocab_size * torch.arange(
self.config.function_token_depth,
device=token_ids.device,
)
inputs = token_ids + vocab_offsets[None, :]
embed = self.embedding(inputs)
encoding, _ = self.decoder(embed)
pooled = torch.mean(encoding, dim=1)
return {name: head(pooled) for name, head in self.heads.items()}
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def decode(
self,
function_token_ids: torch.Tensor,
tokenizer: InterProQuantizedTokenizer,
decode_annotations: bool = True,
annotation_threshold: float = 0.1,
decode_keywords=True,
keywords_threshold: float = 0.5,
annotation_min_length: int | None = 5,
annotation_gap_merge_max: int | None = 3,
):
"""Decodes function tokens into predicted annotations and keywords.
Args:
function_token_ids: <int>[length, depth] function token ids. NOTE:
without <bos>/<eos> prefix
tokenizer: function tokenizer.
decode_annotations: whether to decode InterPro annotations.
annotation_threshold: threshold for emitting a function annotation.
decode_keywords: whether to decode function keywords.
keywords_threshold: threshold for emitting a keyword.
annotation_min_length: optional minimum length of predicted annotations for
size filtering.
annotation_gap_merge_max: optional merge adjacent annotation of the same type
Returns:
Decoder outputs:
- "interpro_logits": <float>[length, num_interpro] predicted interpro logits.
- "interpro_preds": <bool>[length, num_interpro] predicted intepro labels.
- "interpro_annotations": list[FunctionAnnotation] predicted InterPro
annotations
- "keyword_logits": <float>[length, keyword_vocabulary] binary prediciton
logits for keywrods.
- "function_keywords": list[FunctionAnnotation] predicted function keyword
ranges.
"""
assert function_token_ids.ndim == 2
assert function_token_ids.shape[1] == tokenizer.depth
assert self.config.function_token_depth == tokenizer.depth
outputs = {}
outputs = self(function_token_ids.to(self.device))
# Only decode in positions that have function tokens.
where_decode = torch.all(
(function_token_ids != tokenizer.vocab_to_index["<pad>"])
& (function_token_ids != tokenizer.vocab_to_index["<none>"])
& (function_token_ids != tokenizer.vocab_to_index["<unk>"]),
dim=1,
)
# Decode InterPro annotations ranges.
interpro_preds = F.sigmoid(outputs["interpro_logits"])
interpro_preds = interpro_preds >= annotation_threshold
interpro_preds[~where_decode, :] = False
outputs["interpro_preds"] = interpro_preds
if decode_annotations:
annotations: list[FunctionAnnotation] = []
preds: np.ndarray = interpro_preds.detach().cpu().numpy()
for position_index, class_index in zip(*preds.nonzero()):
interpro_id = self.interpro_ids[class_index]
annotation = FunctionAnnotation(
label=interpro_id,
start=position_index + 1, # zero-index -> one-index inclusive
end=position_index + 1, # zero-index -> one-index inclusive
)
annotations.append(annotation)
annotations = _merge_annotations(
annotations,
merge_gap_max=annotation_gap_merge_max,
)
# Drop very small annotations.
if annotation_min_length is not None:
annotations = [
annotation
for annotation in annotations
if annotation.end - annotation.start + 1 >= annotation_min_length
]
outputs["interpro_annotations"] = annotations
# Decode function keyword ranges.
keyword_logits = outputs["keyword_logits"]
keyword_logits[~where_decode, :] = -torch.inf
if decode_keywords:
keyword_preds = F.sigmoid(keyword_logits) >= keywords_threshold
outputs["function_keywords"] = self._preds_to_keywords(
keyword_preds.detach().cpu().numpy()
)
return outputs
def _preds_to_keywords(self, keyword_preds: np.ndarray) -> list[FunctionAnnotation]:
"""Converts output log-TFDF to predicted keywords over the sequence.
Args:
keyword_precs: <bool>[length, keyword_vocab] positional predictions of
function keywords from the keyword prediction head.
Returns:
Non-overlapping keyword annotated ranges along the sequence. Note that indices
will index into the *sequence*, not the function token array which has a
<pad> prefix.
"""
assert keyword_preds.ndim == 2
assert keyword_preds.shape[1] == self.config.keyword_vocabulary_size
keyword_positions: dict[str, list[range]] = defaultdict(list)
for position, keyword_id in zip(*np.nonzero(keyword_preds)):
keyword = self.keywords_vocabulary[keyword_id]
keyword_positions[keyword].append(range(position, position + 1))
annotations: list[FunctionAnnotation] = []
for keyword, ranges in keyword_positions.items():
for range_ in merge_ranges(ranges):
annotation = FunctionAnnotation(
label=keyword,
start=range_.start + 1, # zero-index -> one-index
end=range_.stop + 1 - 1, # zero-index excl -> one-index incl
)
annotations.append(annotation)
return annotations
def _merge_annotations(
annotations: list[FunctionAnnotation],
merge_gap_max: int | None = None,
) -> list[FunctionAnnotation]:
"""Merges annotations into non-overlapping segments.
Args:
annotations: annotations to merge.
merge_gap_max: optionally merge neighboring ranges that are separated by a gap
no larger than this size.
Returns:
non-overlapping annotations with gaps merged.
"""
grouped: dict[str, list[range]] = defaultdict(list)
for a in annotations:
# Convert one-indexed inclusive-inclusive, to range()
grouped[a.label].append(range(a.start, a.end + 1))
merged = []
for label, ranges in grouped.items():
merged_ranges = merge_ranges(ranges, merge_gap_max=merge_gap_max)
for range_ in merged_ranges:
annotation = FunctionAnnotation(
label=label,
start=range_.start + 1, # zero-index -> one-index
end=range_.stop - 1, # zero-index excl -> one-index incl
)
merged.append(annotation)
return merged