import hydra import torch import torch.nn.functional as F from einops.layers.torch import Rearrange from utils.pooling import HomogeneousAggregator import torch.nn as nn class RelationalTransformer(nn.Module): def __init__( self, d_node, d_edge, d_attn_hid, d_node_hid, d_edge_hid, d_out_hid, d_out, n_layers, n_heads, layer_layout, graph_constructor, dropout=0.0, node_update_type="rt", disable_edge_updates=False, use_cls_token=False, pooling_method="cat", pooling_layer_idx="last", rev_edge_features=False, modulate_v=True, use_ln=True, tfixit_init=False, ): super().__init__() assert use_cls_token == (pooling_method == "cls_token") self.pooling_method = pooling_method self.pooling_layer_idx = pooling_layer_idx self.rev_edge_features = rev_edge_features self.nodes_per_layer = layer_layout self.construct_graph = hydra.utils.instantiate( graph_constructor, d_node=d_node, d_edge=d_edge, layer_layout=layer_layout, rev_edge_features=rev_edge_features, ) self.use_cls_token = use_cls_token if use_cls_token: self.cls_token = nn.Parameter(torch.randn(d_node)) self.layers = nn.ModuleList( [ torch.jit.script( RTLayer( d_node, d_edge, d_attn_hid, d_node_hid, d_edge_hid, n_heads, dropout, node_update_type=node_update_type, disable_edge_updates=( (disable_edge_updates or (i == n_layers - 1)) and pooling_method != "mean_edge" and pooling_layer_idx != "all" ), modulate_v=modulate_v, use_ln=use_ln, tfixit_init=tfixit_init, n_layers=n_layers, ) ) for i in range(n_layers) ] ) if pooling_method != "cls_token": self.pool = HomogeneousAggregator( pooling_method, pooling_layer_idx, layer_layout, ) self.num_graph_features = ( layer_layout[-1] * d_node if pooling_method == "cat" and pooling_layer_idx == "last" else d_edge if pooling_method in ("mean_edge", "max_edge") else d_node ) self.proj_out = nn.Sequential( nn.Linear(self.num_graph_features, d_out_hid), nn.ReLU(), # nn.Linear(d_out_hid, d_out_hid), # nn.ReLU(), nn.Linear(d_out_hid, d_out), ) self.final_features = (None,None,None,None) def forward(self, inputs): attn_weights = None node_features, edge_features, mask = self.construct_graph(inputs) if self.use_cls_token: node_features = torch.cat( [ # repeat(self.cls_token, "d -> b 1 d", b=node_features.size(0)), self.cls_token.unsqueeze(0).expand(node_features.size(0), 1, -1), node_features, ], dim=1, ) edge_features = F.pad(edge_features, (0, 0, 1, 0, 1, 0), value=0) for layer in self.layers: node_features, edge_features, attn_weights = layer(node_features, edge_features, mask) if self.pooling_method == "cls_token": graph_features = node_features[:, 0] else: graph_features = self.pool(node_features, edge_features) self.final_features = (graph_features, node_features, edge_features, attn_weights) return self.proj_out(graph_features) class RTLayer(nn.Module): def __init__( self, d_node, d_edge, d_attn_hid, d_node_hid, d_edge_hid, n_heads, dropout, node_update_type="rt", disable_edge_updates=False, modulate_v=True, use_ln=True, tfixit_init=False, n_layers=None, ): super().__init__() self.node_update_type = node_update_type self.disable_edge_updates = disable_edge_updates self.use_ln = use_ln self.n_layers = n_layers self.self_attn = torch.jit.script( RTAttention( d_node, d_edge, d_attn_hid, n_heads, modulate_v=modulate_v, use_ln=use_ln, ) ) # self.self_attn = RTAttention(d_hid, d_hid, d_hid, n_heads) self.lin0 = Linear(d_node, d_node) self.dropout0 = nn.Dropout(dropout) if use_ln: self.node_ln0 = nn.LayerNorm(d_node) self.node_ln1 = nn.LayerNorm(d_node) else: self.node_ln0 = nn.Identity() self.node_ln1 = nn.Identity() act_fn = nn.GELU self.node_mlp = nn.Sequential( Linear(d_node, d_node_hid, bias=False), act_fn(), Linear(d_node_hid, d_node), nn.Dropout(dropout), ) if not self.disable_edge_updates: self.edge_updates = EdgeLayer( d_node=d_node, d_edge=d_edge, d_edge_hid=d_edge_hid, dropout=dropout, act_fn=act_fn, use_ln=use_ln, ) else: self.edge_updates = NoEdgeLayer() if tfixit_init: self.fixit_init() def fixit_init(self): temp_state_dict = self.state_dict() n_layers = self.n_layers for name, param in self.named_parameters(): if "weight" in name: if name.split(".")[0] in ["node_mlp", "edge_mlp0", "edge_mlp1"]: temp_state_dict[name] = (0.67 * (n_layers) ** (-1.0 / 4.0)) * param elif name.split(".")[0] in ["self_attn"]: temp_state_dict[name] = (0.67 * (n_layers) ** (-1.0 / 4.0)) * ( param * (2**0.5) ) self.load_state_dict(temp_state_dict) def node_updates(self, node_features, edge_features, mask): out = self.self_attn(node_features, edge_features, mask) attn_out, attn_weights = out node_features = self.node_ln0( node_features + self.dropout0( self.lin0(attn_out) ) ) node_features = self.node_ln1(node_features + self.node_mlp(node_features)) return node_features, attn_weights def forward(self, node_features, edge_features, mask): node_features, attn_weights = self.node_updates(node_features, edge_features, mask) edge_features = self.edge_updates(node_features, edge_features, mask) return node_features, edge_features, attn_weights class EdgeLayer(nn.Module): def __init__( self, *, d_node, d_edge, d_edge_hid, dropout, act_fn, use_ln=True, ) -> None: super().__init__() self.edge_mlp0 = EdgeMLP( d_edge=d_edge, d_node=d_node, d_edge_hid=d_edge_hid, act_fn=act_fn, dropout=dropout, ) self.edge_mlp1 = nn.Sequential( Linear(d_edge, d_edge_hid, bias=False), act_fn(), Linear(d_edge_hid, d_edge), nn.Dropout(dropout), ) if use_ln: self.eln0 = nn.LayerNorm(d_edge) self.eln1 = nn.LayerNorm(d_edge) else: self.eln0 = nn.Identity() self.eln1 = nn.Identity() def forward(self, node_features, edge_features, mask): edge_features = self.eln0( edge_features + self.edge_mlp0(node_features, edge_features) ) edge_features = self.eln1(edge_features + self.edge_mlp1(edge_features)) return edge_features class NoEdgeLayer(nn.Module): def forward(self, node_features, edge_features, mask): return edge_features class EdgeMLP(nn.Module): def __init__(self, *, d_node, d_edge, d_edge_hid, act_fn, dropout): super().__init__() self.reverse_edge = Rearrange("b n m d -> b m n d") self.lin0_e = Linear(2 * d_edge, d_edge_hid) self.lin0_s = Linear(d_node, d_edge_hid) self.lin0_t = Linear(d_node, d_edge_hid) self.act = act_fn() self.lin1 = Linear(d_edge_hid, d_edge) self.drop = nn.Dropout(dropout) def forward(self, node_features, edge_features): source_nodes = ( self.lin0_s(node_features) .unsqueeze(-2) .expand(-1, -1, node_features.size(-2), -1) ) target_nodes = ( self.lin0_t(node_features) .unsqueeze(-3) .expand(-1, node_features.size(-2), -1, -1) ) # reversed_edge_features = self.reverse_edge(edge_features) edge_features = self.lin0_e( torch.cat([edge_features, self.reverse_edge(edge_features)], dim=-1) ) edge_features = edge_features + source_nodes + target_nodes edge_features = self.act(edge_features) edge_features = self.lin1(edge_features) edge_features = self.drop(edge_features) return edge_features class RTAttention(nn.Module): def __init__(self, d_node, d_edge, d_hid, n_heads, modulate_v=None, use_ln=True): super().__init__() self.n_heads = n_heads self.d_node = d_node self.d_edge = d_edge self.d_hid = d_hid self.use_ln = use_ln self.modulate_v = modulate_v self.scale = 1 / (d_hid**0.5) self.split_head_node = Rearrange("b n (h d) -> b h n d", h=n_heads) self.split_head_edge = Rearrange("b n m (h d) -> b h n m d", h=n_heads) self.cat_head_node = Rearrange("... h n d -> ... n (h d)", h=n_heads) self.qkv_node = Linear(d_node, 3 * d_hid, bias=False) self.edge_factor = 4 if modulate_v else 3 self.qkv_edge = Linear(d_edge, self.edge_factor * d_hid, bias=False) self.proj_out = Linear(d_hid, d_node) def forward(self, node_features, edge_features, mask): qkv_node = self.qkv_node(node_features) # qkv_node = rearrange(qkv_node, "b n (h d) -> b h n d", h=self.n_heads) qkv_node = self.split_head_node(qkv_node) q_node, k_node, v_node = torch.chunk(qkv_node, 3, dim=-1) qkv_edge = self.qkv_edge(edge_features) # qkv_edge = rearrange(qkv_edge, "b n m (h d) -> b h n m d", h=self.n_heads) qkv_edge = self.split_head_edge(qkv_edge) qkv_edge = torch.chunk(qkv_edge, self.edge_factor, dim=-1) # q_edge, k_edge, v_edge, q_edge_b, k_edge_b, v_edge_b = torch.chunk( # qkv_edge, 6, dim=-1 # ) # qkv_edge = [item.masked_fill(mask.unsqueeze(1) == 0, 0) for item in qkv_edge] q = q_node.unsqueeze(-2) + qkv_edge[0] # + q_edge_b k = k_node.unsqueeze(-3) + qkv_edge[1] # + k_edge_b if self.modulate_v: v = v_node.unsqueeze(-3) * qkv_edge[3] + qkv_edge[2] else: v = v_node.unsqueeze(-3) + qkv_edge[2] dots = self.scale * torch.einsum("b h i j d, b h i j d -> b h i j", q, k) # dots.masked_fill_(mask.unsqueeze(1).squeeze(-1) == 0, -1e-9) attn = F.softmax(dots, dim=-1) out = torch.einsum("b h i j, b h i j d -> b h i d", attn, v) out = self.cat_head_node(out) return self.proj_out(out), attn def Linear(in_features, out_features, bias=True): m = nn.Linear(in_features, out_features, bias) nn.init.xavier_uniform_(m.weight) # , gain=1 / math.sqrt(2)) if bias: nn.init.constant_(m.bias, 0.0) return m