Hukuna's picture
Upload 221 files
ce7bf5b verified
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from chroma.layers.norm import MaskedBatchNorm1d
class NoOp(nn.Module):
"""A dummy nn.Module wrapping an identity operation.
空操作模块,用来满足代码结构
Inputs:
x (any)
Outputs:
x (any)
"""
def __init__(self):
super().__init__()
def forward(self, x, **kwargs):
return x
class Transpose(nn.Module):
"""An nn.Module wrapping ```torch.transpose```.
Args:
d1 (int): the first (of two) dimensions to swap
d2 (int): the second (of two) dimensions to swap
Inputs:
x (torch.tensor)
Outputs:
y (torch.tensor): ```y = x.transpose(d1,d2)```
"""
def __init__(self, d1=1, d2=2):
super().__init__()
self.d1 = d1
self.d2 = d2
def forward(self, x):
return x.transpose(self.d1, self.d2)
class Unsqueeze(nn.Module):
"""An nn.Module wrapping ```torch.unsqueeze```.
Args:
dim (int): the dimension to unsqueeze input tensors
Inputs:
x (torch.tensor):
Outputs:
y (torch.tensor): where ```y=x.unsqueeze(dim)```
"""
def __init__(self, dim=1):
super().__init__()
self.dim = dim
def forward(self, x):
return x.unsqueeze(self.dim)
class OneHot(nn.Module):
"""An nn.Module that wraps F.one_hot```.
Args:
n_tokens (int): the number of tokens comprising input sequences
Inputs:
x (torch.LongTensor): of size ```(batch_size, *)```
Outputs:
y (torch.ByteTensor): of size (batch_size, *, n_tokens) cast to input.device
"""
def __init__(self, n_tokens):
super().__init__()
self.n_tokens = n_tokens
def forward(self, x):
return F.one_hot(x, self.n_tokens)
class MeanEmbedding(nn.Module):
"""A wrapper around ```nn.Embedding``` that allows for one-hot-like representation inputs (as well as standard tokenized representation),
optionally applying a softmax to the last dimension if the input corresponds to a log-PMF.
Args:
embedding (nn.Embedding): Embedding to wrap
use_softmax (bool): Whether to apply a softmax to the last dimension if input is one-hot-like.
Inputs:
x (torch.tensor): of size (batch_size, sequence_length) (standard tokenized representation) -OR- (batch_size, sequence_length, number_tokens) (one-hot representation)
Outputs:
y (torch.tensor): of size (batch_size, sequence_length, embedding_dimension) obtained via. lookup into ```self.embedding.weight``` if
input is in standard tokenized form or by matrix multiplication of input with ```self.embedding.weight``` if input is one-hot-like. Note
that if the input is a one-hot matrix the output is the same regardless of representation.
这个模块是nn.Embedding 的包装器,它允许输是one-hot-like的表示(以及标准的tokenized表示),
并且如果输入对应于log-PMF,还以选择性地对最后 个维度应用softmax
"""
def __init__(self, embedding, use_softmax=True):
super(MeanEmbedding, self).__init__()
self.embedding = embedding
self.use_softmax = use_softmax
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
if len(x.shape) == 2:
return self.embedding(x)
elif len(x.shape) == 3:
if self.use_softmax:
return self.softmax(x) @ self.embedding.weight
else:
return x @ self.embedding.weight
else:
raise (NotImplementedError)
class PeriodicPositionalEncoding(nn.Module):
"""Positional encoding, adapted from 'The Annotated Transformer'
http://nlp.seas.harvard.edu/2018/04/03/attention.html
这个模块实现了周期性的位置编码,这是Transformer模型的一个重要组成部分。
它使用正弦和余弦函数来生成位置编码
Args:
d_model (int): input and output dimension for the layer
max_seq_len (int): maximum allowed sequence length
dropout (float): Dropout rate
Inputs:
x (torch.tensor): of size (batch_size, sequence_length, d_model)
Outputs:
y (torch.tensor): of size (batch_size, sequence_length, d_model)
"""
def __init__(self, d_model, max_seq_len=4000, dropout=0.0):
super(PeriodicPositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0.0, max_seq_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[:, : x.size(1)]
return self.dropout(x)
class PositionWiseFeedForward(nn.Module):
"""Position-wise feed-forward using 1x1 convolutions, a building block of legacy Transformer code (not code optimized).
这个模块实现了位置感知的前馈网络,这也是Transformer模型的一个重要组成部分。
它使用1x1的卷积来实现前馈网络。
Args:
d_model (int): input and output dimension for the layer
d_inner_hid (int): size of the hidden layer in the position-wise feed-forward sublayer
Inputs:
x (torch.tensor): of size (batch_size, sequence_length, d_model)
Outputs:
y (torch.tensor): of size (batch_size, sequence_length, d_model)
"""
def __init__(self, d_model, d_hidden, dropout=0.1):
super(PositionWiseFeedForward, self).__init__()
self.activation = nn.ReLU()
self.linear1 = nn.Conv1d(d_model, d_hidden, 1)
self.linear2 = nn.Conv1d(d_hidden, d_model, 1)
self.dropout = nn.Dropout(p=dropout)
def reset_parameters(self):
self.linear1.reset_parameters()
self.linear2.reset_parameters()
def forward(self, x):
output = self.activation(self.linear1(x.transpose(1, 2)))
output = self.linear2(output).transpose(1, 2)
return self.dropout(output)
class DropNormLin(nn.Module):
"""nn.Module applying a linear layer, normalization, dropout, and activation
这个模块应用了一个线性层、归一化、dropout和激活函数。你可以选择使用层归一化 (In') 或批归一 (bn) ,或者跳过过归一化。
Args:
in_features (int): input dimension
out_features (int): output dimension
norm_type (str): ```'ln'``` for layer normalization or ```'bn'``` for batch normalization else skip normalization
dropout (float): dropout to apply
actn (nn.Module): activation function to apply
Input:
x (torch.tensor): of size (batch_size, sequence_length, in_features)
input_mask (torch.tensor): of size (batch_size, 1, sequence_length) (optional)
Output:
y (torch.tensor): of size (batch_size, sequence_length, out_features)
"""
def __init__(
self, in_features, out_features, norm_type="ln", dropout=0.0, actn=nn.ReLU()
):
super(DropNormLin, self).__init__()
self.linear = nn.Linear(in_features, out_features)
if norm_type == "ln":
self.norm_layer = nn.LayerNorm(out_features)
elif norm_type == "bn":
self.norm_layer = MaskedBatchNorm1d(out_features)
else:
self.norm_layer = NoOp()
self.dropout = nn.Dropout(p=dropout)
self.actn = actn
def forward(self, x, input_mask=None):
h = self.linear(x)
if isinstance(self.norm_layer, MaskedBatchNorm1d):
h = self.norm_layer(h.transpose(1, 2), input_mask=input_mask).transpose(
1, 2
)
else:
h = self.norm_layer(h)
return self.dropout(self.actn(h))
class ResidualLinearLayer(nn.Module):
"""A Simple Residual Layer using a linear layer a relu and an optional layer norm.
这个模块实现了一个简单的残差层,使用了一个线性层、ReLU激活函数和一个可选的层归一化。
Args:
d_model (int): Model Dimension
use_norm (bool, *optional*): Optionally Use a Layer Norm. Default `True`.
"""
def __init__(self, d_model, use_norm=True):
super(ResidualLinearLayer, self).__init__()
self.linear = nn.Linear(d_model, d_model)
self.ReLU = nn.ReLU()
self.use_norm = use_norm
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
z = self.linear(x)
z = self.ReLU(z)
if self.use_norm:
z = self.norm(z)
return x + z
class TriangleMultiplication(nn.Module):
def __init__(self, d_model=512, mode="outgoing"):
"""
Triangle multiplication as defined in Jumper et al. (2021)
这个模块实现了Jumper等人在2021年的论文中定义的三角乘法。它接受一个四维的张量作为输入
并通过一系列的线性变换和非线性激活函数,以及一个特殊的乘法操作(由 torch.einsum实现) ,来计算输出。
Args:
d_model (int): dimension of the embedding at each position
mode (str): Must be 'outgoing' (algorithm 11) or 'incoming' (algorithm 12).
Inputs:
X (torch.tensor): Pair representations of size (batch, nres, nres, channels)
mask (torch.tensor): of dtype `torch.bool` and size (batch, nres, nres, channels) (or broadcastable to this size)
Outputs:
Y (torch.tensor): Pair representations of size (batch, nres, nres, channels)
"""
super().__init__()
self.mode = mode
assert self.mode in ["outgoing", "incoming"]
self.equation = (
"bikc,bjkc->bijc" if self.mode == "outgoing" else "bkjc,bkic->bijc"
)
self.layer_norm = nn.LayerNorm(d_model)
self.left_edge_mlp = nn.Sequential(
nn.Linear(d_model, d_model), nn.Sigmoid(), nn.Linear(d_model, d_model)
)
self.right_edge_mlp = nn.Sequential(
nn.Linear(d_model, d_model), nn.Sigmoid(), nn.Linear(d_model, d_model)
)
self.skip = nn.Sequential(nn.Linear(d_model, d_model), nn.Sigmoid())
self.combine = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, d_model))
def forward(self, X, mask=None):
h = self.layer_norm(X)
A = self.left_edge_mlp(h)
B = self.right_edge_mlp(h)
G = self.skip(h)
if mask is not None:
A = A.masked_fill(~mask, 0.0)
B = B.masked_fill(~mask, 0.0)
h = torch.einsum(self.equation, A, B)
h = self.combine(h) * G
return h
class NodeProduct(nn.Module):
"""Like Alg. 10 in Jumper et al. (2021) but instead of computing a mean over MSA dimension,
process for single-sequence inputs.
这个模块实现了Jumper等人在2021年的论文中描述的节点乘积算法。
它接受一个二维的张量作为输入,然后通过一系列的线性变换和层归一化操作,来计算输出。
Args:
d_in (int): dimension of node embeddings (inputs)
d_out (int): dimension of edge embeddings (outputs)
Inputs:
node_features (torch.tensor): of size (batch_size, nres, d_model)
node_mask (torch.tensor): of size (batch_size, nres)
edge_mask (torch.tensor): of size (batch_size, nres, nres)
Outputs:
edge_features (torch.tensor): of size (batch_size, nres, nres, d_model)
"""
def __init__(self, d_in, d_out):
super().__init__()
self.layer_norm = nn.LayerNorm(d_in)
self.left_lin = nn.Linear(d_in, d_in)
self.right_lin = nn.Linear(d_in, d_in)
self.edge_lin = nn.Linear(2 * d_in, d_out)
def forward(self, node_features, node_mask=None, edge_mask=None):
_, nres, _ = node_features.size()
node_features = self.layer_norm(node_features)
left_embs = self.left_lin(node_features)
right_embs = self.right_lin(node_features)
if node_mask is not None:
mask = node_mask[:, :, None]
left_embs = left_embs.masked_fill(~mask, 0.0)
right_embs = right_embs.masked_fill(~mask, 0.0)
left_embs = left_embs[:, None, :, :].repeat(1, nres, 1, 1)
right_embs = right_embs[:, :, None, :].repeat(1, 1, nres, 1)
edge_features = torch.cat([left_embs, right_embs], dim=-1)
edge_features = self.edge_lin(edge_features)
if edge_mask is not None:
mask = edge_mask[:, :, :, None]
edge_features = edge_features.masked_fill(~mask, 0.0)
return edge_features
class FourierFeaturization(nn.Module):
"""Applies fourier featurization of low-dimensional (usually spatial) input data as described in [https://arxiv.org/abs/2006.10739] ,
optionally trainable as described in [https://arxiv.org/abs/2106.02795].
这个模块实现了对低维输入数据的傅里叶特征化,这是一种将输入数据转换为频域表示的方法。
这个模块可以选择是否学习傅里叶特征的频率
Args:
d_input (int): dimension of inputs
d_model (int): dimension of outputs
trainable (bool): whether to learn the frequency of fourier features
scale (float): if not trainable, controls the scale of fourier feature periods (see reference for description, this parameter matters and should be tuned!)
Inputs:
input (torch.tensor): of size (batch_size, *, d_input)
Outputs:
output (torch.tensor): of size (batch_size, *, d_output)
"""
def __init__(self, d_input, d_model, trainable=False, scale=1.0):
super().__init__()
self.scale = scale
if d_model % 2 != 0:
raise ValueError(
"d_model needs to be even for this featurization, try again!"
)
B = 2 * math.pi * scale * torch.randn(d_input, d_model // 2)
self.trainable = trainable
if not trainable:
self.register_buffer("B", B)
else:
self.register_parameter("B", torch.nn.Parameter(B))
def forward(self, inputs):
h = inputs @ self.B
return torch.cat([h.cos(), h.sin()], -1)
class PositionalEncoding(nn.Module):
"""Axis-aligned positional encodings with log-linear spacing.
这个模块实现了对输入数据的位置编码,这是一种将输入数据的位置信息编码为连续的向量的方法。
这个模块使用了对数线性间隔的频率组件。
Args:
d_input (int): dimension of inputs
d_model (int): dimension of outputs
period_range (tuple of floats): Min and maximum periods for the
frequency components. Fourier features will be log-linearly spaced
between these values (inclusive).
Inputs:
input (torch.tensor): of size (..., d_input)
Outputs:
output (torch.tensor): of size (..., d_model)
"""
def __init__(self, d_model, d_input=1, period_range=(1.0, 1000.0)):
super().__init__()
if d_model % (2 * d_input) != 0:
raise ValueError(
"d_model needs to be divisible by 2*d_input for this featurization, "
f"but got {d_model} versus {d_input}"
)
num_frequencies = d_model // (2 * d_input)
log_bounds = np.log10(period_range)
p = torch.logspace(log_bounds[0], log_bounds[1], num_frequencies, base=10.0)
w = 2 * math.pi / p
self.register_buffer("w", w)
def forward(self, inputs):
batch_dims = list(inputs.shape)[:-1]
# (..., 1, num_out) * (..., num_in, 1)
w = self.w.reshape(len(batch_dims) * [1] + [1, -1])
h = w * inputs[..., None]
h = torch.cat([h.cos(), h.sin()], -1).reshape(batch_dims + [-1])
return h
class MaybeOnehotEmbedding(nn.Embedding):
"""Wrapper around :class:`torch.nn.Embedding` to support either int-encoded
LongTensors or one-hot encoded FloatTensors.
这个模块是torch.nn.Embedding 的包装器,它支持整数编码的LongTensor输入或者独热编码的FloatTensor输入。
如果输入是浮点类型,那么它会通过矩阵乘法来计算嵌入,否则,它会调用父类的 forward 方法来计算嵌入。
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dtype.is_floating_point: # onehot
return x @ self.weight
return super().forward(x)