submission-template / tasks /utils /relational_transformer.py
IlayMalinyak
cnnkan
2f54ec8
import hydra
import torch
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from utils.pooling import HomogeneousAggregator
import torch.nn as nn
class RelationalTransformer(nn.Module):
def __init__(
self,
d_node,
d_edge,
d_attn_hid,
d_node_hid,
d_edge_hid,
d_out_hid,
d_out,
n_layers,
n_heads,
layer_layout,
graph_constructor,
dropout=0.0,
node_update_type="rt",
disable_edge_updates=False,
use_cls_token=False,
pooling_method="cat",
pooling_layer_idx="last",
rev_edge_features=False,
modulate_v=True,
use_ln=True,
tfixit_init=False,
):
super().__init__()
assert use_cls_token == (pooling_method == "cls_token")
self.pooling_method = pooling_method
self.pooling_layer_idx = pooling_layer_idx
self.rev_edge_features = rev_edge_features
self.nodes_per_layer = layer_layout
self.construct_graph = hydra.utils.instantiate(
graph_constructor,
d_node=d_node,
d_edge=d_edge,
layer_layout=layer_layout,
rev_edge_features=rev_edge_features,
)
self.use_cls_token = use_cls_token
if use_cls_token:
self.cls_token = nn.Parameter(torch.randn(d_node))
self.layers = nn.ModuleList(
[
torch.jit.script(
RTLayer(
d_node,
d_edge,
d_attn_hid,
d_node_hid,
d_edge_hid,
n_heads,
dropout,
node_update_type=node_update_type,
disable_edge_updates=(
(disable_edge_updates or (i == n_layers - 1))
and pooling_method != "mean_edge"
and pooling_layer_idx != "all"
),
modulate_v=modulate_v,
use_ln=use_ln,
tfixit_init=tfixit_init,
n_layers=n_layers,
)
)
for i in range(n_layers)
]
)
if pooling_method != "cls_token":
self.pool = HomogeneousAggregator(
pooling_method,
pooling_layer_idx,
layer_layout,
)
self.num_graph_features = (
layer_layout[-1] * d_node
if pooling_method == "cat" and pooling_layer_idx == "last"
else d_edge if pooling_method in ("mean_edge", "max_edge") else d_node
)
self.proj_out = nn.Sequential(
nn.Linear(self.num_graph_features, d_out_hid),
nn.ReLU(),
# nn.Linear(d_out_hid, d_out_hid),
# nn.ReLU(),
nn.Linear(d_out_hid, d_out),
)
self.final_features = (None,None,None,None)
def forward(self, inputs):
attn_weights = None
node_features, edge_features, mask = self.construct_graph(inputs)
if self.use_cls_token:
node_features = torch.cat(
[
# repeat(self.cls_token, "d -> b 1 d", b=node_features.size(0)),
self.cls_token.unsqueeze(0).expand(node_features.size(0), 1, -1),
node_features,
],
dim=1,
)
edge_features = F.pad(edge_features, (0, 0, 1, 0, 1, 0), value=0)
for layer in self.layers:
node_features, edge_features, attn_weights = layer(node_features, edge_features, mask)
if self.pooling_method == "cls_token":
graph_features = node_features[:, 0]
else:
graph_features = self.pool(node_features, edge_features)
self.final_features = (graph_features, node_features, edge_features, attn_weights)
return self.proj_out(graph_features)
class RTLayer(nn.Module):
def __init__(
self,
d_node,
d_edge,
d_attn_hid,
d_node_hid,
d_edge_hid,
n_heads,
dropout,
node_update_type="rt",
disable_edge_updates=False,
modulate_v=True,
use_ln=True,
tfixit_init=False,
n_layers=None,
):
super().__init__()
self.node_update_type = node_update_type
self.disable_edge_updates = disable_edge_updates
self.use_ln = use_ln
self.n_layers = n_layers
self.self_attn = torch.jit.script(
RTAttention(
d_node,
d_edge,
d_attn_hid,
n_heads,
modulate_v=modulate_v,
use_ln=use_ln,
)
)
# self.self_attn = RTAttention(d_hid, d_hid, d_hid, n_heads)
self.lin0 = Linear(d_node, d_node)
self.dropout0 = nn.Dropout(dropout)
if use_ln:
self.node_ln0 = nn.LayerNorm(d_node)
self.node_ln1 = nn.LayerNorm(d_node)
else:
self.node_ln0 = nn.Identity()
self.node_ln1 = nn.Identity()
act_fn = nn.GELU
self.node_mlp = nn.Sequential(
Linear(d_node, d_node_hid, bias=False),
act_fn(),
Linear(d_node_hid, d_node),
nn.Dropout(dropout),
)
if not self.disable_edge_updates:
self.edge_updates = EdgeLayer(
d_node=d_node,
d_edge=d_edge,
d_edge_hid=d_edge_hid,
dropout=dropout,
act_fn=act_fn,
use_ln=use_ln,
)
else:
self.edge_updates = NoEdgeLayer()
if tfixit_init:
self.fixit_init()
def fixit_init(self):
temp_state_dict = self.state_dict()
n_layers = self.n_layers
for name, param in self.named_parameters():
if "weight" in name:
if name.split(".")[0] in ["node_mlp", "edge_mlp0", "edge_mlp1"]:
temp_state_dict[name] = (0.67 * (n_layers) ** (-1.0 / 4.0)) * param
elif name.split(".")[0] in ["self_attn"]:
temp_state_dict[name] = (0.67 * (n_layers) ** (-1.0 / 4.0)) * (
param * (2**0.5)
)
self.load_state_dict(temp_state_dict)
def node_updates(self, node_features, edge_features, mask):
out = self.self_attn(node_features, edge_features, mask)
attn_out, attn_weights = out
node_features = self.node_ln0(
node_features
+ self.dropout0(
self.lin0(attn_out)
)
)
node_features = self.node_ln1(node_features + self.node_mlp(node_features))
return node_features, attn_weights
def forward(self, node_features, edge_features, mask):
node_features, attn_weights = self.node_updates(node_features, edge_features, mask)
edge_features = self.edge_updates(node_features, edge_features, mask)
return node_features, edge_features, attn_weights
class EdgeLayer(nn.Module):
def __init__(
self,
*,
d_node,
d_edge,
d_edge_hid,
dropout,
act_fn,
use_ln=True,
) -> None:
super().__init__()
self.edge_mlp0 = EdgeMLP(
d_edge=d_edge,
d_node=d_node,
d_edge_hid=d_edge_hid,
act_fn=act_fn,
dropout=dropout,
)
self.edge_mlp1 = nn.Sequential(
Linear(d_edge, d_edge_hid, bias=False),
act_fn(),
Linear(d_edge_hid, d_edge),
nn.Dropout(dropout),
)
if use_ln:
self.eln0 = nn.LayerNorm(d_edge)
self.eln1 = nn.LayerNorm(d_edge)
else:
self.eln0 = nn.Identity()
self.eln1 = nn.Identity()
def forward(self, node_features, edge_features, mask):
edge_features = self.eln0(
edge_features + self.edge_mlp0(node_features, edge_features)
)
edge_features = self.eln1(edge_features + self.edge_mlp1(edge_features))
return edge_features
class NoEdgeLayer(nn.Module):
def forward(self, node_features, edge_features, mask):
return edge_features
class EdgeMLP(nn.Module):
def __init__(self, *, d_node, d_edge, d_edge_hid, act_fn, dropout):
super().__init__()
self.reverse_edge = Rearrange("b n m d -> b m n d")
self.lin0_e = Linear(2 * d_edge, d_edge_hid)
self.lin0_s = Linear(d_node, d_edge_hid)
self.lin0_t = Linear(d_node, d_edge_hid)
self.act = act_fn()
self.lin1 = Linear(d_edge_hid, d_edge)
self.drop = nn.Dropout(dropout)
def forward(self, node_features, edge_features):
source_nodes = (
self.lin0_s(node_features)
.unsqueeze(-2)
.expand(-1, -1, node_features.size(-2), -1)
)
target_nodes = (
self.lin0_t(node_features)
.unsqueeze(-3)
.expand(-1, node_features.size(-2), -1, -1)
)
# reversed_edge_features = self.reverse_edge(edge_features)
edge_features = self.lin0_e(
torch.cat([edge_features, self.reverse_edge(edge_features)], dim=-1)
)
edge_features = edge_features + source_nodes + target_nodes
edge_features = self.act(edge_features)
edge_features = self.lin1(edge_features)
edge_features = self.drop(edge_features)
return edge_features
class RTAttention(nn.Module):
def __init__(self, d_node, d_edge, d_hid, n_heads, modulate_v=None, use_ln=True):
super().__init__()
self.n_heads = n_heads
self.d_node = d_node
self.d_edge = d_edge
self.d_hid = d_hid
self.use_ln = use_ln
self.modulate_v = modulate_v
self.scale = 1 / (d_hid**0.5)
self.split_head_node = Rearrange("b n (h d) -> b h n d", h=n_heads)
self.split_head_edge = Rearrange("b n m (h d) -> b h n m d", h=n_heads)
self.cat_head_node = Rearrange("... h n d -> ... n (h d)", h=n_heads)
self.qkv_node = Linear(d_node, 3 * d_hid, bias=False)
self.edge_factor = 4 if modulate_v else 3
self.qkv_edge = Linear(d_edge, self.edge_factor * d_hid, bias=False)
self.proj_out = Linear(d_hid, d_node)
def forward(self, node_features, edge_features, mask):
qkv_node = self.qkv_node(node_features)
# qkv_node = rearrange(qkv_node, "b n (h d) -> b h n d", h=self.n_heads)
qkv_node = self.split_head_node(qkv_node)
q_node, k_node, v_node = torch.chunk(qkv_node, 3, dim=-1)
qkv_edge = self.qkv_edge(edge_features)
# qkv_edge = rearrange(qkv_edge, "b n m (h d) -> b h n m d", h=self.n_heads)
qkv_edge = self.split_head_edge(qkv_edge)
qkv_edge = torch.chunk(qkv_edge, self.edge_factor, dim=-1)
# q_edge, k_edge, v_edge, q_edge_b, k_edge_b, v_edge_b = torch.chunk(
# qkv_edge, 6, dim=-1
# )
# qkv_edge = [item.masked_fill(mask.unsqueeze(1) == 0, 0) for item in qkv_edge]
q = q_node.unsqueeze(-2) + qkv_edge[0] # + q_edge_b
k = k_node.unsqueeze(-3) + qkv_edge[1] # + k_edge_b
if self.modulate_v:
v = v_node.unsqueeze(-3) * qkv_edge[3] + qkv_edge[2]
else:
v = v_node.unsqueeze(-3) + qkv_edge[2]
dots = self.scale * torch.einsum("b h i j d, b h i j d -> b h i j", q, k)
# dots.masked_fill_(mask.unsqueeze(1).squeeze(-1) == 0, -1e-9)
attn = F.softmax(dots, dim=-1)
out = torch.einsum("b h i j, b h i j d -> b h i d", attn, v)
out = self.cat_head_node(out)
return self.proj_out(out), attn
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight) # , gain=1 / math.sqrt(2))
if bias:
nn.init.constant_(m.bias, 0.0)
return m