import copy import math from typing import Optional import torch import torch.nn.functional as F from rff.layers import GaussianEncoding, PositionalEncoding from torch import nn from .kan.fasterkan import FasterKAN class Sine(nn.Module): def __init__(self, w0=1.0): super().__init__() self.w0 = w0 def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.sin(self.w0 * x) def params_to_tensor(params): return torch.cat([p.flatten() for p in params]), [p.shape for p in params] def tensor_to_params(tensor, shapes): params = [] start = 0 for shape in shapes: size = torch.prod(torch.tensor(shape)).item() param = tensor[start : start + size].reshape(shape) params.append(param) start += size return tuple(params) def wrap_func(func, shapes): def wrapped_func(params, *args, **kwargs): params = tensor_to_params(params, shapes) return func(params, *args, **kwargs) return wrapped_func class Siren(nn.Module): def __init__( self, dim_in, dim_out, w0=30.0, c=6.0, is_first=False, use_bias=True, activation=None, ): super().__init__() self.w0 = w0 self.c = c self.dim_in = dim_in self.dim_out = dim_out self.is_first = is_first weight = torch.zeros(dim_out, dim_in) bias = torch.zeros(dim_out) if use_bias else None self.init_(weight, bias, c=c, w0=w0) self.weight = nn.Parameter(weight) self.bias = nn.Parameter(bias) if use_bias else None self.activation = Sine(w0) if activation is None else activation def init_(self, weight: torch.Tensor, bias: torch.Tensor, c: float, w0: float): dim = self.dim_in w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0) weight.uniform_(-w_std, w_std) if bias is not None: # bias.uniform_(-w_std, w_std) bias.zero_() def forward(self, x: torch.Tensor) -> torch.Tensor: out = F.linear(x, self.weight, self.bias) out = self.activation(out) return out class INR(nn.Module): def __init__( self, in_features: int = 2, n_layers: int = 3, hidden_features: int = 32, out_features: int = 1, pe_features: Optional[int] = None, fix_pe=True, ): super().__init__() if pe_features is not None: if fix_pe: self.layers = [PositionalEncoding(sigma=10, m=pe_features)] encoded_dim = in_features * pe_features * 2 else: self.layers = [ GaussianEncoding( sigma=10, input_size=in_features, encoded_size=pe_features ) ] encoded_dim = pe_features * 2 self.layers.append(Siren(dim_in=encoded_dim, dim_out=hidden_features)) else: self.layers = [Siren(dim_in=in_features, dim_out=hidden_features)] for i in range(n_layers - 2): self.layers.append(Siren(hidden_features, hidden_features)) self.layers.append(nn.Linear(hidden_features, out_features)) self.seq = nn.Sequential(*self.layers) self.num_layers = len(self.layers) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.seq(x) + 0.5 class INRPerLayer(INR): def forward(self, x: torch.Tensor) -> torch.Tensor: nodes = [x] for layer in self.seq: nodes.append(layer(nodes[-1])) nodes[-1] = nodes[-1] + 0.5 return nodes def make_functional(mod, disable_autograd_tracking=False): params_dict = dict(mod.named_parameters()) params_names = params_dict.keys() params_values = tuple(params_dict.values()) stateless_mod = copy.deepcopy(mod) stateless_mod.to("meta") def fmodel(new_params_values, *args, **kwargs): new_params_dict = { name: value for name, value in zip(params_names, new_params_values) } return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs) if disable_autograd_tracking: params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values) return fmodel, params_values