from typing import Callable import attr import torch from tqdm import tqdm from esm.sdk.api import ( ESM3InferenceClient, ESMProtein, ESMProteinTensor, GenerationConfig, SamplingConfig, SamplingTrackConfig, ) from esm.tokenization import ( EsmTokenizerBase, TokenizerCollectionProtocol, ) from esm.utils.constants import esm3 as C from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY def iterative_sampling_raw( client: ESM3InferenceClient, input: ESMProtein, config: GenerationConfig, ): # Keep structure tokens input_tokens = client.encode(input) output_tokens = client.generate(input_tokens, config) raw_protein = client.decode(output_tokens) track_to_sample = config.track if track_to_sample not in ["function", "residue_annotations"]: # Function and residue annotation encoding/decoding is lossy # There is no guarantee that decoding encoded tokens will yield the same input raw_protein.function_annotations = input.function_annotations return raw_protein def iterative_sampling_tokens( client: ESM3InferenceClient, input_tokens: ESMProteinTensor, config: GenerationConfig, tokenizers: TokenizerCollectionProtocol, ) -> ESMProteinTensor: track_to_sample = config.track # Get all tracks that require sampling all_tracks = [ f.name for f in attr.fields(SamplingConfig) if "embedding" not in f.name ] sequence_length = len(input_tokens) device = input_tokens.device # Initialize schedule and masks decoding_schedule = NOISE_SCHEDULE_REGISTRY[config.schedule] sampled_tokens = attr.evolve(input_tokens) # Make a copy if config.condition_on_coordinates_only and input_tokens.coordinates is not None: sampled_tokens.structure = None sampling_mask = torch.ones( sequence_length, dtype=torch.bool, device=device, ) sampling_mask[0] = False sampling_mask[-1] = False get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s) if getattr(sampled_tokens, track_to_sample) is None: if track_to_sample == "function": dims = (sequence_length, tokenizers.function.depth) elif track_to_sample == "residue_annotations": dims = (sequence_length, C.MAX_RESIDUE_ANNOTATIONS) else: dims = (sequence_length,) masked_tokens = torch.full( dims, get_tokenizer(track_to_sample).mask_token_id, dtype=torch.long, device=device, ) if track_to_sample == "sequence": masked_tokens[0] = tokenizers.sequence.cls_token_id # type: ignore masked_tokens[-1] = tokenizers.sequence.eos_token_id # type: ignore else: masked_tokens[0] = get_tokenizer(track_to_sample).bos_token_id masked_tokens[-1] = get_tokenizer(track_to_sample).eos_token_id setattr( sampled_tokens, track_to_sample, masked_tokens, ) else: is_mask: torch.Tensor = ( getattr(input_tokens, track_to_sample) == get_tokenizer(track_to_sample).mask_token_id ) if not is_mask.any().item(): raise ValueError(f"Cannot sample {config.track} when input has no masks.") sampling_mask = sampling_mask & is_mask # Decode def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None: return x.clone() if x is not None else None L = sequence_length - 2 positions_sampled = 0 for t in tqdm(range(config.num_steps)): # Single step sampling at all positions track_sample_config = SamplingTrackConfig() track_sample_config.invalid_ids = config.invalid_ids track_sample_config.temperature = config.temperature track_sample_config.top_p = config.top_p sampling_config = SamplingConfig(**{track_to_sample: track_sample_config}) # type: ignore forward_and_sample_output = client.forward_and_sample( sampled_tokens, sampling_config ) new_samples = forward_and_sample_output.protein_tensor # Calculate number of tokens to sample perc_masked = decoding_schedule(torch.tensor((t + 1) / config.num_steps)) num_to_sample = int((1 - perc_masked) * L) - positions_sampled positions_sampled += num_to_sample # Select tokens based on lowest entropy if track_to_sample in ["function", "residue_annotations"]: # TODO: Implement iterative decoding for function and residue_annotations # TODO: Fix encode/decode of interpro tokens (not yet supported) sampled_tokens.function = maybe_clone(input_tokens.function) sampled_tokens.residue_annotations = maybe_clone( input_tokens.residue_annotations ) if track_to_sample in track_to_sample: raise NotImplementedError( f"Iterative decoding for {track_to_sample} is not supported yet." ) continue sampling_mask = sampling_mask & ( getattr(sampled_tokens, track_to_sample) == get_tokenizer(track_to_sample).mask_token_id ) track_entropy: torch.Tensor = getattr( forward_and_sample_output.entropy, track_to_sample ) track_entropy = track_entropy.masked_fill( ~sampling_mask, torch.finfo(track_entropy.dtype).max ) _, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False) is_top_k = ~( torch.arange(sequence_length, device=device)[:, None] != indices[None, :] ).all(-1) tokens_to_sample = sampling_mask & is_top_k old_track_samples = getattr(sampled_tokens, track_to_sample) new_track_samples = getattr(new_samples, track_to_sample) new_track_samples = torch.where( tokens_to_sample, new_track_samples, old_track_samples ) setattr(sampled_tokens, track_to_sample, new_track_samples) # Do not update tracks that were not sampled (e.g. keep None instead of masks) for track in all_tracks: if track != track_to_sample: setattr( sampled_tokens, track, maybe_clone(getattr(input_tokens, track)), ) return sampled_tokens