|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import io, os, logging, urllib |
|
import yaml |
|
import trimesh |
|
import imageio |
|
import numbers |
|
import math |
|
import numpy as np |
|
from collections import OrderedDict |
|
from plyfile import PlyData |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from torch.utils import model_zoo |
|
from skimage import measure, img_as_float32 |
|
from igl import adjacency_matrix, connected_components |
|
|
|
|
|
|
|
|
|
def fftfreqs(res, dtype=torch.float32, exact=True): |
|
""" |
|
Helper function to return frequency tensors |
|
:param res: n_dims int tuple of number of frequency modes |
|
:return: |
|
""" |
|
|
|
n_dims = len(res) |
|
freqs = [] |
|
for dim in range(n_dims - 1): |
|
r_ = res[dim] |
|
freq = np.fft.fftfreq(r_, d=1/r_) |
|
freqs.append(torch.tensor(freq, dtype=dtype)) |
|
r_ = res[-1] |
|
if exact: |
|
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype)) |
|
else: |
|
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype)) |
|
omega = torch.meshgrid(freqs) |
|
omega = list(omega) |
|
omega = torch.stack(omega, dim=-1) |
|
|
|
return omega |
|
|
|
def img(x, deg=1): |
|
""" |
|
multiply tensor x by i ** deg |
|
""" |
|
deg %= 4 |
|
if deg == 0: |
|
res = x |
|
elif deg == 1: |
|
res = x[..., [1, 0]] |
|
res[..., 0] = -res[..., 0] |
|
elif deg == 2: |
|
res = -x |
|
elif deg == 3: |
|
res = x[..., [1, 0]] |
|
res[..., 1] = -res[..., 1] |
|
return res |
|
|
|
def spec_gaussian_filter(res, sig): |
|
omega = fftfreqs(res, dtype=torch.float64) |
|
dis = torch.sqrt(torch.sum(omega ** 2, dim=-1)) |
|
filter_ = torch.exp(-0.5*((sig*2*dis/res[0])**2)).unsqueeze(-1).unsqueeze(-1) |
|
filter_.requires_grad = False |
|
|
|
return filter_ |
|
|
|
def grid_interp(grid, pts, batched=True): |
|
""" |
|
:param grid: tensor of shape (batch, *size, in_features) |
|
:param pts: tensor of shape (batch, num_points, dim) within range (0, 1) |
|
:return values at query points |
|
""" |
|
if not batched: |
|
grid = grid.unsqueeze(0) |
|
pts = pts.unsqueeze(0) |
|
dim = pts.shape[-1] |
|
bs = grid.shape[0] |
|
size = torch.tensor(grid.shape[1:-1]).to(grid.device).type(pts.dtype) |
|
cubesize = 1.0 / size |
|
|
|
ind0 = torch.floor(pts / cubesize).long() |
|
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() |
|
ind01 = torch.stack((ind0, ind1), dim=0) |
|
tmp = torch.tensor([0,1],dtype=torch.long) |
|
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim) |
|
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) |
|
ind_ = ind01[com_, ..., dim_] |
|
ind_n = ind_.permute(2, 3, 0, 1) |
|
ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) |
|
|
|
if dim == 2: |
|
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] |
|
else: |
|
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2]] |
|
|
|
|
|
xyz0 = ind0.type(cubesize.dtype) * cubesize |
|
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize |
|
xyz01 = torch.stack((xyz0, xyz1), dim=0) |
|
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) |
|
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) |
|
pos_ = pos_.type(pts.dtype) |
|
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize |
|
weights = torch.prod(dxyz_, dim=-1, keepdim=False) |
|
query_values = torch.sum(lat * weights.unsqueeze(-1), dim=-2) |
|
if not batched: |
|
query_values = query_values.squeeze(0) |
|
|
|
return query_values |
|
|
|
def scatter_to_grid(inds, vals, size): |
|
""" |
|
Scatter update values into empty tensor of size size. |
|
:param inds: (#values, dims) |
|
:param vals: (#values) |
|
:param size: tuple for size. len(size)=dims |
|
""" |
|
dims = inds.shape[1] |
|
assert(inds.shape[0] == vals.shape[0]) |
|
assert(len(size) == dims) |
|
dev = vals.device |
|
|
|
|
|
result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) |
|
|
|
fac = [np.prod(size[i+1:]) for i in range(len(size)-1)] + [1] |
|
fac = torch.tensor(fac, device=dev).type(inds.dtype) |
|
inds_fold = torch.sum(inds*fac, dim=-1) |
|
result.scatter_add_(0, inds_fold, vals) |
|
result = result.view(*size) |
|
return result |
|
|
|
def point_rasterize(pts, vals, size): |
|
""" |
|
:param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1) |
|
:param vals: point values, tensor of shape (batch, num_points, features) |
|
:param size: len(size)=dim tuple for grid size |
|
:return rasterized values (batch, features, res0, res1, res2) |
|
""" |
|
dim = pts.shape[-1] |
|
assert(pts.shape[:2] == vals.shape[:2]) |
|
assert(pts.shape[2] == dim) |
|
size_list = list(size) |
|
size = torch.tensor(size).to(pts.device).float() |
|
cubesize = 1.0 / size |
|
bs = pts.shape[0] |
|
nf = vals.shape[-1] |
|
npts = pts.shape[1] |
|
dev = pts.device |
|
|
|
ind0 = torch.floor(pts / cubesize).long() |
|
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() |
|
ind01 = torch.stack((ind0, ind1), dim=0) |
|
tmp = torch.tensor([0,1],dtype=torch.long) |
|
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim) |
|
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) |
|
ind_ = ind01[com_, ..., dim_] |
|
ind_n = ind_.permute(2, 3, 0, 1) |
|
|
|
ind_b = torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) |
|
|
|
|
|
xyz0 = ind0.type(cubesize.dtype) * cubesize |
|
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize |
|
xyz01 = torch.stack((xyz0, xyz1), dim=0) |
|
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) |
|
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) |
|
pos_ = pos_.type(pts.dtype) |
|
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize |
|
weights = torch.prod(dxyz_, dim=-1, keepdim=False) |
|
|
|
ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) |
|
ind_n = ind_n.unsqueeze(-2) |
|
ind_f = torch.arange(nf, device=dev).view(1, 1, 1, nf, 1) |
|
|
|
|
|
ind_b = ind_b.expand(bs, npts, 2**dim, nf, 1) |
|
ind_n = ind_n.expand(bs, npts, 2**dim, nf, dim).to(dev) |
|
ind_f = ind_f.expand(bs, npts, 2**dim, nf, 1) |
|
inds = torch.cat([ind_b, ind_f, ind_n], dim=-1) |
|
|
|
|
|
vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) |
|
|
|
inds = inds.view(-1, dim+2).permute(1, 0).long() |
|
vals = vals.reshape(-1) |
|
tensor_size = [bs, nf] + size_list |
|
raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf] + size_list) |
|
|
|
return raster |
|
|
|
|
|
|
|
|
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.n = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.n = n |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
@property |
|
def valcavg(self): |
|
return self.val.sum().item() / (self.n != 0).sum().item() |
|
|
|
@property |
|
def avgcavg(self): |
|
return self.avg.sum().item() / (self.count != 0).sum().item() |
|
|
|
def load_model_manual(state_dict, model): |
|
new_state_dict = OrderedDict() |
|
is_model_parallel = isinstance(model, torch.nn.DataParallel) |
|
for k, v in state_dict.items(): |
|
if k.startswith('module.') != is_model_parallel: |
|
if k.startswith('module.'): |
|
|
|
k = k[7:] |
|
else: |
|
|
|
k = 'module.' + k |
|
|
|
new_state_dict[k]=v |
|
|
|
model.load_state_dict(new_state_dict) |
|
|
|
def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0): |
|
''' |
|
Run marching cubes from PSR grid |
|
''' |
|
batch_size = psr_grid.shape[0] |
|
s = psr_grid.shape[-1] |
|
psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy() |
|
|
|
if batch_size>1: |
|
verts, faces, normals = [], [], [] |
|
for i in range(batch_size): |
|
verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0) |
|
verts.append(verts_cur) |
|
faces.append(faces_cur) |
|
normals.append(normals_cur) |
|
verts = np.stack(verts, axis = 0) |
|
faces = np.stack(faces, axis = 0) |
|
normals = np.stack(normals, axis = 0) |
|
else: |
|
try: |
|
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level) |
|
except: |
|
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy) |
|
if real_scale: |
|
verts = verts / (s-1) |
|
else: |
|
verts = verts / s |
|
|
|
if pytorchify: |
|
device = psr_grid.device |
|
verts = torch.Tensor(np.ascontiguousarray(verts)).to(device) |
|
faces = torch.Tensor(np.ascontiguousarray(faces)).to(device) |
|
normals = torch.Tensor(np.ascontiguousarray(-normals)).to(device) |
|
|
|
return verts, faces, normals |
|
|
|
def calc_inters_points(verts, faces, pose, img_size, mask_gt=None): |
|
verts = verts.squeeze() |
|
faces = faces.squeeze() |
|
pix_to_face, w, mask = mesh_rasterization(verts, faces, pose, img_size) |
|
if mask_gt is not None: |
|
|
|
mask = mask & mask_gt |
|
|
|
if True: |
|
w_masked = w[mask] |
|
f_p = faces[pix_to_face[mask]].long() |
|
|
|
v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]] |
|
|
|
|
|
p_inters = w_masked[..., 0, None] * v_a + \ |
|
w_masked[..., 1, None] * v_b + \ |
|
w_masked[..., 2, None] * v_c |
|
else: |
|
|
|
W, H = img_size[1], img_size[0] |
|
xy = uv.to(mask.device)[mask] |
|
x_ndc = 1 - (2*xy[:, 0]) / (W - 1) |
|
y_ndc = 1 - (2*xy[:, 1]) / (H - 1) |
|
z = zbuf.squeeze().reshape(H * W)[mask] |
|
xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1) |
|
|
|
p_inters = pose.unproject_points(xy_depth, world_coordinates=True) |
|
|
|
|
|
if (p_inters.max()>1) | (p_inters.min()<-1): |
|
mask_bound = (p_inters>=-1) & (p_inters<=1) |
|
mask_bound = (mask_bound.sum(dim=-1)==3) |
|
mask[mask==True] = mask_bound |
|
p_inters = p_inters[mask_bound] |
|
print('!!!!!find outlier!') |
|
|
|
return p_inters, mask, f_p, w_masked |
|
|
|
def mesh_rasterization(verts, faces, pose, img_size): |
|
''' |
|
Use PyTorch3D to rasterize the mesh given a camera |
|
''' |
|
transformed_v = pose.transform_points(verts.detach()) |
|
if isinstance(pose, PerspectiveCameras): |
|
transformed_v[..., 2] = 1/transformed_v[..., 2] |
|
|
|
transformed_mesh = Meshes(verts=[transformed_v], faces=[faces]) |
|
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( |
|
transformed_mesh, |
|
image_size=img_size, |
|
blur_radius=0, |
|
faces_per_pixel=1, |
|
perspective_correct=False |
|
) |
|
pix_to_face = pix_to_face.reshape(1, -1) |
|
mask = pix_to_face.clone() != -1 |
|
mask = mask.squeeze() |
|
pix_to_face = pix_to_face.squeeze() |
|
w = bary_coords.reshape(-1, 3) |
|
|
|
return pix_to_face, w, mask |
|
|
|
def verts_on_largest_mesh(verts, faces): |
|
''' |
|
verts: Numpy array or Torch.Tensor (N, 3) |
|
faces: Numpy array (N, 3) |
|
''' |
|
if torch.is_tensor(faces): |
|
verts = verts.squeeze().detach().cpu().numpy() |
|
faces = faces.squeeze().int().detach().cpu().numpy() |
|
|
|
A = adjacency_matrix(faces) |
|
num, conn_idx, conn_size = connected_components(A) |
|
if num == 0: |
|
v_large, f_large = verts, faces |
|
else: |
|
max_idx = conn_size.argmax() |
|
v_large = verts[conn_idx==max_idx] |
|
|
|
if True: |
|
mesh_largest = trimesh.Trimesh(verts, faces) |
|
connected_comp = mesh_largest.split(only_watertight=False) |
|
mesh_largest = connected_comp[max_idx] |
|
v_large, f_large = mesh_largest.vertices, mesh_largest.faces |
|
v_large = v_large.astype(np.float32) |
|
return v_large, f_large |
|
|
|
def update_recursive(dict1, dict2): |
|
''' Update two config dictionaries recursively. |
|
|
|
Args: |
|
dict1 (dict): first dictionary to be updated |
|
dict2 (dict): second dictionary which entries should be used |
|
|
|
''' |
|
for k, v in dict2.items(): |
|
if k not in dict1: |
|
dict1[k] = dict() |
|
if isinstance(v, dict): |
|
update_recursive(dict1[k], v) |
|
else: |
|
dict1[k] = v |
|
|
|
def scale2onet(p, scale=1.2): |
|
''' |
|
Scale the point cloud from SAP to ONet range |
|
''' |
|
return (p - 0.5) * scale |
|
|
|
def update_optimizer(inputs, cfg, epoch, model=None, schedule=None): |
|
if model is not None: |
|
if schedule is not None: |
|
optimizer = torch.optim.Adam([ |
|
{"params": model.parameters(), |
|
"lr": schedule[0].get_learning_rate(epoch)}, |
|
{"params": inputs, |
|
"lr": schedule[1].get_learning_rate(epoch)}]) |
|
elif 'lr' in cfg['train']: |
|
optimizer = torch.optim.Adam([ |
|
{"params": model.parameters(), |
|
"lr": float(cfg['train']['lr'])}, |
|
{"params": inputs, |
|
"lr": float(cfg['train']['lr_pcl'])}]) |
|
else: |
|
raise Exception('no known learning rate') |
|
else: |
|
if schedule is not None: |
|
optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch)) |
|
else: |
|
optimizer = torch.optim.Adam([inputs], lr=float(cfg['train']['lr_pcl'])) |
|
|
|
return optimizer |
|
|
|
|
|
def is_url(url): |
|
scheme = urllib.parse.urlparse(url).scheme |
|
return scheme in ('http', 'https') |
|
|
|
def load_url(url): |
|
'''Load a module dictionary from url. |
|
|
|
Args: |
|
url (str): url to saved model |
|
''' |
|
print(url) |
|
print('=> Loading checkpoint from url...') |
|
state_dict = model_zoo.load_url(url, progress=True) |
|
|
|
return state_dict |
|
|
|
|
|
class GaussianSmoothing(nn.Module): |
|
""" |
|
Apply gaussian smoothing on a |
|
1d, 2d or 3d tensor. Filtering is performed seperately for each channel |
|
in the input using a depthwise convolution. |
|
Arguments: |
|
channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. |
|
kernel_size (int, sequence): Size of the gaussian kernel. |
|
sigma (float, sequence): Standard deviation of the gaussian kernel. |
|
dim (int, optional): The number of dimensions of the data. |
|
Default value is 2 (spatial). |
|
""" |
|
def __init__(self, channels, kernel_size, sigma, dim=3): |
|
super(GaussianSmoothing, self).__init__() |
|
if isinstance(kernel_size, numbers.Number): |
|
kernel_size = [kernel_size] * dim |
|
if isinstance(sigma, numbers.Number): |
|
sigma = [sigma] * dim |
|
|
|
|
|
|
|
kernel = 1 |
|
meshgrids = torch.meshgrid( |
|
[ |
|
torch.arange(size, dtype=torch.float32) |
|
for size in kernel_size |
|
] |
|
) |
|
for size, std, mgrid in zip(kernel_size, sigma, meshgrids): |
|
mean = (size - 1) / 2 |
|
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ |
|
torch.exp(-((mgrid - mean) / std) ** 2 / 2) |
|
|
|
|
|
kernel = kernel / torch.sum(kernel) |
|
|
|
|
|
kernel = kernel.view(1, 1, *kernel.size()) |
|
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) |
|
|
|
self.register_buffer('weight', kernel) |
|
self.groups = channels |
|
|
|
if dim == 1: |
|
self.conv = F.conv1d |
|
elif dim == 2: |
|
self.conv = F.conv2d |
|
elif dim == 3: |
|
self.conv = F.conv3d |
|
else: |
|
raise RuntimeError( |
|
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) |
|
) |
|
|
|
def forward(self, input): |
|
""" |
|
Apply gaussian filter to input. |
|
Arguments: |
|
input (torch.Tensor): Input to apply gaussian filter on. |
|
Returns: |
|
filtered (torch.Tensor): Filtered output. |
|
""" |
|
return self.conv(input, weight=self.weight, groups=self.groups) |
|
|
|
|
|
def get_learning_rate_schedules(schedule_specs): |
|
|
|
schedules = [] |
|
|
|
for key in schedule_specs.keys(): |
|
schedules.append(StepLearningRateSchedule( |
|
schedule_specs[key]['initial'], |
|
schedule_specs[key]["interval"], |
|
schedule_specs[key]["factor"], |
|
schedule_specs[key]["final"])) |
|
return schedules |
|
|
|
class LearningRateSchedule: |
|
def get_learning_rate(self, epoch): |
|
pass |
|
class StepLearningRateSchedule(LearningRateSchedule): |
|
def __init__(self, initial, interval, factor, final=1e-6): |
|
self.initial = float(initial) |
|
self.interval = interval |
|
self.factor = factor |
|
self.final = float(final) |
|
|
|
def get_learning_rate(self, epoch): |
|
lr = np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6) |
|
if lr > self.final: |
|
return lr |
|
else: |
|
return self.final |
|
|
|
def adjust_learning_rate(lr_schedules, optimizer, epoch): |
|
for i, param_group in enumerate(optimizer.param_groups): |
|
param_group["lr"] = lr_schedules[i].get_learning_rate(epoch) |