Simon Duerr
add fast af
85bd48b
raw
history blame
21.4 kB
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data for AlphaFold."""
from alphafold.common import residue_constants
from alphafold.model.tf import shape_helpers
from alphafold.model.tf import shape_placeholders
from alphafold.model.tf import utils
import numpy as np
import tensorflow.compat.v1 as tf
# Pylint gets confused by the curry1 decorator because it changes the number
# of arguments to the function.
# pylint:disable=no-value-for-parameter
NUM_RES = shape_placeholders.NUM_RES
NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ
NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ
NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def cast_64bit_ints(protein):
for k, v in protein.items():
if v.dtype == tf.int64:
protein[k] = tf.cast(v, tf.int32)
return protein
_MSA_FEATURE_NAMES = [
'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask',
'true_msa'
]
def make_seq_mask(protein):
protein['seq_mask'] = tf.ones(
shape_helpers.shape_list(protein['aatype']), dtype=tf.float32)
return protein
def make_template_mask(protein):
protein['template_mask'] = tf.ones(
shape_helpers.shape_list(protein['template_domain_names']),
dtype=tf.float32)
return protein
def curry1(f):
"""Supply all arguments but the first."""
def fc(*args, **kwargs):
return lambda x: f(x, *args, **kwargs)
return fc
@curry1
def add_distillation_flag(protein, distillation):
protein['is_distillation'] = tf.constant(float(distillation),
shape=[],
dtype=tf.float32)
return protein
def make_all_atom_aatype(protein):
protein['all_atom_aatype'] = protein['aatype']
return protein
def fix_templates_aatype(protein):
"""Fixes aatype encoding of templates."""
# Map one-hot to indices.
protein['template_aatype'] = tf.argmax(
protein['template_aatype'], output_type=tf.int32, axis=-1)
# Map hhsearch-aatype to our aatype.
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = tf.constant(new_order_list, dtype=tf.int32)
protein['template_aatype'] = tf.gather(params=new_order,
indices=protein['template_aatype'])
return protein
def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = tf.constant(new_order_list, dtype=protein['msa'].dtype)
protein['msa'] = tf.gather(new_order, protein['msa'], axis=0)
perm_matrix = np.zeros((22, 22), dtype=np.float32)
perm_matrix[range(len(new_order_list)), new_order_list] = 1.
for k in protein:
if 'profile' in k: # Include both hhblits and psiblast profiles
num_dim = protein[k].shape.as_list()[-1]
assert num_dim in [20, 21, 22], (
'num_dim for %s out of expected range: %s' % (k, num_dim))
protein[k] = tf.tensordot(protein[k], perm_matrix[:num_dim, :num_dim], 1)
return protein
def squeeze_features(protein):
"""Remove singleton and repeated dimensions in protein features."""
protein['aatype'] = tf.argmax(
protein['aatype'], axis=-1, output_type=tf.int32)
for k in [
'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence',
'superfamily', 'deletion_matrix', 'resolution',
'between_segment_residues', 'residue_index', 'template_all_atom_masks']:
if k in protein:
final_dim = shape_helpers.shape_list(protein[k])[-1]
if isinstance(final_dim, int) and final_dim == 1:
protein[k] = tf.squeeze(protein[k], axis=-1)
for k in ['seq_length', 'num_alignments']:
if k in protein:
protein[k] = protein[k][0] # Remove fake sequence dimension
return protein
def make_random_crop_to_size_seed(protein):
"""Random seed for cropping residues and templates."""
protein['random_crop_to_size_seed'] = utils.make_random_seed()
return protein
@curry1
def randomly_replace_msa_with_unknown(protein, replace_proportion):
"""Replace a proportion of the MSA with 'X'."""
msa_mask = (tf.random.uniform(shape_helpers.shape_list(protein['msa'])) <
replace_proportion)
x_idx = 20
gap_idx = 21
msa_mask = tf.logical_and(msa_mask, protein['msa'] != gap_idx)
protein['msa'] = tf.where(msa_mask,
tf.ones_like(protein['msa']) * x_idx,
protein['msa'])
aatype_mask = (
tf.random.uniform(shape_helpers.shape_list(protein['aatype'])) <
replace_proportion)
protein['aatype'] = tf.where(aatype_mask,
tf.ones_like(protein['aatype']) * x_idx,
protein['aatype'])
return protein
@curry1
def sample_msa(protein, max_seq, keep_extra):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
protein: batch to sample msa from.
max_seq: number of sequences to sample.
keep_extra: When True sequences not sampled are put into fields starting
with 'extra_*'.
Returns:
Protein with sampled msa.
"""
num_seq = tf.shape(protein['msa'])[0]
shuffled = tf.random_shuffle(tf.range(1, num_seq))
index_order = tf.concat([[0], shuffled], axis=0)
num_sel = tf.minimum(max_seq, num_seq)
sel_seq, not_sel_seq = tf.split(index_order, [num_sel, num_seq - num_sel])
for k in _MSA_FEATURE_NAMES:
if k in protein:
if keep_extra:
protein['extra_' + k] = tf.gather(protein[k], not_sel_seq)
protein[k] = tf.gather(protein[k], sel_seq)
return protein
@curry1
def crop_extra_msa(protein, max_extra_msa):
"""MSA features are cropped so only `max_extra_msa` sequences are kept."""
num_seq = tf.shape(protein['extra_msa'])[0]
num_sel = tf.minimum(max_extra_msa, num_seq)
select_indices = tf.random_shuffle(tf.range(0, num_seq))[:num_sel]
for k in _MSA_FEATURE_NAMES:
if 'extra_' + k in protein:
protein['extra_' + k] = tf.gather(protein['extra_' + k], select_indices)
return protein
def delete_extra_msa(protein):
for k in _MSA_FEATURE_NAMES:
if 'extra_' + k in protein:
del protein['extra_' + k]
return protein
@curry1
def block_delete_msa(protein, config):
"""Sample MSA by deleting contiguous blocks.
Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion"
Arguments:
protein: batch dict containing the msa
config: ConfigDict with parameters
Returns:
updated protein
"""
num_seq = shape_helpers.shape_list(protein['msa'])[0]
block_num_seq = tf.cast(
tf.floor(tf.cast(num_seq, tf.float32) * config.msa_fraction_per_block),
tf.int32)
if config.randomize_num_blocks:
nb = tf.random.uniform([], 0, config.num_blocks + 1, dtype=tf.int32)
else:
nb = config.num_blocks
del_block_starts = tf.random.uniform([nb], 0, num_seq, dtype=tf.int32)
del_blocks = del_block_starts[:, None] + tf.range(block_num_seq)
del_blocks = tf.clip_by_value(del_blocks, 0, num_seq - 1)
del_indices = tf.unique(tf.sort(tf.reshape(del_blocks, [-1])))[0]
# Make sure we keep the original sequence
sparse_diff = tf.sets.difference(tf.range(1, num_seq)[None],
del_indices[None])
keep_indices = tf.squeeze(tf.sparse.to_dense(sparse_diff), 0)
keep_indices = tf.concat([[0], keep_indices], axis=0)
for k in _MSA_FEATURE_NAMES:
if k in protein:
protein[k] = tf.gather(protein[k], keep_indices)
return protein
@curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask
weights = tf.concat([
tf.ones(21),
gap_agreement_weight * tf.ones(1),
np.zeros(1)], 0)
# Make agreement score as weighted Hamming distance
sample_one_hot = (protein['msa_mask'][:, :, None] *
tf.one_hot(protein['msa'], 23))
extra_one_hot = (protein['extra_msa_mask'][:, :, None] *
tf.one_hot(protein['extra_msa'], 23))
num_seq, num_res, _ = shape_helpers.shape_list(sample_one_hot)
extra_num_seq, _, _ = shape_helpers.shape_list(extra_one_hot)
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup.
agreement = tf.matmul(
tf.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
tf.reshape(sample_one_hot * weights, [num_seq, num_res * 23]),
transpose_b=True)
# Assign each sequence in the extra sequences to the closest MSA sample
protein['extra_cluster_assignment'] = tf.argmax(
agreement, axis=1, output_type=tf.int32)
return protein
@curry1
def summarize_clusters(protein):
"""Produce profile and deletion_matrix_mean within each cluster."""
num_seq = shape_helpers.shape_list(protein['msa'])[0]
def csum(x):
return tf.math.unsorted_segment_sum(
x, protein['extra_cluster_assignment'], num_seq)
mask = protein['extra_msa_mask']
mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center
msa_sum = csum(mask[:, :, None] * tf.one_hot(protein['extra_msa'], 23))
msa_sum += tf.one_hot(protein['msa'], 23) # Original sequence
protein['cluster_profile'] = msa_sum / mask_counts[:, :, None]
del msa_sum
del_sum = csum(mask * protein['extra_deletion_matrix'])
del_sum += protein['deletion_matrix'] # Original sequence
protein['cluster_deletion_mean'] = del_sum / mask_counts
del del_sum
return protein
def make_msa_mask(protein):
"""Mask features are all ones, but will later be zero-padded."""
protein['msa_mask'] = tf.ones(
shape_helpers.shape_list(protein['msa']), dtype=tf.float32)
protein['msa_row_mask'] = tf.ones(
shape_helpers.shape_list(protein['msa'])[0], dtype=tf.float32)
return protein
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
"""Create pseudo beta features."""
is_gly = tf.equal(aatype, residue_constants.restype_order['G'])
ca_idx = residue_constants.atom_order['CA']
cb_idx = residue_constants.atom_order['CB']
pseudo_beta = tf.where(
tf.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :])
if all_atom_masks is not None:
pseudo_beta_mask = tf.where(
is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
pseudo_beta_mask = tf.cast(pseudo_beta_mask, tf.float32)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
@curry1
def make_pseudo_beta(protein, prefix=''):
"""Create pseudo-beta (alpha for glycine) position and mask."""
assert prefix in ['', 'template_']
protein[prefix + 'pseudo_beta'], protein[prefix + 'pseudo_beta_mask'] = (
pseudo_beta_fn(
protein['template_aatype' if prefix else 'all_atom_aatype'],
protein[prefix + 'all_atom_positions'],
protein['template_all_atom_masks' if prefix else 'all_atom_mask']))
return protein
@curry1
def add_constant_field(protein, key, value):
protein[key] = tf.convert_to_tensor(value)
return protein
def shaped_categorical(probs, epsilon=1e-10):
ds = shape_helpers.shape_list(probs)
num_classes = ds[-1]
counts = tf.random.categorical(
tf.reshape(tf.log(probs + epsilon), [-1, num_classes]),
1,
dtype=tf.int32)
return tf.reshape(counts, ds[:-1])
def make_hhblits_profile(protein):
"""Compute the HHblits MSA profile if not already present."""
if 'hhblits_profile' in protein:
return protein
# Compute the profile for every residue (over all MSA sequences).
protein['hhblits_profile'] = tf.reduce_mean(
tf.one_hot(protein['msa'], 22), axis=0)
return protein
@curry1
def make_masked_msa(protein, config, replace_fraction):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly
random_aa = tf.constant([0.05] * 20 + [0., 0.], dtype=tf.float32)
categorical_probs = (
config.uniform_prob * random_aa +
config.profile_prob * protein['hhblits_profile'] +
config.same_prob * tf.one_hot(protein['msa'], 22))
# Put all remaining probability on [MASK] which is a new column
pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))]
pad_shapes[-1][1] = 1
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
assert mask_prob >= 0.
categorical_probs = tf.pad(
categorical_probs, pad_shapes, constant_values=mask_prob)
sh = shape_helpers.shape_list(protein['msa'])
mask_position = tf.random.uniform(sh) < replace_fraction
bert_msa = shaped_categorical(categorical_probs)
bert_msa = tf.where(mask_position, bert_msa, protein['msa'])
# Mix real and masked MSA
protein['bert_mask'] = tf.cast(mask_position, tf.float32)
protein['true_msa'] = protein['msa']
protein['msa'] = bert_msa
return protein
@curry1
def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size,
num_res, num_templates=0):
"""Guess at the MSA and sequence dimensions to make fixed size."""
pad_size_map = {
NUM_RES: num_res,
NUM_MSA_SEQ: msa_cluster_size,
NUM_EXTRA_SEQ: extra_msa_size,
NUM_TEMPLATES: num_templates,
}
for k, v in protein.items():
# Don't transfer this to the accelerator.
if k == 'extra_cluster_assignment':
continue
shape = v.shape.as_list()
schema = shape_schema[k]
assert len(shape) == len(schema), (
f'Rank mismatch between shape and shape schema for {k}: '
f'{shape} vs {schema}')
pad_size = [
pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
]
padding = [(0, p - tf.shape(v)[i]) for i, p in enumerate(pad_size)]
if padding:
protein[k] = tf.pad(
v, padding, name=f'pad_to_fixed_{k}')
protein[k].set_shape(pad_size)
return protein
@curry1
def make_msa_feat(protein):
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping
# for compatibility with domain datasets.
has_break = tf.clip_by_value(
tf.cast(protein['between_segment_residues'], tf.float32),
0, 1)
aatype_1hot = tf.one_hot(protein['aatype'], 21, axis=-1)
target_feat = [
tf.expand_dims(has_break, axis=-1),
aatype_1hot, # Everyone gets the original sequence.
]
msa_1hot = tf.one_hot(protein['msa'], 23, axis=-1)
has_deletion = tf.clip_by_value(protein['deletion_matrix'], 0., 1.)
deletion_value = tf.atan(protein['deletion_matrix'] / 3.) * (2. / np.pi)
msa_feat = [
msa_1hot,
tf.expand_dims(has_deletion, axis=-1),
tf.expand_dims(deletion_value, axis=-1),
]
if 'cluster_profile' in protein:
deletion_mean_value = (
tf.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi))
msa_feat.extend([
protein['cluster_profile'],
tf.expand_dims(deletion_mean_value, axis=-1),
])
if 'extra_deletion_matrix' in protein:
protein['extra_has_deletion'] = tf.clip_by_value(
protein['extra_deletion_matrix'], 0., 1.)
protein['extra_deletion_value'] = tf.atan(
protein['extra_deletion_matrix'] / 3.) * (2. / np.pi)
protein['msa_feat'] = tf.concat(msa_feat, axis=-1)
protein['target_feat'] = tf.concat(target_feat, axis=-1)
return protein
@curry1
def select_feat(protein, feature_list):
return {k: v for k, v in protein.items() if k in feature_list}
@curry1
def crop_templates(protein, max_templates):
for k, v in protein.items():
if k.startswith('template_'):
protein[k] = v[:max_templates]
return protein
@curry1
def random_crop_to_size(protein, crop_size, max_templates, shape_schema,
subsample_templates=False):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length = protein['seq_length']
if 'template_mask' in protein:
num_templates = tf.cast(
shape_helpers.shape_list(protein['template_mask'])[0], tf.int32)
else:
num_templates = tf.constant(0, dtype=tf.int32)
num_res_crop_size = tf.math.minimum(seq_length, crop_size)
# Ensures that the cropping of residues and templates happens in the same way
# across ensembling iterations.
# Do not use for randomness that should vary in ensembling.
seed_maker = utils.SeedMaker(initial_seed=protein['random_crop_to_size_seed'])
if subsample_templates:
templates_crop_start = tf.random.stateless_uniform(
shape=(), minval=0, maxval=num_templates + 1, dtype=tf.int32,
seed=seed_maker())
else:
templates_crop_start = 0
num_templates_crop_size = tf.math.minimum(
num_templates - templates_crop_start, max_templates)
num_res_crop_start = tf.random.stateless_uniform(
shape=(), minval=0, maxval=seq_length - num_res_crop_size + 1,
dtype=tf.int32, seed=seed_maker())
templates_select_indices = tf.argsort(tf.random.stateless_uniform(
[num_templates], seed=seed_maker()))
for k, v in protein.items():
if k not in shape_schema or (
'template' not in k and NUM_RES not in shape_schema[k]):
continue
# randomly permute the templates before cropping them.
if k.startswith('template') and subsample_templates:
v = tf.gather(v, templates_select_indices)
crop_sizes = []
crop_starts = []
for i, (dim_size, dim) in enumerate(zip(shape_schema[k],
shape_helpers.shape_list(v))):
is_num_res = (dim_size == NUM_RES)
if i == 0 and k.startswith('template'):
crop_size = num_templates_crop_size
crop_start = templates_crop_start
else:
crop_start = num_res_crop_start if is_num_res else 0
crop_size = (num_res_crop_size if is_num_res else
(-1 if dim is None else dim))
crop_sizes.append(crop_size)
crop_starts.append(crop_start)
protein[k] = tf.slice(v, crop_starts, crop_sizes)
protein['seq_length'] = num_res_crop_size
return protein
def make_atom14_masks(protein):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
restype_atom14_mask = []
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]]
restype_atom14_to_atom37.append([
(residue_constants.atom_order[name] if name else 0)
for name in atom_names
])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types
])
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = tf.gather(restype_atom14_to_atom37,
protein['aatype'])
residx_atom14_mask = tf.gather(restype_atom14_mask,
protein['aatype'])
protein['atom14_atom_exists'] = residx_atom14_mask
protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37
# create the gather indices for mapping back
residx_atom37_to_atom14 = tf.gather(restype_atom37_to_atom14,
protein['aatype'])
protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14
# create the corresponding mask
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = tf.gather(restype_atom37_mask,
protein['aatype'])
protein['atom37_atom_exists'] = residx_atom37_mask
return protein