Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.models as models | |
from typing import Union, List, Tuple | |
import os | |
import video3d.utils.misc as misc | |
import torch.nn.functional as F | |
from siren_pytorch import SirenNet | |
from video3d.triplane_texture.lift_architecture import Lift_Encoder | |
from video3d.triplane_texture.triplane_transformer import Triplane_Transformer | |
EPS = 1e-7 | |
def get_activation(name, inplace=True, lrelu_param=0.2): | |
if name == 'tanh': | |
return nn.Tanh() | |
elif name == 'sigmoid': | |
return nn.Sigmoid() | |
elif name == 'relu': | |
return nn.ReLU(inplace=inplace) | |
elif name == 'lrelu': | |
return nn.LeakyReLU(lrelu_param, inplace=inplace) | |
else: | |
raise NotImplementedError | |
class MLPWithPositionalEncoding(nn.Module): | |
def __init__(self, | |
cin, | |
cout, | |
num_layers, | |
nf=256, | |
dropout=0, | |
activation=None, | |
n_harmonic_functions=10, | |
omega0=1, | |
extra_dim=0, | |
embed_concat_pts=True, | |
symmetrize=False): | |
super().__init__() | |
self.extra_dim = extra_dim | |
if n_harmonic_functions > 0: | |
self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) | |
dim_in = cin * 2 * n_harmonic_functions | |
self.embed_concat_pts = embed_concat_pts | |
if embed_concat_pts: | |
dim_in += cin | |
else: | |
self.embedder = None | |
dim_in = cin | |
self.in_layer = nn.Linear(dim_in, nf) | |
self.relu = nn.ReLU(inplace=True) | |
self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation) | |
self.symmetrize = symmetrize | |
def forward(self, x, feat=None): | |
assert (feat is None and self.extra_dim == 0) or feat.shape[-1] == self.extra_dim | |
if self.symmetrize: | |
xs, ys, zs = x.unbind(-1) | |
x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x | |
if self.embedder is not None: | |
x_in = self.embedder(x) | |
if self.embed_concat_pts: | |
x_in = torch.cat([x, x_in], -1) | |
else: | |
x_in = x | |
x_in = self.relu(self.in_layer(x_in)) | |
if feat is not None: | |
# if len(feat.shape) == 1: | |
# for _ in range(len(x_in.shape) - 1): | |
# feat = feat.unsqueeze(0) | |
# feat = feat.repeat(*x_in.shape[:-1], 1) | |
x_in = torch.concat([x_in, feat], dim=-1) | |
return self.mlp(x_in) | |
class MLPWithPositionalEncoding_Style(nn.Module): | |
def __init__(self, | |
cin, | |
cout, | |
num_layers, | |
nf=256, | |
dropout=0, | |
activation=None, | |
n_harmonic_functions=10, | |
omega0=1, | |
extra_dim=0, | |
embed_concat_pts=True, | |
symmetrize=False, | |
style_choice='film'): | |
super().__init__() | |
self.extra_dim = extra_dim | |
if n_harmonic_functions > 0: | |
self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) | |
dim_in = cin * 2 * n_harmonic_functions | |
self.embed_concat_pts = embed_concat_pts | |
if embed_concat_pts: | |
dim_in += cin | |
else: | |
self.embedder = None | |
dim_in = cin | |
self.in_layer = nn.Linear(dim_in, nf) | |
self.relu = nn.ReLU(inplace=True) | |
if extra_dim == 0: | |
self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation) | |
else: | |
if style_choice == 'film': | |
self.mlp = MLP_FiLM(nf, cout, num_layers, nf, dropout, activation) | |
self.style_mlp = MLP(extra_dim, nf*2, 2, nf, dropout, None) | |
elif style_choice == 'mod': | |
self.mlp = MLP_Mod(nf, cout, num_layers, nf, dropout, activation) | |
self.style_mlp = MLP(extra_dim, nf, 2, nf, dropout, None) | |
else: | |
raise NotImplementedError | |
self.style_choice = style_choice | |
self.symmetrize = symmetrize | |
def forward(self, x, feat=None): | |
assert (feat is None and self.extra_dim == 0) or feat.shape[-1] == self.extra_dim | |
if self.symmetrize: | |
xs, ys, zs = x.unbind(-1) | |
x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x | |
if self.embedder is not None: | |
x_in = self.embedder(x) | |
if self.embed_concat_pts: | |
x_in = torch.cat([x, x_in], -1) | |
else: | |
x_in = x | |
x_in = self.relu(self.in_layer(x_in)) | |
if feat is not None: | |
style = self.style_mlp(feat) | |
if self.style_choice == 'film': | |
style = style.reshape(style.shape[:-1] + (-1, 2)) | |
out = self.mlp(x_in, style) | |
else: | |
out = self.mlp(x_in) | |
return out | |
class MLP_FiLM(nn.Module): | |
def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None): | |
# default no dropout | |
super().__init__() | |
assert num_layers >= 1 | |
self.num_layers = num_layers | |
if num_layers == 1: | |
self.network = Linear_FiLM(cin, cout, bias=False) | |
else: | |
self.relu = nn.ReLU(inplace=True) | |
for i in range(num_layers): | |
if i == 0: | |
setattr(self, f'linear_{i}', Linear_FiLM(cin, nf, bias=False)) | |
elif i == (num_layers-1): | |
setattr(self, f'linear_{i}', Linear_FiLM(nf, cout, bias=False)) | |
else: | |
setattr(self, f'linear_{i}', Linear_FiLM(nf, nf, bias=False)) | |
def forward(self, input, style): | |
if self.num_layers == 1: | |
out = self.network(input, style) | |
else: | |
x = input | |
for i in range(self.num_layers): | |
linear_layer = getattr(self, f'linear_{i}') | |
if i == (self.num_layers - 1): | |
x = linear_layer(x, style) | |
else: | |
x = linear_layer(x, style) | |
x = self.relu(x) | |
out = x | |
return out | |
class MLP_Mod(nn.Module): | |
def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None): | |
# default no dropout | |
super().__init__() | |
assert num_layers >= 1 | |
self.num_layers = num_layers | |
if num_layers == 1: | |
self.network = Linear_Mod(cin, cout, bias=False) | |
else: | |
self.relu = nn.ReLU(inplace=True) | |
for i in range(num_layers): | |
if i == 0: | |
setattr(self, f'linear_{i}', Linear_Mod(cin, nf, bias=False)) | |
elif i == (num_layers-1): | |
setattr(self, f'linear_{i}', Linear_Mod(nf, cout, bias=False)) | |
else: | |
setattr(self, f'linear_{i}', Linear_Mod(nf, nf, bias=False)) | |
def forward(self, input, style): | |
if self.num_layers == 1: | |
out = self.network(input, style) | |
else: | |
x = input | |
for i in range(self.num_layers): | |
linear_layer = getattr(self, f'linear_{i}') | |
if i == (self.num_layers - 1): | |
x = linear_layer(x, style) | |
else: | |
x = linear_layer(x, style) | |
x = self.relu(x) | |
out = x | |
return out | |
import math | |
class Linear_FiLM(nn.Module): | |
def __init__(self, in_features: int, out_features: int, bias: bool = True, | |
device=None, dtype=None) -> None: | |
factory_kwargs = {'device': device, 'dtype': dtype} | |
super().__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) | |
if bias: | |
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) | |
else: | |
self.register_parameter('bias', None) | |
self.reset_parameters() | |
def reset_parameters(self) -> None: | |
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | |
if self.bias is not None: | |
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) | |
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | |
nn.init.uniform_(self.bias, -bound, bound) | |
def forward(self, input, style): | |
# if input is [..., D], style should be [..., D, 2] | |
x = input * style[..., 0] + style[..., 1] | |
return torch.nn.functional.linear(x, self.weight, self.bias) | |
def extra_repr(self) -> str: | |
return 'in_features={}, out_features={}, bias={}'.format( | |
self.in_features, self.out_features, self.bias is not None | |
) | |
class Linear_Mod(nn.Module): | |
def __init__(self, in_features: int, out_features: int, bias: bool = True, | |
device=None, dtype=None) -> None: | |
factory_kwargs = {'device': device, 'dtype': dtype} | |
super().__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) | |
if bias: | |
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) | |
else: | |
self.register_parameter('bias', None) | |
self.reset_parameters() | |
def reset_parameters(self) -> None: | |
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | |
if self.bias is not None: | |
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) | |
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | |
nn.init.uniform_(self.bias, -bound, bound) | |
def forward(self, input, style): | |
# weight: [out_features, in_features] | |
# style: [..., in_features] | |
if len(style.shape) > 1: | |
style = style.reshape(-1, style.shape[-1]) | |
style = style[0] | |
weight = self.weight * style.unsqueeze(0) | |
decoefs = ((weight * weight).sum(dim=-1, keepdim=True) + 1e-5).sqrt() | |
weight = weight / decoefs | |
return torch.nn.functional.linear(input, weight, self.bias) | |
def extra_repr(self) -> str: | |
return 'in_features={}, out_features={}, bias={}'.format( | |
self.in_features, self.out_features, self.bias is not None | |
) | |
class MLPTextureSimple(nn.Module): | |
def __init__(self, | |
cin, | |
cout, | |
num_layers, | |
nf=256, | |
dropout=0, | |
activation=None, | |
min_max=None, | |
n_harmonic_functions=10, | |
omega0=1, | |
extra_dim=0, | |
embed_concat_pts=True, | |
perturb_normal=False, | |
symmetrize=False, | |
texture_act='relu', | |
linear_bias=False): | |
super().__init__() | |
self.extra_dim = extra_dim | |
if n_harmonic_functions > 0: | |
self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) | |
dim_in = cin * 2 * n_harmonic_functions | |
self.embed_concat_pts = embed_concat_pts | |
if embed_concat_pts: | |
dim_in += cin | |
else: | |
self.embedder = None | |
dim_in = cin | |
self.in_layer = nn.Linear(dim_in, nf) | |
self.relu = nn.ReLU(inplace=True) | |
if texture_act == 'sin': | |
print('using siren network for texture mlp here') | |
self.mlp = SirenNet( | |
dim_in=(nf + extra_dim), | |
dim_hidden=nf, | |
dim_out=cout, | |
num_layers=num_layers, | |
final_activation=get_activation(activation), | |
w0_initial=30, | |
use_bias=linear_bias, | |
dropout=dropout | |
) | |
else: | |
self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation, inner_act=texture_act, linear_bias=linear_bias) | |
self.perturb_normal = perturb_normal | |
self.symmetrize = symmetrize | |
if min_max is not None: | |
self.register_buffer('min_max', min_max) | |
else: | |
self.min_max = None | |
self.bsdf = None | |
def sample(self, x, feat=None): | |
assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] == self.extra_dim) | |
b, h, w, c = x.shape | |
if self.symmetrize: | |
xs, ys, zs = x.unbind(-1) | |
x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x | |
x = x.view(-1, c) | |
if self.embedder is not None: | |
x_in = self.embedder(x) | |
if self.embed_concat_pts: | |
x_in = torch.cat([x, x_in], -1) | |
else: | |
x_in = x | |
x_in = self.in_layer(x_in) | |
if feat is not None: | |
feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) | |
x_in = torch.concat([x_in, feat], dim=-1) | |
out = self.mlp(self.relu(x_in)) | |
if self.min_max is not None: | |
out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] | |
return out.view(b, h, w, -1) | |
class MLPTextureTriplane(nn.Module): | |
def __init__(self, | |
cin, | |
cout, | |
num_layers, | |
nf=256, | |
dropout=0, | |
activation=None, | |
min_max=None, | |
n_harmonic_functions=10, | |
omega0=1, | |
extra_dim=0, | |
embed_concat_pts=True, | |
perturb_normal=False, | |
symmetrize=False, | |
texture_act='relu', | |
linear_bias=False, | |
cam_pos_z_offset=10., | |
grid_scale=7,): | |
super().__init__() | |
self.extra_dim = extra_dim | |
if n_harmonic_functions > 0: | |
self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) | |
dim_in = cin * 2 * n_harmonic_functions | |
self.embed_concat_pts = embed_concat_pts | |
if embed_concat_pts: | |
dim_in += cin | |
else: | |
self.embedder = None | |
dim_in = cin | |
self.in_layer = nn.Linear(dim_in, nf) | |
self.relu = nn.ReLU(inplace=True) | |
self.feat_net = Triplane_Transformer( | |
emb_dim=256, | |
num_layers=8, | |
triplane_dim=80, | |
triplane_scale=grid_scale | |
) | |
self.extra_dim -= extra_dim | |
self.extra_dim += (self.feat_net.triplane_dim * 3) | |
if texture_act == 'sin': | |
print('using siren network for texture mlp here') | |
self.mlp = SirenNet( | |
dim_in=(nf + self.extra_dim), | |
dim_hidden=nf, | |
dim_out=cout, | |
num_layers=num_layers, | |
final_activation=get_activation(activation), | |
w0_initial=30, | |
use_bias=linear_bias, | |
dropout=dropout | |
) | |
else: | |
self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation, inner_act=texture_act, linear_bias=linear_bias) | |
self.perturb_normal = perturb_normal | |
self.symmetrize = symmetrize | |
if min_max is not None: | |
self.register_buffer('min_max', min_max) | |
else: | |
self.min_max = None | |
self.bsdf = None | |
def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None): | |
# assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] == self.extra_dim) | |
b, h, w, c = x.shape | |
if self.symmetrize: | |
xs, ys, zs = x.unbind(-1) | |
x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x | |
if isinstance(feat_map, dict): | |
feat_map = feat_map["im_features_map"] | |
feat_map = feat_map.permute(0, 2, 3, 1) | |
_, ph, pw, _ = feat_map.shape | |
feat_map = feat_map.reshape(feat_map.shape[0], ph*pw, feat_map.shape[-1]) | |
pts_feat = self.feat_net(feat_map, x.reshape(b, -1, 3)) | |
pts_c = pts_feat.shape[-1] | |
pts_feat = pts_feat.reshape(-1, pts_c) | |
x = x.view(-1, c) | |
if self.embedder is not None: | |
x_in = self.embedder(x) | |
if self.embed_concat_pts: | |
x_in = torch.cat([x, x_in], -1) | |
else: | |
x_in = x | |
x_in = self.in_layer(x_in) | |
x_in = torch.concat([x_in, pts_feat], dim=-1) | |
out = self.mlp(self.relu(x_in)) | |
if self.min_max is not None: | |
out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] | |
return out.view(b, h, w, -1) | |
class LocalFeatureBlock(nn.Module): | |
def __init__(self, local_feat_dim, input_dim=384, output_dim=384, upscale_num=3): | |
super().__init__() | |
self.local_feat_dim = local_feat_dim | |
self.conv_list = nn.ModuleList([]) | |
self.upscale_list = nn.ModuleList([]) | |
for i in range(upscale_num): | |
if i == 0: | |
self.conv_list.append(nn.Conv2d(input_dim, 4 * local_feat_dim, 3, stride=1, padding=1, dilation=1)) | |
else: | |
self.conv_list.append(nn.Conv2d(local_feat_dim, 4 * local_feat_dim, 3, stride=1, padding=1, dilation=1)) | |
self.upscale_list.append(nn.PixelShuffle(2)) | |
self.conv_head = nn.Conv2d(local_feat_dim, output_dim, 3, stride=1, padding=1, dilation=1) | |
def forward(self, x): | |
for idx, conv in enumerate(self.conv_list): | |
x = conv(x) | |
x = self.upscale_list[idx](x) | |
out = self.conv_head(x) | |
return out | |
class MLPTextureLocal(nn.Module): | |
def __init__(self, | |
cin, | |
cout, | |
num_layers, | |
nf=256, | |
dropout=0, | |
activation=None, | |
min_max=None, | |
n_harmonic_functions=10, | |
omega0=1, | |
extra_dim=0, | |
embed_concat_pts=True, | |
perturb_normal=False, | |
symmetrize=False, | |
texture_way=None, | |
larger_tex_dim=False, | |
cam_pos_z_offset=10., | |
grid_scale=7.): | |
super().__init__() | |
self.extra_dim = extra_dim | |
self.cam_pos_z_offset = cam_pos_z_offset | |
self.grid_scale = grid_scale | |
local_feat_dim = 64 | |
assert texture_way is not None | |
self.texture_way = texture_way | |
if 'local' in texture_way and 'global' in texture_way: | |
# self.extra_dim = extra_dim + local_feat_dim | |
self.extra_dim = extra_dim | |
elif 'local' in texture_way and 'global' not in texture_way: | |
# self.extra_dim = local_feat_dim | |
self.extra_dim = extra_dim | |
elif 'local' not in texture_way and 'global' in texture_way: | |
self.extra_dim = extra_dim | |
if n_harmonic_functions > 0: | |
self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) | |
dim_in = cin * 2 * n_harmonic_functions | |
self.embed_concat_pts = embed_concat_pts | |
if embed_concat_pts: | |
dim_in += cin | |
else: | |
self.embedder = None | |
dim_in = cin | |
# self.local_feature_block = LocalFeatureBlock(local_feat_dim=local_feat_dim, input_dim=384, output_dim=256) | |
self.local_feature_block = nn.Linear(384, nf, bias=False) | |
self.in_layer = nn.Linear(dim_in, nf) | |
self.relu = nn.ReLU(inplace=True) | |
self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation) | |
self.perturb_normal = perturb_normal | |
self.symmetrize = symmetrize | |
if min_max is not None: | |
self.register_buffer('min_max', min_max) | |
else: | |
self.min_max = None | |
self.bsdf = None | |
def get_uv_depth(self, xyz, mvp): | |
# xyz: [b, k, 3] | |
# mvp: [b, 4, 4] | |
cam4 = torch.matmul(torch.nn.functional.pad(xyz, pad=(0,1), mode='constant', value=1.0), torch.transpose(mvp, 1, 2)) | |
cam3 = cam4[..., :3] / cam4[..., 3:4] | |
cam_uv = cam3[..., :2] | |
# cam_uv = cam_uv.detach() | |
cam_depth = cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(xyz.device).view(1, 1, 3) | |
cam_depth = cam_depth / self.grid_scale * 2 | |
cam_depth = cam_depth[..., 2:3] | |
# cam_depth = cam_depth.detach() | |
return cam_uv, cam_depth | |
def proj_sample_deform(self, xyz, feat_map, mvp, w2c, img_h, img_w): | |
# here the xyz is deformed points | |
# and we don't cast any symmtery here | |
b, k, c = xyz.shape | |
THRESHOLD = 1e-4 | |
if isinstance(feat_map, torch.Tensor): | |
coordinates = xyz | |
# use pre-symmetry points to get feature and record depth | |
cam_uv, cam_depth = self.get_uv_depth(coordinates, mvp) | |
cam_uv = cam_uv.detach() | |
cam_depth = cam_depth.detach() | |
# get local feature | |
feature = F.grid_sample(feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] | |
self.input_depth = cam_depth.reshape(b, 256, 256, 1) # [B, 256, 256, 1] | |
self.input_pts = coordinates.detach() | |
elif isinstance(feat_map, dict): | |
original_mvp = feat_map['original_mvp'] | |
local_feat_map = feat_map['im_features_map'] | |
original_depth = self.input_depth[0:b] | |
coordinates = xyz | |
cam_uv, cam_depth = self.get_uv_depth(coordinates, original_mvp) | |
cam_uv = cam_uv.detach() | |
cam_depth = cam_depth.detach() | |
project_feature = F.grid_sample(local_feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] | |
project_depth = F.grid_sample(original_depth.permute(0, 3, 1, 2), cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1] | |
use_mask = cam_depth <= project_depth + THRESHOLD | |
feature = project_feature * use_mask.repeat(1, 1, project_feature.shape[-1]) | |
ret_feature = self.local_feature_block(feature.reshape(b*k, -1)) # the linear is without bias, so 0 value feature will still get 0 value | |
return ret_feature | |
def proj_sample(self, xyz, feat_map, mvp, w2c, img_h, img_w, xyz_before_sym=None): | |
# the new one with no input feature map upsampling | |
# feat_map: [B, C, H, W] | |
b, k, c = xyz.shape | |
if isinstance(feat_map, torch.Tensor): | |
if xyz_before_sym is None: | |
coordinates = xyz | |
else: | |
coordinates = xyz_before_sym | |
# use pre-symmetry points to get feature and record depth | |
cam_uv, cam_depth = self.get_uv_depth(coordinates, mvp) | |
cam_uv = cam_uv.detach() | |
cam_depth = cam_depth.detach() | |
# get local feature | |
feature = F.grid_sample(feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] | |
self.input_depth = cam_depth.reshape(b, 256, 256, 1) # [B, 256, 256, 1] | |
self.input_pts = coordinates.detach() | |
elif isinstance(feat_map, dict): | |
original_mvp = feat_map['original_mvp'] | |
local_feat_map = feat_map['im_features_map'] | |
THRESHOLD = 1e-4 | |
original_depth = self.input_depth[0:b] | |
# if b == 1: | |
# from pdb import set_trace; set_trace() | |
# tmp_mask = xyz[0].reshape(256, 256, 3).sum(dim=-1) != 0 | |
# tmp_mask = tmp_mask.cpu().numpy() | |
# tmp_mask = tmp_mask * 255 | |
# src_dp = self.input_depth[0,:,:,0].cpu().numpy() | |
# input_pts = self.input_pts[0].cpu().numpy() | |
# input_mask = self.input_pts[0].reshape(256, 256, 3).sum(dim=-1) != 0 | |
# input_mask = input_mask.int().cpu().numpy() | |
# input_mask = input_mask * 255 | |
# np.save('./tmp_save/src_dp.npy', src_dp) | |
# np.save('./tmp_save/input_pts.npy', input_pts) | |
# import cv2 | |
# cv2.imwrite('./tmp_save/input_mask.png', input_mask) | |
# cv2.imwrite('./tmp_save/mask.png', tmp_mask) | |
# test_pts_pos = xyz[0].cpu().numpy() | |
# np.save('./tmp_save/test_pts_pos.npy', test_pts_pos) | |
# test_pts_raw = xyz_before_sym[0].cpu().numpy() | |
# np.save('./tmp_save/test_pts_raw.npy', test_pts_raw) | |
# mvp_now = mvp[0].detach().cpu().numpy() | |
# mvp_original = original_mvp[0].detach().cpu().numpy() | |
# np.save('./tmp_save/mvp_now.npy', mvp_now) | |
# np.save('./tmp_save/mvp_original.npy', mvp_original) | |
if xyz_before_sym is None: | |
# just check the project depth of xyz | |
coordinates = xyz | |
cam_uv, cam_depth = self.get_uv_depth(coordinates, original_mvp) | |
cam_uv = cam_uv.detach() | |
cam_depth = cam_depth.detach() | |
project_feature = F.grid_sample(local_feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] | |
project_depth = F.grid_sample(original_depth.permute(0, 3, 1, 2), cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1] | |
use_mask = cam_depth <= project_depth + THRESHOLD | |
feature = project_feature * use_mask.repeat(1, 1, project_feature.shape[-1]) | |
else: | |
# need to double check, but now we are still use symmetry! Even if the two points are all visible in input view | |
coords_inp = xyz | |
x_check, y_check, z_check = xyz.unbind(-1) | |
xyz_check = torch.stack([-1 * x_check, y_check, z_check], -1) | |
coords_rev = xyz_check # we directly use neg-x to get the points of another side | |
uv_inp, dp_inp = self.get_uv_depth(coords_inp, original_mvp) | |
uv_rev, dp_rev = self.get_uv_depth(coords_rev, original_mvp) | |
uv_inp = uv_inp.detach() | |
uv_rev = uv_rev.detach() | |
dp_inp = dp_inp.detach() | |
dp_rev = dp_rev.detach() | |
proj_feat_inp = F.grid_sample(local_feat_map, uv_inp.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] | |
proj_feat_rev = F.grid_sample(local_feat_map, uv_rev.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] | |
proj_dp_inp = F.grid_sample(original_depth.permute(0, 3, 1, 2), uv_inp.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1] | |
proj_dp_rev = F.grid_sample(original_depth.permute(0, 3, 1, 2), uv_rev.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1] | |
use_mask_inp = dp_inp <= proj_dp_inp + THRESHOLD | |
use_mask_rev = dp_rev <= proj_dp_rev + THRESHOLD | |
# for those points we can see in two sides, we use average | |
use_mask_inp = use_mask_inp.int() | |
use_mask_rev = use_mask_rev.int() | |
both_vis = (use_mask_inp == 1) & (use_mask_rev == 1) | |
use_mask_inp[both_vis] = 0.5 | |
use_mask_rev[both_vis] = 0.5 | |
feature = proj_feat_inp * use_mask_inp.repeat(1, 1, proj_feat_inp.shape[-1]) + proj_feat_rev * use_mask_rev.repeat(1, 1, proj_feat_rev.shape[-1]) | |
else: | |
raise NotImplementedError | |
ret_feature = self.local_feature_block(feature.reshape(b*k, -1)) # the linear is without bias, so 0 value feature will still get 0 value | |
return ret_feature | |
def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None): | |
# assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] <= self.extra_dim) | |
b, h, w, c = x.shape | |
xyz_before_sym = None | |
if self.symmetrize: | |
xyz_before_sym = x.reshape(b, -1, c) | |
xs, ys, zs = x.unbind(-1) | |
x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x | |
mvp = mvp.detach() # [b, 4, 4] | |
w2c = w2c.detach() # [b, 4, 4] | |
pts_xyz = x.reshape(b, -1, c) | |
deform_xyz = deform_xyz.reshape(b, -1, c) | |
if 'global' in self.texture_way and 'local' in self.texture_way: | |
global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) | |
# local_feat = self.proj_sample(pts_xyz, feat_map, mvp, w2c, h, w, xyz_before_sym=xyz_before_sym) | |
local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w) | |
# feature_rep = torch.concat([global_feat, local_feat], dim=-1) | |
feature_rep = global_feat + local_feat | |
elif 'global' not in self.texture_way and 'local' in self.texture_way: | |
# local_feat = self.proj_sample(pts_xyz, feat_map, mvp, w2c, h, w, xyz_before_sym=xyz_before_sym) | |
local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w) | |
feature_rep = local_feat | |
elif 'global' in self.texture_way and 'local' not in self.texture_way: | |
global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) | |
feature_rep = global_feat | |
else: | |
global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) | |
feature_rep = global_feat | |
x = x.view(-1, c) | |
if self.embedder is not None: | |
x_in = self.embedder(x) | |
if self.embed_concat_pts: | |
x_in = torch.cat([x, x_in], -1) | |
else: | |
x_in = x | |
x_in = self.in_layer(x_in) | |
# if feat is not None: | |
# feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) | |
# x_in = torch.concat([x_in, feat], dim=-1) | |
x_in = torch.concat([x_in, feature_rep], dim=-1) | |
out = self.mlp(self.relu(x_in)) | |
if self.min_max is not None: | |
out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] | |
return out.view(b, h, w, -1) | |
class LiftTexture(nn.Module): | |
def __init__(self, | |
cin, | |
cout, | |
num_layers, | |
nf=256, | |
dropout=0, | |
activation=None, | |
min_max=None, | |
n_harmonic_functions=10, | |
omega0=1, | |
extra_dim=0, | |
embed_concat_pts=True, | |
perturb_normal=False, | |
symmetrize=False, | |
texture_way=None, | |
cam_pos_z_offset=10., | |
grid_scale=7., | |
local_feat_dim=128, | |
grid_size=32, | |
optim_latent=False): | |
super().__init__() | |
self.extra_dim = extra_dim | |
self.cam_pos_z_offset = cam_pos_z_offset | |
self.grid_scale = grid_scale | |
assert texture_way is not None | |
self.extra_dim = local_feat_dim + extra_dim | |
if n_harmonic_functions > 0: | |
self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) | |
dim_in = cin * 2 * n_harmonic_functions | |
self.embed_concat_pts = embed_concat_pts | |
if embed_concat_pts: | |
dim_in += cin | |
else: | |
self.embedder = None | |
dim_in = cin | |
self.encoder = Lift_Encoder( | |
cin=384, | |
feat_dim=local_feat_dim, | |
grid_scale=grid_scale / 2, # the dmtet is initialized in (-0.5, 0.5) | |
grid_size=grid_size, | |
optim_latent=optim_latent, | |
with_z_feature=True, | |
cam_pos_z_offset=cam_pos_z_offset | |
) | |
self.in_layer = nn.Linear(dim_in, nf) | |
self.relu = nn.ReLU(inplace=True) | |
self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation) | |
self.perturb_normal = perturb_normal | |
self.symmetrize = symmetrize | |
if min_max is not None: | |
self.register_buffer('min_max', min_max) | |
else: | |
self.min_max = None | |
self.bsdf = None | |
def get_uv_depth(self, xyz, mvp): | |
# xyz: [b, k, 3] | |
# mvp: [b, 4, 4] | |
cam4 = torch.matmul(torch.nn.functional.pad(xyz, pad=(0,1), mode='constant', value=1.0), torch.transpose(mvp, 1, 2)) | |
cam3 = cam4[..., :3] / cam4[..., 3:4] | |
cam_uv = cam3[..., :2] | |
# cam_uv = cam_uv.detach() | |
cam_depth = cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(xyz.device).view(1, 1, 3) | |
cam_depth = cam_depth / self.grid_scale * 2 | |
cam_depth = cam_depth[..., 2:3] | |
# cam_depth = cam_depth.detach() | |
return cam_uv, cam_depth | |
def proj_sample_deform(self, xyz, feat_map, mvp, w2c, img_h, img_w): | |
# here the xyz is deformed points | |
# and we don't cast any symmtery here | |
if isinstance(feat_map, torch.Tensor): | |
feature = self.encoder(feat_map, mvp, xyz, inference="unproject") | |
elif isinstance(feat_map, dict): | |
feature = self.encoder(feat_map['im_features_map'], mvp, xyz, inference="sample") | |
C = feature.shape[-1] | |
feature = feature.reshape(-1, C) | |
return feature | |
def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None): | |
# assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] <= self.extra_dim) | |
b, h, w, c = x.shape | |
xyz_before_sym = None | |
if self.symmetrize: | |
xyz_before_sym = x.reshape(b, -1, c) | |
xs, ys, zs = x.unbind(-1) | |
x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x | |
mvp = mvp.detach() # [b, 4, 4] | |
w2c = w2c.detach() # [b, 4, 4] | |
pts_xyz = x.reshape(b, -1, c) | |
deform_xyz = deform_xyz.reshape(b, -1, c) | |
global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) | |
local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w) | |
feature_rep = torch.concat([global_feat, local_feat], dim=-1) | |
x = x.view(-1, c) | |
if self.embedder is not None: | |
x_in = self.embedder(x) | |
if self.embed_concat_pts: | |
x_in = torch.cat([x, x_in], -1) | |
else: | |
x_in = x | |
x_in = self.in_layer(x_in) | |
# if feat is not None: | |
# feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) | |
# x_in = torch.concat([x_in, feat], dim=-1) | |
x_in = torch.concat([x_in, feature_rep], dim=-1) | |
out = self.mlp(self.relu(x_in)) | |
if self.min_max is not None: | |
out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] | |
return out.view(b, h, w, -1) | |
class HarmonicEmbedding(nn.Module): | |
def __init__(self, n_harmonic_functions=10, omega0=1): | |
""" | |
Positional Embedding implementation (adapted from Pytorch3D). | |
Given an input tensor `x` of shape [minibatch, ... , dim], | |
the harmonic embedding layer converts each feature | |
in `x` into a series of harmonic features `embedding` | |
as follows: | |
embedding[..., i*dim:(i+1)*dim] = [ | |
sin(x[..., i]), | |
sin(2*x[..., i]), | |
sin(4*x[..., i]), | |
... | |
sin(2**self.n_harmonic_functions * x[..., i]), | |
cos(x[..., i]), | |
cos(2*x[..., i]), | |
cos(4*x[..., i]), | |
... | |
cos(2**self.n_harmonic_functions * x[..., i]) | |
] | |
Note that `x` is also premultiplied by `omega0` before | |
evaluting the harmonic functions. | |
""" | |
super().__init__() | |
self.frequencies = omega0 * (2.0 ** torch.arange(n_harmonic_functions)) | |
def forward(self, x): | |
""" | |
Args: | |
x: tensor of shape [..., dim] | |
Returns: | |
embedding: a harmonic embedding of `x` | |
of shape [..., n_harmonic_functions * dim * 2] | |
""" | |
embed = (x[..., None] * self.frequencies.to(x.device)).view(*x.shape[:-1], -1) | |
return torch.cat((embed.sin(), embed.cos()), dim=-1) | |
class VGGEncoder(nn.Module): | |
def __init__(self, cout, pretrained=False): | |
super().__init__() | |
if pretrained: | |
raise NotImplementedError | |
vgg = models.vgg16() | |
self.vgg_encoder = nn.Sequential(vgg.features, vgg.avgpool) | |
self.linear1 = nn.Linear(25088, 4096) | |
self.linear2 = nn.Linear(4096, cout) | |
self.relu = nn.ReLU(inplace=True) | |
def forward(self, x): | |
batch_size, _, _, _ = x.shape | |
out = self.relu(self.linear1(self.vgg_encoder(x).view(batch_size, -1))) | |
return self.linear2(out) | |
class ResnetEncoder(nn.Module): | |
def __init__(self, cout, pretrained=False): | |
super().__init__() | |
self.resnet = nn.Sequential(list(models.resnet18(weights="DEFAULT" if pretrained else None).modules())[:-1]) | |
self.final_linear = nn.Linear(512, cout) | |
def forward(self, x): | |
return self.final_linear(self.resnet(x)) | |
class Encoder(nn.Module): | |
def __init__(self, cin, cout, in_size=128, zdim=None, nf=64, activation=None): | |
super().__init__() | |
network = [ | |
nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64 | |
nn.GroupNorm(16, nf), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 | |
nn.GroupNorm(16*2, nf*2), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 | |
nn.GroupNorm(16*4, nf*4), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 | |
# nn.GroupNorm(16*8, nf*8), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
] | |
add_downsample = int(np.log2(in_size//128)) | |
if add_downsample > 0: | |
for _ in range(add_downsample): | |
network += [ | |
nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 | |
# nn.GroupNorm(16*8, nf*8), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
] | |
network += [ | |
nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 | |
nn.LeakyReLU(0.2, inplace=True), | |
] | |
if zdim is None: | |
network += [ | |
nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 | |
] | |
else: | |
network += [ | |
nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False), | |
] | |
if activation is not None: | |
network += [get_activation(activation)] | |
self.network = nn.Sequential(*network) | |
def forward(self, input): | |
return self.network(input).reshape(input.size(0), -1) | |
class EncoderWithDINO(nn.Module): | |
def __init__(self, cin_rgb, cin_dino, cout, in_size=128, zdim=None, nf=64, activation=None): | |
super().__init__() | |
network_rgb_in = [ | |
nn.Conv2d(cin_rgb, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64 | |
nn.GroupNorm(16, nf), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 | |
nn.GroupNorm(16*2, nf*2), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 | |
nn.GroupNorm(16*4, nf*4), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
] | |
self.network_rgb_in = nn.Sequential(*network_rgb_in) | |
network_dino_in = [ | |
nn.Conv2d(cin_dino, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64 | |
nn.GroupNorm(16, nf), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 | |
nn.GroupNorm(16*2, nf*2), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 | |
nn.GroupNorm(16*4, nf*4), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
] | |
self.network_dino_in = nn.Sequential(*network_dino_in) | |
network_fusion = [ | |
nn.Conv2d(nf*4*2, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 | |
# nn.GroupNorm(16*8, nf*8), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
] | |
add_downsample = int(np.log2(in_size//128)) | |
if add_downsample > 0: | |
for _ in range(add_downsample): | |
network_fusion += [ | |
nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 | |
# nn.GroupNorm(16*8, nf*8), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
] | |
network_fusion += [ | |
nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 | |
nn.LeakyReLU(0.2, inplace=True), | |
] | |
if zdim is None: | |
network_fusion += [ | |
nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 | |
] | |
else: | |
network_fusion += [ | |
nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False), | |
] | |
if activation is not None: | |
network_fusion += [get_activation(activation)] | |
self.network_fusion = nn.Sequential(*network_fusion) | |
def forward(self, rgb_image, dino_image): | |
rgb_feat = self.network_rgb_in(rgb_image) | |
dino_feat = self.network_dino_in(dino_image) | |
out = self.network_fusion(torch.cat([rgb_feat, dino_feat], dim=1)) | |
return out.reshape(rgb_image.size(0), -1) | |
class Encoder32(nn.Module): | |
def __init__(self, cin, cout, nf=256, activation=None): | |
super().__init__() | |
network = [ | |
nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 | |
nn.GroupNorm(nf//4, nf), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf, nf, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 | |
nn.GroupNorm(nf//4, nf), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf, nf, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 | |
nn.GroupNorm(nf//4, nf), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 | |
] | |
if activation is not None: | |
network += [get_activation(activation)] | |
self.network = nn.Sequential(*network) | |
def forward(self, input): | |
return self.network(input).reshape(input.size(0), -1) | |
class MLP(nn.Module): | |
def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None, inner_act='relu', linear_bias=False): | |
super().__init__() | |
assert num_layers >= 1 | |
layer_act = get_activation(inner_act) | |
if num_layers == 1: | |
network = [nn.Linear(cin, cout, bias=linear_bias)] | |
else: | |
# network = [nn.Linear(cin, nf, bias=False)] | |
# for _ in range(num_layers-2): | |
# network += [ | |
# nn.ReLU(inplace=True), | |
# nn.Linear(nf, nf, bias=False)] | |
# if dropout: | |
# network += [nn.Dropout(dropout)] | |
# network += [ | |
# nn.ReLU(inplace=True), | |
# nn.Linear(nf, cout, bias=False)] | |
network = [nn.Linear(cin, nf, bias=linear_bias)] | |
for _ in range(num_layers-2): | |
network += [ | |
layer_act, | |
nn.Linear(nf, nf, bias=linear_bias)] | |
if dropout: | |
network += [nn.Dropout(dropout)] | |
network += [ | |
layer_act, | |
nn.Linear(nf, cout, bias=linear_bias)] | |
if activation is not None: | |
network += [get_activation(activation)] | |
self.network = nn.Sequential(*network) | |
def forward(self, input): | |
return self.network(input) | |
class Embedding(nn.Module): | |
def __init__(self, cin, cout, zdim=128, nf=64, activation=None): | |
super().__init__() | |
network = [ | |
nn.Linear(cin, nf, bias=False), | |
nn.ReLU(inplace=True), | |
nn.Linear(nf, zdim, bias=False), | |
nn.ReLU(inplace=True), | |
nn.Linear(zdim, cout, bias=False)] | |
if activation is not None: | |
network += [get_activation(activation)] | |
self.network = nn.Sequential(*network) | |
def forward(self, input): | |
return self.network(input.reshape(input.size(0), -1)).reshape(input.size(0), -1) | |
class PerceptualLoss(nn.Module): | |
def __init__(self, requires_grad=False): | |
super(PerceptualLoss, self).__init__() | |
mean_rgb = torch.FloatTensor([0.485, 0.456, 0.406]) | |
std_rgb = torch.FloatTensor([0.229, 0.224, 0.225]) | |
self.register_buffer('mean_rgb', mean_rgb) | |
self.register_buffer('std_rgb', std_rgb) | |
vgg_pretrained_features = torchvision.models.vgg16(pretrained=True).features | |
self.slice1 = nn.Sequential() | |
self.slice2 = nn.Sequential() | |
self.slice3 = nn.Sequential() | |
self.slice4 = nn.Sequential() | |
for x in range(4): | |
self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(4, 9): | |
self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(9, 16): | |
self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(16, 23): | |
self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
if not requires_grad: | |
for param in self.parameters(): | |
param.requires_grad = False | |
def normalize(self, x): | |
out = x/2 + 0.5 | |
out = (out - self.mean_rgb.view(1,3,1,1)) / self.std_rgb.view(1,3,1,1) | |
return out | |
def __call__(self, im1, im2, mask=None, conf_sigma=None): | |
im = torch.cat([im1,im2], 0) | |
im = self.normalize(im) # normalize input | |
## compute features | |
feats = [] | |
f = self.slice1(im) | |
feats += [torch.chunk(f, 2, dim=0)] | |
f = self.slice2(f) | |
feats += [torch.chunk(f, 2, dim=0)] | |
f = self.slice3(f) | |
feats += [torch.chunk(f, 2, dim=0)] | |
f = self.slice4(f) | |
feats += [torch.chunk(f, 2, dim=0)] | |
losses = [] | |
for f1, f2 in feats[2:3]: # use relu3_3 features only | |
loss = (f1-f2)**2 | |
if conf_sigma is not None: | |
loss = loss / (2*conf_sigma**2 +EPS) + (conf_sigma +EPS).log() | |
if mask is not None: | |
b, c, h, w = loss.shape | |
_, _, hm, wm = mask.shape | |
sh, sw = hm//h, wm//w | |
mask0 = nn.functional.avg_pool2d(mask, kernel_size=(sh,sw), stride=(sh,sw)).expand_as(loss) | |
loss = (loss * mask0).sum() / mask0.sum() | |
else: | |
loss = loss.mean() | |
losses += [loss] | |
return sum(losses) | |
## from: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py | |
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |
"""3x3 convolution with padding""" | |
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |
padding=dilation, groups=groups, bias=False, dilation=dilation) | |
def conv1x1(in_planes, out_planes, stride=1): | |
"""1x1 convolution""" | |
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |
class BasicBlock(nn.Module): | |
expansion = 1 | |
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | |
base_width=64, dilation=1, norm_layer=None): | |
super(BasicBlock, self).__init__() | |
if groups != 1 or base_width != 64: | |
raise ValueError('BasicBlock only supports groups=1 and base_width=64') | |
if dilation > 1: | |
raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |
# Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |
self.conv1 = conv3x3(inplanes, planes, stride) | |
self.relu = nn.ReLU(inplace=True) | |
self.conv2 = conv3x3(planes, planes) | |
self.norm_layer = norm_layer | |
if norm_layer is not None: | |
self.bn1 = norm_layer(planes) | |
self.bn2 = norm_layer(planes) | |
if inplanes != planes: | |
self.downsample = nn.Sequential( | |
conv1x1(inplanes, planes, stride), | |
norm_layer(planes), | |
) | |
else: | |
self.downsample = None | |
self.stride = stride | |
def forward(self, x): | |
identity = x | |
out = self.conv1(x) | |
if self.norm_layer is not None: | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
if self.norm_layer is not None: | |
out = self.bn2(out) | |
if self.downsample is not None: | |
identity = self.downsample(x) | |
out += identity | |
out = self.relu(out) | |
return out | |
class ResEncoder(nn.Module): | |
def __init__(self, cin, cout, in_size=128, zdim=None, nf=64, activation=None): | |
super().__init__() | |
network = [ | |
nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64 | |
# nn.GroupNorm(16, nf), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 | |
# nn.GroupNorm(16*2, nf*2), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
BasicBlock(nf*2, nf*2, norm_layer=None), | |
BasicBlock(nf*2, nf*2, norm_layer=None), | |
nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 | |
# nn.GroupNorm(16*4, nf*4), | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
BasicBlock(nf*4, nf*4, norm_layer=None), | |
BasicBlock(nf*4, nf*4, norm_layer=None), | |
nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
BasicBlock(nf*8, nf*8, norm_layer=None), | |
BasicBlock(nf*8, nf*8, norm_layer=None), | |
] | |
add_downsample = int(np.log2(in_size//64)) | |
if add_downsample > 0: | |
for _ in range(add_downsample): | |
network += [ | |
nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
BasicBlock(nf*8, nf*8, norm_layer=None), | |
BasicBlock(nf*8, nf*8, norm_layer=None), | |
] | |
if zdim is None: | |
network += [ | |
nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 | |
] | |
else: | |
network += [ | |
nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 | |
# nn.ReLU(inplace=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False), | |
] | |
if activation is not None: | |
network += [get_activation(activation)] | |
self.network = nn.Sequential(*network) | |
def forward(self, input): | |
return self.network(input).reshape(input.size(0), -1) | |
class Attention(nn.Module): | |
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = qk_scale or head_dim ** -0.5 | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x): | |
B, N, C = x.shape | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv[0], qkv[1], qkv[2] | |
attn = (q @ k.transpose(-2, -1)) * self.scale | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x, attn | |
class ViTEncoder(nn.Module): | |
def __init__(self, cout, which_vit='dino_vits8', pretrained=False, frozen=False, in_size=256, final_layer_type='none', root='/root'): | |
super().__init__() | |
if misc.is_main_process(): | |
force_reload = not os.path.exists(os.path.join(root, ".cache/torch/hub/checkpoints/")) | |
else: | |
force_reload = False | |
if "dinov2" in which_vit: | |
self.ViT = torch.hub.load('facebookresearch/dinov2:main', which_vit, pretrained=pretrained, force_reload=force_reload) | |
else: | |
self.ViT = torch.hub.load('facebookresearch/dino:main', which_vit, pretrained=pretrained, force_reload=force_reload) | |
if frozen: | |
for p in self.ViT.parameters(): | |
p.requires_grad = False | |
if which_vit == 'dino_vits8': | |
self.vit_feat_dim = 384 | |
self.patch_size = 8 | |
elif which_vit == 'dinov2_vits14': | |
self.vit_feat_dim = 384 | |
self.patch_size = 14 | |
elif which_vit == 'dino_vitb8': | |
self.vit_feat_dim = 768 | |
self.patch_size = 8 | |
self._feats = [] | |
self.hook_handlers = [] | |
if final_layer_type == 'none': | |
pass | |
elif final_layer_type == 'conv': | |
self.final_layer_patch_out = Encoder32(self.vit_feat_dim, cout, nf=256, activation=None) | |
self.final_layer_patch_key = Encoder32(self.vit_feat_dim, cout, nf=256, activation=None) | |
elif final_layer_type == 'attention': | |
raise NotImplementedError | |
self.final_layer = Attention( | |
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) | |
self.fc = nn.Linear(self.vit_feat_dim, cout) | |
else: | |
raise NotImplementedError | |
self.final_layer_type = final_layer_type | |
def _get_hook(self, facet: str): | |
""" | |
generate a hook method for a specific block and facet. | |
""" | |
if facet in ['attn', 'token']: | |
def _hook(model, input, output): | |
self._feats.append(output) | |
return _hook | |
if facet == 'query': | |
facet_idx = 0 | |
elif facet == 'key': | |
facet_idx = 1 | |
elif facet == 'value': | |
facet_idx = 2 | |
else: | |
raise TypeError(f"{facet} is not a supported facet.") | |
def _inner_hook(module, input, output): | |
input = input[0] | |
B, N, C = input.shape | |
qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) | |
self._feats.append(qkv[facet_idx]) #Bxhxtxd | |
return _inner_hook | |
def _register_hooks(self, layers: List[int], facet: str) -> None: | |
""" | |
register hook to extract features. | |
:param layers: layers from which to extract features. | |
:param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] | |
""" | |
for block_idx, block in enumerate(self.ViT.blocks): | |
if block_idx in layers: | |
if facet == 'token': | |
self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet))) | |
elif facet == 'attn': | |
self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet))) | |
elif facet in ['key', 'query', 'value']: | |
self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet))) | |
else: | |
raise TypeError(f"{facet} is not a supported facet.") | |
def _unregister_hooks(self) -> None: | |
""" | |
unregisters the hooks. should be called after feature extraction. | |
""" | |
for handle in self.hook_handlers: | |
handle.remove() | |
self.hook_handlers = [] | |
def forward(self, x, return_patches=False): | |
b, c, h, w = x.shape | |
self._feats = [] | |
self._register_hooks([11], 'key') | |
#self._register_hooks([11], 'token') | |
x = self.ViT.prepare_tokens(x) | |
#x = self.ViT.prepare_tokens_with_masks(x) | |
for blk in self.ViT.blocks: | |
x = blk(x) | |
out = self.ViT.norm(x) | |
self._unregister_hooks() | |
ph, pw = h // self.patch_size, w // self.patch_size | |
patch_out = out[:, 1:] # first is class token | |
patch_out = patch_out.reshape(b, ph, pw, self.vit_feat_dim).permute(0, 3, 1, 2) | |
patch_key = self._feats[0][:,:,1:] # B, num_heads, num_patches, dim | |
patch_key = patch_key.permute(0, 1, 3, 2).reshape(b, self.vit_feat_dim, ph, pw) | |
if self.final_layer_type == 'none': | |
global_feat_out = out[:, 0].reshape(b, -1) # first is class token | |
global_feat_key = self._feats[0][:, :, 0].reshape(b, -1) # first is class token | |
elif self.final_layer_type == 'conv': | |
global_feat_out = self.final_layer_patch_out(patch_out).view(b, -1) | |
global_feat_key = self.final_layer_patch_key(patch_key).view(b, -1) | |
elif self.final_layer_type == 'attention': | |
raise NotImplementedError | |
else: | |
raise NotImplementedError | |
if not return_patches: | |
patch_out = patch_key = None | |
return global_feat_out, global_feat_key, patch_out, patch_key | |
class ArticulationNetwork(nn.Module): | |
def __init__(self, net_type, feat_dim, pos_dim, num_layers, nf, n_harmonic_functions=0, omega0=1, activation=None, enable_articulation_idadd=False): | |
super().__init__() | |
if n_harmonic_functions > 0: | |
self.posenc = HarmonicEmbedding(n_harmonic_functions=n_harmonic_functions, omega0=omega0) | |
pos_dim = pos_dim * (n_harmonic_functions * 2 + 1) | |
else: | |
self.posenc = None | |
pos_dim = 4 | |
cout = 3 | |
if net_type == 'mlp': | |
self.network = MLP( | |
feat_dim + pos_dim, # + bone xyz pos and index | |
cout, # We represent the rotation of each bone by its Euler angles ψ, θ, and φ | |
num_layers, | |
nf=nf, | |
dropout=0, | |
activation=activation | |
) | |
elif net_type == 'attention': | |
self.in_layer = nn.Sequential( | |
nn.Linear(feat_dim + pos_dim, nf), | |
nn.GELU(), | |
nn.LayerNorm(nf), | |
) | |
self.blocks = nn.ModuleList([ | |
Block( | |
dim=nf, num_heads=8, mlp_ratio=2., qkv_bias=False, qk_scale=None, | |
drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm) | |
for i in range(num_layers)]) | |
out_layer = [nn.Linear(nf, cout)] | |
if activation: | |
out_layer += [get_activation(activation)] | |
self.out_layer = nn.Sequential(*out_layer) | |
else: | |
raise NotImplementedError | |
self.net_type = net_type | |
self.enable_articulation_idadd = enable_articulation_idadd | |
def forward(self, x, pos): | |
pos_inp = pos | |
if self.posenc is not None: | |
pos = torch.cat([pos, self.posenc(pos)], dim=-1) | |
x = torch.cat([x, pos], dim=-1) | |
if self.enable_articulation_idadd: | |
articulation_id = pos_inp[..., -1:] | |
x = x + articulation_id | |
if self.net_type == 'mlp': | |
out = self.network(x) | |
elif self.net_type == 'attention': | |
x = self.in_layer(x) | |
for blk in self.blocks: | |
x = blk(x) | |
out = self.out_layer(x) | |
else: | |
raise NotImplementedError | |
return out | |
## Attention block from ViT (https://github.com/facebookresearch/dino/blob/main/vision_transformer.py) | |
class Attention(nn.Module): | |
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = qk_scale or head_dim ** -0.5 | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x): | |
B, N, C = x.shape | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv[0], qkv[1], qkv[2] | |
attn = (q @ k.transpose(-2, -1)) * self.scale | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x, attn | |
class Mlp(nn.Module): | |
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class DropPath(nn.Module): | |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
""" | |
def __init__(self, drop_prob=None): | |
super(DropPath, self).__init__() | |
self.drop_prob = drop_prob | |
def forward(self, x): | |
return drop_path(x, self.drop_prob, self.training) | |
class Block(nn.Module): | |
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., | |
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): | |
super().__init__() | |
self.norm1 = norm_layer(dim) | |
self.attn = Attention( | |
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) | |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) | |
def forward(self, x, return_attention=False): | |
y, attn = self.attn(self.norm1(x)) | |
if return_attention: | |
return attn | |
x = x + self.drop_path(y) | |
x = x + self.drop_path(self.mlp(self.norm2(x))) | |
return x | |
class FeatureAttention(nn.Module): | |
def __init__(self, vit_type, pos_dim, embedder_freq=0, zdim=128, img_size=256, activation=None): | |
super().__init__() | |
self.zdim = zdim | |
if embedder_freq > 0: | |
self.posenc = HarmonicEmbedding(n_harmonic_functions=embedder_freq, omega0=1) | |
pos_dim = pos_dim * (embedder_freq * 2 + 1) | |
else: | |
self.posenc = None | |
self.pos_dim = pos_dim | |
if vit_type == 'dino_vits8': | |
self.vit_feat_dim = 384 | |
patch_size = 8 | |
elif which_vit == 'dinov2_vits14': | |
self.vit_feat_dim = 384 | |
self.patch_size = 14 | |
elif vit_type == 'dino_vitb8': | |
self.vit_feat_dim = 768 | |
patch_size = 8 | |
else: | |
raise NotImplementedError | |
self.num_patches_per_dim = img_size // patch_size | |
self.kv = nn.Sequential( | |
nn.Linear(self.vit_feat_dim, zdim), | |
nn.ReLU(inplace=True), | |
nn.LayerNorm(zdim), | |
nn.Linear(zdim, zdim*2), | |
) | |
self.q = nn.Sequential( | |
nn.Linear(pos_dim, zdim), | |
nn.ReLU(inplace=True), | |
nn.LayerNorm(zdim), | |
nn.Linear(zdim, zdim), | |
) | |
final_mlp = [ | |
nn.Linear(zdim, zdim), | |
nn.ReLU(inplace=True), | |
nn.LayerNorm(zdim), | |
nn.Linear(zdim, self.vit_feat_dim) | |
] | |
if activation is not None: | |
final_mlp += [get_activation(activation)] | |
self.final_ln = nn.Sequential(*final_mlp) | |
def forward(self, x, feat): | |
_, vit_feat_dim, ph, pw = feat.shape | |
assert ph == pw and ph == self.num_patches_per_dim and vit_feat_dim == self.vit_feat_dim | |
if self.posenc is not None: | |
x = torch.cat([x, self.posenc(x)], dim=-1) | |
bxf, k, c = x.shape | |
assert c == self.pos_dim | |
query = self.q(x) | |
feat_in = feat.view(bxf, vit_feat_dim, ph*pw).permute(0, 2, 1) # N, K, C | |
k, v = self.kv(feat_in).chunk(2, dim=-1) | |
attn = torch.einsum('bnd,bpd->bnp', query, k).softmax(dim=-1) | |
out = torch.einsum('bnp,bpd->bnd', attn, v) | |
out = self.final_ln(out) | |
return out | |