Spaces:
Sleeping
Sleeping
File size: 4,361 Bytes
2f54ec8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import copy
import math
from typing import Optional
import torch
import torch.nn.functional as F
from rff.layers import GaussianEncoding, PositionalEncoding
from torch import nn
from .kan.fasterkan import FasterKAN
class Sine(nn.Module):
def __init__(self, w0=1.0):
super().__init__()
self.w0 = w0
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.sin(self.w0 * x)
def params_to_tensor(params):
return torch.cat([p.flatten() for p in params]), [p.shape for p in params]
def tensor_to_params(tensor, shapes):
params = []
start = 0
for shape in shapes:
size = torch.prod(torch.tensor(shape)).item()
param = tensor[start : start + size].reshape(shape)
params.append(param)
start += size
return tuple(params)
def wrap_func(func, shapes):
def wrapped_func(params, *args, **kwargs):
params = tensor_to_params(params, shapes)
return func(params, *args, **kwargs)
return wrapped_func
class Siren(nn.Module):
def __init__(
self,
dim_in,
dim_out,
w0=30.0,
c=6.0,
is_first=False,
use_bias=True,
activation=None,
):
super().__init__()
self.w0 = w0
self.c = c
self.dim_in = dim_in
self.dim_out = dim_out
self.is_first = is_first
weight = torch.zeros(dim_out, dim_in)
bias = torch.zeros(dim_out) if use_bias else None
self.init_(weight, bias, c=c, w0=w0)
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias) if use_bias else None
self.activation = Sine(w0) if activation is None else activation
def init_(self, weight: torch.Tensor, bias: torch.Tensor, c: float, w0: float):
dim = self.dim_in
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
weight.uniform_(-w_std, w_std)
if bias is not None:
# bias.uniform_(-w_std, w_std)
bias.zero_()
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = F.linear(x, self.weight, self.bias)
out = self.activation(out)
return out
class INR(nn.Module):
def __init__(
self,
in_features: int = 2,
n_layers: int = 3,
hidden_features: int = 32,
out_features: int = 1,
pe_features: Optional[int] = None,
fix_pe=True,
):
super().__init__()
if pe_features is not None:
if fix_pe:
self.layers = [PositionalEncoding(sigma=10, m=pe_features)]
encoded_dim = in_features * pe_features * 2
else:
self.layers = [
GaussianEncoding(
sigma=10, input_size=in_features, encoded_size=pe_features
)
]
encoded_dim = pe_features * 2
self.layers.append(Siren(dim_in=encoded_dim, dim_out=hidden_features))
else:
self.layers = [Siren(dim_in=in_features, dim_out=hidden_features)]
for i in range(n_layers - 2):
self.layers.append(Siren(hidden_features, hidden_features))
self.layers.append(nn.Linear(hidden_features, out_features))
self.seq = nn.Sequential(*self.layers)
self.num_layers = len(self.layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.seq(x) + 0.5
class INRPerLayer(INR):
def forward(self, x: torch.Tensor) -> torch.Tensor:
nodes = [x]
for layer in self.seq:
nodes.append(layer(nodes[-1]))
nodes[-1] = nodes[-1] + 0.5
return nodes
def make_functional(mod, disable_autograd_tracking=False):
params_dict = dict(mod.named_parameters())
params_names = params_dict.keys()
params_values = tuple(params_dict.values())
stateless_mod = copy.deepcopy(mod)
stateless_mod.to("meta")
def fmodel(new_params_values, *args, **kwargs):
new_params_dict = {
name: value for name, value in zip(params_names, new_params_values)
}
return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
if disable_autograd_tracking:
params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
return fmodel, params_values
|