File size: 6,460 Bytes
224a33f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from typing import Callable

import attr
import torch
from tqdm import tqdm

from esm.sdk.api import (
    ESM3InferenceClient,
    ESMProtein,
    ESMProteinTensor,
    GenerationConfig,
    SamplingConfig,
    SamplingTrackConfig,
)
from esm.tokenization import (
    EsmTokenizerBase,
    TokenizerCollectionProtocol,
)
from esm.utils.constants import esm3 as C
from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY


def iterative_sampling_raw(
    client: ESM3InferenceClient,
    input: ESMProtein,
    config: GenerationConfig,
):
    # Keep structure tokens
    input_tokens = client.encode(input)

    output_tokens = client.generate(input_tokens, config)

    raw_protein = client.decode(output_tokens)

    track_to_sample = config.track

    if track_to_sample not in ["function", "residue_annotations"]:
        # Function and residue annotation encoding/decoding is lossy
        # There is no guarantee that decoding encoded tokens will yield the same input
        raw_protein.function_annotations = input.function_annotations

    return raw_protein


def iterative_sampling_tokens(
    client: ESM3InferenceClient,
    input_tokens: ESMProteinTensor,
    config: GenerationConfig,
    tokenizers: TokenizerCollectionProtocol,
) -> ESMProteinTensor:
    track_to_sample = config.track

    # Get all tracks that require sampling
    all_tracks = [
        f.name for f in attr.fields(SamplingConfig) if "embedding" not in f.name
    ]

    sequence_length = len(input_tokens)
    device = input_tokens.device

    # Initialize schedule and masks
    decoding_schedule = NOISE_SCHEDULE_REGISTRY[config.schedule]
    sampled_tokens = attr.evolve(input_tokens)  # Make a copy

    if config.condition_on_coordinates_only and input_tokens.coordinates is not None:
        sampled_tokens.structure = None

    sampling_mask = torch.ones(
        sequence_length,
        dtype=torch.bool,
        device=device,
    )
    sampling_mask[0] = False
    sampling_mask[-1] = False

    get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s)
    if getattr(sampled_tokens, track_to_sample) is None:
        if track_to_sample == "function":
            dims = (sequence_length, tokenizers.function.depth)
        elif track_to_sample == "residue_annotations":
            dims = (sequence_length, C.MAX_RESIDUE_ANNOTATIONS)
        else:
            dims = (sequence_length,)
        masked_tokens = torch.full(
            dims,
            get_tokenizer(track_to_sample).mask_token_id,
            dtype=torch.long,
            device=device,
        )
        if track_to_sample == "sequence":
            masked_tokens[0] = tokenizers.sequence.cls_token_id  # type: ignore
            masked_tokens[-1] = tokenizers.sequence.eos_token_id  # type: ignore
        else:
            masked_tokens[0] = get_tokenizer(track_to_sample).bos_token_id
            masked_tokens[-1] = get_tokenizer(track_to_sample).eos_token_id

        setattr(
            sampled_tokens,
            track_to_sample,
            masked_tokens,
        )
    else:
        is_mask: torch.Tensor = (
            getattr(input_tokens, track_to_sample)
            == get_tokenizer(track_to_sample).mask_token_id
        )
        if not is_mask.any().item():
            raise ValueError(f"Cannot sample {config.track} when input has no masks.")
        sampling_mask = sampling_mask & is_mask

    # Decode

    def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None:
        return x.clone() if x is not None else None

    L = sequence_length - 2
    positions_sampled = 0
    for t in tqdm(range(config.num_steps)):
        # Single step sampling at all positions
        track_sample_config = SamplingTrackConfig()
        track_sample_config.invalid_ids = config.invalid_ids
        track_sample_config.temperature = config.temperature
        track_sample_config.top_p = config.top_p
        sampling_config = SamplingConfig(**{track_to_sample: track_sample_config})  # type: ignore

        forward_and_sample_output = client.forward_and_sample(
            sampled_tokens, sampling_config
        )
        new_samples = forward_and_sample_output.protein_tensor

        # Calculate number of tokens to sample
        perc_masked = decoding_schedule(torch.tensor((t + 1) / config.num_steps))
        num_to_sample = int((1 - perc_masked) * L) - positions_sampled
        positions_sampled += num_to_sample

        # Select tokens based on lowest entropy
        if track_to_sample in ["function", "residue_annotations"]:
            # TODO: Implement iterative decoding for function and residue_annotations
            # TODO: Fix encode/decode of interpro tokens (not yet supported)
            sampled_tokens.function = maybe_clone(input_tokens.function)
            sampled_tokens.residue_annotations = maybe_clone(
                input_tokens.residue_annotations
            )
            if track_to_sample in track_to_sample:
                raise NotImplementedError(
                    f"Iterative decoding for {track_to_sample} is not supported yet."
                )
            continue

        sampling_mask = sampling_mask & (
            getattr(sampled_tokens, track_to_sample)
            == get_tokenizer(track_to_sample).mask_token_id
        )

        track_entropy: torch.Tensor = getattr(
            forward_and_sample_output.entropy, track_to_sample
        )
        track_entropy = track_entropy.masked_fill(
            ~sampling_mask, torch.finfo(track_entropy.dtype).max
        )
        _, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False)
        is_top_k = ~(
            torch.arange(sequence_length, device=device)[:, None] != indices[None, :]
        ).all(-1)
        tokens_to_sample = sampling_mask & is_top_k

        old_track_samples = getattr(sampled_tokens, track_to_sample)
        new_track_samples = getattr(new_samples, track_to_sample)

        new_track_samples = torch.where(
            tokens_to_sample, new_track_samples, old_track_samples
        )

        setattr(sampled_tokens, track_to_sample, new_track_samples)

    # Do not update tracks that were not sampled (e.g. keep None instead of masks)
    for track in all_tracks:
        if track != track_to_sample:
            setattr(
                sampled_tokens,
                track,
                maybe_clone(getattr(input_tokens, track)),
            )

    return sampled_tokens