Spaces:
Running
Running
import math | |
import torch | |
import torch.nn as nn | |
from esm.layers.blocks import UnifiedTransformerBlock | |
from esm.utils.structure.affine3d import Affine3D | |
class TransformerStack(nn.Module): | |
""" | |
A stack of transformer blocks used in the ESM-3 model. Each block is a UnifiedTransformerBlock, | |
which can either be geometric attention or standard multi-head attention. | |
Args: | |
d_model (int): The dimensionality of the input and output feature vectors. | |
n_heads (int): The number of attention heads. | |
v_heads (int): The number of voting heads. | |
n_layers (int): The number of transformer blocks in the stack. | |
n_layers_geom (int, optional): The number of transformer blocks that use geometric attention. | |
scale_residue (bool, optional): Whether to scale the residue connections in each transformer block. | |
mask_and_zero_frameless (bool, optional): Whether to mask and zero frameless positions in the input. | |
Only applies in the geometric attention blocks, which is conditioned on the structure | |
""" | |
def __init__( | |
self, | |
d_model: int, | |
n_heads: int, | |
v_heads: int | None, | |
n_layers: int, | |
n_layers_geom: int = 1, | |
scale_residue: bool = True, | |
mask_and_zero_frameless: bool = False, | |
bias: bool = False, | |
qk_layernorm: bool = True, | |
ffn_type: str = "swiglu", # swiglu | gelu | |
expansion_ratio: float = 8 / 3, | |
): | |
super().__init__() | |
self.blocks = nn.ModuleList( | |
[ | |
UnifiedTransformerBlock( | |
d_model, | |
n_heads, | |
v_heads=v_heads, | |
use_geom_attn=i < n_layers_geom, | |
residue_scaling_factor=( | |
math.sqrt(n_layers / 36) if scale_residue else 1.0 | |
), | |
expansion_ratio=expansion_ratio, | |
mask_and_zero_frameless=mask_and_zero_frameless, | |
bias=bias, | |
qk_layernorm=qk_layernorm, | |
ffn_type=ffn_type, | |
) | |
for i in range(n_layers) | |
] | |
) | |
self.norm = nn.LayerNorm(d_model, bias=False) | |
def forward( | |
self, | |
x: torch.Tensor, | |
sequence_id: torch.Tensor | None = None, | |
affine: Affine3D | None = None, | |
affine_mask: torch.Tensor | None = None, | |
chain_id: torch.Tensor | None = None, | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Forward pass of the TransformerStack. | |
Args: | |
x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, d_model). | |
sequence_id (torch.Tensor): The sequence ID tensor of shape (batch_size, sequence_length). | |
affine (Affine3D | None): The affine transformation tensor or None. | |
affine_mask (torch.Tensor | None): The affine mask tensor or None. | |
chain_id (torch.Tensor): The protein chain tensor of shape (batch_size, sequence_length). | |
Only used in geometric attention. | |
Returns: | |
post_norm: The output tensor of shape (batch_size, sequence_length, d_model). | |
pre_norm: The embedding of shape (batch_size, sequence_length, d_model). | |
""" | |
*batch_dims, _ = x.shape | |
if sequence_id is None: | |
sequence_id = torch.ones( | |
size=batch_dims, dtype=torch.int64, device=x.device | |
) | |
if chain_id is None: | |
chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device) | |
for block in self.blocks: | |
x = block(x, sequence_id, affine, affine_mask, chain_id) | |
return self.norm(x), x | |