Spaces:
Running
on
T4
Running
on
T4
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Quaternion geometry modules. | |
This introduces a representation of coordinate frames that is based around a | |
‘QuatAffine’ object. This object describes an array of coordinate frames. | |
It consists of vectors corresponding to the | |
origin of the frames as well as orientations which are stored in two | |
ways, as unit quaternions as well as a rotation matrices. | |
The rotation matrices are derived from the unit quaternions and the two are kept | |
in sync. | |
For an explanation of the relation between unit quaternions and rotations see | |
https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation | |
This representation is used in the model for the backbone frames. | |
One important thing to note here, is that while we update both representations | |
the jit compiler is going to ensure that only the parts that are | |
actually used are executed. | |
""" | |
import functools | |
from typing import Tuple | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
# pylint: disable=bad-whitespace | |
QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32) | |
QUAT_TO_ROT[0, 0] = [[ 1, 0, 0], [ 0, 1, 0], [ 0, 0, 1]] # rr | |
QUAT_TO_ROT[1, 1] = [[ 1, 0, 0], [ 0,-1, 0], [ 0, 0,-1]] # ii | |
QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [ 0, 1, 0], [ 0, 0,-1]] # jj | |
QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [ 0,-1, 0], [ 0, 0, 1]] # kk | |
QUAT_TO_ROT[1, 2] = [[ 0, 2, 0], [ 2, 0, 0], [ 0, 0, 0]] # ij | |
QUAT_TO_ROT[1, 3] = [[ 0, 0, 2], [ 0, 0, 0], [ 2, 0, 0]] # ik | |
QUAT_TO_ROT[2, 3] = [[ 0, 0, 0], [ 0, 0, 2], [ 0, 2, 0]] # jk | |
QUAT_TO_ROT[0, 1] = [[ 0, 0, 0], [ 0, 0,-2], [ 0, 2, 0]] # ir | |
QUAT_TO_ROT[0, 2] = [[ 0, 0, 2], [ 0, 0, 0], [-2, 0, 0]] # jr | |
QUAT_TO_ROT[0, 3] = [[ 0,-2, 0], [ 2, 0, 0], [ 0, 0, 0]] # kr | |
QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32) | |
QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0], | |
[ 0,-1, 0, 0], | |
[ 0, 0,-1, 0], | |
[ 0, 0, 0,-1]] | |
QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0], | |
[ 1, 0, 0, 0], | |
[ 0, 0, 0, 1], | |
[ 0, 0,-1, 0]] | |
QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0], | |
[ 0, 0, 0,-1], | |
[ 1, 0, 0, 0], | |
[ 0, 1, 0, 0]] | |
QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1], | |
[ 0, 0, 1, 0], | |
[ 0,-1, 0, 0], | |
[ 1, 0, 0, 0]] | |
QUAT_MULTIPLY_BY_VEC = QUAT_MULTIPLY[:, 1:, :] | |
# pylint: enable=bad-whitespace | |
def rot_to_quat(rot, unstack_inputs=False): | |
"""Convert rotation matrix to quaternion. | |
Note that this function calls self_adjoint_eig which is extremely expensive on | |
the GPU. If at all possible, this function should run on the CPU. | |
Args: | |
rot: rotation matrix (see below for format). | |
unstack_inputs: If true, rotation matrix should be shape (..., 3, 3) | |
otherwise the rotation matrix should be a list of lists of tensors. | |
Returns: | |
Quaternion as (..., 4) tensor. | |
""" | |
if unstack_inputs: | |
rot = [jnp.moveaxis(x, -1, 0) for x in jnp.moveaxis(rot, -2, 0)] | |
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot | |
# pylint: disable=bad-whitespace | |
k = [[ xx + yy + zz, zy - yz, xz - zx, yx - xy,], | |
[ zy - yz, xx - yy - zz, xy + yx, xz + zx,], | |
[ xz - zx, xy + yx, yy - xx - zz, yz + zy,], | |
[ yx - xy, xz + zx, yz + zy, zz - xx - yy,]] | |
# pylint: enable=bad-whitespace | |
k = (1./3.) * jnp.stack([jnp.stack(x, axis=-1) for x in k], | |
axis=-2) | |
# Get eigenvalues in non-decreasing order and associated. | |
_, qs = jnp.linalg.eigh(k) | |
return qs[..., -1] | |
def rot_list_to_tensor(rot_list): | |
"""Convert list of lists to rotation tensor.""" | |
return jnp.stack( | |
[jnp.stack(rot_list[0], axis=-1), | |
jnp.stack(rot_list[1], axis=-1), | |
jnp.stack(rot_list[2], axis=-1)], | |
axis=-2) | |
def vec_list_to_tensor(vec_list): | |
"""Convert list to vector tensor.""" | |
return jnp.stack(vec_list, axis=-1) | |
def quat_to_rot(normalized_quat): | |
"""Convert a normalized quaternion to a rotation matrix.""" | |
rot_tensor = jnp.sum( | |
np.reshape(QUAT_TO_ROT, (4, 4, 9)) * | |
normalized_quat[..., :, None, None] * | |
normalized_quat[..., None, :, None], | |
axis=(-3, -2)) | |
rot = jnp.moveaxis(rot_tensor, -1, 0) # Unstack. | |
return [[rot[0], rot[1], rot[2]], | |
[rot[3], rot[4], rot[5]], | |
[rot[6], rot[7], rot[8]]] | |
def quat_multiply_by_vec(quat, vec): | |
"""Multiply a quaternion by a pure-vector quaternion.""" | |
return jnp.sum( | |
QUAT_MULTIPLY_BY_VEC * | |
quat[..., :, None, None] * | |
vec[..., None, :, None], | |
axis=(-3, -2)) | |
def quat_multiply(quat1, quat2): | |
"""Multiply a quaternion by another quaternion.""" | |
return jnp.sum( | |
QUAT_MULTIPLY * | |
quat1[..., :, None, None] * | |
quat2[..., None, :, None], | |
axis=(-3, -2)) | |
def apply_rot_to_vec(rot, vec, unstack=False): | |
"""Multiply rotation matrix by a vector.""" | |
if unstack: | |
x, y, z = [vec[:, i] for i in range(3)] | |
else: | |
x, y, z = vec | |
return [rot[0][0] * x + rot[0][1] * y + rot[0][2] * z, | |
rot[1][0] * x + rot[1][1] * y + rot[1][2] * z, | |
rot[2][0] * x + rot[2][1] * y + rot[2][2] * z] | |
def apply_inverse_rot_to_vec(rot, vec): | |
"""Multiply the inverse of a rotation matrix by a vector.""" | |
# Inverse rotation is just transpose | |
return [rot[0][0] * vec[0] + rot[1][0] * vec[1] + rot[2][0] * vec[2], | |
rot[0][1] * vec[0] + rot[1][1] * vec[1] + rot[2][1] * vec[2], | |
rot[0][2] * vec[0] + rot[1][2] * vec[1] + rot[2][2] * vec[2]] | |
class QuatAffine(object): | |
"""Affine transformation represented by quaternion and vector.""" | |
def __init__(self, quaternion, translation, rotation=None, normalize=True, | |
unstack_inputs=False): | |
"""Initialize from quaternion and translation. | |
Args: | |
quaternion: Rotation represented by a quaternion, to be applied | |
before translation. Must be a unit quaternion unless normalize==True. | |
translation: Translation represented as a vector. | |
rotation: Same rotation as the quaternion, represented as a (..., 3, 3) | |
tensor. If None, rotation will be calculated from the quaternion. | |
normalize: If True, l2 normalize the quaternion on input. | |
unstack_inputs: If True, translation is a vector with last component 3 | |
""" | |
if quaternion is not None: | |
assert quaternion.shape[-1] == 4 | |
if unstack_inputs: | |
if rotation is not None: | |
rotation = [jnp.moveaxis(x, -1, 0) # Unstack. | |
for x in jnp.moveaxis(rotation, -2, 0)] # Unstack. | |
translation = jnp.moveaxis(translation, -1, 0) # Unstack. | |
if normalize and quaternion is not None: | |
quaternion = quaternion / jnp.linalg.norm(quaternion, axis=-1, | |
keepdims=True) | |
if rotation is None: | |
rotation = quat_to_rot(quaternion) | |
self.quaternion = quaternion | |
self.rotation = [list(row) for row in rotation] | |
self.translation = list(translation) | |
assert all(len(row) == 3 for row in self.rotation) | |
assert len(self.translation) == 3 | |
def to_tensor(self): | |
return jnp.concatenate( | |
[self.quaternion] + | |
[jnp.expand_dims(x, axis=-1) for x in self.translation], | |
axis=-1) | |
def apply_tensor_fn(self, tensor_fn): | |
"""Return a new QuatAffine with tensor_fn applied (e.g. stop_gradient).""" | |
return QuatAffine( | |
tensor_fn(self.quaternion), | |
[tensor_fn(x) for x in self.translation], | |
rotation=[[tensor_fn(x) for x in row] for row in self.rotation], | |
normalize=False) | |
def apply_rotation_tensor_fn(self, tensor_fn): | |
"""Return a new QuatAffine with tensor_fn applied to the rotation part.""" | |
return QuatAffine( | |
tensor_fn(self.quaternion), | |
[x for x in self.translation], | |
rotation=[[tensor_fn(x) for x in row] for row in self.rotation], | |
normalize=False) | |
def scale_translation(self, position_scale): | |
"""Return a new quat affine with a different scale for translation.""" | |
return QuatAffine( | |
self.quaternion, | |
[x * position_scale for x in self.translation], | |
rotation=[[x for x in row] for row in self.rotation], | |
normalize=False) | |
def from_tensor(cls, tensor, normalize=False): | |
quaternion, tx, ty, tz = jnp.split(tensor, [4, 5, 6], axis=-1) | |
return cls(quaternion, | |
[tx[..., 0], ty[..., 0], tz[..., 0]], | |
normalize=normalize) | |
def pre_compose(self, update): | |
"""Return a new QuatAffine which applies the transformation update first. | |
Args: | |
update: Length-6 vector. 3-vector of x, y, and z such that the quaternion | |
update is (1, x, y, z) and zero for the 3-vector is the identity | |
quaternion. 3-vector for translation concatenated. | |
Returns: | |
New QuatAffine object. | |
""" | |
vector_quaternion_update, x, y, z = jnp.split(update, [3, 4, 5], axis=-1) | |
trans_update = [jnp.squeeze(x, axis=-1), | |
jnp.squeeze(y, axis=-1), | |
jnp.squeeze(z, axis=-1)] | |
new_quaternion = (self.quaternion + | |
quat_multiply_by_vec(self.quaternion, | |
vector_quaternion_update)) | |
trans_update = apply_rot_to_vec(self.rotation, trans_update) | |
new_translation = [ | |
self.translation[0] + trans_update[0], | |
self.translation[1] + trans_update[1], | |
self.translation[2] + trans_update[2]] | |
return QuatAffine(new_quaternion, new_translation) | |
def apply_to_point(self, point, extra_dims=0): | |
"""Apply affine to a point. | |
Args: | |
point: List of 3 tensors to apply affine. | |
extra_dims: Number of dimensions at the end of the transformed_point | |
shape that are not present in the rotation and translation. The most | |
common use is rotation N points at once with extra_dims=1 for use in a | |
network. | |
Returns: | |
Transformed point after applying affine. | |
""" | |
rotation = self.rotation | |
translation = self.translation | |
for _ in range(extra_dims): | |
expand_fn = functools.partial(jnp.expand_dims, axis=-1) | |
rotation = jax.tree_map(expand_fn, rotation) | |
translation = jax.tree_map(expand_fn, translation) | |
rot_point = apply_rot_to_vec(rotation, point) | |
return [ | |
rot_point[0] + translation[0], | |
rot_point[1] + translation[1], | |
rot_point[2] + translation[2]] | |
def invert_point(self, transformed_point, extra_dims=0): | |
"""Apply inverse of transformation to a point. | |
Args: | |
transformed_point: List of 3 tensors to apply affine | |
extra_dims: Number of dimensions at the end of the transformed_point | |
shape that are not present in the rotation and translation. The most | |
common use is rotation N points at once with extra_dims=1 for use in a | |
network. | |
Returns: | |
Transformed point after applying affine. | |
""" | |
rotation = self.rotation | |
translation = self.translation | |
for _ in range(extra_dims): | |
expand_fn = functools.partial(jnp.expand_dims, axis=-1) | |
rotation = jax.tree_map(expand_fn, rotation) | |
translation = jax.tree_map(expand_fn, translation) | |
rot_point = [ | |
transformed_point[0] - translation[0], | |
transformed_point[1] - translation[1], | |
transformed_point[2] - translation[2]] | |
return apply_inverse_rot_to_vec(rotation, rot_point) | |
def __repr__(self): | |
return 'QuatAffine(%r, %r)' % (self.quaternion, self.translation) | |
def _multiply(a, b): | |
return jnp.stack([ | |
jnp.array([a[0][0]*b[0][0] + a[0][1]*b[1][0] + a[0][2]*b[2][0], | |
a[0][0]*b[0][1] + a[0][1]*b[1][1] + a[0][2]*b[2][1], | |
a[0][0]*b[0][2] + a[0][1]*b[1][2] + a[0][2]*b[2][2]]), | |
jnp.array([a[1][0]*b[0][0] + a[1][1]*b[1][0] + a[1][2]*b[2][0], | |
a[1][0]*b[0][1] + a[1][1]*b[1][1] + a[1][2]*b[2][1], | |
a[1][0]*b[0][2] + a[1][1]*b[1][2] + a[1][2]*b[2][2]]), | |
jnp.array([a[2][0]*b[0][0] + a[2][1]*b[1][0] + a[2][2]*b[2][0], | |
a[2][0]*b[0][1] + a[2][1]*b[1][1] + a[2][2]*b[2][1], | |
a[2][0]*b[0][2] + a[2][1]*b[1][2] + a[2][2]*b[2][2]])]) | |
def make_canonical_transform( | |
n_xyz: jnp.ndarray, | |
ca_xyz: jnp.ndarray, | |
c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: | |
"""Returns translation and rotation matrices to canonicalize residue atoms. | |
Note that this method does not take care of symmetries. If you provide the | |
atom positions in the non-standard way, the N atom will end up not at | |
[-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You | |
need to take care of such cases in your code. | |
Args: | |
n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates. | |
ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates. | |
c_xyz: An array of shape [batch, 3] of carbon xyz coordinates. | |
Returns: | |
A tuple (translation, rotation) where: | |
translation is an array of shape [batch, 3] defining the translation. | |
rotation is an array of shape [batch, 3, 3] defining the rotation. | |
After applying the translation and rotation to all atoms in a residue: | |
* All atoms will be shifted so that CA is at the origin, | |
* All atoms will be rotated so that C is at the x-axis, | |
* All atoms will be shifted so that N is in the xy plane. | |
""" | |
assert len(n_xyz.shape) == 2, n_xyz.shape | |
assert n_xyz.shape[-1] == 3, n_xyz.shape | |
assert n_xyz.shape == ca_xyz.shape == c_xyz.shape, ( | |
n_xyz.shape, ca_xyz.shape, c_xyz.shape) | |
# Place CA at the origin. | |
translation = -ca_xyz | |
n_xyz = n_xyz + translation | |
c_xyz = c_xyz + translation | |
# Place C on the x-axis. | |
c_x, c_y, c_z = [c_xyz[:, i] for i in range(3)] | |
# Rotate by angle c1 in the x-y plane (around the z-axis). | |
sin_c1 = -c_y / jnp.sqrt(1e-20 + c_x**2 + c_y**2) | |
cos_c1 = c_x / jnp.sqrt(1e-20 + c_x**2 + c_y**2) | |
zeros = jnp.zeros_like(sin_c1) | |
ones = jnp.ones_like(sin_c1) | |
# pylint: disable=bad-whitespace | |
c1_rot_matrix = jnp.stack([jnp.array([cos_c1, -sin_c1, zeros]), | |
jnp.array([sin_c1, cos_c1, zeros]), | |
jnp.array([zeros, zeros, ones])]) | |
# Rotate by angle c2 in the x-z plane (around the y-axis). | |
sin_c2 = c_z / jnp.sqrt(1e-20 + c_x**2 + c_y**2 + c_z**2) | |
cos_c2 = jnp.sqrt(c_x**2 + c_y**2) / jnp.sqrt( | |
1e-20 + c_x**2 + c_y**2 + c_z**2) | |
c2_rot_matrix = jnp.stack([jnp.array([cos_c2, zeros, sin_c2]), | |
jnp.array([zeros, ones, zeros]), | |
jnp.array([-sin_c2, zeros, cos_c2])]) | |
c_rot_matrix = _multiply(c2_rot_matrix, c1_rot_matrix) | |
n_xyz = jnp.stack(apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True)).T | |
# Place N in the x-y plane. | |
_, n_y, n_z = [n_xyz[:, i] for i in range(3)] | |
# Rotate by angle alpha in the y-z plane (around the x-axis). | |
sin_n = -n_z / jnp.sqrt(1e-20 + n_y**2 + n_z**2) | |
cos_n = n_y / jnp.sqrt(1e-20 + n_y**2 + n_z**2) | |
n_rot_matrix = jnp.stack([jnp.array([ones, zeros, zeros]), | |
jnp.array([zeros, cos_n, -sin_n]), | |
jnp.array([zeros, sin_n, cos_n])]) | |
# pylint: enable=bad-whitespace | |
return (translation, | |
jnp.transpose(_multiply(n_rot_matrix, c_rot_matrix), [2, 0, 1])) | |
def make_transform_from_reference( | |
n_xyz: jnp.ndarray, | |
ca_xyz: jnp.ndarray, | |
c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: | |
"""Returns rotation and translation matrices to convert from reference. | |
Note that this method does not take care of symmetries. If you provide the | |
atom positions in the non-standard way, the N atom will end up not at | |
[-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You | |
need to take care of such cases in your code. | |
Args: | |
n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates. | |
ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates. | |
c_xyz: An array of shape [batch, 3] of carbon xyz coordinates. | |
Returns: | |
A tuple (rotation, translation) where: | |
rotation is an array of shape [batch, 3, 3] defining the rotation. | |
translation is an array of shape [batch, 3] defining the translation. | |
After applying the translation and rotation to the reference backbone, | |
the coordinates will approximately equal to the input coordinates. | |
The order of translation and rotation differs from make_canonical_transform | |
because the rotation from this function should be applied before the | |
translation, unlike make_canonical_transform. | |
""" | |
translation, rotation = make_canonical_transform(n_xyz, ca_xyz, c_xyz) | |
return np.transpose(rotation, (0, 2, 1)), -translation | |