Spaces:
Sleeping
Sleeping
# Copyright Generate Biomedicines, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Layers for batched 3D transformations, such as residue poses. | |
This module contains pytorch layers for computing and composing with | |
3D, 6-degree-of freedom transformations. | |
""" | |
from typing import Optional, Tuple | |
import torch | |
import torch.nn.functional as F | |
from chroma.layers import graph | |
from chroma.layers.structure import geometry | |
def compose_transforms( | |
R_a: torch.Tensor, t_a: torch.Tensor, R_b: torch.Tensor, t_b: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Compose transforms `T_compose = T_a * T_b` (broadcastable). | |
Args: | |
R_a (torch.Tensor): Transform `T_a` rotation matrix with shape `(...,3,3)`. | |
t_a (torch.Tensor): Transform `T_a` translation with shape `(...,3)`. | |
R_b (torch.Tensor): Transform `T_b` rotation matrix with shape `(...,3,3)`. | |
t_b (torch.Tensor): Transform `T_b` translation with shape `(...,3)`. | |
Returns: | |
R_composed (torch.Tensor): Composed transform `a * b` rotation matrix with | |
shape `(...,3,3)`. | |
t_composed (torch.Tensor): Composed transform `a * b` translation vector with | |
shape `(...,3)`. | |
""" | |
R_composed = R_a @ R_b | |
t_composed = t_a + (R_a @ t_b.unsqueeze(-1)).squeeze(-1) | |
return R_composed, t_composed | |
def compose_translation( | |
R_a: torch.Tensor, t_a: torch.Tensor, t_b: torch.Tensor | |
) -> torch.Tensor: | |
"""Compose translation component of `T_compose = T_a * T_b` (broadcastable). | |
Args: | |
R_a (torch.Tensor): Transform `T_a` rotation matrix with shape `(...,3,3)`. | |
t_a (torch.Tensor): Transform `T_a` translation with shape `(...,3)`. | |
t_b (torch.Tensor): Transform `T_b` translation with shape `(...,3)`. | |
Returns: | |
t_composed (torch.Tensor): Composed transform `a * b` translation vector with | |
shape `(...,3)`. | |
""" | |
t_composed = t_a + (R_a @ t_b.unsqueeze(-1)).squeeze(-1) | |
return t_composed | |
def compose_inner_transforms( | |
R_a: torch.Tensor, t_a: torch.Tensor, R_b: torch.Tensor, t_b: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Compose the relative inner transform `T_ab = T_a^{-1} * T_b`. | |
Args: | |
R_a (torch.Tensor): Transform `T_a` rotation matrix with shape `(...,3,3)`. | |
t_a (torch.Tensor): Transform `T_a` translation with shape `(...,3)`. | |
R_b (torch.Tensor): Transform `T_b` rotation matrix with shape `(...,3,3)`. | |
t_b (torch.Tensor): Transform `T_b` translation with shape `(...,3)`. | |
Returns: | |
R_ab (torch.Tensor): Composed transform `T_a * T_b` rotation matrix with | |
shape `(...,3,3)`. Inner dimensions are broadcastable. | |
t_ab (torch.Tensor): Composed transform `T_a * T_b` translation vector with | |
shape `(...,3)`. | |
""" | |
R_a_inverse = R_a.transpose(-1, -2) | |
R_ab = R_a_inverse @ R_b | |
t_ab = (R_a_inverse @ ((t_b - t_a).unsqueeze(-1))).squeeze(-1) | |
return R_ab, t_ab | |
def fuse_gaussians_isometric_plus_radial( | |
x: torch.Tensor, | |
p_iso: torch.Tensor, | |
p_rad: torch.Tensor, | |
direction: torch.Tensor, | |
dim: int, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Fuse Gaussians along a dimension ``dim``. This assumes the Gaussian | |
precision matrices are a sum of an isometric part P_iso together with | |
a part P_rad that provides information only along one direction. | |
Args: | |
x (torch.Tensor): A (...,3)-shaped tensor of means. | |
p_iso (torch.Tensor): A (...)-shaped tensor of weights of the isometric part of the | |
precision matrix. | |
p_rad (torch.Tensor): A (...)-shaped tensor of weights of the radial part of the | |
precision matrix. | |
direction (torch.Tensor): A (...,3)-shaped tensor of directions along which | |
information is available. | |
dim (int): The dimension over which to aggregate (fuse). | |
Returns: | |
A tuple ``(x_fused, P_fused)`` of fused mean and precision, with | |
specified ``dim`` removed. | |
""" | |
assert dim >= 0, "dimension must index from the left" | |
# P_rad has information only parallel to the edge. | |
outer = direction.unsqueeze(-1) * direction.unsqueeze(-2) | |
inner = direction.square().sum(-1).clamp(min=1e-10) | |
P_rad = (p_rad / inner)[..., None, None] * outer | |
P_iso = p_iso.unsqueeze(-1).expand(p_iso.shape + (3,)).diag_embed() | |
P = P_iso + P_rad | |
# Compute the Bayesian fusion aka product-of-experts of the Gaussians. | |
P_fused = P.sum(dim) | |
Px_fused = (P @ x.unsqueeze(-1)).squeeze(-1).sum(dim) | |
# There might be a cheaper way to do this, either via Cholesky | |
# or hand-coding the 3x3 matrix solve operation. | |
x_fused = torch.linalg.solve(P_fused, Px_fused) | |
return x_fused, P_fused | |
def collect_neighbor_transforms( | |
R_i: torch.Tensor, t_i: torch.Tensor, edge_idx: torch.LongTensor | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Collect neighbor transforms. | |
Args: | |
R_i (torch.Tensor): Transform `T` rotation matrices with shape | |
`(num_batch, num_residues, 3, 3)`. | |
t_i (torch.Tensor): Transform `T` translations with shape | |
`(num_batch, num_residues, 3)`. | |
edge_idx (torch.LongTensor): Edge indices for neighbors with shape | |
`(num_batch, num_nodes, num_neighbors)`. | |
Returns: | |
R_j (torch.Tensor): Rotation matrices of neighbor transforms, with shape | |
`(num_batch, num_residues, num_neighbors, 3, 3)`. | |
t_j (torch.Tensor): Translations of neighbor transforms, with shape | |
`(num_batch, num_residues, num_neighbors, 3)`. | |
""" | |
num_batch, num_residues, num_neighbors = edge_idx.shape | |
R_i_flat = R_i.reshape([num_batch, num_residues, 9]) | |
R_j = graph.collect_neighbors(R_i_flat, edge_idx).reshape( | |
[num_batch, num_residues, num_neighbors, 3, 3] | |
) | |
t_j = graph.collect_neighbors(t_i, edge_idx) | |
return R_j, t_j | |
def collect_neighbor_inner_transforms( | |
R_i: torch.Tensor, t_i: torch.Tensor, edge_idx: torch.LongTensor | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Collect inner transforms between neighbors. | |
Args: | |
R_i (torch.Tensor): Transform `T` rotation matrices with shape | |
`(num_batch, num_residues, 3, 3)`. | |
t_i (torch.Tensor): Transform `T` translations with shape | |
`(num_batch, num_residues, 3)`. | |
edge_idx (torch.LongTensor): Edge indices for neighbors with shape | |
`(num_batch, num_nodes, num_neighbors)`. | |
Returns: | |
R_ji (torch.Tensor): Rotation matrices of neighbor transforms, with shape | |
`(num_batch, num_residues, num_neighbors, 3, 3)`. | |
t_ji (torch.Tensor): Translations of neighbor transforms, with shape | |
`(num_batch, num_residues, num_neighbors, 3)`. | |
""" | |
R_j, t_j = collect_neighbor_transforms(R_i, t_i, edge_idx) | |
R_ji, t_ji = compose_inner_transforms( | |
R_j, t_j, R_i.unsqueeze(-3), t_i.unsqueeze(-2) | |
) | |
return R_ji, t_ji | |
def equilibrate_transforms( | |
R_i: torch.Tensor, | |
t_i: torch.Tensor, | |
R_ji: torch.Tensor, | |
t_ji: torch.Tensor, | |
logit_ij: torch.Tensor, | |
mask_ij: torch.Tensor, | |
edge_idx: torch.LongTensor, | |
iterations: int = 1, | |
R_global: Optional[torch.Tensor] = None, | |
t_global: Optional[torch.Tensor] = None, | |
R_global_i: Optional[torch.Tensor] = None, | |
t_global_i: Optional[torch.Tensor] = None, | |
logit_global_i: Optional[torch.Tensor] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Equilibrate neighbor transforms. | |
Args: | |
R_i (torch.Tensor): Transform `T` rotation matrices with shape | |
`(num_batch, num_residues, 3, 3)`. | |
t_i (torch.Tensor): Transform `T` translations with shape | |
`(num_batch, num_residues, 3)`. | |
R_ji (torch.Tensor): Rotation matrices to go between frames for nodes i and j | |
with shape `(num_batch, num_residues, num_neighbors, 3, 3)`. | |
t_ji (torch.Tensor): Translations to go between frames for nodes i and j with | |
shape `(num_batch, num_residues, num_neighbors, 3)`. | |
logit_ij (torch.Tensor): Logits for averaging neighbor transforms with shape | |
`(num_batch, num_residues, num_neighbors, num_weights)`. Note that | |
`num_weights` must be 1, 2, or 3; see the documentation for | |
`generate.layers.structure.transforms.average_transforms` for an | |
explanation of the interpretations with different `num_weights`. | |
mask_ij (torch.Tensor): Mask for averaging neighbor transforms with shape | |
`(num_batch, num_residues, num_neighbors)`. | |
edge_idx (torch.LongTensor): Edge indices for neighbors with shape | |
`(num_batch, num_nodes, num_neighbors)`. | |
iterations (int): Number of iterations to equilibrate for. | |
R_global (torch.Tensor): Optional global frame rotation matrix with shape | |
`(num_batch, 3, 3)`. | |
t_global (torch.Tensor): Optional global frame translation with shape | |
`(num_batch, 3)`. | |
R_global_i (torch.Tensor): Optional rotation matrix for global frame from | |
nodes with shape `(num_batch, num_residues, 3, 3)`. | |
t_global_i (torch.Tensor): Optional translation for global frame from nodes | |
with shape `(num_batch, num_residues, 3)`. | |
logit_global_i (torch.Tensor): Logits for averaging global frame transform | |
with shape `(num_batch, num_residues, num_weights)`. `num_weights` | |
should match that of `logit_ij`. | |
Returns: | |
R_i (torch.Tensor): Rotation matrices of equilibrated transforms, with shape | |
`(num_batch, num_residues, 3, 3)`. | |
t_i (torch.Tensor): Translations of equilibrated transforms, with shape | |
`(num_batch, num_residues, 3)`. | |
""" | |
# Optional global frames are treated as additional neighbor | |
update_global = False | |
if None not in [R_global, t_global, R_global_i, t_global_i, logit_global_i]: | |
update_global = True | |
num_batch, num_residues, num_neighbors = list(mask_ij.shape) | |
R_global_i = R_global_i.unsqueeze(2) | |
t_global_i = t_global_i.unsqueeze(2) | |
R_ji = torch.cat((R_ji, R_global_i), dim=2) | |
t_ji = torch.cat((t_ji, t_global_i), dim=2) | |
logit_ij = torch.cat((logit_ij, logit_global_i.unsqueeze(2)), dim=2) | |
R_global = R_global.reshape([num_batch, 1, 1, 3, 3]).expand(R_global_i.shape) | |
t_global = t_global.reshape([num_batch, 1, 1, 3]).expand(t_global_i.shape) | |
mask_i = (mask_ij.sum(2, keepdims=True) > 0).float() | |
mask_ij = torch.cat((mask_ij, mask_i), dim=2) | |
t_edge = None | |
for i in range(iterations): | |
R_j, t_j = collect_neighbor_transforms(R_i, t_i, edge_idx) | |
if update_global: | |
R_j = torch.cat((R_j, R_global), dim=2) | |
t_j = torch.cat((t_j, t_global), dim=2) | |
R_i_pred, t_i_pred = compose_transforms(R_j, t_j, R_ji, t_ji) | |
if logit_ij.size(-1) == 3: | |
# Compute i-j displacement in the same coordinate system as | |
# t_i_pred, i.e. in global coords. Sign does not matter. | |
t_edge = t_j - t_i_pred | |
R_i, t_i = average_transforms( | |
R_i_pred, t_i_pred, logit_ij, mask_ij, t_edge=t_edge, dim=2 | |
) | |
return R_i, t_i | |
def average_transforms( | |
R: torch.Tensor, | |
t: torch.Tensor, | |
w: torch.Tensor, | |
mask: torch.Tensor, | |
dim: int, | |
t_edge: Optional[torch.Tensor] = None, | |
dither: Optional[bool] = True, | |
dither_eps: float = 1e-4, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Average transforms with optional dithering. | |
Args: | |
R (torch.Tensor): Transform `T` rotation matrix with shape `(...,3,3)`. | |
t (torch.Tensor): Transform `T` translation with shape `(...,3)`. | |
w (torch.Tensor): Logits for averaging weights with shape | |
`(...,num_weights)`. `num_weights` can be 1 (single scalar | |
weight per transform), 2 (separate weights for each rotation | |
and translation), or 3 (one weight for rotation, two weights | |
for translation corresponding to precision in all directions / | |
along t_edge). | |
mask (torch.Tensor): Mask for averaging weights with shape `(...)`. | |
dim (int): Dimension to average along. | |
t_edge (torch.Tensor, optional): Translation `T` of shape `(..., 3)` | |
indicating the displacement between source and target nodes. | |
dither (bool): Whether to noise final rotations. | |
dither_eps (float): Fractional amount by which to noise rotations. | |
Returns: | |
R_avg (torch.Tensor): Average transform `T_avg` rotation matrix with | |
shape `(...{reduced}...,3,3)`. | |
t_avg (torch.Tensor): Average transform `T_avg` translation with | |
shape `(...{reduced}...,3)`. | |
""" | |
assert dim >= 0, "dimension must index from the left" | |
w = torch.where( | |
mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min) | |
) | |
# We use different averaging models based on the number of weights | |
num_transform_weights = w.size(-1) | |
if num_transform_weights == 1: | |
# Share a single scalar weight between t and R. | |
probs = w.softmax(dim) | |
t_probs = probs | |
R_probs = probs[..., None] | |
# Average translation. | |
t_avg = (t * t_probs).sum(dim) | |
elif num_transform_weights == 2: | |
# Use separate scalar weights for each of t and R. | |
probs = w.softmax(dim) | |
t_probs, R_probs = probs.unbind(-1) | |
t_probs = t_probs[..., None] | |
R_probs = R_probs[..., None, None] | |
# Average translation. | |
t_avg = (t * t_probs).sum(dim) | |
elif num_transform_weights == 3: | |
# For R use a signed scalar weight. | |
R_probs = w[..., 2].softmax(dim)[..., None, None] | |
# For t use a two-parameter precision matrix P = P_isometric + P_radial. | |
# We need to hand compute softmax over the shared dim x 2 elements. | |
w_t = w[..., :2] | |
w_t_total = w_t.logsumexp([dim, -1], True) | |
p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1) | |
# Use Gaussian fusion for translation. | |
t_edge = t_edge * mask.to(t_edge.dtype)[..., None] | |
t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim) | |
else: | |
raise NotImplementedError | |
# Average rotation via SVD | |
R_avg_unc = (R * R_probs).sum(dim) | |
R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc) | |
U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True) | |
R_avg = U @ Vh | |
# Enforce that matrix is rotation matrix | |
d = torch.linalg.det(R_avg) | |
d_expand = F.pad(d[..., None, None], (2, 0), value=1.0) | |
Vh = Vh * d_expand | |
R_avg = U @ Vh | |
return R_avg, t_avg | |
def _debug_plot_transforms( | |
R_ij: torch.Tensor, | |
t_ij: torch.Tensor, | |
logits_ij: torch.Tensor, | |
edge_idx: torch.LongTensor, | |
mask_ij: torch.Tensor, | |
dist_eps: float = 1e-3, | |
): | |
"""Visualize 6dof frame transformations""" | |
from matplotlib import pyplot as plt | |
num_batch = R_ij.shape[0] | |
num_residues = R_ij.shape[1] | |
# Masked softmax on logits | |
# logits_ij = torch.where( | |
# mask_ij.bool(), logits_ij, | |
# torch.full_like(logits_ij, torch.finfo(logits_ij.dtype).min) | |
# ) | |
p_ij = torch.softmax(logits_ij, 2) | |
p_ij = torch.log_softmax(logits_ij, 2) | |
# p_ij = torch.softmax(logits_ij, 2) | |
P_ij = graph.scatter_edges(p_ij[..., None], edge_idx)[..., 0] | |
q_ij = geometry.quaternions_from_rotations(R_ij) | |
q_ij = graph.scatter_edges(q_ij, edge_idx) | |
t_ij = graph.scatter_edges(t_ij, edge_idx) | |
# Converte to distance, direction, orientation | |
D = torch.sqrt(t_ij.square().sum(-1)) | |
U = t_ij / (D[..., None] + dist_eps) | |
D_max = D.max().item() | |
t_ij = t_ij / D_max | |
q_axis = q_ij[..., 1:] | |
# Distance features | |
D_img = D | |
D_img_min = D_img.min().item() | |
D_img_max = D_img.max().item() | |
def _format(T): | |
T = T.cpu().data.numpy() | |
# RGB on (0,1)^3 | |
if len(T.shape) == 3: | |
T = (T + 1) / 2 | |
return T | |
base_width = 4 | |
num_cols = 4 | |
plt.figure(figsize=(base_width * 4, base_width * num_batch), dpi=300) | |
ix = 1 | |
for i in range(num_batch): | |
plt.subplot(num_batch, num_cols, ix) | |
plt.imshow(_format(D_img[i, :, :]), cmap="inferno") | |
# plt.clim([hD_min, hD_max]) | |
plt.axis("off") | |
plt.subplot(num_batch, num_cols, ix + 1) | |
plt.imshow(_format(U[i, :, :, :])) | |
plt.axis("off") | |
plt.subplot(num_batch, num_cols, ix + 2) | |
plt.imshow(_format(q_axis[i, :, :, :])) | |
plt.axis("off") | |
# Confidence plots | |
plt.subplot(num_batch, num_cols, ix + 3) | |
plt.imshow(_format(P_ij[i, :, :]), cmap="inferno") | |
# plt.clim([0, P_ij[i,:,:].max().item()]) | |
plt.axis("off") | |
ix = ix + num_cols | |
plt.tight_layout() | |
return | |