submission-template / tasks /utils /probe_features.py
IlayMalinyak
cnnkan
2f54ec8
import hydra
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
from nn.inr import make_functional, params_to_tensor, wrap_func
class GraphProbeFeatures(nn.Module):
def __init__(self, d_in, num_inputs, inr_model, input_init=None, proj_dim=None):
super().__init__()
inr = hydra.utils.instantiate(inr_model)
fmodel, params = make_functional(inr)
vparams, vshapes = params_to_tensor(params)
self.sirens = torch.vmap(wrap_func(fmodel, vshapes))
inputs = (
input_init
if input_init is not None
else 2 * torch.rand(1, num_inputs, d_in) - 1
)
self.inputs = nn.Parameter(inputs, requires_grad=input_init is None)
self.reshape_weights = Rearrange("b i o 1 -> b (o i)")
self.reshape_biases = Rearrange("b o 1 -> b o")
self.proj_dim = proj_dim
if proj_dim is not None:
self.proj = nn.ModuleList(
[
nn.Sequential(
nn.Linear(num_inputs, proj_dim),
nn.LayerNorm(proj_dim),
)
for _ in range(inr.num_layers + 1)
]
)
def forward(self, weights, biases):
weights = [self.reshape_weights(w) for w in weights]
biases = [self.reshape_biases(b) for b in biases]
params_flat = torch.cat(
[w_or_b for p in zip(weights, biases) for w_or_b in p], dim=-1
)
out = self.sirens(params_flat, self.inputs.expand(params_flat.shape[0], -1, -1))
if self.proj_dim is not None:
out = [proj(out[i].permute(0, 2, 1)) for i, proj in enumerate(self.proj)]
out = torch.cat(out, dim=1)
return out
else:
out = torch.cat(out, dim=-1)
return out.permute(0, 2, 1)