Simon Duerr
add fast af
85bd48b
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modules and code used in the core part of AlphaFold.
The structure generation code is in 'folding.py'.
"""
import functools
from alphafold.common import residue_constants
from alphafold.model import all_atom
from alphafold.model import common_modules
from alphafold.model import folding
from alphafold.model import layer_stack
from alphafold.model import lddt
from alphafold.model import mapping
from alphafold.model import prng
from alphafold.model import quat_affine
from alphafold.model import utils
import haiku as hk
import jax
import jax.numpy as jnp
from alphafold.model.r3 import Rigids, Rots, Vecs
def softmax_cross_entropy(logits, labels):
"""Computes softmax cross entropy given logits and one-hot class labels."""
loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
return jnp.asarray(loss)
def sigmoid_cross_entropy(logits, labels):
"""Computes sigmoid cross entropy given logits and multiple class labels."""
log_p = jax.nn.log_sigmoid(logits)
# log(1 - sigmoid(x)) = log_sigmoid(-x), the latter is more numerically stable
log_not_p = jax.nn.log_sigmoid(-logits)
loss = -labels * log_p - (1. - labels) * log_not_p
return jnp.asarray(loss)
def apply_dropout(*, tensor, safe_key, rate, is_training, broadcast_dim=None):
"""Applies dropout to a tensor."""
if is_training: # and rate != 0.0:
shape = list(tensor.shape)
if broadcast_dim is not None:
shape[broadcast_dim] = 1
keep_rate = 1.0 - rate
keep = jax.random.bernoulli(safe_key.get(), keep_rate, shape=shape)
return keep * tensor / keep_rate
else:
return tensor
def dropout_wrapper(module,
input_act,
mask,
safe_key,
global_config,
output_act=None,
is_training=True,
scale_rate=1.0,
**kwargs):
"""Applies module + dropout + residual update."""
if output_act is None:
output_act = input_act
gc = global_config
residual = module(input_act, mask, is_training=is_training, **kwargs)
dropout_rate = 0.0 if gc.deterministic else module.config.dropout_rate
if module.config.shared_dropout:
if module.config.orientation == 'per_row':
broadcast_dim = 0
else:
broadcast_dim = 1
else:
broadcast_dim = None
residual = apply_dropout(tensor=residual,
safe_key=safe_key,
rate=dropout_rate * scale_rate,
is_training=is_training,
broadcast_dim=broadcast_dim)
new_act = output_act + residual
return new_act
def create_extra_msa_feature(batch):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Arguments:
batch: a dictionary with the following keys:
* 'extra_msa': [N_extra_seq, N_res] MSA that wasn't selected as a cluster
centre. Note, that this is not one-hot encoded.
* 'extra_has_deletion': [N_extra_seq, N_res] Whether there is a deletion to
the left of each position in the extra MSA.
* 'extra_deletion_value': [N_extra_seq, N_res] The number of deletions to
the left of each position in the extra MSA.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
msa_1hot = jax.nn.one_hot(batch['extra_msa'], 23)
msa_feat = [msa_1hot,
jnp.expand_dims(batch['extra_has_deletion'], axis=-1),
jnp.expand_dims(batch['extra_deletion_value'], axis=-1)]
return jnp.concatenate(msa_feat, axis=-1)
class AlphaFoldIteration(hk.Module):
"""A single recycling iteration of AlphaFold architecture.
Computes ensembled (averaged) representations from the provided features.
These representations are then passed to the various heads
that have been requested by the configuration file. Each head also returns a
loss which is combined as a weighted sum to produce the total loss.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 3-22
"""
def __init__(self, config, global_config, name='alphafold_iteration'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
ensembled_batch,
non_ensembled_batch,
is_training,
compute_loss=False,
ensemble_representations=False,
return_representations=False):
num_ensemble = jnp.asarray(ensembled_batch['seq_length'].shape[0])
if not ensemble_representations:
assert ensembled_batch['seq_length'].shape[0] == 1
def slice_batch(i):
b = {k: v[i] for k, v in ensembled_batch.items()}
b.update(non_ensembled_batch)
return b
# Compute representations for each batch element and average.
evoformer_module = EmbeddingsAndEvoformer(
self.config.embeddings_and_evoformer, self.global_config)
batch0 = slice_batch(0)
representations = evoformer_module(batch0, is_training)
# MSA representations are not ensembled so
# we don't pass tensor into the loop.
msa_representation = representations['msa']
del representations['msa']
# Average the representations (except MSA) over the batch dimension.
if ensemble_representations:
def body(x):
"""Add one element to the representations ensemble."""
i, current_representations = x
feats = slice_batch(i)
representations_update = evoformer_module(
feats, is_training)
new_representations = {}
for k in current_representations:
new_representations[k] = (
current_representations[k] + representations_update[k])
return i+1, new_representations
if hk.running_init():
# When initializing the Haiku module, run one iteration of the
# while_loop to initialize the Haiku modules used in `body`.
_, representations = body((1, representations))
else:
_, representations = hk.while_loop(
lambda x: x[0] < num_ensemble,
body,
(1, representations))
for k in representations:
if k != 'msa':
representations[k] /= num_ensemble.astype(representations[k].dtype)
representations['msa'] = msa_representation
batch = batch0 # We are not ensembled from here on.
if jnp.issubdtype(ensembled_batch['aatype'].dtype, jnp.integer):
_, num_residues = ensembled_batch['aatype'].shape
else:
_, num_residues, _ = ensembled_batch['aatype'].shape
if self.config.use_struct:
struct_module = folding.StructureModule
else:
struct_module = folding.dummy
heads = {}
for head_name, head_config in sorted(self.config.heads.items()):
if not head_config.weight:
continue # Do not instantiate zero-weight heads.
head_factory = {
'masked_msa': MaskedMsaHead,
'distogram': DistogramHead,
'structure_module': functools.partial(struct_module, compute_loss=compute_loss),
'predicted_lddt': PredictedLDDTHead,
'predicted_aligned_error': PredictedAlignedErrorHead,
'experimentally_resolved': ExperimentallyResolvedHead,
}[head_name]
heads[head_name] = (head_config,
head_factory(head_config, self.global_config))
total_loss = 0.
ret = {}
ret['representations'] = representations
def loss(module, head_config, ret, name, filter_ret=True):
if filter_ret:
value = ret[name]
else:
value = ret
loss_output = module.loss(value, batch)
ret[name].update(loss_output)
loss = head_config.weight * ret[name]['loss']
return loss
for name, (head_config, module) in heads.items():
# Skip PredictedLDDTHead and PredictedAlignedErrorHead until
# StructureModule is executed.
if name in ('predicted_lddt', 'predicted_aligned_error'):
continue
else:
ret[name] = module(representations, batch, is_training)
if 'representations' in ret[name]:
# Extra representations from the head. Used by the structure module
# to provide activations for the PredictedLDDTHead.
representations.update(ret[name].pop('representations'))
if compute_loss:
total_loss += loss(module, head_config, ret, name)
if self.config.use_struct:
if self.config.heads.get('predicted_lddt.weight', 0.0):
# Add PredictedLDDTHead after StructureModule executes.
name = 'predicted_lddt'
# Feed all previous results to give access to structure_module result.
head_config, module = heads[name]
ret[name] = module(representations, batch, is_training)
if compute_loss:
total_loss += loss(module, head_config, ret, name, filter_ret=False)
if ('predicted_aligned_error' in self.config.heads
and self.config.heads.get('predicted_aligned_error.weight', 0.0)):
# Add PredictedAlignedErrorHead after StructureModule executes.
name = 'predicted_aligned_error'
# Feed all previous results to give access to structure_module result.
head_config, module = heads[name]
ret[name] = module(representations, batch, is_training)
if compute_loss:
total_loss += loss(module, head_config, ret, name, filter_ret=False)
if compute_loss:
return ret, total_loss
else:
return ret
class AlphaFold(hk.Module):
"""AlphaFold model with recycling.
Jumper et al. (2021) Suppl. Alg. 2 "Inference"
"""
def __init__(self, config, name='alphafold'):
super().__init__(name=name)
self.config = config
self.global_config = config.global_config
def __call__(
self,
batch,
is_training,
compute_loss=False,
ensemble_representations=False,
return_representations=False):
"""Run the AlphaFold model.
Arguments:
batch: Dictionary with inputs to the AlphaFold model.
is_training: Whether the system is in training or inference mode.
compute_loss: Whether to compute losses (requires extra features
to be present in the batch and knowing the true structure).
ensemble_representations: Whether to use ensembling of representations.
return_representations: Whether to also return the intermediate
representations.
Returns:
When compute_loss is True:
a tuple of loss and output of AlphaFoldIteration.
When compute_loss is False:
just output of AlphaFoldIteration.
The output of AlphaFoldIteration is a nested dictionary containing
predictions from the various heads.
"""
if "scale_rate" not in batch:
batch["scale_rate"] = jnp.ones((1,))
impl = AlphaFoldIteration(self.config, self.global_config)
if jnp.issubdtype(batch['aatype'].dtype, jnp.integer):
batch_size, num_residues = batch['aatype'].shape
else:
batch_size, num_residues, _ = batch['aatype'].shape
def get_prev(ret):
new_prev = {
'prev_msa_first_row': ret['representations']['msa_first_row'],
'prev_pair': ret['representations']['pair'],
'prev_dgram': ret["distogram"]["logits"],
}
if self.config.use_struct:
new_prev.update({'prev_pos': ret['structure_module']['final_atom_positions'],
'prev_plddt': ret["predicted_lddt"]["logits"]})
if "predicted_aligned_error" in ret:
new_prev["prev_pae"] = ret["predicted_aligned_error"]["logits"]
if not self.config.backprop_recycle:
for k in ["prev_pos","prev_msa_first_row","prev_pair"]:
if k in new_prev:
new_prev[k] = jax.lax.stop_gradient(new_prev[k])
return new_prev
def do_call(prev,
recycle_idx,
compute_loss=compute_loss):
if self.config.resample_msa_in_recycling:
num_ensemble = batch_size // (self.config.num_recycle + 1)
def slice_recycle_idx(x):
start = recycle_idx * num_ensemble
size = num_ensemble
return jax.lax.dynamic_slice_in_dim(x, start, size, axis=0)
ensembled_batch = jax.tree_map(slice_recycle_idx, batch)
else:
num_ensemble = batch_size
ensembled_batch = batch
non_ensembled_batch = jax.tree_map(lambda x: x, prev)
return impl(ensembled_batch=ensembled_batch,
non_ensembled_batch=non_ensembled_batch,
is_training=is_training,
compute_loss=compute_loss,
ensemble_representations=ensemble_representations)
emb_config = self.config.embeddings_and_evoformer
prev = {
'prev_msa_first_row': jnp.zeros([num_residues, emb_config.msa_channel]),
'prev_pair': jnp.zeros([num_residues, num_residues, emb_config.pair_channel]),
'prev_dgram': jnp.zeros([num_residues, num_residues, 64]),
}
if self.config.use_struct:
prev.update({'prev_pos': jnp.zeros([num_residues, residue_constants.atom_type_num, 3]),
'prev_plddt': jnp.zeros([num_residues, 50]),
'prev_pae': jnp.zeros([num_residues, num_residues, 64])})
for k in ["pos","msa_first_row","pair","dgram"]:
if f"init_{k}" in batch: prev[f"prev_{k}"] = batch[f"init_{k}"][0]
if self.config.num_recycle:
if 'num_iter_recycling' in batch:
# Training time: num_iter_recycling is in batch.
# The value for each ensemble batch is the same, so arbitrarily taking
# 0-th.
num_iter = batch['num_iter_recycling'][0]
# Add insurance that we will not run more
# recyclings than the model is configured to run.
num_iter = jnp.minimum(num_iter, self.config.num_recycle)
else:
# Eval mode or tests: use the maximum number of iterations.
num_iter = self.config.num_recycle
def add_prev(p,p_):
p_["prev_dgram"] += p["prev_dgram"]
if self.config.use_struct:
p_["prev_plddt"] += p["prev_plddt"]
p_["prev_pae"] += p["prev_pae"]
return p_
##############################################################
def body(p, i):
p_ = get_prev(do_call(p, recycle_idx=i, compute_loss=False))
if self.config.add_prev:
p_ = add_prev(p, p_)
return p_, None
if hk.running_init():
prev,_ = body(prev, 0)
else:
prev,_ = hk.scan(body, prev, jnp.arange(num_iter))
##############################################################
else:
num_iter = 0
ret = do_call(prev=prev, recycle_idx=num_iter)
if self.config.add_prev:
prev_ = get_prev(ret)
if compute_loss:
ret = ret[0], [ret[1]]
if not return_representations:
del (ret[0] if compute_loss else ret)['representations'] # pytype: disable=unsupported-operands
if self.config.add_prev and num_iter > 0:
prev_ = add_prev(prev, prev_)
ret["distogram"]["logits"] = prev_["prev_dgram"]/(num_iter+1)
if self.config.use_struct:
ret["predicted_lddt"]["logits"] = prev_["prev_plddt"]/(num_iter+1)
if "predicted_aligned_error" in ret:
ret["predicted_aligned_error"]["logits"] = prev_["prev_pae"]/(num_iter+1)
return ret
class TemplatePairStack(hk.Module):
"""Pair stack for the templates.
Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack"
"""
def __init__(self, config, global_config, name='template_pair_stack'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, pair_act, pair_mask, is_training, safe_key=None, scale_rate=1.0):
"""Builds TemplatePairStack module.
Arguments:
pair_act: Pair activations for single template, shape [N_res, N_res, c_t].
pair_mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
safe_key: Safe key object encapsulating the random number generation key.
Returns:
Updated pair_act, shape [N_res, N_res, c_t].
"""
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
gc = self.global_config
c = self.config
if not c.num_block:
return pair_act
def block(x):
"""One block of the template pair stack."""
pair_act, safe_key = x
dropout_wrapper_fn = functools.partial(
dropout_wrapper, is_training=is_training, global_config=gc, scale_rate=scale_rate)
safe_key, *sub_keys = safe_key.split(6)
sub_keys = iter(sub_keys)
pair_act = dropout_wrapper_fn(
TriangleAttention(c.triangle_attention_starting_node, gc,
name='triangle_attention_starting_node'),
pair_act,
pair_mask,
next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleAttention(c.triangle_attention_ending_node, gc,
name='triangle_attention_ending_node'),
pair_act,
pair_mask,
next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleMultiplication(c.triangle_multiplication_outgoing, gc,
name='triangle_multiplication_outgoing'),
pair_act,
pair_mask,
next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleMultiplication(c.triangle_multiplication_incoming, gc,
name='triangle_multiplication_incoming'),
pair_act,
pair_mask,
next(sub_keys))
pair_act = dropout_wrapper_fn(
Transition(c.pair_transition, gc, name='pair_transition'),
pair_act,
pair_mask,
next(sub_keys))
return pair_act, safe_key
if gc.use_remat:
block = hk.remat(block)
res_stack = layer_stack.layer_stack(c.num_block)(block)
pair_act, safe_key = res_stack((pair_act, safe_key))
return pair_act
class Transition(hk.Module):
"""Transition layer.
Jumper et al. (2021) Suppl. Alg. 9 "MSATransition"
Jumper et al. (2021) Suppl. Alg. 15 "PairTransition"
"""
def __init__(self, config, global_config, name='transition_block'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, act, mask, is_training=True):
"""Builds Transition module.
Arguments:
act: A tensor of queries of size [batch_size, N_res, N_channel].
mask: A tensor denoting the mask of size [batch_size, N_res].
is_training: Whether the module is in training mode.
Returns:
A float32 tensor of size [batch_size, N_res, N_channel].
"""
_, _, nc = act.shape
num_intermediate = int(nc * self.config.num_intermediate_factor)
mask = jnp.expand_dims(mask, axis=-1)
act = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='input_layer_norm')(
act)
transition_module = hk.Sequential([
common_modules.Linear(
num_intermediate,
initializer='relu',
name='transition1'), jax.nn.relu,
common_modules.Linear(
nc,
initializer=utils.final_init(self.global_config),
name='transition2')
])
act = mapping.inference_subbatch(
transition_module,
self.global_config.subbatch_size,
batched_args=[act],
nonbatched_args=[],
low_memory=not is_training)
return act
def glorot_uniform():
return hk.initializers.VarianceScaling(scale=1.0,
mode='fan_avg',
distribution='uniform')
class Attention(hk.Module):
"""Multihead attention."""
def __init__(self, config, global_config, output_dim, name='attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.output_dim = output_dim
def __call__(self, q_data, m_data, bias, nonbatched_bias=None):
"""Builds Attention module.
Arguments:
q_data: A tensor of queries, shape [batch_size, N_queries, q_channels].
m_data: A tensor of memories from which the keys and values are
projected, shape [batch_size, N_keys, m_channels].
bias: A bias for the attention, shape [batch_size, N_queries, N_keys].
nonbatched_bias: Shared bias, shape [N_queries, N_keys].
Returns:
A float32 tensor of shape [batch_size, N_queries, output_dim].
"""
# Sensible default for when the config keys are missing
key_dim = self.config.get('key_dim', int(q_data.shape[-1]))
value_dim = self.config.get('value_dim', int(m_data.shape[-1]))
num_head = self.config.num_head
assert key_dim % num_head == 0
assert value_dim % num_head == 0
key_dim = key_dim // num_head
value_dim = value_dim // num_head
q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim),
init=glorot_uniform())
k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], num_head, key_dim),
init=glorot_uniform())
v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], num_head, value_dim),
init=glorot_uniform())
q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)
v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k) + bias
if nonbatched_bias is not None:
logits += jnp.expand_dims(nonbatched_bias, axis=0)
weights = jax.nn.softmax(logits)
weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)
if self.global_config.zero_init:
init = hk.initializers.Constant(0.0)
else:
init = glorot_uniform()
if self.config.gating:
gating_weights = hk.get_parameter(
'gating_w',
shape=(q_data.shape[-1], num_head, value_dim),
init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter(
'gating_b',
shape=(num_head, value_dim),
init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
gating_weights) + gating_bias
gate_values = jax.nn.sigmoid(gate_values)
weighted_avg *= gate_values
o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim),
init=init)
o_bias = hk.get_parameter('output_b', shape=(self.output_dim,),
init=hk.initializers.Constant(0.0))
output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
return output
class GlobalAttention(hk.Module):
"""Global attention.
Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" lines 2-7
"""
def __init__(self, config, global_config, output_dim, name='attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.output_dim = output_dim
def __call__(self, q_data, m_data, q_mask, bias):
"""Builds GlobalAttention module.
Arguments:
q_data: A tensor of queries with size [batch_size, N_queries,
q_channels]
m_data: A tensor of memories from which the keys and values
projected. Size [batch_size, N_keys, m_channels]
q_mask: A binary mask for q_data with zeros in the padded sequence
elements and ones otherwise. Size [batch_size, N_queries, q_channels]
(or broadcastable to this shape).
bias: A bias for the attention.
Returns:
A float32 tensor of size [batch_size, N_queries, output_dim].
"""
# Sensible default for when the config keys are missing
key_dim = self.config.get('key_dim', int(q_data.shape[-1]))
value_dim = self.config.get('value_dim', int(m_data.shape[-1]))
num_head = self.config.num_head
assert key_dim % num_head == 0
assert value_dim % num_head == 0
key_dim = key_dim // num_head
value_dim = value_dim // num_head
q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim),
init=glorot_uniform())
k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], key_dim),
init=glorot_uniform())
v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], value_dim),
init=glorot_uniform())
v = jnp.einsum('bka,ac->bkc', m_data, v_weights)
q_avg = utils.mask_mean(q_mask, q_data, axis=1)
q = jnp.einsum('ba,ahc->bhc', q_avg, q_weights) * key_dim**(-0.5)
k = jnp.einsum('bka,ac->bkc', m_data, k_weights)
bias = (1e9 * (q_mask[:, None, :, 0] - 1.))
logits = jnp.einsum('bhc,bkc->bhk', q, k) + bias
weights = jax.nn.softmax(logits)
weighted_avg = jnp.einsum('bhk,bkc->bhc', weights, v)
if self.global_config.zero_init:
init = hk.initializers.Constant(0.0)
else:
init = glorot_uniform()
o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim),
init=init)
o_bias = hk.get_parameter('output_b', shape=(self.output_dim,),
init=hk.initializers.Constant(0.0))
if self.config.gating:
gating_weights = hk.get_parameter(
'gating_w',
shape=(q_data.shape[-1], num_head, value_dim),
init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter(
'gating_b',
shape=(num_head, value_dim),
init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights)
gate_values = jax.nn.sigmoid(gate_values + gating_bias)
weighted_avg = weighted_avg[:, None] * gate_values
output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
else:
output = jnp.einsum('bhc,hco->bo', weighted_avg, o_weights) + o_bias
output = output[:, None]
return output
class MSARowAttentionWithPairBias(hk.Module):
"""MSA per-row attention biased by the pair representation.
Jumper et al. (2021) Suppl. Alg. 7 "MSARowAttentionWithPairBias"
"""
def __init__(self, config, global_config,
name='msa_row_attention_with_pair_bias'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
msa_act,
msa_mask,
pair_act,
is_training=False):
"""Builds MSARowAttentionWithPairBias module.
Arguments:
msa_act: [N_seq, N_res, c_m] MSA representation.
msa_mask: [N_seq, N_res] mask of non-padded regions.
pair_act: [N_res, N_res, c_z] pair representation.
is_training: Whether the module is in training mode.
Returns:
Update to msa_act, shape [N_seq, N_res, c_m].
"""
c = self.config
assert len(msa_act.shape) == 3
assert len(msa_mask.shape) == 2
assert c.orientation == 'per_row'
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = hk.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
pair_act = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='feat_2d_norm')(
pair_act)
init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1]))
weights = hk.get_parameter(
'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head),
init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
attn_mod = Attention(
c, self.global_config, msa_act.shape[-1])
msa_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[msa_act, msa_act, bias],
nonbatched_args=[nonbatched_bias],
low_memory=not is_training)
return msa_act
class MSAColumnAttention(hk.Module):
"""MSA per-column attention.
Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention"
"""
def __init__(self, config, global_config, name='msa_column_attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
msa_act,
msa_mask,
is_training=False):
"""Builds MSAColumnAttention module.
Arguments:
msa_act: [N_seq, N_res, c_m] MSA representation.
msa_mask: [N_seq, N_res] mask of non-padded regions.
is_training: Whether the module is in training mode.
Returns:
Update to msa_act, shape [N_seq, N_res, c_m]
"""
c = self.config
assert len(msa_act.shape) == 3
assert len(msa_mask.shape) == 2
assert c.orientation == 'per_column'
msa_act = jnp.swapaxes(msa_act, -2, -3)
msa_mask = jnp.swapaxes(msa_mask, -1, -2)
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = hk.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
attn_mod = Attention(
c, self.global_config, msa_act.shape[-1])
msa_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[msa_act, msa_act, bias],
nonbatched_args=[],
low_memory=not is_training)
msa_act = jnp.swapaxes(msa_act, -2, -3)
return msa_act
class MSAColumnGlobalAttention(hk.Module):
"""MSA per-column global attention.
Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention"
"""
def __init__(self, config, global_config, name='msa_column_global_attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
msa_act,
msa_mask,
is_training=False):
"""Builds MSAColumnGlobalAttention module.
Arguments:
msa_act: [N_seq, N_res, c_m] MSA representation.
msa_mask: [N_seq, N_res] mask of non-padded regions.
is_training: Whether the module is in training mode.
Returns:
Update to msa_act, shape [N_seq, N_res, c_m].
"""
c = self.config
assert len(msa_act.shape) == 3
assert len(msa_mask.shape) == 2
assert c.orientation == 'per_column'
msa_act = jnp.swapaxes(msa_act, -2, -3)
msa_mask = jnp.swapaxes(msa_mask, -1, -2)
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = hk.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
attn_mod = GlobalAttention(
c, self.global_config, msa_act.shape[-1],
name='attention')
# [N_seq, N_res, 1]
msa_mask = jnp.expand_dims(msa_mask, axis=-1)
msa_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[msa_act, msa_act, msa_mask, bias],
nonbatched_args=[],
low_memory=not is_training)
msa_act = jnp.swapaxes(msa_act, -2, -3)
return msa_act
class TriangleAttention(hk.Module):
"""Triangle Attention.
Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode"
Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode"
"""
def __init__(self, config, global_config, name='triangle_attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, pair_act, pair_mask, is_training=False):
"""Builds TriangleAttention module.
Arguments:
pair_act: [N_res, N_res, c_z] pair activations tensor
pair_mask: [N_res, N_res] mask of non-padded regions in the tensor.
is_training: Whether the module is in training mode.
Returns:
Update to pair_act, shape [N_res, N_res, c_z].
"""
c = self.config
assert len(pair_act.shape) == 3
assert len(pair_mask.shape) == 2
assert c.orientation in ['per_row', 'per_column']
if c.orientation == 'per_column':
pair_act = jnp.swapaxes(pair_act, -2, -3)
pair_mask = jnp.swapaxes(pair_mask, -1, -2)
bias = (1e9 * (pair_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
pair_act = hk.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
pair_act)
init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1]))
weights = hk.get_parameter(
'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head),
init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
attn_mod = Attention(
c, self.global_config, pair_act.shape[-1])
pair_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[pair_act, pair_act, bias],
nonbatched_args=[nonbatched_bias],
low_memory=not is_training)
if c.orientation == 'per_column':
pair_act = jnp.swapaxes(pair_act, -2, -3)
return pair_act
class MaskedMsaHead(hk.Module):
"""Head to predict MSA at the masked locations.
The MaskedMsaHead employs a BERT-style objective to reconstruct a masked
version of the full MSA, based on a linear projection of
the MSA representation.
Jumper et al. (2021) Suppl. Sec. 1.9.9 "Masked MSA prediction"
"""
def __init__(self, config, global_config, name='masked_msa_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, representations, batch, is_training):
"""Builds MaskedMsaHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'msa': MSA representation, shape [N_seq, N_res, c_m].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* 'logits': logits of shape [N_seq, N_res, N_aatype] with
(unnormalized) log probabilies of predicted aatype at position.
"""
del batch
logits = common_modules.Linear(
self.config.num_output,
initializer=utils.final_init(self.global_config),
name='logits')(
representations['msa'])
return dict(logits=logits)
def loss(self, value, batch):
errors = softmax_cross_entropy(
labels=jax.nn.one_hot(batch['true_msa'], num_classes=23),
logits=value['logits'])
loss = (jnp.sum(errors * batch['bert_mask'], axis=(-2, -1)) /
(1e-8 + jnp.sum(batch['bert_mask'], axis=(-2, -1))))
return {'loss': loss}
class PredictedLDDTHead(hk.Module):
"""Head to predict the per-residue LDDT to be used as a confidence measure.
Jumper et al. (2021) Suppl. Sec. 1.9.6 "Model confidence prediction (pLDDT)"
Jumper et al. (2021) Suppl. Alg. 29 "predictPerResidueLDDT_Ca"
"""
def __init__(self, config, global_config, name='predicted_lddt_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, representations, batch, is_training):
"""Builds ExperimentallyResolvedHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'structure_module': Single representation from the structure module,
shape [N_res, c_s].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing :
* 'logits': logits of shape [N_res, N_bins] with
(unnormalized) log probabilies of binned predicted lDDT.
"""
act = representations['structure_module']
act = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='input_layer_norm')(
act)
act = common_modules.Linear(
self.config.num_channels,
initializer='relu',
name='act_0')(
act)
act = jax.nn.relu(act)
act = common_modules.Linear(
self.config.num_channels,
initializer='relu',
name='act_1')(
act)
act = jax.nn.relu(act)
logits = common_modules.Linear(
self.config.num_bins,
initializer=utils.final_init(self.global_config),
name='logits')(
act)
# Shape (batch_size, num_res, num_bins)
return dict(logits=logits)
def loss(self, value, batch):
# Shape (num_res, 37, 3)
pred_all_atom_pos = value['structure_module']['final_atom_positions']
# Shape (num_res, 37, 3)
true_all_atom_pos = batch['all_atom_positions']
# Shape (num_res, 37)
all_atom_mask = batch['all_atom_mask']
# Shape (num_res,)
lddt_ca = lddt.lddt(
# Shape (batch_size, num_res, 3)
predicted_points=pred_all_atom_pos[None, :, 1, :],
# Shape (batch_size, num_res, 3)
true_points=true_all_atom_pos[None, :, 1, :],
# Shape (batch_size, num_res, 1)
true_points_mask=all_atom_mask[None, :, 1:2].astype(jnp.float32),
cutoff=15.,
per_residue=True)[0]
lddt_ca = jax.lax.stop_gradient(lddt_ca)
num_bins = self.config.num_bins
bin_index = jnp.floor(lddt_ca * num_bins).astype(jnp.int32)
# protect against out of range for lddt_ca == 1
bin_index = jnp.minimum(bin_index, num_bins - 1)
lddt_ca_one_hot = jax.nn.one_hot(bin_index, num_classes=num_bins)
# Shape (num_res, num_channel)
logits = value['predicted_lddt']['logits']
errors = softmax_cross_entropy(labels=lddt_ca_one_hot, logits=logits)
# Shape (num_res,)
mask_ca = all_atom_mask[:, residue_constants.atom_order['CA']]
mask_ca = mask_ca.astype(jnp.float32)
loss = jnp.sum(errors * mask_ca) / (jnp.sum(mask_ca) + 1e-8)
if self.config.filter_by_resolution:
# NMR & distillation have resolution = 0
loss *= ((batch['resolution'] >= self.config.min_resolution)
& (batch['resolution'] <= self.config.max_resolution)).astype(
jnp.float32)
output = {'loss': loss}
return output
class PredictedAlignedErrorHead(hk.Module):
"""Head to predict the distance errors in the backbone alignment frames.
Can be used to compute predicted TM-Score.
Jumper et al. (2021) Suppl. Sec. 1.9.7 "TM-score prediction"
"""
def __init__(self, config, global_config,
name='predicted_aligned_error_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, representations, batch, is_training):
"""Builds PredictedAlignedErrorHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'pair': pair representation, shape [N_res, N_res, c_z].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* logits: logits for aligned error, shape [N_res, N_res, N_bins].
* bin_breaks: array containing bin breaks, shape [N_bins - 1].
"""
act = representations['pair']
# Shape (num_res, num_res, num_bins)
logits = common_modules.Linear(
self.config.num_bins,
initializer=utils.final_init(self.global_config),
name='logits')(act)
# Shape (num_bins,)
breaks = jnp.linspace(
0., self.config.max_error_bin, self.config.num_bins - 1)
return dict(logits=logits, breaks=breaks)
def loss(self, value, batch):
# Shape (num_res, 7)
predicted_affine = quat_affine.QuatAffine.from_tensor(
value['structure_module']['final_affines'])
# Shape (num_res, 7)
true_affine = quat_affine.QuatAffine.from_tensor(
batch['backbone_affine_tensor'])
# Shape (num_res)
mask = batch['backbone_affine_mask']
# Shape (num_res, num_res)
square_mask = mask[:, None] * mask[None, :]
num_bins = self.config.num_bins
# (1, num_bins - 1)
breaks = value['predicted_aligned_error']['breaks']
# (1, num_bins)
logits = value['predicted_aligned_error']['logits']
# Compute the squared error for each alignment.
def _local_frame_points(affine):
points = [jnp.expand_dims(x, axis=-2) for x in affine.translation]
return affine.invert_point(points, extra_dims=1)
error_dist2_xyz = [
jnp.square(a - b)
for a, b in zip(_local_frame_points(predicted_affine),
_local_frame_points(true_affine))]
error_dist2 = sum(error_dist2_xyz)
# Shape (num_res, num_res)
# First num_res are alignment frames, second num_res are the residues.
error_dist2 = jax.lax.stop_gradient(error_dist2)
sq_breaks = jnp.square(breaks)
true_bins = jnp.sum((
error_dist2[..., None] > sq_breaks).astype(jnp.int32), axis=-1)
errors = softmax_cross_entropy(
labels=jax.nn.one_hot(true_bins, num_bins, axis=-1), logits=logits)
loss = (jnp.sum(errors * square_mask, axis=(-2, -1)) /
(1e-8 + jnp.sum(square_mask, axis=(-2, -1))))
if self.config.filter_by_resolution:
# NMR & distillation have resolution = 0
loss *= ((batch['resolution'] >= self.config.min_resolution)
& (batch['resolution'] <= self.config.max_resolution)).astype(
jnp.float32)
output = {'loss': loss}
return output
class ExperimentallyResolvedHead(hk.Module):
"""Predicts if an atom is experimentally resolved in a high-res structure.
Only trained on high-resolution X-ray crystals & cryo-EM.
Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction'
"""
def __init__(self, config, global_config,
name='experimentally_resolved_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, representations, batch, is_training):
"""Builds ExperimentallyResolvedHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'single': Single representation, shape [N_res, c_s].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* 'logits': logits of shape [N_res, 37],
log probability that an atom is resolved in atom37 representation,
can be converted to probability by applying sigmoid.
"""
logits = common_modules.Linear(
37, # atom_exists.shape[-1]
initializer=utils.final_init(self.global_config),
name='logits')(representations['single'])
return dict(logits=logits)
def loss(self, value, batch):
logits = value['logits']
assert len(logits.shape) == 2
# Does the atom appear in the amino acid?
atom_exists = batch['atom37_atom_exists']
# Is the atom resolved in the experiment? Subset of atom_exists,
# *except for OXT*
all_atom_mask = batch['all_atom_mask'].astype(jnp.float32)
xent = sigmoid_cross_entropy(labels=all_atom_mask, logits=logits)
loss = jnp.sum(xent * atom_exists) / (1e-8 + jnp.sum(atom_exists))
if self.config.filter_by_resolution:
# NMR & distillation examples have resolution = 0.
loss *= ((batch['resolution'] >= self.config.min_resolution)
& (batch['resolution'] <= self.config.max_resolution)).astype(
jnp.float32)
output = {'loss': loss}
return output
class TriangleMultiplication(hk.Module):
"""Triangle multiplication layer ("outgoing" or "incoming").
Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing"
Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming"
"""
def __init__(self, config, global_config, name='triangle_multiplication'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, act, mask, is_training=True):
"""Builds TriangleMultiplication module.
Arguments:
act: Pair activations, shape [N_res, N_res, c_z]
mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
Returns:
Outputs, same shape/type as act.
"""
del is_training
c = self.config
gc = self.global_config
mask = mask[..., None]
act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
name='layer_norm_input')(act)
input_act = act
left_projection = common_modules.Linear(
c.num_intermediate_channel,
name='left_projection')
left_proj_act = mask * left_projection(act)
right_projection = common_modules.Linear(
c.num_intermediate_channel,
name='right_projection')
right_proj_act = mask * right_projection(act)
left_gate_values = jax.nn.sigmoid(common_modules.Linear(
c.num_intermediate_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='left_gate')(act))
right_gate_values = jax.nn.sigmoid(common_modules.Linear(
c.num_intermediate_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='right_gate')(act))
left_proj_act *= left_gate_values
right_proj_act *= right_gate_values
# "Outgoing" edges equation: 'ikc,jkc->ijc'
# "Incoming" edges equation: 'kjc,kic->ijc'
# Note on the Suppl. Alg. 11 & 12 notation:
# For the "outgoing" edges, a = left_proj_act and b = right_proj_act
# For the "incoming" edges, it's swapped:
# b = left_proj_act and a = right_proj_act
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='center_layer_norm')(
act)
output_channel = int(input_act.shape[-1])
act = common_modules.Linear(
output_channel,
initializer=utils.final_init(gc),
name='output_projection')(act)
gate_values = jax.nn.sigmoid(common_modules.Linear(
output_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='gating_linear')(input_act))
act *= gate_values
return act
class DistogramHead(hk.Module):
"""Head to predict a distogram.
Jumper et al. (2021) Suppl. Sec. 1.9.8 "Distogram prediction"
"""
def __init__(self, config, global_config, name='distogram_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, representations, batch, is_training):
"""Builds DistogramHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'pair': pair representation, shape [N_res, N_res, c_z].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* logits: logits for distogram, shape [N_res, N_res, N_bins].
* bin_breaks: array containing bin breaks, shape [N_bins - 1,].
"""
half_logits = common_modules.Linear(
self.config.num_bins,
initializer=utils.final_init(self.global_config),
name='half_logits')(
representations['pair'])
logits = half_logits + jnp.swapaxes(half_logits, -2, -3)
breaks = jnp.linspace(self.config.first_break, self.config.last_break,
self.config.num_bins - 1)
return dict(logits=logits, bin_edges=breaks)
def loss(self, value, batch):
return _distogram_log_loss(value['logits'], value['bin_edges'],
batch, self.config.num_bins)
def _distogram_log_loss(logits, bin_edges, batch, num_bins):
"""Log loss of a distogram."""
assert len(logits.shape) == 3
positions = batch['pseudo_beta']
mask = batch['pseudo_beta_mask']
assert positions.shape[-1] == 3
sq_breaks = jnp.square(bin_edges)
dist2 = jnp.sum(
jnp.square(
jnp.expand_dims(positions, axis=-2) -
jnp.expand_dims(positions, axis=-3)),
axis=-1,
keepdims=True)
true_bins = jnp.sum(dist2 > sq_breaks, axis=-1)
errors = softmax_cross_entropy(
labels=jax.nn.one_hot(true_bins, num_bins), logits=logits)
square_mask = jnp.expand_dims(mask, axis=-2) * jnp.expand_dims(mask, axis=-1)
avg_error = (
jnp.sum(errors * square_mask, axis=(-2, -1)) /
(1e-6 + jnp.sum(square_mask, axis=(-2, -1))))
dist2 = dist2[..., 0]
return dict(loss=avg_error, true_dist=jnp.sqrt(1e-6 + dist2))
class OuterProductMean(hk.Module):
"""Computes mean outer product.
Jumper et al. (2021) Suppl. Alg. 10 "OuterProductMean"
"""
def __init__(self,
config,
global_config,
num_output_channel,
name='outer_product_mean'):
super().__init__(name=name)
self.global_config = global_config
self.config = config
self.num_output_channel = num_output_channel
def __call__(self, act, mask, is_training=True):
"""Builds OuterProductMean module.
Arguments:
act: MSA representation, shape [N_seq, N_res, c_m].
mask: MSA mask, shape [N_seq, N_res].
is_training: Whether the module is in training mode.
Returns:
Update to pair representation, shape [N_res, N_res, c_z].
"""
gc = self.global_config
c = self.config
mask = mask[..., None]
act = hk.LayerNorm([-1], True, True, name='layer_norm_input')(act)
left_act = mask * common_modules.Linear(
c.num_outer_channel,
initializer='linear',
name='left_projection')(
act)
right_act = mask * common_modules.Linear(
c.num_outer_channel,
initializer='linear',
name='right_projection')(
act)
if gc.zero_init:
init_w = hk.initializers.Constant(0.0)
else:
init_w = hk.initializers.VarianceScaling(scale=2., mode='fan_in')
output_w = hk.get_parameter(
'output_w',
shape=(c.num_outer_channel, c.num_outer_channel,
self.num_output_channel),
init=init_w)
output_b = hk.get_parameter(
'output_b', shape=(self.num_output_channel,),
init=hk.initializers.Constant(0.0))
def compute_chunk(left_act):
# This is equivalent to
#
# act = jnp.einsum('abc,ade->dceb', left_act, right_act)
# act = jnp.einsum('dceb,cef->bdf', act, output_w) + output_b
#
# but faster.
left_act = jnp.transpose(left_act, [0, 2, 1])
act = jnp.einsum('acb,ade->dceb', left_act, right_act)
act = jnp.einsum('dceb,cef->dbf', act, output_w) + output_b
return jnp.transpose(act, [1, 0, 2])
act = mapping.inference_subbatch(
compute_chunk,
c.chunk_size,
batched_args=[left_act],
nonbatched_args=[],
low_memory=True,
input_subbatch_dim=1,
output_subbatch_dim=0)
epsilon = 1e-3
norm = jnp.einsum('abc,adc->bdc', mask, mask)
act /= epsilon + norm
return act
def dgram_from_positions(positions, num_bins, min_bin, max_bin):
"""Compute distogram from amino acid positions.
Arguments:
positions: [N_res, 3] Position coordinates.
num_bins: The number of bins in the distogram.
min_bin: The left edge of the first bin.
max_bin: The left edge of the final bin. The final bin catches
everything larger than `max_bin`.
Returns:
Distogram with the specified number of bins.
"""
def squared_difference(x, y):
return jnp.square(x - y)
lower_breaks = jnp.linspace(min_bin, max_bin, num_bins)
lower_breaks = jnp.square(lower_breaks)
upper_breaks = jnp.concatenate([lower_breaks[1:],jnp.array([1e8], dtype=jnp.float32)], axis=-1)
dist2 = jnp.sum(
squared_difference(
jnp.expand_dims(positions, axis=-2),
jnp.expand_dims(positions, axis=-3)),
axis=-1, keepdims=True)
return ((dist2 > lower_breaks).astype(jnp.float32) * (dist2 < upper_breaks).astype(jnp.float32))
def dgram_from_positions_soft(positions, num_bins, min_bin, max_bin, temp=2.0):
'''soft positions to dgram converter'''
lower_breaks = jnp.append(-1e8,jnp.linspace(min_bin, max_bin, num_bins))
upper_breaks = jnp.append(lower_breaks[1:],1e8)
dist = jnp.sqrt(jnp.square(positions[...,:,None,:] - positions[...,None,:,:]).sum(-1,keepdims=True) + 1e-8)
o = jax.nn.sigmoid((dist - lower_breaks)/temp) * jax.nn.sigmoid((upper_breaks - dist)/temp)
o = o/(o.sum(-1,keepdims=True) + 1e-8)
return o[...,1:]
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
"""Create pseudo beta features."""
ca_idx = residue_constants.atom_order['CA']
cb_idx = residue_constants.atom_order['CB']
if jnp.issubdtype(aatype.dtype, jnp.integer):
is_gly = jnp.equal(aatype, residue_constants.restype_order['G'])
is_gly_tile = jnp.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3])
pseudo_beta = jnp.where(is_gly_tile, all_atom_positions[..., ca_idx, :], all_atom_positions[..., cb_idx, :])
if all_atom_masks is not None:
pseudo_beta_mask = jnp.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
pseudo_beta_mask = pseudo_beta_mask.astype(jnp.float32)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
else:
is_gly = aatype[...,residue_constants.restype_order['G']]
ca_pos = all_atom_positions[...,ca_idx,:]
cb_pos = all_atom_positions[...,cb_idx,:]
pseudo_beta = is_gly[...,None] * ca_pos + (1-is_gly[...,None]) * cb_pos
if all_atom_masks is not None:
ca_mask = all_atom_masks[...,ca_idx]
cb_mask = all_atom_masks[...,cb_idx]
pseudo_beta_mask = is_gly * ca_mask + (1-is_gly) * cb_mask
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
class EvoformerIteration(hk.Module):
"""Single iteration (block) of Evoformer stack.
Jumper et al. (2021) Suppl. Alg. 6 "EvoformerStack" lines 2-10
"""
def __init__(self, config, global_config, is_extra_msa,
name='evoformer_iteration'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.is_extra_msa = is_extra_msa
def __call__(self, activations, masks, is_training=True, safe_key=None, scale_rate=1.0):
"""Builds EvoformerIteration module.
Arguments:
activations: Dictionary containing activations:
* 'msa': MSA activations, shape [N_seq, N_res, c_m].
* 'pair': pair activations, shape [N_res, N_res, c_z].
masks: Dictionary of masks:
* 'msa': MSA mask, shape [N_seq, N_res].
* 'pair': pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
safe_key: prng.SafeKey encapsulating rng key.
Returns:
Outputs, same shape/type as act.
"""
c = self.config
gc = self.global_config
msa_act, pair_act = activations['msa'], activations['pair']
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
msa_mask, pair_mask = masks['msa'], masks['pair']
dropout_wrapper_fn = functools.partial(
dropout_wrapper,
is_training=is_training,
global_config=gc,
scale_rate=scale_rate)
safe_key, *sub_keys = safe_key.split(10)
sub_keys = iter(sub_keys)
msa_act = dropout_wrapper_fn(
MSARowAttentionWithPairBias(
c.msa_row_attention_with_pair_bias, gc,
name='msa_row_attention_with_pair_bias'),
msa_act,
msa_mask,
safe_key=next(sub_keys),
pair_act=pair_act)
if not self.is_extra_msa:
attn_mod = MSAColumnAttention(
c.msa_column_attention, gc, name='msa_column_attention')
else:
attn_mod = MSAColumnGlobalAttention(
c.msa_column_attention, gc, name='msa_column_global_attention')
msa_act = dropout_wrapper_fn(
attn_mod,
msa_act,
msa_mask,
safe_key=next(sub_keys))
msa_act = dropout_wrapper_fn(
Transition(c.msa_transition, gc, name='msa_transition'),
msa_act,
msa_mask,
safe_key=next(sub_keys))
pair_act = dropout_wrapper_fn(
OuterProductMean(
config=c.outer_product_mean,
global_config=self.global_config,
num_output_channel=int(pair_act.shape[-1]),
name='outer_product_mean'),
msa_act,
msa_mask,
safe_key=next(sub_keys),
output_act=pair_act)
pair_act = dropout_wrapper_fn(
TriangleMultiplication(c.triangle_multiplication_outgoing, gc,
name='triangle_multiplication_outgoing'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleMultiplication(c.triangle_multiplication_incoming, gc,
name='triangle_multiplication_incoming'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleAttention(c.triangle_attention_starting_node, gc,
name='triangle_attention_starting_node'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleAttention(c.triangle_attention_ending_node, gc,
name='triangle_attention_ending_node'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
pair_act = dropout_wrapper_fn(
Transition(c.pair_transition, gc, name='pair_transition'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
return {'msa': msa_act, 'pair': pair_act}
class EmbeddingsAndEvoformer(hk.Module):
"""Embeds the input data and runs Evoformer.
Produces the MSA, single and pair representations.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5-18
"""
def __init__(self, config, global_config, name='evoformer'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, batch, is_training, safe_key=None):
c = self.config
gc = self.global_config
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
# Embed clustered MSA.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5
# Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder"
preprocess_1d = common_modules.Linear(
c.msa_channel, name='preprocess_1d')(
batch['target_feat'])
preprocess_msa = common_modules.Linear(
c.msa_channel, name='preprocess_msa')(
batch['msa_feat'])
msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
left_single = common_modules.Linear(
c.pair_channel, name='left_single')(
batch['target_feat'])
right_single = common_modules.Linear(
c.pair_channel, name='right_single')(
batch['target_feat'])
pair_activations = left_single[:, None] + right_single[None]
mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
# Inject previous outputs for recycling.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6
# Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder"
if "prev_pos" in batch:
# use predicted position input
prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None)
if c.backprop_dgram:
dgram = dgram_from_positions_soft(prev_pseudo_beta, temp=c.backprop_dgram_temp, **c.prev_pos)
else:
dgram = dgram_from_positions(prev_pseudo_beta, **c.prev_pos)
elif 'prev_dgram' in batch:
# use predicted distogram input (from Sergey)
dgram = jax.nn.softmax(batch["prev_dgram"])
dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0)
dgram = dgram @ dgram_map
pair_activations += common_modules.Linear(c.pair_channel, name='prev_pos_linear')(dgram)
if c.recycle_features:
if 'prev_msa_first_row' in batch:
prev_msa_first_row = hk.LayerNorm([-1],
True,
True,
name='prev_msa_first_row_norm')(
batch['prev_msa_first_row'])
msa_activations = msa_activations.at[0].add(prev_msa_first_row)
if 'prev_pair' in batch:
pair_activations += hk.LayerNorm([-1],
True,
True,
name='prev_pair_norm')(
batch['prev_pair'])
# Relative position encoding.
# Jumper et al. (2021) Suppl. Alg. 4 "relpos"
# Jumper et al. (2021) Suppl. Alg. 5 "one_hot"
if c.max_relative_feature:
# Add one-hot-encoded clipped residue distances to the pair activations.
if "rel_pos" in batch:
rel_pos = batch['rel_pos']
else:
if "offset" in batch:
offset = batch['offset']
else:
pos = batch['residue_index']
offset = pos[:, None] - pos[None, :]
rel_pos = jax.nn.one_hot(
jnp.clip(
offset + c.max_relative_feature,
a_min=0,
a_max=2 * c.max_relative_feature),
2 * c.max_relative_feature + 1)
pair_activations += common_modules.Linear(c.pair_channel, name='pair_activiations')(rel_pos)
# Embed templates into the pair activations.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13
if c.template.enabled:
template_batch = {k: batch[k] for k in batch if k.startswith('template_')}
template_pair_representation = TemplateEmbedding(c.template, gc)(
pair_activations,
template_batch,
mask_2d,
is_training=is_training,
scale_rate=batch["scale_rate"])
pair_activations += template_pair_representation
# Embed extra MSA features.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16
extra_msa_feat = create_extra_msa_feature(batch)
extra_msa_activations = common_modules.Linear(
c.extra_msa_channel,
name='extra_msa_activations')(
extra_msa_feat)
# Extra MSA Stack.
# Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack"
extra_msa_stack_input = {
'msa': extra_msa_activations,
'pair': pair_activations,
}
extra_msa_stack_iteration = EvoformerIteration(
c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack')
def extra_msa_stack_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
extra_evoformer_output = extra_msa_stack_iteration(
activations=act,
masks={
'msa': batch['extra_msa_mask'],
'pair': mask_2d
},
is_training=is_training,
safe_key=safe_subkey, scale_rate=batch["scale_rate"])
return (extra_evoformer_output, safe_key)
if gc.use_remat:
extra_msa_stack_fn = hk.remat(extra_msa_stack_fn)
extra_msa_stack = layer_stack.layer_stack(
c.extra_msa_stack_num_block)(
extra_msa_stack_fn)
extra_msa_output, safe_key = extra_msa_stack(
(extra_msa_stack_input, safe_key))
pair_activations = extra_msa_output['pair']
evoformer_input = {
'msa': msa_activations,
'pair': pair_activations,
}
evoformer_masks = {'msa': batch['msa_mask'], 'pair': mask_2d}
####################################################################
####################################################################
# Append num_templ rows to msa_activations with template embeddings.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 7-8
if c.template.enabled and c.template.embed_torsion_angles:
if jnp.issubdtype(batch['template_aatype'].dtype, jnp.integer):
num_templ, num_res = batch['template_aatype'].shape
# Embed the templates aatypes.
aatype = batch['template_aatype']
aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1)
else:
num_templ, num_res, _ = batch['template_aatype'].shape
aatype = batch['template_aatype'].argmax(-1)
aatype_one_hot = batch['template_aatype']
# Embed the templates aatype, torsion angles and masks.
# Shape (templates, residues, msa_channels)
ret = all_atom.atom37_to_torsion_angles(
aatype=aatype,
all_atom_pos=batch['template_all_atom_positions'],
all_atom_mask=batch['template_all_atom_masks'],
# Ensure consistent behaviour during testing:
placeholder_for_undefined=not gc.zero_init)
template_features = jnp.concatenate([
aatype_one_hot,
jnp.reshape(ret['torsion_angles_sin_cos'], [num_templ, num_res, 14]),
jnp.reshape(ret['alt_torsion_angles_sin_cos'], [num_templ, num_res, 14]),
ret['torsion_angles_mask']], axis=-1)
template_activations = common_modules.Linear(
c.msa_channel,
initializer='relu',
name='template_single_embedding')(template_features)
template_activations = jax.nn.relu(template_activations)
template_activations = common_modules.Linear(
c.msa_channel,
initializer='relu',
name='template_projection')(template_activations)
# Concatenate the templates to the msa.
evoformer_input['msa'] = jnp.concatenate([evoformer_input['msa'], template_activations], axis=0)
# Concatenate templates masks to the msa masks.
# Use mask from the psi angle, as it only depends on the backbone atoms
# from a single residue.
torsion_angle_mask = ret['torsion_angles_mask'][:, :, 2]
torsion_angle_mask = torsion_angle_mask.astype(evoformer_masks['msa'].dtype)
evoformer_masks['msa'] = jnp.concatenate([evoformer_masks['msa'], torsion_angle_mask], axis=0)
####################################################################
####################################################################
# Main trunk of the network
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18
evoformer_iteration = EvoformerIteration(
c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')
def evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
evoformer_output = evoformer_iteration(
activations=act,
masks=evoformer_masks,
is_training=is_training,
safe_key=safe_subkey, scale_rate=batch["scale_rate"])
return (evoformer_output, safe_key)
if gc.use_remat:
evoformer_fn = hk.remat(evoformer_fn)
evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(evoformer_fn)
evoformer_output, safe_key = evoformer_stack((evoformer_input, safe_key))
msa_activations = evoformer_output['msa']
pair_activations = evoformer_output['pair']
single_activations = common_modules.Linear(
c.seq_channel, name='single_activations')(msa_activations[0])
num_sequences = batch['msa_feat'].shape[0]
output = {
'single': single_activations,
'pair': pair_activations,
# Crop away template rows such that they are not used in MaskedMsaHead.
'msa': msa_activations[:num_sequences, :, :],
'msa_first_row': msa_activations[0],
}
return output
####################################################################
####################################################################
class SingleTemplateEmbedding(hk.Module):
"""Embeds a single template.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9+11
"""
def __init__(self, config, global_config, name='single_template_embedding'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, query_embedding, batch, mask_2d, is_training, scale_rate=1.0):
"""Build the single template embedding.
Arguments:
query_embedding: Query pair representation, shape [N_res, N_res, c_z].
batch: A batch of template features (note the template dimension has been
stripped out as this module only runs over a single template).
mask_2d: Padding mask (Note: this doesn't care if a template exists,
unlike the template_pseudo_beta_mask).
is_training: Whether the module is in training mode.
Returns:
A template embedding [N_res, N_res, c_z].
"""
assert mask_2d.dtype == query_embedding.dtype
dtype = query_embedding.dtype
num_res = batch['template_aatype'].shape[0]
num_channels = (self.config.template_pair_stack
.triangle_attention_ending_node.value_dim)
template_mask = batch['template_pseudo_beta_mask']
template_mask_2d = template_mask[:, None] * template_mask[None, :]
template_mask_2d = template_mask_2d.astype(dtype)
if "template_dgram" in batch:
template_dgram = batch["template_dgram"]
else:
if self.config.backprop_dgram:
template_dgram = dgram_from_positions_soft(batch['template_pseudo_beta'],
temp=self.config.backprop_dgram_temp,
**self.config.dgram_features)
else:
template_dgram = dgram_from_positions(batch['template_pseudo_beta'],
**self.config.dgram_features)
template_dgram = template_dgram.astype(dtype)
to_concat = [template_dgram, template_mask_2d[:, :, None]]
if jnp.issubdtype(batch['template_aatype'].dtype, jnp.integer):
aatype = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1, dtype=dtype)
else:
aatype = batch['template_aatype']
to_concat.append(jnp.tile(aatype[None, :, :], [num_res, 1, 1]))
to_concat.append(jnp.tile(aatype[:, None, :], [1, num_res, 1]))
# Backbone affine mask: whether the residue has C, CA, N
# (the template mask defined above only considers pseudo CB).
n, ca, c = [residue_constants.atom_order[a] for a in ('N', 'CA', 'C')]
template_mask = (
batch['template_all_atom_masks'][..., n] *
batch['template_all_atom_masks'][..., ca] *
batch['template_all_atom_masks'][..., c])
template_mask_2d = template_mask[:, None] * template_mask[None, :]
# compute unit_vector (not used by default)
if self.config.use_template_unit_vector:
rot, trans = quat_affine.make_transform_from_reference(
n_xyz=batch['template_all_atom_positions'][:, n],
ca_xyz=batch['template_all_atom_positions'][:, ca],
c_xyz=batch['template_all_atom_positions'][:, c])
affines = quat_affine.QuatAffine(
quaternion=quat_affine.rot_to_quat(rot, unstack_inputs=True),
translation=trans,
rotation=rot,
unstack_inputs=True)
points = [jnp.expand_dims(x, axis=-2) for x in affines.translation]
affine_vec = affines.invert_point(points, extra_dims=1)
inv_distance_scalar = jax.lax.rsqrt(1e-6 + sum([jnp.square(x) for x in affine_vec]))
inv_distance_scalar *= template_mask_2d.astype(inv_distance_scalar.dtype)
unit_vector = [(x * inv_distance_scalar)[..., None] for x in affine_vec]
else:
unit_vector = [jnp.zeros((num_res,num_res,1))] * 3
unit_vector = [x.astype(dtype) for x in unit_vector]
to_concat.extend(unit_vector)
template_mask_2d = template_mask_2d.astype(dtype)
to_concat.append(template_mask_2d[..., None])
act = jnp.concatenate(to_concat, axis=-1)
# Mask out non-template regions so we don't get arbitrary values in the
# distogram for these regions.
act *= template_mask_2d[..., None]
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 9
act = common_modules.Linear(
num_channels,
initializer='relu',
name='embedding2d')(act)
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 11
act = TemplatePairStack(
self.config.template_pair_stack, self.global_config)(act, mask_2d, is_training, scale_rate=scale_rate)
act = hk.LayerNorm([-1], True, True, name='output_layer_norm')(act)
return act
class TemplateEmbedding(hk.Module):
"""Embeds a set of templates.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12
Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention"
"""
def __init__(self, config, global_config, name='template_embedding'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, query_embedding, template_batch, mask_2d, is_training, scale_rate=1.0):
"""Build TemplateEmbedding module.
Arguments:
query_embedding: Query pair representation, shape [N_res, N_res, c_z].
template_batch: A batch of template features.
mask_2d: Padding mask (Note: this doesn't care if a template exists,
unlike the template_pseudo_beta_mask).
is_training: Whether the module is in training mode.
Returns:
A template embedding [N_res, N_res, c_z].
"""
num_templates = template_batch['template_mask'].shape[0]
num_channels = (self.config.template_pair_stack
.triangle_attention_ending_node.value_dim)
num_res = query_embedding.shape[0]
dtype = query_embedding.dtype
template_mask = template_batch['template_mask']
template_mask = template_mask.astype(dtype)
query_num_channels = query_embedding.shape[-1]
# Make sure the weights are shared across templates by constructing the
# embedder here.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12
template_embedder = SingleTemplateEmbedding(self.config, self.global_config)
def map_fn(batch):
return template_embedder(query_embedding, batch, mask_2d, is_training, scale_rate=scale_rate)
template_pair_representation = mapping.sharded_map(map_fn, in_axes=0)(template_batch)
# Cross attend from the query to the templates along the residue
# dimension by flattening everything else into the batch dimension.
# Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention"
flat_query = jnp.reshape(query_embedding,[num_res * num_res, 1, query_num_channels])
flat_templates = jnp.reshape(
jnp.transpose(template_pair_representation, [1, 2, 0, 3]),
[num_res * num_res, num_templates, num_channels])
bias = (1e9 * (template_mask[None, None, None, :] - 1.))
template_pointwise_attention_module = Attention(
self.config.attention, self.global_config, query_num_channels)
nonbatched_args = [bias]
batched_args = [flat_query, flat_templates]
embedding = mapping.inference_subbatch(
template_pointwise_attention_module,
self.config.subbatch_size,
batched_args=batched_args,
nonbatched_args=nonbatched_args,
low_memory=not is_training)
embedding = jnp.reshape(embedding,[num_res, num_res, query_num_channels])
# No gradients if no templates.
embedding *= (jnp.sum(template_mask) > 0.).astype(embedding.dtype)
return embedding
####################################################################