Spaces:
Running
Running
File size: 13,944 Bytes
224a33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 |
"""Function Token Decoder."""
from collections import defaultdict
from dataclasses import dataclass, field
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.misc import merge_ranges
from esm.utils.types import FunctionAnnotation
@dataclass(frozen=True)
class FunctionTokenDecoderConfig:
"""Configures function token decoder."""
# Embedding dimension of decoder.
d_model: int = 1024
# Number of attention heads of decoder.
n_heads: int = 8
# Number of layers of decoder.
n_layers: int = 3
# Number of integer values that function tokens may assume.
function_token_vocab_size: int = 260
# Number of function tokens at each position.
function_token_depth: int = 8
# Number of InterPro labels that can be decoded.
num_interpro_classes: int = 29026
# Number of function keywords that can be decoded.
keyword_vocabulary_size: int = 58641
# List of supported InterPro ids.
interpro_entry_list: str = field(
default_factory=lambda: str(C.data_root() / C.INTERPRO_ENTRY)
)
# Path to keywords vocabulary.
keyword_vocabulary_path: str = field(
default_factory=lambda: str(C.data_root() / C.KEYWORDS_VOCABULARY)
)
# Whether to unpack LSH bits into single-bit tokens.
unpack_lsh_bits: bool = True
# The number of special tokens in the function tokenizer vocabulary which come
# before the LSH tokens.
num_special_tokens: int = 4
# The number of bits per LSH token in the function tokenizer.
bits_per_token: int = 8
class FunctionTokenDecoder(nn.Module):
def __init__(self, config: FunctionTokenDecoderConfig | None = None):
"""Constructs function token decoder."""
super().__init__()
if config is None:
config = FunctionTokenDecoderConfig()
self.config = config
# Get the supported set of InterPro ids.
df = pd.read_csv(config.interpro_entry_list, sep="\t")
self.interpro_ids = sorted(df.ENTRY_AC)
self.interpro2index = {
interpro_id: i for i, interpro_id in enumerate(self.interpro_ids)
}
assert len(self.interpro_ids) == config.num_interpro_classes
with open(config.keyword_vocabulary_path, "r") as f:
self.keywords_vocabulary: list[str] = list(f.read().strip().split("\n"))
assert len(self.keywords_vocabulary) == config.keyword_vocabulary_size
if config.unpack_lsh_bits:
vocab_size = 2 * config.function_token_depth * config.bits_per_token
else:
# Function-token id's re-use the same token ids at each position along the depth
# dimension, despite distinct meanings. The decoder should take this into
# account so create distinct embeddings for tokens at each position.
vocab_size = (
self.config.function_token_depth * self.config.function_token_vocab_size
)
self.embedding = nn.Embedding(
# Function-token id's re-use the same token ids at each position along the
# depth dimension, despite distinct meanings. The decoder should take this
# into account so create distinct embeddings for tokens at each position.
num_embeddings=(vocab_size),
embedding_dim=config.d_model,
)
self.decoder = TransformerStack(
d_model=config.d_model,
n_heads=config.n_heads,
v_heads=None,
n_layers=config.n_layers,
n_layers_geom=0,
scale_residue=False,
bias=True,
qk_layernorm=False,
ffn_type="gelu",
expansion_ratio=4,
)
self.heads = nn.ModuleDict(
{
# Binary classification head predicting which keywords are present.
"keyword_logits": RegressionHead(
d_model=config.d_model,
output_dim=config.keyword_vocabulary_size,
hidden_dim=4 * config.d_model,
),
# Regresses the TF-IDF value of each present keyword.
"keyword_tfidf": RegressionHead(
d_model=config.d_model,
output_dim=config.keyword_vocabulary_size,
hidden_dim=4 * config.d_model,
),
# Predicts which InterPro annotations are present.
"interpro_logits": RegressionHead(
d_model=config.d_model,
output_dim=config.num_interpro_classes,
hidden_dim=4 * config.d_model,
),
}
)
def forward(self, token_ids: torch.Tensor) -> dict[str, torch.Tensor]:
"""Forward pass through function token decoder.
Args:
token_ids: <int>[batch_size, function_token_depth] batch of function tokens
ids to decode.
Returns:
interpro_logits: binary classification logits tensor of shape
<float>[batch_size, num_interpro_classes]
"""
assert token_ids.ndim == 2
assert token_ids.shape[1] == self.config.function_token_depth
batch_size, depth = token_ids.shape
if self.config.unpack_lsh_bits:
# Shift values into [0, 2^bits/token)
lsh_bits = token_ids - self.config.num_special_tokens
# extract each bit. (hob stands for highest-order bit)
bits = torch.concat(
[
torch.bitwise_and(lsh_bits, 1 << hob).gt(0).to(torch.int32)
for hob in range(self.config.bits_per_token)
],
dim=1,
)
assert bits.shape == (batch_size, depth * self.config.bits_per_token)
# Shift each bit into individual vocabulary ranges, so they get distinct
# embeddings.
vocab_offsets = 2 * torch.arange(
depth * self.config.bits_per_token, device=token_ids.device
)
inputs = vocab_offsets[None, :] + bits
# zero-out special tokens, i.e. non LSH tokens.
where_special = token_ids < self.config.num_special_tokens
inputs = torch.where(where_special.any(dim=1, keepdim=True), 0, inputs)
else:
# Apply depth-position offset to use distinct vocabs. See __init__ for
# explaination.
vocab_offsets = self.config.function_token_vocab_size * torch.arange(
self.config.function_token_depth,
device=token_ids.device,
)
inputs = token_ids + vocab_offsets[None, :]
embed = self.embedding(inputs)
encoding, _ = self.decoder(embed)
pooled = torch.mean(encoding, dim=1)
return {name: head(pooled) for name, head in self.heads.items()}
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def decode(
self,
function_token_ids: torch.Tensor,
tokenizer: InterProQuantizedTokenizer,
decode_annotations: bool = True,
annotation_threshold: float = 0.1,
decode_keywords=True,
keywords_threshold: float = 0.5,
annotation_min_length: int | None = 5,
annotation_gap_merge_max: int | None = 3,
):
"""Decodes function tokens into predicted annotations and keywords.
Args:
function_token_ids: <int>[length, depth] function token ids. NOTE:
without <bos>/<eos> prefix
tokenizer: function tokenizer.
decode_annotations: whether to decode InterPro annotations.
annotation_threshold: threshold for emitting a function annotation.
decode_keywords: whether to decode function keywords.
keywords_threshold: threshold for emitting a keyword.
annotation_min_length: optional minimum length of predicted annotations for
size filtering.
annotation_gap_merge_max: optional merge adjacent annotation of the same type
Returns:
Decoder outputs:
- "interpro_logits": <float>[length, num_interpro] predicted interpro logits.
- "interpro_preds": <bool>[length, num_interpro] predicted intepro labels.
- "interpro_annotations": list[FunctionAnnotation] predicted InterPro
annotations
- "keyword_logits": <float>[length, keyword_vocabulary] binary prediciton
logits for keywrods.
- "function_keywords": list[FunctionAnnotation] predicted function keyword
ranges.
"""
assert function_token_ids.ndim == 2
assert function_token_ids.shape[1] == tokenizer.depth
assert self.config.function_token_depth == tokenizer.depth
outputs = {}
outputs = self(function_token_ids.to(self.device))
# Only decode in positions that have function tokens.
where_decode = torch.all(
(function_token_ids != tokenizer.vocab_to_index["<pad>"])
& (function_token_ids != tokenizer.vocab_to_index["<none>"])
& (function_token_ids != tokenizer.vocab_to_index["<unk>"]),
dim=1,
)
# Decode InterPro annotations ranges.
interpro_preds = F.sigmoid(outputs["interpro_logits"])
interpro_preds = interpro_preds >= annotation_threshold
interpro_preds[~where_decode, :] = False
outputs["interpro_preds"] = interpro_preds
if decode_annotations:
annotations: list[FunctionAnnotation] = []
preds: np.ndarray = interpro_preds.detach().cpu().numpy()
for position_index, class_index in zip(*preds.nonzero()):
interpro_id = self.interpro_ids[class_index]
annotation = FunctionAnnotation(
label=interpro_id,
start=position_index + 1, # zero-index -> one-index inclusive
end=position_index + 1, # zero-index -> one-index inclusive
)
annotations.append(annotation)
annotations = _merge_annotations(
annotations,
merge_gap_max=annotation_gap_merge_max,
)
# Drop very small annotations.
if annotation_min_length is not None:
annotations = [
annotation
for annotation in annotations
if annotation.end - annotation.start + 1 >= annotation_min_length
]
outputs["interpro_annotations"] = annotations
# Decode function keyword ranges.
keyword_logits = outputs["keyword_logits"]
keyword_logits[~where_decode, :] = -torch.inf
if decode_keywords:
keyword_preds = F.sigmoid(keyword_logits) >= keywords_threshold
outputs["function_keywords"] = self._preds_to_keywords(
keyword_preds.detach().cpu().numpy()
)
return outputs
def _preds_to_keywords(self, keyword_preds: np.ndarray) -> list[FunctionAnnotation]:
"""Converts output log-TFDF to predicted keywords over the sequence.
Args:
keyword_precs: <bool>[length, keyword_vocab] positional predictions of
function keywords from the keyword prediction head.
Returns:
Non-overlapping keyword annotated ranges along the sequence. Note that indices
will index into the *sequence*, not the function token array which has a
<pad> prefix.
"""
assert keyword_preds.ndim == 2
assert keyword_preds.shape[1] == self.config.keyword_vocabulary_size
keyword_positions: dict[str, list[range]] = defaultdict(list)
for position, keyword_id in zip(*np.nonzero(keyword_preds)):
keyword = self.keywords_vocabulary[keyword_id]
keyword_positions[keyword].append(range(position, position + 1))
annotations: list[FunctionAnnotation] = []
for keyword, ranges in keyword_positions.items():
for range_ in merge_ranges(ranges):
annotation = FunctionAnnotation(
label=keyword,
start=range_.start + 1, # zero-index -> one-index
end=range_.stop + 1 - 1, # zero-index excl -> one-index incl
)
annotations.append(annotation)
return annotations
def _merge_annotations(
annotations: list[FunctionAnnotation],
merge_gap_max: int | None = None,
) -> list[FunctionAnnotation]:
"""Merges annotations into non-overlapping segments.
Args:
annotations: annotations to merge.
merge_gap_max: optionally merge neighboring ranges that are separated by a gap
no larger than this size.
Returns:
non-overlapping annotations with gaps merged.
"""
grouped: dict[str, list[range]] = defaultdict(list)
for a in annotations:
# Convert one-indexed inclusive-inclusive, to range()
grouped[a.label].append(range(a.start, a.end + 1))
merged = []
for label, ranges in grouped.items():
merged_ranges = merge_ranges(ranges, merge_gap_max=merge_gap_max)
for range_ in merged_ranges:
annotation = FunctionAnnotation(
label=label,
start=range_.start + 1, # zero-index -> one-index
end=range_.stop - 1, # zero-index excl -> one-index incl
)
merged.append(annotation)
return merged
|