Hukuna's picture
Upload 221 files
ce7bf5b verified
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers for batched 3D transformations, such as residue poses.
This module contains pytorch layers for computing and composing with
3D, 6-degree-of freedom transformations.
"""
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from chroma.layers import graph
from chroma.layers.structure import geometry
def compose_transforms(
R_a: torch.Tensor, t_a: torch.Tensor, R_b: torch.Tensor, t_b: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compose transforms `T_compose = T_a * T_b` (broadcastable).
Args:
R_a (torch.Tensor): Transform `T_a` rotation matrix with shape `(...,3,3)`.
t_a (torch.Tensor): Transform `T_a` translation with shape `(...,3)`.
R_b (torch.Tensor): Transform `T_b` rotation matrix with shape `(...,3,3)`.
t_b (torch.Tensor): Transform `T_b` translation with shape `(...,3)`.
Returns:
R_composed (torch.Tensor): Composed transform `a * b` rotation matrix with
shape `(...,3,3)`.
t_composed (torch.Tensor): Composed transform `a * b` translation vector with
shape `(...,3)`.
"""
R_composed = R_a @ R_b
t_composed = t_a + (R_a @ t_b.unsqueeze(-1)).squeeze(-1)
return R_composed, t_composed
def compose_translation(
R_a: torch.Tensor, t_a: torch.Tensor, t_b: torch.Tensor
) -> torch.Tensor:
"""Compose translation component of `T_compose = T_a * T_b` (broadcastable).
Args:
R_a (torch.Tensor): Transform `T_a` rotation matrix with shape `(...,3,3)`.
t_a (torch.Tensor): Transform `T_a` translation with shape `(...,3)`.
t_b (torch.Tensor): Transform `T_b` translation with shape `(...,3)`.
Returns:
t_composed (torch.Tensor): Composed transform `a * b` translation vector with
shape `(...,3)`.
"""
t_composed = t_a + (R_a @ t_b.unsqueeze(-1)).squeeze(-1)
return t_composed
def compose_inner_transforms(
R_a: torch.Tensor, t_a: torch.Tensor, R_b: torch.Tensor, t_b: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compose the relative inner transform `T_ab = T_a^{-1} * T_b`.
Args:
R_a (torch.Tensor): Transform `T_a` rotation matrix with shape `(...,3,3)`.
t_a (torch.Tensor): Transform `T_a` translation with shape `(...,3)`.
R_b (torch.Tensor): Transform `T_b` rotation matrix with shape `(...,3,3)`.
t_b (torch.Tensor): Transform `T_b` translation with shape `(...,3)`.
Returns:
R_ab (torch.Tensor): Composed transform `T_a * T_b` rotation matrix with
shape `(...,3,3)`. Inner dimensions are broadcastable.
t_ab (torch.Tensor): Composed transform `T_a * T_b` translation vector with
shape `(...,3)`.
"""
R_a_inverse = R_a.transpose(-1, -2)
R_ab = R_a_inverse @ R_b
t_ab = (R_a_inverse @ ((t_b - t_a).unsqueeze(-1))).squeeze(-1)
return R_ab, t_ab
def fuse_gaussians_isometric_plus_radial(
x: torch.Tensor,
p_iso: torch.Tensor,
p_rad: torch.Tensor,
direction: torch.Tensor,
dim: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Fuse Gaussians along a dimension ``dim``. This assumes the Gaussian
precision matrices are a sum of an isometric part P_iso together with
a part P_rad that provides information only along one direction.
Args:
x (torch.Tensor): A (...,3)-shaped tensor of means.
p_iso (torch.Tensor): A (...)-shaped tensor of weights of the isometric part of the
precision matrix.
p_rad (torch.Tensor): A (...)-shaped tensor of weights of the radial part of the
precision matrix.
direction (torch.Tensor): A (...,3)-shaped tensor of directions along which
information is available.
dim (int): The dimension over which to aggregate (fuse).
Returns:
A tuple ``(x_fused, P_fused)`` of fused mean and precision, with
specified ``dim`` removed.
"""
assert dim >= 0, "dimension must index from the left"
# P_rad has information only parallel to the edge.
outer = direction.unsqueeze(-1) * direction.unsqueeze(-2)
inner = direction.square().sum(-1).clamp(min=1e-10)
P_rad = (p_rad / inner)[..., None, None] * outer
P_iso = p_iso.unsqueeze(-1).expand(p_iso.shape + (3,)).diag_embed()
P = P_iso + P_rad
# Compute the Bayesian fusion aka product-of-experts of the Gaussians.
P_fused = P.sum(dim)
Px_fused = (P @ x.unsqueeze(-1)).squeeze(-1).sum(dim)
# There might be a cheaper way to do this, either via Cholesky
# or hand-coding the 3x3 matrix solve operation.
x_fused = torch.linalg.solve(P_fused, Px_fused)
return x_fused, P_fused
def collect_neighbor_transforms(
R_i: torch.Tensor, t_i: torch.Tensor, edge_idx: torch.LongTensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Collect neighbor transforms.
Args:
R_i (torch.Tensor): Transform `T` rotation matrices with shape
`(num_batch, num_residues, 3, 3)`.
t_i (torch.Tensor): Transform `T` translations with shape
`(num_batch, num_residues, 3)`.
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
`(num_batch, num_nodes, num_neighbors)`.
Returns:
R_j (torch.Tensor): Rotation matrices of neighbor transforms, with shape
`(num_batch, num_residues, num_neighbors, 3, 3)`.
t_j (torch.Tensor): Translations of neighbor transforms, with shape
`(num_batch, num_residues, num_neighbors, 3)`.
"""
num_batch, num_residues, num_neighbors = edge_idx.shape
R_i_flat = R_i.reshape([num_batch, num_residues, 9])
R_j = graph.collect_neighbors(R_i_flat, edge_idx).reshape(
[num_batch, num_residues, num_neighbors, 3, 3]
)
t_j = graph.collect_neighbors(t_i, edge_idx)
return R_j, t_j
def collect_neighbor_inner_transforms(
R_i: torch.Tensor, t_i: torch.Tensor, edge_idx: torch.LongTensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Collect inner transforms between neighbors.
Args:
R_i (torch.Tensor): Transform `T` rotation matrices with shape
`(num_batch, num_residues, 3, 3)`.
t_i (torch.Tensor): Transform `T` translations with shape
`(num_batch, num_residues, 3)`.
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
`(num_batch, num_nodes, num_neighbors)`.
Returns:
R_ji (torch.Tensor): Rotation matrices of neighbor transforms, with shape
`(num_batch, num_residues, num_neighbors, 3, 3)`.
t_ji (torch.Tensor): Translations of neighbor transforms, with shape
`(num_batch, num_residues, num_neighbors, 3)`.
"""
R_j, t_j = collect_neighbor_transforms(R_i, t_i, edge_idx)
R_ji, t_ji = compose_inner_transforms(
R_j, t_j, R_i.unsqueeze(-3), t_i.unsqueeze(-2)
)
return R_ji, t_ji
def equilibrate_transforms(
R_i: torch.Tensor,
t_i: torch.Tensor,
R_ji: torch.Tensor,
t_ji: torch.Tensor,
logit_ij: torch.Tensor,
mask_ij: torch.Tensor,
edge_idx: torch.LongTensor,
iterations: int = 1,
R_global: Optional[torch.Tensor] = None,
t_global: Optional[torch.Tensor] = None,
R_global_i: Optional[torch.Tensor] = None,
t_global_i: Optional[torch.Tensor] = None,
logit_global_i: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Equilibrate neighbor transforms.
Args:
R_i (torch.Tensor): Transform `T` rotation matrices with shape
`(num_batch, num_residues, 3, 3)`.
t_i (torch.Tensor): Transform `T` translations with shape
`(num_batch, num_residues, 3)`.
R_ji (torch.Tensor): Rotation matrices to go between frames for nodes i and j
with shape `(num_batch, num_residues, num_neighbors, 3, 3)`.
t_ji (torch.Tensor): Translations to go between frames for nodes i and j with
shape `(num_batch, num_residues, num_neighbors, 3)`.
logit_ij (torch.Tensor): Logits for averaging neighbor transforms with shape
`(num_batch, num_residues, num_neighbors, num_weights)`. Note that
`num_weights` must be 1, 2, or 3; see the documentation for
`generate.layers.structure.transforms.average_transforms` for an
explanation of the interpretations with different `num_weights`.
mask_ij (torch.Tensor): Mask for averaging neighbor transforms with shape
`(num_batch, num_residues, num_neighbors)`.
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
`(num_batch, num_nodes, num_neighbors)`.
iterations (int): Number of iterations to equilibrate for.
R_global (torch.Tensor): Optional global frame rotation matrix with shape
`(num_batch, 3, 3)`.
t_global (torch.Tensor): Optional global frame translation with shape
`(num_batch, 3)`.
R_global_i (torch.Tensor): Optional rotation matrix for global frame from
nodes with shape `(num_batch, num_residues, 3, 3)`.
t_global_i (torch.Tensor): Optional translation for global frame from nodes
with shape `(num_batch, num_residues, 3)`.
logit_global_i (torch.Tensor): Logits for averaging global frame transform
with shape `(num_batch, num_residues, num_weights)`. `num_weights`
should match that of `logit_ij`.
Returns:
R_i (torch.Tensor): Rotation matrices of equilibrated transforms, with shape
`(num_batch, num_residues, 3, 3)`.
t_i (torch.Tensor): Translations of equilibrated transforms, with shape
`(num_batch, num_residues, 3)`.
"""
# Optional global frames are treated as additional neighbor
update_global = False
if None not in [R_global, t_global, R_global_i, t_global_i, logit_global_i]:
update_global = True
num_batch, num_residues, num_neighbors = list(mask_ij.shape)
R_global_i = R_global_i.unsqueeze(2)
t_global_i = t_global_i.unsqueeze(2)
R_ji = torch.cat((R_ji, R_global_i), dim=2)
t_ji = torch.cat((t_ji, t_global_i), dim=2)
logit_ij = torch.cat((logit_ij, logit_global_i.unsqueeze(2)), dim=2)
R_global = R_global.reshape([num_batch, 1, 1, 3, 3]).expand(R_global_i.shape)
t_global = t_global.reshape([num_batch, 1, 1, 3]).expand(t_global_i.shape)
mask_i = (mask_ij.sum(2, keepdims=True) > 0).float()
mask_ij = torch.cat((mask_ij, mask_i), dim=2)
t_edge = None
for i in range(iterations):
R_j, t_j = collect_neighbor_transforms(R_i, t_i, edge_idx)
if update_global:
R_j = torch.cat((R_j, R_global), dim=2)
t_j = torch.cat((t_j, t_global), dim=2)
R_i_pred, t_i_pred = compose_transforms(R_j, t_j, R_ji, t_ji)
if logit_ij.size(-1) == 3:
# Compute i-j displacement in the same coordinate system as
# t_i_pred, i.e. in global coords. Sign does not matter.
t_edge = t_j - t_i_pred
R_i, t_i = average_transforms(
R_i_pred, t_i_pred, logit_ij, mask_ij, t_edge=t_edge, dim=2
)
return R_i, t_i
def average_transforms(
R: torch.Tensor,
t: torch.Tensor,
w: torch.Tensor,
mask: torch.Tensor,
dim: int,
t_edge: Optional[torch.Tensor] = None,
dither: Optional[bool] = True,
dither_eps: float = 1e-4,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Average transforms with optional dithering.
Args:
R (torch.Tensor): Transform `T` rotation matrix with shape `(...,3,3)`.
t (torch.Tensor): Transform `T` translation with shape `(...,3)`.
w (torch.Tensor): Logits for averaging weights with shape
`(...,num_weights)`. `num_weights` can be 1 (single scalar
weight per transform), 2 (separate weights for each rotation
and translation), or 3 (one weight for rotation, two weights
for translation corresponding to precision in all directions /
along t_edge).
mask (torch.Tensor): Mask for averaging weights with shape `(...)`.
dim (int): Dimension to average along.
t_edge (torch.Tensor, optional): Translation `T` of shape `(..., 3)`
indicating the displacement between source and target nodes.
dither (bool): Whether to noise final rotations.
dither_eps (float): Fractional amount by which to noise rotations.
Returns:
R_avg (torch.Tensor): Average transform `T_avg` rotation matrix with
shape `(...{reduced}...,3,3)`.
t_avg (torch.Tensor): Average transform `T_avg` translation with
shape `(...{reduced}...,3)`.
"""
assert dim >= 0, "dimension must index from the left"
w = torch.where(
mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min)
)
# We use different averaging models based on the number of weights
num_transform_weights = w.size(-1)
if num_transform_weights == 1:
# Share a single scalar weight between t and R.
probs = w.softmax(dim)
t_probs = probs
R_probs = probs[..., None]
# Average translation.
t_avg = (t * t_probs).sum(dim)
elif num_transform_weights == 2:
# Use separate scalar weights for each of t and R.
probs = w.softmax(dim)
t_probs, R_probs = probs.unbind(-1)
t_probs = t_probs[..., None]
R_probs = R_probs[..., None, None]
# Average translation.
t_avg = (t * t_probs).sum(dim)
elif num_transform_weights == 3:
# For R use a signed scalar weight.
R_probs = w[..., 2].softmax(dim)[..., None, None]
# For t use a two-parameter precision matrix P = P_isometric + P_radial.
# We need to hand compute softmax over the shared dim x 2 elements.
w_t = w[..., :2]
w_t_total = w_t.logsumexp([dim, -1], True)
p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1)
# Use Gaussian fusion for translation.
t_edge = t_edge * mask.to(t_edge.dtype)[..., None]
t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim)
else:
raise NotImplementedError
# Average rotation via SVD
R_avg_unc = (R * R_probs).sum(dim)
R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc)
U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True)
R_avg = U @ Vh
# Enforce that matrix is rotation matrix
d = torch.linalg.det(R_avg)
d_expand = F.pad(d[..., None, None], (2, 0), value=1.0)
Vh = Vh * d_expand
R_avg = U @ Vh
return R_avg, t_avg
def _debug_plot_transforms(
R_ij: torch.Tensor,
t_ij: torch.Tensor,
logits_ij: torch.Tensor,
edge_idx: torch.LongTensor,
mask_ij: torch.Tensor,
dist_eps: float = 1e-3,
):
"""Visualize 6dof frame transformations"""
from matplotlib import pyplot as plt
num_batch = R_ij.shape[0]
num_residues = R_ij.shape[1]
# Masked softmax on logits
# logits_ij = torch.where(
# mask_ij.bool(), logits_ij,
# torch.full_like(logits_ij, torch.finfo(logits_ij.dtype).min)
# )
p_ij = torch.softmax(logits_ij, 2)
p_ij = torch.log_softmax(logits_ij, 2)
# p_ij = torch.softmax(logits_ij, 2)
P_ij = graph.scatter_edges(p_ij[..., None], edge_idx)[..., 0]
q_ij = geometry.quaternions_from_rotations(R_ij)
q_ij = graph.scatter_edges(q_ij, edge_idx)
t_ij = graph.scatter_edges(t_ij, edge_idx)
# Converte to distance, direction, orientation
D = torch.sqrt(t_ij.square().sum(-1))
U = t_ij / (D[..., None] + dist_eps)
D_max = D.max().item()
t_ij = t_ij / D_max
q_axis = q_ij[..., 1:]
# Distance features
D_img = D
D_img_min = D_img.min().item()
D_img_max = D_img.max().item()
def _format(T):
T = T.cpu().data.numpy()
# RGB on (0,1)^3
if len(T.shape) == 3:
T = (T + 1) / 2
return T
base_width = 4
num_cols = 4
plt.figure(figsize=(base_width * 4, base_width * num_batch), dpi=300)
ix = 1
for i in range(num_batch):
plt.subplot(num_batch, num_cols, ix)
plt.imshow(_format(D_img[i, :, :]), cmap="inferno")
# plt.clim([hD_min, hD_max])
plt.axis("off")
plt.subplot(num_batch, num_cols, ix + 1)
plt.imshow(_format(U[i, :, :, :]))
plt.axis("off")
plt.subplot(num_batch, num_cols, ix + 2)
plt.imshow(_format(q_axis[i, :, :, :]))
plt.axis("off")
# Confidence plots
plt.subplot(num_batch, num_cols, ix + 3)
plt.imshow(_format(P_ij[i, :, :]), cmap="inferno")
# plt.clim([0, P_ij[i,:,:].max().item()])
plt.axis("off")
ix = ix + num_cols
plt.tight_layout()
return