WindVChen's picture
Update
033bd8b
import torch
from torch import nn as nn
import numpy as np
import math
import torch.nn.functional as F
class SimpleInputFusion(nn.Module):
def __init__(self, add_ch=1, rgb_ch=3, ch=8, norm_layer=nn.BatchNorm2d):
super(SimpleInputFusion, self).__init__()
self.fusion_conv = nn.Sequential(
nn.Conv2d(in_channels=add_ch + rgb_ch, out_channels=ch, kernel_size=1),
nn.LeakyReLU(negative_slope=0.2),
norm_layer(ch),
nn.Conv2d(in_channels=ch, out_channels=rgb_ch, kernel_size=1),
)
def forward(self, image, additional_input):
return self.fusion_conv(torch.cat((image, additional_input), dim=1))
class MaskedChannelAttention(nn.Module):
def __init__(self, in_channels, *args, **kwargs):
super(MaskedChannelAttention, self).__init__()
self.global_max_pool = MaskedGlobalMaxPool2d()
self.global_avg_pool = FastGlobalAvgPool2d()
intermediate_channels_count = max(in_channels // 16, 8)
self.attention_transform = nn.Sequential(
nn.Linear(3 * in_channels, intermediate_channels_count),
nn.ReLU(inplace=True),
nn.Linear(intermediate_channels_count, in_channels),
nn.Sigmoid(),
)
def forward(self, x, mask):
if mask.shape[2:] != x.shape[:2]:
mask = nn.functional.interpolate(
mask, size=x.size()[-2:],
mode='bilinear', align_corners=True
)
pooled_x = torch.cat([
self.global_max_pool(x, mask),
self.global_avg_pool(x)
], dim=1)
channel_attention_weights = self.attention_transform(pooled_x)[..., None, None]
return channel_attention_weights * x
class MaskedGlobalMaxPool2d(nn.Module):
def __init__(self):
super().__init__()
self.global_max_pool = FastGlobalMaxPool2d()
def forward(self, x, mask):
return torch.cat((
self.global_max_pool(x * mask),
self.global_max_pool(x * (1.0 - mask))
), dim=1)
class FastGlobalAvgPool2d(nn.Module):
def __init__(self):
super(FastGlobalAvgPool2d, self).__init__()
def forward(self, x):
in_size = x.size()
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
class FastGlobalMaxPool2d(nn.Module):
def __init__(self):
super(FastGlobalMaxPool2d, self).__init__()
def forward(self, x):
in_size = x.size()
return x.view((in_size[0], in_size[1], -1)).max(dim=2)[0]
class ScaleLayer(nn.Module):
def __init__(self, init_value=1.0, lr_mult=1):
super().__init__()
self.lr_mult = lr_mult
self.scale = nn.Parameter(
torch.full((1,), init_value / lr_mult, dtype=torch.float32)
)
def forward(self, x):
scale = torch.abs(self.scale * self.lr_mult)
return x * scale
class FeaturesConnector(nn.Module):
def __init__(self, mode, in_channels, feature_channels, out_channels):
super(FeaturesConnector, self).__init__()
self.mode = mode if feature_channels else ''
if self.mode == 'catc':
self.reduce_conv = nn.Conv2d(in_channels + feature_channels, out_channels, kernel_size=1)
elif self.mode == 'sum':
self.reduce_conv = nn.Conv2d(feature_channels, out_channels, kernel_size=1)
self.output_channels = out_channels if self.mode != 'cat' else in_channels + feature_channels
def forward(self, x, features):
if self.mode == 'cat':
return torch.cat((x, features), 1)
if self.mode == 'catc':
return self.reduce_conv(torch.cat((x, features), 1))
if self.mode == 'sum':
return self.reduce_conv(features) + x
return x
def extra_repr(self):
return self.mode
class PosEncodingNeRF(nn.Module):
def __init__(self, in_features, sidelength=None, fn_samples=None, use_nyquist=True):
super().__init__()
self.in_features = in_features
if self.in_features == 3:
self.num_frequencies = 10
elif self.in_features == 2:
assert sidelength is not None
if isinstance(sidelength, int):
sidelength = (sidelength, sidelength)
self.num_frequencies = 4
if use_nyquist:
self.num_frequencies = self.get_num_frequencies_nyquist(min(sidelength[0], sidelength[1]))
elif self.in_features == 1:
assert fn_samples is not None
self.num_frequencies = 4
if use_nyquist:
self.num_frequencies = self.get_num_frequencies_nyquist(fn_samples)
self.out_dim = in_features + 2 * in_features * self.num_frequencies
def get_num_frequencies_nyquist(self, samples):
nyquist_rate = 1 / (2 * (2 * 1 / samples))
return int(math.floor(math.log(nyquist_rate, 2)))
def forward(self, coords):
coords = coords.view(coords.shape[0], -1, self.in_features)
coords_pos_enc = coords
for i in range(self.num_frequencies):
for j in range(self.in_features):
c = coords[..., j]
sin = torch.unsqueeze(torch.sin((2 ** i) * np.pi * c), -1)
cos = torch.unsqueeze(torch.cos((2 ** i) * np.pi * c), -1)
coords_pos_enc = torch.cat((coords_pos_enc, sin, cos), axis=-1)
return coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)
class RandomFourier(nn.Module):
def __init__(self, std_scale, embedding_length, device):
super().__init__()
self.embed = torch.normal(0, 1, (2, embedding_length)) * std_scale
self.embed = self.embed.to(device)
self.out_dim = embedding_length * 2 + 2
def forward(self, coords):
coords_pos_enc = torch.cat([torch.sin(torch.matmul(2 * np.pi * coords, self.embed)),
torch.cos(torch.matmul(2 * np.pi * coords, self.embed))], dim=-1)
return torch.cat([coords, coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)], dim=-1)
class CIPS_embed(nn.Module):
def __init__(self, size, embedding_length):
super().__init__()
self.fourier_embed = ConstantInput(size, embedding_length)
self.predict_embed = Predict_embed(embedding_length)
self.out_dim = embedding_length * 2 + 2
def forward(self, coord, res=None):
x = self.predict_embed(coord)
y = self.fourier_embed(x, coord, res)
return torch.cat([coord, x, y], dim=-1)
class Predict_embed(nn.Module):
def __init__(self, embedding_length):
super(Predict_embed, self).__init__()
self.ffm = nn.Linear(2, embedding_length, bias=True)
nn.init.uniform_(self.ffm.weight, -np.sqrt(9 / 2), np.sqrt(9 / 2))
def forward(self, x):
x = self.ffm(x)
x = torch.sin(x)
return x
class ConstantInput(nn.Module):
def __init__(self, size, channel):
super().__init__()
self.input = nn.Parameter(torch.randn(1, size ** 2, channel))
def forward(self, input, coord, resolution=None):
batch = input.shape[0]
out = self.input.repeat(batch, 1, 1)
if coord.shape[1] != self.input.shape[1]:
x = out.permute(0, 2, 1).contiguous().view(batch, self.input.shape[-1],
int(self.input.shape[1] ** 0.5), int(self.input.shape[1] ** 0.5))
if resolution is None:
grid = coord.view(coord.shape[0], int(coord.shape[1] ** 0.5), int(coord.shape[1] ** 0.5), coord.shape[-1])
else:
grid = coord.view(coord.shape[0], *resolution, coord.shape[-1])
out = F.grid_sample(x, grid.flip(-1), mode='bilinear', padding_mode='border', align_corners=True)
out = out.permute(0, 2, 3, 1).contiguous().view(batch, -1, self.input.shape[-1])
return out
class INRGAN_embed(nn.Module):
def __init__(self, resolution: int, w_dim=None):
super().__init__()
self.resolution = resolution
self.res_cfg = {"log_emb_size": 32,
"random_emb_size": 32,
"const_emb_size": 64,
"use_cosine": True}
self.log_emb_size = self.res_cfg.get('log_emb_size', 0)
self.random_emb_size = self.res_cfg.get('random_emb_size', 0)
self.shared_emb_size = self.res_cfg.get('shared_emb_size', 0)
self.predictable_emb_size = self.res_cfg.get('predictable_emb_size', 0)
self.const_emb_size = self.res_cfg.get('const_emb_size', 0)
self.fourier_scale = self.res_cfg.get('fourier_scale', np.sqrt(10))
self.use_cosine = self.res_cfg.get('use_cosine', False)
if self.log_emb_size > 0:
self.register_buffer('log_basis', generate_logarithmic_basis(
resolution, self.log_emb_size, use_diagonal=self.res_cfg.get('use_diagonal', False)))
if self.random_emb_size > 0:
self.register_buffer('random_basis', self.sample_w_matrix((2, self.random_emb_size), self.fourier_scale))
if self.shared_emb_size > 0:
self.shared_basis = nn.Parameter(self.sample_w_matrix((2, self.shared_emb_size), self.fourier_scale))
if self.predictable_emb_size > 0:
self.W_size = self.predictable_emb_size * self.cfg.coord_dim
self.b_size = self.predictable_emb_size
self.affine = nn.Linear(w_dim, self.W_size + self.b_size)
if self.const_emb_size > 0:
self.const_embs = nn.Parameter(torch.randn(1, resolution ** 2, self.const_emb_size))
self.out_dim = self.get_total_dim() + 2
def sample_w_matrix(self, shape, scale: float):
return torch.randn(shape) * scale
def get_total_dim(self) -> int:
total_dim = 0
if self.log_emb_size > 0:
total_dim += self.log_basis.shape[0] * (2 if self.use_cosine else 1)
total_dim += self.random_emb_size * (2 if self.use_cosine else 1)
total_dim += self.shared_emb_size * (2 if self.use_cosine else 1)
total_dim += self.predictable_emb_size * (2 if self.use_cosine else 1)
total_dim += self.const_emb_size
return total_dim
def forward(self, raw_coords, w=None):
batch_size, img_size, in_channels = raw_coords.shape
raw_embs = []
if self.log_emb_size > 0:
log_bases = self.log_basis.unsqueeze(0).repeat(batch_size, 1, 1).permute(0, 2, 1)
raw_log_embs = torch.matmul(raw_coords, log_bases)
raw_embs.append(raw_log_embs)
if self.random_emb_size > 0:
random_bases = self.random_basis.unsqueeze(0).repeat(batch_size, 1, 1)
raw_random_embs = torch.matmul(raw_coords, random_bases)
raw_embs.append(raw_random_embs)
if self.shared_emb_size > 0:
shared_bases = self.shared_basis.unsqueeze(0).repeat(batch_size, 1, 1)
raw_shared_embs = torch.matmul(raw_coords, shared_bases)
raw_embs.append(raw_shared_embs)
if self.predictable_emb_size > 0:
mod = self.affine(w)
W = self.fourier_scale * mod[:, :self.W_size]
W = W.view(batch_size, self.cfg.coord_dim, self.predictable_emb_size)
bias = mod[:, self.W_size:].view(batch_size, 1, self.predictable_emb_size)
raw_predictable_embs = (torch.matmul(raw_coords, W) + bias)
raw_embs.append(raw_predictable_embs)
if len(raw_embs) > 0:
raw_embs = torch.cat(raw_embs, dim=-1)
raw_embs = raw_embs.contiguous()
out = raw_embs.sin()
if self.use_cosine:
out = torch.cat([out, raw_embs.cos()], dim=-1)
if self.const_emb_size > 0:
const_embs = self.const_embs.repeat([batch_size, 1, 1])
const_embs = const_embs
out = torch.cat([out, const_embs], dim=-1)
return torch.cat([raw_coords, out], dim=-1)
def generate_logarithmic_basis(
resolution,
max_num_feats,
remove_lowest_freq: bool = False,
use_diagonal: bool = True):
"""
Generates a directional logarithmic basis with the following directions:
- horizontal
- vertical
- main diagonal
- anti-diagonal
"""
max_num_feats_per_direction = np.ceil(np.log2(resolution)).astype(int)
bases = [
generate_horizontal_basis(max_num_feats_per_direction),
generate_vertical_basis(max_num_feats_per_direction),
]
if use_diagonal:
bases.extend([
generate_diag_main_basis(max_num_feats_per_direction),
generate_anti_diag_basis(max_num_feats_per_direction),
])
if remove_lowest_freq:
bases = [b[1:] for b in bases]
# If we do not fit into `max_num_feats`, then trying to remove the features in the order:
# 1) anti-diagonal 2) main-diagonal
# while (max_num_feats_per_direction * len(bases) > max_num_feats) and (len(bases) > 2):
# bases = bases[:-1]
basis = torch.cat(bases, dim=0)
# If we still do not fit, then let's remove each second feature,
# then each third, each forth and so on
# We cannot drop the whole horizontal or vertical direction since otherwise
# model won't be able to locate the position
# (unless the previously computed embeddings encode the position)
# while basis.shape[0] > max_num_feats:
# num_exceeding_feats = basis.shape[0] - max_num_feats
# basis = basis[::2]
assert basis.shape[0] <= max_num_feats, \
f"num_coord_feats > max_num_fixed_coord_feats: {basis.shape, max_num_feats}."
return basis
def generate_horizontal_basis(num_feats: int):
return generate_wavefront_basis(num_feats, [0.0, 1.0], 4.0)
def generate_vertical_basis(num_feats: int):
return generate_wavefront_basis(num_feats, [1.0, 0.0], 4.0)
def generate_diag_main_basis(num_feats: int):
return generate_wavefront_basis(num_feats, [-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2))
def generate_anti_diag_basis(num_feats: int):
return generate_wavefront_basis(num_feats, [1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2))
def generate_wavefront_basis(num_feats: int, basis_block, period_length: float):
period_coef = 2.0 * np.pi / period_length
basis = torch.tensor([basis_block]).repeat(num_feats, 1) # [num_feats, 2]
powers = torch.tensor([2]).repeat(num_feats).pow(torch.arange(num_feats)).unsqueeze(1) # [num_feats, 1]
result = basis * powers * period_coef # [num_feats, 2]
return result.float()