Spaces:
Sleeping
Sleeping
# Copyright Generate Biomedicines, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from chroma.layers.norm import MaskedBatchNorm1d | |
class NoOp(nn.Module): | |
"""A dummy nn.Module wrapping an identity operation. | |
空操作模块,用来满足代码结构 | |
Inputs: | |
x (any) | |
Outputs: | |
x (any) | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, x, **kwargs): | |
return x | |
class Transpose(nn.Module): | |
"""An nn.Module wrapping ```torch.transpose```. | |
Args: | |
d1 (int): the first (of two) dimensions to swap | |
d2 (int): the second (of two) dimensions to swap | |
Inputs: | |
x (torch.tensor) | |
Outputs: | |
y (torch.tensor): ```y = x.transpose(d1,d2)``` | |
""" | |
def __init__(self, d1=1, d2=2): | |
super().__init__() | |
self.d1 = d1 | |
self.d2 = d2 | |
def forward(self, x): | |
return x.transpose(self.d1, self.d2) | |
class Unsqueeze(nn.Module): | |
"""An nn.Module wrapping ```torch.unsqueeze```. | |
Args: | |
dim (int): the dimension to unsqueeze input tensors | |
Inputs: | |
x (torch.tensor): | |
Outputs: | |
y (torch.tensor): where ```y=x.unsqueeze(dim)``` | |
""" | |
def __init__(self, dim=1): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
return x.unsqueeze(self.dim) | |
class OneHot(nn.Module): | |
"""An nn.Module that wraps F.one_hot```. | |
Args: | |
n_tokens (int): the number of tokens comprising input sequences | |
Inputs: | |
x (torch.LongTensor): of size ```(batch_size, *)``` | |
Outputs: | |
y (torch.ByteTensor): of size (batch_size, *, n_tokens) cast to input.device | |
""" | |
def __init__(self, n_tokens): | |
super().__init__() | |
self.n_tokens = n_tokens | |
def forward(self, x): | |
return F.one_hot(x, self.n_tokens) | |
class MeanEmbedding(nn.Module): | |
"""A wrapper around ```nn.Embedding``` that allows for one-hot-like representation inputs (as well as standard tokenized representation), | |
optionally applying a softmax to the last dimension if the input corresponds to a log-PMF. | |
Args: | |
embedding (nn.Embedding): Embedding to wrap | |
use_softmax (bool): Whether to apply a softmax to the last dimension if input is one-hot-like. | |
Inputs: | |
x (torch.tensor): of size (batch_size, sequence_length) (standard tokenized representation) -OR- (batch_size, sequence_length, number_tokens) (one-hot representation) | |
Outputs: | |
y (torch.tensor): of size (batch_size, sequence_length, embedding_dimension) obtained via. lookup into ```self.embedding.weight``` if | |
input is in standard tokenized form or by matrix multiplication of input with ```self.embedding.weight``` if input is one-hot-like. Note | |
that if the input is a one-hot matrix the output is the same regardless of representation. | |
这个模块是nn.Embedding 的包装器,它允许输是one-hot-like的表示(以及标准的tokenized表示), | |
并且如果输入对应于log-PMF,还以选择性地对最后 个维度应用softmax | |
""" | |
def __init__(self, embedding, use_softmax=True): | |
super(MeanEmbedding, self).__init__() | |
self.embedding = embedding | |
self.use_softmax = use_softmax | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, x): | |
if len(x.shape) == 2: | |
return self.embedding(x) | |
elif len(x.shape) == 3: | |
if self.use_softmax: | |
return self.softmax(x) @ self.embedding.weight | |
else: | |
return x @ self.embedding.weight | |
else: | |
raise (NotImplementedError) | |
class PeriodicPositionalEncoding(nn.Module): | |
"""Positional encoding, adapted from 'The Annotated Transformer' | |
http://nlp.seas.harvard.edu/2018/04/03/attention.html | |
这个模块实现了周期性的位置编码,这是Transformer模型的一个重要组成部分。 | |
它使用正弦和余弦函数来生成位置编码 | |
Args: | |
d_model (int): input and output dimension for the layer | |
max_seq_len (int): maximum allowed sequence length | |
dropout (float): Dropout rate | |
Inputs: | |
x (torch.tensor): of size (batch_size, sequence_length, d_model) | |
Outputs: | |
y (torch.tensor): of size (batch_size, sequence_length, d_model) | |
""" | |
def __init__(self, d_model, max_seq_len=4000, dropout=0.0): | |
super(PeriodicPositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_seq_len, d_model) | |
position = torch.arange(0.0, max_seq_len).unsqueeze(1) | |
div_term = torch.exp( | |
torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model) | |
) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0) | |
self.register_buffer("pe", pe) | |
def forward(self, x): | |
x = x + self.pe[:, : x.size(1)] | |
return self.dropout(x) | |
class PositionWiseFeedForward(nn.Module): | |
"""Position-wise feed-forward using 1x1 convolutions, a building block of legacy Transformer code (not code optimized). | |
这个模块实现了位置感知的前馈网络,这也是Transformer模型的一个重要组成部分。 | |
它使用1x1的卷积来实现前馈网络。 | |
Args: | |
d_model (int): input and output dimension for the layer | |
d_inner_hid (int): size of the hidden layer in the position-wise feed-forward sublayer | |
Inputs: | |
x (torch.tensor): of size (batch_size, sequence_length, d_model) | |
Outputs: | |
y (torch.tensor): of size (batch_size, sequence_length, d_model) | |
""" | |
def __init__(self, d_model, d_hidden, dropout=0.1): | |
super(PositionWiseFeedForward, self).__init__() | |
self.activation = nn.ReLU() | |
self.linear1 = nn.Conv1d(d_model, d_hidden, 1) | |
self.linear2 = nn.Conv1d(d_hidden, d_model, 1) | |
self.dropout = nn.Dropout(p=dropout) | |
def reset_parameters(self): | |
self.linear1.reset_parameters() | |
self.linear2.reset_parameters() | |
def forward(self, x): | |
output = self.activation(self.linear1(x.transpose(1, 2))) | |
output = self.linear2(output).transpose(1, 2) | |
return self.dropout(output) | |
class DropNormLin(nn.Module): | |
"""nn.Module applying a linear layer, normalization, dropout, and activation | |
这个模块应用了一个线性层、归一化、dropout和激活函数。你可以选择使用层归一化 (In') 或批归一 (bn) ,或者跳过过归一化。 | |
Args: | |
in_features (int): input dimension | |
out_features (int): output dimension | |
norm_type (str): ```'ln'``` for layer normalization or ```'bn'``` for batch normalization else skip normalization | |
dropout (float): dropout to apply | |
actn (nn.Module): activation function to apply | |
Input: | |
x (torch.tensor): of size (batch_size, sequence_length, in_features) | |
input_mask (torch.tensor): of size (batch_size, 1, sequence_length) (optional) | |
Output: | |
y (torch.tensor): of size (batch_size, sequence_length, out_features) | |
""" | |
def __init__( | |
self, in_features, out_features, norm_type="ln", dropout=0.0, actn=nn.ReLU() | |
): | |
super(DropNormLin, self).__init__() | |
self.linear = nn.Linear(in_features, out_features) | |
if norm_type == "ln": | |
self.norm_layer = nn.LayerNorm(out_features) | |
elif norm_type == "bn": | |
self.norm_layer = MaskedBatchNorm1d(out_features) | |
else: | |
self.norm_layer = NoOp() | |
self.dropout = nn.Dropout(p=dropout) | |
self.actn = actn | |
def forward(self, x, input_mask=None): | |
h = self.linear(x) | |
if isinstance(self.norm_layer, MaskedBatchNorm1d): | |
h = self.norm_layer(h.transpose(1, 2), input_mask=input_mask).transpose( | |
1, 2 | |
) | |
else: | |
h = self.norm_layer(h) | |
return self.dropout(self.actn(h)) | |
class ResidualLinearLayer(nn.Module): | |
"""A Simple Residual Layer using a linear layer a relu and an optional layer norm. | |
这个模块实现了一个简单的残差层,使用了一个线性层、ReLU激活函数和一个可选的层归一化。 | |
Args: | |
d_model (int): Model Dimension | |
use_norm (bool, *optional*): Optionally Use a Layer Norm. Default `True`. | |
""" | |
def __init__(self, d_model, use_norm=True): | |
super(ResidualLinearLayer, self).__init__() | |
self.linear = nn.Linear(d_model, d_model) | |
self.ReLU = nn.ReLU() | |
self.use_norm = use_norm | |
self.norm = nn.LayerNorm(d_model) | |
def forward(self, x): | |
z = self.linear(x) | |
z = self.ReLU(z) | |
if self.use_norm: | |
z = self.norm(z) | |
return x + z | |
class TriangleMultiplication(nn.Module): | |
def __init__(self, d_model=512, mode="outgoing"): | |
""" | |
Triangle multiplication as defined in Jumper et al. (2021) | |
这个模块实现了Jumper等人在2021年的论文中定义的三角乘法。它接受一个四维的张量作为输入 | |
并通过一系列的线性变换和非线性激活函数,以及一个特殊的乘法操作(由 torch.einsum实现) ,来计算输出。 | |
Args: | |
d_model (int): dimension of the embedding at each position | |
mode (str): Must be 'outgoing' (algorithm 11) or 'incoming' (algorithm 12). | |
Inputs: | |
X (torch.tensor): Pair representations of size (batch, nres, nres, channels) | |
mask (torch.tensor): of dtype `torch.bool` and size (batch, nres, nres, channels) (or broadcastable to this size) | |
Outputs: | |
Y (torch.tensor): Pair representations of size (batch, nres, nres, channels) | |
""" | |
super().__init__() | |
self.mode = mode | |
assert self.mode in ["outgoing", "incoming"] | |
self.equation = ( | |
"bikc,bjkc->bijc" if self.mode == "outgoing" else "bkjc,bkic->bijc" | |
) | |
self.layer_norm = nn.LayerNorm(d_model) | |
self.left_edge_mlp = nn.Sequential( | |
nn.Linear(d_model, d_model), nn.Sigmoid(), nn.Linear(d_model, d_model) | |
) | |
self.right_edge_mlp = nn.Sequential( | |
nn.Linear(d_model, d_model), nn.Sigmoid(), nn.Linear(d_model, d_model) | |
) | |
self.skip = nn.Sequential(nn.Linear(d_model, d_model), nn.Sigmoid()) | |
self.combine = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, d_model)) | |
def forward(self, X, mask=None): | |
h = self.layer_norm(X) | |
A = self.left_edge_mlp(h) | |
B = self.right_edge_mlp(h) | |
G = self.skip(h) | |
if mask is not None: | |
A = A.masked_fill(~mask, 0.0) | |
B = B.masked_fill(~mask, 0.0) | |
h = torch.einsum(self.equation, A, B) | |
h = self.combine(h) * G | |
return h | |
class NodeProduct(nn.Module): | |
"""Like Alg. 10 in Jumper et al. (2021) but instead of computing a mean over MSA dimension, | |
process for single-sequence inputs. | |
这个模块实现了Jumper等人在2021年的论文中描述的节点乘积算法。 | |
它接受一个二维的张量作为输入,然后通过一系列的线性变换和层归一化操作,来计算输出。 | |
Args: | |
d_in (int): dimension of node embeddings (inputs) | |
d_out (int): dimension of edge embeddings (outputs) | |
Inputs: | |
node_features (torch.tensor): of size (batch_size, nres, d_model) | |
node_mask (torch.tensor): of size (batch_size, nres) | |
edge_mask (torch.tensor): of size (batch_size, nres, nres) | |
Outputs: | |
edge_features (torch.tensor): of size (batch_size, nres, nres, d_model) | |
""" | |
def __init__(self, d_in, d_out): | |
super().__init__() | |
self.layer_norm = nn.LayerNorm(d_in) | |
self.left_lin = nn.Linear(d_in, d_in) | |
self.right_lin = nn.Linear(d_in, d_in) | |
self.edge_lin = nn.Linear(2 * d_in, d_out) | |
def forward(self, node_features, node_mask=None, edge_mask=None): | |
_, nres, _ = node_features.size() | |
node_features = self.layer_norm(node_features) | |
left_embs = self.left_lin(node_features) | |
right_embs = self.right_lin(node_features) | |
if node_mask is not None: | |
mask = node_mask[:, :, None] | |
left_embs = left_embs.masked_fill(~mask, 0.0) | |
right_embs = right_embs.masked_fill(~mask, 0.0) | |
left_embs = left_embs[:, None, :, :].repeat(1, nres, 1, 1) | |
right_embs = right_embs[:, :, None, :].repeat(1, 1, nres, 1) | |
edge_features = torch.cat([left_embs, right_embs], dim=-1) | |
edge_features = self.edge_lin(edge_features) | |
if edge_mask is not None: | |
mask = edge_mask[:, :, :, None] | |
edge_features = edge_features.masked_fill(~mask, 0.0) | |
return edge_features | |
class FourierFeaturization(nn.Module): | |
"""Applies fourier featurization of low-dimensional (usually spatial) input data as described in [https://arxiv.org/abs/2006.10739] , | |
optionally trainable as described in [https://arxiv.org/abs/2106.02795]. | |
这个模块实现了对低维输入数据的傅里叶特征化,这是一种将输入数据转换为频域表示的方法。 | |
这个模块可以选择是否学习傅里叶特征的频率 | |
Args: | |
d_input (int): dimension of inputs | |
d_model (int): dimension of outputs | |
trainable (bool): whether to learn the frequency of fourier features | |
scale (float): if not trainable, controls the scale of fourier feature periods (see reference for description, this parameter matters and should be tuned!) | |
Inputs: | |
input (torch.tensor): of size (batch_size, *, d_input) | |
Outputs: | |
output (torch.tensor): of size (batch_size, *, d_output) | |
""" | |
def __init__(self, d_input, d_model, trainable=False, scale=1.0): | |
super().__init__() | |
self.scale = scale | |
if d_model % 2 != 0: | |
raise ValueError( | |
"d_model needs to be even for this featurization, try again!" | |
) | |
B = 2 * math.pi * scale * torch.randn(d_input, d_model // 2) | |
self.trainable = trainable | |
if not trainable: | |
self.register_buffer("B", B) | |
else: | |
self.register_parameter("B", torch.nn.Parameter(B)) | |
def forward(self, inputs): | |
h = inputs @ self.B | |
return torch.cat([h.cos(), h.sin()], -1) | |
class PositionalEncoding(nn.Module): | |
"""Axis-aligned positional encodings with log-linear spacing. | |
这个模块实现了对输入数据的位置编码,这是一种将输入数据的位置信息编码为连续的向量的方法。 | |
这个模块使用了对数线性间隔的频率组件。 | |
Args: | |
d_input (int): dimension of inputs | |
d_model (int): dimension of outputs | |
period_range (tuple of floats): Min and maximum periods for the | |
frequency components. Fourier features will be log-linearly spaced | |
between these values (inclusive). | |
Inputs: | |
input (torch.tensor): of size (..., d_input) | |
Outputs: | |
output (torch.tensor): of size (..., d_model) | |
""" | |
def __init__(self, d_model, d_input=1, period_range=(1.0, 1000.0)): | |
super().__init__() | |
if d_model % (2 * d_input) != 0: | |
raise ValueError( | |
"d_model needs to be divisible by 2*d_input for this featurization, " | |
f"but got {d_model} versus {d_input}" | |
) | |
num_frequencies = d_model // (2 * d_input) | |
log_bounds = np.log10(period_range) | |
p = torch.logspace(log_bounds[0], log_bounds[1], num_frequencies, base=10.0) | |
w = 2 * math.pi / p | |
self.register_buffer("w", w) | |
def forward(self, inputs): | |
batch_dims = list(inputs.shape)[:-1] | |
# (..., 1, num_out) * (..., num_in, 1) | |
w = self.w.reshape(len(batch_dims) * [1] + [1, -1]) | |
h = w * inputs[..., None] | |
h = torch.cat([h.cos(), h.sin()], -1).reshape(batch_dims + [-1]) | |
return h | |
class MaybeOnehotEmbedding(nn.Embedding): | |
"""Wrapper around :class:`torch.nn.Embedding` to support either int-encoded | |
LongTensors or one-hot encoded FloatTensors. | |
这个模块是torch.nn.Embedding 的包装器,它支持整数编码的LongTensor输入或者独热编码的FloatTensor输入。 | |
如果输入是浮点类型,那么它会通过矩阵乘法来计算嵌入,否则,它会调用父类的 forward 方法来计算嵌入。 | |
""" | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if x.dtype.is_floating_point: # onehot | |
return x @ self.weight | |
return super().forward(x) | |