Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from mmpl.registry import MODELS | |
from mmengine.model import BaseModule | |
class Sirens(BaseModule): | |
def __init__(self, | |
in_channels, | |
out_channels=3, | |
base_channels=256, | |
num_inner_layers=2, | |
is_residual=True | |
): | |
super(Sirens, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.base_channels = base_channels | |
self.num_inner_layers = num_inner_layers | |
self.is_residual = is_residual | |
self.first_coord = nn.Linear(in_channels, base_channels) | |
self.inner_coords = nn.ModuleList(nn.Linear(base_channels, base_channels) for _ in range(self.num_inner_layers)) | |
self.last_coord = nn.Linear(base_channels, out_channels) | |
def forward(self, x): | |
x = self.first_coord(x) | |
x = torch.sin(x) | |
for idx in range(self.num_inner_layers): | |
residual = x | |
x = self.inner_coords[idx](x) | |
if self.is_residual: | |
x = x + residual | |
x = torch.sin(x) | |
x = self.last_coord(x) | |
return x | |
class ModulatedSirens(BaseModule): | |
def __init__(self, | |
num_inner_layers, | |
in_dim, | |
modulation_dim, | |
out_dim=3, | |
base_channels=256, | |
is_residual=True | |
): | |
super(ModulatedSirens, self).__init__() | |
self.in_dim = in_dim | |
self.num_inner_layers = num_inner_layers | |
self.is_residual = is_residual | |
self.first_mod = nn.Sequential( | |
nn.Conv2d(modulation_dim, base_channels, 1), | |
nn.ReLU() | |
) | |
self.first_coord = nn.Conv2d(in_dim, base_channels, 1) | |
self.inner_mods = nn.ModuleList() | |
self.inner_coords = nn.ModuleList() | |
for _ in range(self.num_inner_layers): | |
self.inner_mods.append( | |
nn.Sequential( | |
nn.Conv2d(modulation_dim+base_channels+base_channels, base_channels, 1), | |
nn.ReLU() | |
) | |
) | |
self.inner_coords.append( | |
nn.Conv2d(base_channels, base_channels, 1) | |
) | |
self.last_coord = nn.Sequential( | |
# nn.Conv2d(base_channels, base_channels//2, 1), | |
# nn.ReLU(), | |
nn.Conv2d(base_channels, out_dim, 1) | |
) | |
def forward(self, x, ori_modulations=None): | |
modulations = self.first_mod(ori_modulations) | |
x = self.first_coord(x) # B 2 H W -> B C H W | |
x = x + modulations | |
x = torch.sin(x) | |
for i_layer in range(self.num_inner_layers): | |
modulations = self.inner_mods[i_layer]( | |
torch.cat((ori_modulations, modulations, x), dim=1)) | |
# modulations = self.inner_mods[i_layer]( | |
# torch.cat((ori_modulations, x), dim=1)) | |
residual = self.inner_coords[i_layer](x) | |
residual = residual + modulations | |
residual = torch.sin(residual) | |
if self.is_residual: | |
x = x + residual | |
else: | |
x = residual | |
x = self.last_coord(x) | |
return x | |