Spaces:
Running
Running
File size: 3,770 Bytes
224a33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import math
import torch
import torch.nn as nn
from esm.layers.blocks import UnifiedTransformerBlock
from esm.utils.structure.affine3d import Affine3D
class TransformerStack(nn.Module):
"""
A stack of transformer blocks used in the ESM-3 model. Each block is a UnifiedTransformerBlock,
which can either be geometric attention or standard multi-head attention.
Args:
d_model (int): The dimensionality of the input and output feature vectors.
n_heads (int): The number of attention heads.
v_heads (int): The number of voting heads.
n_layers (int): The number of transformer blocks in the stack.
n_layers_geom (int, optional): The number of transformer blocks that use geometric attention.
scale_residue (bool, optional): Whether to scale the residue connections in each transformer block.
mask_and_zero_frameless (bool, optional): Whether to mask and zero frameless positions in the input.
Only applies in the geometric attention blocks, which is conditioned on the structure
"""
def __init__(
self,
d_model: int,
n_heads: int,
v_heads: int | None,
n_layers: int,
n_layers_geom: int = 1,
scale_residue: bool = True,
mask_and_zero_frameless: bool = False,
bias: bool = False,
qk_layernorm: bool = True,
ffn_type: str = "swiglu", # swiglu | gelu
expansion_ratio: float = 8 / 3,
):
super().__init__()
self.blocks = nn.ModuleList(
[
UnifiedTransformerBlock(
d_model,
n_heads,
v_heads=v_heads,
use_geom_attn=i < n_layers_geom,
residue_scaling_factor=(
math.sqrt(n_layers / 36) if scale_residue else 1.0
),
expansion_ratio=expansion_ratio,
mask_and_zero_frameless=mask_and_zero_frameless,
bias=bias,
qk_layernorm=qk_layernorm,
ffn_type=ffn_type,
)
for i in range(n_layers)
]
)
self.norm = nn.LayerNorm(d_model, bias=False)
def forward(
self,
x: torch.Tensor,
sequence_id: torch.Tensor | None = None,
affine: Affine3D | None = None,
affine_mask: torch.Tensor | None = None,
chain_id: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of the TransformerStack.
Args:
x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, d_model).
sequence_id (torch.Tensor): The sequence ID tensor of shape (batch_size, sequence_length).
affine (Affine3D | None): The affine transformation tensor or None.
affine_mask (torch.Tensor | None): The affine mask tensor or None.
chain_id (torch.Tensor): The protein chain tensor of shape (batch_size, sequence_length).
Only used in geometric attention.
Returns:
post_norm: The output tensor of shape (batch_size, sequence_length, d_model).
pre_norm: The embedding of shape (batch_size, sequence_length, d_model).
"""
*batch_dims, _ = x.shape
if sequence_id is None:
sequence_id = torch.ones(
size=batch_dims, dtype=torch.int64, device=x.device
)
if chain_id is None:
chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device)
for block in self.blocks:
x = block(x, sequence_id, affine, affine_mask, chain_id)
return self.norm(x), x
|