# Copyright Generate Biomedicines, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from types import SimpleNamespace from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn.functional import pad from chroma.data.xcs import validate_XC from chroma.layers.basic import FourierFeaturization from chroma.layers.structure import diffusion from chroma.models import graph_classifier from chroma.models.graph_classifier import GraphClassifier from chroma.models.graph_design import BackboneEncoderGNN from chroma.utility.model import load_model as utility_load_model from chroma.utility.model import save_model as utility_save_model class ProteinCaption(nn.Module): """ProCap model for caption likelihood given a noised structure. Provides an architecture to model the likelihood of a caption representing a protein backbone at an arbitrary diffusion time. For caption processing, it uses a pretrained language model from Hugging Face which can be user-specified and fine-tuned. For structures, ProteinCaption uses a `BackboneEncoderGNN` that encodes a structure and its noise level in the embedding space of the language model. There are several options for interfacing between the representations of the backbone residues and those of the caption. A `ProteinCaption` model can be used to conditionally generate backbones given a natural language caption, through the creation of a `ProCapConditioner` using the model. In this case, the noising parameters used for the `ProteinCaption` model should be identical to those that were used to train the underlying backbone diffusion model. Args: lm_id (str): Base language model to pull from Hugging Face. gnn_dim_edges (int): Number of edges for structure encoder. context_size (int): When encoding structures by chains, specifies the maximum number of chains to be used for the encodings. Not used when `direct_gnn` is specified. context_per_chain (int): When encoding structures by chain, the number of context tokens to use per chain. Not used when `direct_gnn` is specified. gnn_num_neighbors (int): Number of neighbors per node for structure encoder. gnn_num_layers (int): Number of layers for structure encoder. only_encode_caption_chain (bool): Whether to pass structure of only chain whose caption is being predicted, as opposed to entire structure. gnn_embed_ratio (int): Number of context tokens to extract from GNN per chain, stacks with gnn_embed_ratio. graph_criterion (str): Graph criterion for structure encoder, defines how neighbors are chosen. See `chroma.models.graph_design.BackboneEncoderGNN` for allowed values. node_mlp_layers (int): Number of hidden layers for node update function of structure encoder. node_mlp_dim (int, optional): Dimension of hidden layers for node update function of structure encoder, defaults to match output dimension. noise_schedule (str): Noise schedule for mapping between diffusion time and noise level, see chroma.layers.structure.diffusion.DiffusionChainCov for allowed values. covariance_model (str): Covariance mode for mapping between diffusion time and noise level, see chroma.layers.structure.diffusion.DiffusionChainCov for allowed values. noise_complex_scaling (bool): Whether to scale noise for complexes. noiseless (bool): Whether to train with denoised structures only, useful for debugging but resulting model cannot be used for classifier guidance. normalize_context_embeddings (bool): Whether to normalize context embeddings to an overall length of 1. standardize_context_embeddings (bool): Whether to standardize context embeddings to have mean 0 and variance 1. time_feature_type (str): Method of encoding diffusion timestep. time_log_feature_scaling (float): Scaling of diffusion timestep in preprocessing when encoding with `time_feature_type = "log_snr"`. use_transformer (bool): Whether to use transformer to embed context from residue-level GNN outputs. classifier_checkpoint (str, optional): Path to pre-trained graph classifier checkpoint, whose encoder head will be used for structure encoding. direct_gnn (bool): Whether to pass in GNN encodings for chains/complexes directly to the language model, without any pooling or transformer layers. classifier_kwargs (dict, optional): Dictionary of parameters to create classifier network for encoding. Will override classifier_checkpoint if given. Inputs: X (torch.Tensor): Backbone tensor of shape `(num_batch, num_residues, 4, 3)`. C (torch.Tensor): Chain map of shape `(num_batch, num_residues)`. Positions with 0 are masked, positive integers are used for chain indices, and negative integers are used for missing residues of the chains with indices equal to the corresponding positive integers. caption (List[str]): List of captions with length `num_batch`. chain_id (torch.Tensor): Chain indices for given captions of shape `(num_batch)`. For a caption corresponding to an entire complex, use -1. O (torch.Tensor, optional): One-hot sequence tensor of shape `(num_batch, num_residues, num_alphabet)`. If not given, the loss is computed without sequence information. add_noise (bool): Whether to randomly add noise to the input backbones. If structures are already noised, use `t` instead. t (torch.Tensor, optional): Diffusion timesteps corresponding to noisy input backbones, of shape `(num_batch)`. Use zeros when passing structures without noise. by_sample (bool): Whether to return loss per sample, as opposed to overall batch loss. Outputs: loss (Union[transformers.modeling_outputs.CausalLMOutputWithCrossAttentions, torch.Tensor]): Loss containing average -log(p) of caption tokens given output structures. If `by_sample` is specified, loss is output as a tensor of length `(num_batch)`. """ def __init__( self, lm_id: str = "EleutherAI/gpt-neo-125m", gnn_dim_edges: int = 128, context_size: int = 16, context_per_chain: int = 1, gnn_num_neighbors: int = 30, gnn_num_layers: int = 3, only_encode_caption_chain: bool = False, gnn_embed_ratio: int = 1, graph_criterion: str = "knn", node_mlp_layers: int = 1, node_mlp_dim: Optional[int] = None, noise_schedule: str = "log_snr", covariance_model: str = "globular", noise_complex_scaling: bool = False, noiseless: bool = False, normalize_context_embeddings: bool = False, standardize_context_embeddings: bool = False, time_feature_type: str = "t", time_log_feature_scaling: float = 0.05, use_transformer: bool = False, classifier_checkpoint: Optional[str] = None, direct_gnn: bool = False, classifier_kwargs: Optional[dict] = None, ) -> None: super().__init__() # Save configuration in kwargs self.kwargs = locals() self.kwargs.pop("self") for key in list(self.kwargs.keys()): if key.startswith("__") and key.endswith("__"): self.kwargs.pop(key) args = SimpleNamespace(**self.kwargs) try: import transformers except ImportError: print("Install the hugging face package `transformers` to use ProCap") self.context_size = context_size self.context_per_chain = context_per_chain self.only_encode_caption_chain = only_encode_caption_chain self.gnn_embed_ratio = gnn_embed_ratio self.normalize_context_embeddings = normalize_context_embeddings self.standardize_context_embeddings = standardize_context_embeddings self.time_feature_type = time_feature_type self.time_log_feature_scaling = time_log_feature_scaling self.use_transformer = use_transformer self.classifier_checkpoint = classifier_checkpoint self.direct_gnn = direct_gnn self.classifier_kwargs = classifier_kwargs if self.normalize_context_embeddings and self.standardize_context_embeddings: print( "Warning: both normalization and standardization of context embeddings" " are selected, choosing only standardization" ) self.normalize_context_embeddings = False # Use Pretrained Tokenizer From Hugging Face self.tokenizer = transformers.AutoTokenizer.from_pretrained( lm_id, additional_special_tokens=["<|pdb|>", "<|unconditioned|>"], eos_token="<|endoftext|>", pad_token="<|pad|>", ) # Use Pretrained Language Model From Hugging Face self.language_model = transformers.AutoModelForCausalLM.from_pretrained(lm_id) # Embedding self.language_model.resize_token_embeddings(len(self.tokenizer)) self.embedder = self.language_model.get_input_embeddings() self.d_model = self.embedder.embedding_dim # Standardization for context embeddings if self.standardize_context_embeddings: self.context_normalization = nn.LayerNorm( self.d_model, elementwise_affine=False ) # Transformer for context embeddings if self.use_transformer: self.transformer = nn.Transformer( nhead=8, d_model=self.d_model, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, batch_first=True, ) if gnn_embed_ratio != 1: print( "Warning: both use_transformer and gnn_embed_ratio are set, setting" " gnn_embed_ratio to 1" ) self.gnn_embed_ratio = 1 if context_per_chain != 1: print( "Warning: both use_transformer and context_per_chain are set," " setting context_per_chain to 1" ) self.context_per_chain = 1 if not self.only_encode_caption_chain: print( "Warning: use_transformer is set but only_encode_caption_chain is" " not, this is unsupported! Setting only_encode_caption_chain to" " True" ) self.only_encode_caption_chain = True # Pass in GNN encodings without averaging or transformer if self.direct_gnn: if gnn_embed_ratio != 1: print( "Warning: both direct_gnn and gnn_embed_ratio are set, setting" " gnn_embed_ratio to 1" ) self.gnn_embed_ratio = 1 if context_per_chain != 1: print( "Warning: both direct_gnn and context_per_chain are set, setting" " context_per_chain to 1" ) self.context_per_chain = 1 if not self.only_encode_caption_chain: print( "Warning: direct_gnn is set but only_encode_caption_chain is not," " this is unsupported! Setting only_encode_caption_chain to True" ) self.only_encode_caption_chain = True if self.use_transformer: print( "Warning: direct_gnn and use_transformer are both set, turning off" " use_transformer" ) self.use_transformer = False if self.context_size is not None: print( "Warning: context_size given but not used for direct_gnn, setting" " context_size to None" ) self.context_size = None # Use Standard Protein Encoder if self.classifier_checkpoint is not None or self.classifier_kwargs is not None: if self.classifier_kwargs is not None: self.protein_encoder = GraphClassifier(**classifier_kwargs) else: self.protein_encoder = graph_classifier.load_model( classifier_checkpoint ) self.classifier_kwargs = self.protein_encoder.kwargs self.kwargs["classifier_kwargs"] = self.classifier_kwargs self.protein_encoder_linear = nn.Sequential( nn.Linear( self.protein_encoder.dim_nodes, self.d_model * self.gnn_embed_ratio ), nn.ReLU(), ) else: self.protein_encoder = BackboneEncoderGNN( dim_nodes=self.d_model * self.gnn_embed_ratio, dim_edges=gnn_dim_edges, num_neighbors=gnn_num_neighbors, num_layers=gnn_num_layers, node_mlp_layers=node_mlp_layers, node_mlp_dim=node_mlp_dim, graph_criterion=graph_criterion, ) # Use same Noise Layer as in Graph Energy model if not noiseless: self.noise_generator = diffusion.DiffusionChainCov( log_snr_range=(-7.0, 13.5), noise_schedule=noise_schedule, covariance_model=covariance_model, complex_scaling=noise_complex_scaling, ) else: self.noise_generator = None self.time_features = FourierFeaturization( d_input=1, d_model=self.d_model * self.gnn_embed_ratio, trainable=False, scale=16.0, ) # Embed Tokens for 21 Residue Possibilities self.sequence_embedding = nn.Embedding(22, self.d_model * self.gnn_embed_ratio) @validate_XC() def forward( self, X: torch.Tensor, C: torch.Tensor, caption: List[str], chain_id: torch.Tensor, O: Optional[torch.Tensor] = None, add_noise: bool = True, t: Optional[Union[torch.Tensor, float]] = None, by_sample: bool = False, ) -> Union[ "transformers.modeling_outputs.CausalLMOutputWithCrossAttentions", torch.Tensor ]: if self.noise_generator is None: t = torch.zeros(X.shape[0]).to(X.device) if isinstance(t, float): t = torch.Tensor([t]).to(X.device) elif isinstance(t, torch.Tensor) and t.dim() == 0: t = t.unsqueeze(0) if add_noise and self.noise_generator is not None: # Add Chain Noise X, t = self._noise(X, C) assert all(t <= 1) and all(t >= 0), ( "Noise Temperatures must be between 0 and 1, but got values" f" {t[(t > 1) | (t < 0)]}" ) else: assert t is not None, "Must pass diffusion timestep if not adding noise!" # Encode Protein Context if self.classifier_kwargs is None: # Aux feature encoding node_h = self._time_features(t) if O is not None: # pad one-hot tensor by two to account for special tokens used node_h = node_h + pad(O, (0, 2)) @ self.sequence_embedding.weight.to( X.device ) Xe, _, _, Me, _ = self.protein_encoder.to(X.device)(X, C, node_h_aux=node_h) else: # TODO: is there a better way to deal with sequence padding tokens when batch size > 1? if O is not None and O[:, :, -1].any(): O = None Xe0, _, _, Me, _ = self.protein_encoder.to(X.device).encode(X, C, O, t) Xe = self.protein_encoder_linear.to(X.device)(Xe0) context_embedding, attention_mask_context = self._encode_context( Xe, C, Me, chain_id ) if self.standardize_context_embeddings: context_embedding = self.context_normalization.to(Xe.device)( context_embedding ) elif self.normalize_context_embeddings: context_embedding = torch.nn.functional.normalize(context_embedding, dim=-1) # Encode Text Input if self.direct_gnn: max_caption_tokens = ( self.tokenizer.model_max_length - context_embedding.shape[1] ) else: max_caption_tokens = ( self.tokenizer.model_max_length - (self.context_size - 1) * self.gnn_embed_ratio * self.context_per_chain - 1 ) Y, attention_mask_caption = self._tokenize( caption, add_stop=True, max_length=max_caption_tokens ) Y = Y.to(X.device) attention_mask_caption = attention_mask_caption.to(X.device) caption_embedding = self._embed_text(Y) # Caption inputs_embeds = torch.cat([context_embedding, caption_embedding], dim=1) attention_mask = torch.cat( [attention_mask_context, attention_mask_caption], dim=1 ) labels = torch.cat( [ torch.tensor(-100, device=X.device).expand( attention_mask_context.shape ), Y * attention_mask_caption + (-100) * (1 - attention_mask_caption), ], dim=1, ) # returns a transformers.modeling_outputs.CausalLMOutputWithCrossAttentions object # can get logits with output.logits output = self.language_model.to(X.device).forward( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels ) if not by_sample: return output else: # below code adapted from transformers/modeling_gpt2.py shift_logits = output.logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss(reduction="none") loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ).reshape(X.shape[0], -1) return torch.Tensor( (loss * (shift_labels != -100).int()).sum(dim=-1) / (shift_labels != -100).int().sum(dim=-1) ) return output @validate_XC(all_atom=False) def _noise( self, X: torch.Tensor, C: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Takes in a Structure Tensor X and Chain Tensor C, adds chain noise with quasi-uniformly sampled temperature. Returns the noised X and the time steps used.""" assert self.noise_generator is not None, "Model does not have noising!" return [x.to(X.device) for x in self.noise_generator.to(X.device)(X, C)] # Taken from graph classifier model def _time_features(self, t: torch.Tensor) -> torch.Tensor: h = { "t": lambda: t, "log_snr": lambda: self.noise_generator.noise_schedule.log_SNR(t), }[self.time_feature_type]() if "log" in self.time_feature_type: h = self.time_log_feature_scaling * h time_h = self.time_features.to(t.device)(h[:, None, None]) return time_h def _encode_context( self, Xe: torch.Tensor, C: torch.Tensor, M: torch.Tensor, polymer_id: int ) -> Tuple[torch.Tensor, torch.Tensor]: """Average Pool over Chains after accounting for masking input: Xe (torch.Tensor): embedding tensor of shape [batch, residue, d_model] C (torch.Tensor): chain tensor indexing which chain each residue belongs to [batch, residue] M (torch.Tensor): mask tensor of shape [batch, residue] polymer_id (int): index in C of chain, or -1 for entire structure, or 0 to apply no conditioning """ Cm = C * M # Mask Chain Map Cm[Cm < 0] = 0 # Remove Negatives from Chain Map B, R, Dm = Xe.shape pooled_encoding = [] for x, c, pid in zip(Xe, Cm, polymer_id): batch_encoding = [] # The predict whole complex token is added under this syntax if pid == -1: pdb_embedding = self._embed_text( self._tokenize(["<|pdb|>"], add_stop=False)[0].to(Xe.device) ).squeeze(0) batch_encoding.append(pdb_embedding) if pid == 0: pdb_embedding = ( self._embed_text( self._tokenize(["<|unconditioned|>"], add_stop=False)[0] ) .squeeze(0) .to(Xe.device) ) batch_encoding.append(pdb_embedding) # Power Average Pool By Chain if pid != 0: if self.only_encode_caption_chain and (pid != -1): cid = self._pid_2_cid(pid, c) residue_mask = c == cid n_residues = residue_mask.sum(-1) if self.use_transformer: encodings = [ self.transformer.to(Xe.device)( x[residue_mask].unsqueeze(0), torch.zeros(1, self.context_size, self.d_model).to( Xe.device ), ).squeeze(0) ] elif self.direct_gnn: encodings = x[residue_mask].unsqueeze(0) else: encodings = [ (x[residue_mask].pow(p).sum(0).unsqueeze(0) / n_residues) .abs() .pow(1 / p) * ( x[residue_mask].pow(p).sum(0).unsqueeze(0).sign() if p % 2 == 1 else 1 ) for p in range(1, self.context_per_chain + 1) ] encodings = [ enc.reshape(self.gnn_embed_ratio, -1) for enc in encodings ] batch_encoding.extend(encodings) else: if self.use_transformer or self.direct_gnn: residue_mask = ( c > 0 ) # just use all embeddings, no chain structure if self.use_transformer: # should have pid == -1 to get here, so need encoding of size context_size - 1 because of <|pdb|> token assert self.only_encode_caption_chain, ( "only_encode_caption chain = False not supported when" " use_transformer = True!" ) batch_encoding.append( self.transformer.to(Xe.device)( x[residue_mask].unsqueeze(0), torch.zeros( 1, self.context_size - 1, self.d_model ).to(Xe.device), ).squeeze(0) ) else: # direct_gnn batch_encoding.extend(x[residue_mask].unsqueeze(0)) else: for cid in torch.unique(c): if cid == 0: continue residue_mask = c == cid n_residues = residue_mask.sum(-1) encodings = [ ( x[residue_mask].pow(p).sum(0).unsqueeze(0) / n_residues ) .abs() .pow(1 / p) * ( x[residue_mask].pow(p).sum(0).unsqueeze(0).sign() if p % 2 == 1 else 1 ) for p in range(1, self.context_per_chain + 1) ] batch_encoding.extend( [ enc.reshape(self.gnn_embed_ratio, -1) for enc in encodings ] ) # Reorder the chain embedding to caption to be first if pid != -1: first_cid = self._pid_2_cid(pid, c) try: if first_cid != 0: ( batch_encoding[ (first_cid - 1) * self.gnn_embed_ratio * self.context_per_chain : (first_cid) * self.gnn_embed_ratio * self.context_per_chain ], batch_encoding[ 0 : self.gnn_embed_ratio * self.context_per_chain ], ) = ( batch_encoding[ 0 : self.gnn_embed_ratio * self.context_per_chain ], batch_encoding[ (first_cid - 1) * self.gnn_embed_ratio * self.context_per_chain : (first_cid) * self.gnn_embed_ratio * self.context_per_chain ], ) except IndexError: print( "Problem: tried to switch encodings at positions 0 and" f" {first_cid}, but failed!" ) # raise pooled_encoding.append(torch.cat(batch_encoding)) # Pad with Zero Tensor X_pooled = torch.nn.utils.rnn.pad_sequence(pooled_encoding, batch_first=True) if self.context_size is not None: if ( X_pooled.shape[1] > (self.context_size - 1) * self.gnn_embed_ratio * self.context_per_chain + 1 ): print([x.shape for x in pooled_encoding]) print(polymer_id) assert ( X_pooled.shape[1] <= (self.context_size - 1) * self.gnn_embed_ratio * self.context_per_chain + 1 ), ( f"Context is of length {X_pooled.shape[1]}, which is larger than the" " allowed number of tokens" f" {(self.context_size - 1) * self.gnn_embed_ratio * self.context_per_chain + 1};" " this will cause the model to behave poorly!" ) if ( X_pooled.shape[1] < (self.context_size - 1) * self.gnn_embed_ratio * self.context_per_chain + 1 and not self.direct_gnn ): pad_shape = ( (self.context_size - 1) * self.gnn_embed_ratio * self.context_per_chain + 1 - X_pooled.shape[1] ) zero_pad = torch.zeros( [B, pad_shape, int(Dm / self.gnn_embed_ratio)], device=Xe.device ) X_pooled = torch.cat([X_pooled, zero_pad], dim=1) M_pooled = (X_pooled != 0)[ :, :, 0 ] # This is a bit dangerous because very rarely X_pooled could contain zeros in masked regions... return X_pooled, M_pooled def _pid_2_cid(self, pid: int, c: int) -> int: """This function converts the polymer_entity_id in the pdb to the chain_id in the XCS format of generate.""" assert pid in c, f"pid value {pid} must be in the chain map!" chain_values = torch.unique(c) nonzero_chain_values = chain_values[chain_values != 0] cid = (nonzero_chain_values == pid).nonzero(as_tuple=True)[0].item() + 1 return cid def _tokenize( self, text: list, add_stop: bool = True, max_length: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Converts list of strings into a padded tensor, returning the tokenized strings as well as the associated masks.""" if add_stop: text = [x + self.tokenizer.eos_token for x in text] # Note that there are no stop tokens in truncated sequences tokenized_dict = self.tokenizer( text, padding=True, truncation=True, max_length=max_length, return_tensors="pt", ) return tokenized_dict["input_ids"], tokenized_dict["attention_mask"] def _embed_text(self, tokenized_text: torch.Tensor) -> torch.Tensor: """Embeds tokenized text.""" return self.embedder.to(tokenized_text.device)(tokenized_text) def load_model( weight_file: str, device: str = "cpu", strict: bool = False, strict_unexpected: bool = True, ) -> ProteinCaption: """Loads a ProCap model. Args: weight_file (str): Path to the saved model weights. device (str): Device on which to load the model. strict (bool): Whether to require that the keys match between the input file weights and the model created from the parameters stored in the model kwargs. strict_unexpected (bool): Whether to require that there are no unexpected keys when loading model weights, as distinct from the strict option which doesn't allow for missing keys either. By default, we use this option rather than strict for ease of development when adding model features. Returns: model (ProteinCaption): Instance of `ProteinCaption` with loaded weights. For inference the returned model should be set to eval mode with `model.eval()`. """ return utility_load_model( weight_file, ProteinCaption, device=device, strict=strict, strict_unexpected=strict_unexpected, ) def save_model( model: ProteinCaption, weight_file: str, metadata: Optional[dict] = None ) -> None: """Save model, including optional metadata. Args: model (ProteinCaption): An instance of `ProteinCaption`. weight_file (str): The destination path for saving model weights. metadata (dict): A dictionary of additional metadata to add to the model weights. For example, when saving models during training it can be useful to store `args` representing the CLI args, the date and time of training, etc. """ utility_save_model(model, weight_file, metadata=metadata)