Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from rff.layers import GaussianEncoding | |
# from nn.probe_features import GraphProbeFeatures | |
def sparsify_graph(edges, fraction=0.1): | |
abs_edges = torch.abs(edges) | |
flat_abs_tensor = abs_edges.flatten() | |
sorted_tensor, _ = torch.sort(flat_abs_tensor, descending=True) | |
num_elements = flat_abs_tensor.numel() | |
top_k = int(num_elements * fraction) | |
topk_values, topk_indices = torch.topk(flat_abs_tensor, top_k) | |
mask = torch.zeros_like(flat_abs_tensor, dtype=torch.bool) | |
mask[topk_indices] = True | |
mask = mask.view(edges.shape) | |
return mask | |
def batch_to_graphs( | |
weights, | |
biases, | |
weights_mean=None, | |
weights_std=None, | |
biases_mean=None, | |
biases_std=None, | |
sparsify=False, | |
sym_edges=False | |
): | |
device = weights[0].device | |
bsz = weights[0].shape[0] | |
num_nodes = weights[0].shape[1] + sum(w.shape[2] for w in weights) | |
node_features = torch.zeros(bsz, num_nodes, biases[0].shape[-1], device=device) | |
edge_features = torch.zeros( | |
bsz, num_nodes, num_nodes, weights[0].shape[-1], device=device | |
) | |
row_offset = 0 | |
col_offset = weights[0].shape[1] # no edge to input nodes | |
for i, w in enumerate(weights): | |
_, num_in, num_out, _ = w.shape | |
w_mean = weights_mean[i] if weights_mean is not None else 0 | |
w_std = weights_std[i] if weights_std is not None else 1 | |
w = (w - w_mean) / w_std | |
if sparsify: | |
w[~sparsify_graph(w)] = 0 | |
edge_features[ | |
:, row_offset : row_offset + num_in, col_offset : col_offset + num_out | |
] = w | |
if sym_edges: | |
edge_features[ | |
:, col_offset: col_offset + num_out, row_offset: row_offset + num_in | |
] = torch.swapaxes(w, 1,2) | |
row_offset += num_in | |
col_offset += num_out | |
row_offset = weights[0].shape[1] # no bias in input nodes | |
for i, b in enumerate(biases): | |
_, num_out, _ = b.shape | |
b_mean = biases_mean[i] if biases_mean is not None else 0 | |
b_std = biases_std[i] if biases_std is not None else 1 | |
node_features[:, row_offset : row_offset + num_out] = (b - b_mean) / b_std | |
row_offset += num_out | |
return node_features, edge_features | |
class GraphConstructor(nn.Module): | |
def __init__( | |
self, | |
d_in, | |
d_edge_in, | |
d_node, | |
d_edge, | |
layer_layout, | |
rev_edge_features=False, | |
zero_out_bias=False, | |
zero_out_weights=False, | |
inp_factor=1, | |
input_layers=1, | |
sin_emb=False, | |
sin_emb_dim=128, | |
use_pos_embed=False, | |
num_probe_features=0, | |
inr_model=None, | |
stats=None, | |
sparsify=False, | |
sym_edges=False, | |
): | |
super().__init__() | |
self.rev_edge_features = rev_edge_features | |
self.nodes_per_layer = layer_layout | |
self.zero_out_bias = zero_out_bias | |
self.zero_out_weights = zero_out_weights | |
self.use_pos_embed = use_pos_embed | |
self.stats = stats if stats is not None else {} | |
self._d_node = d_node | |
self._d_edge = d_edge | |
self.sparse = sparsify | |
self.sym_edges = sym_edges | |
self.pos_embed_layout = ( | |
[1] * layer_layout[0] + layer_layout[1:-1] + [1] * layer_layout[-1] | |
) | |
self.pos_embed = nn.Parameter(torch.randn(len(self.pos_embed_layout), d_node)) | |
if not self.zero_out_weights: | |
proj_weight = [] | |
if sin_emb: | |
proj_weight.append( | |
GaussianEncoding( | |
sigma=inp_factor, | |
input_size=d_edge_in | |
+ (2 * d_edge_in if rev_edge_features else 0), | |
encoded_size=sin_emb_dim, | |
) | |
) | |
proj_weight.append(nn.Linear(2 * sin_emb_dim, d_edge)) | |
else: | |
proj_weight.append( | |
nn.Linear( | |
d_edge_in + (2 * d_edge_in if rev_edge_features else 0), d_edge | |
) | |
) | |
for i in range(input_layers - 1): | |
proj_weight.append(nn.SiLU()) | |
proj_weight.append(nn.Linear(d_edge, d_edge)) | |
self.proj_weight = nn.Sequential(*proj_weight) | |
if not self.zero_out_bias: | |
proj_bias = [] | |
if sin_emb: | |
proj_bias.append( | |
GaussianEncoding( | |
sigma=inp_factor, | |
input_size=d_in, | |
encoded_size=sin_emb_dim, | |
) | |
) | |
proj_bias.append(nn.Linear(2 * sin_emb_dim, d_node)) | |
else: | |
proj_bias.append(nn.Linear(d_in, d_node)) | |
for i in range(input_layers - 1): | |
proj_bias.append(nn.SiLU()) | |
proj_bias.append(nn.Linear(d_node, d_node)) | |
self.proj_bias = nn.Sequential(*proj_bias) | |
self.proj_node_in = nn.Linear(d_node, d_node) | |
self.proj_edge_in = nn.Linear(d_edge, d_edge) | |
if num_probe_features > 0: | |
self.gpf = GraphProbeFeatures( | |
d_in=layer_layout[0], | |
num_inputs=num_probe_features, | |
inr_model=inr_model, | |
input_init=None, | |
proj_dim=d_node, | |
) | |
else: | |
self.gpf = None | |
def forward(self, inputs): | |
node_features, edge_features = batch_to_graphs(*inputs, **self.stats, | |
) | |
mask = edge_features.sum(dim=-1, keepdim=True) != 0 | |
if self.rev_edge_features: | |
rev_edge_features = edge_features.transpose(-2, -3) | |
edge_features = torch.cat( | |
[edge_features, rev_edge_features, edge_features + rev_edge_features], | |
dim=-1, | |
) | |
mask = mask | mask.transpose(-3, -2) | |
if self.zero_out_weights: | |
edge_features = torch.zeros( | |
(*edge_features.shape[:-1], self._d_edge), | |
device=edge_features.device, | |
dtype=edge_features.dtype, | |
) | |
else: | |
edge_features = self.proj_weight(edge_features) | |
if self.zero_out_bias: | |
# only zero out bias, not gpf | |
node_features = torch.zeros( | |
(*node_features.shape[:-1], self._d_node), | |
device=node_features.device, | |
dtype=node_features.dtype, | |
) | |
else: | |
node_features = self.proj_bias(node_features) | |
if self.gpf is not None: | |
probe_features = self.gpf(*inputs) | |
node_features = node_features + probe_features | |
node_features = self.proj_node_in(node_features) | |
edge_features = self.proj_edge_in(edge_features) | |
if self.use_pos_embed: | |
pos_embed = torch.cat( | |
[ | |
# repeat(self.pos_embed[i], "d -> 1 n d", n=n) | |
self.pos_embed[i].unsqueeze(0).expand(1, n, -1) | |
for i, n in enumerate(self.pos_embed_layout) | |
], | |
dim=1, | |
) | |
node_features = node_features + pos_embed | |
return node_features, edge_features, mask | |