File size: 33,054 Bytes
ce7bf5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from types import SimpleNamespace
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.nn.functional import pad

from chroma.data.xcs import validate_XC
from chroma.layers.basic import FourierFeaturization
from chroma.layers.structure import diffusion
from chroma.models import graph_classifier
from chroma.models.graph_classifier import GraphClassifier
from chroma.models.graph_design import BackboneEncoderGNN
from chroma.utility.model import load_model as utility_load_model
from chroma.utility.model import save_model as utility_save_model


class ProteinCaption(nn.Module):
    """ProCap model for caption likelihood given a noised structure.

    Provides an architecture to model the likelihood of a caption representing a
    protein backbone at an arbitrary diffusion time. For caption processing, it
    uses a pretrained language model from Hugging Face which can be
    user-specified and fine-tuned. For structures, ProteinCaption uses a
    `BackboneEncoderGNN` that encodes a structure and its noise level in the
    embedding space of the language model. There are several options for
    interfacing between the representations of the backbone residues and those
    of the caption.

    A `ProteinCaption` model can be used to conditionally generate backbones
    given a natural language caption, through the creation of a
    `ProCapConditioner` using the model. In this case, the noising parameters
    used for the `ProteinCaption` model should be identical to those that were
    used to train the underlying backbone diffusion model.

    Args:
        lm_id (str): Base language model to pull from Hugging Face.
        gnn_dim_edges (int): Number of edges for structure encoder.
        context_size (int): When encoding structures by chains, specifies the
            maximum number of chains to be used for the encodings. Not used when
            `direct_gnn` is specified.
        context_per_chain (int): When encoding structures by chain, the number
            of context tokens to use per chain. Not used when `direct_gnn` is
            specified.
        gnn_num_neighbors (int): Number of neighbors per node for structure
            encoder.
        gnn_num_layers (int): Number of layers for structure encoder.
        only_encode_caption_chain (bool): Whether to pass structure of only
            chain whose caption is being predicted, as opposed to entire
            structure.
        gnn_embed_ratio (int): Number of context tokens to extract from GNN per
            chain, stacks with gnn_embed_ratio.
        graph_criterion (str): Graph criterion for structure encoder, defines
            how neighbors are chosen. See
            `chroma.models.graph_design.BackboneEncoderGNN` for
            allowed values.
        node_mlp_layers (int): Number of hidden layers for node update function
            of structure encoder.
        node_mlp_dim (int, optional): Dimension of hidden layers for node update
            function of structure encoder, defaults to match output dimension.
        noise_schedule (str): Noise schedule for mapping between diffusion time
            and noise level, see
            chroma.layers.structure.diffusion.DiffusionChainCov for allowed
            values.
        covariance_model (str): Covariance mode for mapping between diffusion
            time and noise level, see
            chroma.layers.structure.diffusion.DiffusionChainCov for allowed
            values.
        noise_complex_scaling (bool): Whether to scale noise for complexes.
        noiseless (bool): Whether to train with denoised structures only, useful
            for debugging but resulting model cannot be used for classifier
            guidance.
        normalize_context_embeddings (bool): Whether to normalize context
            embeddings to an overall length of 1.
        standardize_context_embeddings (bool): Whether to standardize context
            embeddings to have mean 0 and variance 1.
        time_feature_type (str): Method of encoding diffusion timestep.
        time_log_feature_scaling (float): Scaling of diffusion timestep in
            preprocessing when encoding with `time_feature_type = "log_snr"`.
        use_transformer (bool): Whether to use transformer to embed context from
            residue-level GNN outputs.
        classifier_checkpoint (str, optional): Path to pre-trained graph
            classifier checkpoint, whose encoder head will be used for structure
            encoding.
        direct_gnn (bool): Whether to pass in GNN encodings for chains/complexes
            directly to the language model, without any pooling or transformer
            layers.
        classifier_kwargs (dict, optional): Dictionary of parameters to create
            classifier network for encoding. Will override classifier_checkpoint
            if given.


    Inputs:
        X (torch.Tensor): Backbone tensor of shape `(num_batch, num_residues,
            4, 3)`.
        C (torch.Tensor): Chain map of shape `(num_batch, num_residues)`.
            Positions with 0 are masked, positive integers are used for chain
            indices, and negative integers are used for missing residues of the
            chains with indices equal to the corresponding positive integers.
        caption (List[str]): List of captions with length `num_batch`.
        chain_id (torch.Tensor): Chain indices for given captions of shape
            `(num_batch)`. For a caption corresponding to an entire complex, use
            -1.
        O (torch.Tensor, optional): One-hot sequence tensor of shape
            `(num_batch, num_residues, num_alphabet)`. If not given, the loss is
            computed without sequence information.
        add_noise (bool): Whether to randomly add noise to the input backbones.
            If structures are already noised, use `t` instead.
        t (torch.Tensor, optional): Diffusion timesteps corresponding to noisy
            input backbones, of shape `(num_batch)`. Use zeros when passing
            structures without noise.
        by_sample (bool): Whether to return loss per sample, as opposed to
            overall batch loss.

    Outputs:
        loss (Union[transformers.modeling_outputs.CausalLMOutputWithCrossAttentions,
            torch.Tensor]): Loss containing average -log(p) of caption tokens
            given output structures. If `by_sample` is specified, loss is output
            as a tensor of length `(num_batch)`.
    """

    def __init__(
        self,
        lm_id: str = "EleutherAI/gpt-neo-125m",
        gnn_dim_edges: int = 128,
        context_size: int = 16,
        context_per_chain: int = 1,
        gnn_num_neighbors: int = 30,
        gnn_num_layers: int = 3,
        only_encode_caption_chain: bool = False,
        gnn_embed_ratio: int = 1,
        graph_criterion: str = "knn",
        node_mlp_layers: int = 1,
        node_mlp_dim: Optional[int] = None,
        noise_schedule: str = "log_snr",
        covariance_model: str = "globular",
        noise_complex_scaling: bool = False,
        noiseless: bool = False,
        normalize_context_embeddings: bool = False,
        standardize_context_embeddings: bool = False,
        time_feature_type: str = "t",
        time_log_feature_scaling: float = 0.05,
        use_transformer: bool = False,
        classifier_checkpoint: Optional[str] = None,
        direct_gnn: bool = False,
        classifier_kwargs: Optional[dict] = None,
    ) -> None:
        super().__init__()

        # Save configuration in kwargs
        self.kwargs = locals()
        self.kwargs.pop("self")
        for key in list(self.kwargs.keys()):
            if key.startswith("__") and key.endswith("__"):
                self.kwargs.pop(key)
        args = SimpleNamespace(**self.kwargs)

        try:
            import transformers
        except ImportError:
            print("Install the hugging face package `transformers` to use ProCap")

        self.context_size = context_size
        self.context_per_chain = context_per_chain
        self.only_encode_caption_chain = only_encode_caption_chain
        self.gnn_embed_ratio = gnn_embed_ratio
        self.normalize_context_embeddings = normalize_context_embeddings
        self.standardize_context_embeddings = standardize_context_embeddings
        self.time_feature_type = time_feature_type
        self.time_log_feature_scaling = time_log_feature_scaling
        self.use_transformer = use_transformer
        self.classifier_checkpoint = classifier_checkpoint
        self.direct_gnn = direct_gnn
        self.classifier_kwargs = classifier_kwargs

        if self.normalize_context_embeddings and self.standardize_context_embeddings:
            print(
                "Warning: both normalization and standardization of context embeddings"
                " are selected, choosing only standardization"
            )
            self.normalize_context_embeddings = False

        # Use Pretrained Tokenizer From Hugging Face
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            lm_id,
            additional_special_tokens=["<|pdb|>", "<|unconditioned|>"],
            eos_token="<|endoftext|>",
            pad_token="<|pad|>",
        )

        # Use Pretrained Language Model From Hugging Face
        self.language_model = transformers.AutoModelForCausalLM.from_pretrained(lm_id)

        # Embedding
        self.language_model.resize_token_embeddings(len(self.tokenizer))
        self.embedder = self.language_model.get_input_embeddings()
        self.d_model = self.embedder.embedding_dim

        # Standardization for context embeddings
        if self.standardize_context_embeddings:
            self.context_normalization = nn.LayerNorm(
                self.d_model, elementwise_affine=False
            )

        # Transformer for context embeddings
        if self.use_transformer:
            self.transformer = nn.Transformer(
                nhead=8,
                d_model=self.d_model,
                num_encoder_layers=6,
                num_decoder_layers=6,
                dim_feedforward=2048,
                batch_first=True,
            )
            if gnn_embed_ratio != 1:
                print(
                    "Warning: both use_transformer and gnn_embed_ratio are set, setting"
                    " gnn_embed_ratio to 1"
                )
                self.gnn_embed_ratio = 1
            if context_per_chain != 1:
                print(
                    "Warning: both use_transformer and context_per_chain are set,"
                    " setting context_per_chain to 1"
                )
                self.context_per_chain = 1
            if not self.only_encode_caption_chain:
                print(
                    "Warning: use_transformer is set but only_encode_caption_chain is"
                    " not, this is unsupported! Setting only_encode_caption_chain to"
                    " True"
                )
                self.only_encode_caption_chain = True

        # Pass in GNN encodings without averaging or transformer
        if self.direct_gnn:
            if gnn_embed_ratio != 1:
                print(
                    "Warning: both direct_gnn and gnn_embed_ratio are set, setting"
                    " gnn_embed_ratio to 1"
                )
                self.gnn_embed_ratio = 1
            if context_per_chain != 1:
                print(
                    "Warning: both direct_gnn and context_per_chain are set, setting"
                    " context_per_chain to 1"
                )
                self.context_per_chain = 1
            if not self.only_encode_caption_chain:
                print(
                    "Warning: direct_gnn is set but only_encode_caption_chain is not,"
                    " this is unsupported! Setting only_encode_caption_chain to True"
                )
                self.only_encode_caption_chain = True
            if self.use_transformer:
                print(
                    "Warning: direct_gnn and use_transformer are both set, turning off"
                    " use_transformer"
                )
                self.use_transformer = False
            if self.context_size is not None:
                print(
                    "Warning: context_size given but not used for direct_gnn, setting"
                    " context_size to None"
                )
                self.context_size = None

        # Use Standard Protein Encoder
        if self.classifier_checkpoint is not None or self.classifier_kwargs is not None:
            if self.classifier_kwargs is not None:
                self.protein_encoder = GraphClassifier(**classifier_kwargs)
            else:
                self.protein_encoder = graph_classifier.load_model(
                    classifier_checkpoint
                )
                self.classifier_kwargs = self.protein_encoder.kwargs
                self.kwargs["classifier_kwargs"] = self.classifier_kwargs
            self.protein_encoder_linear = nn.Sequential(
                nn.Linear(
                    self.protein_encoder.dim_nodes, self.d_model * self.gnn_embed_ratio
                ),
                nn.ReLU(),
            )
        else:
            self.protein_encoder = BackboneEncoderGNN(
                dim_nodes=self.d_model * self.gnn_embed_ratio,
                dim_edges=gnn_dim_edges,
                num_neighbors=gnn_num_neighbors,
                num_layers=gnn_num_layers,
                node_mlp_layers=node_mlp_layers,
                node_mlp_dim=node_mlp_dim,
                graph_criterion=graph_criterion,
            )

        # Use same Noise Layer as in Graph Energy model
        if not noiseless:
            self.noise_generator = diffusion.DiffusionChainCov(
                log_snr_range=(-7.0, 13.5),
                noise_schedule=noise_schedule,
                covariance_model=covariance_model,
                complex_scaling=noise_complex_scaling,
            )
        else:
            self.noise_generator = None
        self.time_features = FourierFeaturization(
            d_input=1,
            d_model=self.d_model * self.gnn_embed_ratio,
            trainable=False,
            scale=16.0,
        )

        # Embed Tokens for 21 Residue Possibilities
        self.sequence_embedding = nn.Embedding(22, self.d_model * self.gnn_embed_ratio)

    @validate_XC()
    def forward(
        self,
        X: torch.Tensor,
        C: torch.Tensor,
        caption: List[str],
        chain_id: torch.Tensor,
        O: Optional[torch.Tensor] = None,
        add_noise: bool = True,
        t: Optional[Union[torch.Tensor, float]] = None,
        by_sample: bool = False,
    ) -> Union[
        "transformers.modeling_outputs.CausalLMOutputWithCrossAttentions", torch.Tensor
    ]:
        if self.noise_generator is None:
            t = torch.zeros(X.shape[0]).to(X.device)

        if isinstance(t, float):
            t = torch.Tensor([t]).to(X.device)

        elif isinstance(t, torch.Tensor) and t.dim() == 0:
            t = t.unsqueeze(0)

        if add_noise and self.noise_generator is not None:
            # Add Chain Noise
            X, t = self._noise(X, C)
            assert all(t <= 1) and all(t >= 0), (
                "Noise Temperatures must be between 0 and 1, but got values"
                f" {t[(t > 1) | (t < 0)]}"
            )
        else:
            assert t is not None, "Must pass diffusion timestep if not adding noise!"

        # Encode Protein Context

        if self.classifier_kwargs is None:
            # Aux feature encoding
            node_h = self._time_features(t)
            if O is not None:
                # pad one-hot tensor by two to account for special tokens used
                node_h = node_h + pad(O, (0, 2)) @ self.sequence_embedding.weight.to(
                    X.device
                )
            Xe, _, _, Me, _ = self.protein_encoder.to(X.device)(X, C, node_h_aux=node_h)
        else:
            # TODO: is there a better way to deal with sequence padding tokens when batch size > 1?
            if O is not None and O[:, :, -1].any():
                O = None
            Xe0, _, _, Me, _ = self.protein_encoder.to(X.device).encode(X, C, O, t)
            Xe = self.protein_encoder_linear.to(X.device)(Xe0)

        context_embedding, attention_mask_context = self._encode_context(
            Xe, C, Me, chain_id
        )
        if self.standardize_context_embeddings:
            context_embedding = self.context_normalization.to(Xe.device)(
                context_embedding
            )
        elif self.normalize_context_embeddings:
            context_embedding = torch.nn.functional.normalize(context_embedding, dim=-1)

        # Encode Text Input
        if self.direct_gnn:
            max_caption_tokens = (
                self.tokenizer.model_max_length - context_embedding.shape[1]
            )
        else:
            max_caption_tokens = (
                self.tokenizer.model_max_length
                - (self.context_size - 1)
                * self.gnn_embed_ratio
                * self.context_per_chain
                - 1
            )
        Y, attention_mask_caption = self._tokenize(
            caption, add_stop=True, max_length=max_caption_tokens
        )
        Y = Y.to(X.device)
        attention_mask_caption = attention_mask_caption.to(X.device)
        caption_embedding = self._embed_text(Y)

        # Caption
        inputs_embeds = torch.cat([context_embedding, caption_embedding], dim=1)
        attention_mask = torch.cat(
            [attention_mask_context, attention_mask_caption], dim=1
        )
        labels = torch.cat(
            [
                torch.tensor(-100, device=X.device).expand(
                    attention_mask_context.shape
                ),
                Y * attention_mask_caption + (-100) * (1 - attention_mask_caption),
            ],
            dim=1,
        )

        # returns a transformers.modeling_outputs.CausalLMOutputWithCrossAttentions object
        # can get logits with output.logits
        output = self.language_model.to(X.device).forward(
            inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
        )
        if not by_sample:
            return output
        else:  # below code adapted from transformers/modeling_gpt2.py
            shift_logits = output.logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(reduction="none")
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            ).reshape(X.shape[0], -1)
            return torch.Tensor(
                (loss * (shift_labels != -100).int()).sum(dim=-1)
                / (shift_labels != -100).int().sum(dim=-1)
            )

        return output

    @validate_XC(all_atom=False)
    def _noise(
        self, X: torch.Tensor, C: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Takes in a Structure Tensor X and Chain Tensor C, adds chain noise with quasi-uniformly sampled temperature.
        Returns the noised X and the time steps used."""
        assert self.noise_generator is not None, "Model does not have noising!"
        return [x.to(X.device) for x in self.noise_generator.to(X.device)(X, C)]

    # Taken from graph classifier model
    def _time_features(self, t: torch.Tensor) -> torch.Tensor:
        h = {
            "t": lambda: t,
            "log_snr": lambda: self.noise_generator.noise_schedule.log_SNR(t),
        }[self.time_feature_type]()

        if "log" in self.time_feature_type:
            h = self.time_log_feature_scaling * h

        time_h = self.time_features.to(t.device)(h[:, None, None])
        return time_h

    def _encode_context(
        self, Xe: torch.Tensor, C: torch.Tensor, M: torch.Tensor, polymer_id: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Average Pool over Chains after accounting for masking
        input:
            Xe (torch.Tensor): embedding tensor of shape [batch, residue, d_model]
            C (torch.Tensor): chain tensor indexing which chain each residue belongs to [batch, residue]
            M (torch.Tensor): mask tensor of shape [batch, residue]
            polymer_id (int): index in C of chain, or -1 for entire structure, or 0 to apply no conditioning
        """

        Cm = C * M  # Mask Chain Map
        Cm[Cm < 0] = 0  # Remove Negatives from Chain Map

        B, R, Dm = Xe.shape
        pooled_encoding = []
        for x, c, pid in zip(Xe, Cm, polymer_id):
            batch_encoding = []

            # The predict whole complex token is added under this syntax
            if pid == -1:
                pdb_embedding = self._embed_text(
                    self._tokenize(["<|pdb|>"], add_stop=False)[0].to(Xe.device)
                ).squeeze(0)
                batch_encoding.append(pdb_embedding)

            if pid == 0:
                pdb_embedding = (
                    self._embed_text(
                        self._tokenize(["<|unconditioned|>"], add_stop=False)[0]
                    )
                    .squeeze(0)
                    .to(Xe.device)
                )
                batch_encoding.append(pdb_embedding)

            # Power Average Pool By Chain
            if pid != 0:
                if self.only_encode_caption_chain and (pid != -1):
                    cid = self._pid_2_cid(pid, c)
                    residue_mask = c == cid
                    n_residues = residue_mask.sum(-1)
                    if self.use_transformer:
                        encodings = [
                            self.transformer.to(Xe.device)(
                                x[residue_mask].unsqueeze(0),
                                torch.zeros(1, self.context_size, self.d_model).to(
                                    Xe.device
                                ),
                            ).squeeze(0)
                        ]
                    elif self.direct_gnn:
                        encodings = x[residue_mask].unsqueeze(0)
                    else:
                        encodings = [
                            (x[residue_mask].pow(p).sum(0).unsqueeze(0) / n_residues)
                            .abs()
                            .pow(1 / p)
                            * (
                                x[residue_mask].pow(p).sum(0).unsqueeze(0).sign()
                                if p % 2 == 1
                                else 1
                            )
                            for p in range(1, self.context_per_chain + 1)
                        ]
                        encodings = [
                            enc.reshape(self.gnn_embed_ratio, -1) for enc in encodings
                        ]
                    batch_encoding.extend(encodings)
                else:
                    if self.use_transformer or self.direct_gnn:
                        residue_mask = (
                            c > 0
                        )  # just use all embeddings, no chain structure
                        if self.use_transformer:
                            # should have pid == -1 to get here, so need encoding of size context_size - 1 because of <|pdb|> token
                            assert self.only_encode_caption_chain, (
                                "only_encode_caption chain = False not supported when"
                                " use_transformer = True!"
                            )
                            batch_encoding.append(
                                self.transformer.to(Xe.device)(
                                    x[residue_mask].unsqueeze(0),
                                    torch.zeros(
                                        1, self.context_size - 1, self.d_model
                                    ).to(Xe.device),
                                ).squeeze(0)
                            )
                        else:  # direct_gnn
                            batch_encoding.extend(x[residue_mask].unsqueeze(0))
                    else:
                        for cid in torch.unique(c):
                            if cid == 0:
                                continue
                            residue_mask = c == cid
                            n_residues = residue_mask.sum(-1)
                            encodings = [
                                (
                                    x[residue_mask].pow(p).sum(0).unsqueeze(0)
                                    / n_residues
                                )
                                .abs()
                                .pow(1 / p)
                                * (
                                    x[residue_mask].pow(p).sum(0).unsqueeze(0).sign()
                                    if p % 2 == 1
                                    else 1
                                )
                                for p in range(1, self.context_per_chain + 1)
                            ]
                            batch_encoding.extend(
                                [
                                    enc.reshape(self.gnn_embed_ratio, -1)
                                    for enc in encodings
                                ]
                            )

                    # Reorder the chain embedding to caption to be first
                    if pid != -1:
                        first_cid = self._pid_2_cid(pid, c)
                        try:
                            if first_cid != 0:
                                (
                                    batch_encoding[
                                        (first_cid - 1)
                                        * self.gnn_embed_ratio
                                        * self.context_per_chain : (first_cid)
                                        * self.gnn_embed_ratio
                                        * self.context_per_chain
                                    ],
                                    batch_encoding[
                                        0 : self.gnn_embed_ratio
                                        * self.context_per_chain
                                    ],
                                ) = (
                                    batch_encoding[
                                        0 : self.gnn_embed_ratio
                                        * self.context_per_chain
                                    ],
                                    batch_encoding[
                                        (first_cid - 1)
                                        * self.gnn_embed_ratio
                                        * self.context_per_chain : (first_cid)
                                        * self.gnn_embed_ratio
                                        * self.context_per_chain
                                    ],
                                )
                        except IndexError:
                            print(
                                "Problem: tried to switch encodings at positions 0 and"
                                f" {first_cid}, but failed!"
                            )
                            # raise

            pooled_encoding.append(torch.cat(batch_encoding))

        # Pad with Zero Tensor
        X_pooled = torch.nn.utils.rnn.pad_sequence(pooled_encoding, batch_first=True)

        if self.context_size is not None:
            if (
                X_pooled.shape[1]
                > (self.context_size - 1)
                * self.gnn_embed_ratio
                * self.context_per_chain
                + 1
            ):
                print([x.shape for x in pooled_encoding])
                print(polymer_id)
            assert (
                X_pooled.shape[1]
                <= (self.context_size - 1)
                * self.gnn_embed_ratio
                * self.context_per_chain
                + 1
            ), (
                f"Context is of length {X_pooled.shape[1]}, which is larger than the"
                " allowed number of tokens"
                f" {(self.context_size - 1) * self.gnn_embed_ratio * self.context_per_chain + 1};"
                " this will cause the model to behave poorly!"
            )
            if (
                X_pooled.shape[1]
                < (self.context_size - 1)
                * self.gnn_embed_ratio
                * self.context_per_chain
                + 1
                and not self.direct_gnn
            ):
                pad_shape = (
                    (self.context_size - 1)
                    * self.gnn_embed_ratio
                    * self.context_per_chain
                    + 1
                    - X_pooled.shape[1]
                )
                zero_pad = torch.zeros(
                    [B, pad_shape, int(Dm / self.gnn_embed_ratio)], device=Xe.device
                )
                X_pooled = torch.cat([X_pooled, zero_pad], dim=1)

        M_pooled = (X_pooled != 0)[
            :, :, 0
        ]  # This is a bit dangerous because very rarely X_pooled could contain zeros in masked regions...
        return X_pooled, M_pooled

    def _pid_2_cid(self, pid: int, c: int) -> int:
        """This function converts the polymer_entity_id in the pdb to the chain_id in the XCS format of generate."""
        assert pid in c, f"pid value {pid} must be in the chain map!"
        chain_values = torch.unique(c)
        nonzero_chain_values = chain_values[chain_values != 0]
        cid = (nonzero_chain_values == pid).nonzero(as_tuple=True)[0].item() + 1
        return cid

    def _tokenize(
        self, text: list, add_stop: bool = True, max_length: Optional[int] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Converts list of strings into a padded tensor, returning the tokenized strings as well as the associated masks."""
        if add_stop:
            text = [x + self.tokenizer.eos_token for x in text]
        # Note that there are no stop tokens in truncated sequences
        tokenized_dict = self.tokenizer(
            text,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        )
        return tokenized_dict["input_ids"], tokenized_dict["attention_mask"]

    def _embed_text(self, tokenized_text: torch.Tensor) -> torch.Tensor:
        """Embeds tokenized text."""

        return self.embedder.to(tokenized_text.device)(tokenized_text)


def load_model(
    weight_file: str,
    device: str = "cpu",
    strict: bool = False,
    strict_unexpected: bool = True,
) -> ProteinCaption:
    """Loads a ProCap model.

    Args:
        weight_file (str): Path to the saved model weights.
        device (str): Device on which to load the model.
        strict (bool): Whether to require that the keys match between the
            input file weights and the model created from the parameters stored
            in the model kwargs.
        strict_unexpected (bool): Whether to require that there are no
            unexpected keys when loading model weights, as distinct from the
            strict option which doesn't allow for missing keys either. By
            default, we use this option rather than strict for ease of
            development when adding model features.

    Returns:
        model (ProteinCaption): Instance of `ProteinCaption` with loaded
            weights. For inference the returned model should be set to eval mode
            with `model.eval()`.
    """
    return utility_load_model(
        weight_file,
        ProteinCaption,
        device=device,
        strict=strict,
        strict_unexpected=strict_unexpected,
    )


def save_model(
    model: ProteinCaption, weight_file: str, metadata: Optional[dict] = None
) -> None:
    """Save model, including optional metadata.

    Args:
        model (ProteinCaption): An instance of `ProteinCaption`.
        weight_file (str): The destination path for saving model weights.
        metadata (dict): A dictionary of additional metadata to add to the model
            weights. For example, when saving models during training it can be
            useful to store `args` representing the CLI args, the date and time
            of training, etc.
    """
    utility_save_model(model, weight_file, metadata=metadata)