File size: 6,849 Bytes
224a33f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import re
from typing import Sequence

import torch

from esm.models.function_decoder import (
    FunctionTokenDecoder,
    _merge_annotations,
)
from esm.tokenization.function_tokenizer import (
    InterProQuantizedTokenizer,
)
from esm.tokenization.residue_tokenizer import (
    ResidueAnnotationsTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.types import FunctionAnnotation


def encode_function_annotations(
    sequence: str,
    function_annotations: Sequence[FunctionAnnotation],
    function_tokens_tokenizer: InterProQuantizedTokenizer,
    residue_annotations_tokenizer: ResidueAnnotationsTokenizer,
    add_special_tokens: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
    assert isinstance(
        residue_annotations_tokenizer, ResidueAnnotationsTokenizer
    ), "residue_annotations_tokenizer must be of type ResidueAnnotationsTokenizer"

    # Split the user's annotations by type
    ft_annotations: list[FunctionAnnotation] = []
    ra_annotations: list[FunctionAnnotation] = []
    for fa in function_annotations:
        assert (
            1 <= fa.start <= fa.end <= len(sequence)
        ), f"Invalid (start, end) in function annotation {fa}. Indices 1-indexed and [inclusive, inclusive]"

        supported_label = False

        # Is it an InterPro label?
        if match := re.match(r"IPR\d+", fa.label):
            if match.group() in function_tokens_tokenizer.interpro_to_index:
                ft_annotations.append(fa)
                supported_label = True

        # Is it a function keyword?
        if fa.label in function_tokens_tokenizer._tfidf.vocab_to_index:
            ft_annotations.append(fa)
            supported_label = True

        # Is it a residue annotation?
        if fa.label in residue_annotations_tokenizer._labels:
            ra_annotations.append(fa)
            supported_label = True

        if not supported_label:
            raise ValueError(f"Unknown label in FunctionAnnotation: {fa.label}")

    # Convert function token FunctionAnnotations -> Tensor
    function_tokens = function_tokens_tokenizer.tokenize(
        annotations=ft_annotations,
        seqlen=len(sequence),
    )
    function_token_ids = function_tokens_tokenizer.encode(
        function_tokens, add_special_tokens=add_special_tokens
    )

    # Convert residue annotation FunctionAnnotations -> Tensor
    if ra_annotations:
        descriptions, starts, ends = zip(
            *[(anot.label, anot.start, anot.end) for anot in ra_annotations]
        )
    else:
        descriptions = starts = ends = None
    ra_tokens = residue_annotations_tokenizer.tokenize(
        {
            "interpro_site_descriptions": descriptions,
            "interpro_site_starts": starts,
            "interpro_site_ends": ends,
        },
        sequence=sequence,
        fail_on_mismatch=True,
    )
    residue_annotation_ids = residue_annotations_tokenizer.encode(
        ra_tokens, add_special_tokens=add_special_tokens
    )

    return function_token_ids, residue_annotation_ids


def decode_function_tokens(
    function_token_ids: torch.Tensor,
    function_token_decoder: FunctionTokenDecoder,
    function_tokens_tokenizer: InterProQuantizedTokenizer,
    decoder_annotation_threshold: float = 0.1,
    annotation_min_length: int | None = 5,
    annotation_gap_merge_max: int | None = 3,
) -> list[FunctionAnnotation]:
    """Decodes model prediction logits into function predictions.

    Merges function token and residue annotation predictions into a single
    set of FunctionAnnotation predictions.

    Args:
        function_token_ids: Tensor <float>[length, depth] of
            function token ids.
        residue_annotation_logits: Tensor  <float>[length, RA-vocab] of residue
            annotation binary classification logits.
        function_tokens_tokenizer: InterPro annotation tokenizer.
        residue_annotation_threshold: tokenizer of residue annotations.
        residue_annotation_threshold: predicted probability threshold for emitting
            a predicted residue annotation.
    Returns:
        Predicted function annotations merged from both predictions.
    """
    assert (
        function_token_ids.ndim == 2
    ), "function_token_ids must be of shape (length, depth)"

    annotations: list[FunctionAnnotation] = []

    # Function Annotations from predicted function tokens.
    decoded = function_token_decoder.decode(
        function_token_ids,
        tokenizer=function_tokens_tokenizer,
        annotation_threshold=decoder_annotation_threshold,
        annotation_min_length=annotation_min_length,
        annotation_gap_merge_max=annotation_gap_merge_max,
    )

    # Convert predicted InterPro annotation to FunctionAnnotation.
    annotations.extend(decoded["function_keywords"])
    for annotation in decoded["interpro_annotations"]:
        annotation: FunctionAnnotation
        label = function_tokens_tokenizer.format_annotation(annotation)
        annotations.append(
            FunctionAnnotation(label=label, start=annotation.start, end=annotation.end)
        )

    return annotations


def decode_residue_annotation_tokens(
    residue_annotations_token_ids: torch.Tensor,
    residue_annotations_tokenizer: ResidueAnnotationsTokenizer,
    annotation_min_length: int | None = 5,
    annotation_gap_merge_max: int | None = 3,
) -> list[FunctionAnnotation]:
    """Decodes residue annotation tokens into FunctionAnnotations.

    Args:
        tokens: Tensor <int>[length, MAX_RESIDUE_ANNOTATIONS] of residue annotation tokens.
        residue_annotations_tokenizer: Tokenizer of residue annotations.
        threshold: predicted probability threshold for emitting a predicted residue
            annotation.
    Returns:
        Predicted residue annotations.
    """
    assert (
        residue_annotations_token_ids.ndim == 2
    ), "logits must be of shape (length, MAX_RESIDUE_ANNOTATIONS)"

    annotations: list[FunctionAnnotation] = []

    for depth in range(0, C.MAX_RESIDUE_ANNOTATIONS):
        token_ids = residue_annotations_token_ids[:, depth]
        for loc, vocab_index in torch.nonzero(token_ids).cpu().numpy():
            label = residue_annotations_tokenizer.vocabulary[vocab_index]
            if label not in [*residue_annotations_tokenizer.special_tokens, "<none>"]:
                annotation = FunctionAnnotation(label=label, start=loc, end=loc)
                annotations.append(annotation)

    annotations = _merge_annotations(
        annotations,
        merge_gap_max=annotation_gap_merge_max,
    )

    # Drop very small annotations.
    if annotation_min_length is not None:
        annotations = [
            annotation
            for annotation in annotations
            if annotation.end - annotation.start + 1 >= annotation_min_length
        ]

    return annotations