Spaces:
Running
Running
File size: 5,853 Bytes
224a33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from math import sqrt
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
class GeometricReasoningOriginalImpl(nn.Module):
def __init__(
self,
c_s: int,
v_heads: int,
num_vector_messages: int = 1,
mask_and_zero_frameless: bool = True,
divide_residual_by_depth: bool = False,
bias: bool = False,
):
"""Approximate implementation:
ATTN(A, v) := (softmax_j A_ij) v_j
make_rot_vectors(x) := R(i->g) Linear(x).reshape(..., 3)
make_vectors(x) := T(i->g) Linear(x).reshape(..., 3)
v <- make_rot_vectors(x)
q_dir, k_dir <- make_rot_vectors(x)
q_dist, k_dist <- make_vectors(x)
A_ij <- dot(q_dir_i, k_dir_j) -||q_dist_i - k_dist_j||^2
x <- x + Linear(T(g->i) ATTN(A, v))
"""
super().__init__()
self.c_s = c_s
self.v_heads = v_heads
self.num_vector_messages = num_vector_messages
self.mask_and_zero_frameless = mask_and_zero_frameless
self.s_norm = nn.LayerNorm(c_s, bias=bias)
dim_proj = (
4 * self.v_heads * 3 + self.v_heads * 3 * self.num_vector_messages
) # 2 x (q, k) * number of heads * (x, y, z) + number of heads * number of vector messages * (x, y, z)
self.proj = nn.Linear(c_s, dim_proj, bias=bias)
channels_out = self.v_heads * 3 * self.num_vector_messages
self.out_proj = nn.Linear(channels_out, c_s, bias=bias)
# The basic idea is for some attention heads to pay more or less attention to rotation versus distance,
# as well as to control the sharpness of the softmax (i.e., should this head only attend to those residues
# very nearby or should there be shallower dropoff in attention weight?)
self.distance_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
self.rotation_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
def forward(self, s, affine, affine_mask, sequence_id, chain_id):
attn_bias = sequence_id.unsqueeze(-1) == sequence_id.unsqueeze(-2)
attn_bias = attn_bias.unsqueeze(1).float()
attn_bias = attn_bias.masked_fill(
~affine_mask[:, None, None, :], torch.finfo(attn_bias.dtype).min
)
chain_id_mask = chain_id.unsqueeze(1) != chain_id.unsqueeze(2)
attn_bias = attn_bias.masked_fill(
chain_id_mask.unsqueeze(1), torch.finfo(s.dtype).min
)
ns = self.s_norm(s)
vec_rot, vec_dist = self.proj(ns).split(
[
self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages,
self.v_heads * 2 * 3,
],
dim=-1,
)
# Rotate the queries and keys for the rotation term. We also rotate the values.
# NOTE(zeming, thayes): Values are only rotated, not translated. We may wish to change
# this in the future.
query_rot, key_rot, value = (
affine.rot[..., None]
.apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3))
.split(
[
self.v_heads,
self.v_heads,
self.v_heads * self.num_vector_messages,
],
dim=-2,
)
)
# Rotate and translate the queries and keys for the distance term
# NOTE(thayes): a simple speedup would be to apply all rotations together, then
# separately apply the translations.
query_dist, key_dist = (
affine[..., None]
.apply(rearrange(vec_dist, "... (h c) -> ... h c", c=3))
.chunk(2, dim=-2)
)
query_dist = rearrange(query_dist, "b s h d -> b h s 1 d")
key_dist = rearrange(key_dist, "b s h d -> b h 1 s d")
query_rot = rearrange(query_rot, "b s h d -> b h s d")
key_rot = rearrange(key_rot, "b s h d -> b h d s")
value = rearrange(
value, "b s (h m) d -> b h s (m d)", m=self.num_vector_messages
)
distance_term = (query_dist - key_dist).norm(dim=-1) / sqrt(3)
rotation_term = query_rot.matmul(key_rot) / sqrt(3)
distance_term_weight = rearrange(
F.softplus(self.distance_scale_per_head), "h -> h 1 1"
)
rotation_term_weight = rearrange(
F.softplus(self.rotation_scale_per_head), "h -> h 1 1"
)
attn_weight = (
rotation_term * rotation_term_weight - distance_term * distance_term_weight
)
if attn_bias is not None:
# we can re-use the attention bias from the transformer layers
# NOTE(thayes): This attention bias is expected to handle two things:
# 1. Masking attention on padding tokens
# 2. Masking cross sequence attention in the case of bin packing
s_q = attn_weight.size(2)
s_k = attn_weight.size(3)
_s_q = max(0, attn_bias.size(2) - s_q)
_s_k = max(0, attn_bias.size(3) - s_k)
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
attn_weight = attn_weight + attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_out = attn_weight.matmul(value)
attn_out = (
affine.rot[..., None]
.invert()
.apply(
rearrange(
attn_out, "b h s (m d) -> b s (h m) d", m=self.num_vector_messages
)
)
)
attn_out = rearrange(
attn_out, "b s (h m) d -> b s (h m d)", m=self.num_vector_messages
)
if self.mask_and_zero_frameless:
attn_out = attn_out.masked_fill(~affine_mask[..., None], 0.0)
s = self.out_proj(attn_out)
return s
|