Spaces:
Sleeping
Sleeping
# Copyright Generate Biomedicines, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Layers for building graph representations of protein structure, all-atom. | |
This module contains pytorch layers for representing protein structure as a | |
graph with node and edge features based on geometric information. The graph | |
features are differentiable with respect to input coordinates and can be used | |
for building protein scoring functions and optimizing protein geometries | |
natively in pytorch. | |
""" | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from chroma.layers import graph | |
from chroma.layers.structure import geometry, sidechain | |
class NodeChiRBF(nn.Module): | |
"""Layers for featurizing chi angles with a smooth binning | |
Args: | |
num_chi_bins (int): Number of bins for discretizing chi angles. | |
num_chi (int): Number of chi angles. | |
dim_out (int): Number of output feature dimensions. | |
bin_scale (float, optional): Scaling parameter that sets bin smoothing. | |
Input: | |
chi (Tensor): Chi angles with shape `(num_batch, num_residues, num_chi)`. | |
Output: | |
h_chi (Tensor): Chi angle features with shape | |
`(num_batch, num_residues, num_chi * num_chi_bins)`. | |
""" | |
def __init__(self, dim_out, num_chi, num_chi_bins, bin_scale=2.0): | |
super(NodeChiRBF, self).__init__() | |
self.dim_out = dim_out | |
self.num_chi = num_chi | |
self.num_chi_bins = num_chi_bins | |
self.bin_scale = bin_scale | |
self.embed = nn.Linear(self.num_chi * self.num_chi_bins, dim_out) | |
def _featurize(self, chi, mask_chi=None): | |
num_batch, num_residues, _ = chi.shape | |
chi_bin_center = ( | |
torch.arange(0, self.num_chi_bins, device=chi.device) | |
* 2.0 | |
* np.pi | |
/ self.num_chi_bins | |
) | |
chi_bin_center = chi_bin_center.reshape([1, 1, 1, -1]) | |
# Set smoothing length scale based on ratio beteen adjacent bin centers | |
# bin_i / bin_i+1 = 1 / scale | |
delta_adjacent = np.cos(0.0) - np.cos(2.0 * np.pi / self.num_chi_bins) | |
cosine = torch.cos(chi.unsqueeze(-1) - chi_bin_center) | |
chi_features = torch.exp((cosine - 1.0) * self.bin_scale / delta_adjacent) | |
if mask_chi is not None: | |
chi_features = mask_chi.unsqueeze(-1) * chi_features | |
chi_features = chi_features.reshape( | |
[num_batch, num_residues, self.num_chi * self.num_chi_bins] | |
) | |
return chi_features | |
def forward(self, chi, mask_chi=None): | |
chi_features = self._featurize(chi, mask_chi=mask_chi) | |
h_chi = self.embed(chi_features) | |
return h_chi | |
class EdgeSidechainsDirect(nn.Module): | |
"""Layers for direct encoding of side chain geometries. | |
Args: | |
dim_out (int): Number of output hidden dimensions. | |
max_D (float, optional): Maximum distance cutoff for encoding | |
of edges. | |
Input: | |
X (Tensor): All atom coordinates with shape | |
`(num_batch, num_residues, 14, 3)`. | |
C (LongTensor): Chain map with shape `(num_batch, num_residues)`. | |
S (LongTensor): Sequence tensor with shape | |
`(num_batch, num_residues)`. | |
edge_idx (Tensor): Graph indices for expansion with shape | |
`(num_batch, num_residues_out, num_neighbors)`. The dimension | |
of output variables `num_residues_out` must either equal | |
`num_residues` or 1, the latter of which can be useful for sequential | |
decoding. | |
Output: | |
h (Tensor): Features with shape | |
`(num_batch, num_residues_out, num_neighbors, num_hidden)`. | |
""" | |
def __init__( | |
self, | |
dim_out, | |
length_scale=7.5, | |
distance_eps=0.1, | |
num_fourier=30, | |
fourier_order=2, | |
basis_type="rff", | |
): | |
super(EdgeSidechainsDirect, self).__init__() | |
self.dim_out = dim_out | |
self.length_scale = length_scale | |
self.distance_eps = distance_eps | |
# self.embed = nn.Linear(14 * 3 , dim_out) | |
self.num_fourier = num_fourier | |
self.rff = torch.nn.Parameter( | |
2.0 * np.pi / self.length_scale * torch.randn((3, self.num_fourier)) | |
) | |
self.basis_type = basis_type | |
if self.basis_type == "rff": | |
self.embed = nn.Linear(14 * self.num_fourier * 2, dim_out) | |
elif self.basis_type == "spherical": | |
self.fourier_order = fourier_order | |
self.embed = nn.Linear(14 * (self.fourier_order * 2) ** 3, dim_out) | |
def _local_coordinates(self, X, C, S, edge_idx): | |
num_batch, num_residues, num_neighbors = edge_idx.shape | |
# Mask and transform into features | |
mask_atoms = sidechain.atom_mask(C, S) | |
mask_atoms_j = graph.collect_neighbors(mask_atoms, edge_idx) | |
mask_i = (C > 0).float().reshape([num_batch, num_residues, 1, 1]) | |
mask_atoms_ij = mask_i * mask_atoms_j | |
# Build conditioning mask | |
R_i, CA = geometry.frames_from_backbone(X[:, :, :4, :]) | |
# Transform neighbor X coordinates into local frames | |
X_flat = X.reshape([num_batch, num_residues, -1]) | |
X_j_flat = graph.collect_neighbors(X_flat, edge_idx) | |
X_j = X_j_flat.reshape([num_batch, num_residues, num_neighbors, 14, 3]) | |
dX_ij = X_j - CA.reshape([num_batch, num_residues, 1, 1, 3]) | |
U_ij = torch.einsum("niab,nijma->nijmb", R_i, dX_ij) | |
return U_ij, mask_atoms_ij | |
def _local_coordinates_t(self, t, X, C, S, edge_idx_t): | |
num_batch, _, num_neighbors = edge_idx_t.shape | |
num_residues = X.shape[1] | |
# Make a mask that | |
C_i = C[:, t].unsqueeze(1) | |
# S_i = S[:,t].unsqueeze(1) | |
# mask_atoms_i = sidechain.atom_mask(C_i, S_i) | |
C_j = graph.collect_neighbors(C.unsqueeze(-1), edge_idx_t).reshape( | |
[num_batch, num_neighbors] | |
) | |
S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx_t).reshape( | |
[num_batch, num_neighbors] | |
) | |
mask_atoms_j = sidechain.atom_mask(C_j, S_j).unsqueeze(1) | |
mask_i = (C_i > 0).float().reshape([num_batch, 1, 1, 1]) | |
mask_atoms_ij = mask_i * mask_atoms_j | |
# Build conditioning mask | |
X_bb_i = X[:, t, :4, :].unsqueeze(1) | |
R_i, CA = geometry.frames_from_backbone(X_bb_i) | |
# Transform neighbor X coordinates into local frames | |
X_flat = X.reshape([num_batch, num_residues, -1]) | |
X_j_flat = graph.collect_neighbors(X_flat, edge_idx_t) | |
X_j = X_j_flat.reshape([num_batch, 1, num_neighbors, 14, 3]) | |
dX_ij = X_j - CA.reshape([num_batch, 1, 1, 1, 3]) | |
U_ij = torch.einsum("niab,nijma->nijmb", R_i, dX_ij) | |
return U_ij, mask_atoms_ij | |
def _fourier_expand(self, h, order): | |
k = torch.arange(order, device=h.device) | |
k = k.reshape([1 for i in h.shape] + [-1]) | |
return torch.cat( | |
[torch.sin(h.unsqueeze(-1) * (k + 1)), torch.cos(h.unsqueeze(-1) * k)], | |
dim=-1, | |
) | |
def _featurize(self, U_ij, mask_atoms_ij): | |
if self.basis_type == "rff": | |
# Random fourier features | |
U_ij = mask_atoms_ij.unsqueeze(-1) * U_ij | |
U_ff = torch.einsum("nijax,xy->nijay", U_ij, self.rff) | |
U_ff = torch.concat([torch.cos(U_ff), torch.sin(U_ff)], -1) | |
# Gaussian RBF envelope | |
D_ij = torch.sqrt((U_ij ** 2).sum(-1) + self.distance_eps) | |
magnitude = torch.exp(-D_ij * D_ij / (2 * self.length_scale ** 2)) | |
U_ff = magnitude.unsqueeze(-1) * U_ff | |
U_ff = U_ff.reshape(list(D_ij.shape)[:3] + [-1]) | |
h = mask_atoms_ij[:, :, :, 0].unsqueeze(-1) * self.embed(U_ff) | |
elif self.basis_type == "spherical": | |
# Convert to spherical coordinates | |
r_ij = torch.sqrt((U_ij ** 2).sum(-1) + self.distance_eps) | |
r_ij_scale = r_ij * 2.0 * np.pi / self.length_scale | |
x, y, z = U_ij.unbind(-1) | |
theta_ij = torch.acos(z / r_ij) | |
phi_ij = torch.atan2(y, x) | |
# Build Fourier expansions of each coordinate | |
r_ff, theta_ff, phi_ff = [ | |
self._fourier_expand(h, self.fourier_order) | |
for h in [r_ij_scale, theta_ij, phi_ij] | |
] | |
# Radial envelope function | |
r_envelope = mask_atoms_ij * torch.exp( | |
-r_ij * r_ij / (2 * self.length_scale ** 2) | |
) | |
# Tensor outer product | |
bf_ij = torch.einsum( | |
"bika,bikar,bikat,bikap->bikartp", r_envelope, r_ff, theta_ff, phi_ff | |
).reshape(list(r_ij.shape)[:3] + [-1]) | |
h = mask_atoms_ij[:, :, :, 0].unsqueeze(-1) * self.embed(bf_ij) | |
return h | |
def forward(self, X, C, S, edge_idx): | |
U_ij, mask_atoms_ij = self._local_coordinates(X, C, S, edge_idx) | |
h = self._featurize(U_ij, mask_atoms_ij) | |
return h | |
def step(self, t, X, C, S, edge_idx_t): | |
U_ij, mask_atoms_ij = self._local_coordinates_t(t, X, C, S, edge_idx_t) | |
h = self._featurize(U_ij, mask_atoms_ij) | |
return h | |