import torch import torch.nn as nn from rff.layers import GaussianEncoding # from nn.probe_features import GraphProbeFeatures def sparsify_graph(edges, fraction=0.1): abs_edges = torch.abs(edges) flat_abs_tensor = abs_edges.flatten() sorted_tensor, _ = torch.sort(flat_abs_tensor, descending=True) num_elements = flat_abs_tensor.numel() top_k = int(num_elements * fraction) topk_values, topk_indices = torch.topk(flat_abs_tensor, top_k) mask = torch.zeros_like(flat_abs_tensor, dtype=torch.bool) mask[topk_indices] = True mask = mask.view(edges.shape) return mask def batch_to_graphs( weights, biases, weights_mean=None, weights_std=None, biases_mean=None, biases_std=None, sparsify=False, sym_edges=False ): device = weights[0].device bsz = weights[0].shape[0] num_nodes = weights[0].shape[1] + sum(w.shape[2] for w in weights) node_features = torch.zeros(bsz, num_nodes, biases[0].shape[-1], device=device) edge_features = torch.zeros( bsz, num_nodes, num_nodes, weights[0].shape[-1], device=device ) row_offset = 0 col_offset = weights[0].shape[1] # no edge to input nodes for i, w in enumerate(weights): _, num_in, num_out, _ = w.shape w_mean = weights_mean[i] if weights_mean is not None else 0 w_std = weights_std[i] if weights_std is not None else 1 w = (w - w_mean) / w_std if sparsify: w[~sparsify_graph(w)] = 0 edge_features[ :, row_offset : row_offset + num_in, col_offset : col_offset + num_out ] = w if sym_edges: edge_features[ :, col_offset: col_offset + num_out, row_offset: row_offset + num_in ] = torch.swapaxes(w, 1,2) row_offset += num_in col_offset += num_out row_offset = weights[0].shape[1] # no bias in input nodes for i, b in enumerate(biases): _, num_out, _ = b.shape b_mean = biases_mean[i] if biases_mean is not None else 0 b_std = biases_std[i] if biases_std is not None else 1 node_features[:, row_offset : row_offset + num_out] = (b - b_mean) / b_std row_offset += num_out return node_features, edge_features class GraphConstructor(nn.Module): def __init__( self, d_in, d_edge_in, d_node, d_edge, layer_layout, rev_edge_features=False, zero_out_bias=False, zero_out_weights=False, inp_factor=1, input_layers=1, sin_emb=False, sin_emb_dim=128, use_pos_embed=False, num_probe_features=0, inr_model=None, stats=None, sparsify=False, sym_edges=False, ): super().__init__() self.rev_edge_features = rev_edge_features self.nodes_per_layer = layer_layout self.zero_out_bias = zero_out_bias self.zero_out_weights = zero_out_weights self.use_pos_embed = use_pos_embed self.stats = stats if stats is not None else {} self._d_node = d_node self._d_edge = d_edge self.sparse = sparsify self.sym_edges = sym_edges self.pos_embed_layout = ( [1] * layer_layout[0] + layer_layout[1:-1] + [1] * layer_layout[-1] ) self.pos_embed = nn.Parameter(torch.randn(len(self.pos_embed_layout), d_node)) if not self.zero_out_weights: proj_weight = [] if sin_emb: proj_weight.append( GaussianEncoding( sigma=inp_factor, input_size=d_edge_in + (2 * d_edge_in if rev_edge_features else 0), encoded_size=sin_emb_dim, ) ) proj_weight.append(nn.Linear(2 * sin_emb_dim, d_edge)) else: proj_weight.append( nn.Linear( d_edge_in + (2 * d_edge_in if rev_edge_features else 0), d_edge ) ) for i in range(input_layers - 1): proj_weight.append(nn.SiLU()) proj_weight.append(nn.Linear(d_edge, d_edge)) self.proj_weight = nn.Sequential(*proj_weight) if not self.zero_out_bias: proj_bias = [] if sin_emb: proj_bias.append( GaussianEncoding( sigma=inp_factor, input_size=d_in, encoded_size=sin_emb_dim, ) ) proj_bias.append(nn.Linear(2 * sin_emb_dim, d_node)) else: proj_bias.append(nn.Linear(d_in, d_node)) for i in range(input_layers - 1): proj_bias.append(nn.SiLU()) proj_bias.append(nn.Linear(d_node, d_node)) self.proj_bias = nn.Sequential(*proj_bias) self.proj_node_in = nn.Linear(d_node, d_node) self.proj_edge_in = nn.Linear(d_edge, d_edge) if num_probe_features > 0: self.gpf = GraphProbeFeatures( d_in=layer_layout[0], num_inputs=num_probe_features, inr_model=inr_model, input_init=None, proj_dim=d_node, ) else: self.gpf = None def forward(self, inputs): node_features, edge_features = batch_to_graphs(*inputs, **self.stats, ) mask = edge_features.sum(dim=-1, keepdim=True) != 0 if self.rev_edge_features: rev_edge_features = edge_features.transpose(-2, -3) edge_features = torch.cat( [edge_features, rev_edge_features, edge_features + rev_edge_features], dim=-1, ) mask = mask | mask.transpose(-3, -2) if self.zero_out_weights: edge_features = torch.zeros( (*edge_features.shape[:-1], self._d_edge), device=edge_features.device, dtype=edge_features.dtype, ) else: edge_features = self.proj_weight(edge_features) if self.zero_out_bias: # only zero out bias, not gpf node_features = torch.zeros( (*node_features.shape[:-1], self._d_node), device=node_features.device, dtype=node_features.dtype, ) else: node_features = self.proj_bias(node_features) if self.gpf is not None: probe_features = self.gpf(*inputs) node_features = node_features + probe_features node_features = self.proj_node_in(node_features) edge_features = self.proj_edge_in(edge_features) if self.use_pos_embed: pos_embed = torch.cat( [ # repeat(self.pos_embed[i], "d -> 1 n d", n=n) self.pos_embed[i].unsqueeze(0).expand(1, n, -1) for i, n in enumerate(self.pos_embed_layout) ], dim=1, ) node_features = node_features + pos_embed return node_features, edge_features, mask