# Copyright Generate Biomedicines, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from chroma.layers.norm import MaskedBatchNorm1d class NoOp(nn.Module): """A dummy nn.Module wrapping an identity operation. 空操作模块,用来满足代码结构 Inputs: x (any) Outputs: x (any) """ def __init__(self): super().__init__() def forward(self, x, **kwargs): return x class Transpose(nn.Module): """An nn.Module wrapping ```torch.transpose```. Args: d1 (int): the first (of two) dimensions to swap d2 (int): the second (of two) dimensions to swap Inputs: x (torch.tensor) Outputs: y (torch.tensor): ```y = x.transpose(d1,d2)``` """ def __init__(self, d1=1, d2=2): super().__init__() self.d1 = d1 self.d2 = d2 def forward(self, x): return x.transpose(self.d1, self.d2) class Unsqueeze(nn.Module): """An nn.Module wrapping ```torch.unsqueeze```. Args: dim (int): the dimension to unsqueeze input tensors Inputs: x (torch.tensor): Outputs: y (torch.tensor): where ```y=x.unsqueeze(dim)``` """ def __init__(self, dim=1): super().__init__() self.dim = dim def forward(self, x): return x.unsqueeze(self.dim) class OneHot(nn.Module): """An nn.Module that wraps F.one_hot```. Args: n_tokens (int): the number of tokens comprising input sequences Inputs: x (torch.LongTensor): of size ```(batch_size, *)``` Outputs: y (torch.ByteTensor): of size (batch_size, *, n_tokens) cast to input.device """ def __init__(self, n_tokens): super().__init__() self.n_tokens = n_tokens def forward(self, x): return F.one_hot(x, self.n_tokens) class MeanEmbedding(nn.Module): """A wrapper around ```nn.Embedding``` that allows for one-hot-like representation inputs (as well as standard tokenized representation), optionally applying a softmax to the last dimension if the input corresponds to a log-PMF. Args: embedding (nn.Embedding): Embedding to wrap use_softmax (bool): Whether to apply a softmax to the last dimension if input is one-hot-like. Inputs: x (torch.tensor): of size (batch_size, sequence_length) (standard tokenized representation) -OR- (batch_size, sequence_length, number_tokens) (one-hot representation) Outputs: y (torch.tensor): of size (batch_size, sequence_length, embedding_dimension) obtained via. lookup into ```self.embedding.weight``` if input is in standard tokenized form or by matrix multiplication of input with ```self.embedding.weight``` if input is one-hot-like. Note that if the input is a one-hot matrix the output is the same regardless of representation. 这个模块是nn.Embedding 的包装器,它允许输是one-hot-like的表示(以及标准的tokenized表示), 并且如果输入对应于log-PMF,还以选择性地对最后 个维度应用softmax """ def __init__(self, embedding, use_softmax=True): super(MeanEmbedding, self).__init__() self.embedding = embedding self.use_softmax = use_softmax self.softmax = nn.Softmax(dim=-1) def forward(self, x): if len(x.shape) == 2: return self.embedding(x) elif len(x.shape) == 3: if self.use_softmax: return self.softmax(x) @ self.embedding.weight else: return x @ self.embedding.weight else: raise (NotImplementedError) class PeriodicPositionalEncoding(nn.Module): """Positional encoding, adapted from 'The Annotated Transformer' http://nlp.seas.harvard.edu/2018/04/03/attention.html 这个模块实现了周期性的位置编码,这是Transformer模型的一个重要组成部分。 它使用正弦和余弦函数来生成位置编码 Args: d_model (int): input and output dimension for the layer max_seq_len (int): maximum allowed sequence length dropout (float): Dropout rate Inputs: x (torch.tensor): of size (batch_size, sequence_length, d_model) Outputs: y (torch.tensor): of size (batch_size, sequence_length, d_model) """ def __init__(self, d_model, max_seq_len=4000, dropout=0.0): super(PeriodicPositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_seq_len, d_model) position = torch.arange(0.0, max_seq_len).unsqueeze(1) div_term = torch.exp( torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer("pe", pe) def forward(self, x): x = x + self.pe[:, : x.size(1)] return self.dropout(x) class PositionWiseFeedForward(nn.Module): """Position-wise feed-forward using 1x1 convolutions, a building block of legacy Transformer code (not code optimized). 这个模块实现了位置感知的前馈网络,这也是Transformer模型的一个重要组成部分。 它使用1x1的卷积来实现前馈网络。 Args: d_model (int): input and output dimension for the layer d_inner_hid (int): size of the hidden layer in the position-wise feed-forward sublayer Inputs: x (torch.tensor): of size (batch_size, sequence_length, d_model) Outputs: y (torch.tensor): of size (batch_size, sequence_length, d_model) """ def __init__(self, d_model, d_hidden, dropout=0.1): super(PositionWiseFeedForward, self).__init__() self.activation = nn.ReLU() self.linear1 = nn.Conv1d(d_model, d_hidden, 1) self.linear2 = nn.Conv1d(d_hidden, d_model, 1) self.dropout = nn.Dropout(p=dropout) def reset_parameters(self): self.linear1.reset_parameters() self.linear2.reset_parameters() def forward(self, x): output = self.activation(self.linear1(x.transpose(1, 2))) output = self.linear2(output).transpose(1, 2) return self.dropout(output) class DropNormLin(nn.Module): """nn.Module applying a linear layer, normalization, dropout, and activation 这个模块应用了一个线性层、归一化、dropout和激活函数。你可以选择使用层归一化 (In') 或批归一 (bn) ,或者跳过过归一化。 Args: in_features (int): input dimension out_features (int): output dimension norm_type (str): ```'ln'``` for layer normalization or ```'bn'``` for batch normalization else skip normalization dropout (float): dropout to apply actn (nn.Module): activation function to apply Input: x (torch.tensor): of size (batch_size, sequence_length, in_features) input_mask (torch.tensor): of size (batch_size, 1, sequence_length) (optional) Output: y (torch.tensor): of size (batch_size, sequence_length, out_features) """ def __init__( self, in_features, out_features, norm_type="ln", dropout=0.0, actn=nn.ReLU() ): super(DropNormLin, self).__init__() self.linear = nn.Linear(in_features, out_features) if norm_type == "ln": self.norm_layer = nn.LayerNorm(out_features) elif norm_type == "bn": self.norm_layer = MaskedBatchNorm1d(out_features) else: self.norm_layer = NoOp() self.dropout = nn.Dropout(p=dropout) self.actn = actn def forward(self, x, input_mask=None): h = self.linear(x) if isinstance(self.norm_layer, MaskedBatchNorm1d): h = self.norm_layer(h.transpose(1, 2), input_mask=input_mask).transpose( 1, 2 ) else: h = self.norm_layer(h) return self.dropout(self.actn(h)) class ResidualLinearLayer(nn.Module): """A Simple Residual Layer using a linear layer a relu and an optional layer norm. 这个模块实现了一个简单的残差层,使用了一个线性层、ReLU激活函数和一个可选的层归一化。 Args: d_model (int): Model Dimension use_norm (bool, *optional*): Optionally Use a Layer Norm. Default `True`. """ def __init__(self, d_model, use_norm=True): super(ResidualLinearLayer, self).__init__() self.linear = nn.Linear(d_model, d_model) self.ReLU = nn.ReLU() self.use_norm = use_norm self.norm = nn.LayerNorm(d_model) def forward(self, x): z = self.linear(x) z = self.ReLU(z) if self.use_norm: z = self.norm(z) return x + z class TriangleMultiplication(nn.Module): def __init__(self, d_model=512, mode="outgoing"): """ Triangle multiplication as defined in Jumper et al. (2021) 这个模块实现了Jumper等人在2021年的论文中定义的三角乘法。它接受一个四维的张量作为输入 并通过一系列的线性变换和非线性激活函数,以及一个特殊的乘法操作(由 torch.einsum实现) ,来计算输出。 Args: d_model (int): dimension of the embedding at each position mode (str): Must be 'outgoing' (algorithm 11) or 'incoming' (algorithm 12). Inputs: X (torch.tensor): Pair representations of size (batch, nres, nres, channels) mask (torch.tensor): of dtype `torch.bool` and size (batch, nres, nres, channels) (or broadcastable to this size) Outputs: Y (torch.tensor): Pair representations of size (batch, nres, nres, channels) """ super().__init__() self.mode = mode assert self.mode in ["outgoing", "incoming"] self.equation = ( "bikc,bjkc->bijc" if self.mode == "outgoing" else "bkjc,bkic->bijc" ) self.layer_norm = nn.LayerNorm(d_model) self.left_edge_mlp = nn.Sequential( nn.Linear(d_model, d_model), nn.Sigmoid(), nn.Linear(d_model, d_model) ) self.right_edge_mlp = nn.Sequential( nn.Linear(d_model, d_model), nn.Sigmoid(), nn.Linear(d_model, d_model) ) self.skip = nn.Sequential(nn.Linear(d_model, d_model), nn.Sigmoid()) self.combine = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, d_model)) def forward(self, X, mask=None): h = self.layer_norm(X) A = self.left_edge_mlp(h) B = self.right_edge_mlp(h) G = self.skip(h) if mask is not None: A = A.masked_fill(~mask, 0.0) B = B.masked_fill(~mask, 0.0) h = torch.einsum(self.equation, A, B) h = self.combine(h) * G return h class NodeProduct(nn.Module): """Like Alg. 10 in Jumper et al. (2021) but instead of computing a mean over MSA dimension, process for single-sequence inputs. 这个模块实现了Jumper等人在2021年的论文中描述的节点乘积算法。 它接受一个二维的张量作为输入,然后通过一系列的线性变换和层归一化操作,来计算输出。 Args: d_in (int): dimension of node embeddings (inputs) d_out (int): dimension of edge embeddings (outputs) Inputs: node_features (torch.tensor): of size (batch_size, nres, d_model) node_mask (torch.tensor): of size (batch_size, nres) edge_mask (torch.tensor): of size (batch_size, nres, nres) Outputs: edge_features (torch.tensor): of size (batch_size, nres, nres, d_model) """ def __init__(self, d_in, d_out): super().__init__() self.layer_norm = nn.LayerNorm(d_in) self.left_lin = nn.Linear(d_in, d_in) self.right_lin = nn.Linear(d_in, d_in) self.edge_lin = nn.Linear(2 * d_in, d_out) def forward(self, node_features, node_mask=None, edge_mask=None): _, nres, _ = node_features.size() node_features = self.layer_norm(node_features) left_embs = self.left_lin(node_features) right_embs = self.right_lin(node_features) if node_mask is not None: mask = node_mask[:, :, None] left_embs = left_embs.masked_fill(~mask, 0.0) right_embs = right_embs.masked_fill(~mask, 0.0) left_embs = left_embs[:, None, :, :].repeat(1, nres, 1, 1) right_embs = right_embs[:, :, None, :].repeat(1, 1, nres, 1) edge_features = torch.cat([left_embs, right_embs], dim=-1) edge_features = self.edge_lin(edge_features) if edge_mask is not None: mask = edge_mask[:, :, :, None] edge_features = edge_features.masked_fill(~mask, 0.0) return edge_features class FourierFeaturization(nn.Module): """Applies fourier featurization of low-dimensional (usually spatial) input data as described in [https://arxiv.org/abs/2006.10739] , optionally trainable as described in [https://arxiv.org/abs/2106.02795]. 这个模块实现了对低维输入数据的傅里叶特征化,这是一种将输入数据转换为频域表示的方法。 这个模块可以选择是否学习傅里叶特征的频率 Args: d_input (int): dimension of inputs d_model (int): dimension of outputs trainable (bool): whether to learn the frequency of fourier features scale (float): if not trainable, controls the scale of fourier feature periods (see reference for description, this parameter matters and should be tuned!) Inputs: input (torch.tensor): of size (batch_size, *, d_input) Outputs: output (torch.tensor): of size (batch_size, *, d_output) """ def __init__(self, d_input, d_model, trainable=False, scale=1.0): super().__init__() self.scale = scale if d_model % 2 != 0: raise ValueError( "d_model needs to be even for this featurization, try again!" ) B = 2 * math.pi * scale * torch.randn(d_input, d_model // 2) self.trainable = trainable if not trainable: self.register_buffer("B", B) else: self.register_parameter("B", torch.nn.Parameter(B)) def forward(self, inputs): h = inputs @ self.B return torch.cat([h.cos(), h.sin()], -1) class PositionalEncoding(nn.Module): """Axis-aligned positional encodings with log-linear spacing. 这个模块实现了对输入数据的位置编码,这是一种将输入数据的位置信息编码为连续的向量的方法。 这个模块使用了对数线性间隔的频率组件。 Args: d_input (int): dimension of inputs d_model (int): dimension of outputs period_range (tuple of floats): Min and maximum periods for the frequency components. Fourier features will be log-linearly spaced between these values (inclusive). Inputs: input (torch.tensor): of size (..., d_input) Outputs: output (torch.tensor): of size (..., d_model) """ def __init__(self, d_model, d_input=1, period_range=(1.0, 1000.0)): super().__init__() if d_model % (2 * d_input) != 0: raise ValueError( "d_model needs to be divisible by 2*d_input for this featurization, " f"but got {d_model} versus {d_input}" ) num_frequencies = d_model // (2 * d_input) log_bounds = np.log10(period_range) p = torch.logspace(log_bounds[0], log_bounds[1], num_frequencies, base=10.0) w = 2 * math.pi / p self.register_buffer("w", w) def forward(self, inputs): batch_dims = list(inputs.shape)[:-1] # (..., 1, num_out) * (..., num_in, 1) w = self.w.reshape(len(batch_dims) * [1] + [1, -1]) h = w * inputs[..., None] h = torch.cat([h.cos(), h.sin()], -1).reshape(batch_dims + [-1]) return h class MaybeOnehotEmbedding(nn.Embedding): """Wrapper around :class:`torch.nn.Embedding` to support either int-encoded LongTensors or one-hot encoded FloatTensors. 这个模块是torch.nn.Embedding 的包装器,它支持整数编码的LongTensor输入或者独热编码的FloatTensor输入。 如果输入是浮点类型,那么它会通过矩阵乘法来计算嵌入,否则,它会调用父类的 forward 方法来计算嵌入。 """ def forward(self, x: torch.Tensor) -> torch.Tensor: if x.dtype.is_floating_point: # onehot return x @ self.weight return super().forward(x)