submission-template / tasks /utils /graph_constructor.py
IlayMalinyak
cnnkan
2f54ec8
import torch
import torch.nn as nn
from rff.layers import GaussianEncoding
# from nn.probe_features import GraphProbeFeatures
def sparsify_graph(edges, fraction=0.1):
abs_edges = torch.abs(edges)
flat_abs_tensor = abs_edges.flatten()
sorted_tensor, _ = torch.sort(flat_abs_tensor, descending=True)
num_elements = flat_abs_tensor.numel()
top_k = int(num_elements * fraction)
topk_values, topk_indices = torch.topk(flat_abs_tensor, top_k)
mask = torch.zeros_like(flat_abs_tensor, dtype=torch.bool)
mask[topk_indices] = True
mask = mask.view(edges.shape)
return mask
def batch_to_graphs(
weights,
biases,
weights_mean=None,
weights_std=None,
biases_mean=None,
biases_std=None,
sparsify=False,
sym_edges=False
):
device = weights[0].device
bsz = weights[0].shape[0]
num_nodes = weights[0].shape[1] + sum(w.shape[2] for w in weights)
node_features = torch.zeros(bsz, num_nodes, biases[0].shape[-1], device=device)
edge_features = torch.zeros(
bsz, num_nodes, num_nodes, weights[0].shape[-1], device=device
)
row_offset = 0
col_offset = weights[0].shape[1] # no edge to input nodes
for i, w in enumerate(weights):
_, num_in, num_out, _ = w.shape
w_mean = weights_mean[i] if weights_mean is not None else 0
w_std = weights_std[i] if weights_std is not None else 1
w = (w - w_mean) / w_std
if sparsify:
w[~sparsify_graph(w)] = 0
edge_features[
:, row_offset : row_offset + num_in, col_offset : col_offset + num_out
] = w
if sym_edges:
edge_features[
:, col_offset: col_offset + num_out, row_offset: row_offset + num_in
] = torch.swapaxes(w, 1,2)
row_offset += num_in
col_offset += num_out
row_offset = weights[0].shape[1] # no bias in input nodes
for i, b in enumerate(biases):
_, num_out, _ = b.shape
b_mean = biases_mean[i] if biases_mean is not None else 0
b_std = biases_std[i] if biases_std is not None else 1
node_features[:, row_offset : row_offset + num_out] = (b - b_mean) / b_std
row_offset += num_out
return node_features, edge_features
class GraphConstructor(nn.Module):
def __init__(
self,
d_in,
d_edge_in,
d_node,
d_edge,
layer_layout,
rev_edge_features=False,
zero_out_bias=False,
zero_out_weights=False,
inp_factor=1,
input_layers=1,
sin_emb=False,
sin_emb_dim=128,
use_pos_embed=False,
num_probe_features=0,
inr_model=None,
stats=None,
sparsify=False,
sym_edges=False,
):
super().__init__()
self.rev_edge_features = rev_edge_features
self.nodes_per_layer = layer_layout
self.zero_out_bias = zero_out_bias
self.zero_out_weights = zero_out_weights
self.use_pos_embed = use_pos_embed
self.stats = stats if stats is not None else {}
self._d_node = d_node
self._d_edge = d_edge
self.sparse = sparsify
self.sym_edges = sym_edges
self.pos_embed_layout = (
[1] * layer_layout[0] + layer_layout[1:-1] + [1] * layer_layout[-1]
)
self.pos_embed = nn.Parameter(torch.randn(len(self.pos_embed_layout), d_node))
if not self.zero_out_weights:
proj_weight = []
if sin_emb:
proj_weight.append(
GaussianEncoding(
sigma=inp_factor,
input_size=d_edge_in
+ (2 * d_edge_in if rev_edge_features else 0),
encoded_size=sin_emb_dim,
)
)
proj_weight.append(nn.Linear(2 * sin_emb_dim, d_edge))
else:
proj_weight.append(
nn.Linear(
d_edge_in + (2 * d_edge_in if rev_edge_features else 0), d_edge
)
)
for i in range(input_layers - 1):
proj_weight.append(nn.SiLU())
proj_weight.append(nn.Linear(d_edge, d_edge))
self.proj_weight = nn.Sequential(*proj_weight)
if not self.zero_out_bias:
proj_bias = []
if sin_emb:
proj_bias.append(
GaussianEncoding(
sigma=inp_factor,
input_size=d_in,
encoded_size=sin_emb_dim,
)
)
proj_bias.append(nn.Linear(2 * sin_emb_dim, d_node))
else:
proj_bias.append(nn.Linear(d_in, d_node))
for i in range(input_layers - 1):
proj_bias.append(nn.SiLU())
proj_bias.append(nn.Linear(d_node, d_node))
self.proj_bias = nn.Sequential(*proj_bias)
self.proj_node_in = nn.Linear(d_node, d_node)
self.proj_edge_in = nn.Linear(d_edge, d_edge)
if num_probe_features > 0:
self.gpf = GraphProbeFeatures(
d_in=layer_layout[0],
num_inputs=num_probe_features,
inr_model=inr_model,
input_init=None,
proj_dim=d_node,
)
else:
self.gpf = None
def forward(self, inputs):
node_features, edge_features = batch_to_graphs(*inputs, **self.stats,
)
mask = edge_features.sum(dim=-1, keepdim=True) != 0
if self.rev_edge_features:
rev_edge_features = edge_features.transpose(-2, -3)
edge_features = torch.cat(
[edge_features, rev_edge_features, edge_features + rev_edge_features],
dim=-1,
)
mask = mask | mask.transpose(-3, -2)
if self.zero_out_weights:
edge_features = torch.zeros(
(*edge_features.shape[:-1], self._d_edge),
device=edge_features.device,
dtype=edge_features.dtype,
)
else:
edge_features = self.proj_weight(edge_features)
if self.zero_out_bias:
# only zero out bias, not gpf
node_features = torch.zeros(
(*node_features.shape[:-1], self._d_node),
device=node_features.device,
dtype=node_features.dtype,
)
else:
node_features = self.proj_bias(node_features)
if self.gpf is not None:
probe_features = self.gpf(*inputs)
node_features = node_features + probe_features
node_features = self.proj_node_in(node_features)
edge_features = self.proj_edge_in(edge_features)
if self.use_pos_embed:
pos_embed = torch.cat(
[
# repeat(self.pos_embed[i], "d -> 1 n d", n=n)
self.pos_embed[i].unsqueeze(0).expand(1, n, -1)
for i, n in enumerate(self.pos_embed_layout)
],
dim=1,
)
node_features = node_features + pos_embed
return node_features, edge_features, mask