diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000000000000000000000000000000000000..33456016bc4d048dfdda99816d1b326ca6fff0e5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,184 @@ +__pycache__ +data +data/*/ +data/*/* +!data/preprocessing/ +pretrained/*/ +results +neural_renderer +*.zip +unchanged/ +cvpr23_results/ +# slurm.bash +results +results/*/ +results/* +results/*/* +results/dor_checkpoints/* +results/dor_checkpoints/*/* +results/dor_checkpoints/*/*/* + + +.vscode +.vscode/ + +dor_bash_files/ +zzli_bash_files/ +ray_bash_files/ + +config/dor_exp/ +config/zzli_exp/ +config/ray_exp/ + +wandb +wandb/*/ +wandb/*/* +wandb/*/*/* +canon/out/* +canon/out/ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ +/.idea + +# dependencies +# nvdiffrast/ +data/preprocessing/videos/RAFT/ +preprocessing_data/RAFT/ +preprocessing_data/RAFT/* +preprocessing_data/preprocessing/videos/RAFT/ +# debug + + +DINO_v2_check/out_dor +DINO_v2_check/out_dor/* + +eval/*/ +scripts/vis/ +eval/ diff --git a/ckpts/configs.yml b/ckpts/configs.yml new file mode 100644 index 0000000000000000000000000000000000000000..29ed6692eecf06fcbdda83672c10c1dde4ae6e6f --- /dev/null +++ b/ckpts/configs.yml @@ -0,0 +1,354 @@ +amb_diff_max: +- 1.0 +- 1.0 +amb_diff_min: +- 0.0 +- 0.5 +arti_reg_loss_epochs: +- 8 +- 276 +arti_reg_loss_weight: 0.2 +articulation_arch: attention +articulation_epochs: +- 2 +- 276 +articulation_feature_mode: sample+global +articulation_multiplier: 0.1 +attach_legs_to_body_epochs: +- 8 +- 276 +avg_seqshape_epochs: +- 0 +- 0 +avg_texture_epochs: +- 0 +- 0 +background_mode: none +backward_prior: true +bank_mean_dist_loss_weight: 0.0 +batch_size: 6 +best_pose_start_iter: 10000 +blur_mask: false +body_bone_idx_preset: + 0: + - 0 + - 0 + - 0 + - 0 + 500000: + - 0 + - 0 + - 0 + - 0 +body_bones_type: z_minmax_y+ +body_rotate_reg_mode: all-bones +bone_y_thresh: 0.4 +bsdf: diffuse +cam_pos_z_offset: 10 +checkpoint_dir: /viscam/u/zzli/workspace/4DAnimalKingdom_dev/results/paper_exp/same_dino_1109/mb_all_data_1k_artiID_r500k +clip_tex: false +clip_tex_loss_weight: 0.0 +combine_dataset: true +config: config/zzli_exp/same_dino_1109/mb_data1k_artiID_r500k.yml +constrain_legs: false +crop_fov_approx: 25 +data_loader_mode: n_frame +dataset: video +debug_seq: false +deform_epochs: +- 0 +- 276 +deformation_reg_loss_weight: 10.0 +device: cuda:0 +diffusion_albedo_ratio: 0.2 +diffusion_angle_front: 60 +diffusion_angle_overhead: 30 +diffusion_append_prompt_directions: true +diffusion_guidance_scale: 100 +diffusion_light_ambient: 0.5 +diffusion_light_diffuse: 0.8 +diffusion_loss_weight: 0.0001 +diffusion_max_step: 0.6 +diffusion_num_random_cameras: 1 +diffusion_phi_offset: 180 +diffusion_precision: float16 +diffusion_prompt: an elephant +diffusion_radius_range: +- 9 +- 11 +diffusion_random_light: true +diffusion_resolution: 256 +diffusion_shading_ratio: 0.4 +diffusion_theta_range: +- 0 +- 100 +diffusion_uniform_sphere_rate: 1 +dim_of_classes: 128 +dino_feat_im_loss_weight: + 0: 10.0 + 300000: 1.0 +dino_feature_dim: 16 +dino_feature_input: false +dino_feature_recon_dim: 16 +dino_max: 1.0 +dino_min: 0.0 +disable_fewshot: false +disc_gt: false +disc_iv: true +disc_iv_label: Real +disc_reg_mul: 10.0 +discriminator_loss_weight: 1.0 +dmtet_grid: 256 +dmtet_grid_smaller: 256 +dmtet_grid_smaller_epoch: 1 +embed_concat_pts: true +embedder_freq_arti: 8 +embedder_freq_deform: 10 +embedder_freq_dino: 8 +embedder_freq_shape: 8 +embedder_freq_tex: 10 +enable_articulation: true +enable_articulation_bone_threshold: true +enable_articulation_idadd: true +enable_deform: true +enable_disc: true +enable_encoder: true +enable_lighting: true +enable_mask_distribution: true +enable_memory_bank: true +enable_pose: true +enable_prior: true +enable_sds: false +encoder_arch: vit +encoder_frozen: true +encoder_pretrained: true +enhance_back_view: true +enhance_back_view_path: /viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/segmented_back_view_data +extra_renders: + instance: + - geo_normal + - diffuse + - gray +faces_per_pixel: 10 +few_shot_category_num: -1 +few_shot_class_vector_init: copy +few_shot_data_dir: +- /viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/few_shot_data_all +- /viscam/projects/articulated/dor/Animal-Data-Engine/data/data_resize_update/train_with_classes_filtered +few_shot_iteration_save: true +few_shot_iteration_save_freq: 2000 +few_shot_lr: 0.0001 +few_shot_optimize: exp +few_shot_optimize_bank: all +few_shot_original_classes_num: 7 +few_shot_resume: true +few_shot_test_category_names: +- caracal +- impala +- ox +- squirrel +- wolf +few_shot_test_category_num: 5 +few_shot_val_image_num: 5 +fix_viz_batch: false +flow_loss_epochs: +- 0 +- 0 +flow_loss_weight: 0.0 +forbid_leg_rotate: true +fov_w: 60 +full_size_h: 1080 +full_size_w: 1920 +gamma: 1e-6 +gan_tex: false +grid_scale: 7 +hidden_size: 256 +in_image_size: 256 +init_sdf: ellipsoid +is_dry_run: false +iter_arti_reg_loss_start: 60000 +iter_articulation_start: 20000 +iter_attach_leg_to_body_start: 60000 +iter_deformation_start: 500000 +iter_leg_rotation_start: 300000 +iter_nozeroy_start: 20000 +jitter_grid: 0.05 +kd_max: +- 1.0 +- 1.0 +- 1.0 +- 1.0 +kd_min: +- 0.0 +- 0.0 +- 0.0 +- 0.0 +keep_num_checkpoint: 1 +ks_max: +- 0.0 +- 0.0 +- 0.0 +ks_min: +- 0.0 +- 0.0 +- 0.0 +latent_dim: 256 +load_dino_cluster: false +load_dino_feature: true +log_freq_images: 501 +log_freq_losses: 50 +log_train_images: true +logit_loss_dino_feat_im_loss_multiplier: + 0: 50.0 + 300000: 500.0 +logit_loss_weight: 1.0 +lookat_init: +- 0.0 +- 0.0 +- 0.0 +lookat_zeroy: true +lr: 6.0e-05 +mask_disc_loss_feat_condition: true +mask_disc_loss_weight: 0.1 +mask_discriminator_iter: +- 80000 +- 300000 +mask_distribution_loss_freq: 1 +mask_distribution_loss_weight: 0.0 +mask_distribution_path: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/mask_distribution +max_arti_angle: 60 +max_trans_xy_range_ratio: 0.5 +max_trans_z_range_ratio: 0.5 +memory_bank_init: copy +memory_bank_size: 60 +memory_bank_topk: 10 +memory_encoder: DINO +memory_retrieve: cos-linear +mesh_edge_length_loss_weight: 0.0 +mesh_normal_consistency_loss_weight: 0.0 +min_seq_len: 1 +nrm_max: +- 1.0 +- 1.0 +- 1.0 +nrm_min: +- -1.0 +- -1.0 +- 0.0 +num_body_bones: 8 +num_epochs: 1375 +num_iterations: 10000000 +num_layers_arti: 4 +num_layers_deform: 5 +num_layers_dino: 5 +num_layers_light: 5 +num_layers_tex: 8 +num_leg_bones: 3 +num_legs: 4 +num_sample_frames: 1 +num_workers: 8 +out_image_size: 256 +perturb_articulation_epochs: +- 0 +- 0 +perturb_normal: false +perturb_sdf: false +pose_arch: encoder_dino_patch_key +pose_entropy_loss_weight: 0.0 +pose_epochs: +- 0 +- 0 +pose_xflip_recon_epochs: +- 0 +- 0 +pose_xflip_reg_loss_weight: 0.0 +prior_condition_choice: mod +prior_lr: 0.0006 +prior_sdf_mode: mlp +pyplot_metrics: false +random_flip_train: true +random_mask_law: random_azimuth +random_sample_train_frames: false +random_sample_val_frames: true +rank: 0 +reg_body_rotate_mult: 0.1 +render_dino_mode: feature_mlp +renderer_spp: 4 +resume: true +resume_prior_optim: true +rgb_loss_weight: 1.0 +rgb_suffix: .png +root_dir: /viscam/u/zzli +rot_all_quad_epochs: +- 0 +- 276 +rot_rand_quad_epochs: +- 0 +- 0 +rot_rep: quadlookat +rot_temp_scalar: 1.0 +run_few_shot: true +run_train: true +save_checkpoint_freq: 1 +save_result_freq: 501 +sdf_bce_reg_loss_min_weight: 0 +sdf_bce_reg_loss_weight: 0 +sdf_gradient_reg_loss_min_weight: 0.1 +sdf_gradient_reg_loss_weight: 0.1 +sdf_inflate_reg_loss_epochs: +- 0 +- 0 +sdf_reg_decay_start_iter: 10000 +seed: 0 +seqshape_epochs: +- 0 +- 0 +shuffle_train_seqs: true +sigma: 1e-6 +silhouette_dt_loss_weight: 0.0 +silhouette_inv_dt_loss_weight: 50.0 +silhouette_loss_weight: 5.0 +skinning_temperature: 0.05 +skip_beginning: 0 +skip_end: 0 +small_leg_angle: true +smooth_deformation_loss_weight: 10.0 +static_root_bones: false +sym_deform: true +sym_dino: false +sym_prior_shape: true +sym_texture: true +temp_clip_high: 10.0 +temp_clip_low: 1.0 +tex_im_size: 256 +texture_epochs: +- 0 +- 276 +texture_mode: mlp +train_data_dir: + bear: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/bear_comb_dinov2_new/train + cow: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/cow_comb_dinov2_new/train + elephant: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/elephant_comb_dinov2_new/train + giraffe: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/giraffe_comb_dinov2_new/train + horse: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/horse_comb_dinov2_new/train + sheep: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/sheep_comb_dinov2_new/train + zebra: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/zebra_comb_dinov2_new/train +train_with_cub: false +use_logger: true +use_scheduler: false +use_wandb: false +val_data_dir: + bear: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/bear_comb_dinov2_new/val + cow: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/cow_comb_dinov2_new/val + elephant: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/elephant_comb_dinov2_new/val + giraffe: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/giraffe_comb_dinov2_new/val + horse: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/horse_comb_dinov2_new/val + sheep: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/sheep_comb_dinov2_new/val + zebra: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/zebra_comb_dinov2_new/val +visualize_validation: true +vit_final_layer_type: conv +which_vit: dino_vits8 +world_size: 1 +zflip_epochs: +- 0 +- 0 diff --git a/ckpts/iter0800000.pth b/ckpts/iter0800000.pth new file mode 100644 index 0000000000000000000000000000000000000000..6f7f5e5310f36593bcc60e48d68be28356e1ea1c --- /dev/null +++ b/ckpts/iter0800000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c7b090f1ff3e76e2ba608a25a2bd79af2892d6bb307132c9d038082395c1d57 +size 306560367 diff --git a/video3d/__init__.py b/video3d/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..cba85661999ae357b81f9f2c529625ed7cfefb80 --- /dev/null +++ b/video3d/__init__.py @@ -0,0 +1,6 @@ +from .utils.misc import setup_runtime +from .trainer import Trainer +from .trainer_ddp import TrainerDDP +from .model import Unsup3D +from .model_ddp import Unsup3DDDP +from .trainer_few_shot import Fewshot_Trainer diff --git a/video3d/cages/cages.py b/video3d/cages/cages.py new file mode 100755 index 0000000000000000000000000000000000000000..130162256f16bc20c72f0dfe980207694dfe27b9 --- /dev/null +++ b/video3d/cages/cages.py @@ -0,0 +1,218 @@ +# Cages code used from https://github.com/yifita/deep_cage +import torch +import numpy as np +import trimesh + + + +def deform_with_MVC(cage, cage_deformed, cage_face, query, verbose=False): + """ + cage (B,C,3) + cage_deformed (B,C,3) + cage_face (B,F,3) int64 + query (B,Q,3) + """ + weights, weights_unnormed = mean_value_coordinates_3D(query, cage, cage_face, verbose=True) +# weights = weights.detach() + deformed = torch.sum(weights.unsqueeze(-1)*cage_deformed.unsqueeze(1), dim=2) + if verbose: + return deformed, weights, weights_unnormed + return deformed + + +def loadInitCage(template): + init_cage_V, init_cage_F = read_trimesh(template) + init_cage_V = torch.from_numpy(init_cage_V[:,:3].astype(np.float32)).unsqueeze(0)*2.0 + init_cage_F = torch.from_numpy(init_cage_F[:,:3].astype(np.int64)).unsqueeze(0) + return init_cage_V, init_cage_F + + +def read_trimesh(path): + mesh = trimesh.load(path) + return mesh.vertices, mesh.faces + + +# util functions from pytorch_points +PI = 3.1415927 + +def normalize_to_box(input): + """ + normalize point cloud to unit bounding box + center = (max - min)/2 + scale = max(abs(x)) + input: pc [N, P, dim] or [P, dim] + output: pc, centroid, furthest_distance + """ + if len(input.shape) == 2: + axis = 0 + P = input.shape[0] + D = input.shape[1] + elif len(input.shape) == 3: + axis = 1 + P = input.shape[1] + D = input.shape[2] + if isinstance(input, np.ndarray): + maxP = np.amax(input, axis=axis, keepdims=True) + minP = np.amin(input, axis=axis, keepdims=True) + centroid = (maxP+minP)/2 + input = input - centroid + furthest_distance = np.amax(np.abs(input), axis=(axis, -1), keepdims=True) + input = input / furthest_distance + elif isinstance(input, torch.Tensor): + maxP = torch.max(input, dim=axis, keepdim=True)[0] + minP = torch.min(input, dim=axis, keepdim=True)[0] + centroid = (maxP+minP)/2 + input = input - centroid + in_shape = list(input.shape[:axis])+[P*D] + furthest_distance = torch.max(torch.abs(input).view(in_shape), dim=axis, keepdim=True)[0] + furthest_distance = furthest_distance.unsqueeze(-1) + input = input / furthest_distance + + return input, centroid, furthest_distance + +def normalize(tensor, dim=-1): + """normalize tensor in specified dimension""" + return torch.nn.functional.normalize(tensor, p=2, dim=dim, eps=1e-12, out=None) + + +def check_values(tensor): + """return true if tensor doesn't contain NaN or Inf""" + return not (torch.any(torch.isnan(tensor)).item() or torch.any(torch.isinf(tensor)).item()) + + +class ScatterAdd(torch.autograd.Function): + @staticmethod + def forward(ctx, src, idx, dim, out_size, fill=0.0): + out = torch.full(out_size, fill, device=src.device, dtype=src.dtype) + ctx.save_for_backward(idx) + out.scatter_add_(dim, idx, src) + ctx.mark_non_differentiable(idx) + ctx.dim = dim + return out + + @staticmethod + def backward(ctx, ograd): + idx, = ctx.saved_tensors + grad = torch.gather(ograd, ctx.dim, idx) + return grad, None, None, None, None + + +_scatter_add = ScatterAdd.apply + + +def scatter_add(src, idx, dim, out_size=None, fill=0.0): + if out_size is None: + out_size = list(src.size()) + dim_size = idx.max().item()+1 + out_size[dim] = dim_size + return _scatter_add(src, idx, dim, out_size, fill) + + +def mean_value_coordinates_3D(query, vertices, faces, verbose=False): + """ + Tao Ju et.al. MVC for 3D triangle meshes + params: + query (B,P,3) + vertices (B,N,3) + faces (B,F,3) + return: + wj (B,P,N) + """ + B, F, _ = faces.shape + _, P, _ = query.shape + _, N, _ = vertices.shape + # u_i = p_i - x (B,P,N,3) + uj = vertices.unsqueeze(1) - query.unsqueeze(2) + # \|u_i\| (B,P,N,1) + dj = torch.norm(uj, dim=-1, p=2, keepdim=True) + uj = normalize(uj, dim=-1) + # gather triangle B,P,F,3,3 + ui = torch.gather(uj.unsqueeze(2).expand(-1,-1,F,-1,-1), + 3, + faces.unsqueeze(1).unsqueeze(-1).expand(-1,P,-1,-1,3)) + # li = \|u_{i+1}-u_{i-1}\| (B,P,F,3) + li = torch.norm(ui[:,:,:,[1, 2, 0],:] - ui[:, :, :,[2, 0, 1],:], dim=-1, p=2) + eps = 2e-5 + li = torch.where(li>=2, li-(li.detach()-(2-eps)), li) + li = torch.where(li<=-2, li-(li.detach()+(2-eps)), li) + # asin(x) is inf at +/-1 + # θi = 2arcsin[li/2] (B,P,F,3) + theta_i = 2*torch.asin(li/2) + assert(check_values(theta_i)) + # B,P,F,1 + h = torch.sum(theta_i, dim=-1, keepdim=True)/2 + # wi← sin[θi]d{i−1}d{i+1} + # (B,P,F,3) ci ← (2sin[h]sin[h−θi])/(sin[θ_{i+1}]sin[θ_{i−1}])−1 + ci = 2*torch.sin(h)*torch.sin(h-theta_i)/(torch.sin(theta_i[:,:,:,[1, 2, 0]])*torch.sin(theta_i[:,:,:,[2, 0, 1]]))-1 + + # NOTE: because of floating point ci can be slightly larger than 1, causing problem with sqrt(1-ci^2) + # NOTE: sqrt(x)' is nan for x=0, hence use eps + eps = 1e-5 + ci = torch.where(ci>=1, ci-(ci.detach()-(1-eps)), ci) + ci = torch.where(ci<=-1, ci-(ci.detach()+(1-eps)), ci) + # si← sign[det[u1,u2,u3]]sqrt(1-ci^2) + # (B,P,F)*(B,P,F,3) + + si = torch.sign(torch.det(ui)).unsqueeze(-1)*torch.sqrt(1-ci**2) # sqrt gradient nan for 0 + assert(check_values(si)) + # (B,P,F,3) + di = torch.gather(dj.unsqueeze(2).squeeze(-1).expand(-1,-1,F,-1), 3, + faces.unsqueeze(1).expand(-1,P,-1,-1)) + assert(check_values(di)) + # if si.requires_grad: + # vertices.register_hook(save_grad("mvc/dv")) + # li.register_hook(save_grad("mvc/dli")) + # theta_i.register_hook(save_grad("mvc/dtheta")) + # ci.register_hook(save_grad("mvc/dci")) + # si.register_hook(save_grad("mvc/dsi")) + # di.register_hook(save_grad("mvc/ddi")) + + # wi← (θi −c[i+1]θ[i−1] −c[i−1]θ[i+1])/(disin[θi+1]s[i−1]) + # B,P,F,3 + # CHECK is there a 2* in the denominator + wi = (theta_i-ci[:,:,:,[1,2,0]]*theta_i[:,:,:,[2,0,1]]-ci[:,:,:,[2,0,1]]*theta_i[:,:,:,[1,2,0]])/(di*torch.sin(theta_i[:,:,:,[1,2,0]])*si[:,:,:,[2,0,1]]) + # if ∃i,|si| ≤ ε, set wi to 0. coplaner with T but outside + # ignore coplaner outside triangle + # alternative check + # (B,F,3,3) + # triangle_points = torch.gather(vertices.unsqueeze(1).expand(-1,F,-1,-1), 2, faces.unsqueeze(-1).expand(-1,-1,-1,3)) + # # (B,P,F,3), (B,1,F,3) -> (B,P,F,1) + # determinant = dot_product(triangle_points[:,:,:,0].unsqueeze(1)-query.unsqueeze(2), + # torch.cross(triangle_points[:,:,:,1]-triangle_points[:,:,:,0], + # triangle_points[:,:,:,2]-triangle_points[:,:,:,0], dim=-1).unsqueeze(1), dim=-1, keepdim=True).detach() + # # (B,P,F,1) + # sqrdist = determinant*determinant / (4 * sqrNorm(torch.cross(triangle_points[:,:,:,1]-triangle_points[:,:,:,0], triangle_points[:,:,:,2]-triangle_points[:,:,:,0], dim=-1), keepdim=True)) + + wi = torch.where(torch.any(torch.abs(si) <= 1e-5, keepdim=True, dim=-1), torch.zeros_like(wi), wi) + # wi = torch.where(sqrdist <= 1e-5, torch.zeros_like(wi), wi) + + # if π −h < ε, x lies on t, use 2D barycentric coordinates + # inside triangle + inside_triangle = (PI-h).squeeze(-1)<1e-4 + # set all F for this P to zero + wi = torch.where(torch.any(inside_triangle, dim=-1, keepdim=True).unsqueeze(-1), torch.zeros_like(wi), wi) + # CHECK is it di https://www.cse.wustl.edu/~taoju/research/meanvalue.pdf or li http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.516.1856&rep=rep1&type=pdf + wi = torch.where(inside_triangle.unsqueeze(-1).expand(-1,-1,-1,wi.shape[-1]), torch.sin(theta_i)*di[:,:,:,[2,0,1]]*di[:,:,:,[1,2,0]], wi) + + # sum over all faces face -> vertex (B,P,F*3) -> (B,P,N) + wj = scatter_add(wi.reshape(B,P,-1).contiguous(), faces.unsqueeze(1).expand(-1,P,-1,-1).reshape(B,P,-1), 2, out_size=(B,P,N)) + + # close to vertex (B,P,N) + close_to_point = dj.squeeze(-1) < 1e-8 + # set all F for this P to zero + wj = torch.where(torch.any(close_to_point, dim=-1, keepdim=True), torch.zeros_like(wj), wj) + wj = torch.where(close_to_point, torch.ones_like(wj), wj) + + # (B,P,1) + sumWj = torch.sum(wj, dim=-1, keepdim=True) + sumWj = torch.where(sumWj==0, torch.ones_like(sumWj), sumWj) + + wj_normalised = wj / sumWj + # if wj.requires_grad: + # saved_variables["mvc/wi"] = wi + # wi.register_hook(save_grad("mvc/dwi")) + # wj.register_hook(save_grad("mvc/dwj")) + if verbose: + return wj_normalised, wi + else: + return wj_normalised diff --git a/video3d/cub_dataloaders.py b/video3d/cub_dataloaders.py new file mode 100755 index 0000000000000000000000000000000000000000..7da2aade62ecba5a72594395894d649dc02a3cd5 --- /dev/null +++ b/video3d/cub_dataloaders.py @@ -0,0 +1,404 @@ +import os.path as osp +import cv2 +import numpy as np +import scipy.io as sio +import torch +from PIL import Image +from torch.utils.data import Dataset +from types import SimpleNamespace + + +def get_cub_loader(data_dir, split='test', is_validation=False, batch_size=256, num_workers=4, image_size=256): + opts = SimpleNamespace() + opts.data_dir = data_dir + opts.padding_frac = 0.05 + opts.jitter_frac = 0.05 + opts.input_size = image_size + opts.split = split + + dataset = CUBDataset(opts) + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=not is_validation, + num_workers=num_workers, + pin_memory=True + ) + return loader + + +class CUBDataset(Dataset): + def __init__(self, opts): + super().__init__() + + self.opts = opts + self.img_size = opts.input_size + self.jitter_frac = opts.jitter_frac + self.padding_frac = opts.padding_frac + self.split = opts.split + self.data_dir = opts.data_dir + self.data_cache_dir = osp.join(self.data_dir, 'cachedir/cub') + self.img_dir = osp.join(self.data_dir, 'images') + + self.anno_path = osp.join(self.data_cache_dir, 'data', '%s_cub_cleaned.mat' % self.split) + self.anno_sfm_path = osp.join(self.data_cache_dir, 'sfm', 'anno_%s.mat' % self.split) + + if not osp.exists(self.anno_path): + print('%s doesnt exist!' % self.anno_path) + import pdb; pdb.set_trace() + + # Load the annotation file. + print('loading %s' % self.anno_path) + self.anno = sio.loadmat( + self.anno_path, struct_as_record=False, squeeze_me=True)['images'] + self.anno_sfm = sio.loadmat( + self.anno_sfm_path, struct_as_record=False, squeeze_me=True)['sfm_anno'] + + self.kp_perm = np.array([1, 2, 3, 4, 5, 6, 11, 12, 13, 10, 7, 8, 9, 14, 15]) - 1; + + self.num_imgs = len(self.anno) + print('%d images' % self.num_imgs) + + def forward_img(self, index): + data = self.anno[index] + data_sfm = self.anno_sfm[0] + + # sfm_pose = (sfm_c, sfm_t, sfm_r) + sfm_pose = [np.copy(data_sfm.scale), np.copy(data_sfm.trans), np.copy(data_sfm.rot)] + + sfm_rot = np.pad(sfm_pose[2], (0,1), 'constant') + sfm_rot[3, 3] = 1 + sfm_pose[2] = quaternion_from_matrix(sfm_rot, isprecise=True) + + img_path = osp.join(self.img_dir, str(data.rel_path)) + #img_path = img_path.replace("JPEG", "jpg") + img = np.array(Image.open(img_path)) + + # Some are grayscale: + if len(img.shape) == 2: + img = np.repeat(np.expand_dims(img, 2), 3, axis=2) + mask = data.mask + mask = np.expand_dims(mask, 2) + h,w,_ = mask.shape + + # Adjust to 0 indexing + bbox = np.array( + [data.bbox.x1, data.bbox.y1, data.bbox.x2, data.bbox.y2], + float) - 1 + + parts = data.parts.T.astype(float) + kp = np.copy(parts) + vis = kp[:, 2] > 0 + kp[vis, :2] -= 1 + + # Peturb bbox + if self.split == 'train': + bbox = peturb_bbox( + bbox, pf=self.padding_frac, jf=self.jitter_frac) + else: + bbox = peturb_bbox( + bbox, pf=self.padding_frac, jf=0) + bbox = square_bbox(bbox) + + # crop image around bbox, translate kps + img, mask, kp, sfm_pose = self.crop_image(img, mask, bbox, kp, vis, sfm_pose) + + # scale image, and mask. And scale kps. + img, mask, kp, sfm_pose = self.scale_image(img, mask, kp, vis, sfm_pose) + + # Mirror image on random. + if self.split == 'train': + img, mask, kp, sfm_pose = self.mirror_image(img, mask, kp, sfm_pose) + + # Normalize kp to be [-1, 1] + img_h, img_w = img.shape[:2] + kp_norm, sfm_pose = self.normalize_kp(kp, sfm_pose, img_h, img_w) + + # img = Image.fromarray(np.asarray(img, np.uint8)) + mask = np.asarray(mask, np.float32) + return img, kp_norm, mask, sfm_pose, img_path + + def normalize_kp(self, kp, sfm_pose, img_h, img_w): + vis = kp[:, 2, None] > 0 + new_kp = np.stack([2 * (kp[:, 0] / img_w) - 1, + 2 * (kp[:, 1] / img_h) - 1, + kp[:, 2]]).T + sfm_pose[0] *= (1.0/img_w + 1.0/img_h) + sfm_pose[1][0] = 2.0 * (sfm_pose[1][0] / img_w) - 1 + sfm_pose[1][1] = 2.0 * (sfm_pose[1][1] / img_h) - 1 + new_kp = vis * new_kp + + return new_kp, sfm_pose + + def crop_image(self, img, mask, bbox, kp, vis, sfm_pose): + # crop image and mask and translate kps + img = crop(img, bbox, bgval=1) + mask = crop(mask, bbox, bgval=0) + kp[vis, 0] -= bbox[0] + kp[vis, 1] -= bbox[1] + sfm_pose[1][0] -= bbox[0] + sfm_pose[1][1] -= bbox[1] + return img, mask, kp, sfm_pose + + def scale_image(self, img, mask, kp, vis, sfm_pose): + # Scale image so largest bbox size is img_size + bwidth = np.shape(img)[0] + bheight = np.shape(img)[1] + scale = self.img_size / float(max(bwidth, bheight)) + img_scale, _ = resize_img(img, scale) + # if img_scale.shape[0] != self.img_size: + # print('bad!') + # import ipdb; ipdb.set_trace() + # mask_scale, _ = resize_img(mask, scale) +# mask_scale, _ = resize_img(mask, scale, interpolation=cv2.INTER_NEAREST) + mask_scale, _ = resize_img(mask, scale) + kp[vis, :2] *= scale + sfm_pose[0] *= scale + sfm_pose[1] *= scale + + return img_scale, mask_scale, kp, sfm_pose + + def mirror_image(self, img, mask, kp, sfm_pose): + kp_perm = self.kp_perm + if np.random.rand(1) > 0.5: + # Need copy bc torch collate doesnt like neg strides + img_flip = img[:, ::-1, :].copy() + mask_flip = mask[:, ::-1].copy() + + # Flip kps. + new_x = img.shape[1] - kp[:, 0] - 1 + kp_flip = np.hstack((new_x[:, None], kp[:, 1:])) + kp_flip = kp_flip[kp_perm, :] + # Flip sfm_pose Rot. + R = quaternion_matrix(sfm_pose[2]) + flip_R = np.diag([-1, 1, 1, 1]).dot(R.dot(np.diag([-1, 1, 1, 1]))) + sfm_pose[2] = quaternion_from_matrix(flip_R, isprecise=True) + # Flip tx + tx = img.shape[1] - sfm_pose[1][0] - 1 + sfm_pose[1][0] = tx + return img_flip, mask_flip, kp_flip, sfm_pose + else: + return img, mask, kp, sfm_pose + + def __len__(self): + return self.num_imgs + + def __getitem__(self, index): + img, kp, mask, sfm_pose, img_path = self.forward_img(index) + sfm_pose[0].shape = 1 + mask = np.expand_dims(mask, 2) + + images = torch.FloatTensor(img /255.).permute(2,0,1).unsqueeze(0) + masks = torch.FloatTensor(mask).permute(2,0,1).repeat(1,3,1,1) + mask_dt = compute_distance_transform(masks) + # flows = torch.zeros(1,2, self.img_size, self.img_size) + flows = torch.zeros(1) + bboxs = torch.FloatTensor([0, 0, 0, self.img_size, self.img_size, 1, 1, 0]).unsqueeze(0) # frame_id, crop_x0, crop_y0, crop_w, crop_h, resize_sx, resize_sy, sharpness + bg_image = images[0] + seq_idx = torch.LongTensor([index]) + frame_idx = torch.LongTensor([0]) + return images, masks, mask_dt, flows, bboxs, bg_image, seq_idx, frame_idx + + +def compute_distance_transform(mask): + mask_dt = [] + for m in mask: + dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)) + inv_dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(1 - m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)) + mask_dt += [torch.stack([dt, inv_dt], 0)] + return torch.stack(mask_dt, 0) # Bx2xHxW + + +def resize_img(img, scale_factor): + new_size = (np.round(np.array(img.shape[:2]) * scale_factor)).astype(int) + new_img = cv2.resize(img, (new_size[1], new_size[0])) + # This is scale factor of [height, width] i.e. [y, x] + actual_factor = [new_size[0] / float(img.shape[0]), + new_size[1] / float(img.shape[1])] + return new_img, actual_factor + + +def peturb_bbox(bbox, pf=0, jf=0): + ''' + Jitters and pads the input bbox. + Args: + bbox: Zero-indexed tight bbox. + pf: padding fraction. + jf: jittering fraction. + Returns: + pet_bbox: Jittered and padded box. Might have -ve or out-of-image coordinates + ''' + pet_bbox = [coord for coord in bbox] + bwidth = bbox[2] - bbox[0] + 1 + bheight = bbox[3] - bbox[1] + 1 + + pet_bbox[0] -= (pf*bwidth) + (1-2*np.random.random())*jf*bwidth + pet_bbox[1] -= (pf*bheight) + (1-2*np.random.random())*jf*bheight + pet_bbox[2] += (pf*bwidth) + (1-2*np.random.random())*jf*bwidth + pet_bbox[3] += (pf*bheight) + (1-2*np.random.random())*jf*bheight + + return pet_bbox + + +def square_bbox(bbox): + ''' + Converts a bbox to have a square shape by increasing size along non-max dimension. + ''' + sq_bbox = [int(round(coord)) for coord in bbox] + bwidth = sq_bbox[2] - sq_bbox[0] + 1 + bheight = sq_bbox[3] - sq_bbox[1] + 1 + maxdim = float(max(bwidth, bheight)) + + dw_b_2 = int(round((maxdim-bwidth)/2.0)) + dh_b_2 = int(round((maxdim-bheight)/2.0)) + + sq_bbox[0] -= dw_b_2 + sq_bbox[1] -= dh_b_2 + sq_bbox[2] = sq_bbox[0] + maxdim - 1 + sq_bbox[3] = sq_bbox[1] + maxdim - 1 + + return sq_bbox + + +def crop(img, bbox, bgval=0): + ''' + Crops a region from the image corresponding to the bbox. + If some regions specified go outside the image boundaries, the pixel values are set to bgval. + Args: + img: image to crop + bbox: bounding box to crop + bgval: default background for regions outside image + ''' + bbox = [int(round(c)) for c in bbox] + bwidth = bbox[2] - bbox[0] + 1 + bheight = bbox[3] - bbox[1] + 1 + + im_shape = np.shape(img) + im_h, im_w = im_shape[0], im_shape[1] + + nc = 1 if len(im_shape) < 3 else im_shape[2] + + img_out = np.ones((bheight, bwidth, nc))*bgval + x_min_src = max(0, bbox[0]) + x_max_src = min(im_w, bbox[2]+1) + y_min_src = max(0, bbox[1]) + y_max_src = min(im_h, bbox[3]+1) + + x_min_trg = x_min_src - bbox[0] + x_max_trg = x_max_src - x_min_src + x_min_trg + y_min_trg = y_min_src - bbox[1] + y_max_trg = y_max_src - y_min_src + y_min_trg + + img_out[y_min_trg:y_max_trg, x_min_trg:x_max_trg, :] = img[y_min_src:y_max_src, x_min_src:x_max_src, :] + return img_out + + +# https://github.com/akanazawa/cmr/blob/master/utils/transformations.py +import math +import numpy +_EPS = numpy.finfo(float).eps * 4.0 + +def quaternion_matrix(quaternion): + """Return homogeneous rotation matrix from quaternion. + >>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0]) + >>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0])) + True + >>> M = quaternion_matrix([1, 0, 0, 0]) + >>> numpy.allclose(M, numpy.identity(4)) + True + >>> M = quaternion_matrix([0, 1, 0, 0]) + >>> numpy.allclose(M, numpy.diag([1, -1, -1, 1])) + True + """ + q = numpy.array(quaternion, dtype=numpy.float64, copy=True) + n = numpy.dot(q, q) + if n < _EPS: + return numpy.identity(4) + q *= math.sqrt(2.0 / n) + q = numpy.outer(q, q) + return numpy.array([ + [1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], 0.0], + [ q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], 0.0], + [ q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0], + [ 0.0, 0.0, 0.0, 1.0]]) + +def quaternion_from_matrix(matrix, isprecise=False): + """Return quaternion from rotation matrix. + If isprecise is True, the input matrix is assumed to be a precise rotation + matrix and a faster algorithm is used. + >>> q = quaternion_from_matrix(numpy.identity(4), True) + >>> numpy.allclose(q, [1, 0, 0, 0]) + True + >>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1])) + >>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0]) + True + >>> R = rotation_matrix(0.123, (1, 2, 3)) + >>> q = quaternion_from_matrix(R, True) + >>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786]) + True + >>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0], + ... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]] + >>> q = quaternion_from_matrix(R) + >>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611]) + True + >>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0], + ... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]] + >>> q = quaternion_from_matrix(R) + >>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603]) + True + >>> R = random_rotation_matrix() + >>> q = quaternion_from_matrix(R) + >>> is_same_transform(R, quaternion_matrix(q)) + True + >>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False), + ... quaternion_from_matrix(R, isprecise=True)) + True + >>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0) + >>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False), + ... quaternion_from_matrix(R, isprecise=True)) + True + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4] + if isprecise: + q = numpy.empty((4, )) + t = numpy.trace(M) + if t > M[3, 3]: + q[0] = t + q[3] = M[1, 0] - M[0, 1] + q[2] = M[0, 2] - M[2, 0] + q[1] = M[2, 1] - M[1, 2] + else: + i, j, k = 0, 1, 2 + if M[1, 1] > M[0, 0]: + i, j, k = 1, 2, 0 + if M[2, 2] > M[i, i]: + i, j, k = 2, 0, 1 + t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3] + q[i] = t + q[j] = M[i, j] + M[j, i] + q[k] = M[k, i] + M[i, k] + q[3] = M[k, j] - M[j, k] + q = q[[3, 0, 1, 2]] + q *= 0.5 / math.sqrt(t * M[3, 3]) + else: + m00 = M[0, 0] + m01 = M[0, 1] + m02 = M[0, 2] + m10 = M[1, 0] + m11 = M[1, 1] + m12 = M[1, 2] + m20 = M[2, 0] + m21 = M[2, 1] + m22 = M[2, 2] + # symmetric matrix K + K = numpy.array([[m00-m11-m22, 0.0, 0.0, 0.0], + [m01+m10, m11-m00-m22, 0.0, 0.0], + [m02+m20, m12+m21, m22-m00-m11, 0.0], + [m21-m12, m02-m20, m10-m01, m00+m11+m22]]) + K /= 3.0 + # quaternion is eigenvector of K that corresponds to largest eigenvalue + w, V = numpy.linalg.eigh(K) + q = V[[3, 0, 1, 2], numpy.argmax(w)] + if q[0] < 0.0: + numpy.negative(q, q) + return q diff --git a/video3d/cub_dataloaders_ddp.py b/video3d/cub_dataloaders_ddp.py new file mode 100755 index 0000000000000000000000000000000000000000..babeab0c8b066f0407677b12c1b92f8f98399e1a --- /dev/null +++ b/video3d/cub_dataloaders_ddp.py @@ -0,0 +1,434 @@ +import os.path as osp +import cv2 +import numpy as np +import scipy.io as sio +import torch +from PIL import Image +from torch.utils.data import Dataset +from types import SimpleNamespace + + +def get_cub_loader(data_dir, split='test', is_validation=False, batch_size=256, num_workers=4, image_size=256): + opts = SimpleNamespace() + opts.data_dir = data_dir + opts.padding_frac = 0.05 + opts.jitter_frac = 0.05 + opts.input_size = image_size + opts.split = split + + dataset = CUBDataset(opts) + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=not is_validation, + num_workers=num_workers, + pin_memory=True + ) + return loader + + +def get_cub_loader_ddp(data_dir, world_size, rank, split='test', is_validation=False, batch_size=256, num_workers=4, image_size=256): + opts = SimpleNamespace() + opts.data_dir = data_dir + opts.padding_frac = 0.05 + opts.jitter_frac = 0.05 + opts.input_size = image_size + opts.split = split + + dataset = CUBDataset(opts) + + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + ) + + loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + shuffle=not is_validation, + drop_last=True, + num_workers=num_workers, + pin_memory=True + ) + return loader + + +class CUBDataset(Dataset): + def __init__(self, opts): + super().__init__() + + self.opts = opts + self.img_size = opts.input_size + self.jitter_frac = opts.jitter_frac + self.padding_frac = opts.padding_frac + self.split = opts.split + self.data_dir = opts.data_dir + self.data_cache_dir = osp.join(self.data_dir, 'cachedir/cub') + self.img_dir = osp.join(self.data_dir, 'images') + + self.anno_path = osp.join(self.data_cache_dir, 'data', '%s_cub_cleaned.mat' % self.split) + self.anno_sfm_path = osp.join(self.data_cache_dir, 'sfm', 'anno_%s.mat' % self.split) + + if not osp.exists(self.anno_path): + print('%s doesnt exist!' % self.anno_path) + import pdb; pdb.set_trace() + + # Load the annotation file. + print('loading %s' % self.anno_path) + self.anno = sio.loadmat( + self.anno_path, struct_as_record=False, squeeze_me=True)['images'] + self.anno_sfm = sio.loadmat( + self.anno_sfm_path, struct_as_record=False, squeeze_me=True)['sfm_anno'] + + self.kp_perm = np.array([1, 2, 3, 4, 5, 6, 11, 12, 13, 10, 7, 8, 9, 14, 15]) - 1; + + self.num_imgs = len(self.anno) + print('%d images' % self.num_imgs) + + def forward_img(self, index): + data = self.anno[index] + data_sfm = self.anno_sfm[0] + + # sfm_pose = (sfm_c, sfm_t, sfm_r) + sfm_pose = [np.copy(data_sfm.scale), np.copy(data_sfm.trans), np.copy(data_sfm.rot)] + + sfm_rot = np.pad(sfm_pose[2], (0,1), 'constant') + sfm_rot[3, 3] = 1 + sfm_pose[2] = quaternion_from_matrix(sfm_rot, isprecise=True) + + img_path = osp.join(self.img_dir, str(data.rel_path)) + #img_path = img_path.replace("JPEG", "jpg") + img = np.array(Image.open(img_path)) + + # Some are grayscale: + if len(img.shape) == 2: + img = np.repeat(np.expand_dims(img, 2), 3, axis=2) + mask = data.mask + mask = np.expand_dims(mask, 2) + h,w,_ = mask.shape + + # Adjust to 0 indexing + bbox = np.array( + [data.bbox.x1, data.bbox.y1, data.bbox.x2, data.bbox.y2], + float) - 1 + + parts = data.parts.T.astype(float) + kp = np.copy(parts) + vis = kp[:, 2] > 0 + kp[vis, :2] -= 1 + + # Peturb bbox + if self.split == 'train': + bbox = peturb_bbox( + bbox, pf=self.padding_frac, jf=self.jitter_frac) + else: + bbox = peturb_bbox( + bbox, pf=self.padding_frac, jf=0) + bbox = square_bbox(bbox) + + # crop image around bbox, translate kps + img, mask, kp, sfm_pose = self.crop_image(img, mask, bbox, kp, vis, sfm_pose) + + # scale image, and mask. And scale kps. + img, mask, kp, sfm_pose = self.scale_image(img, mask, kp, vis, sfm_pose) + + # Mirror image on random. + if self.split == 'train': + img, mask, kp, sfm_pose = self.mirror_image(img, mask, kp, sfm_pose) + + # Normalize kp to be [-1, 1] + img_h, img_w = img.shape[:2] + kp_norm, sfm_pose = self.normalize_kp(kp, sfm_pose, img_h, img_w) + + # img = Image.fromarray(np.asarray(img, np.uint8)) + mask = np.asarray(mask, np.float32) + return img, kp_norm, mask, sfm_pose, img_path + + def normalize_kp(self, kp, sfm_pose, img_h, img_w): + vis = kp[:, 2, None] > 0 + new_kp = np.stack([2 * (kp[:, 0] / img_w) - 1, + 2 * (kp[:, 1] / img_h) - 1, + kp[:, 2]]).T + sfm_pose[0] *= (1.0/img_w + 1.0/img_h) + sfm_pose[1][0] = 2.0 * (sfm_pose[1][0] / img_w) - 1 + sfm_pose[1][1] = 2.0 * (sfm_pose[1][1] / img_h) - 1 + new_kp = vis * new_kp + + return new_kp, sfm_pose + + def crop_image(self, img, mask, bbox, kp, vis, sfm_pose): + # crop image and mask and translate kps + img = crop(img, bbox, bgval=1) + mask = crop(mask, bbox, bgval=0) + kp[vis, 0] -= bbox[0] + kp[vis, 1] -= bbox[1] + sfm_pose[1][0] -= bbox[0] + sfm_pose[1][1] -= bbox[1] + return img, mask, kp, sfm_pose + + def scale_image(self, img, mask, kp, vis, sfm_pose): + # Scale image so largest bbox size is img_size + bwidth = np.shape(img)[0] + bheight = np.shape(img)[1] + scale = self.img_size / float(max(bwidth, bheight)) + img_scale, _ = resize_img(img, scale) + # if img_scale.shape[0] != self.img_size: + # print('bad!') + # import ipdb; ipdb.set_trace() + # mask_scale, _ = resize_img(mask, scale) +# mask_scale, _ = resize_img(mask, scale, interpolation=cv2.INTER_NEAREST) + mask_scale, _ = resize_img(mask, scale) + kp[vis, :2] *= scale + sfm_pose[0] *= scale + sfm_pose[1] *= scale + + return img_scale, mask_scale, kp, sfm_pose + + def mirror_image(self, img, mask, kp, sfm_pose): + kp_perm = self.kp_perm + if np.random.rand(1) > 0.5: + # Need copy bc torch collate doesnt like neg strides + img_flip = img[:, ::-1, :].copy() + mask_flip = mask[:, ::-1].copy() + + # Flip kps. + new_x = img.shape[1] - kp[:, 0] - 1 + kp_flip = np.hstack((new_x[:, None], kp[:, 1:])) + kp_flip = kp_flip[kp_perm, :] + # Flip sfm_pose Rot. + R = quaternion_matrix(sfm_pose[2]) + flip_R = np.diag([-1, 1, 1, 1]).dot(R.dot(np.diag([-1, 1, 1, 1]))) + sfm_pose[2] = quaternion_from_matrix(flip_R, isprecise=True) + # Flip tx + tx = img.shape[1] - sfm_pose[1][0] - 1 + sfm_pose[1][0] = tx + return img_flip, mask_flip, kp_flip, sfm_pose + else: + return img, mask, kp, sfm_pose + + def __len__(self): + return self.num_imgs + + def __getitem__(self, index): + img, kp, mask, sfm_pose, img_path = self.forward_img(index) + sfm_pose[0].shape = 1 + mask = np.expand_dims(mask, 2) + + images = torch.FloatTensor(img /255.).permute(2,0,1).unsqueeze(0) + masks = torch.FloatTensor(mask).permute(2,0,1).repeat(1,3,1,1) + mask_dt = compute_distance_transform(masks) + # flows = torch.zeros(1,2, self.img_size, self.img_size) + flows = torch.zeros(1) + bboxs = torch.FloatTensor([0, 0, 0, self.img_size, self.img_size, 1, 1, 0]).unsqueeze(0) # frame_id, crop_x0, crop_y0, crop_w, crop_h, resize_sx, resize_sy, sharpness + bg_image = images[0] + seq_idx = torch.LongTensor([index]) + frame_idx = torch.LongTensor([0]) + return images, masks, mask_dt, flows, bboxs, bg_image, seq_idx, frame_idx + + +def compute_distance_transform(mask): + mask_dt = [] + for m in mask: + dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)) + inv_dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(1 - m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)) + mask_dt += [torch.stack([dt, inv_dt], 0)] + return torch.stack(mask_dt, 0) # Bx2xHxW + + +def resize_img(img, scale_factor): + new_size = (np.round(np.array(img.shape[:2]) * scale_factor)).astype(int) + new_img = cv2.resize(img, (new_size[1], new_size[0])) + # This is scale factor of [height, width] i.e. [y, x] + actual_factor = [new_size[0] / float(img.shape[0]), + new_size[1] / float(img.shape[1])] + return new_img, actual_factor + + +def peturb_bbox(bbox, pf=0, jf=0): + ''' + Jitters and pads the input bbox. + Args: + bbox: Zero-indexed tight bbox. + pf: padding fraction. + jf: jittering fraction. + Returns: + pet_bbox: Jittered and padded box. Might have -ve or out-of-image coordinates + ''' + pet_bbox = [coord for coord in bbox] + bwidth = bbox[2] - bbox[0] + 1 + bheight = bbox[3] - bbox[1] + 1 + + pet_bbox[0] -= (pf*bwidth) + (1-2*np.random.random())*jf*bwidth + pet_bbox[1] -= (pf*bheight) + (1-2*np.random.random())*jf*bheight + pet_bbox[2] += (pf*bwidth) + (1-2*np.random.random())*jf*bwidth + pet_bbox[3] += (pf*bheight) + (1-2*np.random.random())*jf*bheight + + return pet_bbox + + +def square_bbox(bbox): + ''' + Converts a bbox to have a square shape by increasing size along non-max dimension. + ''' + sq_bbox = [int(round(coord)) for coord in bbox] + bwidth = sq_bbox[2] - sq_bbox[0] + 1 + bheight = sq_bbox[3] - sq_bbox[1] + 1 + maxdim = float(max(bwidth, bheight)) + + dw_b_2 = int(round((maxdim-bwidth)/2.0)) + dh_b_2 = int(round((maxdim-bheight)/2.0)) + + sq_bbox[0] -= dw_b_2 + sq_bbox[1] -= dh_b_2 + sq_bbox[2] = sq_bbox[0] + maxdim - 1 + sq_bbox[3] = sq_bbox[1] + maxdim - 1 + + return sq_bbox + + +def crop(img, bbox, bgval=0): + ''' + Crops a region from the image corresponding to the bbox. + If some regions specified go outside the image boundaries, the pixel values are set to bgval. + Args: + img: image to crop + bbox: bounding box to crop + bgval: default background for regions outside image + ''' + bbox = [int(round(c)) for c in bbox] + bwidth = bbox[2] - bbox[0] + 1 + bheight = bbox[3] - bbox[1] + 1 + + im_shape = np.shape(img) + im_h, im_w = im_shape[0], im_shape[1] + + nc = 1 if len(im_shape) < 3 else im_shape[2] + + img_out = np.ones((bheight, bwidth, nc))*bgval + x_min_src = max(0, bbox[0]) + x_max_src = min(im_w, bbox[2]+1) + y_min_src = max(0, bbox[1]) + y_max_src = min(im_h, bbox[3]+1) + + x_min_trg = x_min_src - bbox[0] + x_max_trg = x_max_src - x_min_src + x_min_trg + y_min_trg = y_min_src - bbox[1] + y_max_trg = y_max_src - y_min_src + y_min_trg + + img_out[y_min_trg:y_max_trg, x_min_trg:x_max_trg, :] = img[y_min_src:y_max_src, x_min_src:x_max_src, :] + return img_out + + +# https://github.com/akanazawa/cmr/blob/master/utils/transformations.py +import math +import numpy +_EPS = numpy.finfo(float).eps * 4.0 + + +def quaternion_matrix(quaternion): + """Return homogeneous rotation matrix from quaternion. + >>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0]) + >>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0])) + True + >>> M = quaternion_matrix([1, 0, 0, 0]) + >>> numpy.allclose(M, numpy.identity(4)) + True + >>> M = quaternion_matrix([0, 1, 0, 0]) + >>> numpy.allclose(M, numpy.diag([1, -1, -1, 1])) + True + """ + q = numpy.array(quaternion, dtype=numpy.float64, copy=True) + n = numpy.dot(q, q) + if n < _EPS: + return numpy.identity(4) + q *= math.sqrt(2.0 / n) + q = numpy.outer(q, q) + return numpy.array([ + [1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], 0.0], + [ q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], 0.0], + [ q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0], + [ 0.0, 0.0, 0.0, 1.0]]) + + +def quaternion_from_matrix(matrix, isprecise=False): + """Return quaternion from rotation matrix. + If isprecise is True, the input matrix is assumed to be a precise rotation + matrix and a faster algorithm is used. + >>> q = quaternion_from_matrix(numpy.identity(4), True) + >>> numpy.allclose(q, [1, 0, 0, 0]) + True + >>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1])) + >>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0]) + True + >>> R = rotation_matrix(0.123, (1, 2, 3)) + >>> q = quaternion_from_matrix(R, True) + >>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786]) + True + >>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0], + ... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]] + >>> q = quaternion_from_matrix(R) + >>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611]) + True + >>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0], + ... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]] + >>> q = quaternion_from_matrix(R) + >>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603]) + True + >>> R = random_rotation_matrix() + >>> q = quaternion_from_matrix(R) + >>> is_same_transform(R, quaternion_matrix(q)) + True + >>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False), + ... quaternion_from_matrix(R, isprecise=True)) + True + >>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0) + >>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False), + ... quaternion_from_matrix(R, isprecise=True)) + True + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4] + if isprecise: + q = numpy.empty((4, )) + t = numpy.trace(M) + if t > M[3, 3]: + q[0] = t + q[3] = M[1, 0] - M[0, 1] + q[2] = M[0, 2] - M[2, 0] + q[1] = M[2, 1] - M[1, 2] + else: + i, j, k = 0, 1, 2 + if M[1, 1] > M[0, 0]: + i, j, k = 1, 2, 0 + if M[2, 2] > M[i, i]: + i, j, k = 2, 0, 1 + t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3] + q[i] = t + q[j] = M[i, j] + M[j, i] + q[k] = M[k, i] + M[i, k] + q[3] = M[k, j] - M[j, k] + q = q[[3, 0, 1, 2]] + q *= 0.5 / math.sqrt(t * M[3, 3]) + else: + m00 = M[0, 0] + m01 = M[0, 1] + m02 = M[0, 2] + m10 = M[1, 0] + m11 = M[1, 1] + m12 = M[1, 2] + m20 = M[2, 0] + m21 = M[2, 1] + m22 = M[2, 2] + # symmetric matrix K + K = numpy.array([[m00-m11-m22, 0.0, 0.0, 0.0], + [m01+m10, m11-m00-m22, 0.0, 0.0], + [m02+m20, m12+m21, m22-m00-m11, 0.0], + [m21-m12, m02-m20, m10-m01, m00+m11+m22]]) + K /= 3.0 + # quaternion is eigenvector of K that corresponds to largest eigenvalue + w, V = numpy.linalg.eigh(K) + q = V[[3, 0, 1, 2], numpy.argmax(w)] + if q[0] < 0.0: + numpy.negative(q, q) + return q diff --git a/video3d/dataloaders.py b/video3d/dataloaders.py new file mode 100755 index 0000000000000000000000000000000000000000..99e111820cfb684ae05bd261428d5363f2be3c04 --- /dev/null +++ b/video3d/dataloaders.py @@ -0,0 +1,375 @@ +import os +from glob import glob +import random +import numpy as np +from PIL import Image +import cv2 +import torch +from torch.utils.data import Dataset +import torchvision.datasets.folder +import torchvision.transforms as transforms +from einops import rearrange + + +def compute_distance_transform(mask): + mask_dt = [] + for m in mask: + dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)) + inv_dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(1 - m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)) + mask_dt += [torch.stack([dt, inv_dt], 0)] + return torch.stack(mask_dt, 0) # Bx2xHxW + + +def crop_image(image, boxs, size): + crops = [] + for box in boxs: + crop_x0, crop_y0, crop_w, crop_h = box + crop = transforms.functional.resized_crop(image, crop_y0, crop_x0, crop_h, crop_w, size) + crop = transforms.functional.to_tensor(crop) + crops += [crop] + return torch.stack(crops, 0) + + +def box_loader(fpath): + box = np.loadtxt(fpath, 'str') + box[0] = box[0].split('_')[0] + return box.astype(np.float32) + + +def read_feat_from_img(path, n_channels): + feat = np.array(Image.open(path)) + return dencode_feat_from_img(feat, n_channels) + + +def dencode_feat_from_img(img, n_channels): + n_addon_channels = int(np.ceil(n_channels / 3) * 3) - n_channels + n_tiles = int((n_channels + n_addon_channels) / 3) + feat = rearrange(img, 'h (t w) c -> h w (t c)', t=n_tiles, c=3) + feat = feat[:, :, :-n_addon_channels] + feat = feat.astype('float32') / 255 + return feat.transpose(2, 0, 1) + + +def dino_loader(fpath, n_channels): + dino_map = read_feat_from_img(fpath, n_channels) + return dino_map + + +def get_valid_mask(boxs, image_size): + valid_masks = [] + for box in boxs: + crop_x0, crop_y0, crop_w, crop_h, full_w, full_h = box[1:7].int().numpy() + # Discard a small margin near the boundary. + margin_w = int(crop_w * 0.02) + margin_h = int(crop_h * 0.02) + mask_full = torch.ones(full_h-margin_h*2, full_w-margin_w*2) + mask_full_pad = torch.nn.functional.pad(mask_full, (crop_w+margin_w, crop_w+margin_w, crop_h+margin_h, crop_h+margin_h), mode='constant', value=0.0) + mask_full_crop = mask_full_pad[crop_y0+crop_h:crop_y0+crop_h*2, crop_x0+crop_w:crop_x0+crop_w*2] + mask_crop = torch.nn.functional.interpolate(mask_full_crop[None, None, :, :], image_size, mode='nearest')[0,0] + valid_masks += [mask_crop] + return torch.stack(valid_masks, 0) # NxHxW + + +def horizontal_flip_box(box): + frame_id, crop_x0, crop_y0, crop_w, crop_h, full_w, full_h, sharpness, label = box.unbind(1) + box[:,1] = full_w - crop_x0 - crop_w # x0 + return box + + +def horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features=None, dino_clusters=None): + images = images.flip(3) # NxCxHxW + masks = masks.flip(3) # NxCxHxW + mask_dt = mask_dt.flip(3) # NxCxHxW + mask_valid = mask_valid.flip(2) # NxHxW + if flows.dim() > 1: + flows = flows.flip(3) # (N-1)x(x,y)xHxW + flows[:,0] *= -1 # invert delta x + bboxs = horizontal_flip_box(bboxs) # NxK + bg_images = bg_images.flip(3) # NxCxHxW + if dino_features.dim() > 1: + dino_features = dino_features.flip(3) + if dino_clusters.dim() > 1: + dino_clusters = dino_clusters.flip(3) + return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters + + +class BaseSequenceDataset(Dataset): + def __init__(self, root, skip_beginning=4, skip_end=4, min_seq_len=10, debug_seq=False): + super().__init__() + + self.skip_beginning = skip_beginning + self.skip_end = skip_end + self.min_seq_len = min_seq_len + # self.pattern = "{:07d}_{}" + self.sequences = self._make_sequences(root) + + if debug_seq: + # self.sequences = [self.sequences[0][20:160]] * 100 + seq_len = 0 + while seq_len < min_seq_len: + i = np.random.randint(len(self.sequences)) + rand_seq = self.sequences[i] + seq_len = len(rand_seq) + self.sequences = [rand_seq] + + self.samples = [] + + def _make_sequences(self, path): + result = [] + for d in sorted(os.scandir(path), key=lambda e: e.name): + if d.is_dir(): + files = self._parse_folder(d) + if len(files) >= self.min_seq_len: + result.append(files) + return result + + def _parse_folder(self, path): + result = sorted(glob(os.path.join(path, '*'+self.image_loaders[0][0]))) + result = [p.replace(self.image_loaders[0][0], '{}') for p in result] + + if len(result) <= self.skip_beginning + self.skip_end: + return [] + if self.skip_end == 0: + return result[self.skip_beginning:] + return result[self.skip_beginning:-self.skip_end] + + def _load_ids(self, path_patterns, loaders, transform=None): + result = [] + for loader in loaders: + for p in path_patterns: + x = loader[1](p.format(loader[0]), *loader[2:]) + if transform: + x = transform(x) + result.append(x) + return tuple(result) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, index): + raise NotImplemented("This is a base class and should not be used directly") + + +class NFrameSequenceDataset(BaseSequenceDataset): + def __init__(self, root, cat_name=None, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, in_image_size=256, out_image_size=256, debug_seq=False, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64, **kwargs): + self.cat_name = cat_name + self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)] + self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)] + self.bbox_loaders = [("box.txt", box_loader)] + super().__init__(root, skip_beginning, skip_end, min_seq_len, debug_seq) + if num_sample_frames > 1: + self.flow_loaders = [("flow.png", cv2.imread, cv2.IMREAD_UNCHANGED)] + else: + self.flow_loaders = None + + self.num_sample_frames = num_sample_frames + self.random_sample = random_sample + if self.random_sample: + if shuffle: + random.shuffle(self.sequences) + self.samples = self.sequences + else: + for i, s in enumerate(self.sequences): + stride = 1 if dense_sample else self.num_sample_frames + self.samples += [(i, k) for k in range(0, len(s), stride)] + if shuffle: + random.shuffle(self.samples) + + self.in_image_size = in_image_size + self.out_image_size = out_image_size + self.load_background = load_background + self.color_jitter = color_jitter + self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()]) + self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()]) + if self.flow_loaders is not None: + self.flow_transform = lambda x: (torch.FloatTensor(x.astype(np.float32)).flip(2)[:,:,:2] / 65535. ) *2 -1 + self.random_flip = random_flip + self.load_dino_feature = load_dino_feature + if load_dino_feature: + self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)] + self.load_dino_cluster = load_dino_cluster + if load_dino_cluster: + self.dino_cluster_loaders = [("clusters.png", torchvision.datasets.folder.default_loader)] + + def __getitem__(self, index): + if self.random_sample: + seq_idx = index % len(self.sequences) + seq = self.sequences[seq_idx] + if len(seq) < self.num_sample_frames: + start_frame_idx = 0 + else: + start_frame_idx = np.random.randint(len(seq)-self.num_sample_frames+1) + paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames] + else: + seq_idx, start_frame_idx = self.samples[index % len(self.samples)] + seq = self.sequences[seq_idx] + # Handle edge case: when only last frame is left, sample last two frames, except if the sequence only has one frame + if len(seq) <= start_frame_idx +1: + start_frame_idx = max(0, start_frame_idx-1) + paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames] + + masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images + mask_dt = compute_distance_transform(masks) + jitter = False + if self.color_jitter is not None: + prob, b, h = self.color_jitter + if np.random.rand() < prob: + jitter = True + color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) + image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()]) + color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) + image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()]) + if jitter: + images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images + images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images + images = images_fg * masks + images_bg * (1-masks) + else: + images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images + if len(paths) > 1: + flows = torch.stack(self._load_ids(paths[:-1], self.flow_loaders, transform=self.flow_transform), 0).permute(0,3,1,2) # load flow for first image, (N-1)x(x,y)xHxW, -1~1 + flows = torch.nn.functional.interpolate(flows, size=self.out_image_size, mode="bilinear") + else: + flows = torch.zeros(1) + bboxs = torch.stack(self._load_ids(paths, self.bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images + mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image + if self.load_background: + bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg')) + if jitter: + bg_image = color_jitter_tsf_bg(bg_image) + bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size)) + else: + bg_images = torch.zeros_like(images) + if self.load_dino_feature: + dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224 + else: + dino_features = torch.zeros(1) + if self.load_dino_cluster: + dino_clusters = torch.stack(self._load_ids(paths, self.dino_cluster_loaders, transform=transforms.ToTensor()), 0) # BxFx3x55x55 + else: + dino_clusters = torch.zeros(1) + seq_idx = torch.LongTensor([seq_idx]) + frame_idx = torch.arange(start_frame_idx, start_frame_idx+len(paths)).long() + + if self.random_flip and np.random.rand() < 0.5: + images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters) + + ## pad shorter sequence + if len(paths) < self.num_sample_frames: + num_pad = self.num_sample_frames - len(paths) + images = torch.cat([images[:1]] *num_pad + [images], 0) + masks = torch.cat([masks[:1]] *num_pad + [masks], 0) + mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0) + mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0) + if flows.dim() > 1: + flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0) + bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0) + bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0) + if dino_features.dim() > 1: + dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0) + if dino_clusters.dim() > 1: + dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0) + frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0) + + return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name + + +def get_sequence_loader(data_dir, **kwargs): + if isinstance(data_dir, dict): + loaders = [] + for k, v in data_dir.items(): + dataset= NFrameSequenceDataset(v, cat_name=k, **kwargs) + loader = torch.utils.data.DataLoader(dataset, batch_size=kwargs['batch_size'], shuffle=kwargs['shuffle'], num_workers=kwargs['num_workers'], pin_memory=True) + loaders += [loader] + return loaders + else: + return [get_sequence_loader_single(data_dir, **kwargs)] + + +def get_sequence_loader_single(data_dir, mode='all_frame', is_validation=False, batch_size=256, num_workers=4, in_image_size=256, out_image_size=256, debug_seq=False, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, max_seq_len=256, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.jpg', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64): + if mode == 'n_frame': + dataset = NFrameSequenceDataset(data_dir, num_sample_frames=num_sample_frames, skip_beginning=skip_beginning, skip_end=skip_end, min_seq_len=min_seq_len, in_image_size=in_image_size, out_image_size=out_image_size, debug_seq=debug_seq, random_sample=random_sample, shuffle=shuffle, dense_sample=dense_sample, color_jitter=color_jitter, load_background=load_background, random_flip=random_flip, rgb_suffix=rgb_suffix, load_dino_feature=load_dino_feature, load_dino_cluster=load_dino_cluster, dino_feature_dim=dino_feature_dim) + else: + raise NotImplementedError + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=not is_validation, + num_workers=num_workers, + pin_memory=True + ) + return loader + + +class ImageDataset(Dataset): + def __init__(self, root, is_validation=False, image_size=256, color_jitter=None): + super().__init__() + self.image_loader = ("rgb.jpg", torchvision.datasets.folder.default_loader) + self.mask_loader = ("mask.png", torchvision.datasets.folder.default_loader) + self.bbox_loader = ("box.txt", np.loadtxt, 'str') + self.samples = self._parse_folder(root) + self.image_size = image_size + self.color_jitter = color_jitter + self.image_transform = transforms.Compose([transforms.Resize(self.image_size), transforms.ToTensor()]) + self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=Image.NEAREST), transforms.ToTensor()]) + + def _parse_folder(self, path): + result = sorted(glob(os.path.join(path, '**/*'+self.image_loader[0]), recursive=True)) + result = [p.replace(self.image_loader[0], '{}') for p in result] + return result + + def _load_ids(self, path, loader, transform=None): + x = loader[1](path.format(loader[0]), *loader[2:]) + if transform: + x = transform(x) + return x + + def __len__(self): + return len(self.samples) + + def __getitem__(self, index): + path = self.samples[index % len(self.samples)] + masks = self._load_ids(path, self.mask_loader, transform=self.mask_transform).unsqueeze(0) + mask_dt = compute_distance_transform(masks) + jitter = False + if self.color_jitter is not None: + prob, b, h = self.color_jitter + if np.random.rand() < prob: + jitter = True + color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) + image_transform_fg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_fg, transforms.ToTensor()]) + color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) + image_transform_bg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_bg, transforms.ToTensor()]) + if jitter: + images_fg = self._load_ids(path, self.image_loader, transform=image_transform_fg).unsqueeze(0) + images_bg = self._load_ids(path, self.image_loader, transform=image_transform_bg).unsqueeze(0) + images = images_fg * masks + images_bg * (1-masks) + else: + images = self._load_ids(path, self.image_loader, transform=self.image_transform).unsqueeze(0) + flows = torch.zeros(1) + bboxs = self._load_ids(path, self.bbox_loader, transform=None) + bboxs[0] = '0' + bboxs = torch.FloatTensor(bboxs.astype('float')).unsqueeze(0) + bg_fpath = os.path.join(os.path.dirname(path), 'background_frame.jpg') + if os.path.isfile(bg_fpath): + bg_image = torchvision.datasets.folder.default_loader(bg_fpath) + if jitter: + bg_image = color_jitter_tsf_bg(bg_image) + bg_image = transforms.ToTensor()(bg_image) + else: + bg_image = images[0] + seq_idx = torch.LongTensor([index]) + frame_idx = torch.LongTensor([0]) + return images, masks, mask_dt, flows, bboxs, bg_image, seq_idx, frame_idx + + +def get_image_loader(data_dir, is_validation=False, batch_size=256, num_workers=4, image_size=256, color_jitter=None): + dataset = ImageDataset(data_dir, is_validation=is_validation, image_size=image_size, color_jitter=color_jitter) + + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True + ) + return loader diff --git a/video3d/dataloaders_ddp.py b/video3d/dataloaders_ddp.py new file mode 100755 index 0000000000000000000000000000000000000000..187be1a83fbe2d2ecac57b090e380109d8b22ebc --- /dev/null +++ b/video3d/dataloaders_ddp.py @@ -0,0 +1,1210 @@ +import os +from glob import glob +import random +import numpy as np +from PIL import Image +import cv2 +import itertools +import torch +import copy +from torch.utils.data import Dataset +import torchvision.datasets.folder +import torchvision.transforms as transforms +from einops import rearrange + + +def compute_distance_transform(mask): + mask_dt = [] + for m in mask: + dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)) + inv_dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(1 - m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)) + mask_dt += [torch.stack([dt, inv_dt], 0)] + return torch.stack(mask_dt, 0) # Bx2xHxW + + +def crop_image(image, boxs, size): + crops = [] + for box in boxs: + crop_x0, crop_y0, crop_w, crop_h = box + crop = transforms.functional.resized_crop(image, crop_y0, crop_x0, crop_h, crop_w, size) + crop = transforms.functional.to_tensor(crop) + crops += [crop] + return torch.stack(crops, 0) + + +def box_loader(fpath): + box = np.loadtxt(fpath, 'str') + box[0] = box[0].split('_')[0] + return box.astype(np.float32) + + +def read_feat_from_img(path, n_channels): + feat = np.array(Image.open(path)) + return dencode_feat_from_img(feat, n_channels) + + +def dencode_feat_from_img(img, n_channels): + n_addon_channels = int(np.ceil(n_channels / 3) * 3) - n_channels + n_tiles = int((n_channels + n_addon_channels) / 3) + feat = rearrange(img, 'h (t w) c -> h w (t c)', t=n_tiles, c=3) + if n_addon_channels != 0: + feat = feat[:, :, :-n_addon_channels] + feat = feat.astype('float32') / 255 + return feat.transpose(2, 0, 1) + + +def dino_loader(fpath, n_channels): + dino_map = read_feat_from_img(fpath, n_channels) + return dino_map + + +def get_valid_mask(boxs, image_size): + valid_masks = [] + for box in boxs: + crop_x0, crop_y0, crop_w, crop_h, full_w, full_h = box[1:7].int().numpy() + margin_w = int(crop_w * 0.02) + margin_h = int(crop_h * 0.02) + mask_full = torch.ones(full_h-margin_h*2, full_w-margin_w*2) + mask_full_pad = torch.nn.functional.pad(mask_full, (crop_w+margin_w, crop_w+margin_w, crop_h+margin_h, crop_h+margin_h), mode='constant', value=0.0) + mask_full_crop = mask_full_pad[(crop_y0+crop_h):crop_y0+(crop_h*2), (crop_x0+crop_w):crop_x0+(crop_w*2)] + mask_crop = torch.nn.functional.interpolate(mask_full_crop[None, None, :, :], image_size, mode='nearest')[0,0] + valid_masks += [mask_crop] + return torch.stack(valid_masks, 0) # NxHxW + + +def horizontal_flip_box(box): + frame_id, crop_x0, crop_y0, crop_w, crop_h, full_w, full_h, sharpness, label = box.unbind(1) + box[:,1] = full_w - crop_x0 - crop_w # x0 + return box + + +def horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features=None, dino_clusters=None): + images = images.flip(3) # NxCxHxW + masks = masks.flip(3) # NxCxHxW + mask_dt = mask_dt.flip(3) # NxCxHxW + mask_valid = mask_valid.flip(2) # NxHxW + if flows.dim() > 1: + flows = flows.flip(3) # (N-1)x(x,y)xHxW + flows[:,0] *= -1 # invert delta x + bboxs = horizontal_flip_box(bboxs) # NxK + bg_images = bg_images.flip(3) # NxCxHxW + if dino_features.dim() > 1: + dino_features = dino_features.flip(3) + if dino_clusters.dim() > 1: + dino_clusters = dino_clusters.flip(3) + return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters + + +def none_to_nan(x): + return torch.FloatTensor([float('nan')]) if x is None else x + + +class BaseSequenceDataset(Dataset): + def __init__(self, root, skip_beginning=4, skip_end=4, min_seq_len=10, debug_seq=False): + super().__init__() + + self.skip_beginning = skip_beginning + self.skip_end = skip_end + self.min_seq_len = min_seq_len + # self.pattern = "{:07d}_{}" + self.sequences = self._make_sequences(root) + + if debug_seq: + # self.sequences = [self.sequences[0][20:160]] * 100 + seq_len = 0 + while seq_len < min_seq_len: + i = np.random.randint(len(self.sequences)) + rand_seq = self.sequences[i] + seq_len = len(rand_seq) + self.sequences = [rand_seq] + + self.samples = [] + + def _make_sequences(self, path): + result = [] + for d in sorted(os.scandir(path), key=lambda e: e.name): + if d.is_dir(): + files = self._parse_folder(d) + if len(files) >= self.min_seq_len: + result.append(files) + return result + + def _parse_folder(self, path): + result = sorted(glob(os.path.join(path, '*'+self.image_loaders[0][0]))) + result = [p.replace(self.image_loaders[0][0], '{}') for p in result] + + if len(result) <= self.skip_beginning + self.skip_end: + return [] + if self.skip_end == 0: + return result[self.skip_beginning:] + return result[self.skip_beginning:-self.skip_end] + + def _load_ids(self, path_patterns, loaders, transform=None): + result = [] + for loader in loaders: + for p in path_patterns: + x = loader[1](p.format(loader[0]), *loader[2:]) + if transform: + x = transform(x) + result.append(x) + return tuple(result) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, index): + raise NotImplemented("This is a base class and should not be used directly") + + +class NFrameSequenceDataset(BaseSequenceDataset): + def __init__(self, root, cat_name=None, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, in_image_size=256, out_image_size=256, debug_seq=False, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64, flow_bool=False, **kwargs): + self.cat_name = cat_name + self.flow_bool=flow_bool + + self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)] + self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)] + self.bbox_loaders = [("box.txt", box_loader)] + super().__init__(root, skip_beginning, skip_end, min_seq_len, debug_seq) + # from IPython import embed; embed() + if flow_bool and num_sample_frames > 1: + self.flow_loaders = [("flow.png", cv2.imread, cv2.IMREAD_UNCHANGED)] + else: + self.flow_loaders = None + + self.num_sample_frames = num_sample_frames + self.random_sample = random_sample + if self.random_sample: + if shuffle: + random.shuffle(self.sequences) + self.samples = self.sequences + else: + + for i, s in enumerate(self.sequences): + stride = 1 if dense_sample else self.num_sample_frames + self.samples += [(i, k) for k in range(0, len(s), stride)] + if shuffle: + random.shuffle(self.samples) + + self.in_image_size = in_image_size + self.out_image_size = out_image_size + self.load_background = load_background + self.color_jitter = color_jitter + self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()]) + self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()]) + if self.flow_loaders is not None: + self.flow_transform = lambda x: (torch.FloatTensor(x.astype(np.float32)).flip(2)[:,:,:2] / 65535. ) *2 -1 + self.random_flip = random_flip + self.load_dino_feature = load_dino_feature + if load_dino_feature: + self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)] + self.load_dino_cluster = load_dino_cluster + if load_dino_cluster: + self.dino_cluster_loaders = [("clusters.png", torchvision.datasets.folder.default_loader)] + + def __getitem__(self, index): + if self.random_sample: + seq_idx = index % len(self.sequences) + seq = self.sequences[seq_idx] + if len(seq) < self.num_sample_frames: + start_frame_idx = 0 + else: + start_frame_idx = np.random.randint(len(seq)-self.num_sample_frames+1) + paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames] + else: + seq_idx, start_frame_idx = self.samples[index % len(self.samples)] + seq = self.sequences[seq_idx] + # Handle edge case: when only last frame is left, sample last two frames, except if the sequence only has one frame + if len(seq) <= start_frame_idx +1: + start_frame_idx = max(0, start_frame_idx-1) + paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames] + + masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images + mask_dt = compute_distance_transform(masks) + jitter = False + if self.color_jitter is not None: + prob, b, h = self.color_jitter + if np.random.rand() < prob: + jitter = True + color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) + image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()]) + color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) + image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()]) + if jitter: + images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images + images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images + images = images_fg * masks + images_bg * (1-masks) + else: + images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images + if self.flow_bool==True and len(paths) > 1: + flows = torch.stack(self._load_ids(paths[:-1], self.flow_loaders, transform=self.flow_transform), 0).permute(0,3,1,2) # load flow for first image, (N-1)x(x,y)xHxW, -1~1 + flows = torch.nn.functional.interpolate(flows, size=self.out_image_size, mode="bilinear") + else: + flows = torch.zeros(1) + bboxs = torch.stack(self._load_ids(paths, self.bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images + mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image + if self.load_background: + bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg')) + if jitter: + bg_image = color_jitter_tsf_bg(bg_image) + bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size)) + else: + bg_images = torch.zeros_like(images) + if self.load_dino_feature: + dino_paths = [ + x.replace( + "/viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new", + "/viscam/projects/articulated/zzli/data_dino_5000/7_cat" + ) + for x in paths + ] + dino_features = torch.stack(self._load_ids(dino_paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) + # dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224 + else: + dino_features = torch.zeros(1) + if self.load_dino_cluster: + dino_clusters = torch.stack(self._load_ids(paths, self.dino_cluster_loaders, transform=transforms.ToTensor()), 0) # BxFx3x55x55 + else: + dino_clusters = torch.zeros(1) + seq_idx = torch.LongTensor([seq_idx]) + frame_idx = torch.arange(start_frame_idx, start_frame_idx+len(paths)).long() + + if self.random_flip and np.random.rand() < 0.5: + images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters) + + ## pad shorter sequence + if len(paths) < self.num_sample_frames: + num_pad = self.num_sample_frames - len(paths) + images = torch.cat([images[:1]] *num_pad + [images], 0) + masks = torch.cat([masks[:1]] *num_pad + [masks], 0) + mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0) + mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0) + if flows.dim() > 1: + flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0) + bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0) + bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0) + if dino_features.dim() > 1: + dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0) + if dino_clusters.dim() > 1: + dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0) + frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0) + + out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name)), ) + return out + # return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name + + +def few_shot_box_loader(fpath): + box = np.loadtxt(fpath, 'str') + # box[0] = box[0].split('_')[0] + return box.astype(np.float32) + + +class FewShotImageDataset(Dataset): + def __init__(self, root, cat_name=None, cat_num=0, num_sample_frames=2, in_image_size=256, out_image_size=256, shuffle=False, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, dino_feature_dim=64, flow_bool=False, **kwargs): + super().__init__() + self.cat_name = cat_name + self.cat_num = cat_num # this is actually useless + self.flow_bool=flow_bool + + self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)] + self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)] + self.bbox_loaders = [("box.txt", few_shot_box_loader)] + self.flow_loaders = None + + # get all the valid paths, since it's just image-wise, in get_item, we will make it like a len=1 sequence + result = sorted(glob(os.path.join(root, '*'+self.image_loaders[0][0]))) + result = [p.replace(self.image_loaders[0][0], '{}') for p in result] + self.sequences = result + + self.num_sample_frames = num_sample_frames + if shuffle: + random.shuffle(self.sequences) + self.samples = self.sequences + + self.in_image_size = in_image_size + self.out_image_size = out_image_size + self.load_background = load_background + self.color_jitter = color_jitter + self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()]) + self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()]) + self.random_flip = random_flip + self.load_dino_feature = load_dino_feature + if load_dino_feature: + self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)] + + def _load_ids(self, path_patterns, loaders, transform=None): + result = [] + for loader in loaders: + for p in path_patterns: + x = loader[1](p.format(loader[0]), *loader[2:]) + if transform: + x = transform(x) + result.append(x) + return tuple(result) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, index): + paths = [self.samples[index]] # len 1 sequence + + masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images + mask_dt = compute_distance_transform(masks) + jitter = False + if self.color_jitter is not None: + prob, b, h = self.color_jitter + if np.random.rand() < prob: + jitter = True + color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) + image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()]) + color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) + image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()]) + if jitter: + images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images + images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images + images = images_fg * masks + images_bg * (1-masks) + else: + images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images + + flows = torch.zeros(1) + bboxs = torch.stack(self._load_ids(paths, self.bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images + bboxs=torch.cat([bboxs, torch.Tensor([[self.cat_num]]).float()],dim=-1) # pad a label number + + mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image + if self.load_background: + bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg')) + if jitter: + bg_image = color_jitter_tsf_bg(bg_image) + bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size)) + else: + bg_images = torch.zeros_like(images) + if self.load_dino_feature: + dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224 + else: + dino_features = torch.zeros(1) + + dino_clusters = torch.zeros(1) + + # These are actually no use + seq_idx = 0 + seq_idx = torch.LongTensor([seq_idx]) + frame_idx = torch.arange(0, 1).long() + + if self.random_flip and np.random.rand() < 0.5: + images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters) + + ## pad shorter sequence + if len(paths) < self.num_sample_frames: + num_pad = self.num_sample_frames - len(paths) + images = torch.cat([images[:1]] *num_pad + [images], 0) + masks = torch.cat([masks[:1]] *num_pad + [masks], 0) + mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0) + mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0) + if flows.dim() > 1: + flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0) + bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0) + bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0) + if dino_features.dim() > 1: + dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0) + if dino_clusters.dim() > 1: + dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0) + frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0) + + out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name)), ) + return out + # return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name + + +class Quadrupeds_Image_Dataset(Dataset): + def __init__(self, original_data_dirs, few_shot_data_dirs, original_num=7, few_shot_num=93, num_sample_frames=2, + in_image_size=256, out_image_size=256, is_validation=False, val_image_num=5, shuffle=False, color_jitter=None, + load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, dino_feature_dim=64, + flow_bool=False, disable_fewshot=False, dataset_split_num=-1, **kwargs): + self.original_data_dirs = original_data_dirs + self.few_shot_data_dirs = few_shot_data_dirs + self.original_num = original_num + self.few_shot_num = few_shot_num + + self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)] + self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)] + self.original_bbox_loaders = [("box.txt", box_loader)] + self.few_shot_bbox_loaders = [("box.txt", few_shot_box_loader)] + + assert len(self.original_data_dirs.keys()) == self.original_num + assert len(self.few_shot_data_dirs.keys()) == self.few_shot_num + self.num_sample_frames = num_sample_frames + + self.batch_size = kwargs['batch_size'] # a hack way here + + # for debug, just use some categories + if "override_categories" in kwargs: + self.override_categories = kwargs["override_categories"] + else: + self.override_categories = None + + # original dataset + original_data_paths = {} + for k,v in self.original_data_dirs.items(): + + # categories override + if self.override_categories is not None: + if k not in self.override_categories: + continue + + sequences = self._make_sequences(v) + samples = [] + for seq in sequences: + samples += seq + if shuffle: + random.shuffle(samples) + original_data_paths.update({k: samples}) + + # few-shot dataset + enhance_back_view = kwargs['enhance_back_view'] + if enhance_back_view: + enhance_back_view_path = kwargs['enhance_back_view_path'] + + few_shot_data_paths = {} + for k,v in self.few_shot_data_dirs.items(): + + # categories override + if self.override_categories is not None: + if k not in self.override_categories: + continue + if k.startswith('_'): + # a boundary here for dealing with when in new data, we have same categories as in 7-cat + v = v.replace(k, k[1:]) + + if isinstance(v, str): + result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0]))) + elif isinstance(v, list): + result = [] + for _v in v: + result = result + sorted(glob(os.path.join(_v, '*'+self.image_loaders[0][0]))) + else: + raise NotImplementedError + + # result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0]))) + result = [p.replace(self.image_loaders[0][0], '{}') for p in result] + sequences = result + + # the original 7 categories are using pre-defined paths to separate train and test + # here the few-shot we use is_validation to decide if this dataset is train or test + # if use enhanced back view, we first pad the multiplied back view image paths at the front of seq + # i.e., we don't use back view images for validation + if enhance_back_view: + back_view_dir = os.path.join(enhance_back_view_path, k, 'train') + back_view_result = sorted(glob(os.path.join(back_view_dir, '*'+self.image_loaders[0][0]))) + back_view_result = [p.replace(self.image_loaders[0][0], '{}') for p in back_view_result] + mul_bv_sequences = self._more_back_views(back_view_result, result) + sequences = mul_bv_sequences + sequences + + if is_validation: + # sequences = sequences[-2:] + sequences = sequences[-val_image_num:] + else: + # sequences = sequences[:-2] + sequences = sequences[:-val_image_num] + + if shuffle: + random.shuffle(sequences) + few_shot_data_paths.update({k: sequences}) + + # for visualization purpose + self.pure_ori_data_path = original_data_paths + self.pure_fs_data_path = few_shot_data_paths + + self.few_shot_data_length = self._get_data_length(few_shot_data_paths) # get the original length of each few-shot category + + if disable_fewshot: + few_shot_data_paths = {} + + self.dataset_split_num = dataset_split_num # if -1 then pad to longest, otherwise follow this number to pad and split + if is_validation: + self.dataset_split_num = -1 # validation we don't split dataset + + if self.dataset_split_num == -1: + self.all_data_paths, self.one_category_num = self._pad_paths(original_data_paths, few_shot_data_paths) + self.all_category_num = len(self.all_data_paths.keys()) + self.all_category_names = list(self.all_data_paths.keys()) + self.original_category_names = list(self.original_data_dirs.keys()) + elif self.dataset_split_num > 0: + self.all_data_paths, self.one_category_num, self.original_category_names = self._pad_paths_withnum(original_data_paths, few_shot_data_paths, self.dataset_split_num) + self.all_category_num = len(self.all_data_paths.keys()) + self.all_category_names = list(self.all_data_paths.keys()) + else: + raise NotImplementedError + + self.in_image_size = in_image_size + self.out_image_size = out_image_size + self.load_background = load_background + self.color_jitter = color_jitter + self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()]) + self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()]) + self.random_flip = random_flip + self.load_dino_feature = load_dino_feature + if load_dino_feature: + self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)] + + def _more_back_views(self, back_view_seq, seq): + if len(back_view_seq) == 0: + # for category without back views + return [] + factor = 5 + # length = (len(seq) // factor) * factor + length = (len(seq) // factor) * (factor - 1) + mul_f = length // len(back_view_seq) + pad_f = length % len(back_view_seq) + new_seq = mul_f * back_view_seq + back_view_seq[:pad_f] + return new_seq + + def _get_data_length(self, paths): + data_length = {} + for k,v in paths.items(): + length = len(v) + data_length.update({k: length}) + return data_length + + def _make_sequences(self, path): + result = [] + for d in sorted(os.scandir(path), key=lambda e: e.name): + if d.is_dir(): + files = self._parse_folder(d) + if len(files) >= 1: + result.append(files) + return result + + def _parse_folder(self, path): + result = sorted(glob(os.path.join(path, '*'+self.image_loaders[0][0]))) + result = [p.replace(self.image_loaders[0][0], '{}') for p in result] + + if len(result) <= 0: + return [] + return result + + def _pad_paths(self, ori_paths, fs_paths): + img_nums = [] + all_paths = copy.deepcopy(ori_paths) + all_paths.update(fs_paths) + for _, v in all_paths.items(): + img_nums.append(len(v)) + + img_num = max(img_nums) + img_num = (img_num // self.batch_size) * self.batch_size + + for k,v in all_paths.items(): + if len(v) < img_num: + mul_time = img_num // len(v) + pad_time = img_num % len(v) + # for each v, shuffle it + shuffle_v = copy.deepcopy(v) + new_v = [] + for i in range(mul_time): + new_v = new_v + shuffle_v + random.shuffle(shuffle_v) + del shuffle_v + new_v = new_v + v[0:pad_time] + # new_v = mul_time * v + v[0:pad_time] + all_paths[k] = new_v + elif len(v) > img_num: + all_paths[k] = v[:img_num] + else: + continue + + return all_paths, img_num + + def _pad_paths_withnum(self, ori_paths, fs_paths, split_num=1000): + img_num = (split_num // self.batch_size) * self.batch_size + all_paths = {} + orig_cat_names = [] + + for k, v in ori_paths.items(): + total_num = ((len(v) // img_num) + 1) * img_num + pad_num = total_num - len(v) + split_num = total_num // img_num + + new_v = copy.deepcopy(v) + random.shuffle(new_v) + all_v = v + new_v[:pad_num] + del new_v + + for sn in range(split_num): + split_cat_name = f'{k}_' + '%03d' % sn + all_paths.update({ + split_cat_name: all_v[sn*img_num: (sn+1)*img_num] + }) + orig_cat_names.append(split_cat_name) + + for k, v in fs_paths.items(): + if len(v) < img_num: + mul_time = img_num // len(v) + pad_time = img_num % len(v) + # for each v, shuffle it + shuffle_v = copy.deepcopy(v) + new_v = [] + for i in range(mul_time): + new_v = new_v + shuffle_v + random.shuffle(shuffle_v) + del shuffle_v + new_v = new_v + v[0:pad_time] + # new_v = mul_time * v + v[0:pad_time] + all_paths.update({ + k: new_v + }) + elif len(v) > img_num: + all_paths.update({ + k: v[:img_num] + }) + else: + continue + + return all_paths, img_num, orig_cat_names + + + def _load_ids(self, path_patterns, loaders, transform=None): + result = [] + for loader in loaders: + for p in path_patterns: + x = loader[1](p.format(loader[0]), *loader[2:]) + if transform: + x = transform(x) + result.append(x) + return tuple(result) + + def _shuffle_all(self): + for k,v in self.all_data_paths.items(): + new_v = copy.deepcopy(v) + random.shuffle(new_v) + self.all_data_paths[k] = new_v + return None + + def __len__(self): + return self.all_category_num * self.one_category_num + + def __getitem__(self, index): + ''' + This dataset must have non-shuffled index!! + ''' + category_idx = (index % (self.batch_size * self.all_category_num)) // self.batch_size + path_idx = (index // (self.batch_size * self.all_category_num)) * self.batch_size + (index % (self.batch_size * self.all_category_num)) - category_idx * self.batch_size + category_name = self.all_category_names[category_idx] + paths = [self.all_data_paths[category_name][path_idx]] # len 1 sequence + + if category_name in self.original_category_names: + bbox_loaders = self.original_bbox_loaders + use_original_bbox = True + else: + bbox_loaders = self.few_shot_bbox_loaders + use_original_bbox = False + + masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images + mask_dt = compute_distance_transform(masks) + jitter = False + if self.color_jitter is not None: + prob, b, h = self.color_jitter + if np.random.rand() < prob: + jitter = True + color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) + image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()]) + color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) + image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()]) + if jitter: + images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images + images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images + images = images_fg * masks + images_bg * (1-masks) + else: + images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images + + flows = torch.zeros(1) + bboxs = torch.stack(self._load_ids(paths, bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images + if not use_original_bbox: + bboxs=torch.cat([bboxs, torch.Tensor([[category_idx]]).float()],dim=-1) # pad a label number + + mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image + if self.load_background: + bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg')) + if jitter: + bg_image = color_jitter_tsf_bg(bg_image) + bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size)) + else: + bg_images = torch.zeros_like(images) + if self.load_dino_feature: + # print(paths) + new_dino_data_name = "data_dino_5000" + new_dino_data_path = os.path.join("/viscam/projects/articulated/dor/combine_all_data_for_ablation_magicpony", new_dino_data_name) + + # TODO: use another version of DINO here by changing the path + if paths[0].startswith("/viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new"): + # 7 cat data + new_dino_path = paths[0].replace( + "/viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new", + "/viscam/projects/articulated/zzli/data_dino_5000/7_cat" + ) + dino_paths = [new_dino_path] + elif paths[0].startswith("/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/few_shot_data_all"): + # 100 cat + dino_path = paths[0].replace( + "/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/few_shot_data_all", + os.path.join(new_dino_data_path, "100_cat") + ) + dino_path_list = dino_path.split("/") + new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/" + new_dino_path = '/'.join(new_dino_path) + dino_paths = [new_dino_path] + + elif paths[0].startswith("/viscam/projects/articulated/zzli/fs_data/data_resize_update/few_shot_data_all"): + # 100 cat + dino_path = paths[0].replace( + "/viscam/projects/articulated/zzli/fs_data/data_resize_update/few_shot_data_all", + os.path.join(new_dino_data_path, "100_cat") + ) + dino_path_list = dino_path.split("/") + new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/" + new_dino_path = '/'.join(new_dino_path) + dino_paths = [new_dino_path] + + elif paths[0].startswith("/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/segmented_back_view_data"): + # back 100 cat + dino_path = paths[0].replace( + "/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/segmented_back_view_data", + os.path.join(new_dino_data_path, "back_100_cat") + ) + dino_path_list = dino_path.split("/") + new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/" + new_dino_path = '/'.join(new_dino_path) + dino_paths = [new_dino_path] + + elif paths[0].startswith("/viscam/projects/articulated/dor/Animal-Data-Engine/data/data_resize_update/train_with_classes_filtered"): + # animal3d + dino_path = paths[0].replace( + "/viscam/projects/articulated/dor/Animal-Data-Engine/data/data_resize_update/train_with_classes_filtered", + os.path.join(new_dino_data_path, "animal3D") + ) + dino_path_list = dino_path.split("/") + new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/" + new_dino_path = '/'.join(new_dino_path) + dino_paths = [new_dino_path] + else: + raise NotImplementedError + dino_features = torch.stack(self._load_ids(dino_paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) + # dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224 + else: + dino_features = torch.zeros(1) + + dino_clusters = torch.zeros(1) + + # These are actually no use + seq_idx = 0 + seq_idx = torch.LongTensor([seq_idx]) + frame_idx = torch.arange(0, 1).long() + + if self.random_flip and np.random.rand() < 0.5: + images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters) + + ## pad shorter sequence + if len(paths) < self.num_sample_frames: + num_pad = self.num_sample_frames - len(paths) + images = torch.cat([images[:1]] *num_pad + [images], 0) + masks = torch.cat([masks[:1]] *num_pad + [masks], 0) + mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0) + mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0) + if flows.dim() > 1: + flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0) + bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0) + bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0) + if dino_features.dim() > 1: + dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0) + if dino_clusters.dim() > 1: + dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0) + frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0) + + out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name)), ) + return out + # return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name + +def get_sequence_loader_quadrupeds(original_data_dirs, few_shot_data_dirs, original_num, few_shot_num, rank, world_size, **kwargs): + dataset = Quadrupeds_Image_Dataset(original_data_dirs, few_shot_data_dirs, original_num, few_shot_num, **kwargs) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=False + ) + loaders = [] + loaders += [torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=kwargs['batch_size'], shuffle=False, drop_last=True, num_workers=kwargs['num_workers'], pin_memory=True)] + + return loaders + + +class Quadrupeds_Image_Test_Dataset(Dataset): + def __init__(self, test_data_dirs, num_sample_frames=2, in_image_size=256, out_image_size=256, shuffle=False, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, dino_feature_dim=64, flow_bool=False, **kwargs): + self.few_shot_data_dirs = test_data_dirs + + self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)] + self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)] + self.original_bbox_loaders = [("box.txt", box_loader)] + self.few_shot_bbox_loaders = [("box.txt", few_shot_box_loader)] + + self.num_sample_frames = num_sample_frames + + self.batch_size = kwargs['batch_size'] # a hack way here + + few_shot_data_paths = {} + for k,v in self.few_shot_data_dirs.items(): + + if k.startswith('_'): + # a boundary here for dealing with when in new data, we have same categories as in 7-cat + v = v.replace(k, k[1:]) + + if isinstance(v, str): + result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0]))) + elif isinstance(v, list): + result = [] + for _v in v: + result = result + sorted(glob(os.path.join(_v, '*'+self.image_loaders[0][0]))) + else: + raise NotImplementedError + + # result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0]))) + result = [p.replace(self.image_loaders[0][0], '{}') for p in result] + sequences = result + + if shuffle: + random.shuffle(sequences) + few_shot_data_paths.update({k: sequences}) + + # for visualization purpose + self.pure_fs_data_path = few_shot_data_paths + + self.all_data_paths, self.one_category_num = self._pad_paths(few_shot_data_paths) + self.all_category_num = len(self.all_data_paths.keys()) + self.all_category_names = list(self.all_data_paths.keys()) + + self.in_image_size = in_image_size + self.out_image_size = out_image_size + self.load_background = load_background + self.color_jitter = color_jitter + self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()]) + self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()]) + self.random_flip = random_flip + self.load_dino_feature = load_dino_feature + if load_dino_feature: + self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)] + + def _pad_paths(self, fs_paths): + img_nums = [] + all_paths = copy.deepcopy(fs_paths) + for _, v in all_paths.items(): + img_nums.append(len(v)) + + img_num = max(img_nums) + img_num = (img_num // self.batch_size) * self.batch_size + + for k,v in all_paths.items(): + if len(v) < img_num: + mul_time = img_num // len(v) + pad_time = img_num % len(v) + # for each v, shuffle it + shuffle_v = copy.deepcopy(v) + new_v = [] + for i in range(mul_time): + new_v = new_v + shuffle_v + random.shuffle(shuffle_v) + del shuffle_v + new_v = new_v + v[0:pad_time] + # new_v = mul_time * v + v[0:pad_time] + all_paths[k] = new_v + elif len(v) > img_num: + all_paths[k] = v[:img_num] + else: + continue + + return all_paths, img_num + + def _load_ids(self, path_patterns, loaders, transform=None): + result = [] + for loader in loaders: + for p in path_patterns: + x = loader[1](p.format(loader[0]), *loader[2:]) + if transform: + x = transform(x) + result.append(x) + return tuple(result) + + def _shuffle_all(self): + for k,v in self.all_data_paths.items(): + new_v = copy.deepcopy(v) + random.shuffle(new_v) + self.all_data_paths[k] = new_v + return None + + def __len__(self): + return self.all_category_num * self.one_category_num + + def __getitem__(self, index): + ''' + This dataset must have non-shuffled index!! + ''' + category_idx = (index % (self.batch_size * self.all_category_num)) // self.batch_size + path_idx = (index // (self.batch_size * self.all_category_num)) * self.batch_size + (index % (self.batch_size * self.all_category_num)) - category_idx * self.batch_size + category_name = self.all_category_names[category_idx] + paths = [self.all_data_paths[category_name][path_idx]] # len 1 sequence + + # if category_name in self.original_category_names: + # bbox_loaders = self.original_bbox_loaders + # use_original_bbox = True + # else: + bbox_loaders = self.few_shot_bbox_loaders + use_original_bbox = False + + masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images + mask_dt = compute_distance_transform(masks) + jitter = False + if self.color_jitter is not None: + prob, b, h = self.color_jitter + if np.random.rand() < prob: + jitter = True + color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) + image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()]) + color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) + image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()]) + if jitter: + images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images + images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images + images = images_fg * masks + images_bg * (1-masks) + else: + images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images + + flows = torch.zeros(1) + bboxs = torch.stack(self._load_ids(paths, bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images + if not use_original_bbox: + bboxs=torch.cat([bboxs, torch.Tensor([[category_idx]]).float()],dim=-1) # pad a label number + + mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image + if self.load_background: + bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg')) + if jitter: + bg_image = color_jitter_tsf_bg(bg_image) + bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size)) + else: + bg_images = torch.zeros_like(images) + if self.load_dino_feature: + dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224 + else: + dino_features = torch.zeros(1) + + dino_clusters = torch.zeros(1) + + # These are actually no use + seq_idx = 0 + seq_idx = torch.LongTensor([seq_idx]) + frame_idx = torch.arange(0, 1).long() + + if self.random_flip and np.random.rand() < 0.5: + images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters) + + ## pad shorter sequence + if len(paths) < self.num_sample_frames: + num_pad = self.num_sample_frames - len(paths) + images = torch.cat([images[:1]] *num_pad + [images], 0) + masks = torch.cat([masks[:1]] *num_pad + [masks], 0) + mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0) + mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0) + if flows.dim() > 1: + flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0) + bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0) + bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0) + if dino_features.dim() > 1: + dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0) + if dino_clusters.dim() > 1: + dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0) + frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0) + + out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name)), ) + return out + # return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name + + + +def get_test_loader_quadrupeds(test_data_dirs, rank, world_size, **kwargs): + dataset = Quadrupeds_Image_Test_Dataset(test_data_dirs, **kwargs) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=False + ) + loaders = [] + loaders += [torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=kwargs['batch_size'], shuffle=False, drop_last=True, num_workers=kwargs['num_workers'], pin_memory=True)] + + return loaders + +def get_sequence_loader(data_dir, **kwargs): + if isinstance(data_dir, dict): + loaders = [] + for k, v in data_dir.items(): + dataset= NFrameSequenceDataset(v, cat_name=k, **kwargs) + loader = torch.utils.data.DataLoader(dataset, batch_size=kwargs['batch_size'], shuffle=kwargs['shuffle'], num_workers=kwargs['num_workers'], pin_memory=True) + loaders += [loader] + return loaders + else: + return [get_sequence_loader_single(data_dir, **kwargs)] + + +def get_sequence_loader_single(data_dir, mode='all_frame', is_validation=False, batch_size=256, num_workers=4, in_image_size=256, out_image_size=256, debug_seq=False, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, max_seq_len=256, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.jpg', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64): + if mode == 'n_frame': + dataset = NFrameSequenceDataset(data_dir, num_sample_frames=num_sample_frames, skip_beginning=skip_beginning, skip_end=skip_end, min_seq_len=min_seq_len, in_image_size=in_image_size, out_image_size=out_image_size, debug_seq=debug_seq, random_sample=random_sample, shuffle=shuffle, dense_sample=dense_sample, color_jitter=color_jitter, load_background=load_background, random_flip=random_flip, rgb_suffix=rgb_suffix, load_dino_feature=load_dino_feature, load_dino_cluster=load_dino_cluster, dino_feature_dim=dino_feature_dim) + else: + raise NotImplementedError + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=not is_validation, + num_workers=num_workers, + pin_memory=True + ) + return loader + + +def get_sequence_loader_ddp(data_dir, world_size, rank, use_few_shot=False, **kwargs): + original_classes_num = 0 + use_few_shot = use_few_shot + if isinstance(data_dir, list) and len(data_dir) == 2 and isinstance(data_dir[-1], dict): + # a hack way for few shot experiment + original_classes_num = data_dir[0] + data_dir = data_dir[-1] + if isinstance(data_dir, dict): + loaders = [] + cnt = original_classes_num + for k, v in data_dir.items(): + if use_few_shot: + dataset = FewShotImageDataset(v, cat_name=k, cat_num=cnt, **kwargs) + cnt += 1 + else: + dataset = NFrameSequenceDataset(v, cat_name=k, **kwargs) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + ) + loaders += [torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=kwargs['batch_size'], shuffle=False, drop_last=True, num_workers=kwargs['num_workers'], pin_memory=True)] + return loaders + else: + return [get_sequence_loader_single_ddp(data_dir, world_size, rank, **kwargs)] + + +def get_sequence_loader_single_ddp(data_dir, world_size, rank, mode='all_frame', is_validation=False, batch_size=256, num_workers=4, in_image_size=256, out_image_size=256, debug_seq=False, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, max_seq_len=256, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.jpg', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64, flow_bool=False): + if mode == 'n_frame': + dataset = NFrameSequenceDataset(data_dir, num_sample_frames=num_sample_frames, skip_beginning=skip_beginning, skip_end=skip_end, min_seq_len=min_seq_len, in_image_size=in_image_size, out_image_size=out_image_size, debug_seq=debug_seq, random_sample=random_sample, shuffle=shuffle, dense_sample=dense_sample, color_jitter=color_jitter, load_background=load_background, random_flip=random_flip, rgb_suffix=rgb_suffix, load_dino_feature=load_dino_feature, load_dino_cluster=load_dino_cluster, dino_feature_dim=dino_feature_dim, flow_bool=flow_bool) + else: + raise NotImplementedError + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + ) + loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + shuffle=False, + drop_last=True, + num_workers=num_workers, + pin_memory=True + ) + return loader + + +class ImageDataset(Dataset): + def __init__(self, root, is_validation=False, image_size=256, color_jitter=None): + super().__init__() + self.image_loader = ("rgb.jpg", torchvision.datasets.folder.default_loader) + self.mask_loader = ("mask.png", torchvision.datasets.folder.default_loader) + self.bbox_loader = ("box.txt", np.loadtxt, 'str') + self.samples = self._parse_folder(root) + self.image_size = image_size + self.color_jitter = color_jitter + self.image_transform = transforms.Compose([transforms.Resize(self.image_size), transforms.ToTensor()]) + self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=Image.NEAREST), transforms.ToTensor()]) + + def _parse_folder(self, path): + result = sorted(glob(os.path.join(path, '**/*'+self.image_loader[0]), recursive=True)) + result = [p.replace(self.image_loader[0], '{}') for p in result] + return result + + def _load_ids(self, path, loader, transform=None): + x = loader[1](path.format(loader[0]), *loader[2:]) + if transform: + x = transform(x) + return x + + def __len__(self): + return len(self.samples) + + def __getitem__(self, index): + path = self.samples[index % len(self.samples)] + masks = self._load_ids(path, self.mask_loader, transform=self.mask_transform).unsqueeze(0) + mask_dt = compute_distance_transform(masks) + jitter = False + if self.color_jitter is not None: + prob, b, h = self.color_jitter + if np.random.rand() < prob: + jitter = True + color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) + image_transform_fg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_fg, transforms.ToTensor()]) + color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) + image_transform_bg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_bg, transforms.ToTensor()]) + if jitter: + images_fg = self._load_ids(path, self.image_loader, transform=image_transform_fg).unsqueeze(0) + images_bg = self._load_ids(path, self.image_loader, transform=image_transform_bg).unsqueeze(0) + images = images_fg * masks + images_bg * (1-masks) + else: + images = self._load_ids(path, self.image_loader, transform=self.image_transform).unsqueeze(0) + flows = torch.zeros(1) + bboxs = self._load_ids(path, self.bbox_loader, transform=None) + bboxs[0] = '0' + bboxs = torch.FloatTensor(bboxs.astype('float')).unsqueeze(0) + bg_fpath = os.path.join(os.path.dirname(path), 'background_frame.jpg') + if os.path.isfile(bg_fpath): + bg_image = torchvision.datasets.folder.default_loader(bg_fpath) + if jitter: + bg_image = color_jitter_tsf_bg(bg_image) + bg_image = transforms.ToTensor()(bg_image) + else: + bg_image = images[0] + seq_idx = torch.LongTensor([index]) + frame_idx = torch.LongTensor([0]) + return images, masks, mask_dt, flows, bboxs, bg_image, seq_idx, frame_idx + + +def get_image_loader(data_dir, is_validation=False, batch_size=256, num_workers=4, image_size=256, color_jitter=None): + dataset = ImageDataset(data_dir, is_validation=is_validation, image_size=image_size, color_jitter=color_jitter) + + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True + ) + return loader + + +def get_image_loader_ddp(data_dir, world_size, rank, is_validation=False, batch_size=256, num_workers=4, image_size=256, color_jitter=None): + dataset = ImageDataset(data_dir, is_validation=is_validation, image_size=image_size, color_jitter=color_jitter) + + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + ) + loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + shuffle=False, + drop_last=True, + num_workers=num_workers, + pin_memory=True + ) + return loader diff --git a/video3d/diffusion/sd.py b/video3d/diffusion/sd.py new file mode 100644 index 0000000000000000000000000000000000000000..7c668f46b9a1fe57dee6e6d86a83790055e3b5ef --- /dev/null +++ b/video3d/diffusion/sd.py @@ -0,0 +1,252 @@ +import os +# os.environ['HUGGINGFACE_HUB_CACHE'] = '/work/tomj/cache/huggingface_hub' +# os.environ['HF_HOME'] = '/work/tomj/cache/huggingface_hub' +os.environ['HUGGINGFACE_HUB_CACHE'] = '/viscam/u/zzli' +os.environ['HF_HOME'] = '/viscam/u/zzli' + +from transformers import CLIPTextModel, CLIPTokenizer, logging +from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler + +# Suppress partial model loading warning +logging.set_verbosity_error() + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.cuda.amp import custom_bwd, custom_fwd + +class SpecifyGradient(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, input_tensor, gt_grad): + ctx.save_for_backward(gt_grad) + return torch.zeros([1], device=input_tensor.device, dtype=input_tensor.dtype) # Dummy loss value + + @staticmethod + @custom_bwd + def backward(ctx, grad): + gt_grad, = ctx.saved_tensors + batch_size = len(gt_grad) + return gt_grad / batch_size, None + +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +class StableDiffusion(nn.Module): + def __init__(self, device, sd_version='2.1', hf_key=None, torch_dtype=torch.float32): + super().__init__() + + self.device = device + self.sd_version = sd_version + self.torch_dtype = torch_dtype + + print(f'[INFO] loading stable diffusion...') + + if hf_key is not None: + print(f'[INFO] using hugging face custom model key: {hf_key}') + model_key = hf_key + elif self.sd_version == '2.1': + model_key = "stabilityai/stable-diffusion-2-1-base" + elif self.sd_version == '2.0': + model_key = "stabilityai/stable-diffusion-2-base" + elif self.sd_version == '1.5': + model_key = "runwayml/stable-diffusion-v1-5" + else: + raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.') + + # Create model + self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", torch_dtype=torch_dtype).to(self.device) + self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") + self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device) + self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device) + + self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") + # self.scheduler = PNDMScheduler.from_pretrained(model_key, subfolder="scheduler") + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience + + print(f'[INFO] loaded stable diffusion!') + + def get_text_embeds(self, prompt, negative_prompt): + # prompt, negative_prompt: [str] + + # Tokenize text and get embeddings + text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt') + + with torch.no_grad(): + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + + # Do the same for unconditional embeddings + uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') + + with torch.no_grad(): + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # Cat for final embeddings + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + return text_embeddings + + def train_step(self, text_embeddings, pred_rgb, + guidance_scale=100, loss_weight=1.0, min_step_pct=0.02, max_step_pct=0.98, return_aux=False): + pred_rgb = pred_rgb.to(self.torch_dtype) + text_embeddings = text_embeddings.to(self.torch_dtype) + b = pred_rgb.shape[0] + + # interp to 512x512 to be fed into vae. + + # _t = time.time() + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + # torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s') + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + min_step = int(self.num_train_timesteps * min_step_pct) + max_step = int(self.num_train_timesteps * max_step_pct) + t = torch.randint(min_step, max_step + 1, [b], dtype=torch.long, device=self.device) + + # encode image into latents with vae, requires grad! + # _t = time.time() + latents = self.encode_imgs(pred_rgb_512) + # torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s') + + # predict the noise residual with unet, NO grad! + # _t = time.time() + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + t_input = torch.cat([t, t]) + noise_pred = self.unet(latent_model_input, t_input, encoder_hidden_states=text_embeddings).sample + # torch.cuda.synchronize(); print(f'[TIME] guiding: unet {time.time() - _t:.4f}s') + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + # noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # w(t), sigma_t^2 + w = (1 - self.alphas[t]) + # w = self.alphas[t] ** 0.5 * (1 - self.alphas[t]) + grad = loss_weight * w[:, None, None, None] * (noise_pred - noise) + + # clip grad for stable training? + # grad = grad.clamp(-10, 10) + grad = torch.nan_to_num(grad) + + # since we omitted an item in grad, we need to use the custom function to specify the gradient + # _t = time.time() + # loss = SpecifyGradient.apply(latents, grad) + # torch.cuda.synchronize(); print(f'[TIME] guiding: backward {time.time() - _t:.4f}s') + + targets = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] + + if return_aux: + aux = {'grad': grad, 't': t, 'w': w} + return loss, aux + else: + return loss + + + def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): + + if latents is None: + latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.config.in_channels, height // 8, width // 8), device=self.device) + + self.scheduler.set_timesteps(num_inference_steps) + + with torch.autocast('cuda'): + for i, t in enumerate(self.scheduler.timesteps): + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latent_model_input = torch.cat([latents] * 2) + + # predict the noise residual + with torch.no_grad(): + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] + + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents)['prev_sample'] + + return latents + + def decode_latents(self, latents): + + latents = 1 / self.vae.config.scaling_factor * latents + + with torch.no_grad(): + imgs = self.vae.decode(latents).sample + + imgs = (imgs / 2 + 0.5).clamp(0, 1) + + return imgs + + def encode_imgs(self, imgs): + # imgs: [B, 3, H, W] + + imgs = 2 * imgs - 1 + + posterior = self.vae.encode(imgs).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor + + return latents + + def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): + + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(negative_prompts, str): + negative_prompts = [negative_prompts] + + # Prompts -> text embeds + text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768] + + # Text embeds -> img latents + latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64] + + # Img latents -> imgs + imgs = self.decode_latents(latents) # [1, 3, 512, 512] + + # Img to Numpy + imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() + imgs = (imgs * 255).round().astype('uint8') + + return imgs + + +if __name__ == '__main__': + import argparse + import matplotlib.pyplot as plt + + parser = argparse.ArgumentParser() + parser.add_argument('prompt', type=str) + parser.add_argument('--negative', default='', type=str) + parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version") + parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key") + parser.add_argument('-H', type=int, default=512) + parser.add_argument('-W', type=int, default=512) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--steps', type=int, default=50) + opt = parser.parse_args() + + seed_everything(opt.seed) + + device = torch.device('cuda') + + sd = StableDiffusion(device, opt.sd_version, opt.hf_key) + + imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) + + # visualize image + plt.imshow(imgs[0]) + plt.show() + plt.savefig(f'{opt.prompt}.png') \ No newline at end of file diff --git a/video3d/diffusion/sd_utils.py b/video3d/diffusion/sd_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9e29b23ca39e1ee0ab7986273a2a003ba1055b --- /dev/null +++ b/video3d/diffusion/sd_utils.py @@ -0,0 +1,123 @@ +import torch +import numpy as np +import random +import torch.nn.functional as F + +from ..render.light import DirectionalLight + +def safe_normalize(x, eps=1e-20): + return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps)) + +def get_view_direction(thetas, phis, overhead, front, phi_offset=0): + # phis [B,]; thetas: [B,] + # front = 0 [360 - front / 2, front / 2) + # side (left) = 1 [front / 2, 180 - front / 2) + # back = 2 [180 - front / 2, 180 + front / 2) + # side (right) = 3 [180 + front / 2, 360 - front / 2) + # top = 4 [0, overhead] + # bottom = 5 [180-overhead, 180] + res = torch.zeros(thetas.shape[0], dtype=torch.long) + + # first determine by phis + phi_offset = np.deg2rad(phi_offset) + phis = phis + phi_offset + phis = phis % (2 * np.pi) + half_front = front / 2 + + res[(phis >= (2*np.pi - half_front)) | (phis < half_front)] = 0 + res[(phis >= half_front) & (phis < (np.pi - half_front))] = 1 + res[(phis >= (np.pi - half_front)) & (phis < (np.pi + half_front))] = 2 + res[(phis >= (np.pi + half_front)) & (phis < (2*np.pi - half_front))] = 3 + + # override by thetas + res[thetas <= overhead] = 4 + res[thetas >= (np.pi - overhead)] = 5 + return res + + +def view_direction_id_to_text(view_direction_id): + dir_texts = ['front', 'side', 'back', 'side', 'overhead', 'bottom'] + return [dir_texts[i] for i in view_direction_id] + + +def append_text_direction(prompts, dir_texts): + return [f'{prompt}, {dir_text} view' for prompt, dir_text in zip(prompts, dir_texts)] + + +def rand_lights(camera_dir, fixed_ambient, fixed_diffuse): + size = camera_dir.shape[0] + device = camera_dir.device + random_fixed_dir = F.normalize(torch.randn_like(camera_dir) + camera_dir, dim=-1) # Centered around camera_dir + random_fixed_intensity = torch.tensor([fixed_ambient, fixed_diffuse], device=device)[None, :].repeat(size, 1) # ambient, diffuse + return DirectionalLight(mlp_in=1, mlp_layers=1, mlp_hidden_size=1, # Dummy values + intensity_min_max=[0.5, 1],fixed_dir=random_fixed_dir, fixed_intensity=random_fixed_intensity).to(device) + +def rand_poses(size, device, radius_range=[1, 1], theta_range=[0, 120], phi_range=[0, 360], cam_z_offset=10, return_dirs=False, angle_overhead=30, angle_front=60, phi_offset=0, jitter=False, uniform_sphere_rate=0.5): + ''' generate random poses from an orbit camera + Args: + size: batch size of generated poses. + device: where to allocate the output. + radius_range: [min, max] + theta_range: [min, max], should be in [0, pi] + phi_range: [min, max], should be in [0, 2 * pi] + Return: + poses: [size, 4, 4] + ''' + + theta_range = np.deg2rad(theta_range) + phi_range = np.deg2rad(phi_range) + angle_overhead = np.deg2rad(angle_overhead) + angle_front = np.deg2rad(angle_front) + + radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0] + + phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] + if random.random() < uniform_sphere_rate: + # based on http://corysimon.github.io/articles/uniformdistn-on-sphere/ + # acos takes in [-1, 1], first convert theta range to fit in [-1, 1] + theta_range = torch.from_numpy(np.array(theta_range)).to(device) + theta_amplitude_range = torch.cos(theta_range) + # sample uniformly in amplitude space range + thetas_amplitude = torch.rand(size, device=device) * (theta_amplitude_range[1] - theta_amplitude_range[0]) + theta_amplitude_range[0] + # convert back + thetas = torch.acos(thetas_amplitude) + else: + thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] + + centers = -torch.stack([ + radius * torch.sin(thetas) * torch.sin(phis), + radius * torch.cos(thetas), + radius * torch.sin(thetas) * torch.cos(phis), + ], dim=-1) # [B, 3] + + targets = 0 + + # jitters + if jitter: + centers = centers + (torch.rand_like(centers) * 0.2 - 0.1) + targets = targets + torch.randn_like(centers) * 0.2 + + # lookat + forward_vector = safe_normalize(targets - centers) + up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1) + right_vector = safe_normalize(torch.cross(up_vector, forward_vector, dim=-1)) + + if jitter: + up_noise = torch.randn_like(up_vector) * 0.02 + else: + up_noise = 0 + + up_vector = safe_normalize(torch.cross(forward_vector, right_vector, dim=-1) + up_noise) + + poses = torch.stack([right_vector, up_vector, forward_vector], dim=-1) + radius = radius[..., None] - cam_z_offset + translations = torch.cat([torch.zeros_like(radius), torch.zeros_like(radius), radius], dim=-1) + poses = torch.cat([poses.view(-1, 9), translations], dim=-1) + + if return_dirs: + dirs = get_view_direction(thetas, phis, angle_overhead, angle_front, phi_offset=phi_offset) + dirs = view_direction_id_to_text(dirs) + else: + dirs = None + + return poses, dirs diff --git a/video3d/diffusion/vsd.py b/video3d/diffusion/vsd.py new file mode 100644 index 0000000000000000000000000000000000000000..86df4adf4132bc62fc1730fe7eadb1c0e6806b99 --- /dev/null +++ b/video3d/diffusion/vsd.py @@ -0,0 +1,323 @@ +import os +os.environ['HUGGINGFACE_HUB_CACHE'] = '/viscam/u/zzli' +os.environ['HF_HOME'] = '/viscam/u/zzli' + +from transformers import CLIPTextModel, CLIPTokenizer, logging +from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler + +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.embeddings import TimestepEmbedding +from diffusers.utils.import_utils import is_xformers_available + +# Suppress partial model loading warning +logging.set_verbosity_error() + +import gc +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import tinycudann as tcnn +from video3d.diffusion.sd import StableDiffusion +from torch.cuda.amp import custom_bwd, custom_fwd + + +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + +def cleanup(): + gc.collect() + torch.cuda.empty_cache() + tcnn.free_temporary_memory() + +class StableDiffusion_VSD(StableDiffusion): + def __init__(self, device, sd_version='2.1', hf_key=None, torch_dtype=torch.float32, lora_n_timestamp_samples=1): + super().__init__(device, sd_version=sd_version, hf_key=hf_key, torch_dtype=torch_dtype) + + # self.device = device + # self.sd_version = sd_version + # self.torch_dtype = torch_dtype + + if hf_key is not None: + print(f'[INFO] using hugging face custom model key: {hf_key}') + model_key = hf_key + elif self.sd_version == '2.1': + model_key = "stabilityai/stable-diffusion-2-1-base" + elif self.sd_version == '2.0': + model_key = "stabilityai/stable-diffusion-2-base" + elif self.sd_version == '1.5': + model_key = "runwayml/stable-diffusion-v1-5" + else: + raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.') + + # # Create model + # self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", torch_dtype=torch_dtype).to(self.device) + # self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") + # self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device) + # self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device) + + # self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") + # # self.scheduler = PNDMScheduler.from_pretrained(model_key, subfolder="scheduler") + + # self.num_train_timesteps = self.scheduler.config.num_train_timesteps + # self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience + + print(f'[INFO] loading stable diffusion VSD modules...') + + self.unet_lora = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device) + cleanup() + + for p in self.vae.parameters(): + p.requires_grad_(False) + for p in self.text_encoder.parameters(): + p.requires_grad_(False) + for p in self.unet.parameters(): + p.requires_grad_(False) + for p in self.unet_lora.parameters(): + p.requires_grad_(False) + + # set up LoRA layers + lora_attn_procs = {} + for name in self.unet_lora.attn_processors.keys(): + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else self.unet_lora.config.cross_attention_dim + ) + if name.startswith("mid_block"): + hidden_size = self.unet_lora.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.unet_lora.config.block_out_channels))[ + block_id + ] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.unet_lora.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + + self.unet_lora.set_attn_processor(lora_attn_procs) + + self.lora_layers = AttnProcsLayers(self.unet_lora.attn_processors).to( + self.device + ) + self.lora_layers._load_state_dict_pre_hooks.clear() + self.lora_layers._state_dict_hooks.clear() + self.lora_n_timestamp_samples = lora_n_timestamp_samples + self.scheduler_lora = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") + + print(f'[INFO] loaded stable diffusion VSD modules!') + + def train_lora( + self, + latents, + text_embeddings, + camera_condition + ): + B = latents.shape[0] + lora_n_timestamp_samples = self.lora_n_timestamp_samples + latents = latents.detach().repeat(lora_n_timestamp_samples, 1, 1, 1) + + t = torch.randint( + int(self.num_train_timesteps * 0.0), + int(self.num_train_timesteps * 1.0), + [B * lora_n_timestamp_samples], + dtype=torch.long, + device=self.device, + ) + + noise = torch.randn_like(latents) + noisy_latents = self.scheduler_lora.add_noise(latents, noise, t) + if self.scheduler_lora.config.prediction_type == "epsilon": + target = noise + elif self.scheduler_lora.config.prediction_type == "v_prediction": + target = self.scheduler_lora.get_velocity(latents, noise, t) + else: + raise ValueError( + f"Unknown prediction type {self.scheduler_lora.config.prediction_type}" + ) + + # use view-independent text embeddings in LoRA + _, text_embeddings_cond = text_embeddings.chunk(2) + + if random.random() < 0.1: + camera_condition = torch.zeros_like(camera_condition) + + noise_pred = self.unet_lora( + noisy_latents, + t, + encoder_hidden_states=text_embeddings_cond.repeat( + lora_n_timestamp_samples, 1, 1 + ), + class_labels=camera_condition.reshape(B, -1).repeat( + lora_n_timestamp_samples, 1 + ), + cross_attention_kwargs={"scale": 1.0} + ).sample + + loss_lora = 0.5 * F.mse_loss(noise_pred.float(), target.float(), reduction="mean") + return loss_lora + + + def train_step( + self, + text_embeddings, + text_embeddings_vd, + pred_rgb, + camera_condition, + im_features, + guidance_scale=7.5, + guidance_scale_lora=7.5, + loss_weight=1.0, + min_step_pct=0.02, + max_step_pct=0.98, + return_aux=False + ): + pred_rgb = pred_rgb.to(self.torch_dtype) + text_embeddings = text_embeddings.to(self.torch_dtype) + text_embeddings_vd = text_embeddings_vd.to(self.torch_dtype) + camera_condition = camera_condition.to(self.torch_dtype) + im_features = im_features.to(self.torch_dtype) + + # condition_label = camera_condition + condition_label = im_features + + b = pred_rgb.shape[0] + + # interp to 512x512 to be fed into vae. + # _t = time.time() + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + # torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s') + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + min_step = int(self.num_train_timesteps * min_step_pct) + max_step = int(self.num_train_timesteps * max_step_pct) + t = torch.randint(min_step, max_step + 1, [b], dtype=torch.long, device=self.device) + + # encode image into latents with vae, requires grad! + # _t = time.time() + latents = self.encode_imgs(pred_rgb_512) + # torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s') + + # predict the noise residual with unet, NO grad! + # _t = time.time() + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + + # disable unet class embedding here + cls_embedding = self.unet.class_embedding + self.unet.class_embedding = None + + cross_attention_kwargs = None + noise_pred_pretrain = self.unet( + latent_model_input, + torch.cat([t, t]), + encoder_hidden_states=text_embeddings_vd, + class_labels=None, + cross_attention_kwargs=cross_attention_kwargs + ).sample + + self.unet.class_embedding = cls_embedding + + # use view-independent text embeddings in LoRA + _, text_embeddings_cond = text_embeddings.chunk(2) + + noise_pred_est = self.unet_lora( + latent_model_input, + torch.cat([t, t]), + encoder_hidden_states=torch.cat([text_embeddings_cond] * 2), + class_labels=torch.cat( + [ + condition_label.reshape(b, -1), + torch.zeros_like(condition_label.reshape(b, -1)), + ], + dim=0, + ), + cross_attention_kwargs={"scale": 1.0}, + ).sample + + noise_pred_pretrain_uncond, noise_pred_pretrain_text = noise_pred_pretrain.chunk(2) + + noise_pred_pretrain = noise_pred_pretrain_uncond + guidance_scale * ( + noise_pred_pretrain_text - noise_pred_pretrain_uncond + ) + + assert self.scheduler.config.prediction_type == "epsilon" + if self.scheduler_lora.config.prediction_type == "v_prediction": + alphas_cumprod = self.scheduler_lora.alphas_cumprod.to( + device=latents_noisy.device, dtype=latents_noisy.dtype + ) + alpha_t = alphas_cumprod[t] ** 0.5 + sigma_t = (1 - alphas_cumprod[t]) ** 0.5 + + noise_pred_est = latent_model_input * torch.cat([sigma_t] * 2, dim=0).reshape( + -1, 1, 1, 1 + ) + noise_pred_est * torch.cat([alpha_t] * 2, dim=0).reshape(-1, 1, 1, 1) + + noise_pred_est_uncond, noise_pred_est_camera = noise_pred_est.chunk(2) + + noise_pred_est = noise_pred_est_uncond + guidance_scale_lora * ( + noise_pred_est_camera - noise_pred_est_uncond + ) + + # w(t), sigma_t^2 + w = (1 - self.alphas[t]) + # w = self.alphas[t] ** 0.5 * (1 - self.alphas[t]) + grad = loss_weight * w[:, None, None, None] * (noise_pred_pretrain - noise_pred_est) + + grad = torch.nan_to_num(grad) + + targets = (latents - grad).detach() + loss_vsd = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] + + loss_lora = self.train_lora(latents, text_embeddings, condition_label) + + loss = { + 'loss_vsd': loss_vsd, + 'loss_lora': loss_lora + } + + if return_aux: + aux = {'grad': grad, 't': t, 'w': w} + return loss, aux + else: + return loss + + + +if __name__ == '__main__': + import argparse + import matplotlib.pyplot as plt + + parser = argparse.ArgumentParser() + parser.add_argument('prompt', type=str) + parser.add_argument('--negative', default='', type=str) + parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version") + parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key") + parser.add_argument('-H', type=int, default=512) + parser.add_argument('-W', type=int, default=512) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--steps', type=int, default=50) + opt = parser.parse_args() + + seed_everything(opt.seed) + + device = torch.device('cuda') + + sd = StableDiffusion_VSD(device, opt.sd_version, opt.hf_key) + + imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) + + # visualize image + plt.imshow(imgs[0]) + plt.show() + plt.savefig(f'{opt.prompt}.png') \ No newline at end of file diff --git a/video3d/discriminator_architecture.py b/video3d/discriminator_architecture.py new file mode 100644 index 0000000000000000000000000000000000000000..07e675f4e682ef0dbb5c528e5adea8932d23c569 --- /dev/null +++ b/video3d/discriminator_architecture.py @@ -0,0 +1,83 @@ +import torch.nn as nn +import torch +from math import log2 +import torch.nn.functional as F +from torch import autograd + + +class DCDiscriminator(nn.Module): + ''' DC Discriminator class. + + Args: + in_dim (int): input dimension + n_feat (int): features of final hidden layer + img_size (int): input image size + ''' + def __init__(self, in_dim=1, out_dim=1, n_feat=512, img_size=256, last_bias=False): + super().__init__() + + self.in_dim = in_dim + self.out_dim = out_dim + n_layers = int(log2(img_size) - 2) + self.blocks = nn.ModuleList( + [nn.Conv2d( + in_dim, + int(n_feat / (2 ** (n_layers - 1))), + 4, 2, 1, bias=False)] + [nn.Conv2d( + int(n_feat / (2 ** (n_layers - i))), + int(n_feat / (2 ** (n_layers - 1 - i))), + 4, 2, 1, bias=False) for i in range(1, n_layers)]) + + self.conv_out = nn.Conv2d(n_feat, out_dim, 4, 1, 0, bias=last_bias) + self.actvn = nn.LeakyReLU(0.2, inplace=True) + + def forward(self, x): + batch_size = x.shape[0] + if x.shape[1] != self.in_dim: + import ipdb; ipdb.set_trace() + x = x[:, :self.in_dim] + for layer in self.blocks: + x = self.actvn(layer(x)) + + out = self.conv_out(x) + out = out.reshape(batch_size, self.out_dim) + return out + + +# class ADADiscriminator(DCDiscriminator): +# def __init__(self, aug, aug_p, **kwargs): +# super().__init__(**kwargs) +# self.aug = build_from_config(aug) +# self.aug.p.copy_(torch.tensor(aug_p, dtype=torch.float32)) +# self.resolution = kwargs['img_size'] + +# def get_resolution(self): +# return self.resolution + +# def forward(self, x, **kwargs): +# x = self.aug(x) +# return super().forward(x, **kwargs) + + +# class ADADiscriminatorView(ADADiscriminator): +# def __init__(self, out_dim_position, out_dim_latent, **kwargs): +# self.out_dim_position = out_dim_position +# self.out_dim_latent = out_dim_latent + +# super().__init__(**kwargs) + +def bce_loss_target(d_out, target): + targets = d_out.new_full(size=d_out.size(), fill_value=target) + loss = F.binary_cross_entropy_with_logits(d_out, targets) + return loss.mean() + +def compute_grad2(d_out, x_in): + batch_size = x_in.size(0) + grad_dout = autograd.grad( + outputs=d_out.sum(), inputs=x_in, + create_graph=True, retain_graph=True, only_inputs=True + )[0] + grad_dout2 = grad_dout.pow(2) + assert(grad_dout2.size() == x_in.size()) + reg = grad_dout2.reshape(batch_size, -1).sum(1) + return reg.mean() \ No newline at end of file diff --git a/video3d/flow/__init__.py b/video3d/flow/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/video3d/flow/flow.py b/video3d/flow/flow.py new file mode 100755 index 0000000000000000000000000000000000000000..e18270c0a9db9847c6b560b3c4b30eebe9505e5b --- /dev/null +++ b/video3d/flow/flow.py @@ -0,0 +1,51 @@ +from numpy.lib.npyio import load +from torch._C import device +import sys +sys.path.append('/scratch/shared/beegfs/szwu/projects/video3d/RAFT') +from core.raft import RAFT + +from .utils import InputPadder +import torch + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + + +class FlowModel(): + def __init__(self, model, device): + args = AttrDict({'model': model, 'small': False, 'mixed_precision': False, 'alternate_corr': False}) + self.model = self.load_model(args, device) + self.device = device + + + @staticmethod + def load_model(args, device): + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.model)) + + model = model.module + model.to(device) + model.eval() + return model + + + def preprocess_image(self, image): + # image = image[:, :, ::-1].copy() + image = torch.from_numpy(image).permute(2, 0, 1).float() + image = image.to(self.device) + image = image[None] + # size = [540, 960] + # image = torch.nn.functional.interpolate(image, size=size, mode='bilinear', align_corners=False) + padder = InputPadder(image.shape) + return padder.pad(image)[0], padder + + + def compute_flow(self, frame, next_frame, iters=20): + frame, padder = self.preprocess_image(frame) + next_frame, padder = self.preprocess_image(next_frame) + _, flow = self.model(frame, next_frame, iters=iters, test_mode=True) + return padder.unpad(flow)[0].permute(1, 2, 0).cpu().numpy() diff --git a/video3d/flow/utils.py b/video3d/flow/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..d39d2a5da1da49371252bb66afd73f93c3f9c5e8 --- /dev/null +++ b/video3d/flow/utils.py @@ -0,0 +1,23 @@ +# Taken from RAFT + +import torch.nn.functional as F + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] diff --git a/video3d/geometry/dlmesh.py b/video3d/geometry/dlmesh.py new file mode 100755 index 0000000000000000000000000000000000000000..3d90450cf8fab1e3b3b68fb0b35dfafa94bd1bc4 --- /dev/null +++ b/video3d/geometry/dlmesh.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +from ..render import mesh +from ..render import render +from ..render import regularizer + +############################################################################### +# Geometry interface +############################################################################### + +class DLMesh(torch.nn.Module): + def __init__(self, initial_guess, FLAGS): + super(DLMesh, self).__init__() + + self.FLAGS = FLAGS + + self.initial_guess = initial_guess + self.mesh = initial_guess.clone() + print("Base mesh has %d triangles and %d vertices." % (self.mesh.t_pos_idx.shape[0], self.mesh.v_pos.shape[0])) + + self.mesh.v_pos = torch.nn.Parameter(self.mesh.v_pos, requires_grad=True) + self.register_parameter('vertex_pos', self.mesh.v_pos) + + @torch.no_grad() + def getAABB(self): + return mesh.aabb(self.mesh) + + def getMesh(self, material): + self.mesh.material = material + + imesh = mesh.Mesh(base=self.mesh) + # Compute normals and tangent space + imesh = mesh.auto_normals(imesh) + imesh = mesh.compute_tangents(imesh) + return imesh + + def render(self, glctx, target, lgt, opt_material, bsdf=None): + opt_mesh = self.getMesh(opt_material) + return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], + num_layers=self.FLAGS.layers, msaa=True, background=target['background'], bsdf=bsdf) + + def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration): + + # ============================================================================================== + # Render optimizable object with identical conditions + # ============================================================================================== + buffers = self.render(glctx, target, lgt, opt_material) + + # ============================================================================================== + # Compute loss + # ============================================================================================== + t_iter = iteration / self.FLAGS.iter + + # Image-space loss, split into a coverage component and a color component + color_ref = target['img'] + img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:]) + img_loss += loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:]) + + reg_loss = torch.tensor([0], dtype=torch.float32, device="cuda") + + # Compute regularizer. + if self.FLAGS.laplace == "absolute": + reg_loss += regularizer.laplace_regularizer_const(self.mesh.v_pos, self.mesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter) + elif self.FLAGS.laplace == "relative": + reg_loss += regularizer.laplace_regularizer_const(self.mesh.v_pos - self.initial_guess.v_pos, self.mesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter) + + # Albedo (k_d) smoothnesss regularizer + reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500) + + # Visibility regularizer + reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 0.001 * min(1.0, iteration / 500) + + # Light white balance regularizer + reg_loss = reg_loss + lgt.regularizer() * 0.005 + + return img_loss, reg_loss \ No newline at end of file diff --git a/video3d/geometry/dmtet.py b/video3d/geometry/dmtet.py new file mode 100755 index 0000000000000000000000000000000000000000..062b6658db42f98dcc658a48b274cece8a4625bb --- /dev/null +++ b/video3d/geometry/dmtet.py @@ -0,0 +1,361 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from multiprocessing.spawn import get_preparation_data +import numpy as np +import torch + +from ..render import mesh +from ..render import render +from ..networks import MLPWithPositionalEncoding, MLPWithPositionalEncoding_Style + +############################################################################### +# Marching tetrahedrons implementation (differentiable), adapted from +# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py +# +# Note this only supports batch size = 1. +############################################################################### + +class DMTet: + def __init__(self): + self.triangle_table = torch.tensor([ + [-1, -1, -1, -1, -1, -1], + [ 1, 0, 2, -1, -1, -1], + [ 4, 0, 3, -1, -1, -1], + [ 1, 4, 2, 1, 3, 4], + [ 3, 1, 5, -1, -1, -1], + [ 2, 3, 0, 2, 5, 3], + [ 1, 4, 0, 1, 5, 4], + [ 4, 2, 5, -1, -1, -1], + [ 4, 5, 2, -1, -1, -1], + [ 4, 1, 0, 4, 5, 1], + [ 3, 2, 0, 3, 5, 2], + [ 1, 3, 5, -1, -1, -1], + [ 4, 1, 2, 4, 3, 1], + [ 3, 0, 4, -1, -1, -1], + [ 2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1] + ], dtype=torch.long, device='cuda') + + self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda') + self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda') + + ############################################################################### + # Utility functions + ############################################################################### + + def sort_edges(self, edges_ex2): + with torch.no_grad(): + order = (edges_ex2[:,0] > edges_ex2[:,1]).long() + order = order.unsqueeze(dim=1) + + a = torch.gather(input=edges_ex2, index=order, dim=1) + b = torch.gather(input=edges_ex2, index=1-order, dim=1) + + return torch.stack([a, b],-1) + + def map_uv(self, faces, face_gidx, max_idx): + N = int(np.ceil(np.sqrt((max_idx+1)//2))) + tex_y, tex_x = torch.meshgrid( + torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), + torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), + indexing='ij' + ) + + pad = 0.9 / N + + uvs = torch.stack([ + tex_x , tex_y, + tex_x + pad, tex_y, + tex_x + pad, tex_y + pad, + tex_x , tex_y + pad + ], dim=-1).view(-1, 2) + + def _idx(tet_idx, N): + x = tet_idx % N + y = torch.div(tet_idx, N, rounding_mode='trunc') + return y * N + x + + tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N) + tri_idx = face_gidx % 2 + + uv_idx = torch.stack(( + tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2 + ), dim = -1). view(-1, 3) + + return uvs, uv_idx + + ############################################################################### + # Marching tets implementation + ############################################################################### + + def __call__(self, pos_nx3, sdf_n, tet_fx4): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum>0) & (occ_sum<4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2) + all_edges = self.sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda") + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1) + edges_to_interp_sdf[:,-1] *= -1 + + denominator = edges_to_interp_sdf.sum(1,keepdim = True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1,6) + + v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda")) + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = self.num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat(( + torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3), + torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3), + ), dim=0) + + # Get global face index (static, does not depend on topology) + num_tets = tet_fx4.shape[0] + tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets] + face_gidx = torch.cat(( + tet_gidx[num_triangles == 1]*2, + torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1) + ), dim=0) + + uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2) + + return verts, faces, uvs, uv_idx + +############################################################################### +# Regularizer +############################################################################### + +def sdf_bce_reg_loss(sdf, all_edges): + sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2) + mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \ + torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float()) + if torch.isnan(sdf_diff).any(): + import ipdb; ipdb.set_trace() + return sdf_diff + +############################################################################### +# Geometry interface +############################################################################### + +class DMTetGeometry(torch.nn.Module): + def __init__(self, grid_res, scale, sdf_mode, num_layers=None, hidden_size=None, embedder_freq=None, embed_concat_pts=True, init_sdf=None, jitter_grid=0., perturb_sdf_iter=10000, sym_prior_shape=False, dim_of_classes=0, condition_choice='concat'): + super(DMTetGeometry, self).__init__() + + self.sdf_mode = sdf_mode + self.grid_res = grid_res + self.marching_tets = DMTet() + self.grid_scale = scale + self.init_sdf = init_sdf + self.jitter_grid = jitter_grid + self.perturb_sdf_iter = perturb_sdf_iter + self.sym_prior_shape = sym_prior_shape + self.load_tets(self.grid_res, self.grid_scale) + + if sdf_mode == "param": + sdf = torch.rand_like(self.verts[:,0]) - 0.1 # Random init. + self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True) + self.register_parameter('sdf', self.sdf) + self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True) + self.register_parameter('deform', self.deform) + else: + embedder_scaler = 2 * np.pi / self.grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9 + + if dim_of_classes == 0 or (dim_of_classes != 0 and condition_choice == 'concat'): + self.mlp = MLPWithPositionalEncoding( + 3, + 1, + num_layers, + nf=hidden_size, + extra_dim=dim_of_classes, + dropout=0, + activation=None, + n_harmonic_functions=embedder_freq, + omega0=embedder_scaler, + embed_concat_pts=embed_concat_pts) + + elif condition_choice == 'film' or condition_choice == 'mod': + self.mlp = MLPWithPositionalEncoding_Style( + 3, + 1, + num_layers, + nf=hidden_size, + extra_dim=dim_of_classes, + dropout=0, + activation=None, + n_harmonic_functions=embedder_freq, + omega0=embedder_scaler, + embed_concat_pts=embed_concat_pts, + style_choice=condition_choice) + + else: + raise NotImplementedError + + def load_tets(self, grid_res=None, scale=None): + if grid_res is None: + grid_res = self.grid_res + else: + self.grid_res = grid_res + if scale is None: + scale = self.grid_scale + else: + self.grid_scale = scale + tets = np.load('./data/tets/{}_tets.npz'.format(grid_res)) + self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * scale # verts original scale (-0.5, 0.5) + self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda') + self.generate_edges() + + def get_sdf(self, pts=None, perturb_sdf=False, total_iter=0, class_vector=None): + if self.sdf_mode == 'param': + sdf = self.sdf + else: + if pts is None: + pts = self.verts + if self.sym_prior_shape: + xs, ys, zs = pts.unbind(-1) + pts = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x + feat = None + if class_vector is not None: + feat = class_vector.unsqueeze(0).repeat(pts.shape[0], 1) + sdf = self.mlp(pts, feat=feat) + + if self.init_sdf is None: + pass + elif type(self.init_sdf) in [float, int]: + sdf = sdf + self.init_sdf + elif self.init_sdf == 'sphere': + init_radius = self.grid_scale * 0.25 + init_sdf = init_radius - pts.norm(dim=-1, keepdim=True) # init sdf is a sphere centered at origin + sdf = sdf + init_sdf + elif self.init_sdf == 'ellipsoid': + rxy = self.grid_scale * 0.15 + xs, ys, zs = pts.unbind(-1)[:3] + init_sdf = rxy - torch.stack([xs, ys, zs/2], -1).norm(dim=-1, keepdim=True) # init sdf is approximately an ellipsoid centered at origin + sdf = sdf + init_sdf + else: + raise NotImplementedError + + if perturb_sdf: + sdf = sdf + torch.randn_like(sdf) * 0.1 * max(0, 1-total_iter/self.perturb_sdf_iter) + return sdf + + def get_sdf_gradient(self, class_vector=None): + assert self.sdf_mode == 'mlp', "Only MLP supports gradient computation." + num_samples = 5000 + sample_points = (torch.rand(num_samples, 3, device=self.verts.device) - 0.5) * self.grid_scale + mesh_verts = self.mesh_verts.detach() + (torch.rand_like(self.mesh_verts) -0.5) * 0.1 * self.grid_scale + rand_idx = torch.randperm(len(mesh_verts), device=mesh_verts.device)[:5000] + mesh_verts = mesh_verts[rand_idx] + sample_points = torch.cat([sample_points, mesh_verts], 0) + sample_points.requires_grad = True + y = self.get_sdf(pts=sample_points, perturb_sdf=False, class_vector=class_vector) + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + try: + gradients = torch.autograd.grad( + outputs=[y], + inputs=sample_points, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + except RuntimeError: # For validation, we have disabled gradient calculation. + return torch.zeros_like(sample_points) + return gradients + + def get_sdf_reg_loss(self, class_vector=None): + reg_loss = {"sdf_bce_reg_loss": sdf_bce_reg_loss(self.current_sdf, self.all_edges).mean()} + if self.sdf_mode == 'mlp': + reg_loss["sdf_gradient_reg_loss"] = ((self.get_sdf_gradient(class_vector=class_vector).norm(dim=-1) - 1) ** 2).mean() + reg_loss['sdf_inflate_reg_loss'] = -self.current_sdf.mean() + return reg_loss + + def generate_edges(self): + with torch.no_grad(): + edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda") + all_edges = self.indices[:,edges].reshape(-1,2) + all_edges_sorted = torch.sort(all_edges, dim=1)[0] + self.all_edges = torch.unique(all_edges_sorted, dim=0) + + @torch.no_grad() + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + def getMesh(self, material=None, perturb_sdf=False, total_iter=0, jitter_grid=True, class_vector=None): + # Run DM tet to get a base mesh + v_deformed = self.verts + + # if self.FLAGS.deform_grid: + # v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform) + # else: + # v_deformed = self.verts + if jitter_grid and self.jitter_grid > 0: + jitter = (torch.rand(1, device=v_deformed.device)*2-1) * self.jitter_grid * self.grid_scale + v_deformed = v_deformed + jitter + + self.current_sdf = self.get_sdf(v_deformed, perturb_sdf=perturb_sdf, total_iter=total_iter, class_vector=class_vector) + verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, self.current_sdf, self.indices) + self.mesh_verts = verts + return mesh.make_mesh(verts[None], faces[None], uvs[None], uv_idx[None], material) + + def render(self, glctx, target, lgt, opt_material, bsdf=None): + opt_mesh = self.getMesh(opt_material) + return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], msaa=True, background=target['background'], bsdf=bsdf) + + def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration): + # ============================================================================================== + # Render optimizable object with identical conditions + # ============================================================================================== + buffers = self.render(glctx, target, lgt, opt_material) + + # ============================================================================================== + # Compute loss + # ============================================================================================== + t_iter = iteration / 20000 + + # Image-space loss, split into a coverage component and a color component + color_ref = target['img'] + img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:]) + img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:]) + + # SDF regularizer + # sdf_weight = self.sdf_regularizer - (self.sdf_regularizer - 0.01) * min(1.0, 4.0 * t_iter) # Dropoff to 0.01 + reg_loss = sum(self.get_sdf_reg_loss().values) + + # Albedo (k_d) smoothnesss regularizer + reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500) + + # Visibility regularizer + reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 0.001 * min(1.0, iteration / 500) + + # Light white balance regularizer + reg_loss = reg_loss + lgt.regularizer() * 0.005 + + return img_loss, reg_loss diff --git a/video3d/model.py b/video3d/model.py new file mode 100755 index 0000000000000000000000000000000000000000..7d1bd834e26904c0f95c4021d51fa8dd44284afc --- /dev/null +++ b/video3d/model.py @@ -0,0 +1,1526 @@ +from multiprocessing.spawn import prepare +from turtle import forward +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import nvdiffrast.torch as dr +import numpy as np +import matplotlib.pyplot as plt +import os +import os.path as osp + +from video3d.render.regularizer import get_edge_length, normal_consistency +from . import networks +from .renderer import * +from .utils import misc, meters, flow_viz, arap, custom_loss +from .dataloaders import get_sequence_loader, get_image_loader +from .cub_dataloaders import get_cub_loader +from .utils.skinning_v4 import estimate_bones, skinning +import lpips +from einops import rearrange + +from .geometry.dmtet import DMTetGeometry +from .geometry.dlmesh import DLMesh + +from .render import renderutils as ru +from .render import material +from .render import mlptexture +from .render import util +from .render import mesh +from .render import light +from .render import render + +EPS = 1e-7 + + +def get_optimizer(model, lr=0.0001, betas=(0.9, 0.999), weight_decay=0): + return torch.optim.Adam( + filter(lambda p: p.requires_grad, model.parameters()), + lr=lr, betas=betas, weight_decay=weight_decay) + + +def set_requires_grad(model, requires_grad): + if model is not None: + for param in model.parameters(): + param.requires_grad = requires_grad + + +def forward_to_matrix(vec_forward, up=[0,1,0]): + up = torch.FloatTensor(up).to(vec_forward.device) + # vec_forward = nn.functional.normalize(vec_forward, p=2, dim=-1) # x right, y up, z forward + vec_right = up.expand_as(vec_forward).cross(vec_forward, dim=-1) + vec_right = nn.functional.normalize(vec_right, p=2, dim=-1) + vec_up = vec_forward.cross(vec_right, dim=-1) + vec_up = nn.functional.normalize(vec_up, p=2, dim=-1) + rot_mat = torch.stack([vec_right, vec_up, vec_forward], -2) + return rot_mat + + +def sample_pose_hypothesis_from_quad_prediction(poses_raw, total_iter, batch_size, num_frames, pose_xflip_recon=False, input_image_xflip_flag=None, rot_temp_scalar=1., num_hypos=4, naive_probs_iter=2000, best_pose_start_iter=6000, random_sample=True): + rots_pred = poses_raw[..., :num_hypos*4].view(-1, num_hypos, 4) + rots_logits = rots_pred[..., 0] # Nx4 + temp = 1 / np.clip(total_iter / 1000 / rot_temp_scalar, 1., 100.) + + rots_probs = torch.nn.functional.softmax(-rots_logits / temp, dim=1) # N x K + # naive_probs = torch.FloatTensor([10] + [1] * (num_hypos - 1)).to(rots_logits.device) + naive_probs = torch.ones(num_hypos).to(rots_logits.device) + naive_probs = naive_probs / naive_probs.sum() + naive_probs_weight = np.clip(1 - (total_iter - naive_probs_iter) / 2000, 0, 1) + rots_probs = naive_probs.view(1, num_hypos) * naive_probs_weight + rots_probs * (1 - naive_probs_weight) + + rots_pred = rots_pred[..., 1:4] + trans_pred = poses_raw[..., -3:] + best_rot_idx = torch.argmax(rots_probs, dim=1) # N + if random_sample: + # rand_rot_idx = torch.randint(0, 4, (batch_size * num_frames,), device=poses_raw.device) # N + rand_rot_idx = torch.randperm(batch_size * num_frames, device=poses_raw.device) % num_hypos # N + # rand_rot_idx = torch.randperm(batch_size, device=poses_raw.device)[:,None].repeat(1, num_frames).view(-1) % 4 # N + best_flag = (torch.randperm(batch_size * num_frames, device=poses_raw.device) / (batch_size * num_frames) < np.clip((total_iter - best_pose_start_iter)/2000, 0, 0.8)).long() + rand_flag = 1 - best_flag + # best_flag = torch.zeros_like(best_rot_idx) + rot_idx = best_rot_idx * best_flag + rand_rot_idx * (1 - best_flag) + else: + rand_flag = torch.zeros_like(best_rot_idx) + rot_idx = best_rot_idx + rot_pred = torch.gather(rots_pred, 1, rot_idx[:, None, None].expand(-1, 1, 3))[:, 0] # Nx3 + pose_raw = torch.cat([rot_pred, trans_pred], -1) + rot_prob = torch.gather(rots_probs, 1, rot_idx[:, None].expand(-1, 1))[:, 0] # N + rot_logit = torch.gather(rots_logits, 1, rot_idx[:, None].expand(-1, 1))[:, 0] # N + + if pose_xflip_recon: + raise NotImplementedError + rot_mat = forward_to_matrix(pose_raw[:, :3], up=[0, 1, 0]) + pose = torch.cat([rot_mat.view(batch_size * num_frames, -1), pose_raw[:, 3:]], -1) + return pose_raw, pose, rot_idx, rot_prob, rot_logit, rots_probs, rand_flag + + +class PriorPredictor(nn.Module): + def __init__(self, cfgs): + super().__init__() + dmtet_grid = cfgs.get('dmtet_grid', 64) + grid_scale = cfgs.get('grid_scale', 5) + prior_sdf_mode = cfgs.get('prior_sdf_mode', 'mlp') + num_layers_shape = cfgs.get('num_layers_shape', 5) + hidden_size = cfgs.get('hidden_size', 64) + embedder_freq_shape = cfgs.get('embedder_freq_shape', 8) + embed_concat_pts = cfgs.get('embed_concat_pts', True) + init_sdf = cfgs.get('init_sdf', None) + jitter_grid = cfgs.get('jitter_grid', 0.) + perturb_sdf_iter = cfgs.get('perturb_sdf_iter', 10000) + sym_prior_shape = cfgs.get('sym_prior_shape', False) + self.netShape = DMTetGeometry(dmtet_grid, grid_scale, prior_sdf_mode, num_layers=num_layers_shape, hidden_size=hidden_size, embedder_freq=embedder_freq_shape, embed_concat_pts=embed_concat_pts, init_sdf=init_sdf, jitter_grid=jitter_grid, perturb_sdf_iter=perturb_sdf_iter, sym_prior_shape=sym_prior_shape) + + mlp_hidden_size = cfgs.get('hidden_size', 64) + tet_bbox = self.netShape.getAABB() + self.render_dino_mode = cfgs.get('render_dino_mode', None) + num_layers_dino = cfgs.get("num_layers_dino", 5) + dino_feature_recon_dim = cfgs.get('dino_feature_recon_dim', 64) + sym_dino = cfgs.get("sym_dino", False) + dino_min = torch.zeros(dino_feature_recon_dim) + cfgs.get('dino_min', 0.) + dino_max = torch.zeros(dino_feature_recon_dim) + cfgs.get('dino_max', 1.) + min_max = torch.stack((dino_min, dino_max), dim=0) + if self.render_dino_mode is None: + pass + elif self.render_dino_mode == 'feature_mlpnv': + self.netDINO = mlptexture.MLPTexture3D(tet_bbox, channels=dino_feature_recon_dim, internal_dims=mlp_hidden_size, hidden=num_layers_dino-1, feat_dim=0, min_max=min_max, bsdf=None, perturb_normal=False, symmetrize=sym_dino) + elif self.render_dino_mode == 'feature_mlp': + embedder_scaler = 2 * np.pi / grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9 + embed_concat_pts = cfgs.get('embed_concat_pts', True) + self.netDINO = networks.MLPTextureSimple( + 3, # x, y, z coordinates + dino_feature_recon_dim, + num_layers_dino, + nf=mlp_hidden_size, + dropout=0, + activation="sigmoid", + min_max=min_max, + n_harmonic_functions=cfgs.get('embedder_freq_dino', 8), + omega0=embedder_scaler, + extra_dim=0, + embed_concat_pts=embed_concat_pts, + perturb_normal=False, + symmetrize=sym_dino + ) + elif self.render_dino_mode == 'cluster': + num_layers_dino = cfgs.get("num_layers_dino", 5) + dino_cluster_dim = cfgs.get('dino_cluster_dim', 64) + self.netDINO = mlptexture.MLPTexture3D(tet_bbox, channels=dino_cluster_dim, internal_dims=mlp_hidden_size, hidden=num_layers_dino-1, feat_dim=0, min_max=None, bsdf=None, perturb_normal=False, symmetrize=sym_dino) + else: + raise NotImplementedError + + def forward(self, perturb_sdf=False, total_iter=None, is_training=True): + prior_shape = self.netShape.getMesh(perturb_sdf=perturb_sdf, total_iter=total_iter, jitter_grid=is_training) + return prior_shape, self.netDINO + + +class InstancePredictor(nn.Module): + def __init__(self, cfgs, tet_bbox=None): + super().__init__() + self.cfgs = cfgs + self.grid_scale = cfgs.get('grid_scale', 5) + + self.enable_encoder = cfgs.get('enable_encoder', False) + if self.enable_encoder: + encoder_latent_dim = cfgs.get('latent_dim', 256) + encoder_pretrained = cfgs.get('encoder_pretrained', False) + encoder_frozen = cfgs.get('encoder_frozen', False) + encoder_arch = cfgs.get('encoder_arch', 'simple') + in_image_size = cfgs.get('in_image_size', 256) + self.dino_feature_input = cfgs.get('dino_feature_input', False) + dino_feature_dim = cfgs.get('dino_feature_dim', 64) + if encoder_arch == 'simple': + if self.dino_feature_input: + self.netEncoder = networks.EncoderWithDINO(cin_rgb=3, cin_dino=dino_feature_dim, cout=encoder_latent_dim, in_size=in_image_size, zdim=None, nf=64, activation=None) + else: + self.netEncoder = networks.Encoder(cin=3, cout=encoder_latent_dim, in_size=in_image_size, zdim=None, nf=64, activation=None) + elif encoder_arch == 'vgg': + self.netEncoder = networks.VGGEncoder(cout=encoder_latent_dim, pretrained=encoder_pretrained) + elif encoder_arch == 'resnet': + self.netEncoder = networks.ResnetEncoder(cout=encoder_latent_dim, pretrained=encoder_pretrained) + elif encoder_arch == 'vit': + which_vit = cfgs.get('which_vit', 'dino_vits8') + vit_final_layer_type = cfgs.get('vit_final_layer_type', 'conv') + self.netEncoder = networks.ViTEncoder(cout=encoder_latent_dim, which_vit=which_vit, pretrained=encoder_pretrained, frozen=encoder_frozen, in_size=in_image_size, final_layer_type=vit_final_layer_type) + else: + raise NotImplementedError + else: + encoder_latent_dim = 0 + + mlp_hidden_size = cfgs.get('hidden_size', 64) + + bsdf = cfgs.get("bsdf", 'diffuse') + num_layers_tex = cfgs.get("num_layers_tex", 5) + feat_dim = cfgs.get("latent_dim", 64) if self.enable_encoder else 0 + perturb_normal = cfgs.get("perturb_normal", False) + sym_texture = cfgs.get("sym_texture", False) + kd_min = torch.FloatTensor(cfgs.get('kd_min', [0., 0., 0., 0.])) + kd_max = torch.FloatTensor(cfgs.get('kd_max', [1., 1., 1., 1.])) + ks_min = torch.FloatTensor(cfgs.get('ks_min', [0., 0., 0.])) + ks_max = torch.FloatTensor(cfgs.get('ks_max', [0., 0., 0.])) + nrm_min = torch.FloatTensor(cfgs.get('nrm_min', [-1., -1., 0.])) + nrm_max = torch.FloatTensor(cfgs.get('nrm_max', [1., 1., 1.])) + mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0) + mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0) + min_max = torch.stack((mlp_min, mlp_max), dim=0) + out_chn = 9 + # TODO: if the tet verts are deforming, we need to recompute tet_bbox + texture_mode = cfgs.get("texture_mode", 'mlp') + if texture_mode == 'mlpnv': + self.netTexture = mlptexture.MLPTexture3D(tet_bbox, channels=out_chn, internal_dims=mlp_hidden_size, hidden=num_layers_tex-1, feat_dim=feat_dim, min_max=min_max, bsdf=bsdf, perturb_normal=perturb_normal, symmetrize=sym_texture) + elif texture_mode == 'mlp': + embedder_scaler = 2 * np.pi / self.grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9 + embed_concat_pts = cfgs.get('embed_concat_pts', True) + self.netTexture = networks.MLPTextureSimple( + 3, # x, y, z coordinates + out_chn, + num_layers_tex, + nf=mlp_hidden_size, + dropout=0, + activation="sigmoid", + min_max=min_max, + n_harmonic_functions=cfgs.get('embedder_freq_tex', 10), + omega0=embedder_scaler, + extra_dim=feat_dim, + embed_concat_pts=embed_concat_pts, + perturb_normal=perturb_normal, + symmetrize=sym_texture + ) + + self.rot_rep = cfgs.get('rot_rep', 'euler_angle') + self.enable_pose = cfgs.get('enable_pose', False) + if self.enable_pose: + cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.) + fov = cfgs.get('crop_fov_approx', 25) + half_range = np.tan(fov /2 /180 * np.pi) * cam_pos_z_offset # 2.22 + self.max_trans_xy_range = half_range * cfgs.get('max_trans_xy_range_ratio', 1.) + self.max_trans_z_range = half_range * cfgs.get('max_trans_z_range_ratio', 1.) + self.lookat_init = cfgs.get('lookat_init', None) + self.lookat_zeroy = cfgs.get('lookat_zeroy', False) + self.rot_temp_scalar = cfgs.get('rot_temp_scalar', 1.) + self.naive_probs_iter = cfgs.get('naive_probs_iter', 2000) + self.best_pose_start_iter = cfgs.get('best_pose_start_iter', 6000) + + if self.rot_rep == 'euler_angle': + pose_cout = 6 + elif self.rot_rep == 'quaternion': + pose_cout = 7 + elif self.rot_rep == 'lookat': + pose_cout = 6 + elif self.rot_rep == 'quadlookat': + self.num_pose_hypos = 4 + pose_cout = (3 + 1) * self.num_pose_hypos + 3 # 4 forward vectors for 4 quadrants, 4 quadrant classification logits, 3 for translation + self.orthant_signs = torch.FloatTensor([[1,1,1], [-1,1,1], [-1,1,-1], [1,1,-1]]) + elif self.rot_rep == 'octlookat': + self.num_pose_hypos = 8 + pose_cout = (3 + 1) * self.num_pose_hypos + 3 # 4 forward vectors for 8 octants, 8 octant classification logits, 3 for translation + self.orthant_signs = torch.stack(torch.meshgrid([torch.arange(1, -2, -2)] *3), -1).view(-1, 3) # 8x3 + else: + raise NotImplementedError + + self.pose_arch = cfgs.get('pose_arch', 'mlp') + if self.pose_arch == 'mlp': + num_layers_pose = cfgs.get('num_layers_pose', 5) + self.netPose = networks.MLP( + encoder_latent_dim, + pose_cout, + num_layers_pose, + nf=mlp_hidden_size, + dropout=0, + activation=None + ) + elif self.pose_arch == 'encoder': + if self.dino_feature_input: + dino_feature_dim = cfgs.get('dino_feature_dim', 64) + self.netPose = networks.EncoderWithDINO(cin_rgb=3, cin_dino=dino_feature_dim, cout=pose_cout, in_size=in_image_size, zdim=None, nf=64, activation=None) + else: + self.netPose = networks.Encoder(cin=3, cout=pose_cout, in_size=in_image_size, zdim=None, nf=64, activation=None) + elif self.pose_arch in ['encoder_dino_patch_out', 'encoder_dino_patch_key']: + if which_vit == 'dino_vits8': + dino_feat_dim = 384 + elif which_vit == 'dinov2_vits14': + dino_feat_dim = 384 + elif which_vit == 'dino_vitb8': + dino_feat_dim = 768 + self.netPose = networks.Encoder32(cin=dino_feat_dim, cout=pose_cout, nf=256, activation=None) + elif self.pose_arch == 'vit': + encoder_pretrained = cfgs.get('encoder_pretrained', False) + encoder_frozen = cfgs.get('encoder_frozen', False) + which_vit = cfgs.get('which_vit', 'dino_vits8') + vit_final_layer_type = cfgs.get('vit_final_layer_type', 'conv') + self.netPose = networks.ViTEncoder(cout=encoder_latent_dim, which_vit=which_vit, pretrained=encoder_pretrained, frozen=encoder_frozen, in_size=in_image_size, final_layer_type=vit_final_layer_type) + else: + raise NotImplementedError + + self.enable_deform = cfgs.get('enable_deform', False) + if self.enable_deform: + embedder_scaler = 2 * np.pi / self.grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9 + embed_concat_pts = cfgs.get('embed_concat_pts', True) + num_layers_deform = cfgs.get('num_layers_deform', 5) + self.deform_epochs = np.arange(*cfgs.get('deform_epochs', [0, 0])) + sym_deform = cfgs.get("sym_deform", False) + self.netDeform = networks.MLPWithPositionalEncoding( + 3, # x, y, z coordinates + 3, # dx, dy, dz deformation + num_layers_deform, + nf=mlp_hidden_size, + dropout=0, + activation=None, + n_harmonic_functions=cfgs.get('embedder_freq_deform', 10), + omega0=embedder_scaler, + extra_dim=encoder_latent_dim, + embed_concat_pts=embed_concat_pts, + symmetrize=sym_deform + ) + + self.enable_articulation = cfgs.get('enable_articulation', False) + if self.enable_articulation: + self.num_body_bones = cfgs.get('num_body_bones', 4) + self.articulation_multiplier = cfgs.get('articulation_multiplier', 1) + self.static_root_bones = cfgs.get('static_root_bones', False) + self.skinning_temperature = cfgs.get('skinning_temperature', 1) + self.articulation_epochs = np.arange(*cfgs.get('articulation_epochs', [0, 0])) + self.num_legs = cfgs.get('num_legs', 0) + self.num_leg_bones = cfgs.get('num_leg_bones', 0) + self.body_bones_type = cfgs.get('body_bones_type', 'z_minmax') + self.perturb_articulation_epochs = np.arange(*cfgs.get('perturb_articulation_epochs', [0, 0])) + self.num_bones = self.num_body_bones + self.num_legs * self.num_leg_bones + self.constrain_legs = cfgs.get('constrain_legs', False) + self.attach_legs_to_body_epochs = np.arange(*cfgs.get('attach_legs_to_body_epochs', [0, 0])) + self.max_arti_angle = cfgs.get('max_arti_angle', 60) + + num_layers_arti = cfgs.get('num_layers_arti', 5) + which_vit = cfgs.get('which_vit', 'dino_vits8') + if which_vit == 'dino_vits8': + dino_feat_dim = 384 + elif which_vit == 'dino_vitb8': + dino_feat_dim = 768 + self.articulation_arch = cfgs.get('articulation_arch', 'mlp') + self.articulation_feature_mode = cfgs.get('articulation_feature_mode', 'sample') + embedder_freq_arti = cfgs.get('embedder_freq_arti', 8) + if self.articulation_feature_mode == 'global': + feat_dim = encoder_latent_dim + elif self.articulation_feature_mode == 'sample': + feat_dim = dino_feat_dim + elif self.articulation_feature_mode == 'sample+global': + feat_dim = encoder_latent_dim + dino_feat_dim + if self.articulation_feature_mode == 'attention': + arti_feat_attn_zdim = cfgs.get('arti_feat_attn_zdim', 128) + pos_dim = 1 + 2 + 3*2 + self.netFeatureAttn = networks.FeatureAttention(which_vit, pos_dim, embedder_freq_arti, arti_feat_attn_zdim, img_size=in_image_size) + embedder_scaler = np.pi * 0.9 # originally (-1, 1) rescale to (-pi, pi) * 0.9 + self.netArticulation = networks.ArticulationNetwork(self.articulation_arch, feat_dim, 1+2+3*2, num_layers_arti, mlp_hidden_size, n_harmonic_functions=embedder_freq_arti, omega0=embedder_scaler) + self.kinematic_tree_epoch = -1 + + self.enable_lighting = cfgs.get('enable_lighting', False) + if self.enable_lighting: + num_layers_light = cfgs.get('num_layers_light', 5) + amb_diff_min = torch.FloatTensor(cfgs.get('amb_diff_min', [0., 0.])) + amb_diff_max = torch.FloatTensor(cfgs.get('amb_diff_max', [1., 1.])) + intensity_min_max = torch.stack((amb_diff_min, amb_diff_max), dim=0) + self.netLight = light.DirectionalLight(encoder_latent_dim, num_layers_light, mlp_hidden_size, intensity_min_max=intensity_min_max) + + self.cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.) + self.crop_fov_approx = cfgs.get("crop_fov_approx", 25) + + def forward_encoder(self, images, dino_features=None): + images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1) + patch_out = patch_key = None + if self.dino_feature_input and self.cfgs.get('encoder_arch', 'simple') != 'vit': + dino_features_in = dino_features.view(-1, *dino_features.shape[2:]) * 2 - 1 # rescale to (-1, 1) + feat_out = self.netEncoder(images_in, dino_features_in) # Shape: (B, latent_dim) + elif self.cfgs.get('encoder_arch', 'simple') == 'vit': + feat_out, feat_key, patch_out, patch_key = self.netEncoder(images_in, return_patches=True) + else: + feat_out = self.netEncoder(images_in) # Shape: (B, latent_dim) + return feat_out, feat_key, patch_out, patch_key + + def forward_pose(self, images, feat, patch_out, patch_key, dino_features): + if self.pose_arch == 'mlp': + pose = self.netPose(feat) + elif self.pose_arch == 'encoder': + images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1) + if self.dino_feature_input: + dino_features_in = dino_features.view(-1, *dino_features.shape[2:]) * 2 - 1 # rescale to (-1, 1) + pose = self.netPose(images_in, dino_features_in) # Shape: (B, latent_dim) + else: + pose = self.netPose(images_in) # Shape: (B, latent_dim) + elif self.pose_arch == 'vit': + images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1) + pose = self.netPose(images_in) + elif self.pose_arch == 'encoder_dino_patch_out': + pose = self.netPose(patch_out) # Shape: (B, latent_dim) + elif self.pose_arch == 'encoder_dino_patch_key': + pose = self.netPose(patch_key) # Shape: (B, latent_dim) + else: + raise NotImplementedError + trans_pred = pose[...,-3:].tanh() * torch.FloatTensor([self.max_trans_xy_range, self.max_trans_xy_range, self.max_trans_z_range]).to(pose.device) + if self.rot_rep == 'euler_angle': + multiplier = 1. + if self.gradually_expand_yaw: + # multiplier += (min(iteration, 20000) // 500) * 0.25 + multiplier *= 1.2 ** (min(iteration, 20000) // 500) # 1.125^40 = 111.200 + rot_pred = torch.cat([pose[...,:1], pose[...,1:2]*multiplier, pose[...,2:3]], -1).tanh() + rot_pred = rot_pred * torch.FloatTensor([self.max_rot_x_range, self.max_rot_y_range, self.max_rot_z_range]).to(pose.device) /180 * np.pi + + elif self.rot_rep == 'quaternion': + quat_init = torch.FloatTensor([0.01,0,0,0]).to(pose.device) + rot_pred = pose[...,:4] + quat_init + rot_pred = nn.functional.normalize(rot_pred, p=2, dim=-1) + # rot_pred = torch.cat([rot_pred[...,:1].abs(), rot_pred[...,1:]], -1) # make real part non-negative + rot_pred = rot_pred * rot_pred[...,:1].sign() # make real part non-negative + + elif self.rot_rep == 'lookat': + vec_forward_raw = pose[...,:3] + if self.lookat_init is not None: + vec_forward_raw = vec_forward_raw + torch.FloatTensor(self.lookat_init).to(pose.device) + if self.lookat_zeroy: + vec_forward_raw = vec_forward_raw * torch.FloatTensor([1,0,1]).to(pose.device) + vec_forward_raw = nn.functional.normalize(vec_forward_raw, p=2, dim=-1) # x right, y up, z forward + rot_pred = vec_forward_raw + + elif self.rot_rep in ['quadlookat', 'octlookat']: + rots_pred = pose[..., :self.num_pose_hypos*4].view(-1, self.num_pose_hypos, 4) # (B, T, K, 4) + rots_logits = rots_pred[..., :1] + vec_forward_raw = rots_pred[..., 1:4] + xs, ys, zs = vec_forward_raw.unbind(-1) + margin = 0. + xs = nn.functional.softplus(xs, beta=np.log(2)/(0.5+margin)) - margin # initialize to 0.5 + if self.rot_rep == 'octlookat': + ys = nn.functional.softplus(ys, beta=np.log(2)/(0.5+margin)) - margin # initialize to 0.5 + if self.lookat_zeroy: + ys = ys * 0 + zs = nn.functional.softplus(zs, beta=2*np.log(2)) # initialize to 0.5 + vec_forward_raw = torch.stack([xs, ys, zs], -1) + vec_forward_raw = vec_forward_raw * self.orthant_signs.to(pose.device) + vec_forward_raw = nn.functional.normalize(vec_forward_raw, p=2, dim=-1) # x right, y up, z forward + rot_pred = torch.cat([rots_logits, vec_forward_raw], -1).view(-1, self.num_pose_hypos*4) + + else: + raise NotImplementedError + + pose = torch.cat([rot_pred, trans_pred], -1) + return pose + + def forward_deformation(self, shape, feat=None): + original_verts = shape.v_pos + num_verts = original_verts.shape[1] + if feat is not None: + deform_feat = feat[:, None, :].repeat(1, num_verts, 1) # Shape: (B, num_verts, latent_dim) + original_verts = original_verts.repeat(len(feat),1,1) + deformation = self.netDeform(original_verts, deform_feat) * 0.1 # Shape: (B, num_verts, 3) + shape = shape.deform(deformation) + return shape, deformation + + def forward_articulation(self, shape, feat, patch_feat, mvp, w2c, batch_size, num_frames, epoch): + """ + Forward propagation of articulation. For each bone, the network takes: 1) the 3D location of the bone; 2) the feature of the patch which + the bone is projected to; and 3) an encoding of the bone's index to predict the bone's rotation (represented by an Euler angle). + + Args: + shape: a Mesh object, whose v_pos has batch size BxF or 1. + feat: the feature of the patches. Shape: (BxF, feat_dim, num_patches_per_axis, num_patches_per_axis) + mvp: the model-view-projection matrix. Shape: (BxF, 4, 4) + + Returns: + shape: a Mesh object, whose v_pos has batch size BxF (collapsed). + articulation_angles: the predicted bone rotations. Shape: (B, F, num_bones, 3) + aux: a dictionary containing auxiliary information. + """ + verts = shape.v_pos + if len(verts) == 1: + verts = verts[None] + else: + verts = verts.view(batch_size, num_frames, *verts.shape[1:]) + + if self.kinematic_tree_epoch != epoch: + # if (epoch == self.articulation_epochs[0]) and (self.kinematic_tree_epoch != epoch): + # if (epoch in [self.articulation_epochs[0], self.articulation_epochs[0]+2, self.articulation_epochs[0]+4]) and (self.kinematic_tree_epoch != epoch): + attach_legs_to_body = epoch in self.attach_legs_to_body_epochs + bones, self.kinematic_tree, self.bone_aux = estimate_bones(verts.detach(), self.num_body_bones, n_legs=self.num_legs, n_leg_bones=self.num_leg_bones, body_bones_type=self.body_bones_type, compute_kinematic_chain=True, attach_legs_to_body=attach_legs_to_body) + self.kinematic_tree_epoch = epoch + else: + bones = estimate_bones(verts.detach(), self.num_body_bones, n_legs=self.num_legs, n_leg_bones=self.num_leg_bones, body_bones_type=self.body_bones_type, compute_kinematic_chain=False, aux=self.bone_aux) + + bones_pos = bones # Shape: (B, F, K, 2, 3) + if batch_size > bones_pos.shape[0] or num_frames > bones_pos.shape[1]: + assert bones_pos.shape[0] == 1 and bones_pos.shape[1] == 1, "If there is a mismatch, then there must be only one canonical mesh." + bones_pos = bones_pos.repeat(batch_size, num_frames, 1, 1, 1) + num_bones = bones_pos.shape[2] + bones_pos = bones_pos.view(batch_size*num_frames, num_bones, 2, 3) # NxKx2x3 + bones_mid_pos = bones_pos.mean(2) # NxKx3 + bones_idx = torch.arange(num_bones).to(bones_pos.device) + + bones_mid_pos_world4 = torch.cat([bones_mid_pos, torch.ones_like(bones_mid_pos[..., :1])], -1) # NxKx4 + bones_mid_pos_clip4 = bones_mid_pos_world4 @ mvp.transpose(-1, -2) + bones_mid_pos_uv = bones_mid_pos_clip4[..., :2] / bones_mid_pos_clip4[..., 3:4] + bones_mid_pos_uv = bones_mid_pos_uv.detach() + + bones_pos_world4 = torch.cat([bones_pos, torch.ones_like(bones_pos[..., :1])], -1) # NxKx2x4 + bones_pos_cam4 = bones_pos_world4 @ w2c[:,None].transpose(-1, -2) + bones_pos_cam3 = bones_pos_cam4[..., :3] / bones_pos_cam4[..., 3:4] + bones_pos_cam3 = bones_pos_cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(bones_pos_cam3.device).view(1, 1, 1, 3) + bones_pos_in = bones_pos_cam3.view(batch_size*num_frames, num_bones, 2*3) / self.grid_scale * 2 # (-1, 1), NxKx(2*3) + + bones_idx_in = ((bones_idx[None, :, None] + 0.5) / num_bones * 2 - 1).repeat(batch_size * num_frames, 1, 1) # (-1, 1) + bones_pos_in = torch.cat([bones_mid_pos_uv, bones_pos_in, bones_idx_in], -1).detach() + + if self.articulation_feature_mode == 'global': + bones_patch_features = feat[:, None].repeat(1, num_bones, 1) # (BxF, K, feat_dim) + elif self.articulation_feature_mode == 'sample': + bones_patch_features = F.grid_sample(patch_feat, bones_mid_pos_uv.view(batch_size * num_frames, 1, -1, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # (BxF, K, feat_dim) + elif self.articulation_feature_mode == 'sample+global': + bones_patch_features = F.grid_sample(patch_feat, bones_mid_pos_uv.view(batch_size * num_frames, 1, -1, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # (BxF, K, feat_dim) + bones_patch_features = torch.cat([feat[:, None].repeat(1, num_bones, 1), bones_patch_features], -1) + elif self.articulation_feature_mode == 'attention': + bones_patch_features = self.netFeatureAttn(bones_pos_in, patch_feat) + else: + raise NotImplementedError + + articulation_angles = self.netArticulation(bones_patch_features, bones_pos_in).view(batch_size, num_frames, num_bones, 3) * self.articulation_multiplier + + if self.static_root_bones: + root_bones = [self.num_body_bones // 2 - 1, self.num_body_bones - 1] + tmp_mask = torch.ones_like(articulation_angles) + tmp_mask[:, :, root_bones] = 0 + articulation_angles = articulation_angles * tmp_mask + + articulation_angles = articulation_angles.tanh() + + if self.constrain_legs: + leg_bones_posx = [self.num_body_bones + i for i in range(self.num_leg_bones * self.num_legs // 2)] + leg_bones_negx = [self.num_body_bones + self.num_leg_bones * self.num_legs // 2 + i for i in range(self.num_leg_bones * self.num_legs // 2)] + + tmp_mask = torch.zeros_like(articulation_angles) + tmp_mask[:, :, leg_bones_posx + leg_bones_negx, 2] = 1 + articulation_angles = tmp_mask * (articulation_angles * 0.3) + (1 - tmp_mask) * articulation_angles # no twist + + tmp_mask = torch.zeros_like(articulation_angles) + tmp_mask[:, :, leg_bones_posx + leg_bones_negx, 1] = 1 + articulation_angles = tmp_mask * (articulation_angles * 0.3) + (1 - tmp_mask) * articulation_angles # (-0.4, 0.4), limit side bending + + if epoch in self.perturb_articulation_epochs: + articulation_angles = articulation_angles + torch.randn_like(articulation_angles) * 0.1 + articulation_angles = articulation_angles * self.max_arti_angle / 180 * np.pi + + verts_articulated, aux = skinning(verts, bones, self.kinematic_tree, articulation_angles, + output_posed_bones=True, temperature=self.skinning_temperature) + verts_articulated = verts_articulated.view(batch_size*num_frames, *verts_articulated.shape[2:]) + v_tex = shape.v_tex + if len(v_tex) != len(verts_articulated): + v_tex = v_tex.repeat(len(verts_articulated), 1, 1) + shape = mesh.make_mesh( + verts_articulated, + shape.t_pos_idx, + v_tex, + shape.t_tex_idx, + shape.material) + return shape, articulation_angles, aux + + def get_camera_extrinsics_from_pose(self, pose, znear=0.1, zfar=1000.): + N = len(pose) + cam_pos_offset = torch.FloatTensor([0, 0, -self.cam_pos_z_offset]).to(pose.device) + pose_R = pose[:, :9].view(N, 3, 3).transpose(2, 1) + pose_T = pose[:, -3:] + cam_pos_offset[None, None, :] + pose_T = pose_T.view(N, 3, 1) + pose_RT = torch.cat([pose_R, pose_T], axis=2) # Nx3x4 + w2c = torch.cat([pose_RT, torch.FloatTensor([0, 0, 0, 1]).repeat(N, 1, 1).to(pose.device)], axis=1) # Nx4x4 + # We assume the images are perfect square. + proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, znear, zfar)[None].to(pose.device) + mvp = torch.matmul(proj, w2c) + campos = -torch.matmul(pose_R.transpose(2, 1), pose_T).view(N, 3) + return mvp, w2c, campos + + def forward(self, images=None, prior_shape=None, epoch=None, dino_features=None, dino_clusters=None, total_iter=None, is_training=True): + batch_size, num_frames = images.shape[:2] + if self.enable_encoder: + feat_out, feat_key, patch_out, patch_key = self.forward_encoder(images, dino_features) + else: + feat_out = feat_key = patch_out = patch_key = None + shape = prior_shape + texture = self.netTexture + + multi_hypothesis_aux = {} + if self.enable_pose: + poses_raw = self.forward_pose(images, feat_out, patch_out, patch_key, dino_features) + pose_raw, pose, rot_idx, rot_prob, rot_logit, rots_probs, rand_pose_flag = sample_pose_hypothesis_from_quad_prediction(poses_raw, total_iter, batch_size, num_frames, rot_temp_scalar=self.rot_temp_scalar, num_hypos=self.num_pose_hypos, naive_probs_iter=self.naive_probs_iter, best_pose_start_iter=self.best_pose_start_iter, random_sample=is_training) + multi_hypothesis_aux['rot_idx'] = rot_idx + multi_hypothesis_aux['rot_prob'] = rot_prob + multi_hypothesis_aux['rot_logit'] = rot_logit + multi_hypothesis_aux['rots_probs'] = rots_probs + multi_hypothesis_aux['rand_pose_flag'] = rand_pose_flag + else: + raise NotImplementedError + mvp, w2c, campos = self.get_camera_extrinsics_from_pose(pose) + + deformation = None + if self.enable_deform and epoch in self.deform_epochs: + shape, deformation = self.forward_deformation(shape, feat_key) + + arti_params, articulation_aux = None, {} + if self.enable_articulation and epoch in self.articulation_epochs: + shape, arti_params, articulation_aux = self.forward_articulation(shape, feat_key, patch_key, mvp, w2c, batch_size, num_frames, epoch) + + if self.enable_lighting: + light = self.netLight + else: + light = None + + aux = articulation_aux + aux.update(multi_hypothesis_aux) + + return shape, pose_raw, pose, mvp, w2c, campos, texture, feat_out, deformation, arti_params, light, aux + + +class Unsup3D: + def __init__(self, cfgs): + self.cfgs = cfgs + self.device = cfgs.get('device', 'cpu') + self.in_image_size = cfgs.get('in_image_size', 128) + self.out_image_size = cfgs.get('out_image_size', 128) + + self.num_epochs = cfgs.get('num_epochs', 10) + self.lr = cfgs.get('lr', 1e-4) + self.use_scheduler = cfgs.get('use_scheduler', False) + if self.use_scheduler: + scheduler_milestone = cfgs.get('scheduler_milestone', [1,2,3,4,5]) + scheduler_gamma = cfgs.get('scheduler_gamma', 0.5) + self.make_scheduler = lambda optim: torch.optim.lr_scheduler.MultiStepLR(optim, milestones=scheduler_milestone, gamma=scheduler_gamma) + + self.cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.) + self.full_size_h = cfgs.get('full_size_h', 1080) + self.full_size_w = cfgs.get('full_size_w', 1920) + # self.fov_w = cfgs.get('fov_w', 60) + # self.fov_h = np.arctan(np.tan(self.fov_w /2 /180*np.pi) / self.full_size_w * self.full_size_h) *2 /np.pi*180 # 36 + self.crop_fov_approx = cfgs.get("crop_fov_approx", 25) + self.mesh_regularization_mode = cfgs.get('mesh_regularization_mode', 'seq') + + self.enable_prior = cfgs.get('enable_prior', False) + if self.enable_prior: + self.netPrior = PriorPredictor(self.cfgs) + self.prior_lr = cfgs.get('prior_lr', self.lr) + self.prior_weight_decay = cfgs.get('prior_weight_decay', 0.) + self.prior_only_epochs = cfgs.get('prior_only_epochs', 0) + self.netInstance = InstancePredictor(self.cfgs, tet_bbox=self.netPrior.netShape.getAABB()) + self.perturb_sdf = cfgs.get('perturb_sdf', False) + self.blur_mask = cfgs.get('blur_mask', False) + self.blur_mask_iter = cfgs.get('blur_mask_iter', 1) + + self.seqshape_epochs = np.arange(*cfgs.get('seqshape_epochs', [0, self.num_epochs])) + self.avg_texture_epochs = np.arange(*cfgs.get('avg_texture_epochs', [0, 0])) + self.swap_texture_epochs = np.arange(*cfgs.get('swap_texture_epochs', [0, 0])) + self.swap_priorshape_epochs = np.arange(*cfgs.get('swap_priorshape_epochs', [0, 0])) + self.avg_seqshape_epochs = np.arange(*cfgs.get('avg_seqshape_epochs', [0, 0])) + self.swap_seqshape_epochs = np.arange(*cfgs.get('swap_seqshape_epochs', [0, 0])) + self.pose_epochs = np.arange(*cfgs.get('pose_epochs', [0, 0])) + self.pose_iters = cfgs.get('pose_iters', 0) + self.deform_type = cfgs.get('deform_type', None) + self.mesh_reg_decay_epoch = cfgs.get('mesh_reg_decay_epoch', 0) + self.sdf_reg_decay_start_iter = cfgs.get('sdf_reg_decay_start_iter', 0) + self.mesh_reg_decay_rate = cfgs.get('mesh_reg_decay_rate', 1) + self.texture_epochs = np.arange(*cfgs.get('texture_epochs', [0, self.num_epochs])) + self.zflip_epochs = np.arange(*cfgs.get('zflip_epochs', [0, self.num_epochs])) + self.lookat_zflip_loss_epochs = np.arange(*cfgs.get('lookat_zflip_loss_epochs', [0, self.num_epochs])) + self.lookat_zflip_no_other_losses = cfgs.get('lookat_zflip_no_other_losses', False) + self.flow_loss_epochs = np.arange(*cfgs.get('flow_loss_epochs', [0, self.num_epochs])) + self.sdf_inflate_reg_loss_epochs = np.arange(*cfgs.get('sdf_inflate_reg_loss_epochs', [0, self.num_epochs])) + self.arti_reg_loss_epochs = np.arange(*cfgs.get('arti_reg_loss_epochs', [0, self.num_epochs])) + self.background_mode = cfgs.get('background_mode', 'background') + self.shape_prior_type = cfgs.get('shape_prior_type', 'deform') + self.backward_prior = cfgs.get('backward_prior', True) + self.resume_prior_optim = cfgs.get('resume_prior_optim', True) + self.dmtet_grid_smaller_epoch = cfgs.get('dmtet_grid_smaller_epoch', 0) + self.dmtet_grid_smaller = cfgs.get('dmtet_grid_smaller', 128) + self.dmtet_grid = cfgs.get('dmtet_grid', 256) + self.pose_xflip_recon_epochs = np.arange(*cfgs.get('pose_xflip_recon_epochs', [0, 0])) + self.rot_rand_quad_epochs = np.arange(*cfgs.get('rot_rand_quad_epochs', [0, 0])) + self.rot_all_quad_epochs = np.arange(*cfgs.get('rot_all_quad_epochs', [0, 0])) + + ## perceptual loss + if cfgs.get('perceptual_loss_weight', 0.) > 0: + self.perceptual_loss_use_lin = cfgs.get('perceptual_loss_use_lin', True) + self.perceptual_loss = lpips.LPIPS(net='vgg', lpips=self.perceptual_loss_use_lin) + + self.glctx = dr.RasterizeGLContext() + self.render_flow = self.cfgs.get('flow_loss_weight', 0.) > 0. + self.extra_renders = cfgs.get('extra_renders', []) + self.renderer_spp = cfgs.get('renderer_spp', 1) + self.dino_feature_recon_dim = cfgs.get('dino_feature_recon_dim', 64) + + self.total_loss = 0. + self.all_scores = torch.Tensor() + self.checkpoint_dir = cfgs.get('checkpoint_dir', 'results') + + @staticmethod + def get_data_loaders(cfgs, dataset, in_image_size=256, out_image_size=256, batch_size=64, num_workers=4, run_train=False, run_test=False, train_data_dir=None, val_data_dir=None, test_data_dir=None): + train_loader = val_loader = test_loader = None + color_jitter_train = cfgs.get('color_jitter_train', None) + color_jitter_val = cfgs.get('color_jitter_val', None) + random_flip_train = cfgs.get('random_flip_train', False) + + ## video dataset + if dataset == 'video': + data_loader_mode = cfgs.get('data_loader_mode', 'n_frame') + skip_beginning = cfgs.get('skip_beginning', 4) + skip_end = cfgs.get('skip_end', 4) + num_sample_frames = cfgs.get('num_sample_frames', 2) + min_seq_len = cfgs.get('min_seq_len', 10) + max_seq_len = cfgs.get('max_seq_len', 10) + debug_seq = cfgs.get('debug_seq', False) + random_sample_train_frames = cfgs.get('random_sample_train_frames', False) + shuffle_train_seqs = cfgs.get('shuffle_train_seqs', False) + random_sample_val_frames = cfgs.get('random_sample_val_frames', False) + load_background = cfgs.get('background_mode', 'none') == 'background' + rgb_suffix = cfgs.get('rgb_suffix', '.png') + load_dino_feature = cfgs.get('load_dino_feature', False) + load_dino_cluster = cfgs.get('load_dino_cluster', False) + dino_feature_dim = cfgs.get('dino_feature_dim', 64) + get_loader = lambda **kwargs: get_sequence_loader( + mode=data_loader_mode, + batch_size=batch_size, + num_workers=num_workers, + in_image_size=in_image_size, + out_image_size=out_image_size, + debug_seq=debug_seq, + skip_beginning=skip_beginning, + skip_end=skip_end, + num_sample_frames=num_sample_frames, + min_seq_len=min_seq_len, + max_seq_len=max_seq_len, + load_background=load_background, + rgb_suffix=rgb_suffix, + load_dino_feature=load_dino_feature, + load_dino_cluster=load_dino_cluster, + dino_feature_dim=dino_feature_dim, + **kwargs) + + if run_train: + assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" + print(f"Loading training data from {train_data_dir}") + train_loader = get_loader(data_dir=train_data_dir, is_validation=False, random_sample=random_sample_train_frames, shuffle=shuffle_train_seqs, dense_sample=True, color_jitter=color_jitter_train, random_flip=random_flip_train) + + if val_data_dir is not None: + assert osp.isdir(val_data_dir), f"Validation data directory does not exist: {val_data_dir}" + print(f"Loading validation data from {val_data_dir}") + val_loader = get_loader(data_dir=val_data_dir, is_validation=True, random_sample=random_sample_val_frames, shuffle=False, dense_sample=False, color_jitter=color_jitter_val, random_flip=False) + if run_test: + assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" + print(f"Loading testing data from {test_data_dir}") + test_loader = get_loader(data_dir=test_data_dir, is_validation=True, dense_sample=False, color_jitter=None, random_flip=False) + + ## CUB dataset + elif dataset == 'cub': + get_loader = lambda **kwargs: get_cub_loader( + batch_size=batch_size, + num_workers=num_workers, + image_size=in_image_size, + **kwargs) + + if run_train: + assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" + print(f"Loading training data from {train_data_dir}") + train_loader = get_loader(data_dir=train_data_dir, split='train', is_validation=False) + val_loader = get_loader(data_dir=val_data_dir, split='val', is_validation=True) + + if run_test: + assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" + print(f"Loading testing data from {test_data_dir}") + test_loader = get_loader(data_dir=test_data_dir, split='test', is_validation=True) + + ## other datasets + else: + get_loader = lambda **kwargs: get_image_loader( + batch_size=batch_size, + num_workers=num_workers, + image_size=in_image_size, + **kwargs) + + if run_train: + assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" + print(f"Loading training data from {train_data_dir}") + train_loader = get_loader(data_dir=train_data_dir, is_validation=False, color_jitter=color_jitter_train) + + if val_data_dir is not None: + assert osp.isdir(val_data_dir), f"Validation data directory does not exist: {val_data_dir}" + print(f"Loading validation data from {val_data_dir}") + val_loader = get_loader(data_dir=val_data_dir, is_validation=True, color_jitter=color_jitter_val) + + if run_test: + assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" + print(f"Loading testing data from {test_data_dir}") + test_loader = get_loader(data_dir=test_data_dir, is_validation=True, color_jitter=None) + + return train_loader, val_loader, test_loader + + def load_model_state(self, cp): + self.netInstance.load_state_dict(cp["netInstance"]) + if self.enable_prior: + self.netPrior.load_state_dict(cp["netPrior"]) + + def load_optimizer_state(self, cp): + self.optimizerInstance.load_state_dict(cp["optimizerInstance"]) + if self.use_scheduler: + if 'schedulerInstance' in cp: + self.schedulerInstance.load_state_dict(cp["schedulerInstance"]) + if self.enable_prior and self.resume_prior_optim: + self.optimizerPrior.load_state_dict(cp["optimizerPrior"]) + if self.use_scheduler: + if 'schedulerPrior' in cp: + self.schedulerPrior.load_state_dict(cp["schedulerPrior"]) + + def get_model_state(self): + state = {"netInstance": self.netInstance.state_dict()} + if self.enable_prior: + state["netPrior"] = self.netPrior.state_dict() + return state + + def get_optimizer_state(self): + state = {"optimizerInstance": self.optimizerInstance.state_dict()} + if self.use_scheduler: + state["schedulerInstance"] = self.schedulerInstance.state_dict() + if self.enable_prior: + state["optimizerPrior"] = self.optimizerPrior.state_dict() + if self.use_scheduler: + state["schedulerPrior"] = self.schedulerPrior.state_dict() + return state + + def to(self, device): + self.device = device + self.netInstance.to(device) + if self.enable_prior: + self.netPrior.to(device) + if hasattr(self, 'perceptual_loss'): + self.perceptual_loss.to(device) + + def set_train(self): + self.netInstance.train() + if self.enable_prior: + self.netPrior.train() + + def set_eval(self): + self.netInstance.eval() + if self.enable_prior: + self.netPrior.eval() + + def reset_optimizers(self): + print("Resetting optimizers...") + self.optimizerInstance = get_optimizer(self.netInstance, self.lr) + if self.use_scheduler: + self.schedulerInstance = self.make_scheduler(self.optimizerInstance) + if self.enable_prior: + self.optimizerPrior = get_optimizer(self.netPrior, lr=self.prior_lr, weight_decay=self.prior_weight_decay) + if self.use_scheduler: + self.schedulerPrior = self.make_scheduler(self.optimizerPrior) + + def backward(self): + self.optimizerInstance.zero_grad() + if self.backward_prior: + self.optimizerPrior.zero_grad() + self.total_loss.backward() + self.optimizerInstance.step() + if self.backward_prior: + self.optimizerPrior.step() + self.total_loss = 0. + + def scheduler_step(self): + if self.use_scheduler: + self.schedulerInstance.step() + if self.enable_prior: + self.schedulerPrior.step() + + def zflip_pose(self, pose): + if self.rot_rep == 'lookat': + vec_forward = pose[:,:,6:9] + vec_forward = vec_forward * torch.FloatTensor([1,1,-1]).view(1,1,3).to(vec_forward.device) + up = torch.FloatTensor([0,1,0]).to(pose.device).view(1,1,3) + vec_right = up.expand_as(vec_forward).cross(vec_forward, dim=-1) + vec_right = nn.functional.normalize(vec_right, p=2, dim=-1) + vec_up = vec_forward.cross(vec_right, dim=-1) + vec_up = nn.functional.normalize(vec_up, p=2, dim=-1) + rot_mat = torch.stack([vec_right, vec_up, vec_forward], 2) + rot_pred = rot_mat.reshape(*pose.shape[:-1], -1) + pose_zflip = torch.cat([rot_pred, pose[:,:,9:]], -1) + else: + raise NotImplementedError + return pose_zflip + + def render(self, shape, texture, mvp, w2c, campos, resolution, background='none', im_features=None, light=None, prior_shape=None, render_flow=True, dino_pred=None, render_mode='diffuse', two_sided_shading=True, num_frames=None, spp=1): + h, w = resolution + N = len(mvp) + if background in ['none', 'black']: + bg_image = torch.zeros((N, h, w, 3), device=mvp.device) + elif background == 'white': + bg_image = torch.ones((N, h, w, 3), device=mvp.device) + elif background == 'checkerboard': + bg_image = torch.FloatTensor(util.checkerboard((h, w), 8), device=self.device).repeat(N, 1, 1, 1) # NxHxWxC + else: + raise NotImplementedError + + frame_rendered = render.render_mesh( + self.glctx, + shape, + mtx_in=mvp, + w2c=w2c, + view_pos=campos, + material=texture, + lgt=light, + resolution=resolution, + spp=spp, + msaa=True, + background=bg_image, + bsdf=render_mode, + feat=im_features, + prior_mesh=prior_shape, + two_sided_shading=two_sided_shading, + render_flow=render_flow, + dino_pred=dino_pred, + num_frames=num_frames) + shaded = frame_rendered['shaded'].permute(0, 3, 1, 2) + image_pred = shaded[:, :3, :, :] + mask_pred = shaded[:, 3, :, :] + albedo = frame_rendered['kd'].permute(0, 3, 1, 2)[:, :3, :, :] + if 'shading' in frame_rendered: + shading = frame_rendered['shading'].permute(0, 3, 1, 2)[:, :1, :, :] + else: + shading = None + if render_flow: + flow_pred = frame_rendered['flow'] + flow_pred = flow_pred.permute(0, 3, 1, 2)[:, :2, :, :] + else: + flow_pred = None + if dino_pred is not None: + dino_feat_im_pred = frame_rendered['dino_feat_im_pred'] + dino_feat_im_pred = dino_feat_im_pred.permute(0, 3, 1, 2)[:, :-1] + else: + dino_feat_im_pred = None + + return image_pred, mask_pred, flow_pred, dino_feat_im_pred, albedo, shading + + def compute_reconstruction_losses(self, image_pred, image_gt, mask_pred, mask_gt, mask_dt, mask_valid, flow_pred, flow_gt, dino_feat_im_gt, dino_feat_im_pred, background_mode='none', reduce=False): + losses = {} + batch_size, num_frames, _, h, w = image_pred.shape # BxFxCxHxW + + # image_loss = (image_pred - image_gt) ** 2 + image_loss = (image_pred - image_gt).abs() + + ## silhouette loss + mask_pred_valid = mask_pred * mask_valid + # mask_pred_valid = mask_pred + # losses["silhouette_loss"] = ((mask_pred - mask_gt) ** 2).mean() + # mask_loss_mask = (image_loss.mean(2).detach() > 0.05).float() + mask_loss = (mask_pred_valid - mask_gt) ** 2 + # mask_loss = nn.functional.mse_loss(mask_pred, mask_gt) + # num_mask_pixels = mask_loss_mask.reshape(batch_size*num_frames, -1).sum(1).clamp(min=1) + # losses["silhouette_loss"] = (mask_loss.reshape(batch_size*num_frames, -1).sum(1) / num_mask_pixels).mean() + losses['silhouette_loss'] = mask_loss.view(batch_size, num_frames, -1).mean(2) + losses['silhouette_dt_loss'] = (mask_pred * mask_dt[:,:,1]).view(batch_size, num_frames, -1).mean(2) + losses['silhouette_inv_dt_loss'] = ((1-mask_pred) * mask_dt[:,:,0]).view(batch_size, num_frames, -1).mean(2) + + mask_pred_binary = (mask_pred_valid > 0.).float().detach() + mask_both_binary = (mask_pred_binary * mask_gt).view(batch_size*num_frames, 1, *mask_pred.shape[2:]) + mask_both_binary = (nn.functional.avg_pool2d(mask_both_binary, 3, stride=1, padding=1).view(batch_size, num_frames, *mask_pred.shape[2:]) > 0.99).float().detach() # erode by 1 pixel + + ## reconstruction loss + # image_loss_mask = (mask_pred*mask_gt).unsqueeze(2).expand_as(image_gt) + # image_loss = image_loss * image_loss_mask + # num_mask_pixels = image_loss_mask.reshape(batch_size*num_frames, -1).sum(1).clamp(min=1) + # losses["rgb_loss"] = (image_loss.reshape(batch_size*num_frames, -1).sum(1) / num_mask_pixels).mean() + if background_mode in ['background', 'input']: + pass + else: + image_loss = image_loss * mask_both_binary.unsqueeze(2) + losses['rgb_loss'] = image_loss.reshape(batch_size, num_frames, -1).mean(2) + + if self.cfgs.get('perceptual_loss_weight', 0.) > 0: + if background_mode in ['background', 'input']: + perc_image_pred = image_pred + perc_image_gt = image_gt + else: + perc_image_pred = image_pred * mask_pred_binary.unsqueeze(2) + 0.5 * (1-mask_pred_binary.unsqueeze(2)) + perc_image_gt = image_gt * mask_pred_binary.unsqueeze(2) + 0.5 * (1-mask_pred_binary.unsqueeze(2)) + losses['perceptual_loss'] = self.perceptual_loss(perc_image_pred.view(-1, *image_pred.shape[2:]) *2-1, perc_image_gt.view(-1, *image_gt.shape[2:]) *2-1).view(batch_size, num_frames) + + ## flow loss - between first and second frame + if flow_pred is not None: + flow_loss = (flow_pred - flow_gt).abs() + flow_loss_mask = mask_both_binary[:,:-1].unsqueeze(2).expand_as(flow_gt).detach() + + ## ignore frames where GT flow is too large (likely inaccurate) + large_flow = (flow_gt.abs() > 0.5).float() * flow_loss_mask + large_flow = (large_flow.view(batch_size, num_frames-1, -1).sum(2) > 0).float() + self.large_flow = large_flow + + flow_loss = flow_loss * flow_loss_mask * (1 - large_flow[:,:,None,None,None]) + num_mask_pixels = flow_loss_mask.reshape(batch_size, num_frames-1, -1).sum(2).clamp(min=1) + losses['flow_loss'] = (flow_loss.reshape(batch_size, num_frames-1, -1).sum(2) / num_mask_pixels) + # losses["flow_loss"] = flow_loss.mean() + + if dino_feat_im_pred is not None: + dino_feat_loss = (dino_feat_im_pred - dino_feat_im_gt) ** 2 + dino_feat_loss = dino_feat_loss * mask_both_binary.unsqueeze(2) + losses['dino_feat_im_loss'] = dino_feat_loss.reshape(batch_size, num_frames, -1).mean(2) + + if reduce: + for k, v in losses.item(): + losses[k] = v.mean() + return losses + + def compute_pose_xflip_reg_loss(self, input_image, dino_feat_im, pose_raw, input_image_xflip_flag=None): + image_xflip = input_image.flip(4) + if dino_feat_im is not None: + dino_feat_im_xflip = dino_feat_im.flip(4) + else: + dino_feat_im_xflip = None + feat_xflip, _ = self.netInstance.forward_encoder(image_xflip, dino_feat_im_xflip) + batch_size, num_frames = input_image.shape[:2] + pose_xflip_raw = self.netInstance.forward_pose(image_xflip, feat_xflip, dino_feat_im_xflip) + + if input_image_xflip_flag is not None: + pose_xflip_raw_xflip = pose_xflip_raw * torch.FloatTensor([-1,1,1,-1,1,1]).to(pose_raw.device) # forward x, trans x + pose_xflip_raw = pose_xflip_raw * (1 - input_image_xflip_flag.view(batch_size * num_frames, 1)) + pose_xflip_raw_xflip * input_image_xflip_flag.view(batch_size * num_frames, 1) + + rot_rep = self.netInstance.rot_rep + if rot_rep == 'euler_angle' or rot_rep == 'soft_calss': + pose_xflip_xflip = pose_xflip * torch.FloatTensor([1,-1,-1,-1,1,1]).to(pose_xflip.device) # rot y+z, trans x + pose_xflip_reg_loss = ((pose_xflip_xflip - pose) ** 2.).mean() + elif rot_rep == 'quaternion': + rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose[...,:4]), convention='XYZ') + pose_euler = torch.cat([rot_euler, pose[...,4:]], -1) + rot_xflip_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose_xflip[...,:4]), convention='XYZ') + pose_xflip_euler = torch.cat([rot_xflip_euler, pose_xflip[...,4:]], -1) + pose_xflip_euler_xflip = pose_xflip_euler * torch.FloatTensor([1,-1,-1,-1,1,1]).to(pose_xflip.device) # rot y+z, trans x + pose_xflip_reg_loss = ((pose_xflip_euler_xflip - pose_euler) ** 2.).mean() + elif rot_rep == 'lookat': + pose_xflip_raw_xflip = pose_xflip_raw * torch.FloatTensor([-1,1,1,-1,1,1]).to(pose_raw.device) # forward x, trans x + pose_xflip_reg_loss = ((pose_xflip_raw_xflip - pose_raw)[...,0] ** 2.) # compute x only + # if epoch >= self.nolookat_zflip_loss_epochs and self.lookat_zflip_no_other_losses: + # pose_xflip_reg_loss = pose_xflip_reg_loss.mean(1) * is_pose_1_better + pose_xflip_reg_loss = pose_xflip_reg_loss.mean() + return pose_xflip_reg_loss, pose_xflip_raw + + def compute_edge_length_reg_loss(self, mesh, prior_mesh): + prior_edge_lengths = get_edge_length(prior_mesh.v_pos, prior_mesh.t_pos_idx) + max_length = prior_edge_lengths.max().detach() *1.1 + edge_lengths = get_edge_length(mesh.v_pos, mesh.t_pos_idx) + mesh_edge_length_loss = ((edge_lengths - max_length).clamp(min=0)**2).mean() + return mesh_edge_length_loss, edge_lengths + + def compute_regularizers(self, mesh, prior_mesh, input_image, dino_feat_im, pose_raw, input_image_xflip_flag=None, arti_params=None, deformation=None): + losses = {} + aux = {} + + if self.enable_prior: + losses.update(self.netPrior.netShape.get_sdf_reg_loss()) + + if self.cfgs.get('pose_xflip_reg_loss_weight', 0.) > 0: + losses["pose_xflip_reg_loss"], aux['pose_xflip_raw'] = self.compute_pose_xflip_reg_loss(input_image, dino_feat_im, pose_raw, input_image_xflip_flag) + + b, f = input_image.shape[:2] + if b >= 2: + vec_forward = pose_raw[..., :3] + losses['pose_entropy_loss'] = (vec_forward[:b//2] * vec_forward[b//2:(b//2)*2]).sum(-1).mean() + else: + losses['pose_entropy_loss'] = 0. + + losses['mesh_normal_consistency_loss'] = normal_consistency(mesh.v_pos, mesh.t_pos_idx) + losses['mesh_edge_length_loss'], aux['edge_lengths'] = self.compute_edge_length_reg_loss(mesh, prior_mesh) + if arti_params is not None: + losses['arti_reg_loss'] = (arti_params ** 2).mean() + + if deformation is not None: + losses['deformation_reg_loss'] = (deformation ** 2).mean() + # losses['deformation_reg_loss'] = deformation.abs().mean() + + return losses, aux + + def forward(self, batch, epoch, iter, is_train=True, viz_logger=None, total_iter=None, save_results=False, save_dir=None, which_data='', logger_prefix='', is_training=True): + batch = [x.to(self.device) if x is not None else None for x in batch] + input_image, mask_gt, mask_dt, mask_valid, flow_gt, bbox, bg_image, dino_feat_im, dino_cluster_im, seq_idx, frame_idx = batch + batch_size, num_frames, _, h0, w0 = input_image.shape # BxFxCxHxW + h = w = self.out_image_size + + def collapseF(x): + return None if x is None else x.view(batch_size * num_frames, *x.shape[2:]) + def expandF(x): + return None if x is None else x.view(batch_size, num_frames, *x.shape[1:]) + + if flow_gt.dim() == 2: # dummy tensor for not loading flow + flow_gt = None + if dino_feat_im.dim() == 2: # dummy tensor for not loading dino features + dino_feat_im = None + dino_feat_im_gt = None + else: + dino_feat_im_gt = expandF(torch.nn.functional.interpolate(collapseF(dino_feat_im), size=[h, w], mode="bilinear"))[:, :, :self.dino_feature_recon_dim] + if dino_cluster_im.dim() == 2: # dummy tensor for not loading dino clusters + dino_cluster_im = None + dino_cluster_im_gt = None + else: + dino_cluster_im_gt = expandF(torch.nn.functional.interpolate(collapseF(dino_cluster_im), size=[h, w], mode="nearest")) + + seq_idx = seq_idx.squeeze(1) + # seq_idx = seq_idx * 0 # single sequnce model + frame_id, crop_x0, crop_y0, crop_w, crop_h, full_w, full_h, sharpness = bbox.unbind(2) # BxFx7 + bbox = torch.stack([crop_x0, crop_y0, crop_w, crop_h], 2) + mask_gt = (mask_gt[:, :, 0, :, :] > 0.9).float() # BxFxHxW + mask_dt = mask_dt / self.in_image_size + + if which_data != 'video': + flow_gt = None + + aux_viz = {} + + ## GT + image_gt = input_image + if self.out_image_size != self.in_image_size: + image_gt = expandF(torch.nn.functional.interpolate(collapseF(image_gt), size=[h, w], mode='bilinear')) + if flow_gt is not None: + flow_gt = torch.nn.functional.interpolate(flow_gt.view(batch_size*(num_frames-1), 2, h0, w0), size=[h, w], mode="bilinear").view(batch_size, num_frames-1, 2, h, w) + + self.train_pose_only = False + if epoch in self.pose_epochs: + if (total_iter // self.pose_iters) % 2 == 0: + self.train_pose_only = True + + ## flip input and pose + if epoch in self.pose_xflip_recon_epochs: + input_image_xflip = input_image.flip(-1) + input_image_xflip_flag = torch.randint(0, 2, (batch_size, num_frames), device=input_image.device) + input_image = input_image * (1 - input_image_xflip_flag[:,:,None,None,None]) + input_image_xflip * input_image_xflip_flag[:,:,None,None,None] + else: + input_image_xflip_flag = None + + ## 1st pose hypothesis with original predictions + + # ============================================================================================== + # Predict prior mesh. + # ============================================================================================== + if self.enable_prior: + if epoch < self.dmtet_grid_smaller_epoch: + if self.netPrior.netShape.grid_res != self.dmtet_grid_smaller: + self.netPrior.netShape.load_tets(self.dmtet_grid_smaller) + else: + if self.netPrior.netShape.grid_res != self.dmtet_grid: + self.netPrior.netShape.load_tets(self.dmtet_grid) + + perturb_sdf = self.perturb_sdf if is_train else False + prior_shape, dino_pred = self.netPrior(perturb_sdf=perturb_sdf, total_iter=total_iter, is_training=is_training) + else: + prior_shape = None + raise NotImplementedError + + shape, pose_raw, pose, mvp, w2c, campos, texture, im_features, deformation, arti_params, light, forward_aux = self.netInstance(input_image, prior_shape, epoch, dino_feat_im, dino_cluster_im, total_iter, is_training=is_training) # frame dim collapsed N=(B*F) + rot_logit = forward_aux['rot_logit'] + rot_idx = forward_aux['rot_idx'] + rot_prob = forward_aux['rot_prob'] + aux_viz.update(forward_aux) + + if self.train_pose_only: + safe_detach = lambda x: x.detach() if x is not None else None + prior_shape = safe_detach(prior_shape) + shape = safe_detach(shape) + im_features = safe_detach(im_features) + arti_params = safe_detach(arti_params) + deformation = safe_detach(deformation) + set_requires_grad(texture, False) + set_requires_grad(light, False) + set_requires_grad(dino_pred, False) + else: + set_requires_grad(texture, True) + set_requires_grad(light, True) + set_requires_grad(dino_pred, True) + + render_flow = self.render_flow and num_frames > 1 + image_pred, mask_pred, flow_pred, dino_feat_im_pred, albedo, shading = self.render(shape, texture, mvp, w2c, campos, (h, w), background=self.background_mode, im_features=im_features, light=light, prior_shape=prior_shape, render_flow=render_flow, dino_pred=dino_pred, num_frames=num_frames, spp=self.renderer_spp) + image_pred, mask_pred, flow_pred, dino_feat_im_pred = map(expandF, (image_pred, mask_pred, flow_pred, dino_feat_im_pred)) + if flow_pred is not None: + flow_pred = flow_pred[:, :-1] # Bx(F-1)x2xHxW + + if self.blur_mask: + sigma = max(0.5, 3 * (1 - total_iter / self.blur_mask_iter)) + if sigma > 0.5: + mask_gt = util.blur_image(mask_gt, kernel_size=9, sigma=sigma, mode='gaussian') + # mask_pred = util.blur_image(mask_pred, kernel_size=7, mode='average') + + losses = self.compute_reconstruction_losses(image_pred, image_gt, mask_pred, mask_gt, mask_dt, mask_valid, flow_pred, flow_gt, dino_feat_im_gt, dino_feat_im_pred, background_mode=self.background_mode, reduce=False) + + ## TODO: assume flow loss is not used + logit_loss_target = torch.zeros_like(expandF(rot_logit)) + final_losses = {} + for name, loss in losses.items(): + loss_weight_logit = self.cfgs.get(f"{name}_weight", 0.) + # if (name in ['flow_loss'] and epoch not in self.flow_loss_epochs) or (name in ['rgb_loss', 'perceptual_loss'] and epoch not in self.texture_epochs): + # if name in ['flow_loss', 'rgb_loss', 'perceptual_loss']: + # loss_weight_logit = 0. + if name in ['sdf_bce_reg_loss', 'sdf_gradient_reg_loss', 'sdf_inflate_reg_loss']: + if total_iter >= self.sdf_reg_decay_start_iter: + decay_rate = max(0, 1 - (total_iter-self.sdf_reg_decay_start_iter) / 10000) + loss_weight_logit = max(loss_weight_logit * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.)) + if name in ['dino_feat_im_loss']: + loss_weight_logit = loss_weight_logit * self.cfgs.get("logit_loss_dino_feat_im_loss_multiplier", 1.) + if loss_weight_logit > 0: + logit_loss_target += loss * loss_weight_logit + + if self.netInstance.rot_rep in ['quadlookat', 'octlookat']: + loss = loss * rot_prob.detach().view(batch_size, num_frames)[:, :loss.shape[1]] *self.netInstance.num_pose_hypos + if name == 'flow_loss' and num_frames > 1: + ri = rot_idx.view(batch_size, num_frames) + same_rot_idx = (ri[:, 1:] == ri[:, :-1]).float() + loss = loss * same_rot_idx + final_losses[name] = loss.mean() + final_losses['logit_loss'] = ((expandF(rot_logit) - logit_loss_target.detach())**2.).mean() + + ## regularizers + regularizers, aux = self.compute_regularizers(shape, prior_shape, input_image, dino_feat_im, pose_raw, input_image_xflip_flag, arti_params, deformation) + final_losses.update(regularizers) + aux_viz.update(aux) + + total_loss = 0 + for name, loss in final_losses.items(): + loss_weight = self.cfgs.get(f"{name}_weight", 0.) + if loss_weight <= 0: + continue + + if self.train_pose_only: + if name not in ['silhouette_loss', 'silhouette_dt_loss', 'silhouette_inv_dt_loss', 'flow_loss', 'pose_xflip_reg_loss', 'lookat_zflip_loss', 'dino_feat_im_loss']: + continue + if epoch not in self.flow_loss_epochs: + if name in ['flow_loss']: + continue + if epoch not in self.texture_epochs: + if name in ['rgb_loss', 'perceptual_loss']: + continue + if epoch not in self.lookat_zflip_loss_epochs: + if name in ['lookat_zflip_loss']: + continue + if name in ['mesh_laplacian_smoothing_loss', 'mesh_normal_consistency_loss']: + if total_iter < self.cfgs.get('mesh_reg_start_iter', 0): + continue + if epoch >= self.mesh_reg_decay_epoch: + decay_rate = self.mesh_reg_decay_rate ** (epoch - self.mesh_reg_decay_epoch) + loss_weight = max(loss_weight * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.)) + if epoch not in self.sdf_inflate_reg_loss_epochs: + if name in ['sdf_inflate_reg_loss']: + continue + if epoch not in self.arti_reg_loss_epochs: + if name in ['arti_reg_loss']: + continue + if name in ['sdf_bce_reg_loss', 'sdf_gradient_reg_loss', 'sdf_inflate_reg_loss']: + if total_iter >= self.sdf_reg_decay_start_iter: + decay_rate = max(0, 1 - (total_iter-self.sdf_reg_decay_start_iter) / 10000) + loss_weight = max(loss_weight * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.)) + + total_loss += loss * loss_weight + + self.total_loss += total_loss # reset to 0 in backward step + + if torch.isnan(self.total_loss): + print("NaN in loss...") + import ipdb; ipdb.set_trace() + + final_losses['logit_loss_target'] = logit_loss_target.mean() + + metrics = {'loss': total_loss, **final_losses} + + ## log visuals + if viz_logger is not None: + b0 = max(min(batch_size, 16//num_frames), 1) + viz_logger.add_image(logger_prefix+'image/image_gt', misc.image_grid(image_gt.detach().cpu()[:b0,:].reshape(-1,*input_image.shape[2:]).clamp(0,1)), total_iter) + viz_logger.add_image(logger_prefix+'image/image_pred', misc.image_grid(image_pred.detach().cpu()[:b0,:].reshape(-1,*image_pred.shape[2:]).clamp(0,1)), total_iter) + # viz_logger.add_image(logger_prefix+'image/flow_loss_mask', misc.image_grid(flow_loss_mask[:b0,:,:1].reshape(-1,1,*flow_loss_mask.shape[3:]).repeat(1,3,1,1).clamp(0,1)), total_iter) + viz_logger.add_image(logger_prefix+'image/mask_gt', misc.image_grid(mask_gt.detach().cpu()[:b0,:].reshape(-1,*mask_gt.shape[2:]).unsqueeze(1).repeat(1,3,1,1).clamp(0,1)), total_iter) + viz_logger.add_image(logger_prefix+'image/mask_pred', misc.image_grid(mask_pred.detach().cpu()[:b0,:].reshape(-1,*mask_pred.shape[2:]).unsqueeze(1).repeat(1,3,1,1).clamp(0,1)), total_iter) + + if self.render_flow and flow_gt is not None: + flow_gt = flow_gt.detach().cpu() + flow_gt_viz = torch.cat([flow_gt[:b0], torch.zeros_like(flow_gt[:b0,:,:1])], 2) + 0.5 # -0.5~1.5 + flow_gt_viz = torch.nn.functional.pad(flow_gt_viz, pad=[0, 0, 0, 0, 0, 0, 0, 1]) + + ## draw marker on large flow frames + large_flow_marker_mask = torch.zeros_like(flow_gt_viz) + large_flow_marker_mask[:,:,:,:8,:8] = 1. + large_flow = torch.cat([self.large_flow, self.large_flow[:,:1] *0.], 1).detach().cpu()[:b0] + large_flow_marker_mask = large_flow_marker_mask * large_flow[:,:,None,None,None] + red = torch.FloatTensor([1,0,0])[None,None,:,None,None] + flow_gt_viz = large_flow_marker_mask * red + (1-large_flow_marker_mask) * flow_gt_viz + + viz_logger.add_image(logger_prefix+'image/flow_gt', misc.image_grid(flow_gt_viz.reshape(-1,*flow_gt_viz.shape[2:])), total_iter) + + if self.render_flow and flow_pred is not None: + flow_pred = flow_pred.detach().cpu() + flow_pred_viz = torch.cat([flow_pred[:b0], torch.zeros_like(flow_pred[:b0,:,:1])], 2) + 0.5 # -0.5~1.5 + flow_pred_viz = torch.nn.functional.pad(flow_pred_viz, pad=[0, 0, 0, 0, 0, 0, 0, 1]) + viz_logger.add_image(logger_prefix+'image/flow_pred', misc.image_grid(flow_pred_viz.reshape(-1,*flow_pred_viz.shape[2:])), total_iter) + + if light is not None: + param_names = ['dir_x', 'dir_y', 'dir_z', 'int_ambient', 'int_diffuse'] + for name, param in zip(param_names, light.light_params.unbind(-1)): + viz_logger.add_histogram(logger_prefix+'light/'+name, param, total_iter) + viz_logger.add_image( + logger_prefix + f'image/albedo', + misc.image_grid(expandF(albedo)[:b0, ...].view(-1, *albedo.shape[1:])), + total_iter) + viz_logger.add_image( + logger_prefix + f'image/shading', + misc.image_grid(expandF(shading)[:b0, ...].view(-1, *shading.shape[1:]).repeat(1, 3, 1, 1) /2.), + total_iter) + + viz_logger.add_histogram(logger_prefix+'sdf', self.netPrior.netShape.get_sdf(perturb_sdf=False), total_iter) + viz_logger.add_histogram(logger_prefix+'coordinates', shape.v_pos, total_iter) + if arti_params is not None: + viz_logger.add_histogram(logger_prefix+'arti_params', arti_params, total_iter) + viz_logger.add_histogram(logger_prefix+'edge_lengths', aux_viz['edge_lengths'], total_iter) + + if deformation is not None: + viz_logger.add_histogram(logger_prefix+'deformation', deformation, total_iter) + + rot_rep = self.netInstance.rot_rep + if rot_rep == 'euler_angle' or rot_rep == 'soft_calss': + for i, name in enumerate(['rot_x', 'rot_y', 'rot_z', 'trans_x', 'trans_y', 'trans_z']): + viz_logger.add_histogram(logger_prefix+'pose/'+name, pose[...,i], total_iter) + elif rot_rep == 'quaternion': + for i, name in enumerate(['qt_0', 'qt_1', 'qt_2', 'qt_3', 'trans_x', 'trans_y', 'trans_z']): + viz_logger.add_histogram(logger_prefix+'pose/'+name, pose[...,i], total_iter) + rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose.detach().cpu()[...,:4]), convention='XYZ') + for i, name in enumerate(['rot_x', 'rot_y', 'rot_z']): + viz_logger.add_histogram(logger_prefix+'pose/'+name, rot_euler[...,i], total_iter) + elif rot_rep in ['lookat', 'quadlookat', 'octlookat']: + for i, name in enumerate(['fwd_x', 'fwd_y', 'fwd_z']): + viz_logger.add_histogram(logger_prefix+'pose/'+name, pose_raw[...,i], total_iter) + for i, name in enumerate(['trans_x', 'trans_y', 'trans_z']): + viz_logger.add_histogram(logger_prefix+'pose/'+name, pose_raw[...,-3+i], total_iter) + + if rot_rep in ['quadlookat', 'octlookat']: + for i, rp in enumerate(forward_aux['rots_probs'].unbind(-1)): + viz_logger.add_histogram(logger_prefix+'pose/rot_prob_%d'%i, rp, total_iter) + + if 'pose_xflip_raw' in aux_viz: + pose_xflip_raw = aux_viz['pose_xflip_raw'] + if rot_rep == 'euler_angle' or rot_rep == 'soft_calss': + for i, name in enumerate(['rot_x', 'rot_y', 'rot_z', 'trans_x', 'trans_y', 'trans_z']): + viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip[...,i], total_iter) + elif rot_rep == 'quaternion': + for i, name in enumerate(['qt_0', 'qt_1', 'qt_2', 'qt_3', 'trans_x', 'trans_y', 'trans_z']): + viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip[...,i], total_iter) + rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose_xflip.detach().cpu()[...,:4]), convention='XYZ') + for i, name in enumerate(['rot_x', 'rot_y', 'rot_z']): + viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, rot_euler[...,i], total_iter) + elif rot_rep in ['lookat', 'quadlookat', 'octlookat']: + for i, name in enumerate(['fwd_x', 'fwd_y', 'fwd_z']): + viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip_raw[...,i], total_iter) + for i, name in enumerate(['trans_x', 'trans_y', 'trans_z']): + viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip_raw[...,-3+i], total_iter) + + if dino_feat_im_gt is not None: + dino_feat_im_gt_first3 = dino_feat_im_gt[:,:,:3] + viz_logger.add_image(logger_prefix+'image/dino_feat_im_gt', misc.image_grid(dino_feat_im_gt_first3.detach().cpu()[:b0,:].reshape(-1,*dino_feat_im_gt_first3.shape[2:]).clamp(0,1)), total_iter) + + if dino_cluster_im_gt is not None: + viz_logger.add_image(logger_prefix+'image/dino_cluster_im_gt', misc.image_grid(dino_cluster_im_gt.detach().cpu()[:b0,:].reshape(-1,*dino_cluster_im_gt.shape[2:]).clamp(0,1)), total_iter) + + if dino_feat_im_pred is not None: + dino_feat_im_pred_first3 = dino_feat_im_pred[:,:,:3] + viz_logger.add_image(logger_prefix+'image/dino_feat_im_pred', misc.image_grid(dino_feat_im_pred_first3.detach().cpu()[:b0,:].reshape(-1,*dino_feat_im_pred_first3.shape[2:]).clamp(0,1)), total_iter) + + for which_shape, modes in self.extra_renders.items(): + # This is wrong + # if which_shape == "prior": + # shape_to_render = prior_shape.extend(im_features.shape[0]) + # needed_im_features = None + if which_shape == "instance": + shape_to_render = shape + needed_im_features = im_features + else: + raise NotImplementedError + + for mode in modes: + rendered, _, _, _, _, _ = self.render(shape_to_render, texture, mvp, w2c, campos, (h, w), background=self.background_mode, im_features=needed_im_features, prior_shape=prior_shape, render_mode=mode, render_flow=False, dino_pred=None) + if 'kd' in mode: + rendered = util.rgb_to_srgb(rendered) + rendered = rendered.detach().cpu() + + if 'posed_bones' in aux_viz: + rendered_bone_image = self.render_bones(mvp, aux_viz['posed_bones'], (h, w)) + rendered_bone_image_mask = (rendered_bone_image < 1).any(1, keepdim=True).float() + # viz_logger.add_image(logger_prefix+'image/articulation_bones', misc.image_grid(self.render_bones(mvp, aux_viz['posed_bones'])), total_iter) + rendered = rendered_bone_image_mask*0.8 * rendered_bone_image + (1-rendered_bone_image_mask*0.8) * rendered + + if rot_rep in ['quadlookat', 'octlookat']: + rand_pose_flag = forward_aux['rand_pose_flag'].detach().cpu() + rand_pose_marker_mask = torch.zeros_like(rendered) + rand_pose_marker_mask[:,:,:16,:16] = 1. + rand_pose_marker_mask = rand_pose_marker_mask * rand_pose_flag[:,None,None,None] + red = torch.FloatTensor([1,0,0])[None,:,None,None] + rendered = rand_pose_marker_mask * red + (1-rand_pose_marker_mask) * rendered + + viz_logger.add_image( + logger_prefix + f'image/{which_shape}_{mode}', + misc.image_grid(expandF(rendered)[:b0, ...].view(-1, *rendered.shape[1:])), + total_iter) + + viz_logger.add_video( + logger_prefix + f'animation/{which_shape}_{mode}', + self.render_rotation_frames(shape_to_render, texture, light, (h, w), background=self.background_mode, im_features=needed_im_features, prior_shape=prior_shape, num_frames=15, render_mode=mode, b=1).detach().cpu().unsqueeze(0), + total_iter, + fps=2) + + viz_logger.add_video( + logger_prefix+'animation/prior_image_rotation', + self.render_rotation_frames(prior_shape, texture, light, (h, w), background=self.background_mode, im_features=im_features, num_frames=15, b=1).detach().cpu().unsqueeze(0).clamp(0,1), + total_iter, + fps=2) + + viz_logger.add_video( + logger_prefix+'animation/prior_normal_rotation', + self.render_rotation_frames(prior_shape, texture, light, (h, w), background=self.background_mode, im_features=im_features, num_frames=15, render_mode='geo_normal', b=1).detach().cpu().unsqueeze(0), + total_iter, + fps=2) + + if save_results: + b0 = self.cfgs.get('num_saved_from_each_batch', batch_size*num_frames) + fnames = [f'{total_iter:07d}_{fid:10d}' for fid in collapseF(frame_id.int())][:b0] + + misc.save_images(save_dir, collapseF(image_gt)[:b0].clamp(0,1).detach().cpu().numpy(), suffix='image_gt', fnames=fnames) + misc.save_images(save_dir, collapseF(image_pred)[:b0].clamp(0,1).detach().cpu().numpy(), suffix='image_pred', fnames=fnames) + misc.save_images(save_dir, collapseF(mask_gt)[:b0].unsqueeze(1).repeat(1,3,1,1).clamp(0,1).detach().cpu().numpy(), suffix='mask_gt', fnames=fnames) + misc.save_images(save_dir, collapseF(mask_pred)[:b0].unsqueeze(1).repeat(1,3,1,1).clamp(0,1).detach().cpu().numpy(), suffix='mask_pred', fnames=fnames) + # tmp_shape = shape.first_n(b0).clone() + # tmp_shape.material = texture + # feat = im_features[:b0] if im_features is not None else None + # misc.save_obj(save_dir, tmp_shape, save_material=False, feat=feat, suffix="mesh", fnames=fnames) # Save the first mesh. + # if self.render_flow and flow_gt is not None: + # flow_gt_viz = torch.cat([flow_gt, torch.zeros_like(flow_gt[:,:,:1])], 2) + 0.5 # -0.5~1.5 + # flow_gt_viz = flow_gt_viz.view(-1, *flow_gt_viz.shape[2:]) + # misc.save_images(save_dir, flow_gt_viz[:b0].clamp(0,1).detach().cpu().numpy(), suffix='flow_gt', fnames=fnames) + # if flow_pred is not None: + # flow_pred_viz = torch.cat([flow_pred, torch.zeros_like(flow_pred[:,:,:1])], 2) + 0.5 # -0.5~1.5 + # flow_pred_viz = flow_pred_viz.view(-1, *flow_pred_viz.shape[2:]) + # misc.save_images(save_dir, flow_pred_viz[:b0].clamp(0,1).detach().cpu().numpy(), suffix='flow_pred', fnames=fnames) + + misc.save_txt(save_dir, pose[:b0].detach().cpu().numpy(), suffix='pose', fnames=fnames) + + return metrics + + def save_scores(self, path): + header = 'mask_mse, \ + mask_iou, \ + image_mse, \ + flow_mse' + mean = self.all_scores.mean(0) + std = self.all_scores.std(0) + header = header + '\nMean: ' + ',\t'.join(['%.8f'%x for x in mean]) + header = header + '\nStd: ' + ',\t'.join(['%.8f'%x for x in std]) + misc.save_scores(path, self.all_scores, header=header) + print(header) + + def render_rotation_frames(self, mesh, texture, light, resolution, background='none', im_features=None, prior_shape=None, num_frames=36, render_mode='diffuse', b=None): + frames = [] + if b is None: + b = len(mesh) + else: + mesh = mesh.first_n(b) + feat = im_features[:b] if im_features is not None else None + + delta_angle = np.pi / num_frames * 2 + delta_rot_matrix = torch.FloatTensor([ + [np.cos(delta_angle), 0, np.sin(delta_angle), 0], + [0, 1, 0, 0], + [-np.sin(delta_angle), 0, np.cos(delta_angle), 0], + [0, 0, 0, 1], + ]).to(self.device).repeat(b, 1, 1) + + w2c = torch.FloatTensor(np.diag([1., 1., 1., 1])) + w2c[:3, 3] = torch.FloatTensor([0, 0, -self.cam_pos_z_offset *1.1]) + w2c = w2c.repeat(b, 1, 1).to(self.device) + proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device) + mvp = torch.bmm(proj, w2c) + campos = -w2c[:, :3, 3] + + def rotate_pose(mvp, campos): + mvp = torch.matmul(mvp, delta_rot_matrix) + campos = torch.matmul(delta_rot_matrix[:,:3,:3].transpose(2,1), campos[:,:,None])[:,:,0] + return mvp, campos + + for _ in range(num_frames): + image_pred, _, _, _, _, _ = self.render(mesh, texture, mvp, w2c, campos, resolution, background=background, im_features=feat, light=light, prior_shape=prior_shape, render_flow=False, dino_pred=None, render_mode=render_mode, two_sided_shading=False) + frames += [misc.image_grid(image_pred)] + mvp, campos = rotate_pose(mvp, campos) + return torch.stack(frames, dim=0) # Shape: (T, C, H, W) + + def render_bones(self, mvp, bones_pred, size=(256, 256)): + bone_world4 = torch.concat([bones_pred, torch.ones_like(bones_pred[..., :1]).to(bones_pred.device)], dim=-1) + b, f, num_bones = bone_world4.shape[:3] + bones_clip4 = (bone_world4.view(b, f, num_bones*2, 1, 4) @ mvp.transpose(-1, -2).reshape(b, f, 1, 4, 4)).view(b, f, num_bones, 2, 4) + bones_uv = bones_clip4[..., :2] / bones_clip4[..., 3:4] # b, f, num_bones, 2, 2 + dpi = 32 + fx, fy = size[1] // dpi, size[0] // dpi + + rendered = [] + for b_idx in range(b): + for f_idx in range(f): + frame_bones_uv = bones_uv[b_idx, f_idx].cpu().numpy() + fig = plt.figure(figsize=(fx, fy), dpi=dpi, frameon=False) + ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax.set_axis_off() + for bone in frame_bones_uv: + ax.plot(bone[:, 0], bone[:, 1], marker='o', linewidth=8, markersize=20) + ax.set_xlim(-1, 1) + ax.set_ylim(-1, 1) + ax.invert_yaxis() + # Convert to image + fig.add_axes(ax) + fig.canvas.draw_idle() + image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + w, h = fig.canvas.get_width_height() + image.resize(h, w, 3) + rendered += [image / 255.] + return torch.from_numpy(np.stack(rendered, 0).transpose(0, 3, 1, 2)) + + def render_deformation_frames(self, mesh, texture, batch_size, num_frames, resolution, background='none', im_features=None, render_mode='diffuse', b=None): + # frames = [] + # if b is None: + # b = batch_size + # im_features = im_features[] + # mesh = mesh.first_n(num_frames * b) + # for i in range(b): + # tmp_mesh = mesh.get_m_to_n(i*num_frames:(i+1)*num_frames) + pass diff --git a/video3d/model_ddp.py b/video3d/model_ddp.py new file mode 100755 index 0000000000000000000000000000000000000000..9286bb4ba4a9a0c08ef5e5a556e2ffd9485b7ecb --- /dev/null +++ b/video3d/model_ddp.py @@ -0,0 +1,3515 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.nn.init as init + +import torchvision.models as models +import nvdiffrast.torch as dr +import numpy as np +import matplotlib.pyplot as plt +import os +import os.path as osp +import pickle + +from video3d.render.regularizer import get_edge_length, normal_consistency, laplace_regularizer_const +from . import networks +from .renderer import * +from .utils import misc, meters, flow_viz, arap, custom_loss +from .dataloaders import get_sequence_loader, get_image_loader +from .dataloaders_ddp import get_sequence_loader_ddp, get_image_loader_ddp +from .cub_dataloaders import get_cub_loader +from .cub_dataloaders_ddp import get_cub_loader_ddp +from .utils.skinning_v4 import estimate_bones, skinning +import lpips +from einops import rearrange, repeat + +import clip +import torchvision.transforms.functional as tvf +from . import discriminator_architecture + +from .geometry.dmtet import DMTetGeometry +from .geometry.dlmesh import DLMesh + +from .triplane_texture.triplane_predictor import TriPlaneTex + +from .render import renderutils as ru +from .render import material +from .render import mlptexture +from .render import util +from .render import mesh +from .render import light +from .render import render + +from .diffusion.sd import StableDiffusion +from .diffusion.vsd import StableDiffusion_VSD +from .diffusion.sd_utils import rand_poses, rand_lights, append_text_direction + +EPS = 1e-7 + + +def get_optimizer(model, lr=0.0001, betas=(0.9, 0.999), weight_decay=0): + return torch.optim.Adam( + filter(lambda p: p.requires_grad, model.parameters()), + lr=lr, betas=betas, weight_decay=weight_decay) + + +def set_requires_grad(model, requires_grad): + if model is not None: + for param in model.parameters(): + param.requires_grad = requires_grad + + +def forward_to_matrix(vec_forward, up=[0,1,0]): + up = torch.FloatTensor(up).to(vec_forward.device) + # vec_forward = nn.functional.normalize(vec_forward, p=2, dim=-1) # x right, y up, z forward + vec_right = up.expand_as(vec_forward).cross(vec_forward, dim=-1) + vec_right = nn.functional.normalize(vec_right, p=2, dim=-1) + vec_up = vec_forward.cross(vec_right, dim=-1) + vec_up = nn.functional.normalize(vec_up, p=2, dim=-1) + rot_mat = torch.stack([vec_right, vec_up, vec_forward], -2) + return rot_mat + + +def sample_pose_hypothesis_from_quad_prediction(poses_raw, total_iter, batch_size, num_frames, pose_xflip_recon=False, input_image_xflip_flag=None, rot_temp_scalar=1., num_hypos=4, naive_probs_iter=2000, best_pose_start_iter=6000, random_sample=True, temp_clip_low = 1., temp_clip_high=100.): + rots_pred = poses_raw[..., :num_hypos*4].view(-1, num_hypos, 4) + rots_logits = rots_pred[..., 0] # Nx4 + # temp = 1 / np.clip(total_iter / 1000 / rot_temp_scalar, 1., 100.) + temp = 1 / np.clip(total_iter / 1000 / rot_temp_scalar, temp_clip_low, temp_clip_high) + + rots_probs = torch.nn.functional.softmax(-rots_logits / temp, dim=1) # N x K + # naive_probs = torch.FloatTensor([10] + [1] * (num_hypos - 1)).to(rots_logits.device) + naive_probs = torch.ones(num_hypos).to(rots_logits.device) + naive_probs = naive_probs / naive_probs.sum() + naive_probs_weight = np.clip(1 - (total_iter - naive_probs_iter) / 2000, 0, 1) + rots_probs = naive_probs.view(1, num_hypos) * naive_probs_weight + rots_probs * (1 - naive_probs_weight) + + rots_pred = rots_pred[..., 1:4] + trans_pred = poses_raw[..., -3:] + best_rot_idx = torch.argmax(rots_probs, dim=1) # N + #print("best_rot_idx", best_rot_idx) + #print("best_of_best", torch.argmax(rots_probs)) + #print("similar 7", torch.zeros_like(best_rot_idx) + 7) + #print("similar 2", torch.zeros_like(best_rot_idx) + torch.argmax(rots_probs)) + + if random_sample: + # rand_rot_idx = torch.randint(0, 4, (batch_size * num_frames,), device=poses_raw.device) # N + rand_rot_idx = torch.randperm(batch_size * num_frames, device=poses_raw.device) % num_hypos # N + # rand_rot_idx = torch.randperm(batch_size, device=poses_raw.device)[:,None].repeat(1, num_frames).view(-1) % 4 # N + best_flag = (torch.randperm(batch_size * num_frames, device=poses_raw.device) / (batch_size * num_frames) < np.clip((total_iter - best_pose_start_iter)/2000, 0, 0.8)).long() + rand_flag = 1 - best_flag + # best_flag = torch.zeros_like(best_rot_idx) + rot_idx = best_rot_idx * best_flag + rand_rot_idx * (1 - best_flag) + else: + rand_flag = torch.zeros_like(best_rot_idx) + #rot_idx = torch.full_like(torch.argmax(rots_probs, dim=1), torch.argmax(rots_probs), device=poses_raw.device) + rot_idx = best_rot_idx + + + + rot_pred = torch.gather(rots_pred, 1, rot_idx[:, None, None].expand(-1, 1, 3))[:, 0] # Nx3 + pose_raw = torch.cat([rot_pred, trans_pred], -1) + rot_prob = torch.gather(rots_probs, 1, rot_idx[:, None].expand(-1, 1))[:, 0] # N + rot_logit = torch.gather(rots_logits, 1, rot_idx[:, None].expand(-1, 1))[:, 0] # N + + if pose_xflip_recon: + raise NotImplementedError + + #up = torch.FloatTensor([0, 1, 0]).to(pose_raw.device) + rot_mat = forward_to_matrix(pose_raw[:, :3], up=[0, 1, 0]) + pose = torch.cat([rot_mat.view(batch_size * num_frames, -1), pose_raw[:, 3:]], -1) + return pose_raw, pose, rot_idx, rot_prob, rot_logit, rots_probs, rand_flag + + +def get_joints_20_bones(bones, aux): + # the bones shape is [1, 1, 20, 2, 3] + body_bones_to_joints = aux['bones_to_joints'] + body_bones = bones[:, :, :len(body_bones_to_joints), :, :] + body_joints = torch.empty(bones.shape[0], bones.shape[1], len(body_bones_to_joints) + 1, 3) + + for i, (a, b) in enumerate(body_bones_to_joints): + body_joints[:, :, a, :] = body_bones[:, :, i, 0, :] + body_joints[:, :, b, :] = body_bones[:, :, i, 1, :] + + leg_aux = aux['legs'] + all_leg_joints = [] + for i in range(len(leg_aux)): + leg_bones = bones[:, :, 8+i*3:11+i*3, :, :] + leg_joints = torch.empty(bones.shape[0], bones.shape[1], len(leg_aux[i]['leg_bones_to_joints']), 3) + + for j in range(len(leg_aux[i]['leg_bones_to_joints'])-1): + leg_joint_idx_a = leg_aux[i]['leg_bones_to_joints'][j][0] + leg_joint_idx_b = leg_aux[i]['leg_bones_to_joints'][j][1] + + leg_joints[:, :, leg_joint_idx_a, :] = leg_bones[:, :, j, 0, :] + leg_joints[:, :, leg_joint_idx_b, :] = leg_bones[:, :, j, 1, :] + + all_leg_joints.append(leg_joints) + + all_joints = [body_joints] + all_leg_joints + all_joints = torch.cat(all_joints, dim=2) + return all_joints + + +def get_20_bones_joints(joints, aux): + # the joints shape is [1, 1, 21, 3] + body_bones_to_joints = aux['bones_to_joints'] + body_bones = [] + for a,b in body_bones_to_joints: + body_bones += [torch.stack([joints[:, :, a, :], joints[:, :, b, :]], dim=2)] + body_bones = torch.stack(body_bones, dim=2) # [1, 1, 8, 2, 3] + + legs_bones = [] + legs_aux = aux['legs'] + for i in range(len(legs_aux)): + leg_aux = legs_aux[i] + leg_bones = [] + + leg_bones_to_joints = leg_aux['leg_bones_to_joints'] + for j in range(len(leg_bones_to_joints)-1): + leg_bones += [torch.stack([joints[:, :, 9+i*3+leg_bones_to_joints[j][0], :], joints[:, :, 9+i*3+leg_bones_to_joints[j][1], :]], dim=2)] + # the last bone is attached to the body + leg_bones += [torch.stack([ + body_bones[:, :, leg_aux['body_bone_idx'], 1, :], joints[:, :, 9+i*3+leg_bones_to_joints[-1][1], :] + ], dim=2)] + + leg_bones = torch.stack(leg_bones, dim=2) + legs_bones.append(leg_bones) + + bones = torch.cat([body_bones] + legs_bones, dim=2) + return bones + + +class FixedDirectionLight(torch.nn.Module): + def __init__(self, direction, amb, diff): + super(FixedDirectionLight, self).__init__() + self.light_dir = direction + self.amb = amb + self.diff = diff + self.is_hacking = not (isinstance(self.amb, float) + or isinstance(self.amb, int)) + + def forward(self, feat): + batch_size = feat.shape[0] + if self.is_hacking: + return torch.concat([self.light_dir, self.amb, self.diff], -1) + else: + return torch.concat([self.light_dir, torch.FloatTensor([self.amb, self.diff]).to(self.light_dir.device)], -1).expand(batch_size, -1) + + def shade(self, feat, kd, normal): + light_params = self.forward(feat) + light_dir = light_params[..., :3][:, None, None, :] + int_amb = light_params[..., 3:4][:, None, None, :] + int_diff = light_params[..., 4:5][:, None, None, :] + shading = (int_amb + int_diff * + torch.clamp(util.dot(light_dir, normal), min=0.0)) + shaded = shading * kd + return shaded, shading + + +class SmoothLoss(nn.Module): + def __init__(self, dim=0, smooth_type=None, loss_type="l2"): + super(SmoothLoss, self).__init__() + self.dim = dim + + supported_smooth_types = ['mid_frame', 'dislocation', 'avg'] + assert smooth_type in supported_smooth_types, f"supported smooth type: {supported_smooth_types}" + self.smooth_type = smooth_type + + supported_loss_types = ['l2', 'mse', 'l1'] + assert loss_type in supported_loss_types, f"supported loss type: {supported_loss_types}" + self.loss_type = loss_type + + if self.loss_type in ['l2', 'mse']: + self.loss_fn = torch.nn.MSELoss(reduction='mean') + elif self.loss_type in ['l1']: + self.loss_fn = torch.nn.L1Loss() + else: + raise NotImplementedError + + def mid_frame_smooth(self, inputs): + nframe = inputs.shape[self.dim] + mid_num = (nframe-1) // 2 + # from IPython import embed; embed(); + mid_frame = torch.index_select(inputs, self.dim, torch.tensor([mid_num], device=inputs.device)) + repeat_num = self.get_repeat_num(inputs) + smooth = mid_frame.repeat(repeat_num) + loss = self.loss_fn(inputs, smooth) + # print(loss) + return loss + + def dislocation_smooth(self, inputs): + # from IPython import embed; embed() + nframe = inputs.shape[self.dim] + t = torch.index_select(inputs, self.dim, torch.arange(0, nframe-1).to(inputs.device)) + t_1 = torch.index_select(inputs, self.dim, torch.arange(1, nframe).to(inputs.device)) + loss = self.loss_fn(t, t_1) + return loss + + def avg_smooth(self, inputs): + # nframe = inputs.shape[self.dim] + # from IPython import embed; embed() + avg = inputs.mean(dim=self.dim, keepdim=True) + repeat_num = self.get_repeat_num(inputs) + smooth = avg.repeat(repeat_num) + loss = self.loss_fn(inputs, smooth) + return loss + + def get_repeat_num(self, inputs): + repeat_num = [1] * inputs.dim() + repeat_num[self.dim] = inputs.shape[self.dim] + return repeat_num + + def forward(self, inputs): + print(f"smooth_type: {self.smooth_type}") + if self.smooth_type is None: + return 0. + elif self.smooth_type == 'mid_frame': + return self.mid_frame_smooth(inputs) + elif self.smooth_type == 'dislocation': + return self.dislocation_smooth(inputs) + elif self.smooth_type == 'avg': + return self.avg_smooth(inputs) + else: + raise NotImplementedError() + + +class PriorPredictor(nn.Module): + def __init__(self, cfgs): + super().__init__() + + #add nnParameters + dmtet_grid = cfgs.get('dmtet_grid', 64) + grid_scale = cfgs.get('grid_scale', 5) + prior_sdf_mode = cfgs.get('prior_sdf_mode', 'mlp') + num_layers_shape = cfgs.get('num_layers_shape', 5) + hidden_size = cfgs.get('hidden_size', 64) + embedder_freq_shape = cfgs.get('embedder_freq_shape', 8) + embed_concat_pts = cfgs.get('embed_concat_pts', True) + init_sdf = cfgs.get('init_sdf', None) + jitter_grid = cfgs.get('jitter_grid', 0.) + perturb_sdf_iter = cfgs.get('perturb_sdf_iter', 10000) + sym_prior_shape = cfgs.get('sym_prior_shape', False) + train_data_dir = cfgs.get("train_data_dir", None) + if isinstance(train_data_dir, str): + num_of_classes = 1 + elif isinstance(train_data_dir, dict): + self.category_id_map = {} + num_of_classes = len(train_data_dir) + for i, (k, _) in enumerate(train_data_dir.items()): + self.category_id_map[k] = i + dim_of_classes = cfgs.get('dim_of_classes', 256) if num_of_classes > 1 else 0 + condition_choice = cfgs.get('prior_condition_choice', 'concat') + self.netShape = DMTetGeometry(dmtet_grid, grid_scale, prior_sdf_mode, num_layers=num_layers_shape, hidden_size=hidden_size, embedder_freq=embedder_freq_shape, embed_concat_pts=embed_concat_pts, init_sdf=init_sdf, jitter_grid=jitter_grid, perturb_sdf_iter=perturb_sdf_iter, sym_prior_shape=sym_prior_shape, + dim_of_classes=dim_of_classes, condition_choice=condition_choice) + + mlp_hidden_size = cfgs.get('hidden_size', 64) + tet_bbox = self.netShape.getAABB() + self.render_dino_mode = cfgs.get('render_dino_mode', None) + num_layers_dino = cfgs.get("num_layers_dino", 5) + dino_feature_recon_dim = cfgs.get('dino_feature_recon_dim', 64) + + sym_dino = cfgs.get("sym_dino", False) + dino_min = torch.zeros(dino_feature_recon_dim) + cfgs.get('dino_min', 0.) + dino_max = torch.zeros(dino_feature_recon_dim) + cfgs.get('dino_max', 1.) + min_max = torch.stack((dino_min, dino_max), dim=0) + if self.render_dino_mode is None: + pass + elif self.render_dino_mode == 'feature_mlpnv': + #MLPTexture3D predict the dino for each single point. + self.netDINO = mlptexture.MLPTexture3D(tet_bbox, channels=dino_feature_recon_dim, internal_dims=mlp_hidden_size, hidden=num_layers_dino-1, feat_dim=0, min_max=min_max, bsdf=None, perturb_normal=False, symmetrize=sym_dino) + elif self.render_dino_mode == 'feature_mlp': + embedder_scaler = 2 * np.pi / grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9 + embed_concat_pts = cfgs.get('embed_concat_pts', True) + self.netDINO = networks.MLPTextureSimple( + 3, # x, y, z coordinates + dino_feature_recon_dim, + num_layers_dino, + nf=mlp_hidden_size, + dropout=0, + activation="sigmoid", + min_max=min_max, + n_harmonic_functions=cfgs.get('embedder_freq_dino', 8), + omega0=embedder_scaler, + extra_dim=dim_of_classes, + embed_concat_pts=embed_concat_pts, + perturb_normal=False, + symmetrize=sym_dino + ) + elif self.render_dino_mode == 'cluster': + num_layers_dino = cfgs.get("num_layers_dino", 5) + dino_cluster_dim = cfgs.get('dino_cluster_dim', 64) + self.netDINO = mlptexture.MLPTexture3D(tet_bbox, channels=dino_cluster_dim, internal_dims=mlp_hidden_size, hidden=num_layers_dino-1, feat_dim=0, min_max=None, bsdf=None, perturb_normal=False, symmetrize=sym_dino) + else: + raise NotImplementedError + + self.classes_vectors = None + if num_of_classes > 1: + self.classes_vectors = torch.nn.Parameter(torch.nn.init.uniform_(torch.empty(num_of_classes, dim_of_classes), a=-0.05, b=0.05)) + + def forward(self, category_name=None, perturb_sdf=False, total_iter=None, is_training=True, class_embedding=None): + class_vector = None + if category_name is not None: + # print(category_name) + if class_embedding is not None: + class_vector = class_embedding[0] # [128] + return_classes_vectors = class_vector + else: + class_vector = self.classes_vectors[self.category_id_map[category_name]] + return_classes_vectors = self.classes_vectors + prior_shape = self.netShape.getMesh(perturb_sdf=perturb_sdf, total_iter=total_iter, jitter_grid=is_training, class_vector=class_vector) + # print(prior_shape.v_pos.shape) + # return prior_shape, self.netDINO, self.classes_vectors + return prior_shape, self.netDINO, return_classes_vectors + + +class InstancePredictor(nn.Module): + def __init__(self, cfgs, tet_bbox=None): + super().__init__() + self.cfgs = cfgs + self.grid_scale = cfgs.get('grid_scale', 5) + + self.enable_encoder = cfgs.get('enable_encoder', False) + if self.enable_encoder: + encoder_latent_dim = cfgs.get('latent_dim', 256) + encoder_pretrained = cfgs.get('encoder_pretrained', False) + encoder_frozen = cfgs.get('encoder_frozen', False) + encoder_arch = cfgs.get('encoder_arch', 'simple') + in_image_size = cfgs.get('in_image_size', 256) + self.dino_feature_input = cfgs.get('dino_feature_input', False) + dino_feature_dim = cfgs.get('dino_feature_dim', 64) + if encoder_arch == 'simple': + if self.dino_feature_input: + self.netEncoder = networks.EncoderWithDINO(cin_rgb=3, cin_dino=dino_feature_dim, cout=encoder_latent_dim, in_size=in_image_size, zdim=None, nf=64, activation=None) + else: + self.netEncoder = networks.Encoder(cin=3, cout=encoder_latent_dim, in_size=in_image_size, zdim=None, nf=64, activation=None) + elif encoder_arch == 'vgg': + self.netEncoder = networks.VGGEncoder(cout=encoder_latent_dim, pretrained=encoder_pretrained) + elif encoder_arch == 'resnet': + self.netEncoder = networks.ResnetEncoder(cout=encoder_latent_dim, pretrained=encoder_pretrained) + elif encoder_arch == 'vit': + which_vit = cfgs.get('which_vit', 'dino_vits8') + vit_final_layer_type = cfgs.get('vit_final_layer_type', 'conv') + root_dir = cfgs.get('root_dir', '/root') + self.netEncoder = networks.ViTEncoder(cout=encoder_latent_dim, which_vit=which_vit, pretrained=encoder_pretrained, frozen=encoder_frozen, in_size=in_image_size, final_layer_type=vit_final_layer_type, root=root_dir) + else: + raise NotImplementedError + else: + encoder_latent_dim = 0 + + mlp_hidden_size = cfgs.get('hidden_size', 64) + + bsdf = cfgs.get("bsdf", 'diffuse') + num_layers_tex = cfgs.get("num_layers_tex", 5) + feat_dim = cfgs.get("latent_dim", 64) if self.enable_encoder else 0 + perturb_normal = cfgs.get("perturb_normal", False) + sym_texture = cfgs.get("sym_texture", False) + kd_min = torch.FloatTensor(cfgs.get('kd_min', [0., 0., 0., 0.])) + kd_max = torch.FloatTensor(cfgs.get('kd_max', [1., 1., 1., 1.])) + ks_min = torch.FloatTensor(cfgs.get('ks_min', [0., 0., 0.])) + ks_max = torch.FloatTensor(cfgs.get('ks_max', [0., 0., 0.])) + nrm_min = torch.FloatTensor(cfgs.get('nrm_min', [-1., -1., 0.])) + nrm_max = torch.FloatTensor(cfgs.get('nrm_max', [1., 1., 1.])) + mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0) + mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0) + min_max = torch.stack((mlp_min, mlp_max), dim=0) + out_chn = 9 + # TODO: if the tet verts are deforming, we need to recompute tet_bbox + texture_mode = cfgs.get("texture_mode", 'mlp') + if texture_mode == 'mlpnv': + self.netTexture = mlptexture.MLPTexture3D(tet_bbox, channels=out_chn, internal_dims=mlp_hidden_size, hidden=num_layers_tex-1, feat_dim=feat_dim, min_max=min_max, bsdf=bsdf, perturb_normal=perturb_normal, symmetrize=sym_texture) + elif texture_mode == 'mlp': + embedder_scaler = 2 * np.pi / self.grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9 + embed_concat_pts = cfgs.get('embed_concat_pts', True) + + self.texture_way = cfgs.get('texture_way', None) + + if self.texture_way is None: + texture_act = cfgs.get('texture_act', 'relu') + texture_bias = cfgs.get('texture_bias', False) + self.netTexture = networks.MLPTextureSimple( + 3, # x, y, z coordinates + out_chn, + num_layers_tex, + nf=mlp_hidden_size, + dropout=0, + activation="sigmoid", + min_max=min_max, + n_harmonic_functions=cfgs.get('embedder_freq_tex', 10), + omega0=embedder_scaler, + extra_dim=feat_dim, + embed_concat_pts=embed_concat_pts, + perturb_normal=perturb_normal, + symmetrize=sym_texture, + texture_act=texture_act, + linear_bias=texture_bias + ) + else: + self.netTexture = networks.MLPTextureTriplane( + 3, # x, y, z coordinates + out_chn, + num_layers_tex, + nf=mlp_hidden_size, + dropout=0, + activation="sigmoid", + min_max=min_max, + n_harmonic_functions=cfgs.get('embedder_freq_tex', 10), + omega0=embedder_scaler, + extra_dim=feat_dim, + embed_concat_pts=embed_concat_pts, + perturb_normal=perturb_normal, + symmetrize=sym_texture, + texture_act='relu', + linear_bias=False, + cam_pos_z_offset=cfgs.get('cam_pos_z_offset', 10.), + grid_scale=self.grid_scale + ) + # if 'lift' in self.texture_way: + # # GET3D use global feature to get a tri-plane + # self.netTexture = TriPlaneTex( + # w_dim=512, + # img_channels=out_chn, + # tri_plane_resolution=256, + # device=cfgs.get('device', 'cpu'), + # mlp_latent_channel=32, + # n_implicit_layer=1, + # feat_dim=256, + # n_mapping_layer=8, + # sym_texture=sym_texture, + # grid_scale=self.grid_scale, + # min_max=min_max, + # perturb_normal=perturb_normal + # ) + + # # # project the local feature map into a grid + # # self.netTexture = networks.LiftTexture( + # # 3, # x, y, z coordinates + # # out_chn, + # # num_layers_tex, + # # nf=mlp_hidden_size, + # # dropout=0, + # # activation="sigmoid", + # # min_max=min_max, + # # n_harmonic_functions=cfgs.get('embedder_freq_tex', 10), + # # omega0=embedder_scaler, + # # extra_dim=feat_dim, + # # embed_concat_pts=embed_concat_pts, + # # perturb_normal=perturb_normal, + # # symmetrize=sym_texture, + # # texture_way=self.texture_way, + # # cam_pos_z_offset=cfgs.get('cam_pos_z_offset', 10.), + # # grid_scale=self.grid_scale, + # # local_feat_dim=cfgs.get("lift_local_feat_dim", 128), + # # grid_size=cfgs.get("lift_grid_size", 32), + # # optim_latent=cfgs.get("lift_optim_latent", False) + # # ) + # else: + # # a texture mlp with local feature map from patch_out + # self.netTexture = networks.MLPTextureLocal( + # 3, # x, y, z coordinates + # out_chn, + # num_layers_tex, + # nf=mlp_hidden_size, + # dropout=0, + # activation="sigmoid", + # min_max=min_max, + # n_harmonic_functions=cfgs.get('embedder_freq_tex', 10), + # omega0=embedder_scaler, + # extra_dim=feat_dim, + # embed_concat_pts=embed_concat_pts, + # perturb_normal=perturb_normal, + # symmetrize=sym_texture, + # texture_way=self.texture_way, + # larger_tex_dim=cfgs.get('larger_tex_dim', False), + # cam_pos_z_offset=cfgs.get('cam_pos_z_offset', 10.), + # grid_scale=self.grid_scale + # ) + + self.rot_rep = cfgs.get('rot_rep', 'euler_angle') + self.enable_pose = cfgs.get('enable_pose', False) + if self.enable_pose: + cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.) + fov = cfgs.get('crop_fov_approx', 25) + half_range = np.tan(fov /2 /180 * np.pi) * cam_pos_z_offset # 2.22 + self.max_trans_xy_range = half_range * cfgs.get('max_trans_xy_range_ratio', 1.) + self.max_trans_z_range = half_range * cfgs.get('max_trans_z_range_ratio', 1.) + self.lookat_init = cfgs.get('lookat_init', None) + self.lookat_zeroy = cfgs.get('lookat_zeroy', False) + self.rot_temp_scalar = cfgs.get('rot_temp_scalar', 1.) + self.naive_probs_iter = cfgs.get('naive_probs_iter', 2000) + self.best_pose_start_iter = cfgs.get('best_pose_start_iter', 6000) + + if self.rot_rep == 'euler_angle': + pose_cout = 6 + elif self.rot_rep == 'quaternion': + pose_cout = 7 + elif self.rot_rep == 'lookat': + pose_cout = 6 + elif self.rot_rep == 'quadlookat': + self.num_pose_hypos = 4 + pose_cout = (3 + 1) * self.num_pose_hypos + 3 # 4 forward vectors for 4 quadrants, 4 quadrant classification logits, 3 for translation + self.orthant_signs = torch.FloatTensor([[1,1,1], [-1,1,1], [-1,1,-1], [1,1,-1]]) + elif self.rot_rep == 'octlookat': + self.num_pose_hypos = 8 + pose_cout = (3 + 1) * self.num_pose_hypos + 3 # 4 forward vectors for 8 octants, 8 octant classification logits, 3 for translation + self.orthant_signs = torch.stack(torch.meshgrid([torch.arange(1, -2, -2)] *3), -1).view(-1, 3) # 8x3 + else: + raise NotImplementedError + + self.pose_arch = cfgs.get('pose_arch', 'mlp') + if self.pose_arch == 'mlp': + num_layers_pose = cfgs.get('num_layers_pose', 5) + self.netPose = networks.MLP( + encoder_latent_dim, + pose_cout, + num_layers_pose, + nf=mlp_hidden_size, + dropout=0, + activation=None + ) + elif self.pose_arch == 'encoder': + if self.dino_feature_input: + dino_feature_dim = cfgs.get('dino_feature_dim', 64) + self.netPose = networks.EncoderWithDINO(cin_rgb=3, cin_dino=dino_feature_dim, cout=pose_cout, in_size=in_image_size, zdim=None, nf=64, activation=None) + else: + self.netPose = networks.Encoder(cin=3, cout=pose_cout, in_size=in_image_size, zdim=None, nf=64, activation=None) + elif self.pose_arch in ['encoder_dino_patch_out', 'encoder_dino_patch_key']: + if which_vit == 'dino_vits8': + dino_feat_dim = 384 + elif which_vit == 'dinov2_vits14': + dino_feat_dim = 384 + elif which_vit == 'dino_vitb8': + dino_feat_dim = 768 + self.netPose = networks.Encoder32(cin=dino_feat_dim, cout=pose_cout, nf=256, activation=None) + elif self.pose_arch == 'vit': + encoder_pretrained = cfgs.get('encoder_pretrained', False) + encoder_frozen = cfgs.get('encoder_frozen', False) + which_vit = cfgs.get('which_vit', 'dino_vits8') + vit_final_layer_type = cfgs.get('vit_final_layer_type', 'conv') + root_dir = cfgs.get('root_dir', '/root') + self.netPose = networks.ViTEncoder(cout=encoder_latent_dim, which_vit=which_vit, pretrained=encoder_pretrained, frozen=encoder_frozen, in_size=in_image_size, final_layer_type=vit_final_layer_type, root=root_dir) + else: + raise NotImplementedError + + self.enable_deform = cfgs.get('enable_deform', False) + if self.enable_deform: + embedder_scaler = 2 * np.pi / self.grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9 + embed_concat_pts = cfgs.get('embed_concat_pts', True) + num_layers_deform = cfgs.get('num_layers_deform', 5) + self.deform_epochs = np.arange(*cfgs.get('deform_epochs', [0, 0])) + sym_deform = cfgs.get("sym_deform", False) + self.netDeform = networks.MLPWithPositionalEncoding( + 3, # x, y, z coordinates + 3, # dx, dy, dz deformation + num_layers_deform, + nf=mlp_hidden_size, + dropout=0, + activation=None, + n_harmonic_functions=cfgs.get('embedder_freq_deform', 10), + omega0=embedder_scaler, + extra_dim=encoder_latent_dim, + embed_concat_pts=embed_concat_pts, + symmetrize=sym_deform + ) + # self.avg_deform = cfgs.get('avg_deform', False) + # print(f'********avg_deform: {self.avg_deform}********') + + self.enable_articulation = cfgs.get('enable_articulation', False) + if self.enable_articulation: + self.num_body_bones = cfgs.get('num_body_bones', 4) + self.articulation_multiplier = cfgs.get('articulation_multiplier', 1) + self.static_root_bones = cfgs.get('static_root_bones', False) + self.skinning_temperature = cfgs.get('skinning_temperature', 1) + self.articulation_epochs = np.arange(*cfgs.get('articulation_epochs', [0, 0])) + self.num_legs = cfgs.get('num_legs', 0) + self.num_leg_bones = cfgs.get('num_leg_bones', 0) + self.body_bones_type = cfgs.get('body_bones_type', 'z_minmax') + self.perturb_articulation_epochs = np.arange(*cfgs.get('perturb_articulation_epochs', [0, 0])) + self.num_bones = self.num_body_bones + self.num_legs * self.num_leg_bones + self.constrain_legs = cfgs.get('constrain_legs', False) + self.attach_legs_to_body_epochs = np.arange(*cfgs.get('attach_legs_to_body_epochs', [0, 0])) + self.max_arti_angle = cfgs.get('max_arti_angle', 60) + + num_layers_arti = cfgs.get('num_layers_arti', 5) + which_vit = cfgs.get('which_vit', 'dino_vits8') + if which_vit == 'dino_vits8': + dino_feat_dim = 384 + elif which_vit == 'dino_vitb8': + dino_feat_dim = 768 + self.articulation_arch = cfgs.get('articulation_arch', 'mlp') + self.articulation_feature_mode = cfgs.get('articulation_feature_mode', 'sample') + embedder_freq_arti = cfgs.get('embedder_freq_arti', 8) + if self.articulation_feature_mode == 'global': + feat_dim = encoder_latent_dim + elif self.articulation_feature_mode == 'sample': + feat_dim = dino_feat_dim + elif self.articulation_feature_mode == 'sample+global': + feat_dim = encoder_latent_dim + dino_feat_dim + if self.articulation_feature_mode == 'attention': + arti_feat_attn_zdim = cfgs.get('arti_feat_attn_zdim', 128) + pos_dim = 1 + 2 + 3*2 + self.netFeatureAttn = networks.FeatureAttention(which_vit, pos_dim, embedder_freq_arti, arti_feat_attn_zdim, img_size=in_image_size) + embedder_scaler = np.pi * 0.9 # originally (-1, 1) rescale to (-pi, pi) * 0.9 + enable_articulation_idadd = cfgs.get('enable_articulation_idadd', False) + self.netArticulation = networks.ArticulationNetwork(self.articulation_arch, feat_dim, 1+2+3*2, num_layers_arti, mlp_hidden_size, n_harmonic_functions=embedder_freq_arti, omega0=embedder_scaler, + enable_articulation_idadd=enable_articulation_idadd) + self.kinematic_tree_epoch = -1 + + self.enable_lighting = cfgs.get('enable_lighting', False) + if self.enable_lighting: + num_layers_light = cfgs.get('num_layers_light', 5) + amb_diff_min = torch.FloatTensor(cfgs.get('amb_diff_min', [0., 0.])) + amb_diff_max = torch.FloatTensor(cfgs.get('amb_diff_max', [1., 1.])) + intensity_min_max = torch.stack((amb_diff_min, amb_diff_max), dim=0) + self.netLight = light.DirectionalLight(encoder_latent_dim, num_layers_light, mlp_hidden_size, intensity_min_max=intensity_min_max) + + self.cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.) + self.crop_fov_approx = cfgs.get("crop_fov_approx", 25) + + self.temp_clip_low = cfgs.get('temp_clip_low', 1.) + self.temp_clip_high = cfgs.get('temp_clip_high', 100.) + + # if the articulation and deformation is set as iterations, then use iteration to decide, not epoch + self.iter_articulation_start = cfgs.get('iter_articulation_start', None) + self.iter_deformation_start = cfgs.get('iter_deformation_start', None) + + self.iter_nozeroy_start = cfgs.get('iter_nozeroy_start', None) + self.iter_attach_leg_to_body_start = cfgs.get('iter_attach_leg_to_body_start', None) + + def forward_encoder(self, images, dino_features=None): + images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1) + patch_out = patch_key = None + if self.dino_feature_input and self.cfgs.get('encoder_arch', 'simple') != 'vit': + dino_features_in = dino_features.view(-1, *dino_features.shape[2:]) * 2 - 1 # rescale to (-1, 1) + feat_out = self.netEncoder(images_in, dino_features_in) # Shape: (B, latent_dim) + elif self.cfgs.get('encoder_arch', 'simple') == 'vit': + feat_out, feat_key, patch_out, patch_key = self.netEncoder(images_in, return_patches=True) + else: + feat_out = self.netEncoder(images_in) # Shape: (B, latent_dim) + return feat_out, feat_key, patch_out, patch_key + + + def forward_pose(self, images, feat, patch_out, patch_key, dino_features): + if self.pose_arch == 'mlp': + pose = self.netPose(feat) + elif self.pose_arch == 'encoder': + images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1) + if self.dino_feature_input: + dino_features_in = dino_features.view(-1, *dino_features.shape[2:]) * 2 - 1 # rescale to (-1, 1) + pose = self.netPose(images_in, dino_features_in) # Shape: (B, latent_dim) + else: + pose = self.netPose(images_in) # Shape: (B, latent_dim) + elif self.pose_arch == 'vit': + images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1) + pose = self.netPose(images_in) + elif self.pose_arch == 'encoder_dino_patch_out': + pose = self.netPose(patch_out) # Shape: (B, latent_dim) + elif self.pose_arch == 'encoder_dino_patch_key': + pose = self.netPose(patch_key) # Shape: (B, latent_dim) + else: + raise NotImplementedError + trans_pred = pose[...,-3:].tanh() * torch.FloatTensor([self.max_trans_xy_range, self.max_trans_xy_range, self.max_trans_z_range]).to(pose.device) + if self.rot_rep == 'euler_angle': + multiplier = 1. + if self.gradually_expand_yaw: + # multiplier += (min(iteration, 20000) // 500) * 0.25 + multiplier *= 1.2 ** (min(iteration, 20000) // 500) # 1.125^40 = 111.200 + rot_pred = torch.cat([pose[...,:1], pose[...,1:2]*multiplier, pose[...,2:3]], -1).tanh() + rot_pred = rot_pred * torch.FloatTensor([self.max_rot_x_range, self.max_rot_y_range, self.max_rot_z_range]).to(pose.device) /180 * np.pi + + elif self.rot_rep == 'quaternion': + quat_init = torch.FloatTensor([0.01,0,0,0]).to(pose.device) + rot_pred = pose[...,:4] + quat_init + rot_pred = nn.functional.normalize(rot_pred, p=2, dim=-1) + # rot_pred = torch.cat([rot_pred[...,:1].abs(), rot_pred[...,1:]], -1) # make real part non-negative + rot_pred = rot_pred * rot_pred[...,:1].sign() # make real part non-negative + + elif self.rot_rep == 'lookat': + vec_forward_raw = pose[...,:3] + if self.lookat_init is not None: + vec_forward_raw = vec_forward_raw + torch.FloatTensor(self.lookat_init).to(pose.device) + if self.lookat_zeroy: + vec_forward_raw = vec_forward_raw * torch.FloatTensor([1,0,1]).to(pose.device) + vec_forward_raw = nn.functional.normalize(vec_forward_raw, p=2, dim=-1) # x right, y up, z forward + rot_pred = vec_forward_raw + + elif self.rot_rep in ['quadlookat', 'octlookat']: + rots_pred = pose[..., :self.num_pose_hypos*4].view(-1, self.num_pose_hypos, 4) # (B, T, K, 4) + rots_logits = rots_pred[..., :1] + vec_forward_raw = rots_pred[..., 1:4] + xs, ys, zs = vec_forward_raw.unbind(-1) + margin = 0. + xs = nn.functional.softplus(xs, beta=np.log(2)/(0.5+margin)) - margin # initialize to 0.5 + if self.rot_rep == 'octlookat': + ys = nn.functional.softplus(ys, beta=np.log(2)/(0.5+margin)) - margin # initialize to 0.5 + if self.lookat_zeroy: + ys = ys * 0 + zs = nn.functional.softplus(zs, beta=2*np.log(2)) # initialize to 0.5 + vec_forward_raw = torch.stack([xs, ys, zs], -1) + vec_forward_raw = vec_forward_raw * self.orthant_signs.to(pose.device) + vec_forward_raw = nn.functional.normalize(vec_forward_raw, p=2, dim=-1) # x right, y up, z forward + rot_pred = torch.cat([rots_logits, vec_forward_raw], -1).view(-1, self.num_pose_hypos*4) + + else: + raise NotImplementedError + + pose = torch.cat([rot_pred, trans_pred], -1) + return pose + + def forward_deformation(self, shape, feat=None, batch_size=None, num_frames=None): + original_verts = shape.v_pos + num_verts = original_verts.shape[1] + if feat is not None: + deform_feat = feat[:, None, :].repeat(1, num_verts, 1) # Shape: (B, num_verts, latent_dim) + original_verts = original_verts.repeat(len(feat),1,1) + deformation = self.netDeform(original_verts, deform_feat) * 0.1 # Shape: (B, num_verts, 3) + # if self.avg_deform: + # assert batch_size is not None and num_frames is not None + # assert deformation.shape[0] == batch_size * num_frames + # deformation = deformation.view(batch_size, num_frames, *deformation.shape[1:]) + # deformation = deformation.mean(dim=1, keepdim=True) + # deformation = deformation.repeat(1,num_frames,*[1]*(deformation.dim()-2)) + # deformation = deformation.view(batch_size*num_frames, *deformation.shape[2:]) + shape = shape.deform(deformation) + return shape, deformation + + def forward_articulation(self, shape, feat, patch_feat, mvp, w2c, batch_size, num_frames, epoch, category, total_iter=None): + """ + Forward propagation of articulation. For each bone, the network takes: 1) the 3D location of the bone; 2) the feature of the patch which + the bone is projected to; and 3) an encoding of the bone's index to predict the bone's rotation (represented by an Euler angle). + + Args: + shape: a Mesh object, whose v_pos has batch size BxF or 1. + feat: the feature of the patches. Shape: (BxF, feat_dim, num_patches_per_axis, num_patches_per_axis) + mvp: the model-view-projection matrix. Shape: (BxF, 4, 4) + + Returns: + shape: a Mesh object, whose v_pos has batch size BxF (collapsed). + articulation_angles: the predicted bone rotations. Shape: (B, F, num_bones, 3) + aux: a dictionary containing auxiliary information. + """ + verts = shape.v_pos + if len(verts) == 1: + verts = verts[None] + else: + verts = verts.view(batch_size, num_frames, *verts.shape[1:]) + + if self.kinematic_tree_epoch != epoch: + # if (epoch == self.articulation_epochs[0]) and (self.kinematic_tree_epoch != epoch): + # if (epoch in [self.articulation_epochs[0], self.articulation_epochs[0]+2, self.articulation_epochs[0]+4]) and (self.kinematic_tree_epoch != epoch): + if total_iter is not None and self.iter_attach_leg_to_body_start is not None: + attach_legs_to_body = total_iter > self.iter_attach_leg_to_body_start + else: + attach_legs_to_body = epoch in self.attach_legs_to_body_epochs + + # bone_y_thresh = None if category is None or not category == "giraffe" else 0.1 + bone_y_thresh = self.cfgs.get('bone_y_thresh', None) + + # trivial set here + body_bone_idx_preset_cfg = self.cfgs.get('body_bone_idx_preset', [0, 0, 0, 0]) + if isinstance(body_bone_idx_preset_cfg, list): + body_bone_idx_preset = body_bone_idx_preset_cfg + elif isinstance(body_bone_idx_preset_cfg, dict): + iter_point = list(body_bone_idx_preset_cfg.keys())[1] + if total_iter <= iter_point: + body_bone_idx_preset = body_bone_idx_preset_cfg[0] # the first is start from 0 iter + else: + body_bone_idx_preset = body_bone_idx_preset_cfg[iter_point] + else: + raise NotImplementedError + + bones, self.kinematic_tree, self.bone_aux = estimate_bones(verts.detach(), self.num_body_bones, n_legs=self.num_legs, n_leg_bones=self.num_leg_bones, body_bones_type=self.body_bones_type, compute_kinematic_chain=True, attach_legs_to_body=attach_legs_to_body, bone_y_threshold=bone_y_thresh, body_bone_idx_preset=body_bone_idx_preset) + # self.kinematic_tree_epoch = epoch + else: + bones = estimate_bones(verts.detach(), self.num_body_bones, n_legs=self.num_legs, n_leg_bones=self.num_leg_bones, body_bones_type=self.body_bones_type, compute_kinematic_chain=False, aux=self.bone_aux) + + bones_pos = bones # Shape: (B, F, K, 2, 3) + if batch_size > bones_pos.shape[0] or num_frames > bones_pos.shape[1]: + assert bones_pos.shape[0] == 1 and bones_pos.shape[1] == 1, "If there is a mismatch, then there must be only one canonical mesh." + bones_pos = bones_pos.repeat(batch_size, num_frames, 1, 1, 1) + num_bones = bones_pos.shape[2] + bones_pos = bones_pos.view(batch_size*num_frames, num_bones, 2, 3) # NxKx2x3 + bones_mid_pos = bones_pos.mean(2) # NxKx3 + bones_idx = torch.arange(num_bones).to(bones_pos.device) + + bones_mid_pos_world4 = torch.cat([bones_mid_pos, torch.ones_like(bones_mid_pos[..., :1])], -1) # NxKx4 + bones_mid_pos_clip4 = bones_mid_pos_world4 @ mvp.transpose(-1, -2) + bones_mid_pos_uv = bones_mid_pos_clip4[..., :2] / bones_mid_pos_clip4[..., 3:4] + bones_mid_pos_uv = bones_mid_pos_uv.detach() + + bones_pos_world4 = torch.cat([bones_pos, torch.ones_like(bones_pos[..., :1])], -1) # NxKx2x4 + bones_pos_cam4 = bones_pos_world4 @ w2c[:,None].transpose(-1, -2) + bones_pos_cam3 = bones_pos_cam4[..., :3] / bones_pos_cam4[..., 3:4] + bones_pos_cam3 = bones_pos_cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(bones_pos_cam3.device).view(1, 1, 1, 3) + bones_pos_in = bones_pos_cam3.view(batch_size*num_frames, num_bones, 2*3) / self.grid_scale * 2 # (-1, 1), NxKx(2*3) + + bones_idx_in = ((bones_idx[None, :, None] + 0.5) / num_bones * 2 - 1).repeat(batch_size * num_frames, 1, 1) # (-1, 1) + bones_pos_in = torch.cat([bones_mid_pos_uv, bones_pos_in, bones_idx_in], -1).detach() + + if self.articulation_feature_mode == 'global': + bones_patch_features = feat[:, None].repeat(1, num_bones, 1) # (BxF, K, feat_dim) + elif self.articulation_feature_mode == 'sample': + bones_patch_features = F.grid_sample(patch_feat, bones_mid_pos_uv.view(batch_size * num_frames, 1, -1, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # (BxF, K, feat_dim) + elif self.articulation_feature_mode == 'sample+global': + bones_patch_features = F.grid_sample(patch_feat, bones_mid_pos_uv.view(batch_size * num_frames, 1, -1, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # (BxF, K, feat_dim) + bones_patch_features = torch.cat([feat[:, None].repeat(1, num_bones, 1), bones_patch_features], -1) + elif self.articulation_feature_mode == 'attention': + bones_patch_features = self.netFeatureAttn(bones_pos_in, patch_feat) + else: + raise NotImplementedError + + articulation_angles = self.netArticulation(bones_patch_features, bones_pos_in).view(batch_size, num_frames, num_bones, 3) * self.articulation_multiplier + + if self.static_root_bones: + root_bones = [self.num_body_bones // 2 - 1, self.num_body_bones - 1] + tmp_mask = torch.ones_like(articulation_angles) + tmp_mask[:, :, root_bones] = 0 + articulation_angles = articulation_angles * tmp_mask + + articulation_angles = articulation_angles.tanh() + + if self.cfgs.get('iter_leg_rotation_start', -1) > 0: + if total_iter <= self.cfgs.get('iter_leg_rotation_start', -1): + self.constrain_legs = True + else: + self.constrain_legs = False + + if self.constrain_legs: + leg_bones_posx = [self.num_body_bones + i for i in range(self.num_leg_bones * self.num_legs // 2)] + leg_bones_negx = [self.num_body_bones + self.num_leg_bones * self.num_legs // 2 + i for i in range(self.num_leg_bones * self.num_legs // 2)] + + tmp_mask = torch.zeros_like(articulation_angles) + tmp_mask[:, :, leg_bones_posx + leg_bones_negx, 2] = 1 + articulation_angles = tmp_mask * (articulation_angles * 0.3) + (1 - tmp_mask) * articulation_angles # no twist + + tmp_mask = torch.zeros_like(articulation_angles) + tmp_mask[:, :, leg_bones_posx + leg_bones_negx, 1] = 1 + articulation_angles = tmp_mask * (articulation_angles * 0.3) + (1 - tmp_mask) * articulation_angles # (-0.4, 0.4), limit side bending + + # new regularizations, for bottom 2 bones of each leg, they can only rotate around x-axis, + # and for the toppest bone of legs, restrict its angles in a smaller range + if (self.cfgs.get('iter_leg_rotation_start', -1) > 0) and (total_iter > self.cfgs.get('iter_leg_rotation_start', -1)): + if self.cfgs.get('forbid_leg_rotate', False): + if self.cfgs.get('small_leg_angle', False): + # regularize the rotation angle of first leg bones + leg_bones_top = [8, 11, 14, 17] + # leg_bones_top = [10, 13, 16, 19] + tmp_mask = torch.zeros_like(articulation_angles) + tmp_mask[:, :, leg_bones_top, 1] = 1 + tmp_mask[:, :, leg_bones_top, 2] = 1 + articulation_angles = tmp_mask * (articulation_angles * 0.05) + (1 - tmp_mask) * articulation_angles + + leg_bones_bottom = [9, 10, 12, 13, 15, 16, 18, 19] + # leg_bones_bottom = [8, 9, 11, 12, 14, 15, 17, 18] + tmp_mask = torch.ones_like(articulation_angles) + tmp_mask[:, :, leg_bones_bottom, 1] = 0 + tmp_mask[:, :, leg_bones_bottom, 2] = 0 + # tmp_mask[:, :, leg_bones_bottom, 0] = 0.3 + articulation_angles = tmp_mask * articulation_angles + + if epoch in self.perturb_articulation_epochs: + articulation_angles = articulation_angles + torch.randn_like(articulation_angles) * 0.1 + articulation_angles = articulation_angles * self.max_arti_angle / 180 * np.pi + + # check if regularize the leg-connecting body bones z-rotation first + # then check if regularize all the body bones z-rotation + # regularize z-rotation using 0.1 in pi-space + body_rotate_mult = self.cfgs.get('reg_body_rotate_mult', 0.1) + body_rotate_mult = body_rotate_mult * 180 * 1.0 / (self.max_arti_angle * np.pi) # the max angle = mult*original_max_angle + body_rotate_reg_mode = self.cfgs.get('body_rotate_reg_mode', 'nothing') + if body_rotate_reg_mode == 'leg-connect': + body_bones_mask = [2, 3, 4, 5] + tmp_body_mask = torch.zeros_like(articulation_angles) + tmp_body_mask[:, :, body_bones_mask, 2] = 1 + articulation_angles = tmp_body_mask * (articulation_angles * body_rotate_mult) + (1 - tmp_body_mask) * articulation_angles + + elif body_rotate_reg_mode == 'all-bones': + body_bones_mask = [0, 1, 2, 3, 4, 5, 6, 7] + tmp_body_mask = torch.zeros_like(articulation_angles) + tmp_body_mask[:, :, body_bones_mask, 2] = 1 + articulation_angles = tmp_body_mask * (articulation_angles * body_rotate_mult) + (1 - tmp_body_mask) * articulation_angles + + elif body_rotate_reg_mode == 'nothing': + articulation_angles = articulation_angles * 1. + + else: + raise NotImplementedError + + verts_articulated, aux = skinning(verts, bones, self.kinematic_tree, articulation_angles, + output_posed_bones=True, temperature=self.skinning_temperature) + verts_articulated = verts_articulated.view(batch_size*num_frames, *verts_articulated.shape[2:]) + v_tex = shape.v_tex + if len(v_tex) != len(verts_articulated): + v_tex = v_tex.repeat(len(verts_articulated), 1, 1) + shape = mesh.make_mesh( + verts_articulated, + shape.t_pos_idx, + v_tex, + shape.t_tex_idx, + shape.material) + return shape, articulation_angles, aux + + def get_camera_extrinsics_from_pose(self, pose, znear=0.1, zfar=1000., crop_fov_approx=None, offset_extra=None): + if crop_fov_approx is None: + crop_fov_approx = self.crop_fov_approx + N = len(pose) + if offset_extra is not None: + cam_pos_offset = torch.FloatTensor([0, 0, -self.cam_pos_z_offset - offset_extra]).to(pose.device) + else: + cam_pos_offset = torch.FloatTensor([0, 0, -self.cam_pos_z_offset]).to(pose.device) + pose_R = pose[:, :9].view(N, 3, 3).transpose(2, 1) + pose_T = pose[:, -3:] + cam_pos_offset[None, None, :] + pose_T = pose_T.view(N, 3, 1) + pose_RT = torch.cat([pose_R, pose_T], axis=2) # Nx3x4 + w2c = torch.cat([pose_RT, torch.FloatTensor([0, 0, 0, 1]).repeat(N, 1, 1).to(pose.device)], axis=1) # Nx4x4 + # We assume the images are perfect square. + if isinstance(crop_fov_approx, float) or isinstance(crop_fov_approx, int): + proj = util.perspective(crop_fov_approx / 180 * np.pi, 1, znear, zfar)[None].to(pose.device) + elif isinstance(crop_fov_approx, torch.Tensor): + proj = util.batched_perspective(crop_fov_approx / 180 * np.pi, 1, znear, zfar).to(pose.device) + else: + raise ValueError('crop_fov_approx must be float or torch.Tensor') + mvp = torch.matmul(proj, w2c) + campos = -torch.matmul(pose_R.transpose(2, 1), pose_T).view(N, 3) + return mvp, w2c, campos + + def forward(self, category=None, images=None, prior_shape=None, epoch=None, dino_features=None, dino_clusters=None, total_iter=None, is_training=True): + batch_size, num_frames = images.shape[:2] + if self.enable_encoder: + feat_out, feat_key, patch_out, patch_key = self.forward_encoder(images, dino_features) + else: + feat_out = feat_key = patch_out = patch_key = None + shape = prior_shape + texture = self.netTexture + + multi_hypothesis_aux = {} + if self.iter_nozeroy_start is not None and total_iter >= self.iter_nozeroy_start: + self.lookat_zeroy = False + + if self.enable_pose: + poses_raw = self.forward_pose(images, feat_out, patch_out, patch_key, dino_features) + pose_raw, pose, rot_idx, rot_prob, rot_logit, rots_probs, rand_pose_flag = sample_pose_hypothesis_from_quad_prediction(poses_raw, total_iter, batch_size, num_frames, rot_temp_scalar=self.rot_temp_scalar, num_hypos=self.num_pose_hypos, naive_probs_iter=self.naive_probs_iter, best_pose_start_iter=self.best_pose_start_iter, random_sample=is_training, temp_clip_low=self.temp_clip_low, temp_clip_high=self.temp_clip_high) + multi_hypothesis_aux['rot_idx'] = rot_idx + multi_hypothesis_aux['rot_prob'] = rot_prob + multi_hypothesis_aux['rot_logit'] = rot_logit + multi_hypothesis_aux['rots_probs'] = rots_probs + multi_hypothesis_aux['rand_pose_flag'] = rand_pose_flag + else: + raise NotImplementedError + mvp, w2c, campos = self.get_camera_extrinsics_from_pose(pose) + + deformation = None + if self.iter_deformation_start is not None: + if self.enable_deform and total_iter >= self.iter_deformation_start: + shape, deformation = self.forward_deformation(shape, feat_key, batch_size, num_frames) + else: + if self.enable_deform and epoch in self.deform_epochs: + shape, deformation = self.forward_deformation(shape, feat_key, batch_size, num_frames) + + arti_params, articulation_aux = None, {} + if self.iter_articulation_start is not None: + if self.enable_articulation and total_iter >= self.iter_articulation_start: + shape, arti_params, articulation_aux = self.forward_articulation(shape, feat_key, patch_key, mvp, w2c, batch_size, num_frames, epoch, category, total_iter=total_iter) + else: + if self.enable_articulation and epoch in self.articulation_epochs: + shape, arti_params, articulation_aux = self.forward_articulation(shape, feat_key, patch_key, mvp, w2c, batch_size, num_frames, epoch, category, total_iter=None) + + if self.enable_lighting: + light = self.netLight + else: + light = None + + aux = articulation_aux + aux.update(multi_hypothesis_aux) + + # if using texture_way to control a local texture, output patch_out + if self.texture_way is None: + return shape, pose_raw, pose, mvp, w2c, campos, texture, feat_out, patch_key, deformation, arti_params, light, aux + else: + return shape, pose_raw, pose, mvp, w2c, campos, texture, feat_out, patch_key, deformation, arti_params, light, aux, patch_out + +class Unsup3DDDP: + def __init__(self, cfgs): + self.cfgs = cfgs + self.device = cfgs.get('device', 'cpu') + self.in_image_size = cfgs.get('in_image_size', 128) + self.out_image_size = cfgs.get('out_image_size', 128) + + self.num_epochs = cfgs.get('num_epochs', 10) + self.lr = cfgs.get('lr', 1e-4) + self.use_scheduler = cfgs.get('use_scheduler', False) + if self.use_scheduler: + scheduler_milestone = cfgs.get('scheduler_milestone', [1,2,3,4,5]) + scheduler_gamma = cfgs.get('scheduler_gamma', 0.5) + self.make_scheduler = lambda optim: torch.optim.lr_scheduler.MultiStepLR(optim, milestones=scheduler_milestone, gamma=scheduler_gamma) + + self.cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.) + self.full_size_h = cfgs.get('full_size_h', 1080) + self.full_size_w = cfgs.get('full_size_w', 1920) + # self.fov_w = cfgs.get('fov_w', 60) + # self.fov_h = np.arctan(np.tan(self.fov_w /2 /180*np.pi) / self.full_size_w * self.full_size_h) *2 /np.pi*180 # 36 + self.crop_fov_approx = cfgs.get("crop_fov_approx", 25) + self.mesh_regularization_mode = cfgs.get('mesh_regularization_mode', 'seq') + + self.enable_prior = cfgs.get('enable_prior', False) + if self.enable_prior: + self.netPrior = PriorPredictor(self.cfgs) #DOR - add label + self.prior_lr = cfgs.get('prior_lr', self.lr) + self.prior_weight_decay = cfgs.get('prior_weight_decay', 0.) + self.prior_only_epochs = cfgs.get('prior_only_epochs', 0) + self.netInstance = InstancePredictor(self.cfgs, tet_bbox=self.netPrior.netShape.getAABB()) + self.perturb_sdf = cfgs.get('perturb_sdf', False) + self.blur_mask = cfgs.get('blur_mask', False) + self.blur_mask_iter = cfgs.get('blur_mask_iter', 1) + + self.seqshape_epochs = np.arange(*cfgs.get('seqshape_epochs', [0, self.num_epochs])) + self.avg_texture_epochs = np.arange(*cfgs.get('avg_texture_epochs', [0, 0])) + self.swap_texture_epochs = np.arange(*cfgs.get('swap_texture_epochs', [0, 0])) + self.swap_priorshape_epochs = np.arange(*cfgs.get('swap_priorshape_epochs', [0, 0])) + self.avg_seqshape_epochs = np.arange(*cfgs.get('avg_seqshape_epochs', [0, 0])) + self.swap_seqshape_epochs = np.arange(*cfgs.get('swap_seqshape_epochs', [0, 0])) + self.pose_epochs = np.arange(*cfgs.get('pose_epochs', [0, 0])) + self.pose_iters = cfgs.get('pose_iters', 0) + self.deform_type = cfgs.get('deform_type', None) + self.mesh_reg_decay_epoch = cfgs.get('mesh_reg_decay_epoch', 0) + self.sdf_reg_decay_start_iter = cfgs.get('sdf_reg_decay_start_iter', 0) + self.mesh_reg_decay_rate = cfgs.get('mesh_reg_decay_rate', 1) + self.texture_epochs = np.arange(*cfgs.get('texture_epochs', [0, self.num_epochs])) + self.zflip_epochs = np.arange(*cfgs.get('zflip_epochs', [0, self.num_epochs])) + self.lookat_zflip_loss_epochs = np.arange(*cfgs.get('lookat_zflip_loss_epochs', [0, self.num_epochs])) + self.lookat_zflip_no_other_losses = cfgs.get('lookat_zflip_no_other_losses', False) + self.flow_loss_epochs = np.arange(*cfgs.get('flow_loss_epochs', [0, self.num_epochs])) + self.sdf_inflate_reg_loss_epochs = np.arange(*cfgs.get('sdf_inflate_reg_loss_epochs', [0, self.num_epochs])) + self.arti_reg_loss_epochs = np.arange(*cfgs.get('arti_reg_loss_epochs', [0, self.num_epochs])) + self.background_mode = cfgs.get('background_mode', 'background') + self.shape_prior_type = cfgs.get('shape_prior_type', 'deform') + self.backward_prior = cfgs.get('backward_prior', True) + self.resume_prior_optim = cfgs.get('resume_prior_optim', True) + self.dmtet_grid_smaller_epoch = cfgs.get('dmtet_grid_smaller_epoch', 0) + self.dmtet_grid_smaller = cfgs.get('dmtet_grid_smaller', 128) + self.dmtet_grid = cfgs.get('dmtet_grid', 256) + self.pose_xflip_recon_epochs = np.arange(*cfgs.get('pose_xflip_recon_epochs', [0, 0])) + self.rot_rand_quad_epochs = np.arange(*cfgs.get('rot_rand_quad_epochs', [0, 0])) + self.rot_all_quad_epochs = np.arange(*cfgs.get('rot_all_quad_epochs', [0, 0])) + self.calc_dino_features = cfgs.get('calc_dino_features', False) + + # self.smooth_type = cfgs.get('smooth_type', 'None') + # print(f"****smooth_type: {self.smooth_type}****") + + ## smooth losses + # smooth articulation + self.arti_smooth_type = cfgs.get('arti_smooth_type', None) + self.arti_smooth_loss_type = cfgs.get('arti_smooth_loss_type', None) + self.arti_smooth_loss_weight = cfgs.get('arti_smooth_loss_weight', 0.) + self.using_arti_smooth_loss = self.arti_smooth_type and self.arti_smooth_loss_type and self.arti_smooth_loss_weight > 0. + if self.using_arti_smooth_loss: + self.arti_smooth_loss_fn = SmoothLoss(dim=1, smooth_type=self.arti_smooth_type, loss_type=self.arti_smooth_loss_type) + else: + self.arti_smooth_loss_fn = None + # smooth deformation + self.deform_smooth_type = cfgs.get('deform_smooth_type', None) + self.deform_smooth_loss_type = cfgs.get('deform_smooth_loss_type', None) + self.deform_smooth_loss_weight = cfgs.get('deform_smooth_loss_weight', 0.) + self.using_deform_smooth_loss = self.deform_smooth_type and self.deform_smooth_loss_type and self.deform_smooth_loss_weight > 0. + if self.using_deform_smooth_loss: + self.deform_smooth_loss_fn = SmoothLoss(dim=1, smooth_type=self.deform_smooth_type, loss_type=self.deform_smooth_loss_type) + else: + self.deform_smooth_loss_fn = None + # smooth camera pose + self.campos_smooth_type = cfgs.get('campos_smooth_type', None) + self.campos_smooth_loss_type = cfgs.get('campos_smooth_loss_type', None) + self.campos_smooth_loss_weight = cfgs.get('campos_smooth_loss_weight', 0.) + self.using_campos_smooth_loss = self.campos_smooth_type and self.campos_smooth_loss_type and self.campos_smooth_loss_weight > 0. + if self.using_campos_smooth_loss: + self.campos_smooth_loss_fn = SmoothLoss(dim=1, smooth_type=self.campos_smooth_type, loss_type=self.campos_smooth_loss_type) + else: + self.campos_smooth_loss_fn = None + # smooth articulation velocity + self.artivel_smooth_type = cfgs.get('artivel_smooth_type', None) + self.artivel_smooth_loss_type = cfgs.get('artivel_smooth_loss_type', None) + self.artivel_smooth_loss_weight = cfgs.get('artivel_smooth_loss_weight', 0.) + self.using_artivel_smooth_loss = self.artivel_smooth_type and self.artivel_smooth_loss_type and self.artivel_smooth_loss_weight > 0. + if self.using_artivel_smooth_loss: + self.artivel_smooth_loss_fn = SmoothLoss(dim=1, smooth_type=self.artivel_smooth_type, loss_type=self.artivel_smooth_loss_type) + else: + self.artivel_smooth_loss_fn = None + # smooth bone + self.bone_smooth_type = cfgs.get('bone_smooth_type', None) + self.bone_smooth_loss_type = cfgs.get('bone_smooth_loss_type', None) + self.bone_smooth_loss_weight = cfgs.get('bone_smooth_loss_weight', 0.) + self.using_bone_smooth_loss = self.bone_smooth_type and self.bone_smooth_loss_type and self.bone_smooth_loss_weight > 0. + if self.using_bone_smooth_loss: + self.bone_smooth_loss_fn = SmoothLoss(dim=1, smooth_type=self.bone_smooth_type, loss_type=self.bone_smooth_loss_type) + else: + self.bone_smooth_loss_fn = None + # smooth bone velocity + self.bonevel_smooth_type = cfgs.get('bonevel_smooth_type', None) + self.bonevel_smooth_loss_type = cfgs.get('bonevel_smooth_loss_type', None) + self.bonevel_smooth_loss_weight = cfgs.get('bonevel_smooth_loss_weight', 0.) + self.using_bonevel_smooth_loss = self.bonevel_smooth_type and self.bonevel_smooth_loss_type and self.bonevel_smooth_loss_weight > 0. + if self.using_bonevel_smooth_loss: + self.bonevel_smooth_loss_fn = SmoothLoss(dim=1, smooth_type=self.bonevel_smooth_type, loss_type=self.bonevel_smooth_loss_type) + else: + self.bonevel_smooth_loss_fn = None + + + ## perceptual loss + if cfgs.get('perceptual_loss_weight', 0.) > 0: + self.perceptual_loss_use_lin = cfgs.get('perceptual_loss_use_lin', True) + self.perceptual_loss = lpips.LPIPS(net='vgg', lpips=self.perceptual_loss_use_lin) + + self.glctx = dr.RasterizeGLContext() + self.render_flow = self.cfgs.get('flow_loss_weight', 0.) > 0. + self.extra_renders = cfgs.get('extra_renders', []) + self.renderer_spp = cfgs.get('renderer_spp', 1) + self.dino_feature_recon_dim = cfgs.get('dino_feature_recon_dim', 64) + + self.total_loss = 0. + self.all_scores = torch.Tensor() + self.checkpoint_dir = cfgs.get('checkpoint_dir', 'results') + + # iter + self.iter_arti_reg_loss_start = cfgs.get('iter_arti_reg_loss_start', None) + + # mask distribution + self.enable_mask_distribution = cfgs.get('enable_mask_distribution', False) + self.random_mask_law = cfgs.get('random_mask_law', 'batch_swap_noy') # batch_swap, batch_swap_noy, # random_azimuth # random_all + self.mask_distribution_path = cfgs.get('mask_distribution_path', None) + if self.enable_mask_distribution and (self.mask_distribution_path is not None): + self.class_mask_distribution = {} + for category in os.listdir(self.mask_distribution_path): + # Here we assume the category names are identical + distribution_file = osp.join(self.mask_distribution_path, category, "raw_mask_distribution.npy") + distribution = np.load(distribution_file) + self.class_mask_distribution.update( + { + category: distribution # [256, 256] + } + ) + self.mask_distribution_loss_weight = cfgs.get("mask_distribution_loss_weight", 0.1) + self.mask_distribution_loss_freq = cfgs.get("mask_distribution_loss_freq", 1) + + self.mask_distribution_average = cfgs.get("mask_distribution_average", False) + + else: + self.enable_mask_distribution = False + + self.enable_clip = cfgs.get('enable_clip', False) + if self.enable_clip: + self.clip_model, _ = clip.load('ViT-B/32', self.device) + self.clip_model = self.clip_model.eval().requires_grad_(False) + self.clip_mean = [0.48145466, 0.4578275, 0.40821073] + self.clip_std = [0.26862954, 0.26130258, 0.27577711] + self.clip_reso = 224 + self.clip_render_size = 64 + self.enable_clip_text = cfgs.get('enable_clip_text', False) + if self.enable_clip_text: + self.clip_text_feature = {} + for category_name in ['bear', 'elephant', 'horse', 'sheep', 'cow', 'zebra', 'giraffe']: + text_input = clip.tokenize(['A photo of ' + category_name]).to(self.device) + text_feature = self.clip_model.encode_text(text_input).detach() # [1, 512] + self.clip_text_feature.update({category_name: text_feature}) + + self.enable_disc = cfgs.get('enable_disc', False) + if self.enable_disc: + self.mask_discriminator_iter = cfgs.get('mask_discriminator_iter', [0, 0]) + # this module is not in netInstance or netPrior + + self.mask_disc_feat_condition = cfgs.get('mask_disc_feat_condition', False) + if self.mask_disc_feat_condition: + self.mask_disc = discriminator_architecture.DCDiscriminator(in_dim=(cfgs.get('dim_of_classes', 128) + 1)).to(self.device) + else: + self.mask_disc = discriminator_architecture.DCDiscriminator(in_dim=(len(list(self.netPrior.category_id_map.keys())) + 1)).to(self.device) + + self.disc_gt = cfgs.get('disc_gt', True) + self.disc_iv = cfgs.get('disc_iv', False) # whether to use input view render in disc loss + self.disc_iv_label = cfgs.get('disc_iv_label', 'Fake') + self.disc_reg_mul = cfgs.get('disc_reg_mul', 10.) + + self.record_mask_gt = None + self.record_mask_iv = None + self.record_mask_rv = None + self.discriminator_loss = 0. + self.discriminator_loss_weight = cfgs.get('discriminator_loss_weight', 0.1) + + # the local texture for fine-tune process stage + if (self.cfgs.get('texture_way', None) is not None) or self.cfgs.get('gan_tex', False): + if self.cfgs.get('gan_tex', False): + self.few_shot_gan_tex = True + self.few_shot_gan_tex_reso = self.cfgs.get('few_shot_gan_tex_reso', 64) # used to render novel view, will upsample to out_image_size ASAP + self.few_shot_gan_tex_patch = self.cfgs.get('few_shot_gan_tex_patch', 0) # used to sample patch size on out_image_size image + if self.few_shot_gan_tex_patch > 0: + self.few_shot_gan_tex_patch_max = self.cfgs.get('few_shot_gan_tex_patch_max', 128) + assert self.few_shot_gan_tex_patch_max > self.few_shot_gan_tex_patch + self.few_shot_gan_tex_patch_num = self.cfgs.get('few_shot_gan_tex_patch_num', 1) + self.discriminator_texture = discriminator_architecture.DCDiscriminator(in_dim=3, img_size=self.few_shot_gan_tex_patch).to(self.device) + else: + self.discriminator_texture = discriminator_architecture.DCDiscriminator(in_dim=3, img_size=self.out_image_size).to(self.device) + + self.few_shot_gan_tex_real = self.cfgs.get('few_shot_gan_tex_real', 'gt') + self.few_shot_gan_tex_fake = self.cfgs.get('few_shot_gan_tex_fake', 'rv') + else: + self.few_shot_gan_tex = False + + if self.cfgs.get('clip_tex', False): + self.few_shot_clip_tex = True + self.clip_model, _ = clip.load('ViT-B/32', self.device) + self.clip_model = self.clip_model.eval().requires_grad_(False) + self.clip_mean = [0.48145466, 0.4578275, 0.40821073] + self.clip_std = [0.26862954, 0.26130258, 0.27577711] + self.clip_reso = 224 + self.enable_clip_text = False + else: + self.few_shot_clip_tex = False + + else: + self.few_shot_gan_tex = False + self.few_shot_clip_tex = False + + self.enable_sds = cfgs.get('enable_sds', False) + self.enable_vsd = cfgs.get('enable_vsd', False) + if self.enable_sds: + diffusion_torch_dtype = torch.float16 if cfgs.get('diffusion_precision', 'float16') == 'float16' else torch.float32 + + # decide if use SDS or VSD + if self.enable_vsd: + # self.stable_diffusion = misc.LazyClass(StableDiffusion_VSD, device=self.device, torch_dtype=diffusion_torch_dtype) + self.stable_diffusion = StableDiffusion_VSD(device=self.device, torch_dtype=diffusion_torch_dtype) + self.diffusion_guidance_scale_lora = cfgs.get('diffusion_guidance_scale_lora', 1.) + self.diffusion_guidance_scale = cfgs.get('diffusion_guidance_scale', 7.5) + else: + self.stable_diffusion = misc.LazyClass(StableDiffusion, device=self.device, torch_dtype=diffusion_torch_dtype) + self.diffusion_guidance_scale = cfgs.get('diffusion_guidance_scale', 100.) + + self.diffusion_loss_weight = cfgs.get('diffusion_loss_weight', 1.) + self.diffusion_num_random_cameras = cfgs.get('diffusion_num_random_cameras', 1) + + # For prompts + self.diffusion_prompt = cfgs.get('diffusion_prompt', '') + self.diffusion_negative_prompt = cfgs.get('diffusion_negative_prompt', '') + + # For image sampling + self.diffusion_albedo_ratio = cfgs.get('diffusion_albedo_ratio', 0.2) + self.diffusion_shading_ratio = cfgs.get('diffusion_shading_ratio', 0.4) + self.diffusion_light_ambient = cfgs.get('diffusion_light_ambient', 0.5) + self.diffusion_light_diffuse = cfgs.get('diffusion_light_diffuse', 0.8) + self.diffusion_radius_range = cfgs.get('diffusion_radius_range', [0.8, 1.4]) + self.diffusion_uniform_sphere_rate = cfgs.get('diffusion_uniform_sphere_rate', 0.5) + self.diffusion_theta_range = cfgs.get('diffusion_theta_range', [0, 120]) + self.diffusion_phi_offset = cfgs.get('diffusion_phi_offset', 180) + self.diffusion_resolution = cfgs.get('diffusion_resolution', 256) + + print('-----------------------------------------------') + print(f"!!!!!! the phi offset for diffusion is set as {self.diffusion_phi_offset}!!!!!!!!!!!!!") + print('-----------------------------------------------') + + # For randomizing light + self.diffusion_random_light = cfgs.get('diffusion_random_light', False) + self.diffusion_light_ambient = cfgs.get('diffusion_light_ambient', 0.5) + self.diffusion_light_diffuse = cfgs.get('diffusion_light_diffuse', 0.8) + + # For noise scheduling + self.diffusion_max_step = cfgs.get('diffusion_max_step', 0.98) + + # For view-dependent prompting + self.diffusion_append_prompt_directions = cfgs.get('diffusion_append_prompt_directions', False) + self.diffusion_angle_overhead = cfgs.get('diffusion_angle_overhead', 30) + self.diffusion_angle_front = cfgs.get('diffusion_angle_front', 60) + + @staticmethod + def get_data_loaders(cfgs, dataset, in_image_size=256, out_image_size=256, batch_size=64, num_workers=4, run_train=False, run_test=False, train_data_dir=None, val_data_dir=None, test_data_dir=None, flow_bool=False): + train_loader = val_loader = test_loader = None + color_jitter_train = cfgs.get('color_jitter_train', None) + color_jitter_val = cfgs.get('color_jitter_val', None) + random_flip_train = cfgs.get('random_flip_train', False) + + ## video dataset + if dataset == 'video': + data_loader_mode = cfgs.get('data_loader_mode', 'n_frame') + skip_beginning = cfgs.get('skip_beginning', 4) + skip_end = cfgs.get('skip_end', 4) + num_sample_frames = cfgs.get('num_sample_frames', 2) + min_seq_len = cfgs.get('min_seq_len', 10) + max_seq_len = cfgs.get('max_seq_len', 10) + debug_seq = cfgs.get('debug_seq', False) + random_sample_train_frames = cfgs.get('random_sample_train_frames', False) + shuffle_train_seqs = cfgs.get('shuffle_train_seqs', False) + random_sample_val_frames = cfgs.get('random_sample_val_frames', False) + load_background = cfgs.get('background_mode', 'none') == 'background' + rgb_suffix = cfgs.get('rgb_suffix', '.png') + load_dino_feature = cfgs.get('load_dino_feature', False) + load_dino_cluster = cfgs.get('load_dino_cluster', False) + dino_feature_dim = cfgs.get('dino_feature_dim', 64) + get_loader = lambda **kwargs: get_sequence_loader( + mode=data_loader_mode, + batch_size=batch_size, + num_workers=num_workers, + in_image_size=in_image_size, + out_image_size=out_image_size, + debug_seq=debug_seq, + skip_beginning=skip_beginning, + skip_end=skip_end, + num_sample_frames=num_sample_frames, + min_seq_len=min_seq_len, + max_seq_len=max_seq_len, + load_background=load_background, + rgb_suffix=rgb_suffix, + load_dino_feature=load_dino_feature, + load_dino_cluster=load_dino_cluster, + dino_feature_dim=dino_feature_dim, + flow_bool=flow_bool, + **kwargs) + + if run_train: + assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" + print(f"Loading training data from {train_data_dir}") + train_loader = get_loader(data_dir=train_data_dir, is_validation=False, random_sample=random_sample_train_frames, shuffle=shuffle_train_seqs, dense_sample=True, color_jitter=color_jitter_train, random_flip=random_flip_train) + + if val_data_dir is not None: + assert osp.isdir(val_data_dir), f"Validation data directory does not exist: {val_data_dir}" + print(f"Loading validation data from {val_data_dir}") + val_loader = get_loader(data_dir=val_data_dir, is_validation=True, random_sample=random_sample_val_frames, shuffle=False, dense_sample=False, color_jitter=color_jitter_val, random_flip=False) + if run_test: + assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" + print(f"Loading testing data from {test_data_dir}") + test_loader = get_loader(data_dir=test_data_dir, is_validation=True, dense_sample=False, color_jitter=None, random_flip=False) + + ## CUB dataset + elif dataset == 'cub': + get_loader = lambda **kwargs: get_cub_loader( + batch_size=batch_size, + num_workers=num_workers, + image_size=in_image_size, + **kwargs) + + if run_train: + assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" + print(f"Loading training data from {train_data_dir}") + train_loader = get_loader(data_dir=train_data_dir, split='train', is_validation=False) + val_loader = get_loader(data_dir=val_data_dir, split='val', is_validation=True) + + if run_test: + assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" + print(f"Loading testing data from {test_data_dir}") + test_loader = get_loader(data_dir=test_data_dir, split='test', is_validation=True) + + ## other datasets + else: + get_loader = lambda **kwargs: get_image_loader( + batch_size=batch_size, + num_workers=num_workers, + image_size=in_image_size, + **kwargs) + + if run_train: + assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" + print(f"Loading training data from {train_data_dir}") + train_loader = get_loader(data_dir=train_data_dir, is_validation=False, color_jitter=color_jitter_train) + + if val_data_dir is not None: + assert osp.isdir(val_data_dir), f"Validation data directory does not exist: {val_data_dir}" + print(f"Loading validation data from {val_data_dir}") + val_loader = get_loader(data_dir=val_data_dir, is_validation=True, color_jitter=color_jitter_val) + + if run_test: + assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" + print(f"Loading testing data from {test_data_dir}") + test_loader = get_loader(data_dir=test_data_dir, is_validation=True, color_jitter=None) + + return train_loader, val_loader, test_loader + + @staticmethod + def get_data_loaders_ddp(cfgs, dataset, rank, world_size, in_image_size=256, out_image_size=256, batch_size=64, num_workers=4, run_train=False, run_test=False, train_data_dir=None, val_data_dir=None, test_data_dir=None, flow_bool=False): + train_loader = val_loader = test_loader = None + color_jitter_train = cfgs.get('color_jitter_train', None) + color_jitter_val = cfgs.get('color_jitter_val', None) + random_flip_train = cfgs.get('random_flip_train', False) + + ## video dataset + if dataset == 'video': + data_loader_mode = cfgs.get('data_loader_mode', 'n_frame') + skip_beginning = cfgs.get('skip_beginning', 4) + skip_end = cfgs.get('skip_end', 4) + num_sample_frames = cfgs.get('num_sample_frames', 2) + min_seq_len = cfgs.get('min_seq_len', 10) + max_seq_len = cfgs.get('max_seq_len', 10) + debug_seq = cfgs.get('debug_seq', False) + random_sample_train_frames = cfgs.get('random_sample_train_frames', False) + shuffle_train_seqs = cfgs.get('shuffle_train_seqs', False) + random_sample_val_frames = cfgs.get('random_sample_val_frames', False) + load_background = cfgs.get('background_mode', 'none') == 'background' + rgb_suffix = cfgs.get('rgb_suffix', '.png') + load_dino_feature = cfgs.get('load_dino_feature', False) + load_dino_cluster = cfgs.get('load_dino_cluster', False) + dino_feature_dim = cfgs.get('dino_feature_dim', 64) + get_loader_ddp = lambda **kwargs: get_sequence_loader_ddp( + mode=data_loader_mode, + batch_size=batch_size, + num_workers=num_workers, + in_image_size=in_image_size, + out_image_size=out_image_size, + debug_seq=debug_seq, + skip_beginning=skip_beginning, + skip_end=skip_end, + num_sample_frames=num_sample_frames, + min_seq_len=min_seq_len, + max_seq_len=max_seq_len, + load_background=load_background, + rgb_suffix=rgb_suffix, + load_dino_feature=load_dino_feature, + load_dino_cluster=load_dino_cluster, + dino_feature_dim=dino_feature_dim, + flow_bool=flow_bool, + **kwargs) + get_loader = lambda **kwargs: get_sequence_loader( + mode=data_loader_mode, + batch_size=batch_size, + num_workers=num_workers, + in_image_size=in_image_size, + out_image_size=out_image_size, + debug_seq=debug_seq, + skip_beginning=skip_beginning, + skip_end=skip_end, + num_sample_frames=num_sample_frames, + min_seq_len=min_seq_len, + max_seq_len=max_seq_len, + load_background=load_background, + rgb_suffix=rgb_suffix, + load_dino_feature=load_dino_feature, + load_dino_cluster=load_dino_cluster, + dino_feature_dim=dino_feature_dim, + **kwargs) + + if run_train: + if isinstance(train_data_dir, dict): + for data_path in train_data_dir.values(): + assert osp.isdir(data_path), f"Training data directory does not exist: {data_path}" + elif isinstance(train_data_dir, str): + assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" + else: + raise ValueError("train_data_dir must be a string or a dict of strings") + + print(f"Loading training data...") + train_loader = get_loader_ddp(data_dir=train_data_dir, rank=rank, world_size=world_size, is_validation=False, random_sample=random_sample_train_frames, shuffle=shuffle_train_seqs, dense_sample=True, color_jitter=color_jitter_train, random_flip=random_flip_train) + + if val_data_dir is not None: + if isinstance(val_data_dir, dict): + for data_path in val_data_dir.values(): + assert osp.isdir(data_path), f"Training data directory does not exist: {data_path}" + elif isinstance(val_data_dir, str): + assert osp.isdir(val_data_dir), f"Training data directory does not exist: {val_data_dir}" + else: + raise ValueError("train_data_dir must be a string or a dict of strings") + print(f"Loading validation data...") + # No need for data parallel for the validation data loader. + val_loader = get_loader(data_dir=val_data_dir, is_validation=True, random_sample=random_sample_val_frames, shuffle=False, dense_sample=False, color_jitter=color_jitter_val, random_flip=False) + + if run_test: + assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" + print(f"Loading testing data from {test_data_dir}") + test_loader = get_loader_ddp(data_dir=test_data_dir, rank=rank, world_size=world_size, is_validation=True, dense_sample=False, color_jitter=None, random_flip=False) + + ## CUB dataset + elif dataset == 'cub': + get_loader = lambda **kwargs: get_cub_loader_ddp( + batch_size=batch_size, + num_workers=num_workers, + image_size=in_image_size, + **kwargs) + + if run_train: + assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" + print(f"Loading training data from {train_data_dir}") + train_loader = get_loader(data_dir=train_data_dir, rank=rank, world_size=world_size, split='train', is_validation=False) + val_loader = get_loader(data_dir=val_data_dir, rank=rank, world_size=world_size, split='val', is_validation=True) + + if run_test: + assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" + print(f"Loading testing data from {test_data_dir}") + test_loader = get_loader(data_dir=test_data_dir, rank=rank, world_size=world_size, split='test', is_validation=True) + + ## other datasets + else: + get_loader = lambda **kwargs: get_image_loader_ddp( + batch_size=batch_size, + num_workers=num_workers, + image_size=in_image_size, + **kwargs) + + if run_train: + assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" + print(f"Loading training data from {train_data_dir}") + train_loader = get_loader(data_dir=train_data_dir, rank=rank, world_size=world_size, is_validation=False, color_jitter=color_jitter_train) + + if val_data_dir is not None: + assert osp.isdir(val_data_dir), f"Validation data directory does not exist: {val_data_dir}" + print(f"Loading validation data from {val_data_dir}") + val_loader = get_loader(data_dir=val_data_dir, rank=rank, world_size=world_size, is_validation=True, color_jitter=color_jitter_val) + + if run_test: + assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" + print(f"Loading testing data from {test_data_dir}") + test_loader = get_loader(data_dir=test_data_dir, rank=rank, world_size=world_size, is_validation=True, color_jitter=None) + + return train_loader, val_loader, test_loader + + def load_model_state(self, cp): + # TODO: very hacky: if using local texture, which is also usually finetuned from global texture + # we need to check if needs some handcrafted load in netInstance + if (self.netInstance.texture_way is not None) or (self.cfgs.get('texture_act', 'relu') != 'relu'): + new_netInstance_weights = {k: v for k, v in cp['netInstance'].items() if 'netTexture' not in k} + #find the new texture weights + texture_weights = self.netInstance.netTexture.state_dict() + #add the new weights to the new model weights + for k, v in texture_weights.items(): + new_netInstance_weights['netTexture.' + k] = v + self.netInstance.load_state_dict(new_netInstance_weights) + else: + self.netInstance.load_state_dict(cp["netInstance"]) + if self.enable_disc and "net_mask_disc" in cp: + self.mask_disc.load_state_dict(cp["net_mask_disc"]) + if self.enable_prior: + self.netPrior.load_state_dict(cp["netPrior"]) + + + def load_optimizer_state(self, cp): + # TODO: also very hacky here, as the load_model_state above + if self.netInstance.texture_way is not None: + opt_state_dict = self.optimizerInstance.state_dict() + param_ids = [id(p) for p in self.netInstance.netTexture.parameters()] + new_opt_state_dict = {} + new_opt_state_dict['state'] = {k: v for k, v in opt_state_dict['state'].items() if k not in param_ids} + + new_param_groups = [] + for param_group in opt_state_dict['param_groups']: + new_param_group = {k: v for k, v in param_group.items() if k != 'params'} + new_param_group['params'] = [p_id for p_id in param_group['params'] if p_id not in param_ids] + new_param_groups.append(new_param_group) + + new_opt_state_dict['param_groups'] = new_param_groups + + self.optimizerInstance.load_state_dict(new_opt_state_dict) + else: + self.optimizerInstance.load_state_dict(cp["optimizerInstance"]) + + # add parameters into optimizerInstance here + # if self.enable_disc: + # print('add mask discriminator parameters to Instance optimizer') + # self.optimizerInstance.add_param_group({'params': self.mask_disc.parameters()}) + + if self.use_scheduler: + if 'schedulerInstance' in cp: + self.schedulerInstance.load_state_dict(cp["schedulerInstance"]) + if self.enable_disc and "optimizerDiscriminator" in cp: + self.optimizerDiscriminator.load_state_dict(cp["optimizerDiscriminator"]) + if self.enable_prior and self.resume_prior_optim: + self.optimizerPrior.load_state_dict(cp["optimizerPrior"]) + if self.use_scheduler: + if 'schedulerPrior' in cp: + self.schedulerPrior.load_state_dict(cp["schedulerPrior"]) + + def get_model_state(self): + state = {"netInstance": self.netInstance.state_dict()} + if self.enable_disc: + state["net_mask_disc"] = self.mask_disc.state_dict() + if self.enable_prior: + state["netPrior"] = self.netPrior.state_dict() + return state + + def get_optimizer_state(self): + state = {"optimizerInstance": self.optimizerInstance.state_dict()} + if self.enable_disc: + state['optimizerDiscriminator'] = self.optimizerDiscriminator.state_dict() + if self.use_scheduler: + state["schedulerInstance"] = self.schedulerInstance.state_dict() + if self.enable_prior: + state["optimizerPrior"] = self.optimizerPrior.state_dict() + if self.use_scheduler: + state["schedulerPrior"] = self.schedulerPrior.state_dict() + return state + + def to(self, device): + self.device = device + self.netInstance.to(device) + if self.enable_prior: + self.netPrior.to(device) + for v in vars(self.netPrior.netShape): + attr = getattr(self.netPrior.netShape,v) + if type(attr) == torch.Tensor: + setattr(self.netPrior.netShape, v, attr.to(device)) + if hasattr(self, 'perceptual_loss'): + self.perceptual_loss.to(device) + + def ddp(self, rank, world_size): + self.rank = rank + self.world_size = world_size + + if self.world_size > 1: + self.netInstance_ddp = DDP( + self.netInstance, device_ids=[rank], + find_unused_parameters=True) + self.netInstance_ddp._set_static_graph() + self.netInstance = self.netInstance_ddp.module + + if self.enable_prior: + self.netPrior_ddp = DDP( + self.netPrior, device_ids=[rank], + find_unused_parameters=True) + self.netPrior_ddp._set_static_graph() + self.netPrior = self.netPrior_ddp.module + + if hasattr(self, 'perceptual_loss'): + self.perceptual_loss_ddp = DDP( + self.perceptual_loss, device_ids=[rank], + find_unused_parameters=True) + self.perceptual_loss = self.perceptual_loss_ddp.module + else: + print('actually no DDP for model') + + def set_train(self): + if self.world_size > 1: + self.netInstance_ddp.train() + if self.enable_prior: + self.netPrior_ddp.train() + else: + self.netInstance.train() + if self.enable_disc: + self.mask_disc.train() + if self.enable_prior: + self.netPrior.train() + + def set_eval(self): + if self.world_size > 1: + self.netInstance_ddp.eval() + if self.enable_prior: + self.netPrior_ddp.eval() + else: + self.netInstance.eval() + if self.enable_disc: + self.mask_disc.eval() + if self.enable_prior: + self.netPrior.eval() + + def reset_optimizers(self): + print("Resetting optimizers...") + self.optimizerInstance = get_optimizer(self.netInstance, self.lr) + + if self.enable_disc: + self.optimizerDiscriminator = get_optimizer(self.mask_disc, self.lr) + + if self.use_scheduler: + self.schedulerInstance = self.make_scheduler(self.optimizerInstance) + if self.enable_prior: + self.optimizerPrior = get_optimizer(self.netPrior, lr=self.prior_lr, weight_decay=self.prior_weight_decay) + if self.use_scheduler: + self.schedulerPrior = self.make_scheduler(self.optimizerPrior) + + def reset_only_disc_optimizer(self): + if self.enable_disc: + self.optimizerDiscriminator = get_optimizer(self.mask_disc, self.lr) + + def backward(self): + self.optimizerInstance.zero_grad() + if self.backward_prior: + self.optimizerPrior.zero_grad() + # self.total_loss = self.add_unused() + self.total_loss.backward() + self.optimizerInstance.step() + if self.backward_prior: + self.optimizerPrior.step() + self.total_loss = 0. + + def scheduler_step(self): + if self.use_scheduler: + self.schedulerInstance.step() + if self.enable_prior: + self.schedulerPrior.step() + + def zflip_pose(self, pose): + if self.rot_rep == 'lookat': + vec_forward = pose[:,:,6:9] + vec_forward = vec_forward * torch.FloatTensor([1,1,-1]).view(1,1,3).to(vec_forward.device) + up = torch.FloatTensor([0,1,0]).to(pose.device).view(1,1,3) + vec_right = up.expand_as(vec_forward).cross(vec_forward, dim=-1) + vec_right = nn.functional.normalize(vec_right, p=2, dim=-1) + vec_up = vec_forward.cross(vec_right, dim=-1) + vec_up = nn.functional.normalize(vec_up, p=2, dim=-1) + rot_mat = torch.stack([vec_right, vec_up, vec_forward], 2) + rot_pred = rot_mat.reshape(*pose.shape[:-1], -1) + pose_zflip = torch.cat([rot_pred, pose[:,:,9:]], -1) + else: + raise NotImplementedError + return pose_zflip + + def render(self, shape, texture, mvp, w2c, campos, resolution, background='none', im_features=None, light=None, prior_shape=None, render_flow=False, dino_pred=None, class_vector=None, render_mode='diffuse', two_sided_shading=True, num_frames=None, spp=1, bg_image=None, im_features_map=None): + h, w = resolution + N = len(mvp) + if bg_image is None: + if background in ['none', 'black']: + bg_image = torch.zeros((N, h, w, 3), device=mvp.device) + elif background == 'white': + bg_image = torch.ones((N, h, w, 3), device=mvp.device) + elif background == 'checkerboard': + bg_image = torch.FloatTensor(util.checkerboard((h, w), 8), device=self.device).repeat(N, 1, 1, 1) # NxHxWxC + elif background == 'random': + bg_image = torch.rand((N, h, w, 3), device=mvp.device) # NxHxWxC + elif background == 'random-pure': + random_values = torch.rand(N) + bg_image = random_values[..., None, None, None].repeat(1, h, w, 3).to(self.device) + else: + raise NotImplementedError + + #insider render_mesh -> render_layer -> shade DOR + frame_rendered = render.render_mesh( + self.glctx, + shape, + mtx_in=mvp, + w2c=w2c, + view_pos=campos, + material=texture, + lgt=light, + resolution=resolution, + spp=spp, + msaa=True, + background=bg_image, + bsdf=render_mode, + feat=im_features, + prior_mesh=prior_shape, + two_sided_shading=two_sided_shading, + render_flow=render_flow, + dino_pred=dino_pred, + class_vector=class_vector, + num_frames=num_frames, + im_features_map=im_features_map) + shaded = frame_rendered['shaded'].permute(0, 3, 1, 2) + image_pred = shaded[:, :3, :, :] + mask_pred = shaded[:, 3, :, :] + albedo = frame_rendered['kd'].permute(0, 3, 1, 2)[:, :3, :, :] + if 'shading' in frame_rendered: + shading = frame_rendered['shading'].permute(0, 3, 1, 2)[:, :1, :, :] + else: + shading = None + if render_flow: + flow_pred = frame_rendered['flow'] + flow_pred = flow_pred.permute(0, 3, 1, 2)[:, :2, :, :] + else: + flow_pred = None + if dino_pred is not None: + dino_feat_im_pred = frame_rendered['dino_feat_im_pred'] + dino_feat_im_pred = dino_feat_im_pred.permute(0, 3, 1, 2)[:, :-1] + else: + dino_feat_im_pred = None + + return image_pred, mask_pred, flow_pred, dino_feat_im_pred, albedo, shading + + def compute_reconstruction_losses(self, image_pred, image_gt, mask_pred, mask_gt, mask_dt, mask_valid, flow_pred, flow_gt, dino_feat_im_gt, dino_feat_im_pred, background_mode='none', reduce=False): + losses = {} + batch_size, num_frames, _, h, w = image_pred.shape # BxFxCxHxW + + # image_loss = (image_pred - image_gt) ** 2 + image_loss = (image_pred - image_gt).abs() + + ## silhouette loss + mask_pred_valid = mask_pred * mask_valid + # mask_pred_valid = mask_pred + # losses["silhouette_loss"] = ((mask_pred - mask_gt) ** 2).mean() + # mask_loss_mask = (image_loss.mean(2).detach() > 0.05).float() + mask_loss = (mask_pred_valid - mask_gt) ** 2 + # mask_loss = nn.functional.mse_loss(mask_pred, mask_gt) + # num_mask_pixels = mask_loss_mask.reshape(batch_size*num_frames, -1).sum(1).clamp(min=1) + # losses["silhouette_loss"] = (mask_loss.reshape(batch_size*num_frames, -1).sum(1) / num_mask_pixels).mean() + losses['silhouette_loss'] = mask_loss.view(batch_size, num_frames, -1).mean(2) + losses['silhouette_dt_loss'] = (mask_pred * mask_dt[:,:,1]).view(batch_size, num_frames, -1).mean(2) + losses['silhouette_inv_dt_loss'] = ((1-mask_pred) * mask_dt[:,:,0]).view(batch_size, num_frames, -1).mean(2) + + mask_pred_binary = (mask_pred_valid > 0.).float().detach() + mask_both_binary = (mask_pred_binary * mask_gt).view(batch_size*num_frames, 1, *mask_pred.shape[2:]) + mask_both_binary = (nn.functional.avg_pool2d(mask_both_binary, 3, stride=1, padding=1).view(batch_size, num_frames, *mask_pred.shape[2:]) > 0.99).float().detach() # erode by 1 pixel + + ## reconstruction loss + # image_loss_mask = (mask_pred*mask_gt).unsqueeze(2).expand_as(image_gt) + # image_loss = image_loss * image_loss_mask + # num_mask_pixels = image_loss_mask.reshape(batch_size*num_frames, -1).sum(1).clamp(min=1) + # losses["rgb_loss"] = (image_loss.reshape(batch_size*num_frames, -1).sum(1) / num_mask_pixels).mean() + if background_mode in ['background', 'input']: + pass + else: + image_loss = image_loss * mask_both_binary.unsqueeze(2) + losses['rgb_loss'] = image_loss.reshape(batch_size, num_frames, -1).mean(2) + + if self.cfgs.get('perceptual_loss_weight', 0.) > 0: + if background_mode in ['background', 'input']: + perc_image_pred = image_pred + perc_image_gt = image_gt + else: + perc_image_pred = image_pred * mask_pred_binary.unsqueeze(2) + 0.5 * (1-mask_pred_binary.unsqueeze(2)) + perc_image_gt = image_gt * mask_pred_binary.unsqueeze(2) + 0.5 * (1-mask_pred_binary.unsqueeze(2)) + losses['perceptual_loss'] = self.perceptual_loss(perc_image_pred.view(-1, *image_pred.shape[2:]) *2-1, perc_image_gt.view(-1, *image_gt.shape[2:]) *2-1).view(batch_size, num_frames) + + ## flow loss - between first and second frame + if flow_pred is not None: + flow_loss = (flow_pred - flow_gt).abs() + flow_loss_mask = mask_both_binary[:,:-1].unsqueeze(2).expand_as(flow_gt).detach() + + ## ignore frames where GT flow is too large (likely inaccurate) + large_flow = (flow_gt.abs() > 0.5).float() * flow_loss_mask + large_flow = (large_flow.view(batch_size, num_frames-1, -1).sum(2) > 0).float() + self.large_flow = large_flow + + flow_loss = flow_loss * flow_loss_mask * (1 - large_flow[:,:,None,None,None]) + num_mask_pixels = flow_loss_mask.reshape(batch_size, num_frames-1, -1).sum(2).clamp(min=1) + losses['flow_loss'] = (flow_loss.reshape(batch_size, num_frames-1, -1).sum(2) / num_mask_pixels) + # losses["flow_loss"] = flow_loss.mean() + + if dino_feat_im_pred is not None and dino_feat_im_gt is not None: + dino_feat_loss = (dino_feat_im_pred - dino_feat_im_gt) ** 2 + dino_feat_loss = dino_feat_loss * mask_both_binary.unsqueeze(2) + losses['dino_feat_im_loss'] = dino_feat_loss.reshape(batch_size, num_frames, -1).mean(2) + + if reduce: + for k, v in losses.item(): + losses[k] = v.mean() + return losses + + def compute_pose_xflip_reg_loss(self, input_image, dino_feat_im, pose_raw, input_image_xflip_flag=None): + image_xflip = input_image.flip(4) + if dino_feat_im is not None: + dino_feat_im_xflip = dino_feat_im.flip(4) + else: + dino_feat_im_xflip = None + + if self.world_size > 1: + netInst = self.netInstance_ddp + else: + netInst = self.netInstance + + # feat_xflip, _ = self.netInstance_ddp.forward_encoder(image_xflip, dino_feat_im_xflip) + feat_xflip, _ = netInst.forward_encoder(image_xflip, dino_feat_im_xflip) + batch_size, num_frames = input_image.shape[:2] + # pose_xflip_raw = self.netInstance_ddp.forward_pose(image_xflip, feat_xflip, dino_feat_im_xflip) + pose_xflip_raw = netInst.forward_pose(image_xflip, feat_xflip, dino_feat_im_xflip) + + if input_image_xflip_flag is not None: + pose_xflip_raw_xflip = pose_xflip_raw * torch.FloatTensor([-1,1,1,-1,1,1]).to(pose_raw.device) # forward x, trans x + pose_xflip_raw = pose_xflip_raw * (1 - input_image_xflip_flag.view(batch_size * num_frames, 1)) + pose_xflip_raw_xflip * input_image_xflip_flag.view(batch_size * num_frames, 1) + + # rot_rep = self.netInstance_ddp.rot_rep + rot_rep = netInst.rot_rep + if rot_rep == 'euler_angle' or rot_rep == 'soft_calss': + pose_xflip_xflip = pose_xflip * torch.FloatTensor([1,-1,-1,-1,1,1]).to(pose_xflip.device) # rot y+z, trans x + pose_xflip_reg_loss = ((pose_xflip_xflip - pose) ** 2.).mean() + elif rot_rep == 'quaternion': + rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose[...,:4]), convention='XYZ') + pose_euler = torch.cat([rot_euler, pose[...,4:]], -1) + rot_xflip_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose_xflip[...,:4]), convention='XYZ') + pose_xflip_euler = torch.cat([rot_xflip_euler, pose_xflip[...,4:]], -1) + pose_xflip_euler_xflip = pose_xflip_euler * torch.FloatTensor([1,-1,-1,-1,1,1]).to(pose_xflip.device) # rot y+z, trans x + pose_xflip_reg_loss = ((pose_xflip_euler_xflip - pose_euler) ** 2.).mean() + elif rot_rep == 'lookat': + pose_xflip_raw_xflip = pose_xflip_raw * torch.FloatTensor([-1,1,1,-1,1,1]).to(pose_raw.device) # forward x, trans x + pose_xflip_reg_loss = ((pose_xflip_raw_xflip - pose_raw)[...,0] ** 2.) # compute x only + # if epoch >= self.nolookat_zflip_loss_epochs and self.lookat_zflip_no_other_losses: + # pose_xflip_reg_loss = pose_xflip_reg_loss.mean(1) * is_pose_1_better + pose_xflip_reg_loss = pose_xflip_reg_loss.mean() + return pose_xflip_reg_loss, pose_xflip_raw + + def compute_edge_length_reg_loss(self, mesh, prior_mesh): + prior_edge_lengths = get_edge_length(prior_mesh.v_pos, prior_mesh.t_pos_idx) + max_length = prior_edge_lengths.max().detach() *1.1 + edge_lengths = get_edge_length(mesh.v_pos, mesh.t_pos_idx) + mesh_edge_length_loss = ((edge_lengths - max_length).clamp(min=0)**2).mean() + return mesh_edge_length_loss, edge_lengths + + def compute_regularizers(self, mesh, prior_mesh, input_image, dino_feat_im, pose_raw, input_image_xflip_flag=None, arti_params=None, deformation=None, mid_img_idx=0, posed_bones=None, class_vector=None): + losses = {} + aux = {} + + if self.enable_prior: + losses.update(self.netPrior.netShape.get_sdf_reg_loss(class_vector=class_vector)) + + if self.cfgs.get('pose_xflip_reg_loss_weight', 0.) > 0: + losses["pose_xflip_reg_loss"], aux['pose_xflip_raw'] = self.compute_pose_xflip_reg_loss(input_image, dino_feat_im, pose_raw, input_image_xflip_flag) + + if self.using_campos_smooth_loss: + # from IPython import embed; embed() + pose_raw_ = pose_raw.view(self.bs, self.nf, *pose_raw.shape[1:]) + losses['campos_smooth_loss'] = self.campos_smooth_loss_fn(pose_raw_) + + b, f = input_image.shape[:2] + if b >= 2: + vec_forward = pose_raw[..., :3] + losses['pose_entropy_loss'] = (vec_forward[:b//2] * vec_forward[b//2:(b//2)*2]).sum(-1).mean() + else: + losses['pose_entropy_loss'] = 0. + + losses['mesh_normal_consistency_loss'] = normal_consistency(mesh.v_pos, mesh.t_pos_idx) + losses['mesh_laplacian_consistency_loss'] = laplace_regularizer_const(mesh.v_pos, mesh.t_pos_idx) + losses['mesh_edge_length_loss'], aux['edge_lengths'] = self.compute_edge_length_reg_loss(mesh, prior_mesh) + if arti_params is not None: + #losses['arti_reg_loss'] = (arti_params ** 2).mean() + losses['arti_reg_loss'] = (arti_params ** 2).mean() #TODO dor Rart + + if arti_params is not None and self.using_arti_smooth_loss: + arti_smooth_loss = self.arti_smooth_loss_fn(arti_params) + losses['arti_smooth_loss'] = arti_smooth_loss + # if arti_params is not None and self.cfgs.get('arti_smooth_loss_weight', 0.) > 0: + # if self.smooth_type == 'loss' and mid_img_idx > 0: + # # print("+++++++++++++++++add smooth to *articulation* loss") + # # from IPython import embed; embed() + # arti_smooth_loss = ( + # ((arti_params[:,mid_img_idx,:,:] - arti_params[:,0:mid_img_idx,:,:])**2) + # + ((arti_params[:,mid_img_idx,:,:] - arti_params[:,mid_img_idx+1:2*mid_img_idx+1,:,:])**2) + # ).mean() + # losses['arti_smooth_loss'] = arti_smooth_loss + + if arti_params is not None and self.using_artivel_smooth_loss: + # from IPython import embed; embed() + _, nf, _, _= arti_params.shape + arti_vel = arti_params[:,1:nf,:,:] - arti_params[:,:(nf-1),:,:] + artivel_smooth_loss = self.artivel_smooth_loss_fn(arti_vel) + losses['artivel_smooth_loss'] = artivel_smooth_loss + + if deformation is not None: + #losses['deformation_reg_loss'] = (deformation ** 2).mean() + losses['deformation_reg_loss'] = (deformation ** 2).mean() #TODO dor - Rdef + + d1 = deformation[:, mesh.t_pos_idx[0, :, 0], :] + d2 = deformation[:, mesh.t_pos_idx[0, :, 1], :] + d3 = deformation[:, mesh.t_pos_idx[0, :, 2], :] + + num_samples = 5000 + sample_idx1 = torch.randperm(d1.shape[1])[:num_samples].to(self.device) + sample_idx2 = torch.randperm(d1.shape[1])[:num_samples].to(self.device) + sample_idx3 = torch.randperm(d1.shape[1])[:num_samples].to(self.device) + + dist1 = ((d1[:, sample_idx1, :] - d2[:, sample_idx1, :]) ** 2).mean() + dist2 = ((d2[:, sample_idx2, :] - d3[:, sample_idx2, :]) ** 2).mean() + dist3 = ((d3[:, sample_idx3, :] - d1[:, sample_idx3, :]) ** 2).mean() + + losses['smooth_deformation_loss'] = dist1 + dist2 + dist3 + + if deformation is not None and self.using_deform_smooth_loss: + deformation_ = deformation.view(self.bs, self.nf, *deformation.shape[1:]) + losses['deform_smooth_loss'] = self.deform_smooth_loss_fn(deformation_) + # if deformation is not None and self.cfgs.get('deformation_smooth_loss_weight', 0.) > 0: + # if self.smooth_type == 'loss' and mid_img_idx > 0: + # # print("+++++++++++++++++add smooth to *deformation* loss") + # deformation = deformation.view(self.bs, self.nf, *deformation.shape[1:]) + # deformation_smooth_loss = ( + # ((deformation[:, mid_img_idx,:,:] - deformation[:, 0:mid_img_idx,:,:]) ** 2) + # + ((deformation[:, mid_img_idx,:,:] - deformation[:, mid_img_idx+1:2*mid_img_idx+1,:,:]) ** 2) + # ).mean() + # losses['deformation_smooth_loss'] = deformation_smooth_loss + # # deformation = deformation.view(self.bs * self.nf, *deformation.shape[2:]) + # # losses['deformation_reg_loss'] = deformation.abs().mean() + + ## posed bones. + if posed_bones is not None and self.using_bone_smooth_loss: + bone_smooth_loss = self.bone_smooth_loss_fn(posed_bones) + losses['bone_smooth_loss'] = bone_smooth_loss + + if posed_bones is not None and self.using_bonevel_smooth_loss: + _, nf, _, _, _= posed_bones.shape + bone_vel = posed_bones[:,1:nf,...] - posed_bones[:,:(nf-1),...] + bonevel_smooth_loss = self.bonevel_smooth_loss_fn(bone_vel) + losses['bonevel_smooth_loss'] = bonevel_smooth_loss + + return losses, aux + + def score_distillation_sampling(self, shape, texture, resolution, im_features, light, prior_shape, random_light=False, prompts=None, classes_vectors=None, im_features_map=None, w2c_pred=None): + num_instances = im_features.shape[0] + n_total_random_cameras = num_instances * self.diffusion_num_random_cameras + + poses, dirs = rand_poses( + n_total_random_cameras, self.device, radius_range=self.diffusion_radius_range, uniform_sphere_rate=self.diffusion_uniform_sphere_rate, + cam_z_offset=self.cam_pos_z_offset, theta_range=self.diffusion_theta_range, phi_offset=self.diffusion_phi_offset, return_dirs=True, + angle_front=self.diffusion_angle_front, angle_overhead=self.diffusion_angle_overhead, + ) + mvp, w2c, campos = self.netInstance.get_camera_extrinsics_from_pose(poses, crop_fov_approx=self.crop_fov_approx) + + if random_light: + lights = rand_lights(campos, fixed_ambient=self.diffusion_light_ambient, fixed_diffuse=self.diffusion_light_diffuse) + else: + lights = light + + proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(num_instances, 1, 1).to(self.device) + original_mvp = torch.bmm(proj, w2c_pred) + + im_features = im_features.repeat(self.diffusion_num_random_cameras, 1) if im_features is not None else None + num_shapes = shape.v_pos.shape[0] + assert n_total_random_cameras % num_shapes == 0 + shape = shape.extend(n_total_random_cameras // num_shapes) + + bg_color = torch.rand((n_total_random_cameras, 3), device=self.device) # channel-wise random + background = repeat(bg_color, 'b c -> b h w c', h=resolution[0], w=resolution[1]) + + # only train the texture + safe_detach = lambda x: x.detach() if x is not None else None + shape = safe_detach(shape) + im_features = safe_detach(im_features) + im_features_map = safe_detach(im_features_map) + + set_requires_grad(texture, True) + set_requires_grad(light, True) + + image_pred, mask_pred, _, _, albedo, shading = self.render( + shape, + texture, + mvp, + w2c, + campos, + resolution, + im_features=im_features, + light=lights, + prior_shape=prior_shape, + dino_pred=None, + spp=self.renderer_spp, + bg_image=background, + im_features_map={"original_mvp": original_mvp, "im_features_map": im_features_map} if im_features_map is not None else None + ) + if self.enable_vsd: + if prompts is None: + prompts = n_total_random_cameras * [self.diffusion_prompt] + else: + if '_' in prompts: + prompts = prompts.replace('_', ' ') + prompts = n_total_random_cameras * [prompts] + + prompts = ['a high-resolution DSLR image of ' + x for x in prompts] + assert self.diffusion_append_prompt_directions + # TODO: check if this implementation is aligned with stable-diffusion-prompt-processor + prompts_vd = append_text_direction(prompts, dirs) + negative_prompts = n_total_random_cameras * [self.diffusion_negative_prompt] + + text_embeddings = self.stable_diffusion.get_text_embeds(prompts, negative_prompts) # [BB, 77, 768] + text_embeddings_vd = self.stable_diffusion.get_text_embeds(prompts_vd, negative_prompts) + + camera_condition_type = 'c2w' + if camera_condition_type == 'c2w': + camera_condition = torch.linalg.inv(w2c).detach() + elif camera_condition_type == 'mvp': + camera_condition = mvp.detach() + else: + raise NotImplementedError + + # Alternate among albedo, shading, and image + rand = torch.rand(n_total_random_cameras, device=self.device) + rendered_component = torch.zeros_like(image_pred) + mask_pred = mask_pred[:, None] + background = rearrange(background, 'b h w c -> b c h w') + albedo_flag = rand > (1 - self.diffusion_albedo_ratio) + rendered_component[albedo_flag] = albedo[albedo_flag] * mask_pred[albedo_flag] + (1 - mask_pred[albedo_flag]) * background[albedo_flag] + shading_flag = (rand > (1 - self.diffusion_albedo_ratio - self.diffusion_shading_ratio)) & (rand <= (1 - self.diffusion_albedo_ratio)) + rendered_component[shading_flag] = shading.repeat(1, 3, 1, 1)[shading_flag] / 2 * mask_pred[shading_flag] + (1 - mask_pred[shading_flag]) * background[shading_flag] + rendered_component[~(albedo_flag | shading_flag)] = image_pred[~(albedo_flag | shading_flag)] + + condition_label = classes_vectors + # condition_label = im_features + + sd_loss, sd_aux = self.stable_diffusion.train_step( + text_embeddings, + text_embeddings_vd, + rendered_component, + camera_condition, # TODO: can we input category condition in lora? + condition_label, + guidance_scale=self.diffusion_guidance_scale, + guidance_scale_lora=self.diffusion_guidance_scale_lora, + loss_weight=self.diffusion_loss_weight, + max_step_pct=self.diffusion_max_step, + return_aux=True + ) + + aux = {'loss': sd_loss['loss_vsd'], 'loss_lora': sd_loss['loss_lora'], 'dirs': dirs, 'sd_aux': sd_aux, 'rendered_shape': shape} + + else: + # Prompt to text embeds + if prompts is None: + prompts = n_total_random_cameras * [self.diffusion_prompt] + else: + if '_' in prompts: + prompts = prompts.replace('_', ' ') + prompts = n_total_random_cameras * [prompts] + prompts = ['a high-resolution DSLR image of ' + x for x in prompts] + if self.diffusion_append_prompt_directions: + prompts = append_text_direction(prompts, dirs) + negative_prompts = n_total_random_cameras * [self.diffusion_negative_prompt] + text_embeddings = self.stable_diffusion.get_text_embeds(prompts, negative_prompts) # [2, 77, 768] + + # Alternate among albedo, shading, and image + rand = torch.rand(n_total_random_cameras, device=self.device) + rendered_component = torch.zeros_like(image_pred) + mask_pred = mask_pred[:, None] + background = rearrange(background, 'b h w c -> b c h w') + albedo_flag = rand > (1 - self.diffusion_albedo_ratio) + rendered_component[albedo_flag] = albedo[albedo_flag] * mask_pred[albedo_flag] + (1 - mask_pred[albedo_flag]) * background[albedo_flag] + shading_flag = (rand > (1 - self.diffusion_albedo_ratio - self.diffusion_shading_ratio)) & (rand <= (1 - self.diffusion_albedo_ratio)) + rendered_component[shading_flag] = shading.repeat(1, 3, 1, 1)[shading_flag] / 2 * mask_pred[shading_flag] + (1 - mask_pred[shading_flag]) * background[shading_flag] + rendered_component[~(albedo_flag | shading_flag)] = image_pred[~(albedo_flag | shading_flag)] + sd_loss, sd_aux = self.stable_diffusion.train_step( + text_embeddings, rendered_component, guidance_scale=self.diffusion_guidance_scale, loss_weight=self.diffusion_loss_weight, max_step_pct=self.diffusion_max_step, return_aux=True) + aux = {'loss':sd_loss, 'dirs': dirs, 'sd_aux': sd_aux, 'rendered_shape': shape} + + return rendered_component, aux + + def parse_dict_definition(self, dict_config, total_iter): + ''' + The dict_config is a diction-based configuration with ascending order + The key: value is the NUM_ITERATION_WEIGHT_BEGIN: WEIGHT + For example, + {0: 0.1, 1000: 0.2, 10000: 0.3} + means at beginning, the weight is 0.1, from 1k iterations, weight is 0.2, and after 10k, weight is 0.3 + ''' + length = len(dict_config) + all_iters = list(dict_config.keys()) + all_weights = list(dict_config.values()) + + weight = all_weights[-1] + + for i in range(length-1): + # this works for dict having at least two items, otherwise you don't need dict to set config + iter_num = all_iters[i] + iter_num_next = all_iters[i+1] + if iter_num <= total_iter and total_iter < iter_num_next: + weight = all_weights[i] + break + + return weight + + def compute_clip_loss(self, random_image_pred, image_pred, category): + # image preprocess for CLIP + random_image = torch.nn.functional.interpolate(random_image_pred, (self.clip_reso, self.clip_reso), mode='bilinear') + image_pred = torch.nn.functional.interpolate(image_pred.squeeze(1), (self.clip_reso, self.clip_reso), mode='bilinear') + random_image = tvf.normalize(random_image, self.clip_mean, self.clip_std) + image_pred = tvf.normalize(image_pred, self.clip_mean, self.clip_std) + + feat_img_1 = self.clip_model.encode_image(random_image) + feat_img_2 = self.clip_model.encode_image(image_pred) + + clip_all_loss = torch.nn.functional.cosine_similarity(feat_img_1, feat_img_2) + clip_all_loss = 1 - clip_all_loss.mean() + + # feat_img_1 = torch.mean(feat_img_1, dim=0) + # feat_img_2 = torch.mean(feat_img_2, dim=0) + # clip_all_loss = torch.nn.functional.cosine_similarity(feat_img_1, feat_img_2, dim=0) + # clip_all_loss = 1 - clip_all_loss + + if self.enable_clip_text: + text_feature = self.clip_text_feature[category].repeat(feat_img_1.shape[0], 1) + + text_loss_1 = torch.nn.functional.cosine_similarity(feat_img_1, text_feature).mean() + text_loss_2 = torch.nn.functional.cosine_similarity(feat_img_2, text_feature).mean() + + # text_feature = self.clip_text_feature[category][0] + + # text_loss_1 = torch.nn.functional.cosine_similarity(feat_img_1, text_feature, dim=0) + # text_loss_2 = torch.nn.functional.cosine_similarity(feat_img_2, text_feature, dim=0) + + clip_all_loss = clip_all_loss + (1 - text_loss_1) + (1 - text_loss_2) + + return {'clip_all_loss': clip_all_loss} + + def generate_patch_crop(self, images, masks, patch_size=128, patch_num_per_mask=1): + b, _, H, W = masks.shape + + patches = [] + for i in range(masks.shape[0]): + mask = masks[i] + # mask: [1, H, W] + nonzero_indices = torch.nonzero(mask > 0, as_tuple=False) # [K', 3] + valid_mask = (nonzero_indices[:, 1] > patch_size // 2) & (nonzero_indices[:, 1] < (H - 1 - patch_size // 2)) & (nonzero_indices[:, 2] > patch_size // 2) & (nonzero_indices[:, 2] < (W - 1 - patch_size // 2)) + valid_idx = nonzero_indices[valid_mask] + patch_idx = valid_idx[torch.randperm(valid_idx.shape[0])[:patch_num_per_mask]] # [K, 3] + + if patch_idx.shape[0] < patch_num_per_mask: + patches_this_img = torch.zeros(patch_num_per_mask, 3, self.few_shot_gan_tex_patch, self.few_shot_gan_tex_patch).to(self.device) + else: + patches_this_img = [] + + for idx in range(patch_idx.shape[0]): + _, y, x = patch_idx[idx] + + y_start = max(0, y - patch_size // 2) + y_end = min(H, y_start + patch_size) + x_start = max(0, x - patch_size // 2) + x_end = min(W, x_start + patch_size) + + patch_content = images[i, :, y_start:y_end, x_start:x_end] + + patch = F.interpolate(patch_content.unsqueeze(0), size=self.few_shot_gan_tex_patch, mode='bilinear') # [1, 3, ps, ps] + patches_this_img.append(patch) + + patches_this_img = torch.cat(patches_this_img, dim=0) # [K, 3, ps, ps] + + patches.append(patches_this_img) + + patches = torch.concat(patches, dim=0) # [B*K, 3, ps, ps] + return patches + + + def compute_gan_tex_loss(self, category, image_gt, mask_gt, iv_image_pred, iv_mask_pred, w2c_pred, campos_pred, shape, prior_shape, texture, dino_pred, im_features, light, class_vector, num_frames, im_features_map, bins=360): + ''' + This part is used to do gan training on texture, this is meant to only be used in fine-tuning, with local texture network + Ideally this loss only contributes to the Texture + ''' + delta_angle = 2 * np.pi / bins + b = len(shape) + rand_degree = torch.randint(120, [b]) + rand_degree = rand_degree + 120 + # rand_degree = torch.ones(b) * 180 # we want to see the reversed side + delta_angle = delta_angle * rand_degree + delta_rot_matrix = [] + for i in range(b): + angle = delta_angle[i].item() + angle_matrix = torch.FloatTensor([ + [np.cos(angle), 0, np.sin(angle), 0], + [0, 1, 0, 0], + [-np.sin(angle), 0, np.cos(angle), 0], + [0, 0, 0, 1], + ]).to(self.device) + delta_rot_matrix.append(angle_matrix) + delta_rot_matrix = torch.stack(delta_rot_matrix, dim=0) + + proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device) + + original_mvp = torch.bmm(proj, w2c_pred) + # original_campos = -w2c_pred[:, :3, 3] + original_campos = campos_pred + mvp = torch.matmul(original_mvp, delta_rot_matrix) + campos = torch.matmul(delta_rot_matrix[:,:3,:3].transpose(2,1), original_campos[:,:,None])[:,:,0] + w2c = w2c_pred + + resolution = (self.few_shot_gan_tex_reso, self.few_shot_gan_tex_reso) + + # only train the texture + safe_detach = lambda x: x.detach() if x is not None else None + mesh = safe_detach(shape) + im_features = safe_detach(im_features) + im_features_map = safe_detach(im_features_map) + class_vector = safe_detach(class_vector) + + set_requires_grad(texture, True) + set_requires_grad(dino_pred, False) + set_requires_grad(light, False) + + background_for_reverse = 'none' + # background_for_reverse = 'random-pure' + + image_pred, mask_pred, _, _, _, _ = self.render( + mesh, + texture, + mvp, + w2c, + campos, + resolution, + background=background_for_reverse, + im_features=im_features, + light=light, + prior_shape=prior_shape, + render_flow=False, + dino_pred=dino_pred, + spp=self.renderer_spp, + class_vector=class_vector, + render_mode='diffuse', + two_sided_shading=False, + num_frames=num_frames, + im_features_map={"original_mvp": original_mvp, "im_features_map": im_features_map} if im_features_map is not None else None # in other views we need to pass the original mvp + ) + + mask_pred = mask_pred.unsqueeze(1) + if self.few_shot_gan_tex_reso != self.out_image_size: + image_pred = torch.nn.functional.interpolate(image_pred, (self.out_image_size, self.out_image_size), mode='bilinear') + mask_pred = torch.nn.functional.interpolate(mask_pred, (self.out_image_size, self.out_image_size), mode='bilinear') + + # image_pred = image_pred.clamp(0, 1) + # mask_pred = mask_pred.clamp(0, 1) # [B, 1, H, W] + + if background_for_reverse == 'random': + # as we set a random background for rendering, we also need another random background for input view + # for background, we use the same as random view: a small resolution then upsample + random_bg = torch.rand(self.bs, self.nf, 3, self.few_shot_gan_tex_reso, self.few_shot_gan_tex_reso).to(self.device) + random_bg = torch.nn.functional.interpolate(random_bg.squeeze(1), (self.out_image_size, self.out_image_size), mode='bilinear').unsqueeze(1) + iv_mask_pred = iv_mask_pred.unsqueeze(2).repeat(1, 1, 3, 1, 1) + iv_image_pred = iv_image_pred * iv_mask_pred + random_bg * (1. - iv_mask_pred) + iv_image_pred = iv_image_pred.squeeze(1) + + random_bg_gt = torch.rand(self.bs, self.nf, 3, self.few_shot_gan_tex_reso, self.few_shot_gan_tex_reso).to(self.device) + random_bg_gt = torch.nn.functional.interpolate(random_bg_gt.squeeze(1), (self.out_image_size, self.out_image_size), mode='bilinear').unsqueeze(1) + mask_gt = mask_gt.unsqueeze(2).repeat(1, 1, 3, 1, 1) + image_gt = image_gt * mask_gt + random_bg_gt * (1. - mask_gt) + image_gt = image_gt.squeeze(1) + + elif background_for_reverse == 'random-pure': + # the background is random but with one color + random_values = torch.rand(b) + random_bg = random_values[..., None, None, None, None].repeat(1, 1, 3, self.few_shot_gan_tex_reso, self.few_shot_gan_tex_reso).to(self.device) + random_bg = torch.nn.functional.interpolate(random_bg.squeeze(1), (self.out_image_size, self.out_image_size), mode='bilinear').unsqueeze(1) + iv_mask_pred = iv_mask_pred.unsqueeze(2).repeat(1, 1, 3, 1, 1) + iv_image_pred = iv_image_pred * iv_mask_pred + random_bg * (1. - iv_mask_pred) + iv_image_pred = iv_image_pred.squeeze(1) + + random_values_gt = torch.rand(b) + random_bg_gt = random_values_gt[..., None, None, None, None].repeat(1, 1, 3, self.few_shot_gan_tex_reso, self.few_shot_gan_tex_reso).to(self.device) + random_bg_gt = torch.nn.functional.interpolate(random_bg_gt.squeeze(1), (self.out_image_size, self.out_image_size), mode='bilinear').unsqueeze(1) + mask_gt = mask_gt.unsqueeze(2).repeat(1, 1, 3, 1, 1) + image_gt = image_gt * mask_gt + random_bg_gt * (1. - mask_gt) + image_gt = image_gt.squeeze(1) + + elif background_for_reverse == 'none': + iv_image_pred = iv_image_pred.squeeze(1) + iv_mask_pred = iv_mask_pred.unsqueeze(2).repeat(1, 1, 3, 1, 1) + # image_gt = image_gt * mask_gt + random_bg_gt * (1. - mask_gt) + mask_gt = mask_gt.unsqueeze(2).repeat(1, 1, 3, 1, 1) + image_gt = image_gt * mask_gt + image_gt = image_gt.squeeze(1) + + else: + raise NotImplementedError + + # image_gt = torch.nn.functional.interpolate(image_gt, (32, 32), mode='bilinear') + # image_gt = torch.nn.functional.interpolate(image_gt, (256, 256), mode='bilinear') + + # we need to let discriminator think this reverse view is Real sample + if self.cfgs.get('few_shot_gan_tex_patch', 0) > 0: + patch_size = torch.randint(self.few_shot_gan_tex_patch, self.few_shot_gan_tex_patch_max, (1,)).item() + # random view + image_pred = self.generate_patch_crop(image_pred, mask_pred, patch_size, self.few_shot_gan_tex_patch_num) + # input view + iv_image_pred = self.generate_patch_crop(iv_image_pred, iv_mask_pred.squeeze(1)[:, 0:1, :, :], patch_size, self.few_shot_gan_tex_patch_num) + # gt view + image_gt = self.generate_patch_crop(image_gt, mask_gt.squeeze(1)[:, 0:1, :, :], patch_size, self.few_shot_gan_tex_patch_num) + + return_loss = {} + if self.few_shot_gan_tex: + # here we compute the fake sample as real loss + gan_tex_loss = 0.0 + if 'rv' in self.few_shot_gan_tex_fake: + d_rv = self.discriminator_texture(image_pred) + gan_tex_loss_rv = discriminator_architecture.bce_loss_target(d_rv, 1) + gan_tex_loss += gan_tex_loss_rv + + if 'iv' in self.few_shot_gan_tex_fake: + d_iv = self.discriminator_texture(iv_image_pred) + gan_tex_loss_iv = discriminator_architecture.bce_loss_target(d_iv, 1) + gan_tex_loss += gan_tex_loss_iv + + return_loss['gan_tex_loss'] = gan_tex_loss + + if self.few_shot_clip_tex: + clip_tex_loss_rv_iv = self.compute_clip_loss(image_pred, iv_image_pred.unsqueeze(1), category='none') + clip_tex_loss_rv_gt = self.compute_clip_loss(image_pred, image_gt.unsqueeze(1), category='none') + clip_tex_loss = clip_tex_loss_rv_iv['clip_all_loss'] + clip_tex_loss_rv_gt['clip_all_loss'] + return_loss['clip_tex_loss'] = clip_tex_loss + + return_aux = { + 'gan_tex_render_image': image_pred.clone().clamp(0, 1), + 'gan_tex_inpview_image': iv_image_pred.clone().clamp(0, 1), + 'gan_tex_gt_image': image_gt.clone().clamp(0, 1) + } + + with torch.no_grad(): + # self.record_image_iv = iv_image_pred.clone().clamp(0, 1) + # self.record_image_rv = image_pred.clone().clamp(0, 1) + # self.record_image_gt = image_gt.clone().clamp(0, 1) + self.record_image_iv = iv_image_pred.clone() + self.record_image_rv = image_pred.clone() + self.record_image_gt = image_gt.clone() + + return return_loss, return_aux + + def compute_mask_distribution_loss(self, category, w2c_pred, shape, prior_shape, texture, dino_pred, im_features, light, class_vector, num_frames, im_features_map, bins=360): + delta_angle = 2 * np.pi / bins + b = len(shape) + + if self.random_mask_law == 'batch_swap': + # shuffle in predicted poses + rand_degree_1 = torch.randperm(int(w2c_pred.shape[0] // 2)) + rand_degree_2 = torch.randperm(w2c_pred.shape[0] - int(w2c_pred.shape[0] // 2)) + int(w2c_pred.shape[0] // 2) + rand_degree = torch.cat([rand_degree_2, rand_degree_1], dim=0).long().to(w2c_pred.device) + w2c = w2c_pred[rand_degree] + + proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device) + mvp = torch.bmm(proj, w2c) + campos = -w2c[:, :3, 3] + + elif self.random_mask_law == 'batch_swap_noy': + # shuffle in predicted poses + rand_degree_1 = torch.randperm(int(w2c_pred.shape[0] // 2)) + rand_degree_2 = torch.randperm(w2c_pred.shape[0] - int(w2c_pred.shape[0] // 2)) + int(w2c_pred.shape[0] // 2) + rand_degree = torch.cat([rand_degree_2, rand_degree_1], dim=0).long().to(w2c_pred.device) + w2c = w2c_pred[rand_degree] + # we don't random swap the y-translation in discriminator loss + w2c[:, 1, 3] = w2c_pred[:, 1, 3] + + proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device) + mvp = torch.bmm(proj, w2c) + campos = -w2c[:, :3, 3] + + elif self.random_mask_law == 'random_azimuth': + # the render rotation matrix is different + rand_degree = torch.randint(bins, [b]) + delta_angle = delta_angle * rand_degree + delta_rot_matrix = [] + for i in range(b): + angle = delta_angle[i].item() + angle_matrix = torch.FloatTensor([ + [np.cos(angle), 0, np.sin(angle), 0], + [0, 1, 0, 0], + [-np.sin(angle), 0, np.cos(angle), 0], + [0, 0, 0, 1], + ]).to(self.device) + delta_rot_matrix.append(angle_matrix) + delta_rot_matrix = torch.stack(delta_rot_matrix, dim=0) + + w2c = torch.FloatTensor(np.diag([1., 1., 1., 1])) + w2c[:3, 3] = torch.FloatTensor([0, 0, -self.cam_pos_z_offset *1.4]) + w2c = w2c.repeat(b, 1, 1).to(self.device) + # use the predicted transition + w2c_pred = w2c_pred.detach() + w2c[:, :3, 3] = w2c_pred[:b][:, :3, 3] + + proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device) + mvp = torch.bmm(proj, w2c) + campos = -w2c[:, :3, 3] + + mvp = torch.matmul(mvp, delta_rot_matrix) + campos = torch.matmul(delta_rot_matrix[:,:3,:3].transpose(2,1), campos[:,:,None])[:,:,0] + + elif self.random_mask_law == 'random_all': + # the render rotation matrix is different, and actually the translation are just pre-set + rand_degree = torch.randint(bins, [b]) + delta_angle = delta_angle * rand_degree + delta_rot_matrix = [] + for i in range(b): + angle = delta_angle[i].item() + angle_matrix = torch.FloatTensor([ + [np.cos(angle), 0, np.sin(angle), 0], + [0, 1, 0, 0], + [-np.sin(angle), 0, np.cos(angle), 0], + [0, 0, 0, 1], + ]).to(self.device) + delta_rot_matrix.append(angle_matrix) + delta_rot_matrix = torch.stack(delta_rot_matrix, dim=0) + + w2c = torch.FloatTensor(np.diag([1., 1., 1., 1])) + w2c[:3, 3] = torch.FloatTensor([0, 0, -self.cam_pos_z_offset *1.4]) + w2c = w2c.repeat(b, 1, 1).to(self.device) + + proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device) + mvp = torch.bmm(proj, w2c) + campos = -w2c[:, :3, 3] + + mvp = torch.matmul(mvp, delta_rot_matrix) + campos = torch.matmul(delta_rot_matrix[:,:3,:3].transpose(2,1), campos[:,:,None])[:,:,0] + + else: + raise NotImplementedError + + resolution = (self.out_image_size, self.out_image_size) + # render the articulated shape + mesh = shape + if self.enable_clip: + resolution = (self.clip_render_size, self.clip_render_size) + set_requires_grad(texture, False) + image_pred, mask_pred, _, _, _, _ = self.render( + mesh, + texture, + mvp, + w2c, + campos, + resolution, + background='none', + im_features=im_features, + light=light, + prior_shape=prior_shape, + render_flow=False, + dino_pred=dino_pred, + spp=self.renderer_spp, + class_vector=class_vector, + render_mode='diffuse', + two_sided_shading=False, + num_frames=num_frames, + im_features_map=im_features_map + ) + + if resolution[0] != self.out_image_size: + image_pred = torch.nn.functional.interpolate(image_pred, (self.out_image_size, self.out_image_size), mode='bilinear') + mask_pred = torch.nn.functional.interpolate(mask_pred.unsqueeze(1), (self.out_image_size, self.out_image_size), mode='bilinear').squeeze(1) + else: + _, mask_pred, _, _, _, _ = self.render( + mesh, + None, + mvp, + w2c, + campos, + resolution, + background='none', + im_features=None, + light=None, + prior_shape=prior_shape, + render_flow=False, + dino_pred=None, + class_vector=class_vector, + render_mode='diffuse', + two_sided_shading=False, + num_frames=num_frames, + im_features_map=None + ) + image_pred = None + + # TODO: disable mask distribution and isolate mask discriminator loss + # mask_distribution = self.class_mask_distribution[category] + # mask_distribution = torch.Tensor(mask_distribution).to(self.device).unsqueeze(0).repeat(b, 1, 1) + mask_distribution = torch.Tensor(self.class_mask_distribution["zebra"]).to(self.device).unsqueeze(0).repeat(b, 1, 1) + + if self.mask_distribution_average: + # if use mask_distribution_average, then first average across batch then compute the loss + mask_pred = mask_pred.mean(dim=0).unsqueeze(0).repeat(b, 1, 1) + + mask_pred = mask_pred.clamp(0,1) + mask_distribution = mask_distribution.clamp(0,1) + distribution_loss = torch.nn.functional.binary_cross_entropy(mask_pred, mask_distribution) + + out_loss = {'mask_distribution_loss': 0 * distribution_loss} + out_aux = { + 'mask_random_pred': mask_pred.unsqueeze(1), + 'mask_distribution': mask_distribution.unsqueeze(1), + 'rand_degree': rand_degree + } + + if self.enable_clip: + out_aux.update({'random_render_image': image_pred}) + + return out_loss, out_aux + + def use_line_correct_valid_mask(self, mask_valid, p1, p2, mvp, mask_gt): + line = torch.cat([p1.unsqueeze(-2), p2.unsqueeze(-2)], dim=-2) # [B, 2, 3] + line_world4 = torch.cat([line, torch.ones_like(line[..., :1])], -1) + line_clip4 = line_world4 @ mvp.transpose(-1, -2) + line_uv = line_clip4[..., :2] / line_clip4[..., 3:4] + line_uv = line_uv.detach() + b, _, n_uv = line_uv.shape + line_uv = line_uv * torch.Tensor([mask_valid.shape[-2] // 2, mask_valid.shape[-1] // 2]).to(line_uv.device).unsqueeze(0).unsqueeze(-1).repeat(b, 1, n_uv) + line_uv = line_uv + torch.Tensor([mask_valid.shape[-2] // 2, mask_valid.shape[-1] // 2]).to(line_uv.device).unsqueeze(0).unsqueeze(-1).repeat(b, 1, n_uv) + from pdb import set_trace; set_trace() + line_slope = (line_uv[:, 0, 1] - line_uv[:, 1, 1]) / (line_uv[:, 0, 0] - line_uv[:, 1, 0]) + + uv = np.mgrid[0:mask_valid.shape[-2], 0:mask_valid.shape[-1]].astype(np.int32) + uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float().unsqueeze(0).repeat(b, 1, 1, 1) # [B, 2, 256, 256] + tmp_u = uv[:, 0, ...][mask_gt[:, 0, ...].bool()] + tmp_v = uv[:, 1, ...][mask_gt[:, 0, ...].bool()] + return mask_valid + + def discriminator_step(self): + mask_gt = self.record_mask_gt + mask_pred = self.record_mask_iv + mask_random_pred = self.record_mask_rv + + self.optimizerDiscriminator.zero_grad() + + # the random view mask are False + d_random_pred = self.mask_disc(mask_random_pred) + disc_loss = discriminator_architecture.bce_loss_target(d_random_pred, 0) # in gen loss, train it to be real + + grad_loss = 0.0 + count = 1 + + discriminator_loss_rv = disc_loss.detach() + discriminator_loss_gt = 0.0 + discriminator_loss_iv = 0. + d_gt = None + d_iv = None + + if self.disc_gt: + mask_gt.requires_grad_() + d_gt = self.mask_disc(mask_gt) + if d_gt.requires_grad is False: + # in the test case + disc_gt_loss = discriminator_architecture.bce_loss_target(d_gt, 1) + else: + grad_penalty = self.disc_reg_mul * discriminator_architecture.compute_grad2(d_gt, mask_gt) + disc_gt_loss = discriminator_architecture.bce_loss_target(d_gt, 1) + grad_penalty + grad_loss += grad_penalty + disc_loss = disc_loss + disc_gt_loss + discriminator_loss_gt = disc_gt_loss + count = count + 1 + + if self.disc_iv: + mask_pred.requires_grad_() + d_iv = self.mask_disc(mask_pred) + if self.disc_iv_label == 'Real': + if d_iv.requires_grad is False: + # in the test case + disc_iv_loss = discriminator_architecture.bce_loss_target(d_iv, 1) + else: + grad_penalty = self.disc_reg_mul * discriminator_architecture.compute_grad2(d_iv, mask_pred) + disc_iv_loss = discriminator_architecture.bce_loss_target(d_iv, 1) + grad_penalty + grad_loss += grad_penalty + + else: + disc_iv_loss = discriminator_architecture.bce_loss_target(d_iv, 0) + disc_loss = disc_loss + disc_iv_loss + count = count + 1 + discriminator_loss_iv = disc_iv_loss + + disc_loss = disc_loss / count + grad_loss = grad_loss / count + + self.discriminator_loss = disc_loss * self.discriminator_loss_weight + self.discriminator_loss.backward() + self.optimizerDiscriminator.step() + self.discriminator_loss = 0. + return { + 'discriminator_loss': disc_loss, + 'discriminator_loss_rv': discriminator_loss_rv, + 'discriminator_loss_iv': discriminator_loss_iv, + 'discriminator_loss_gt': discriminator_loss_gt, + 'd_rv': d_random_pred, + 'd_iv': d_iv if d_iv is not None else None, + 'd_gt': d_gt if d_gt is not None else None, + }, grad_loss + + def compute_mask_disc_loss_gen(self, mask_gt, mask_pred, mask_random_pred, category_name=None, condition_feat=None): + # mask_gt[mask_gt < 1.] = 0. + # mask_pred[mask_pred > 0.] = 1. + # mask_random_pred[mask_random_pred > 0.] = 1. + + if not self.mask_disc_feat_condition: + try: + class_idx = list(self.netPrior.category_id_map.keys()).index(category_name) + except: + class_idx = 100 + num_classes = len(list(self.netPrior.category_id_map.keys())) + class_idx = torch.LongTensor([class_idx]) + # class_one_hot = torch.nn.functional.one_hot(class_idx, num_classes=7).unsqueeze(-1).unsqueeze(-1).to(mask_gt.device) # [1, 7, 1, 1] + class_one_hot = torch.nn.functional.one_hot(class_idx, num_classes=num_classes).unsqueeze(-1).unsqueeze(-1).to(mask_gt.device) + class_one_hot = class_one_hot.repeat(mask_gt.shape[0], 1, mask_gt.shape[-2], mask_gt.shape[-1]) + # TODO: a hack try here + class_one_hot = class_one_hot[:, :(self.mask_disc.in_dim-1), :, :] + else: + class_one_hot = condition_feat.detach() + class_one_hot = class_one_hot.reshape(1, -1, 1, 1).repeat(mask_gt.shape[0], 1, mask_gt.shape[-2], mask_gt.shape[-1]) + + # concat + mask_gt = torch.cat([mask_gt, class_one_hot], dim=1) + mask_pred = torch.cat([mask_pred, class_one_hot], dim=1) + mask_random_pred = torch.cat([mask_random_pred, class_one_hot], dim=1) + + # mask shape are all [B,1,256,256] + # the random view mask are False + d_random_pred = self.mask_disc(mask_random_pred) + disc_loss = discriminator_architecture.bce_loss_target(d_random_pred, 1) # in gen loss, train it to be real + count = 1 + + disc_loss_rv = disc_loss.detach() + disc_loss_iv = 0.0 + + if self.disc_iv: + if self.disc_iv_label != 'Real': # consider the input view also fake + d_iv = self.mask_disc(mask_pred) + disc_iv_loss = discriminator_architecture.bce_loss_target(d_iv, 1) # so now we need to train them to be real + disc_loss = disc_loss + disc_iv_loss + count = count + 1 + disc_loss_iv = disc_iv_loss.detach() + + disc_loss = disc_loss / count + + # record the masks for discriminator training + self.record_mask_gt = mask_gt.clone().detach() + self.record_mask_iv = mask_pred.clone().detach() + self.record_mask_rv = mask_random_pred.clone().detach() + + return { + 'mask_disc_loss': disc_loss, + 'mask_disc_loss_rv': disc_loss_rv, + 'mask_disc_loss_iv': disc_loss_iv, + } + + def forward(self, batch, epoch, iter, is_train=True, viz_logger=None, total_iter=None, save_results=False, save_dir=None, which_data='', logger_prefix='', is_training=True, bank_embedding=None): + batch = [x.to(self.device) if x is not None and isinstance(x, torch.Tensor) else x for x in batch] + input_image, mask_gt, mask_dt, mask_valid, flow_gt, bbox, bg_image, dino_feat_im, dino_cluster_im, seq_idx, frame_idx, category_name = batch + + # if save_results: + # save_for_pkl = { + # "image": input_image.cpu(), + # "mask_gt": mask_gt.cpu(), + # "mask_dt": mask_dt.cpu(), + # "mask_valid": mask_valid.cpu(), + # "flow_gt": None, + # "bbox": bbox.cpu(), + # "bg_image": bg_image.cpu(), + # "dino_feat_im": dino_feat_im.cpu(), + # "dino_cluster_im": dino_cluster_im.cpu(), + # "seq_idx": seq_idx.cpu(), + # "frame_idx": frame_idx.cpu(), + # "category_name": category_name + # } + + batch_size, num_frames, _, h0, w0 = input_image.shape # BxFxCxHxW + self.bs = batch_size + self.nf = num_frames + mid_img_idx = int((input_image.shape[1]-1)//2) + # print(f"mid_img_idx: {mid_img_idx}") + + h = w = self.out_image_size + + def collapseF(x): + return None if x is None else x.view(batch_size * num_frames, *x.shape[2:]) + def expandF(x): + return None if x is None else x.view(batch_size, num_frames, *x.shape[1:]) + + if flow_gt.dim() == 2: # dummy tensor for not loading flow + flow_gt = None + + if dino_cluster_im.dim() == 2: # dummy tensor for not loading dino clusters + dino_cluster_im = None + dino_cluster_im_gt = None + else: + dino_cluster_im_gt = expandF(torch.nn.functional.interpolate(collapseF(dino_cluster_im), size=[h, w], mode="nearest")) + + seq_idx = seq_idx.squeeze(1) + # seq_idx = seq_idx * 0 # single sequnce model + frame_id, crop_x0, crop_y0, crop_w, crop_h, full_w, full_h, sharpness, label = bbox.unbind(2) # BxFx7 + bbox = torch.stack([crop_x0, crop_y0, crop_w, crop_h], 2) + mask_gt = (mask_gt[:, :, 0, :, :] > 0.9).float() # BxFxHxW + mask_dt = mask_dt / self.in_image_size + + if which_data != 'video': + flow_gt = None + + aux_viz = {} + + ## GT + image_gt = input_image + if self.out_image_size != self.in_image_size: + image_gt = expandF(torch.nn.functional.interpolate(collapseF(image_gt), size=[h, w], mode='bilinear')) + if flow_gt is not None: + flow_gt = torch.nn.functional.interpolate(flow_gt.view(batch_size*(num_frames-1), 2, h0, w0), size=[h, w], mode="bilinear").view(batch_size, num_frames-1, 2, h, w) + + self.train_pose_only = False + if epoch in self.pose_epochs: + if (total_iter // self.pose_iters) % 2 == 0: + self.train_pose_only = True + + ## flip input and pose + if epoch in self.pose_xflip_recon_epochs: + input_image_xflip = input_image.flip(-1) + input_image_xflip_flag = torch.randint(0, 2, (batch_size, num_frames), device=input_image.device) + input_image = input_image * (1 - input_image_xflip_flag[:,:,None,None,None]) + input_image_xflip * input_image_xflip_flag[:,:,None,None,None] + else: + input_image_xflip_flag = None + + ## 1st pose hypothesis with original predictions + + # ============================================================================================== + # Predict prior mesh. + # ============================================================================================== + if self.enable_prior: + if self.world_size > 1: + if epoch < self.dmtet_grid_smaller_epoch: + if self.netPrior_ddp.module.netShape.grid_res != self.dmtet_grid_smaller: + self.netPrior_ddp.module.netShape.load_tets(self.dmtet_grid_smaller) + else: + if self.netPrior_ddp.module.netShape.grid_res != self.dmtet_grid: + self.netPrior_ddp.module.netShape.load_tets(self.dmtet_grid) + + else: + if epoch < self.dmtet_grid_smaller_epoch: + if self.netPrior.netShape.grid_res != self.dmtet_grid_smaller: + self.netPrior.netShape.load_tets(self.dmtet_grid_smaller) + else: + if self.netPrior.netShape.grid_res != self.dmtet_grid: + self.netPrior.netShape.load_tets(self.dmtet_grid) + + perturb_sdf = self.perturb_sdf if is_train else False + # DINO prior category specific - DOR + if self.world_size > 1: + prior_shape, dino_pred, classes_vectors = self.netPrior_ddp(category_name=category_name[0], perturb_sdf=perturb_sdf, total_iter=total_iter, is_training=is_training, class_embedding=bank_embedding) + else: + prior_shape, dino_pred, classes_vectors = self.netPrior(category_name=category_name[0], perturb_sdf=perturb_sdf, total_iter=total_iter, is_training=is_training, class_embedding=bank_embedding) + else: + prior_shape = None + raise NotImplementedError + + if self.world_size > 1: + shape, pose_raw, pose, mvp, w2c, campos, texture, im_features, dino_feat_im_calc, deformation, arti_params, light, forward_aux = self.netInstance_ddp(category_name, input_image, prior_shape, epoch, dino_feat_im, dino_cluster_im, total_iter, is_training=is_training) # frame dim collapsed N=(B*F) + else: + Instance_out = self.netInstance(category_name, input_image, prior_shape, epoch, dino_feat_im, dino_cluster_im, total_iter, is_training=is_training) # frame dim collapsed N=(B*F) + + # if no patch_out as output from netInstance, then set im_features_map as None in following part + if len(Instance_out) == 13: + shape, pose_raw, pose, mvp, w2c, campos, texture, im_features, dino_feat_im_calc, deformation, arti_params, light, forward_aux = Instance_out + im_features_map = None + else: + shape, pose_raw, pose, mvp, w2c, campos, texture, im_features, dino_feat_im_calc, deformation, arti_params, light, forward_aux, im_features_map = Instance_out + + # if save_results: + # save_for_pkl.update( + # { + # "pose_raw": pose_raw.cpu(), + # "pose": pose.cpu(), + # "mvp": mvp.cpu(), + # "w2c": w2c.cpu(), + # "campos": campos.cpu(), + # "campos_z_offset": self.netInstance.cam_pos_z_offset + # } + # ) + + if self.calc_dino_features == True: + + # get the shape parameters of the tensor + batch_size, height, width, channels = dino_feat_im_calc.shape #3 X 384 X 32 X 32 + + + # reshape the tensor to have 2 dimensions, with the last dimension being preserved + dino_feat_im = dino_feat_im_calc.reshape(batch_size , height, -1) + + # normalize the tensor using L2 normalization + norm = torch.norm(dino_feat_im, dim=-1, keepdim=True) + + dino_feat_im = dino_feat_im / norm + + # reshape the tensor back to the original shape with an additional singleton dimension along the first dimension + dino_feat_im = dino_feat_im.reshape(batch_size, height, width, channels) + dino_feat_im = dino_feat_im.unsqueeze(1) + + + if dino_feat_im.dim() == 2: # dummy tensor for not loading dino features + dino_feat_im = None + dino_feat_im_gt = None + else: + dino_feat_im_gt = expandF(torch.nn.functional.interpolate(collapseF(dino_feat_im), size=[h, w], mode="bilinear"))[:, :, :self.dino_feature_recon_dim] + + rot_logit = forward_aux['rot_logit'] + rot_idx = forward_aux['rot_idx'] + rot_prob = forward_aux['rot_prob'] + + if self.using_bonevel_smooth_loss: + posed_bones = forward_aux['posed_bones'] + else: + posed_bones = None + + aux_viz.update(forward_aux) + + if self.train_pose_only: + safe_detach = lambda x: x.detach() if x is not None else None + prior_shape = safe_detach(prior_shape) + shape = safe_detach(shape) + im_features = safe_detach(im_features) + arti_params = safe_detach(arti_params) + deformation = safe_detach(deformation) + set_requires_grad(texture, False) + set_requires_grad(light, False) + set_requires_grad(dino_pred, False) + else: + set_requires_grad(texture, True) + set_requires_grad(light, True) + set_requires_grad(dino_pred, True) + + render_flow = self.render_flow and num_frames > 1 #false + # from IPython import embed; embed() + + # if num_frames > 1 and self.smooth_type == 'rend': + # print("rendererr smoothness !!!!") + # image_pred, mask_pred, flow_pred, dino_feat_im_pred, albedo, shading = self.render(shape, texture, mvp, w2c, campos, (h, w), background=self.background_mode, im_features=im_features[torch.randperm(im_features.size(0))], light=light, prior_shape=prior_shape, render_flow=render_flow, dino_pred=dino_pred, num_frames=num_frames, spp=self.renderer_spp) #the real rendering process + # else: + # print("regular render") + #print("a cecond before rendering .... need to get the correct label and the correct vector") + #print("label", label) + #print("classes_vectors", classes_vectors) + #print("im_features", im_features.shape) + + class_vector = None + if classes_vectors is not None: + if len(classes_vectors.shape) == 1: + class_vector = classes_vectors + else: + class_vector = classes_vectors[self.netPrior.category_id_map[category_name[0]], :] + + image_pred, mask_pred, flow_pred, dino_feat_im_pred, albedo, shading = self.render(shape, texture, mvp, w2c, campos, (h, w), background=self.background_mode, im_features=im_features, light=light, prior_shape=prior_shape, render_flow=render_flow, dino_pred=dino_pred, class_vector=class_vector[None, :].expand(batch_size * num_frames, -1), num_frames=num_frames, spp=self.renderer_spp, im_features_map=im_features_map) #the real rendering process + image_pred, mask_pred, flow_pred, dino_feat_im_pred = map(expandF, (image_pred, mask_pred, flow_pred, dino_feat_im_pred)) + + if flow_pred is not None: + flow_pred = flow_pred[:, :-1] # Bx(F-1)x2xHxW + + if self.blur_mask: + sigma = max(0.5, 3 * (1 - total_iter / self.blur_mask_iter)) + if sigma > 0.5: + mask_gt = util.blur_image(mask_gt, kernel_size=9, sigma=sigma, mode='gaussian') + # mask_pred = util.blur_image(mask_pred, kernel_size=7, mode='average') + + # back_line_p1 = forward_aux['posed_bones'][:, :, 3, -1].squeeze(1) # [8, 3] + # back_line_p2 = forward_aux['posed_bones'][:, :, 7, -1].squeeze(1) + # mask_valid = self.use_line_correct_valid_mask(mask_valid, back_line_p1, back_line_p2, mvp, mask_gt) + + losses = self.compute_reconstruction_losses(image_pred, image_gt, mask_pred, mask_gt, mask_dt, mask_valid, flow_pred, flow_gt, dino_feat_im_gt, dino_feat_im_pred, background_mode=self.background_mode, reduce=False) + + ## TODO: assume flow loss is not used + logit_loss_target = torch.zeros_like(expandF(rot_logit)) + final_losses = {} + for name, loss in losses.items(): + if name == 'flow_loss': + continue + loss_weight_logit = self.cfgs.get(f"{name}_weight", 0.) + + if isinstance(loss_weight_logit, dict): + loss_weight_logit = self.parse_dict_definition(loss_weight_logit, total_iter) + + # from IPython import embed; embed() + # print("-"*10) + # print(f"{name}_weight: {loss_weight_logit}.") + # print(f"logit_loss_target.shape: {logit_loss_target.shape}.") + # print(f"loss.shape: {loss.shape}.") + # if (name in ['flow_loss'] and epoch not in self.flow_loss_epochs) or (name in ['rgb_loss', 'perceptual_loss'] and epoch not in self.texture_epochs): + # if name in ['flow_loss', 'rgb_loss', 'perceptual_loss']: + # loss_weight_logit = 0. + if name in ['sdf_bce_reg_loss', 'sdf_gradient_reg_loss', 'sdf_inflate_reg_loss']: + if total_iter >= self.sdf_reg_decay_start_iter: + decay_rate = max(0, 1 - (total_iter-self.sdf_reg_decay_start_iter) / 10000) + loss_weight_logit = max(loss_weight_logit * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.)) + if name in ['dino_feat_im_loss']: + dino_feat_im_loss_multipler = self.cfgs.get("logit_loss_dino_feat_im_loss_multiplier", 1.) + + if isinstance(dino_feat_im_loss_multipler, dict): + dino_feat_im_loss_multipler = self.parse_dict_definition(dino_feat_im_loss_multipler, total_iter) + + loss_weight_logit = loss_weight_logit * dino_feat_im_loss_multipler + # loss_weight_logit = loss_weight_logit * self.cfgs.get("logit_loss_dino_feat_im_loss_multiplier", 1.) + if loss_weight_logit > 0: + logit_loss_target += loss * loss_weight_logit + + if self.netInstance.rot_rep in ['quadlookat', 'octlookat']: + loss = loss * rot_prob.detach().view(batch_size, num_frames)[:, :loss.shape[1]] *self.netInstance.num_pose_hypos + if name == 'flow_loss' and num_frames > 1: + ri = rot_idx.view(batch_size, num_frames) + same_rot_idx = (ri[:, 1:] == ri[:, :-1]).float() + loss = loss * same_rot_idx + final_losses[name] = loss.mean() + final_losses['logit_loss'] = ((expandF(rot_logit) - logit_loss_target.detach())**2.).mean() + + ## score distillation sampling + sds_random_images = None + if self.enable_sds: + prompts = None + if classes_vectors is not None: + prompts = category_name[0] + sds_random_images, sds_aux = self.score_distillation_sampling(shape, texture, [self.diffusion_resolution, self.diffusion_resolution], im_features, light, prior_shape, prompts=prompts, classes_vectors=class_vector[None, :].expand(batch_size * num_frames, -1), im_features_map=im_features_map, w2c_pred=w2c) + if self.enable_vsd: + final_losses.update({'vsd_loss': sds_aux['loss']}) + final_losses.update({'vsd_lora_loss': sds_aux['loss_lora']}) + else: + final_losses.update({'sds_loss': sds_aux['loss']}) + + ## mask distribution loss + mask_distribution_aux = None + if self.enable_mask_distribution: + if total_iter % self.mask_distribution_loss_freq == 0: + mask_distribution_loss, mask_distribution_aux = self.compute_mask_distribution_loss(category_name[0], w2c, shape, prior_shape, texture, dino_pred, im_features, light, class_vector[None, :].expand(batch_size * num_frames, -1), num_frames, im_features_map) + final_losses.update(mask_distribution_loss) + # this also follows the iteration frequency + if self.enable_clip: + random_render_image = mask_distribution_aux["random_render_image"] + clip_all_loss = self.compute_clip_loss(random_render_image, image_pred, category_name[0]) # a dict + final_losses.update(clip_all_loss) + + # implement the mask discriminator + if self.enable_disc and (self.mask_discriminator_iter[0] < total_iter) and (self.mask_discriminator_iter[1] > total_iter): + disc_loss = self.compute_mask_disc_loss_gen(mask_gt, mask_pred, mask_distribution_aux['mask_random_pred'], category_name=category_name[0], condition_feat=class_vector) + final_losses.update(disc_loss) + + # implement the gan training for local texture in fine-tuning + gan_tex_aux = None + if (self.few_shot_gan_tex and viz_logger is None) or (self.few_shot_gan_tex and viz_logger is not None and logger_prefix == 'train_'): + gan_tex_loss, gan_tex_aux = self.compute_gan_tex_loss(category_name[0], image_gt, mask_gt, image_pred, mask_pred, w2c, campos, shape, prior_shape, texture, dino_pred, im_features, light, class_vector[None, :].expand(batch_size * num_frames, -1), num_frames, im_features_map) + final_losses.update(gan_tex_loss) + + # implement the memory bank related loss + if bank_embedding is not None: + batch_embedding = bank_embedding[0] # [d] + embeddings = bank_embedding[1] # [B, d] + bank_mean_dist = torch.nn.functional.mse_loss(embeddings, batch_embedding.unsqueeze(0).repeat(batch_size, 1)) + final_losses.update({'bank_mean_dist_loss': bank_mean_dist}) + + + ## regularizers + regularizers, aux = self.compute_regularizers(shape, prior_shape, input_image, dino_feat_im, pose_raw, input_image_xflip_flag, arti_params, deformation, mid_img_idx, posed_bones=posed_bones, class_vector=class_vector.detach() if class_vector is not None else None) + final_losses.update(regularizers) + aux_viz.update(aux) + + total_loss = 0 + for name, loss in final_losses.items(): + loss_weight = self.cfgs.get(f"{name}_weight", 0.) + + if isinstance(loss_weight, dict): + loss_weight = self.parse_dict_definition(loss_weight, total_iter) + + if loss_weight <= 0: + continue + + if self.train_pose_only: + if name not in ['silhouette_loss', 'silhouette_dt_loss', 'silhouette_inv_dt_loss', 'flow_loss', 'pose_xflip_reg_loss', 'lookat_zflip_loss', 'dino_feat_im_loss']: + continue + if epoch not in self.flow_loss_epochs: + if name in ['flow_loss']: + continue + if epoch not in self.texture_epochs: + if name in ['rgb_loss', 'perceptual_loss']: + continue + if epoch not in self.lookat_zflip_loss_epochs: + if name in ['lookat_zflip_loss']: + continue + if name in ['mesh_laplacian_smoothing_loss', 'mesh_normal_consistency_loss']: + if total_iter < self.cfgs.get('mesh_reg_start_iter', 0): + continue + if epoch >= self.mesh_reg_decay_epoch: + decay_rate = self.mesh_reg_decay_rate ** (epoch - self.mesh_reg_decay_epoch) + loss_weight = max(loss_weight * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.)) + if epoch not in self.sdf_inflate_reg_loss_epochs: + if name in ['sdf_inflate_reg_loss']: + continue + if self.iter_arti_reg_loss_start is not None: + if total_iter <= self.iter_arti_reg_loss_start: + if name in ['arti_reg_loss']: + continue + else: + if epoch not in self.arti_reg_loss_epochs: + if name in ['arti_reg_loss']: + continue + if name in ['sdf_bce_reg_loss', 'sdf_gradient_reg_loss', 'sdf_inflate_reg_loss']: + if total_iter >= self.sdf_reg_decay_start_iter: + decay_rate = max(0, 1 - (total_iter-self.sdf_reg_decay_start_iter) / 10000) + loss_weight = max(loss_weight * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.)) + + total_loss += loss * loss_weight + + self.total_loss += total_loss # reset to 0 in backward step + + if torch.isnan(self.total_loss): + print("NaN in loss...") + import ipdb; ipdb.set_trace() + + final_losses['logit_loss_target'] = logit_loss_target.mean() + + metrics = {'loss': total_loss, **final_losses} + ## log visuals + if viz_logger is not None: + b0 = max(min(batch_size, 16//num_frames), 1) + viz_logger.add_image(logger_prefix+'image/image_gt', misc.image_grid(image_gt.detach().cpu()[:b0,:].reshape(-1,*input_image.shape[2:]).clamp(0,1)), total_iter) + viz_logger.add_image(logger_prefix+'image/image_pred', misc.image_grid(image_pred.detach().cpu()[:b0,:].reshape(-1,*image_pred.shape[2:]).clamp(0,1)), total_iter) + # viz_logger.add_image(logger_prefix+'image/flow_loss_mask', misc.image_grid(flow_loss_mask[:b0,:,:1].reshape(-1,1,*flow_loss_mask.shape[3:]).repeat(1,3,1,1).clamp(0,1)), total_iter) + viz_logger.add_image(logger_prefix+'image/mask_gt', misc.image_grid(mask_gt.detach().cpu()[:b0,:].reshape(-1,*mask_gt.shape[2:]).unsqueeze(1).repeat(1,3,1,1).clamp(0,1)), total_iter) + viz_logger.add_image(logger_prefix+'image/mask_pred', misc.image_grid(mask_pred.detach().cpu()[:b0,:].reshape(-1,*mask_pred.shape[2:]).unsqueeze(1).repeat(1,3,1,1).clamp(0,1)), total_iter) + + if self.render_flow and flow_gt is not None: + # if False: + flow_gt = flow_gt.detach().cpu() + flow_gt_viz = torch.cat([flow_gt[:b0], torch.zeros_like(flow_gt[:b0,:,:1])], 2) + 0.5 # -0.5~1.5 + flow_gt_viz = torch.nn.functional.pad(flow_gt_viz, pad=[0, 0, 0, 0, 0, 0, 0, 1]) + + # ## draw marker on large flow frames + # large_flow_marker_mask = torch.zeros_like(flow_gt_viz) + # large_flow_marker_mask[:,:,:,:8,:8] = 1. + # large_flow = torch.cat([self.large_flow, self.large_flow[:,:1] *0.], 1).detach().cpu()[:b0] + # large_flow_marker_mask = large_flow_marker_mask * large_flow[:,:,None,None,None] + # red = torch.FloatTensor([1,0,0])[None,None,:,None,None] + # flow_gt_viz = large_flow_marker_mask * red + (1-large_flow_marker_mask) * flow_gt_viz + + viz_logger.add_image(logger_prefix+'image/flow_gt', misc.image_grid(flow_gt_viz.reshape(-1,*flow_gt_viz.shape[2:])), total_iter) + + if self.render_flow and flow_pred is not None: + # if False + flow_pred = flow_pred.detach().cpu() + flow_pred_viz = torch.cat([flow_pred[:b0], torch.zeros_like(flow_pred[:b0,:,:1])], 2) + 0.5 # -0.5~1.5 + flow_pred_viz = torch.nn.functional.pad(flow_pred_viz, pad=[0, 0, 0, 0, 0, 0, 0, 1]) + viz_logger.add_image(logger_prefix+'image/flow_pred', misc.image_grid(flow_pred_viz.reshape(-1,*flow_pred_viz.shape[2:])), total_iter) + + if sds_random_images is not None: + viz_logger.add_image( + logger_prefix + 'image/sds_image', + self.vis_sds_image(sds_random_images, sds_aux), + total_iter) + viz_logger.add_image( + logger_prefix + 'image/sds_grad', + self.vis_sds_grads(sds_aux), total_iter) + + if mask_distribution_aux is not None: + degree_text = mask_distribution_aux['rand_degree'] + mask_random_pred = mask_distribution_aux['mask_random_pred'].detach().cpu().clamp(0, 1) + mask_distribution_data = mask_distribution_aux['mask_distribution'].detach().cpu().clamp(0, 1) + + mask_random_pred_image = [misc.add_text_to_image(img, str(text.item())) for img, text in zip(mask_random_pred, degree_text)] + mask_random_pred_image = misc.image_grid(mask_random_pred_image) + mask_distribution_image = misc.image_grid(mask_distribution_data) + + viz_logger.add_image( + logger_prefix + 'image/mask_random_pred', + mask_random_pred_image, + total_iter) + viz_logger.add_image( + logger_prefix + 'image/mask_distribution', + mask_distribution_image, + total_iter) + + if gan_tex_aux is not None: + gan_tex_render_image = gan_tex_aux['gan_tex_render_image'].detach().cpu().clamp(0, 1) + gan_tex_render_image = misc.image_grid(gan_tex_render_image) + viz_logger.add_image( + logger_prefix + 'image/gan_tex_render_image', + gan_tex_render_image, + total_iter) + + gan_tex_render_image_iv = gan_tex_aux['gan_tex_inpview_image'].detach().cpu().clamp(0, 1) + gan_tex_render_image_iv = misc.image_grid(gan_tex_render_image_iv) + viz_logger.add_image( + logger_prefix + 'image/gan_tex_inpview_image', + gan_tex_render_image_iv, + total_iter) + + gan_tex_render_image_gt = gan_tex_aux['gan_tex_gt_image'].detach().cpu().clamp(0, 1) + gan_tex_render_image_gt = misc.image_grid(gan_tex_render_image_gt) + viz_logger.add_image( + logger_prefix + 'image/gan_tex_gt_image', + gan_tex_render_image_gt, + total_iter) + + # if self.render_flow and flow_gt is not None and flow_pred is not None: + # flow_gt = flow_gt.detach().cpu() + # # flow_gt_viz = torch.cat([flow_gt[:b0], torch.zeros_like(flow_gt[:b0,:,:1])], 2) + 0.5 # -0.5~1.5 + # # flow_gt_viz = torch.nn.functional.pad(flow_gt_viz, pad=[0, 0, 0, 0, 0, 0, 0, 1]) + + # # ## draw marker on large flow frames + # # large_flow_marker_mask = torch.zeros_like(flow_gt_viz) + # # large_flow_marker_mask[:,:,:,:8,:8] = 1. + # # large_flow = torch.cat([self.large_flow, self.large_flow[:,:1] *0.], 1).detach().cpu()[:b0] + # # large_flow_marker_mask = large_flow_marker_mask * large_flow[:,:,None,None,None] + # # red = torch.FloatTensor([1,0,0])[None,None,:,None,None] + # # flow_gt_viz = large_flow_marker_mask * red + (1-large_flow_marker_mask) * flow_gt_viz + + # # viz_logger.add_image(logger_prefix+'image/flow_gt', misc.image_grid(flow_gt_viz.reshape(-1,*flow_gt_viz.shape[2:])), total_iter) + + # flow_pred = flow_pred.detach().cpu() + # # flow_pred_viz = torch.cat([flow_pred[:b0], torch.zeros_like(flow_pred[:b0,:,:1])], 2) + 0.5 # -0.5~1.5 + # # flow_pred_viz = torch.nn.functional.pad(flow_pred_viz, pad=[0, 0, 0, 0, 0, 0, 0, 1]) + + # flow_gt_pred = torch.cat([flow_gt, flow_pred], dim=-1) + # flow_gt_pred = flow_gt_pred.permute(0,1,3,4,2).detach().cpu().reshape(flow_gt_pred.shape[0]*flow_gt_pred.shape[1],*flow_gt_pred.shape[2:]) + # flow_gt_pred = flow_viz.flow_batch_to_images(flow_gt_pred) + # # flow_gt_pred = torch.tensor(flow_gt_pred).permute(0,3,1,2) + + # # viz_logger.add_image(logger_prefix+'image/flow_gt_pred', misc.image_grid(flow_gt_pred.reshape(-1,*flow_gt_pred.shape[2:])), total_iter) + # viz_logger.add_image(logger_prefix+'image/flow_gt_pred', misc.image_grid(flow_gt_pred), total_iter) + + if light is not None: + param_names = ['dir_x', 'dir_y', 'dir_z', 'int_ambient', 'int_diffuse'] + for name, param in zip(param_names, light.light_params.unbind(-1)): + viz_logger.add_histogram(logger_prefix+'light/'+name, param, total_iter) + viz_logger.add_image( + logger_prefix + f'image/albedo', + misc.image_grid(expandF(albedo)[:b0, ...].view(-1, *albedo.shape[1:])), + total_iter) + viz_logger.add_image( + logger_prefix + f'image/shading', + misc.image_grid(expandF(shading)[:b0, ...].view(-1, *shading.shape[1:]).repeat(1, 3, 1, 1) /2.), + total_iter) + + viz_logger.add_histogram(logger_prefix+'sdf', self.netPrior.netShape.get_sdf(perturb_sdf=False, class_vector=class_vector), total_iter) + viz_logger.add_histogram(logger_prefix+'coordinates', shape.v_pos, total_iter) + if arti_params is not None: + viz_logger.add_histogram(logger_prefix+'arti_params', arti_params, total_iter) + viz_logger.add_histogram(logger_prefix+'edge_lengths', aux_viz['edge_lengths'], total_iter) + + if deformation is not None: + viz_logger.add_histogram(logger_prefix+'deformation', deformation, total_iter) + + rot_rep = self.netInstance.rot_rep + if rot_rep == 'euler_angle' or rot_rep == 'soft_calss': + for i, name in enumerate(['rot_x', 'rot_y', 'rot_z', 'trans_x', 'trans_y', 'trans_z']): + viz_logger.add_histogram(logger_prefix+'pose/'+name, pose[...,i], total_iter) + elif rot_rep == 'quaternion': + for i, name in enumerate(['qt_0', 'qt_1', 'qt_2', 'qt_3', 'trans_x', 'trans_y', 'trans_z']): + viz_logger.add_histogram(logger_prefix+'pose/'+name, pose[...,i], total_iter) + rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose.detach().cpu()[...,:4]), convention='XYZ') + for i, name in enumerate(['rot_x', 'rot_y', 'rot_z']): + viz_logger.add_histogram(logger_prefix+'pose/'+name, rot_euler[...,i], total_iter) + elif rot_rep in ['lookat', 'quadlookat', 'octlookat']: + for i, name in enumerate(['fwd_x', 'fwd_y', 'fwd_z']): + viz_logger.add_histogram(logger_prefix+'pose/'+name, pose_raw[...,i], total_iter) + for i, name in enumerate(['trans_x', 'trans_y', 'trans_z']): + viz_logger.add_histogram(logger_prefix+'pose/'+name, pose_raw[...,-3+i], total_iter) + + if rot_rep in ['quadlookat', 'octlookat']: + for i, rp in enumerate(forward_aux['rots_probs'].unbind(-1)): + viz_logger.add_histogram(logger_prefix+'pose/rot_prob_%d'%i, rp, total_iter) + + if bank_embedding is not None: + weights_for_emb = bank_embedding[2]['weights'] # [B, k] + for i, weight_for_emb in enumerate(weights_for_emb.unbind(-1)): + viz_logger.add_histogram(logger_prefix+'bank_embedding/emb_weight_%d'%i, weight_for_emb, total_iter) + + indices_for_emb = bank_embedding[2]['pick_idx'] # [B, k] + for i, idx_for_emb in enumerate(indices_for_emb.unbind(-1)): + viz_logger.add_histogram(logger_prefix+'bank_embedding/emb_idx_%d'%i, idx_for_emb, total_iter) + + + if 'pose_xflip_raw' in aux_viz: + pose_xflip_raw = aux_viz['pose_xflip_raw'] + if rot_rep == 'euler_angle' or rot_rep == 'soft_calss': + for i, name in enumerate(['rot_x', 'rot_y', 'rot_z', 'trans_x', 'trans_y', 'trans_z']): + viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip[...,i], total_iter) + elif rot_rep == 'quaternion': + for i, name in enumerate(['qt_0', 'qt_1', 'qt_2', 'qt_3', 'trans_x', 'trans_y', 'trans_z']): + viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip[...,i], total_iter) + rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose_xflip.detach().cpu()[...,:4]), convention='XYZ') + for i, name in enumerate(['rot_x', 'rot_y', 'rot_z']): + viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, rot_euler[...,i], total_iter) + elif rot_rep in ['lookat', 'quadlookat', 'octlookat']: + for i, name in enumerate(['fwd_x', 'fwd_y', 'fwd_z']): + viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip_raw[...,i], total_iter) + for i, name in enumerate(['trans_x', 'trans_y', 'trans_z']): + viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip_raw[...,-3+i], total_iter) + + if dino_feat_im_gt is not None: + dino_feat_im_gt_first3 = dino_feat_im_gt[:,:,:3] + viz_logger.add_image(logger_prefix+'image/dino_feat_im_gt', misc.image_grid(dino_feat_im_gt_first3.detach().cpu()[:b0,:].reshape(-1,*dino_feat_im_gt_first3.shape[2:]).clamp(0,1)), total_iter) + + if dino_cluster_im_gt is not None: + viz_logger.add_image(logger_prefix+'image/dino_cluster_im_gt', misc.image_grid(dino_cluster_im_gt.detach().cpu()[:b0,:].reshape(-1,*dino_cluster_im_gt.shape[2:]).clamp(0,1)), total_iter) + + if dino_feat_im_pred is not None: + dino_feat_im_pred_first3 = dino_feat_im_pred[:,:,:3] + viz_logger.add_image(logger_prefix+'image/dino_feat_im_pred', misc.image_grid(dino_feat_im_pred_first3.detach().cpu()[:b0,:].reshape(-1,*dino_feat_im_pred_first3.shape[2:]).clamp(0,1)), total_iter) + + for which_shape, modes in self.extra_renders.items(): + # This is wrong + # if which_shape == "prior": + # shape_to_render = prior_shape.extend(im_features.shape[0]) + # needed_im_features = None + if which_shape == "instance": + shape_to_render = shape + needed_im_features = im_features + else: + raise NotImplementedError + + for mode in modes: + if mode in ['gray']: + gray_light = FixedDirectionLight(direction=torch.FloatTensor([0, 0, 1]).to(self.device), amb=0.2, diff=0.7) + _, render_mask, _, _, _, rendered = self.render(shape_to_render, texture, mvp, w2c, campos, (h, w), background=self.background_mode, im_features=needed_im_features, prior_shape=prior_shape, render_mode='diffuse', light=gray_light, render_flow=False, dino_pred=None, im_features_map=im_features_map) #renderer for visualization only!!! + if self.background_mode == 'white': + # we want to render shading here, which is always black background, so modify here + render_mask = render_mask.unsqueeze(1) + rendered[render_mask == 0] = 1 + rendered = rendered.repeat(1, 3, 1, 1) + else: + rendered, _, _, _, _, _ = self.render(shape_to_render, texture, mvp, w2c, campos, (h, w), background=self.background_mode, im_features=needed_im_features, prior_shape=prior_shape, render_mode=mode, render_flow=False, dino_pred=None, im_features_map=im_features_map) #renderer for visualization only!!! + if 'kd' in mode: + rendered = util.rgb_to_srgb(rendered) + rendered = rendered.detach().cpu() + rendered_wo_bones = rendered + + if 'posed_bones' in aux_viz: + rendered_bone_image = self.render_bones(mvp, aux_viz['posed_bones'], (h, w)) + rendered_bone_image_mask = (rendered_bone_image < 1).any(1, keepdim=True).float() + # viz_logger.add_image(logger_prefix+'image/articulation_bones', misc.image_grid(self.render_bones(mvp, aux_viz['posed_bones'])), total_iter) + rendered = rendered_bone_image_mask*0.8 * rendered_bone_image + (1-rendered_bone_image_mask*0.8) * rendered + + if rot_rep in ['quadlookat', 'octlookat']: + rand_pose_flag = forward_aux['rand_pose_flag'].detach().cpu() + rand_pose_marker_mask = torch.zeros_like(rendered) + rand_pose_marker_mask[:,:,:16,:16] = 1. + rand_pose_marker_mask = rand_pose_marker_mask * rand_pose_flag[:,None,None,None] + red = torch.FloatTensor([1,0,0])[None,:,None,None] + rendered = rand_pose_marker_mask * red + (1-rand_pose_marker_mask) * rendered + + viz_logger.add_image( + logger_prefix + f'image/{which_shape}_{mode}', + misc.image_grid(expandF(rendered)[:b0, ...].view(-1, *rendered.shape[1:])), + total_iter) + + if rendered_wo_bones is not None: + viz_logger.add_image( + logger_prefix + f'image/{which_shape}_{mode}_raw', + misc.image_grid(expandF(rendered_wo_bones)[:b0, ...].view(-1, *rendered_wo_bones.shape[1:])), + total_iter) + + if mode in ['gray']: + viz_logger.add_video( + logger_prefix + f'animation/{which_shape}_{mode}', + self.render_rotation_frames(shape_to_render, texture, gray_light, (h, w), background=self.background_mode, im_features=needed_im_features, prior_shape=prior_shape, num_frames=15, render_mode='diffuse', b=1, im_features_map=im_features_map, original_mvp=mvp, original_w2c=w2c, original_campos=campos, render_gray=True).detach().cpu().unsqueeze(0), + total_iter, + fps=2) + else: + viz_logger.add_video( + logger_prefix + f'animation/{which_shape}_{mode}', + self.render_rotation_frames(shape_to_render, texture, light, (h, w), background=self.background_mode, im_features=needed_im_features, prior_shape=prior_shape, num_frames=15, render_mode=mode, b=1, im_features_map=im_features_map, original_mvp=mvp, original_w2c=w2c, original_campos=campos).detach().cpu().unsqueeze(0), + total_iter, + fps=2) + + viz_logger.add_video( + logger_prefix+'animation/prior_image_rotation', + self.render_rotation_frames(prior_shape, texture, light, (h, w), background=self.background_mode, im_features=im_features, num_frames=15, b=1, text=category_name[0], im_features_map=im_features_map, original_mvp=mvp).detach().cpu().unsqueeze(0).clamp(0,1), + total_iter, + fps=2) + + viz_logger.add_video( + logger_prefix+'animation/prior_normal_rotation', + self.render_rotation_frames(prior_shape, texture, light, (h, w), background=self.background_mode, im_features=im_features, num_frames=15, render_mode='geo_normal', b=1, text=category_name[0], im_features_map=im_features_map, original_mvp=mvp).detach().cpu().unsqueeze(0), + total_iter, + fps=2) + + if save_results and self.rank == 0: + b0 = self.cfgs.get('num_saved_from_each_batch', batch_size*num_frames) + # from IPython import embed; embed() + fnames = [f'{total_iter:07d}_{fid:010d}' for fid in collapseF(frame_id.int())][:b0] + + # pkl_str = osp.join(save_dir, f'{total_iter:07d}_animal_data.pkl') + os.makedirs(save_dir, exist_ok=True) + # with open(pkl_str, 'wb') as fpkl: + # pickle.dump(save_for_pkl, fpkl) + # fpkl.close() + + misc.save_images(save_dir, collapseF(image_gt)[:b0].clamp(0,1).detach().cpu().numpy(), suffix='image_gt', fnames=fnames) + misc.save_images(save_dir, collapseF(image_pred)[:b0].clamp(0,1).detach().cpu().numpy(), suffix='image_pred', fnames=fnames) + misc.save_images(save_dir, collapseF(mask_gt)[:b0].unsqueeze(1).repeat(1,3,1,1).clamp(0,1).detach().cpu().numpy(), suffix='mask_gt', fnames=fnames) + misc.save_images(save_dir, collapseF(mask_pred)[:b0].unsqueeze(1).repeat(1,3,1,1).clamp(0,1).detach().cpu().numpy(), suffix='mask_pred', fnames=fnames) + # tmp_shape = shape.first_n(b0).clone() + # tmp_shape.material = texture + # feat = im_features[:b0] if im_features is not None else None + # misc.save_obj(save_dir, tmp_shape, save_material=False, feat=feat, suffix="mesh", fnames=fnames) # Save the first mesh. + if self.render_flow and flow_gt is not None: + flow_gt_viz = torch.cat([flow_gt, torch.zeros_like(flow_gt[:,:,:1])], 2) + 0.5 # -0.5~1.5 + flow_gt_viz = flow_gt_viz.view(-1, *flow_gt_viz.shape[2:]) + misc.save_images(save_dir, flow_gt_viz[:b0].clamp(0,1).detach().cpu().numpy(), suffix='flow_gt', fnames=fnames) + if flow_pred is not None: + flow_pred_viz = torch.cat([flow_pred, torch.zeros_like(flow_pred[:,:,:1])], 2) + 0.5 # -0.5~1.5 + flow_pred_viz = flow_pred_viz.view(-1, *flow_pred_viz.shape[2:]) + misc.save_images(save_dir, flow_pred_viz[:b0].clamp(0,1).detach().cpu().numpy(), suffix='flow_pred', fnames=fnames) + + misc.save_txt(save_dir, pose[:b0].detach().cpu().numpy(), suffix='pose', fnames=fnames) + return metrics + + def save_scores(self, path): + header = 'mask_mse, \ + mask_iou, \ + image_mse, \ + flow_mse' + mean = self.all_scores.mean(0) + std = self.all_scores.std(0) + header = header + '\nMean: ' + ',\t'.join(['%.8f'%x for x in mean]) + header = header + '\nStd: ' + ',\t'.join(['%.8f'%x for x in std]) + misc.save_scores(path, self.all_scores, header=header) + print(header) + + def render_rotation_frames(self, mesh, texture, light, resolution, background='none', im_features=None, prior_shape=None, num_frames=36, render_mode='diffuse', b=None, text=None, im_features_map=None, original_mvp=None, original_w2c=None, original_campos=None, render_gray=False): + frames = [] + if b is None: + b = len(mesh) + else: + mesh = mesh.first_n(b) + feat = im_features[:b] if im_features is not None else None + im_features_map = im_features_map[:b] if im_features_map is not None else None + original_mvp = original_mvp[:b] if original_mvp is not None else None # [b, 4, 4] + + if im_features_map is not None: + im_features_map = {'im_features_map': im_features_map, 'original_mvp':original_mvp} + + delta_angle = np.pi / num_frames * 2 + delta_rot_matrix = torch.FloatTensor([ + [np.cos(delta_angle), 0, np.sin(delta_angle), 0], + [0, 1, 0, 0], + [-np.sin(delta_angle), 0, np.cos(delta_angle), 0], + [0, 0, 0, 1], + ]).to(self.device).repeat(b, 1, 1) + + w2c = torch.FloatTensor(np.diag([1., 1., 1., 1])) + w2c[:3, 3] = torch.FloatTensor([0, 0, -self.cam_pos_z_offset *1.1]) + w2c = w2c.repeat(b, 1, 1).to(self.device) + proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device) + mvp = torch.bmm(proj, w2c) + campos = -w2c[:, :3, 3] + + if original_w2c is not None and original_campos is not None and original_mvp is not None: + w2c = original_w2c[:b] + campos = original_campos[:b] + mvp = original_mvp[:b] + + def rotate_pose(mvp, campos): + mvp = torch.matmul(mvp, delta_rot_matrix) + campos = torch.matmul(delta_rot_matrix[:,:3,:3].transpose(2,1), campos[:,:,None])[:,:,0] + return mvp, campos + + for _ in range(num_frames): + if render_gray: + _, render_mask, _, _, _, image_pred = self.render(mesh, texture, mvp, w2c, campos, resolution, background=background, im_features=feat, light=light, prior_shape=prior_shape, render_flow=False, dino_pred=None, render_mode=render_mode, two_sided_shading=False, im_features_map=im_features_map) + if self.background_mode == 'white': + # we want to render shading here, which is always black background, so modify here + render_mask = render_mask.unsqueeze(1) + image_pred[render_mask == 0] = 1 + image_pred = image_pred.repeat(1, 3, 1, 1) + else: + image_pred, _, _, _, _, _ = self.render(mesh, texture, mvp, w2c, campos, resolution, background=background, im_features=feat, light=light, prior_shape=prior_shape, render_flow=False, dino_pred=None, render_mode=render_mode, two_sided_shading=False, im_features_map=im_features_map) #for rotation frames only! + image_pred = image_pred.clamp(0, 1) + frames += [misc.image_grid(image_pred)] + mvp, campos = rotate_pose(mvp, campos) + + if text is not None: + frames = [torch.Tensor(misc.add_text_to_image(f, text)).permute(2, 0, 1) for f in frames] + + return torch.stack(frames, dim=0) # Shape: (T, C, H, W) + + def render_bones(self, mvp, bones_pred, size=(256, 256)): + bone_world4 = torch.concat([bones_pred, torch.ones_like(bones_pred[..., :1]).to(bones_pred.device)], dim=-1) + b, f, num_bones = bone_world4.shape[:3] + bones_clip4 = (bone_world4.view(b, f, num_bones*2, 1, 4) @ mvp.transpose(-1, -2).reshape(b, f, 1, 4, 4)).view(b, f, num_bones, 2, 4) + bones_uv = bones_clip4[..., :2] / bones_clip4[..., 3:4] # b, f, num_bones, 2, 2 + dpi = 32 + fx, fy = size[1] // dpi, size[0] // dpi + + rendered = [] + for b_idx in range(b): + for f_idx in range(f): + frame_bones_uv = bones_uv[b_idx, f_idx].cpu().numpy() + fig = plt.figure(figsize=(fx, fy), dpi=dpi, frameon=False) + ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax.set_axis_off() + for bone in frame_bones_uv: + ax.plot(bone[:, 0], bone[:, 1], marker='o', linewidth=8, markersize=20) + ax.set_xlim(-1, 1) + ax.set_ylim(-1, 1) + ax.invert_yaxis() + # Convert to image + fig.add_axes(ax) + fig.canvas.draw_idle() + image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + w, h = fig.canvas.get_width_height() + image.resize(h, w, 3) + rendered += [image / 255.] + return torch.from_numpy(np.stack(rendered, 0).transpose(0, 3, 1, 2)) + + def render_deformation_frames(self, mesh, texture, batch_size, num_frames, resolution, background='none', im_features=None, render_mode='diffuse', b=None): + # frames = [] + # if b is None: + # b = batch_size + # im_features = im_features[] + # mesh = mesh.first_n(num_frames * b) + # for i in range(b): + # tmp_mesh = mesh.get_m_to_n(i*num_frames:(i+1)*num_frames) + pass + + def vis_sds_image(self, sds_image, sds_aux): + sds_image = sds_image.detach().cpu().clamp(0, 1) + sds_image = [misc.add_text_to_image(img, text) for img, text in zip(sds_image, sds_aux['dirs'])] + return misc.image_grid(sds_image) + + def vis_sds_grads(self, sds_aux): + grads = sds_aux['sd_aux']['grad'] + grads = grads.detach().cpu() + # compute norm + grads_norm = grads.norm(dim=1, keepdim=True) + # interpolate to 4x size + grads_norm = F.interpolate(grads_norm, scale_factor=4, mode='nearest') + # add time step and weight + t = sds_aux['sd_aux']['t'] + w = sds_aux['sd_aux']['w'] + # max norm for each sample over dim (1, 2, 3) + n = grads_norm.view(grads_norm.shape[0], -1).max(dim=1)[0] + texts = [f"t: {t_} w: {w_:.2f} n: {n_:.2e}" for t_, w_ , n_ in zip(t, w, n)] + return misc.image_grid_multi_channel(grads_norm, texts=texts, font_scale=0.5) \ No newline at end of file diff --git a/video3d/networks.py b/video3d/networks.py new file mode 100755 index 0000000000000000000000000000000000000000..b80895ad6a3d5c1233e4d1ee729de3f2b03c9425 --- /dev/null +++ b/video3d/networks.py @@ -0,0 +1,1724 @@ +import numpy as np +import torch +import torch.nn as nn +import torchvision +import torchvision.models as models +from typing import Union, List, Tuple +import os +import video3d.utils.misc as misc +import torch.nn.functional as F +from siren_pytorch import SirenNet +from video3d.triplane_texture.lift_architecture import Lift_Encoder +from video3d.triplane_texture.triplane_transformer import Triplane_Transformer + + +EPS = 1e-7 + + +def get_activation(name, inplace=True, lrelu_param=0.2): + if name == 'tanh': + return nn.Tanh() + elif name == 'sigmoid': + return nn.Sigmoid() + elif name == 'relu': + return nn.ReLU(inplace=inplace) + elif name == 'lrelu': + return nn.LeakyReLU(lrelu_param, inplace=inplace) + else: + raise NotImplementedError + + +class MLPWithPositionalEncoding(nn.Module): + def __init__(self, + cin, + cout, + num_layers, + nf=256, + dropout=0, + activation=None, + n_harmonic_functions=10, + omega0=1, + extra_dim=0, + embed_concat_pts=True, + symmetrize=False): + super().__init__() + self.extra_dim = extra_dim + + if n_harmonic_functions > 0: + self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) + dim_in = cin * 2 * n_harmonic_functions + self.embed_concat_pts = embed_concat_pts + if embed_concat_pts: + dim_in += cin + else: + self.embedder = None + dim_in = cin + + self.in_layer = nn.Linear(dim_in, nf) + self.relu = nn.ReLU(inplace=True) + self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation) + self.symmetrize = symmetrize + + def forward(self, x, feat=None): + assert (feat is None and self.extra_dim == 0) or feat.shape[-1] == self.extra_dim + if self.symmetrize: + xs, ys, zs = x.unbind(-1) + x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x + + if self.embedder is not None: + x_in = self.embedder(x) + if self.embed_concat_pts: + x_in = torch.cat([x, x_in], -1) + else: + x_in = x + + x_in = self.relu(self.in_layer(x_in)) + + if feat is not None: + # if len(feat.shape) == 1: + # for _ in range(len(x_in.shape) - 1): + # feat = feat.unsqueeze(0) + # feat = feat.repeat(*x_in.shape[:-1], 1) + x_in = torch.concat([x_in, feat], dim=-1) + + return self.mlp(x_in) + + +class MLPWithPositionalEncoding_Style(nn.Module): + def __init__(self, + cin, + cout, + num_layers, + nf=256, + dropout=0, + activation=None, + n_harmonic_functions=10, + omega0=1, + extra_dim=0, + embed_concat_pts=True, + symmetrize=False, + style_choice='film'): + super().__init__() + self.extra_dim = extra_dim + + if n_harmonic_functions > 0: + self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) + dim_in = cin * 2 * n_harmonic_functions + self.embed_concat_pts = embed_concat_pts + if embed_concat_pts: + dim_in += cin + else: + self.embedder = None + dim_in = cin + + self.in_layer = nn.Linear(dim_in, nf) + self.relu = nn.ReLU(inplace=True) + + if extra_dim == 0: + self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation) + + else: + if style_choice == 'film': + self.mlp = MLP_FiLM(nf, cout, num_layers, nf, dropout, activation) + self.style_mlp = MLP(extra_dim, nf*2, 2, nf, dropout, None) + + elif style_choice == 'mod': + self.mlp = MLP_Mod(nf, cout, num_layers, nf, dropout, activation) + self.style_mlp = MLP(extra_dim, nf, 2, nf, dropout, None) + + else: + raise NotImplementedError + + self.style_choice = style_choice + + self.symmetrize = symmetrize + + def forward(self, x, feat=None): + assert (feat is None and self.extra_dim == 0) or feat.shape[-1] == self.extra_dim + if self.symmetrize: + xs, ys, zs = x.unbind(-1) + x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x + + if self.embedder is not None: + x_in = self.embedder(x) + if self.embed_concat_pts: + x_in = torch.cat([x, x_in], -1) + else: + x_in = x + + x_in = self.relu(self.in_layer(x_in)) + + if feat is not None: + style = self.style_mlp(feat) + + if self.style_choice == 'film': + style = style.reshape(style.shape[:-1] + (-1, 2)) + + out = self.mlp(x_in, style) + + else: + out = self.mlp(x_in) + + return out + + +class MLP_FiLM(nn.Module): + def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None): + # default no dropout + super().__init__() + assert num_layers >= 1 + self.num_layers = num_layers + if num_layers == 1: + self.network = Linear_FiLM(cin, cout, bias=False) + else: + self.relu = nn.ReLU(inplace=True) + for i in range(num_layers): + if i == 0: + setattr(self, f'linear_{i}', Linear_FiLM(cin, nf, bias=False)) + elif i == (num_layers-1): + setattr(self, f'linear_{i}', Linear_FiLM(nf, cout, bias=False)) + else: + setattr(self, f'linear_{i}', Linear_FiLM(nf, nf, bias=False)) + + def forward(self, input, style): + if self.num_layers == 1: + out = self.network(input, style) + else: + x = input + for i in range(self.num_layers): + linear_layer = getattr(self, f'linear_{i}') + if i == (self.num_layers - 1): + x = linear_layer(x, style) + else: + x = linear_layer(x, style) + x = self.relu(x) + + out = x + return out + + +class MLP_Mod(nn.Module): + def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None): + # default no dropout + super().__init__() + assert num_layers >= 1 + self.num_layers = num_layers + if num_layers == 1: + self.network = Linear_Mod(cin, cout, bias=False) + else: + self.relu = nn.ReLU(inplace=True) + for i in range(num_layers): + if i == 0: + setattr(self, f'linear_{i}', Linear_Mod(cin, nf, bias=False)) + elif i == (num_layers-1): + setattr(self, f'linear_{i}', Linear_Mod(nf, cout, bias=False)) + else: + setattr(self, f'linear_{i}', Linear_Mod(nf, nf, bias=False)) + + def forward(self, input, style): + if self.num_layers == 1: + out = self.network(input, style) + else: + x = input + for i in range(self.num_layers): + linear_layer = getattr(self, f'linear_{i}') + if i == (self.num_layers - 1): + x = linear_layer(x, style) + else: + x = linear_layer(x, style) + x = self.relu(x) + + out = x + return out + + +import math + +class Linear_FiLM(nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) + if bias: + self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input, style): + # if input is [..., D], style should be [..., D, 2] + x = input * style[..., 0] + style[..., 1] + return torch.nn.functional.linear(x, self.weight, self.bias) + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features, self.bias is not None + ) + + +class Linear_Mod(nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) + if bias: + self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input, style): + # weight: [out_features, in_features] + # style: [..., in_features] + if len(style.shape) > 1: + style = style.reshape(-1, style.shape[-1]) + style = style[0] + + weight = self.weight * style.unsqueeze(0) + decoefs = ((weight * weight).sum(dim=-1, keepdim=True) + 1e-5).sqrt() + weight = weight / decoefs + + return torch.nn.functional.linear(input, weight, self.bias) + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features, self.bias is not None + ) + + +class MLPTextureSimple(nn.Module): + def __init__(self, + cin, + cout, + num_layers, + nf=256, + dropout=0, + activation=None, + min_max=None, + n_harmonic_functions=10, + omega0=1, + extra_dim=0, + embed_concat_pts=True, + perturb_normal=False, + symmetrize=False, + texture_act='relu', + linear_bias=False): + super().__init__() + self.extra_dim = extra_dim + + if n_harmonic_functions > 0: + self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) + dim_in = cin * 2 * n_harmonic_functions + self.embed_concat_pts = embed_concat_pts + if embed_concat_pts: + dim_in += cin + else: + self.embedder = None + dim_in = cin + + self.in_layer = nn.Linear(dim_in, nf) + self.relu = nn.ReLU(inplace=True) + + if texture_act == 'sin': + print('using siren network for texture mlp here') + self.mlp = SirenNet( + dim_in=(nf + extra_dim), + dim_hidden=nf, + dim_out=cout, + num_layers=num_layers, + final_activation=get_activation(activation), + w0_initial=30, + use_bias=linear_bias, + dropout=dropout + ) + else: + self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation, inner_act=texture_act, linear_bias=linear_bias) + self.perturb_normal = perturb_normal + self.symmetrize = symmetrize + if min_max is not None: + self.register_buffer('min_max', min_max) + else: + self.min_max = None + self.bsdf = None + + def sample(self, x, feat=None): + assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] == self.extra_dim) + b, h, w, c = x.shape + + if self.symmetrize: + xs, ys, zs = x.unbind(-1) + x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x + + x = x.view(-1, c) + if self.embedder is not None: + x_in = self.embedder(x) + if self.embed_concat_pts: + x_in = torch.cat([x, x_in], -1) + else: + x_in = x + + x_in = self.in_layer(x_in) + if feat is not None: + feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) + x_in = torch.concat([x_in, feat], dim=-1) + out = self.mlp(self.relu(x_in)) + if self.min_max is not None: + out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] + return out.view(b, h, w, -1) + + +class MLPTextureTriplane(nn.Module): + def __init__(self, + cin, + cout, + num_layers, + nf=256, + dropout=0, + activation=None, + min_max=None, + n_harmonic_functions=10, + omega0=1, + extra_dim=0, + embed_concat_pts=True, + perturb_normal=False, + symmetrize=False, + texture_act='relu', + linear_bias=False, + cam_pos_z_offset=10., + grid_scale=7,): + super().__init__() + self.extra_dim = extra_dim + + if n_harmonic_functions > 0: + self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) + dim_in = cin * 2 * n_harmonic_functions + self.embed_concat_pts = embed_concat_pts + if embed_concat_pts: + dim_in += cin + else: + self.embedder = None + dim_in = cin + + self.in_layer = nn.Linear(dim_in, nf) + self.relu = nn.ReLU(inplace=True) + + self.feat_net = Triplane_Transformer( + emb_dim=256, + num_layers=8, + triplane_dim=80, + triplane_scale=grid_scale + ) + self.extra_dim -= extra_dim + self.extra_dim += (self.feat_net.triplane_dim * 3) + + if texture_act == 'sin': + print('using siren network for texture mlp here') + self.mlp = SirenNet( + dim_in=(nf + self.extra_dim), + dim_hidden=nf, + dim_out=cout, + num_layers=num_layers, + final_activation=get_activation(activation), + w0_initial=30, + use_bias=linear_bias, + dropout=dropout + ) + else: + self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation, inner_act=texture_act, linear_bias=linear_bias) + self.perturb_normal = perturb_normal + self.symmetrize = symmetrize + if min_max is not None: + self.register_buffer('min_max', min_max) + else: + self.min_max = None + self.bsdf = None + + def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None): + # assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] == self.extra_dim) + b, h, w, c = x.shape + + if self.symmetrize: + xs, ys, zs = x.unbind(-1) + x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x + + if isinstance(feat_map, dict): + feat_map = feat_map["im_features_map"] + + feat_map = feat_map.permute(0, 2, 3, 1) + _, ph, pw, _ = feat_map.shape + feat_map = feat_map.reshape(feat_map.shape[0], ph*pw, feat_map.shape[-1]) + pts_feat = self.feat_net(feat_map, x.reshape(b, -1, 3)) + pts_c = pts_feat.shape[-1] + pts_feat = pts_feat.reshape(-1, pts_c) + + x = x.view(-1, c) + if self.embedder is not None: + x_in = self.embedder(x) + if self.embed_concat_pts: + x_in = torch.cat([x, x_in], -1) + else: + x_in = x + + x_in = self.in_layer(x_in) + + x_in = torch.concat([x_in, pts_feat], dim=-1) + + out = self.mlp(self.relu(x_in)) + if self.min_max is not None: + out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] + return out.view(b, h, w, -1) + + +class LocalFeatureBlock(nn.Module): + def __init__(self, local_feat_dim, input_dim=384, output_dim=384, upscale_num=3): + super().__init__() + self.local_feat_dim = local_feat_dim + self.conv_list = nn.ModuleList([]) + self.upscale_list = nn.ModuleList([]) + + for i in range(upscale_num): + if i == 0: + self.conv_list.append(nn.Conv2d(input_dim, 4 * local_feat_dim, 3, stride=1, padding=1, dilation=1)) + else: + self.conv_list.append(nn.Conv2d(local_feat_dim, 4 * local_feat_dim, 3, stride=1, padding=1, dilation=1)) + self.upscale_list.append(nn.PixelShuffle(2)) + + self.conv_head = nn.Conv2d(local_feat_dim, output_dim, 3, stride=1, padding=1, dilation=1) + + def forward(self, x): + for idx, conv in enumerate(self.conv_list): + x = conv(x) + x = self.upscale_list[idx](x) + + out = self.conv_head(x) + return out + + +class MLPTextureLocal(nn.Module): + def __init__(self, + cin, + cout, + num_layers, + nf=256, + dropout=0, + activation=None, + min_max=None, + n_harmonic_functions=10, + omega0=1, + extra_dim=0, + embed_concat_pts=True, + perturb_normal=False, + symmetrize=False, + texture_way=None, + larger_tex_dim=False, + cam_pos_z_offset=10., + grid_scale=7.): + super().__init__() + self.extra_dim = extra_dim + self.cam_pos_z_offset = cam_pos_z_offset + self.grid_scale = grid_scale + + local_feat_dim = 64 + + assert texture_way is not None + self.texture_way = texture_way + if 'local' in texture_way and 'global' in texture_way: + # self.extra_dim = extra_dim + local_feat_dim + self.extra_dim = extra_dim + elif 'local' in texture_way and 'global' not in texture_way: + # self.extra_dim = local_feat_dim + self.extra_dim = extra_dim + elif 'local' not in texture_way and 'global' in texture_way: + self.extra_dim = extra_dim + + if n_harmonic_functions > 0: + self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) + dim_in = cin * 2 * n_harmonic_functions + self.embed_concat_pts = embed_concat_pts + if embed_concat_pts: + dim_in += cin + else: + self.embedder = None + dim_in = cin + + # self.local_feature_block = LocalFeatureBlock(local_feat_dim=local_feat_dim, input_dim=384, output_dim=256) + self.local_feature_block = nn.Linear(384, nf, bias=False) + + self.in_layer = nn.Linear(dim_in, nf) + self.relu = nn.ReLU(inplace=True) + self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation) + self.perturb_normal = perturb_normal + self.symmetrize = symmetrize + if min_max is not None: + self.register_buffer('min_max', min_max) + else: + self.min_max = None + self.bsdf = None + + def get_uv_depth(self, xyz, mvp): + # xyz: [b, k, 3] + # mvp: [b, 4, 4] + cam4 = torch.matmul(torch.nn.functional.pad(xyz, pad=(0,1), mode='constant', value=1.0), torch.transpose(mvp, 1, 2)) + cam3 = cam4[..., :3] / cam4[..., 3:4] + cam_uv = cam3[..., :2] + # cam_uv = cam_uv.detach() + cam_depth = cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(xyz.device).view(1, 1, 3) + cam_depth = cam_depth / self.grid_scale * 2 + cam_depth = cam_depth[..., 2:3] + # cam_depth = cam_depth.detach() + return cam_uv, cam_depth + + def proj_sample_deform(self, xyz, feat_map, mvp, w2c, img_h, img_w): + # here the xyz is deformed points + # and we don't cast any symmtery here + b, k, c = xyz.shape + THRESHOLD = 1e-4 + if isinstance(feat_map, torch.Tensor): + coordinates = xyz + # use pre-symmetry points to get feature and record depth + cam_uv, cam_depth = self.get_uv_depth(coordinates, mvp) + cam_uv = cam_uv.detach() + cam_depth = cam_depth.detach() + + # get local feature + feature = F.grid_sample(feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] + + self.input_depth = cam_depth.reshape(b, 256, 256, 1) # [B, 256, 256, 1] + self.input_pts = coordinates.detach() + + elif isinstance(feat_map, dict): + original_mvp = feat_map['original_mvp'] + local_feat_map = feat_map['im_features_map'] + original_depth = self.input_depth[0:b] + + coordinates = xyz + cam_uv, cam_depth = self.get_uv_depth(coordinates, original_mvp) + cam_uv = cam_uv.detach() + cam_depth = cam_depth.detach() + + project_feature = F.grid_sample(local_feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] + project_depth = F.grid_sample(original_depth.permute(0, 3, 1, 2), cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1] + + use_mask = cam_depth <= project_depth + THRESHOLD + feature = project_feature * use_mask.repeat(1, 1, project_feature.shape[-1]) + + ret_feature = self.local_feature_block(feature.reshape(b*k, -1)) # the linear is without bias, so 0 value feature will still get 0 value + return ret_feature + + def proj_sample(self, xyz, feat_map, mvp, w2c, img_h, img_w, xyz_before_sym=None): + # the new one with no input feature map upsampling + # feat_map: [B, C, H, W] + b, k, c = xyz.shape + if isinstance(feat_map, torch.Tensor): + if xyz_before_sym is None: + coordinates = xyz + else: + coordinates = xyz_before_sym + # use pre-symmetry points to get feature and record depth + cam_uv, cam_depth = self.get_uv_depth(coordinates, mvp) + cam_uv = cam_uv.detach() + cam_depth = cam_depth.detach() + + # get local feature + feature = F.grid_sample(feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] + + self.input_depth = cam_depth.reshape(b, 256, 256, 1) # [B, 256, 256, 1] + self.input_pts = coordinates.detach() + + elif isinstance(feat_map, dict): + original_mvp = feat_map['original_mvp'] + local_feat_map = feat_map['im_features_map'] + THRESHOLD = 1e-4 + original_depth = self.input_depth[0:b] + # if b == 1: + # from pdb import set_trace; set_trace() + # tmp_mask = xyz[0].reshape(256, 256, 3).sum(dim=-1) != 0 + # tmp_mask = tmp_mask.cpu().numpy() + # tmp_mask = tmp_mask * 255 + # src_dp = self.input_depth[0,:,:,0].cpu().numpy() + # input_pts = self.input_pts[0].cpu().numpy() + # input_mask = self.input_pts[0].reshape(256, 256, 3).sum(dim=-1) != 0 + # input_mask = input_mask.int().cpu().numpy() + # input_mask = input_mask * 255 + # np.save('./tmp_save/src_dp.npy', src_dp) + # np.save('./tmp_save/input_pts.npy', input_pts) + # import cv2 + # cv2.imwrite('./tmp_save/input_mask.png', input_mask) + # cv2.imwrite('./tmp_save/mask.png', tmp_mask) + # test_pts_pos = xyz[0].cpu().numpy() + # np.save('./tmp_save/test_pts_pos.npy', test_pts_pos) + # test_pts_raw = xyz_before_sym[0].cpu().numpy() + # np.save('./tmp_save/test_pts_raw.npy', test_pts_raw) + # mvp_now = mvp[0].detach().cpu().numpy() + # mvp_original = original_mvp[0].detach().cpu().numpy() + # np.save('./tmp_save/mvp_now.npy', mvp_now) + # np.save('./tmp_save/mvp_original.npy', mvp_original) + if xyz_before_sym is None: + # just check the project depth of xyz + coordinates = xyz + cam_uv, cam_depth = self.get_uv_depth(coordinates, original_mvp) + cam_uv = cam_uv.detach() + cam_depth = cam_depth.detach() + + project_feature = F.grid_sample(local_feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] + project_depth = F.grid_sample(original_depth.permute(0, 3, 1, 2), cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1] + + use_mask = cam_depth <= project_depth + THRESHOLD + feature = project_feature * use_mask.repeat(1, 1, project_feature.shape[-1]) + else: + # need to double check, but now we are still use symmetry! Even if the two points are all visible in input view + coords_inp = xyz + x_check, y_check, z_check = xyz.unbind(-1) + xyz_check = torch.stack([-1 * x_check, y_check, z_check], -1) + coords_rev = xyz_check # we directly use neg-x to get the points of another side + + uv_inp, dp_inp = self.get_uv_depth(coords_inp, original_mvp) + uv_rev, dp_rev = self.get_uv_depth(coords_rev, original_mvp) + uv_inp = uv_inp.detach() + uv_rev = uv_rev.detach() + dp_inp = dp_inp.detach() + dp_rev = dp_rev.detach() + + proj_feat_inp = F.grid_sample(local_feat_map, uv_inp.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] + proj_feat_rev = F.grid_sample(local_feat_map, uv_rev.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c] + + proj_dp_inp = F.grid_sample(original_depth.permute(0, 3, 1, 2), uv_inp.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1] + proj_dp_rev = F.grid_sample(original_depth.permute(0, 3, 1, 2), uv_rev.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1] + + use_mask_inp = dp_inp <= proj_dp_inp + THRESHOLD + use_mask_rev = dp_rev <= proj_dp_rev + THRESHOLD + + # for those points we can see in two sides, we use average + use_mask_inp = use_mask_inp.int() + use_mask_rev = use_mask_rev.int() + both_vis = (use_mask_inp == 1) & (use_mask_rev == 1) + use_mask_inp[both_vis] = 0.5 + use_mask_rev[both_vis] = 0.5 + + feature = proj_feat_inp * use_mask_inp.repeat(1, 1, proj_feat_inp.shape[-1]) + proj_feat_rev * use_mask_rev.repeat(1, 1, proj_feat_rev.shape[-1]) + else: + raise NotImplementedError + + ret_feature = self.local_feature_block(feature.reshape(b*k, -1)) # the linear is without bias, so 0 value feature will still get 0 value + return ret_feature + + def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None): + # assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] <= self.extra_dim) + b, h, w, c = x.shape + + xyz_before_sym = None + if self.symmetrize: + xyz_before_sym = x.reshape(b, -1, c) + xs, ys, zs = x.unbind(-1) + x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x + + mvp = mvp.detach() # [b, 4, 4] + w2c = w2c.detach() # [b, 4, 4] + + pts_xyz = x.reshape(b, -1, c) + deform_xyz = deform_xyz.reshape(b, -1, c) + + if 'global' in self.texture_way and 'local' in self.texture_way: + global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) + # local_feat = self.proj_sample(pts_xyz, feat_map, mvp, w2c, h, w, xyz_before_sym=xyz_before_sym) + local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w) + # feature_rep = torch.concat([global_feat, local_feat], dim=-1) + feature_rep = global_feat + local_feat + elif 'global' not in self.texture_way and 'local' in self.texture_way: + # local_feat = self.proj_sample(pts_xyz, feat_map, mvp, w2c, h, w, xyz_before_sym=xyz_before_sym) + local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w) + feature_rep = local_feat + elif 'global' in self.texture_way and 'local' not in self.texture_way: + global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) + feature_rep = global_feat + else: + global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) + feature_rep = global_feat + + x = x.view(-1, c) + + if self.embedder is not None: + x_in = self.embedder(x) + if self.embed_concat_pts: + x_in = torch.cat([x, x_in], -1) + else: + x_in = x + + x_in = self.in_layer(x_in) + + # if feat is not None: + # feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) + # x_in = torch.concat([x_in, feat], dim=-1) + + x_in = torch.concat([x_in, feature_rep], dim=-1) + + out = self.mlp(self.relu(x_in)) + if self.min_max is not None: + out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] + return out.view(b, h, w, -1) + + +class LiftTexture(nn.Module): + def __init__(self, + cin, + cout, + num_layers, + nf=256, + dropout=0, + activation=None, + min_max=None, + n_harmonic_functions=10, + omega0=1, + extra_dim=0, + embed_concat_pts=True, + perturb_normal=False, + symmetrize=False, + texture_way=None, + cam_pos_z_offset=10., + grid_scale=7., + local_feat_dim=128, + grid_size=32, + optim_latent=False): + super().__init__() + self.extra_dim = extra_dim + self.cam_pos_z_offset = cam_pos_z_offset + self.grid_scale = grid_scale + + assert texture_way is not None + self.extra_dim = local_feat_dim + extra_dim + + if n_harmonic_functions > 0: + self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0) + dim_in = cin * 2 * n_harmonic_functions + self.embed_concat_pts = embed_concat_pts + if embed_concat_pts: + dim_in += cin + else: + self.embedder = None + dim_in = cin + + self.encoder = Lift_Encoder( + cin=384, + feat_dim=local_feat_dim, + grid_scale=grid_scale / 2, # the dmtet is initialized in (-0.5, 0.5) + grid_size=grid_size, + optim_latent=optim_latent, + with_z_feature=True, + cam_pos_z_offset=cam_pos_z_offset + ) + + + self.in_layer = nn.Linear(dim_in, nf) + self.relu = nn.ReLU(inplace=True) + self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation) + self.perturb_normal = perturb_normal + self.symmetrize = symmetrize + if min_max is not None: + self.register_buffer('min_max', min_max) + else: + self.min_max = None + self.bsdf = None + + def get_uv_depth(self, xyz, mvp): + # xyz: [b, k, 3] + # mvp: [b, 4, 4] + cam4 = torch.matmul(torch.nn.functional.pad(xyz, pad=(0,1), mode='constant', value=1.0), torch.transpose(mvp, 1, 2)) + cam3 = cam4[..., :3] / cam4[..., 3:4] + cam_uv = cam3[..., :2] + # cam_uv = cam_uv.detach() + cam_depth = cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(xyz.device).view(1, 1, 3) + cam_depth = cam_depth / self.grid_scale * 2 + cam_depth = cam_depth[..., 2:3] + # cam_depth = cam_depth.detach() + return cam_uv, cam_depth + + def proj_sample_deform(self, xyz, feat_map, mvp, w2c, img_h, img_w): + # here the xyz is deformed points + # and we don't cast any symmtery here + if isinstance(feat_map, torch.Tensor): + feature = self.encoder(feat_map, mvp, xyz, inference="unproject") + + elif isinstance(feat_map, dict): + feature = self.encoder(feat_map['im_features_map'], mvp, xyz, inference="sample") + C = feature.shape[-1] + feature = feature.reshape(-1, C) + return feature + + def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None): + # assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] <= self.extra_dim) + b, h, w, c = x.shape + + xyz_before_sym = None + if self.symmetrize: + xyz_before_sym = x.reshape(b, -1, c) + xs, ys, zs = x.unbind(-1) + x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x + + mvp = mvp.detach() # [b, 4, 4] + w2c = w2c.detach() # [b, 4, 4] + + pts_xyz = x.reshape(b, -1, c) + deform_xyz = deform_xyz.reshape(b, -1, c) + + global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) + local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w) + feature_rep = torch.concat([global_feat, local_feat], dim=-1) + x = x.view(-1, c) + + if self.embedder is not None: + x_in = self.embedder(x) + if self.embed_concat_pts: + x_in = torch.cat([x, x_in], -1) + else: + x_in = x + + x_in = self.in_layer(x_in) + + # if feat is not None: + # feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1) + # x_in = torch.concat([x_in, feat], dim=-1) + + x_in = torch.concat([x_in, feature_rep], dim=-1) + + out = self.mlp(self.relu(x_in)) + if self.min_max is not None: + out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] + return out.view(b, h, w, -1) + + +class HarmonicEmbedding(nn.Module): + def __init__(self, n_harmonic_functions=10, omega0=1): + """ + Positional Embedding implementation (adapted from Pytorch3D). + Given an input tensor `x` of shape [minibatch, ... , dim], + the harmonic embedding layer converts each feature + in `x` into a series of harmonic features `embedding` + as follows: + embedding[..., i*dim:(i+1)*dim] = [ + sin(x[..., i]), + sin(2*x[..., i]), + sin(4*x[..., i]), + ... + sin(2**self.n_harmonic_functions * x[..., i]), + cos(x[..., i]), + cos(2*x[..., i]), + cos(4*x[..., i]), + ... + cos(2**self.n_harmonic_functions * x[..., i]) + ] + Note that `x` is also premultiplied by `omega0` before + evaluting the harmonic functions. + """ + super().__init__() + self.frequencies = omega0 * (2.0 ** torch.arange(n_harmonic_functions)) + + def forward(self, x): + """ + Args: + x: tensor of shape [..., dim] + Returns: + embedding: a harmonic embedding of `x` + of shape [..., n_harmonic_functions * dim * 2] + """ + embed = (x[..., None] * self.frequencies.to(x.device)).view(*x.shape[:-1], -1) + return torch.cat((embed.sin(), embed.cos()), dim=-1) + + +class VGGEncoder(nn.Module): + def __init__(self, cout, pretrained=False): + super().__init__() + if pretrained: + raise NotImplementedError + vgg = models.vgg16() + self.vgg_encoder = nn.Sequential(vgg.features, vgg.avgpool) + self.linear1 = nn.Linear(25088, 4096) + self.linear2 = nn.Linear(4096, cout) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + batch_size, _, _, _ = x.shape + out = self.relu(self.linear1(self.vgg_encoder(x).view(batch_size, -1))) + return self.linear2(out) + + +class ResnetEncoder(nn.Module): + def __init__(self, cout, pretrained=False): + super().__init__() + self.resnet = nn.Sequential(list(models.resnet18(weights="DEFAULT" if pretrained else None).modules())[:-1]) + self.final_linear = nn.Linear(512, cout) + + def forward(self, x): + return self.final_linear(self.resnet(x)) + + +class Encoder(nn.Module): + def __init__(self, cin, cout, in_size=128, zdim=None, nf=64, activation=None): + super().__init__() + network = [ + nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64 + nn.GroupNorm(16, nf), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 + nn.GroupNorm(16*2, nf*2), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 + nn.GroupNorm(16*4, nf*4), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 + # nn.GroupNorm(16*8, nf*8), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ] + + add_downsample = int(np.log2(in_size//128)) + if add_downsample > 0: + for _ in range(add_downsample): + network += [ + nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 + # nn.GroupNorm(16*8, nf*8), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ] + + network += [ + nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 + nn.LeakyReLU(0.2, inplace=True), + ] + + if zdim is None: + network += [ + nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 + ] + else: + network += [ + nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False), + ] + + if activation is not None: + network += [get_activation(activation)] + self.network = nn.Sequential(*network) + + def forward(self, input): + return self.network(input).reshape(input.size(0), -1) + + +class EncoderWithDINO(nn.Module): + def __init__(self, cin_rgb, cin_dino, cout, in_size=128, zdim=None, nf=64, activation=None): + super().__init__() + network_rgb_in = [ + nn.Conv2d(cin_rgb, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64 + nn.GroupNorm(16, nf), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 + nn.GroupNorm(16*2, nf*2), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 + nn.GroupNorm(16*4, nf*4), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ] + self.network_rgb_in = nn.Sequential(*network_rgb_in) + network_dino_in = [ + nn.Conv2d(cin_dino, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64 + nn.GroupNorm(16, nf), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 + nn.GroupNorm(16*2, nf*2), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 + nn.GroupNorm(16*4, nf*4), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ] + self.network_dino_in = nn.Sequential(*network_dino_in) + + network_fusion = [ + nn.Conv2d(nf*4*2, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 + # nn.GroupNorm(16*8, nf*8), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ] + + add_downsample = int(np.log2(in_size//128)) + if add_downsample > 0: + for _ in range(add_downsample): + network_fusion += [ + nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 + # nn.GroupNorm(16*8, nf*8), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ] + + network_fusion += [ + nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 + nn.LeakyReLU(0.2, inplace=True), + ] + + if zdim is None: + network_fusion += [ + nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 + ] + else: + network_fusion += [ + nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False), + ] + + if activation is not None: + network_fusion += [get_activation(activation)] + self.network_fusion = nn.Sequential(*network_fusion) + + def forward(self, rgb_image, dino_image): + rgb_feat = self.network_rgb_in(rgb_image) + dino_feat = self.network_dino_in(dino_image) + out = self.network_fusion(torch.cat([rgb_feat, dino_feat], dim=1)) + return out.reshape(rgb_image.size(0), -1) + + +class Encoder32(nn.Module): + def __init__(self, cin, cout, nf=256, activation=None): + super().__init__() + network = [ + nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 + nn.GroupNorm(nf//4, nf), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(nf, nf, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 + nn.GroupNorm(nf//4, nf), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(nf, nf, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 + nn.GroupNorm(nf//4, nf), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(nf, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 + ] + if activation is not None: + network += [get_activation(activation)] + self.network = nn.Sequential(*network) + + def forward(self, input): + return self.network(input).reshape(input.size(0), -1) + + +class MLP(nn.Module): + def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None, inner_act='relu', linear_bias=False): + super().__init__() + assert num_layers >= 1 + layer_act = get_activation(inner_act) + if num_layers == 1: + network = [nn.Linear(cin, cout, bias=linear_bias)] + else: + # network = [nn.Linear(cin, nf, bias=False)] + # for _ in range(num_layers-2): + # network += [ + # nn.ReLU(inplace=True), + # nn.Linear(nf, nf, bias=False)] + # if dropout: + # network += [nn.Dropout(dropout)] + # network += [ + # nn.ReLU(inplace=True), + # nn.Linear(nf, cout, bias=False)] + network = [nn.Linear(cin, nf, bias=linear_bias)] + for _ in range(num_layers-2): + network += [ + layer_act, + nn.Linear(nf, nf, bias=linear_bias)] + if dropout: + network += [nn.Dropout(dropout)] + network += [ + layer_act, + nn.Linear(nf, cout, bias=linear_bias)] + if activation is not None: + network += [get_activation(activation)] + self.network = nn.Sequential(*network) + + def forward(self, input): + return self.network(input) + + +class Embedding(nn.Module): + def __init__(self, cin, cout, zdim=128, nf=64, activation=None): + super().__init__() + network = [ + nn.Linear(cin, nf, bias=False), + nn.ReLU(inplace=True), + nn.Linear(nf, zdim, bias=False), + nn.ReLU(inplace=True), + nn.Linear(zdim, cout, bias=False)] + if activation is not None: + network += [get_activation(activation)] + self.network = nn.Sequential(*network) + + def forward(self, input): + return self.network(input.reshape(input.size(0), -1)).reshape(input.size(0), -1) + + +class PerceptualLoss(nn.Module): + def __init__(self, requires_grad=False): + super(PerceptualLoss, self).__init__() + mean_rgb = torch.FloatTensor([0.485, 0.456, 0.406]) + std_rgb = torch.FloatTensor([0.229, 0.224, 0.225]) + self.register_buffer('mean_rgb', mean_rgb) + self.register_buffer('std_rgb', std_rgb) + + vgg_pretrained_features = torchvision.models.vgg16(pretrained=True).features + self.slice1 = nn.Sequential() + self.slice2 = nn.Sequential() + self.slice3 = nn.Sequential() + self.slice4 = nn.Sequential() + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def normalize(self, x): + out = x/2 + 0.5 + out = (out - self.mean_rgb.view(1,3,1,1)) / self.std_rgb.view(1,3,1,1) + return out + + def __call__(self, im1, im2, mask=None, conf_sigma=None): + im = torch.cat([im1,im2], 0) + im = self.normalize(im) # normalize input + + ## compute features + feats = [] + f = self.slice1(im) + feats += [torch.chunk(f, 2, dim=0)] + f = self.slice2(f) + feats += [torch.chunk(f, 2, dim=0)] + f = self.slice3(f) + feats += [torch.chunk(f, 2, dim=0)] + f = self.slice4(f) + feats += [torch.chunk(f, 2, dim=0)] + + losses = [] + for f1, f2 in feats[2:3]: # use relu3_3 features only + loss = (f1-f2)**2 + if conf_sigma is not None: + loss = loss / (2*conf_sigma**2 +EPS) + (conf_sigma +EPS).log() + if mask is not None: + b, c, h, w = loss.shape + _, _, hm, wm = mask.shape + sh, sw = hm//h, wm//w + mask0 = nn.functional.avg_pool2d(mask, kernel_size=(sh,sw), stride=(sh,sw)).expand_as(loss) + loss = (loss * mask0).sum() / mask0.sum() + else: + loss = loss.mean() + losses += [loss] + return sum(losses) + + +## from: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + + self.norm_layer = norm_layer + if norm_layer is not None: + self.bn1 = norm_layer(planes) + self.bn2 = norm_layer(planes) + + if inplanes != planes: + self.downsample = nn.Sequential( + conv1x1(inplanes, planes, stride), + norm_layer(planes), + ) + else: + self.downsample = None + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + if self.norm_layer is not None: + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + if self.norm_layer is not None: + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResEncoder(nn.Module): + def __init__(self, cin, cout, in_size=128, zdim=None, nf=64, activation=None): + super().__init__() + network = [ + nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64 + # nn.GroupNorm(16, nf), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 + # nn.GroupNorm(16*2, nf*2), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + BasicBlock(nf*2, nf*2, norm_layer=None), + BasicBlock(nf*2, nf*2, norm_layer=None), + nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 + # nn.GroupNorm(16*4, nf*4), + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + BasicBlock(nf*4, nf*4, norm_layer=None), + BasicBlock(nf*4, nf*4, norm_layer=None), + nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + BasicBlock(nf*8, nf*8, norm_layer=None), + BasicBlock(nf*8, nf*8, norm_layer=None), + ] + + add_downsample = int(np.log2(in_size//64)) + if add_downsample > 0: + for _ in range(add_downsample): + network += [ + nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + BasicBlock(nf*8, nf*8, norm_layer=None), + BasicBlock(nf*8, nf*8, norm_layer=None), + ] + + if zdim is None: + network += [ + nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 + ] + else: + network += [ + nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 + # nn.ReLU(inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False), + ] + + if activation is not None: + network += [get_activation(activation)] + self.network = nn.Sequential(*network) + + def forward(self, input): + return self.network(input).reshape(input.size(0), -1) + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class ViTEncoder(nn.Module): + def __init__(self, cout, which_vit='dino_vits8', pretrained=False, frozen=False, in_size=256, final_layer_type='none', root='/root'): + super().__init__() + if misc.is_main_process(): + force_reload = not os.path.exists(os.path.join(root, ".cache/torch/hub/checkpoints/")) + else: + force_reload = False + if "dinov2" in which_vit: + self.ViT = torch.hub.load('facebookresearch/dinov2:main', which_vit, pretrained=pretrained, force_reload=force_reload) + else: + self.ViT = torch.hub.load('facebookresearch/dino:main', which_vit, pretrained=pretrained, force_reload=force_reload) + + if frozen: + for p in self.ViT.parameters(): + p.requires_grad = False + if which_vit == 'dino_vits8': + self.vit_feat_dim = 384 + self.patch_size = 8 + elif which_vit == 'dinov2_vits14': + self.vit_feat_dim = 384 + self.patch_size = 14 + elif which_vit == 'dino_vitb8': + self.vit_feat_dim = 768 + self.patch_size = 8 + + self._feats = [] + self.hook_handlers = [] + + if final_layer_type == 'none': + pass + elif final_layer_type == 'conv': + self.final_layer_patch_out = Encoder32(self.vit_feat_dim, cout, nf=256, activation=None) + self.final_layer_patch_key = Encoder32(self.vit_feat_dim, cout, nf=256, activation=None) + elif final_layer_type == 'attention': + raise NotImplementedError + self.final_layer = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.fc = nn.Linear(self.vit_feat_dim, cout) + else: + raise NotImplementedError + self.final_layer_type = final_layer_type + + def _get_hook(self, facet: str): + """ + generate a hook method for a specific block and facet. + """ + if facet in ['attn', 'token']: + def _hook(model, input, output): + self._feats.append(output) + return _hook + + if facet == 'query': + facet_idx = 0 + elif facet == 'key': + facet_idx = 1 + elif facet == 'value': + facet_idx = 2 + else: + raise TypeError(f"{facet} is not a supported facet.") + + def _inner_hook(module, input, output): + input = input[0] + B, N, C = input.shape + qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) + self._feats.append(qkv[facet_idx]) #Bxhxtxd + return _inner_hook + + def _register_hooks(self, layers: List[int], facet: str) -> None: + """ + register hook to extract features. + :param layers: layers from which to extract features. + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] + """ + for block_idx, block in enumerate(self.ViT.blocks): + if block_idx in layers: + if facet == 'token': + self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet))) + elif facet == 'attn': + self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet))) + elif facet in ['key', 'query', 'value']: + self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet))) + else: + raise TypeError(f"{facet} is not a supported facet.") + + def _unregister_hooks(self) -> None: + """ + unregisters the hooks. should be called after feature extraction. + """ + for handle in self.hook_handlers: + handle.remove() + self.hook_handlers = [] + + def forward(self, x, return_patches=False): + b, c, h, w = x.shape + self._feats = [] + self._register_hooks([11], 'key') + #self._register_hooks([11], 'token') + x = self.ViT.prepare_tokens(x) + #x = self.ViT.prepare_tokens_with_masks(x) + + for blk in self.ViT.blocks: + x = blk(x) + out = self.ViT.norm(x) + self._unregister_hooks() + + ph, pw = h // self.patch_size, w // self.patch_size + patch_out = out[:, 1:] # first is class token + patch_out = patch_out.reshape(b, ph, pw, self.vit_feat_dim).permute(0, 3, 1, 2) + + patch_key = self._feats[0][:,:,1:] # B, num_heads, num_patches, dim + patch_key = patch_key.permute(0, 1, 3, 2).reshape(b, self.vit_feat_dim, ph, pw) + + if self.final_layer_type == 'none': + global_feat_out = out[:, 0].reshape(b, -1) # first is class token + global_feat_key = self._feats[0][:, :, 0].reshape(b, -1) # first is class token + elif self.final_layer_type == 'conv': + global_feat_out = self.final_layer_patch_out(patch_out).view(b, -1) + global_feat_key = self.final_layer_patch_key(patch_key).view(b, -1) + elif self.final_layer_type == 'attention': + raise NotImplementedError + else: + raise NotImplementedError + if not return_patches: + patch_out = patch_key = None + return global_feat_out, global_feat_key, patch_out, patch_key + + +class ArticulationNetwork(nn.Module): + def __init__(self, net_type, feat_dim, pos_dim, num_layers, nf, n_harmonic_functions=0, omega0=1, activation=None, enable_articulation_idadd=False): + super().__init__() + if n_harmonic_functions > 0: + self.posenc = HarmonicEmbedding(n_harmonic_functions=n_harmonic_functions, omega0=omega0) + pos_dim = pos_dim * (n_harmonic_functions * 2 + 1) + else: + self.posenc = None + pos_dim = 4 + cout = 3 + + if net_type == 'mlp': + self.network = MLP( + feat_dim + pos_dim, # + bone xyz pos and index + cout, # We represent the rotation of each bone by its Euler angles ψ, θ, and φ + num_layers, + nf=nf, + dropout=0, + activation=activation + ) + elif net_type == 'attention': + self.in_layer = nn.Sequential( + nn.Linear(feat_dim + pos_dim, nf), + nn.GELU(), + nn.LayerNorm(nf), + ) + self.blocks = nn.ModuleList([ + Block( + dim=nf, num_heads=8, mlp_ratio=2., qkv_bias=False, qk_scale=None, + drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm) + for i in range(num_layers)]) + out_layer = [nn.Linear(nf, cout)] + if activation: + out_layer += [get_activation(activation)] + self.out_layer = nn.Sequential(*out_layer) + else: + raise NotImplementedError + self.net_type = net_type + self.enable_articulation_idadd = enable_articulation_idadd + + def forward(self, x, pos): + pos_inp = pos + if self.posenc is not None: + pos = torch.cat([pos, self.posenc(pos)], dim=-1) + x = torch.cat([x, pos], dim=-1) + if self.enable_articulation_idadd: + articulation_id = pos_inp[..., -1:] + x = x + articulation_id + if self.net_type == 'mlp': + out = self.network(x) + elif self.net_type == 'attention': + x = self.in_layer(x) + for blk in self.blocks: + x = blk(x) + out = self.out_layer(x) + else: + raise NotImplementedError + return out + + +## Attention block from ViT (https://github.com/facebookresearch/dino/blob/main/vision_transformer.py) +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, return_attention=False): + y, attn = self.attn(self.norm1(x)) + if return_attention: + return attn + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class FeatureAttention(nn.Module): + def __init__(self, vit_type, pos_dim, embedder_freq=0, zdim=128, img_size=256, activation=None): + super().__init__() + self.zdim = zdim + if embedder_freq > 0: + self.posenc = HarmonicEmbedding(n_harmonic_functions=embedder_freq, omega0=1) + pos_dim = pos_dim * (embedder_freq * 2 + 1) + else: + self.posenc = None + self.pos_dim = pos_dim + + if vit_type == 'dino_vits8': + self.vit_feat_dim = 384 + patch_size = 8 + elif which_vit == 'dinov2_vits14': + self.vit_feat_dim = 384 + self.patch_size = 14 + elif vit_type == 'dino_vitb8': + self.vit_feat_dim = 768 + patch_size = 8 + else: + raise NotImplementedError + self.num_patches_per_dim = img_size // patch_size + + self.kv = nn.Sequential( + nn.Linear(self.vit_feat_dim, zdim), + nn.ReLU(inplace=True), + nn.LayerNorm(zdim), + nn.Linear(zdim, zdim*2), + ) + + self.q = nn.Sequential( + nn.Linear(pos_dim, zdim), + nn.ReLU(inplace=True), + nn.LayerNorm(zdim), + nn.Linear(zdim, zdim), + ) + + final_mlp = [ + nn.Linear(zdim, zdim), + nn.ReLU(inplace=True), + nn.LayerNorm(zdim), + nn.Linear(zdim, self.vit_feat_dim) + ] + if activation is not None: + final_mlp += [get_activation(activation)] + self.final_ln = nn.Sequential(*final_mlp) + + def forward(self, x, feat): + _, vit_feat_dim, ph, pw = feat.shape + assert ph == pw and ph == self.num_patches_per_dim and vit_feat_dim == self.vit_feat_dim + + if self.posenc is not None: + x = torch.cat([x, self.posenc(x)], dim=-1) + bxf, k, c = x.shape + assert c == self.pos_dim + + query = self.q(x) + feat_in = feat.view(bxf, vit_feat_dim, ph*pw).permute(0, 2, 1) # N, K, C + k, v = self.kv(feat_in).chunk(2, dim=-1) + attn = torch.einsum('bnd,bpd->bnp', query, k).softmax(dim=-1) + out = torch.einsum('bnp,bpd->bnd', attn, v) + out = self.final_ln(out) + return out diff --git a/video3d/render/light.py b/video3d/render/light.py new file mode 100755 index 0000000000000000000000000000000000000000..4514e02020fed7ac92b93655076b78caef38388a --- /dev/null +++ b/video3d/render/light.py @@ -0,0 +1,191 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import torch.nn.functional as F +import nvdiffrast.torch as dr + +from . import util +from . import renderutils as ru +from ..networks import MLP + +###################################################################################### +# Utility functions +###################################################################################### + +class cubemap_mip(torch.autograd.Function): + @staticmethod + def forward(ctx, cubemap): + return util.avg_pool_nhwc(cubemap, (2,2)) + + @staticmethod + def backward(ctx, dout): + res = dout.shape[1] * 2 + out = torch.zeros(6, res, res, dout.shape[-1], dtype=torch.float32, device="cuda") + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"), + torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"), + indexing='ij') + v = util.safe_normalize(util.cube_to_dir(s, gx, gy)) + out[s, ...] = dr.texture(dout[None, ...] * 0.25, v[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube') + return out + +###################################################################################### +# Split-sum environment map light source with automatic mipmap generation +###################################################################################### + +class EnvironmentLight(torch.nn.Module): + LIGHT_MIN_RES = 16 + + MIN_ROUGHNESS = 0.08 + MAX_ROUGHNESS = 0.5 + + def __init__(self, base): + super(EnvironmentLight, self).__init__() + self.mtx = None + self.base = torch.nn.Parameter(base.clone().detach(), requires_grad=True) + self.register_parameter('env_base', self.base) + + def xfm(self, mtx): + self.mtx = mtx + + def clone(self): + return EnvironmentLight(self.base.clone().detach()) + + def clamp_(self, min=None, max=None): + self.base.clamp_(min, max) + + def get_mip(self, roughness): + return torch.where(roughness < self.MAX_ROUGHNESS + , (torch.clamp(roughness, self.MIN_ROUGHNESS, self.MAX_ROUGHNESS) - self.MIN_ROUGHNESS) / (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) * (len(self.specular) - 2) + , (torch.clamp(roughness, self.MAX_ROUGHNESS, 1.0) - self.MAX_ROUGHNESS) / (1.0 - self.MAX_ROUGHNESS) + len(self.specular) - 2) + + def build_mips(self, cutoff=0.99): + self.specular = [self.base] + while self.specular[-1].shape[1] > self.LIGHT_MIN_RES: + self.specular += [cubemap_mip.apply(self.specular[-1])] + + self.diffuse = ru.diffuse_cubemap(self.specular[-1]) + + for idx in range(len(self.specular) - 1): + roughness = (idx / (len(self.specular) - 2)) * (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) + self.MIN_ROUGHNESS + self.specular[idx] = ru.specular_cubemap(self.specular[idx], roughness, cutoff) + self.specular[-1] = ru.specular_cubemap(self.specular[-1], 1.0, cutoff) + + def regularizer(self): + white = (self.base[..., 0:1] + self.base[..., 1:2] + self.base[..., 2:3]) / 3.0 + return torch.mean(torch.abs(self.base - white)) + + def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True): + wo = util.safe_normalize(view_pos - gb_pos) + + if specular: + roughness = ks[..., 1:2] # y component + metallic = ks[..., 2:3] # z component + spec_col = (1.0 - metallic)*0.04 + kd * metallic + diff_col = kd * (1.0 - metallic) + else: + diff_col = kd + + reflvec = util.safe_normalize(util.reflect(wo, gb_normal)) + nrmvec = gb_normal + if self.mtx is not None: # Rotate lookup + mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda') + reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape) + nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape) + + # Diffuse lookup + diffuse = dr.texture(self.diffuse[None, ...], nrmvec.contiguous(), filter_mode='linear', boundary_mode='cube') + shaded_col = diffuse * diff_col + + if specular: + # Lookup FG term from lookup texture + NdotV = torch.clamp(util.dot(wo, gb_normal), min=1e-4) + fg_uv = torch.cat((NdotV, roughness), dim=-1) + if not hasattr(self, '_FG_LUT'): + self._FG_LUT = torch.as_tensor(np.fromfile('data/irrmaps/bsdf_256_256.bin', dtype=np.float32).reshape(1, 256, 256, 2), dtype=torch.float32, device='cuda') + fg_lookup = dr.texture(self._FG_LUT, fg_uv, filter_mode='linear', boundary_mode='clamp') + + # Roughness adjusted specular env lookup + miplevel = self.get_mip(roughness) + spec = dr.texture(self.specular[0][None, ...], reflvec.contiguous(), mip=list(m[None, ...] for m in self.specular[1:]), mip_level_bias=miplevel[..., 0], filter_mode='linear-mipmap-linear', boundary_mode='cube') + + # Compute aggregate lighting + reflectance = spec_col * fg_lookup[...,0:1] + fg_lookup[...,1:2] + shaded_col += spec * reflectance + + return shaded_col * (1.0 - ks[..., 0:1]) # Modulate by hemisphere visibility + +###################################################################################### +# Load and store +###################################################################################### + +# Load from latlong .HDR file +def _load_env_hdr(fn, scale=1.0): + latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale + cubemap = util.latlong_to_cubemap(latlong_img, [512, 512]) + + l = EnvironmentLight(cubemap) + l.build_mips() + + return l + +def load_env(fn, scale=1.0): + if os.path.splitext(fn)[1].lower() == ".hdr": + return _load_env_hdr(fn, scale) + else: + assert False, "Unknown envlight extension %s" % os.path.splitext(fn)[1] + +def save_env_map(fn, light): + assert isinstance(light, EnvironmentLight), "Can only save EnvironmentLight currently" + if isinstance(light, EnvironmentLight): + color = util.cubemap_to_latlong(light.base, [512, 1024]) + util.save_image_raw(fn, color.detach().cpu().numpy()) + +###################################################################################### +# Create trainable env map with random initialization +###################################################################################### + +def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25): + base = torch.rand(6, base_res, base_res, 3, dtype=torch.float32, device='cuda') * scale + bias + return EnvironmentLight(base) + + +###################################################################################### +# Directional light source +###################################################################################### + +class DirectionalLight(torch.nn.Module): + def __init__(self, mlp_in, mlp_layers, mlp_hidden_size, intensity_min_max=None): + super(DirectionalLight, self).__init__() + self.mlp = MLP(mlp_in, 4, mlp_layers, nf=mlp_hidden_size, activation='sigmoid') + if intensity_min_max is not None: + self.register_buffer('intensity_min_max', intensity_min_max) + else: + self.intensity_min_max = None + + def forward(self, feat): + # print('----------------- forward light !!! -----------------') + out = self.mlp(feat) + light_dir = F.normalize(torch.cat([out[..., 0:1] *2-1, torch.ones_like(out[..., :1]) * 0.5, out[..., 1:2] *2-1], dim=-1), dim=-1) # upper hemisphere + if self.intensity_min_max is not None: + int = out[..., 2:] * (self.intensity_min_max[1][None, :] - self.intensity_min_max[0][None, :]) + self.intensity_min_max[0][None, :] + self.light_params = torch.cat([light_dir, int], -1) + return self.light_params + + def shade(self, feat, kd, normal): + light_params = self.forward(feat) + light_dir = light_params[..., :3][:, None, None, :] + int_amb = light_params[..., 3:4][:, None, None, :] + int_diff = light_params[..., 4:5][:, None, None, :] + shading = (int_amb + int_diff * torch.clamp(util.dot(light_dir, normal), min=0.0)) + shaded = shading * kd + return shaded, shading diff --git a/video3d/render/material.py b/video3d/render/material.py new file mode 100755 index 0000000000000000000000000000000000000000..5e812404d17374f131b76cae67242a03b15e0b16 --- /dev/null +++ b/video3d/render/material.py @@ -0,0 +1,282 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr +import cv2 + +from video3d.render.render import render_uv + +from . import util +from . import texture +from . import mlptexture +from ..utils import misc + +###################################################################################### +# Wrapper to make materials behave like a python dict, but register textures as +# torch.nn.Module parameters. +###################################################################################### +class Material(torch.nn.Module): + def __init__(self, mat_dict): + super(Material, self).__init__() + self.mat_keys = set() + for key in mat_dict.keys(): + self.mat_keys.add(key) + self[key] = mat_dict[key] + + def __contains__(self, key): + return hasattr(self, key) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, val): + self.mat_keys.add(key) + setattr(self, key, val) + + def __delitem__(self, key): + self.mat_keys.remove(key) + delattr(self, key) + + def keys(self): + return self.mat_keys + +###################################################################################### +# .mtl material format loading / storing +###################################################################################### +@torch.no_grad() +def load_mtl(fn, clear_ks=True): + import re + mtl_path = os.path.dirname(fn) + + # Read file + with open(fn, 'r') as f: + lines = f.readlines() + + # Parse materials + materials = [] + for line in lines: + split_line = re.split(' +|\t+|\n+', line.strip()) + prefix = split_line[0].lower() + data = split_line[1:] + if 'newmtl' in prefix: + material = Material({'name' : data[0]}) + materials += [material] + elif materials: + if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix: + material[prefix] = data[0] + else: + material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda') + + # Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps + for mat in materials: + if not 'bsdf' in mat: + mat['bsdf'] = 'pbr' + + if 'map_kd' in mat: + mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd'])) + else: + mat['kd'] = texture.Texture2D(mat['kd']) + + if 'map_ks' in mat: + mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3) + else: + mat['ks'] = texture.Texture2D(mat['ks']) + + if 'bump' in mat: + mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3) + + # Convert Kd from sRGB to linear RGB + mat['kd'] = texture.srgb_to_rgb(mat['kd']) + + if clear_ks: + # Override ORM occlusion (red) channel by zeros. We hijack this channel + for mip in mat['ks'].getMips(): + mip[..., 0] = 0.0 + + return materials + +@torch.no_grad() +def save_mtl(fn, material, mesh=None, feat=None, resolution=[256, 256], prior_shape=None): + folder = os.path.dirname(fn) + file = os.path.basename(fn) + prefix = '_'.join(file.split('_')[:-1]) + '_' + with open(fn, "w") as f: + f.write('newmtl defaultMat\n') + if material is not None: + f.write('bsdf %s\n' % material['bsdf']) + if 'kd_ks_normal' in material.keys(): + assert mesh is not None + glctx = dr.RasterizeGLContext() + mask, kd, ks, normal = render_uv(glctx, mesh, resolution, material['kd_ks_normal'], feat=feat, prior_shape=prior_shape) + + hole_mask = 1. - mask + hole_mask = hole_mask.int()[0] + def uv_padding(image): + uv_padding_size = 4 + inpaint_image = ( + cv2.inpaint( + (image.detach().cpu().numpy() * 255).astype(np.uint8), + (hole_mask.detach().cpu().numpy() * 255).astype(np.uint8), + uv_padding_size, + cv2.INPAINT_TELEA, + ) + / 255.0 + ) + return torch.from_numpy(inpaint_image).to(image) + + kd = uv_padding(kd[0])[None] + + batch_size = kd.shape[0] + f.write(f'map_Kd {prefix}texture_kd.png\n') + misc.save_images(folder, kd.permute(0,3,1,2).detach().cpu().numpy(), fnames=[prefix + "texture_kd"] * batch_size) + f.write(f'map_Ks {prefix}texture_ks.png\n') + misc.save_images(folder, ks.permute(0,3,1,2).detach().cpu().numpy(), fnames=[prefix + "texture_ks"] * batch_size) + # disable normal + # f.write(f'bump {prefix}texture_n.png\n') + # misc.save_images(folder, normal.permute(0,3,1,2).detach().cpu().numpy(), fnames=[prefix + "texture_n"] * batch_size) + if 'kd' in material.keys(): + f.write('map_Kd texture_kd.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_Kd.png'), texture.rgb_to_srgb(material['kd'])) + if 'ks' in material.keys(): + f.write('map_Ks texture_ks.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_Ks.png'), material['ks']) + if 'normal' in material.keys(): + f.write('bump texture_n.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(util.safe_normalize(x)+1)*0.5) + else: + f.write('Kd 1 1 1\n') + f.write('Ks 0 0 0\n') + f.write('Ka 0 0 0\n') + f.write('Tf 1 1 1\n') + f.write('Ni 1\n') + f.write('Ns 0\n') + +###################################################################################### +# Merge multiple materials into a single uber-material +###################################################################################### + +def _upscale_replicate(x, full_res): + x = x.permute(0, 3, 1, 2) + x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate') + return x.permute(0, 2, 3, 1).contiguous() + +def merge_materials(materials, texcoords, tfaces, mfaces): + assert len(materials) > 0 + for mat in materials: + assert mat['bsdf'] == materials[0]['bsdf'], "All materials must have the same BSDF (uber shader)" + assert ('normal' in mat) is ('normal' in materials[0]), "All materials must have either normal map enabled or disabled" + + uber_material = Material({ + 'name' : 'uber_material', + 'bsdf' : materials[0]['bsdf'], + }) + + textures = ['kd', 'ks', 'normal'] + + # Find maximum texture resolution across all materials and textures + max_res = None + for mat in materials: + for tex in textures: + tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1]) + max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res + + # Compute size of compund texture and round up to nearest PoT + full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(np.int) + + # Normalize texture resolution across all materials & combine into a single large texture + for tex in textures: + if tex in materials[0]: + tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x + tex_data = _upscale_replicate(tex_data, full_res) + uber_material[tex] = texture.Texture2D(tex_data) + + # Compute scaling values for used / unused texture area + s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]] + + # Recompute texture coordinates to cooincide with new composite texture + new_tverts = {} + new_tverts_data = [] + for fi in range(len(tfaces)): + matIdx = mfaces[fi] + for vi in range(3): + ti = tfaces[fi][vi] + if not (ti in new_tverts): + new_tverts[ti] = {} + if not (matIdx in new_tverts[ti]): # create new vertex + new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here + new_tverts[ti][matIdx] = len(new_tverts_data) - 1 + tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex + + return uber_material, new_tverts_data, tfaces + +###################################################################################### +# Utility functions for material +###################################################################################### + +def initial_guess_material(cfgs, mlp=False, init_mat=None, tet_bbox=None): + kd_min = torch.tensor(cfgs.get('kd_min', [0., 0., 0., 0.]), dtype=torch.float32) + kd_max = torch.tensor(cfgs.get('kd_max', [1., 1., 1., 1.]), dtype=torch.float32) + ks_min = torch.tensor(cfgs.get('ks_min', [0., 0., 0.]), dtype=torch.float32) + ks_max = torch.tensor(cfgs.get('ks_max', [0., 0., 0.]), dtype=torch.float32) + nrm_min = torch.tensor(cfgs.get('nrm_min', [-1., -1., 0.]), dtype=torch.float32) + nrm_max = torch.tensor(cfgs.get('nrm_max', [1., 1., 1.]), dtype=torch.float32) + if mlp: + num_layers = cfgs.get("num_layers_tex", 5) + nf = cfgs.get("hidden_size", 128) + enable_encoder = cfgs.get("enable_encoder", False) + feat_dim = cfgs.get("latent_dim", 64) if enable_encoder else 0 + + mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0) + mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0) + min_max = torch.stack((mlp_min, mlp_max), dim=0) + out_chn = 9 + mlp_map_opt = mlptexture.MLPTexture3D(tet_bbox, channels=out_chn, internal_dims=nf, hidden=num_layers-1, feat_dim=feat_dim, min_max=min_max) + mat = Material({'kd_ks_normal' : mlp_map_opt}) + else: + # Setup Kd (albedo) and Ks (x, roughness, metalness) textures + if cfgs.random_textures or init_mat is None: + num_channels = 4 if cfgs.layers > 1 else 3 + kd_init = torch.rand(size=cfgs.texture_res + [num_channels]) * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels] + kd_map_opt = texture.create_trainable(kd_init , cfgs.texture_res, not cfgs.custom_mip, [kd_min, kd_max]) + + ksR = np.random.uniform(size=cfgs.texture_res + [1], low=0.0, high=0.01) + ksG = np.random.uniform(size=cfgs.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu()) + ksB = np.random.uniform(size=cfgs.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu()) + + ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), cfgs.texture_res, not cfgs.custom_mip, [ks_min, ks_max]) + else: + kd_map_opt = texture.create_trainable(init_mat['kd'], cfgs.texture_res, not cfgs.custom_mip, [kd_min, kd_max]) + ks_map_opt = texture.create_trainable(init_mat['ks'], cfgs.texture_res, not cfgs.custom_mip, [ks_min, ks_max]) + + # Setup normal map + if cfgs.random_textures or init_mat is None or 'normal' not in init_mat: + normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), cfgs.texture_res, not cfgs.custom_mip, [nrm_min, nrm_max]) + else: + normal_map_opt = texture.create_trainable(init_mat['normal'], cfgs.texture_res, not cfgs.custom_mip, [nrm_min, nrm_max]) + + mat = Material({ + 'kd' : kd_map_opt, + 'ks' : ks_map_opt, + 'normal' : normal_map_opt + }) + + if init_mat is not None: + mat['bsdf'] = init_mat['bsdf'] + elif "bsdf" in cfgs: + mat['bsdf'] = cfgs["bsdf"] + else: + mat['bsdf'] = 'pbr' + + if not cfgs.get("perturb_normal", False): + mat['no_perturbed_nrm'] = True + + return mat \ No newline at end of file diff --git a/video3d/render/mesh.py b/video3d/render/mesh.py new file mode 100755 index 0000000000000000000000000000000000000000..902bfbd594800ad9573bf2e904de89bf3b1c5927 --- /dev/null +++ b/video3d/render/mesh.py @@ -0,0 +1,377 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from difflib import unified_diff +import os +import numpy as np +import torch + +from . import obj +from . import util + +######################################################################################### +# Base mesh class +# +# Minibatch in mesh is supported, as long as each mesh shares the same edge connectivity. +######################################################################################### +class Mesh: + def __init__(self, + v_pos=None, + t_pos_idx=None, + v_nrm=None, + t_nrm_idx=None, + v_tex=None, + t_tex_idx=None, + v_tng=None, + t_tng_idx=None, + material=None, + base=None): + self.v_pos = v_pos + self.v_nrm = v_nrm + self.v_tex = v_tex + self.v_tng = v_tng + self.t_pos_idx = t_pos_idx + self.t_nrm_idx = t_nrm_idx + self.t_tex_idx = t_tex_idx + self.t_tng_idx = t_tng_idx + self.material = material + + if base is not None: + self.copy_none(base) + + def __len__(self): + return len(self.v_pos) + + def copy_none(self, other): + if self.v_pos is None: + self.v_pos = other.v_pos + if self.t_pos_idx is None: + self.t_pos_idx = other.t_pos_idx + if self.v_nrm is None: + self.v_nrm = other.v_nrm + if self.t_nrm_idx is None: + self.t_nrm_idx = other.t_nrm_idx + if self.v_tex is None: + self.v_tex = other.v_tex + if self.t_tex_idx is None: + self.t_tex_idx = other.t_tex_idx + if self.v_tng is None: + self.v_tng = other.v_tng + if self.t_tng_idx is None: + self.t_tng_idx = other.t_tng_idx + if self.material is None: + self.material = other.material + + def clone(self): + out = Mesh(base=self) + if out.v_pos is not None: + out.v_pos = out.v_pos.clone().detach() + if out.t_pos_idx is not None: + out.t_pos_idx = out.t_pos_idx.clone().detach() + if out.v_nrm is not None: + out.v_nrm = out.v_nrm.clone().detach() + if out.t_nrm_idx is not None: + out.t_nrm_idx = out.t_nrm_idx.clone().detach() + if out.v_tex is not None: + out.v_tex = out.v_tex.clone().detach() + if out.t_tex_idx is not None: + out.t_tex_idx = out.t_tex_idx.clone().detach() + if out.v_tng is not None: + out.v_tng = out.v_tng.clone().detach() + if out.t_tng_idx is not None: + out.t_tng_idx = out.t_tng_idx.clone().detach() + return out + + def detach(self): + return self.clone() + + def extend(self, N: int): + """ + Create new Mesh class which contains each input mesh N times. + + Args: + N: number of new copies of each mesh. + + Returns: + new Mesh object. + """ + verts = self.v_pos.repeat(N, 1, 1) + faces = self.t_pos_idx + uvs = self.v_tex.repeat(N, 1, 1) + uv_idx = self.t_tex_idx + mat = self.material + + return make_mesh(verts, faces, uvs, uv_idx, self.material) + + def deform(self, deformation): + """ + Create new Mesh class which is obtained by performing the deformation to the self. + + Args: + deformation: tensor with shape (B, V, 3) + + Returns: + new Mesh object after the deformation. + """ + assert deformation.shape[1] == self.v_pos.shape[1] and deformation.shape[2] == 3 + verts = self.v_pos + deformation + return make_mesh(verts, self.t_pos_idx, self.v_tex.repeat(len(verts), 1, 1), self.t_tex_idx, self.material) + + def get_m_to_n(self, m: int, n: int): + """ + Create new Mesh class with the n-th (included) mesh to the m-th (not included) mesh in the batch. + + Args: + m: the index of the starting mesh to be contained. + n: the index of the first mesh not to be contained. + """ + verts = self.v_pos[m:n, ...] + faces = self.t_pos_idx + uvs = self.v_tex[m:n, ...] + uv_idx = self.t_tex_idx + mat = self.material + + return make_mesh(verts, faces, uvs, uv_idx, mat) + + def first_n(self, n: int): + """ + Create new Mesh class with only the first n meshes in the batch. + + Args: + n: number of meshes to be contained. + + Returns: + new Mesh object with the first n meshes. + """ + return self.get_m_to_n(0, n) + verts = self.v_pos[:n, ...] + faces = self.t_pos_idx + uvs = self.v_tex[:n, ...] + uv_idx = self.t_tex_idx + mat = self.material + + return make_mesh(verts, faces, uvs, uv_idx, mat) + + def get_n(self, n: int): + """ + Create new Mesh class with only the n-th meshes in the batch. + + Args: + n: the index of the mesh to be contained. + + Returns: + new Mesh object with the n-th mesh. + """ + verts = self.v_pos[n:n+1, ...] + faces = self.t_pos_idx + uvs = self.v_tex[n:n+1, ...] + uv_idx = self.t_tex_idx + mat = self.material + + return make_mesh(verts, faces, uvs, uv_idx, mat) + + +###################################################################################### +# Mesh loading helper +###################################################################################### +def load_mesh(filename, mtl_override=None): + name, ext = os.path.splitext(filename) + if ext == ".obj": + return obj.load_obj(filename, clear_ks=True, mtl_override=mtl_override) + assert False, "Invalid mesh file extension" + +###################################################################################### +# Compute AABB +###################################################################################### +def aabb(mesh): + return torch.min(mesh.v_pos, dim=0).values, torch.max(mesh.v_pos, dim=0).values + +###################################################################################### +# Compute unique edge list from attribute/vertex index list +###################################################################################### +def compute_edges(attr_idx, return_inverse=False): + with torch.no_grad(): + # Create all edges, packed by triangle + idx = attr_idx[0] + all_edges = torch.cat(( + torch.stack((idx[:, 0], idx[:, 1]), dim=-1), + torch.stack((idx[:, 1], idx[:, 2]), dim=-1), + torch.stack((idx[:, 2], idx[:, 0]), dim=-1), + ), dim=-1).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat(( + torch.gather(all_edges, 1, order), + torch.gather(all_edges, 1, 1 - order) + ), dim=-1) + + # Eliminate duplicates and return inverse mapping + return torch.unique(sorted_edges, dim=0, return_inverse=return_inverse) + +###################################################################################### +# Compute unique edge to face mapping from attribute/vertex index list +###################################################################################### +def compute_edge_to_face_mapping(attr_idx, return_inverse=False): + with torch.no_grad(): + # Get unique edges + # Create all edges, packed by triangle + idx = attr_idx[0] + all_edges = torch.cat(( + torch.stack((idx[:, 0], idx[:, 1]), dim=-1), + torch.stack((idx[:, 1], idx[:, 2]), dim=-1), + torch.stack((idx[:, 2], idx[:, 0]), dim=-1), + ), dim=-1).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat(( + torch.gather(all_edges, 1, order), + torch.gather(all_edges, 1, 1 - order) + ), dim=-1) + + # Elliminate duplicates and return inverse mapping + unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True) + + tris = torch.arange(idx.shape[0]).repeat_interleave(3).cuda() + + tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda() + + # Compute edge to face table + mask0 = order[:,0] == 0 + mask1 = order[:,0] == 1 + tris_per_edge[idx_map[mask0], 0] = tris[mask0] + tris_per_edge[idx_map[mask1], 1] = tris[mask1] + + return tris_per_edge + +###################################################################################### +# Align base mesh to reference mesh:move & rescale to match bounding boxes. +###################################################################################### +def unit_size(mesh): + with torch.no_grad(): + vmin, vmax = aabb(mesh) + scale = 2 / torch.max(vmax - vmin).item() + v_pos = mesh.v_pos - (vmax + vmin) / 2 # Center mesh on origin + v_pos = v_pos * scale # Rescale to unit size + + return Mesh(v_pos, base=mesh) + +###################################################################################### +# Center & scale mesh for rendering +###################################################################################### +def center_by_reference(base_mesh, ref_aabb, scale): + center = (ref_aabb[0] + ref_aabb[1]) * 0.5 + scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item() + v_pos = (base_mesh.v_pos - center[None, ...]) * scale + return Mesh(v_pos, base=base_mesh) + +###################################################################################### +# Simple smooth vertex normal computation +###################################################################################### +def auto_normals(imesh): + batch_size = imesh.v_pos.shape[0] + + i0 = imesh.t_pos_idx[0, :, 0] # Shape: (F) + i1 = imesh.t_pos_idx[0, :, 1] # Shape: (F) + i2 = imesh.t_pos_idx[0, :, 2] # Shape: (F) + + v0 = imesh.v_pos[:, i0, :] # Shape: (B, F, 3) + v1 = imesh.v_pos[:, i1, :] # Shape: (B, F, 3) + v2 = imesh.v_pos[:, i2, :] # Shape: (B, F, 3) + + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) # Shape: (B, F, 3) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(imesh.v_pos) # Shape: (B, V, 3) + v_nrm.scatter_add_(1, i0[None, :, None].repeat(batch_size, 1, 3), face_normals) + v_nrm.scatter_add_(1, i1[None, :, None].repeat(batch_size, 1, 3), face_normals) + v_nrm.scatter_add_(1, i2[None, :, None].repeat(batch_size, 1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20, + v_nrm, torch.tensor([0.0, 0.0, 1.0], + dtype=torch.float32, device='cuda')) + v_nrm = util.safe_normalize(v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(v_nrm)) + + return Mesh(v_nrm=v_nrm, t_nrm_idx=imesh.t_pos_idx, base=imesh) + +###################################################################################### +# Compute tangent space from texture map coordinates +# Follows http://www.mikktspace.com/ conventions +###################################################################################### +def compute_tangents(imesh): + batch_size = imesh.v_pos.shape[0] + + vn_idx = [None] * 3 + pos = [None] * 3 + tex = [None] * 3 + for i in range(0,3): + pos[i] = imesh.v_pos[:, imesh.t_pos_idx[0, :, i]] + tex[i] = imesh.v_tex[:, imesh.t_tex_idx[0, :, i]] + vn_idx[i] = imesh.t_nrm_idx[..., i:i+1] + + tangents = torch.zeros_like(imesh.v_nrm) + tansum = torch.zeros_like(imesh.v_nrm) + + # Compute tangent space for each triangle + uve1 = tex[1] - tex[0] # Shape: (B, F, 2) + uve2 = tex[2] - tex[0] # Shape: (B, F, 2) + pe1 = pos[1] - pos[0] # Shape: (B, F, 3) + pe2 = pos[2] - pos[0] # Shape: (B, F, 3) + + nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2] # Shape: (B, F, 3) + denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1] # Shape: (B, F, 1) + + # Avoid division by zero for degenerated texture coordinates + tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) # Shape: (B, F, 3) + + # Update all 3 vertices + for i in range(0,3): + idx = vn_idx[i].repeat(batch_size, 1, 3) # Shape: (B, F, 3) + tangents.scatter_add_(1, idx, tang) # tangents[n_i] = tangents[n_i] + tang + tansum.scatter_add_(1, idx, torch.ones_like(tang)) # tansum[n_i] = tansum[n_i] + 1 + tangents = tangents / tansum + + # Normalize and make sure tangent is perpendicular to normal + tangents = util.safe_normalize(tangents) + tangents = util.safe_normalize(tangents - util.dot(tangents, imesh.v_nrm) * imesh.v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(tangents)) + + return Mesh(v_tng=tangents, t_tng_idx=imesh.t_nrm_idx, base=imesh) + +###################################################################################### +# Create new Mesh from verts, faces, uvs, and uv_idx. The rest is auto computed. +###################################################################################### +def make_mesh(verts, faces, uvs, uv_idx, material): + """ + Create new Mesh class with given verts, faces, uvs, and uv_idx. + + Args: + verts: tensor of shape (B, V, 3) + faces: tensor of shape (1, F, 3) + uvs: tensor of shape (B, V, 2) + uv_idx: tensor of shape (1, F, 3) + material: an Material instance, specifying the material of the mesh. + + Returns: + new Mesh object. + """ + assert len(verts.shape) == 3 and len(faces.shape) == 3 and len(uvs.shape) == 3 and len(uv_idx.shape) == 3, "All components must be batched." + assert faces.shape[0] == 1 and uv_idx.shape[0] == 1, "Every mesh must share the same edge connectivity." + assert verts.shape[0] == uvs.shape[0], "Batch size must be consistent." + ret = Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material) + ret = auto_normals(ret) + ret = compute_tangents(ret) + return ret diff --git a/video3d/render/mlptexture.py b/video3d/render/mlptexture.py new file mode 100755 index 0000000000000000000000000000000000000000..5df6ea3fcb29043806a5cf9d3d1511da0ea487f9 --- /dev/null +++ b/video3d/render/mlptexture.py @@ -0,0 +1,122 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch +import tinycudann as tcnn +import numpy as np + +####################################################################################################################################################### +# Small MLP using PyTorch primitives, internal helper class +####################################################################################################################################################### + +class _MLP(torch.nn.Module): + def __init__(self, cfg, loss_scale=1.0): + super(_MLP, self).__init__() + self.loss_scale = loss_scale + net = (torch.nn.Linear(cfg['n_input_dims'], cfg['n_neurons'], bias=False), torch.nn.ReLU()) + for i in range(cfg['n_hidden_layers']-1): + net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_neurons'], bias=False), torch.nn.ReLU()) + net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_output_dims'], bias=False),) + self.net = torch.nn.Sequential(*net).cuda() + + self.net.apply(self._init_weights) + + if self.loss_scale != 1.0: + self.net.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] * self.loss_scale, )) + + def forward(self, x): + return self.net(x.to(torch.float32)) + + @staticmethod + def _init_weights(m): + if type(m) == torch.nn.Linear: + torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu') + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + +####################################################################################################################################################### +# Outward visible MLP class +####################################################################################################################################################### + +class MLPTexture3D(torch.nn.Module): + def __init__(self, AABB, channels=3, internal_dims=32, hidden=2, feat_dim=0, min_max=None, bsdf='diffuse', perturb_normal=False, symmetrize=False): + super(MLPTexture3D, self).__init__() + + self.channels = channels + self.feat_dim = feat_dim + self.internal_dims = internal_dims + self.AABB = AABB + self.bsdf = bsdf + self.perturb_normal = perturb_normal + self.symmetrize = symmetrize + if min_max is not None: + self.register_buffer('min_max', min_max) + else: + self.min_max = None + + # Setup positional encoding, see https://github.com/NVlabs/tiny-cuda-nn for details. + desired_resolution = 4096 + base_grid_resolution = 16 + num_levels = 16 + per_level_scale = np.exp(np.log(desired_resolution / base_grid_resolution) / (num_levels-1)) + + enc_cfg = { + "otype": "HashGrid", + "n_levels": num_levels, + "n_features_per_level": 2, + "log2_hashmap_size": 19, + "base_resolution": base_grid_resolution, + "per_level_scale" : per_level_scale + } + + # gradient_scaling = 128.0 + gradient_scaling = 1.0 + self.encoder = tcnn.Encoding(3, enc_cfg) + self.encoder.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] / gradient_scaling, )) + + # Setup MLP + mlp_cfg = { + "n_input_dims" : internal_dims + feat_dim, + "n_output_dims" : self.channels, + "n_hidden_layers" : hidden, + "n_neurons" : self.internal_dims + } + self.linear = torch.nn.Linear(self.encoder.n_output_dims, internal_dims) + self.net = _MLP(mlp_cfg, gradient_scaling) + self.relu = torch.nn.ReLU(inplace=True) + print("Encoder output: %d dims" % (self.encoder.n_output_dims)) + + # Sample texture at a given location + def sample(self, texc, feat=None): + assert (feat is None and self.feat_dim == 0) or feat.shape[-1] == self.feat_dim + + if self.symmetrize: + xs, ys, zs = texc.unbind(-1) + texc = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x + + _texc = (texc.view(-1, 3) - self.AABB[0][None, ...]) / (self.AABB[1][None, ...] - self.AABB[0][None, ...]) + _texc = torch.clamp(_texc, min=0, max=1) + + _, image_h, image_w, _ = texc.shape + p_enc = self.encoder(_texc.contiguous()) + x_in = self.linear(p_enc.type(texc.dtype)) + if feat is not None: + feat_in = feat[:, None, None, :].repeat(1, image_h, image_w, 1).view(-1, self.feat_dim) + x_in = torch.concat([x_in, feat_in], dim=-1) + out = self.net(self.relu(x_in)) + + # Sigmoid limit and scale to the allowed range + out = torch.sigmoid(out) + if self.min_max is not None: + out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] + + return out.view(*texc.shape[:-1], self.channels) # Remap to [n, h, w, c] + + def cleanup(self): + tcnn.free_temporary_memory() diff --git a/video3d/render/obj.py b/video3d/render/obj.py new file mode 100755 index 0000000000000000000000000000000000000000..cdfa3e8a1ba602d1a79db05dcc13eb048f4dedd6 --- /dev/null +++ b/video3d/render/obj.py @@ -0,0 +1,288 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import torch +import xatlas +import trimesh +import numpy as np +import cv2 +import nvdiffrast.torch as dr +from video3d.render.render import render_uv +from video3d.render.mesh import Mesh +from . import texture +from . import mesh +from . import material + +###################################################################################### +# Utility functions +###################################################################################### + +def _find_mat(materials, name): + for mat in materials: + if mat['name'] == name: + return mat + return materials[0] # Materials 0 is the default + +###################################################################################### +# Create mesh object from objfile +###################################################################################### + +def load_obj(filename, clear_ks=True, mtl_override=None): + obj_path = os.path.dirname(filename) + + # Read entire file + with open(filename, 'r') as f: + lines = f.readlines() + + # Load materials + all_materials = [ + { + 'name' : '_default_mat', + 'bsdf' : 'pbr', + 'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')), + 'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda')) + } + ] + if mtl_override is None: + for line in lines: + if len(line.split()) == 0: + continue + if line.split()[0] == 'mtllib': + all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks) # Read in entire material library + else: + all_materials += material.load_mtl(mtl_override) + + # load vertices + vertices, texcoords, normals = [], [], [] + for line in lines: + if len(line.split()) == 0: + continue + + prefix = line.split()[0].lower() + if prefix == 'v': + vertices.append([float(v) for v in line.split()[1:]]) + elif prefix == 'vt': + val = [float(v) for v in line.split()[1:]] + texcoords.append([val[0], 1.0 - val[1]]) + elif prefix == 'vn': + normals.append([float(v) for v in line.split()[1:]]) + + # load faces + activeMatIdx = None + used_materials = [] + faces, tfaces, nfaces, mfaces = [], [], [], [] + for line in lines: + if len(line.split()) == 0: + continue + + prefix = line.split()[0].lower() + if prefix == 'usemtl': # Track used materials + mat = _find_mat(all_materials, line.split()[1]) + if not mat in used_materials: + used_materials.append(mat) + activeMatIdx = used_materials.index(mat) + elif prefix == 'f': # Parse face + vs = line.split()[1:] + nv = len(vs) + vv = vs[0].split('/') + v0 = int(vv[0]) - 1 + t0 = int(vv[1]) - 1 if vv[1] != "" else -1 + n0 = int(vv[2]) - 1 if vv[2] != "" else -1 + for i in range(nv - 2): # Triangulate polygons + vv = vs[i + 1].split('/') + v1 = int(vv[0]) - 1 + t1 = int(vv[1]) - 1 if vv[1] != "" else -1 + n1 = int(vv[2]) - 1 if vv[2] != "" else -1 + vv = vs[i + 2].split('/') + v2 = int(vv[0]) - 1 + t2 = int(vv[1]) - 1 if vv[1] != "" else -1 + n2 = int(vv[2]) - 1 if vv[2] != "" else -1 + mfaces.append(activeMatIdx) + faces.append([v0, v1, v2]) + tfaces.append([t0, t1, t2]) + nfaces.append([n0, n1, n2]) + assert len(tfaces) == len(faces) and len(nfaces) == len (faces) + + # Create an "uber" material by combining all textures into a larger texture + if len(used_materials) > 1: + uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces) + else: + uber_material = used_materials[0] + + vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda') + texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None + normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None + + faces = torch.tensor(faces, dtype=torch.int64, device='cuda') + tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None + nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None + + return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material) + +###################################################################################### +# Save mesh object to objfile +###################################################################################### + +def write_obj(folder, fname, mesh, idx, save_material=True, feat=None, resolution=[256, 256]): + obj_file = os.path.join(folder, fname + '.obj') + print("Writing mesh: ", obj_file) + with open(obj_file, "w") as f: + f.write(f"mtllib {fname}.mtl\n") + f.write("g default\n") + + v_pos = mesh.v_pos[idx].detach().cpu().numpy() if mesh.v_pos is not None else None + v_nrm = mesh.v_nrm[idx].detach().cpu().numpy() if mesh.v_nrm is not None else None + v_tex = mesh.v_tex[idx].detach().cpu().numpy() if mesh.v_tex is not None else None + + t_pos_idx = mesh.t_pos_idx[0].detach().cpu().numpy() if mesh.t_pos_idx is not None else None + t_nrm_idx = mesh.t_nrm_idx[0].detach().cpu().numpy() if mesh.t_nrm_idx is not None else None + t_tex_idx = mesh.t_tex_idx[0].detach().cpu().numpy() if mesh.t_tex_idx is not None else None + + print(" writing %d vertices" % len(v_pos)) + for v in v_pos: + f.write('v {} {} {} \n'.format(v[0], v[1], v[2])) + + if v_tex is not None and save_material: + print(" writing %d texcoords" % len(v_tex)) + assert(len(t_pos_idx) == len(t_tex_idx)) + for v in v_tex: + f.write('vt {} {} \n'.format(v[0], 1.0 - v[1])) + + if v_nrm is not None: + print(" writing %d normals" % len(v_nrm)) + assert(len(t_pos_idx) == len(t_nrm_idx)) + for v in v_nrm: + f.write('vn {} {} {}\n'.format(v[0], v[1], v[2])) + + # faces + f.write("s 1 \n") + f.write("g pMesh1\n") + f.write("usemtl defaultMat\n") + + # Write faces + print(" writing %d faces" % len(t_pos_idx)) + for i in range(len(t_pos_idx)): + f.write("f ") + for j in range(3): + f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1))) + f.write("\n") + + if save_material and mesh.material is not None: + mtl_file = os.path.join(folder, fname + '.mtl') + print("Writing material: ", mtl_file) + material.save_mtl(mtl_file, mesh.material, mesh=mesh.get_n(idx), feat=feat, resolution=resolution) + + print("Done exporting mesh") + + +def write_textured_obj(folder, fname, mesh, idx, save_material=True, feat=None, resolution=[256, 256], prior_shape=None): + mesh = mesh.get_n(idx) + obj_file = os.path.join(folder, fname + '.obj') + print("Writing mesh: ", obj_file) + + # Create uvs with xatlas + v_pos = mesh.v_pos.detach().cpu().numpy() + t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() + + # v_color = torch.Tensor(v_pos)[None].to("cuda") + # v_color = mesh.material.sample(v_color, feat) + # v_color = v_color[0,0,:,:3].detach().cpu() + # v_color = torch.concat([v_color, torch.ones((v_color.shape[0], 1))], dim=-1) + # v_color = v_color.numpy() * 255 + # v_color = v_color.astype(np.int32) + # tmp = trimesh.Trimesh(vertices=v_pos[0], faces=t_pos_idx[0], vertex_colors=v_color) + # _ = tmp.export("tmp.obj") + # from pdb import set_trace; set_trace() + + atlas = xatlas.Atlas() + atlas.add_mesh( + v_pos[0], + t_pos_idx[0], + ) + co = xatlas.ChartOptions() + po = xatlas.PackOptions() + # for k, v in xatlas_chart_options.items(): + # setattr(co, k, v) + # for k, v in xatlas_pack_options.items(): + # setattr(po, k, v) + atlas.generate(co, po) + vmapping, indices, uvs = atlas.get_mesh(0) + # vmapping, indices, uvs = xatlas.parametrize(v_pos[0], t_pos_idx[0]) + + # Convert to tensors + indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) + + uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda') + faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda') + + # new_mesh = Mesh(v_tex=uvs, t_tex_idx=faces, base=mesh) + new_mesh = Mesh(v_tex=uvs[None], t_tex_idx=faces[None], base=mesh) + + # glctx = dr.RasterizeGLContext() + # mask, kd, ks, normal = render_uv(glctx, new_mesh, resolution, mesh.material, feat=feat) + + # kd_min, kd_max = torch.tensor([ 0.0, 0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'), torch.tensor([ 1.0, 1.0, 1.0, 1.0], dtype=torch.float32, device='cuda') + # ks_min, ks_max = torch.tensor([ 0.0, 0.0, 0.0] , dtype=torch.float32, device='cuda'), torch.tensor([ 0.0, 0.0, 0.0] , dtype=torch.float32, device='cuda') + # nrm_min, nrm_max = torch.tensor([-1.0, -1.0, 0.0], dtype=torch.float32, device='cuda'), torch.tensor([ 1.0, 1.0, 1.0], dtype=torch.float32, device='cuda') + + new_mesh.material = material.Material({ + 'bsdf' : 'diffuse', + # 'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]), + # 'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]), + # 'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max]), + 'kd_ks_normal': mesh.material + }) + + with open(obj_file, "w") as f: + f.write(f"mtllib {fname}.mtl\n") + f.write("g default\n") + + v_pos = new_mesh.v_pos[idx].detach().cpu().numpy() if new_mesh.v_pos is not None else None + v_nrm = new_mesh.v_nrm[idx].detach().cpu().numpy() if new_mesh.v_nrm is not None else None + v_tex = new_mesh.v_tex[idx].detach().cpu().numpy() if new_mesh.v_tex is not None else None + + t_pos_idx = new_mesh.t_pos_idx[0].detach().cpu().numpy() if new_mesh.t_pos_idx is not None else None + t_nrm_idx = new_mesh.t_nrm_idx[0].detach().cpu().numpy() if new_mesh.t_nrm_idx is not None else None + t_tex_idx = new_mesh.t_tex_idx[0].detach().cpu().numpy() if new_mesh.t_tex_idx is not None else None + + print(" writing %d vertices" % len(v_pos)) + for v in v_pos: + f.write('v {} {} {} \n'.format(v[0], v[1], v[2])) + + if v_tex is not None and save_material: + print(" writing %d texcoords" % len(v_tex)) + assert(len(t_pos_idx) == len(t_tex_idx)) + for v in v_tex: + f.write('vt {} {} \n'.format(v[0], 1.0 - v[1])) + + if v_nrm is not None: + print(" writing %d normals" % len(v_nrm)) + assert(len(t_pos_idx) == len(t_nrm_idx)) + for v in v_nrm: + f.write('vn {} {} {}\n'.format(v[0], v[1], v[2])) + + # faces + f.write("s 1 \n") + f.write("g pMesh1\n") + f.write("usemtl defaultMat\n") + + # Write faces + print(" writing %d faces" % len(t_pos_idx)) + for i in range(len(t_pos_idx)): + f.write("f ") + for j in range(3): + f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1))) + f.write("\n") + + mtl_file = os.path.join(folder, fname + '.mtl') + print("Writing material: ", mtl_file) + material.save_mtl(mtl_file, new_mesh.material, mesh=new_mesh, feat=feat, resolution=resolution, prior_shape=prior_shape) + + print("Done exporting mesh") \ No newline at end of file diff --git a/video3d/render/regularizer.py b/video3d/render/regularizer.py new file mode 100755 index 0000000000000000000000000000000000000000..27f918d1cabe171f5514b63f41c79a77e997de4e --- /dev/null +++ b/video3d/render/regularizer.py @@ -0,0 +1,93 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch +import nvdiffrast.torch as dr + +from . import util +from . import mesh + +###################################################################################### +# Computes the image gradient, useful for kd/ks smoothness losses +###################################################################################### +def image_grad(buf, std=0.01): + t, s = torch.meshgrid(torch.linspace(-1.0 + 1.0 / buf.shape[1], 1.0 - 1.0 / buf.shape[1], buf.shape[1], device="cuda"), + torch.linspace(-1.0 + 1.0 / buf.shape[2], 1.0 - 1.0 / buf.shape[2], buf.shape[2], device="cuda"), + indexing='ij') + tc = torch.normal(mean=0, std=std, size=(buf.shape[0], buf.shape[1], buf.shape[2], 2), device="cuda") + torch.stack((s, t), dim=-1)[None, ...] + tap = dr.texture(buf, tc, filter_mode='linear', boundary_mode='clamp') + return torch.abs(tap[..., :-1] - buf[..., :-1]) * tap[..., -1:] * buf[..., -1:] + +###################################################################################### +# Computes the avergage edge length of a mesh. +# Rough estimate of the tessellation of a mesh. Can be used e.g. to clamp gradients +###################################################################################### +def avg_edge_length(v_pos, t_pos_idx): + e_pos_idx = mesh.compute_edges(t_pos_idx) + edge_len = util.length(v_pos[:, e_pos_idx[:, 0]] - v_pos[:, e_pos_idx[:, 1]]) + return torch.mean(edge_len) + +###################################################################################### +# Laplacian regularization using umbrella operator (Fujiwara / Desbrun). +# https://mgarland.org/class/geom04/material/smoothing.pdf +###################################################################################### +def laplace_regularizer_const(v_pos, t_pos_idx): + batch_size = v_pos.shape[0] + + term = torch.zeros_like(v_pos) + norm = torch.zeros_like(v_pos[..., 0:1]) + + v0 = v_pos[:, t_pos_idx[0, :, 0], :] + v1 = v_pos[:, t_pos_idx[0, :, 1], :] + v2 = v_pos[:, t_pos_idx[0, :, 2], :] + + term.scatter_add_(1, t_pos_idx[..., 0:1].repeat(batch_size, 1, 3), (v1 - v0) + (v2 - v0)) + term.scatter_add_(1, t_pos_idx[..., 1:2].repeat(batch_size, 1, 3), (v0 - v1) + (v2 - v1)) + term.scatter_add_(1, t_pos_idx[..., 2:3].repeat(batch_size, 1, 3), (v0 - v2) + (v1 - v2)) + + two = torch.ones_like(v0) * 2.0 + # norm.scatter_add_(1, t_pos_idx[..., 0:1].repeat(batch_size, 1, 3), two) + # norm.scatter_add_(1, t_pos_idx[..., 1:2].repeat(batch_size, 1, 3), two) + # norm.scatter_add_(1, t_pos_idx[..., 2:3].repeat(batch_size, 1, 3), two) + norm.scatter_add_(1, t_pos_idx[..., 0:1].repeat(batch_size, 1, 1), two) + norm.scatter_add_(1, t_pos_idx[..., 1:2].repeat(batch_size, 1, 1), two) + norm.scatter_add_(1, t_pos_idx[..., 2:3].repeat(batch_size, 1, 1), two) + + term = term / torch.clamp(norm, min=1.0) + + return torch.mean(term ** 2) + +###################################################################################### +# Smooth vertex normals +###################################################################################### +def normal_consistency(v_pos, t_pos_idx): + # Compute face normals + v0 = v_pos[:, t_pos_idx[0, :, 0]] + v1 = v_pos[:, t_pos_idx[0, :, 1]] + v2 = v_pos[:, t_pos_idx[0, :, 2]] + + face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0, dim=-1)) + + tris_per_edge = mesh.compute_edge_to_face_mapping(t_pos_idx) + + # Fetch normals for both faces sharing an edge + n0 = face_normals[:, tris_per_edge[:, 0], :] + n1 = face_normals[:, tris_per_edge[:, 1], :] + + # Compute error metric based on normal difference + term = torch.clamp(util.dot(n0, n1), min=-1.0, max=1.0) + term = (1.0 - term) * 0.5 + + return torch.mean(torch.abs(term)) + + +def get_edge_length(v_pos, t_pos_idx): + e_pos_idx = mesh.compute_edges(t_pos_idx) + edge_len = util.length(v_pos[:, e_pos_idx[:, 0]] - v_pos[:, e_pos_idx[:, 1]]) + return edge_len diff --git a/video3d/render/render.py b/video3d/render/render.py new file mode 100755 index 0000000000000000000000000000000000000000..657bb45ad2b3d1e78a5c83578a91d2f6bdad8d8a --- /dev/null +++ b/video3d/render/render.py @@ -0,0 +1,369 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch +import nvdiffrast.torch as dr + +from . import util +from . import renderutils as ru +from . import light + +# ============================================================================================== +# Helper functions +# ============================================================================================== +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') + +# ============================================================================================== +# pixel shader +# ============================================================================================== +def shade( + gb_pos, + gb_geometric_normal, + gb_normal, + gb_tangent, + gb_tex_pos, + gb_texc, + gb_texc_deriv, + w2c, + view_pos, + lgt, + material, + bsdf, + feat, + two_sided_shading, + delta_xy_interp=None, + dino_pred=None, + class_vector=None, + im_features_map=None, + mvp=None + ): + + ################################################################################ + # Texture lookups + ################################################################################ + perturbed_nrm = None + # Combined texture, used for MLPs because lookups are expensive + # all_tex_jitter = material.sample(gb_tex_pos + torch.normal(mean=0, std=0.01, size=gb_tex_pos.shape, device="cuda"), feat=feat) + if material is not None: + if im_features_map is None: + all_tex = material.sample(gb_tex_pos, feat=feat) + else: + all_tex = material.sample(gb_tex_pos, feat=feat, feat_map=im_features_map, mvp=mvp, w2c=w2c, deform_xyz=gb_pos) + else: + all_tex = torch.ones(*gb_pos.shape[:-1], 9, device=gb_pos.device) + kd, ks, perturbed_nrm = all_tex[..., :3], all_tex[..., 3:6], all_tex[..., 6:9] + + # Compute albedo (kd) gradient, used for material regularizer + # kd_grad = torch.sum(torch.abs(all_tex_jitter[..., :-6] - all_tex[..., :-6]), dim=-1, keepdim=True) / + + if dino_pred is not None and class_vector is None: + # DOR: predive the dino value using x,y,z, we would concatenate the label vector. + # trained together, generated image as the supervision for the one-hot-vector. + dino_feat_im_pred = dino_pred.sample(gb_tex_pos) + # dino_feat_im_pred = dino_pred.sample(gb_tex_pos.detach()) + if dino_pred is not None and class_vector is not None: + dino_feat_im_pred = dino_pred.sample(gb_tex_pos, feat=class_vector) + + # else: + # kd_jitter = material['kd'].sample(gb_texc + torch.normal(mean=0, std=0.005, size=gb_texc.shape, device="cuda"), gb_texc_deriv) + # kd = material['kd'].sample(gb_texc, gb_texc_deriv) + # ks = material['ks'].sample(gb_texc, gb_texc_deriv)[..., 0:3] # skip alpha + # if 'normal' in material: + # perturbed_nrm = material['normal'].sample(gb_texc, gb_texc_deriv) + # kd_grad = torch.sum(torch.abs(kd_jitter[..., 0:3] - kd[..., 0:3]), dim=-1, keepdim=True) / 3 + + # Separate kd into alpha and color, default alpha = 1 + # alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1]) + # kd = kd[..., 0:3] + alpha = torch.ones_like(kd[..., 0:1]) + + ################################################################################ + # Normal perturbation & normal bend + ################################################################################ + if material is None or not material.perturb_normal: + perturbed_nrm = None + + gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=two_sided_shading, opengl=True, use_python=True) + + # if two_sided_shading: + # view_vec = util.safe_normalize(view_pos - gb_pos, -1) + # gb_normal = torch.where(torch.sum(gb_geometric_normal * view_vec, -1, keepdim=True) > 0, gb_geometric_normal, -gb_geometric_normal) + # else: + # gb_normal = gb_geometric_normal + + b, h, w, _ = gb_normal.shape + cam_normal = util.safe_normalize(torch.matmul(gb_normal.view(b, -1, 3), w2c[:,:3,:3].transpose(2,1))).view(b, h, w, 3) + + ################################################################################ + # Evaluate BSDF + ################################################################################ + + assert bsdf is not None or material.bsdf is not None, "Material must specify a BSDF type" + bsdf = bsdf if bsdf is not None else material.bsdf + shading = None + if bsdf == 'pbr': + if isinstance(lgt, light.EnvironmentLight): + shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True) + else: + assert False, "Invalid light type" + elif bsdf == 'diffuse': + if lgt is None: + shaded_col = kd + elif isinstance(lgt, light.EnvironmentLight): + shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=False) + # elif isinstance(lgt, light.DirectionalLight): + # shaded_col, shading = lgt.shade(feat, kd, cam_normal) + # else: + # assert False, "Invalid light type" + else: + shaded_col, shading = lgt.shade(feat, kd, cam_normal) + elif bsdf == 'normal': + shaded_col = (gb_normal + 1.0) * 0.5 + elif bsdf == 'geo_normal': + shaded_col = (gb_geometric_normal + 1.0) * 0.5 + elif bsdf == 'tangent': + shaded_col = (gb_tangent + 1.0) * 0.5 + elif bsdf == 'kd': + shaded_col = kd + elif bsdf == 'ks': + shaded_col = ks + else: + assert False, "Invalid BSDF '%s'" % bsdf + + # Return multiple buffers + buffers = { + 'kd' : torch.cat((kd, alpha), dim=-1), + 'shaded' : torch.cat((shaded_col, alpha), dim=-1), + # 'kd_grad' : torch.cat((kd_grad, alpha), dim=-1), + # 'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1), + } + + if dino_pred is not None: + buffers['dino_feat_im_pred'] = torch.cat((dino_feat_im_pred, alpha), dim=-1) + + if delta_xy_interp is not None: + buffers['flow'] = torch.cat((delta_xy_interp, alpha), dim=-1) + + if shading is not None: + buffers['shading'] = torch.cat((shading, alpha), dim=-1) + + return buffers + +# ============================================================================================== +# Render a depth slice of the mesh (scene), some limitations: +# - Single light +# - Single material +# ============================================================================================== +def render_layer( + rast, + rast_deriv, + mesh, + w2c, + view_pos, + material, + lgt, + resolution, + spp, + msaa, + bsdf, + feat, + prior_mesh, + two_sided_shading, + render_flow, + delta_xy=None, + dino_pred=None, + class_vector=None, + im_features_map=None, + mvp=None + ): + + full_res = [resolution[0]*spp, resolution[1]*spp] + + if prior_mesh is None: + prior_mesh = mesh + + ################################################################################ + # Rasterize + ################################################################################ + + # Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution + if spp > 1 and msaa: + rast_out_s = util.scale_img_nhwc(rast, resolution, mag='nearest', min='nearest') + rast_out_deriv_s = util.scale_img_nhwc(rast_deriv, resolution, mag='nearest', min='nearest') * spp + else: + rast_out_s = rast + rast_out_deriv_s = rast_deriv + + if render_flow: + delta_xy_interp, _ = interpolate(delta_xy, rast_out_s, mesh.t_pos_idx[0].int()) + else: + delta_xy_interp = None + + ################################################################################ + # Interpolate attributes + ################################################################################ + + # Interpolate world space position + gb_pos, _ = interpolate(mesh.v_pos, rast_out_s, mesh.t_pos_idx[0].int()) + + # Compute geometric normals. We need those because of bent normals trick (for bump mapping) + v0 = mesh.v_pos[:, mesh.t_pos_idx[0, :, 0], :] + v1 = mesh.v_pos[:, mesh.t_pos_idx[0, :, 1], :] + v2 = mesh.v_pos[:, mesh.t_pos_idx[0, :, 2], :] + face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0, dim=-1)) + num_faces = face_normals.shape[1] + face_normal_indices = (torch.arange(0, num_faces, dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3) + gb_geometric_normal, _ = interpolate(face_normals, rast_out_s, face_normal_indices.int()) + + # Compute tangent space + assert mesh.v_nrm is not None and mesh.v_tng is not None + gb_normal, _ = interpolate(mesh.v_nrm, rast_out_s, mesh.t_nrm_idx[0].int()) + gb_tangent, _ = interpolate(mesh.v_tng, rast_out_s, mesh.t_tng_idx[0].int()) # Interpolate tangents + + # Texture coordinate + assert mesh.v_tex is not None + gb_texc, gb_texc_deriv = interpolate(mesh.v_tex, rast_out_s, mesh.t_tex_idx[0].int(), rast_db=rast_out_deriv_s) + + ################################################################################ + # Shade + ################################################################################ + + gb_tex_pos, _ = interpolate(prior_mesh.v_pos, rast_out_s, mesh.t_pos_idx[0].int()) + buffers = shade(gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_tex_pos, gb_texc, gb_texc_deriv, w2c, view_pos, lgt, material, bsdf, feat=feat, two_sided_shading=two_sided_shading, delta_xy_interp=delta_xy_interp, dino_pred=dino_pred, class_vector=class_vector, im_features_map=im_features_map, mvp=mvp) + + ################################################################################ + # Prepare output + ################################################################################ + + # Scale back up to visibility resolution if using MSAA + if spp > 1 and msaa: + for key in buffers.keys(): + buffers[key] = util.scale_img_nhwc(buffers[key], full_res, mag='nearest', min='nearest') + + # Return buffers + return buffers + +# ============================================================================================== +# Render a depth peeled mesh (scene), some limitations: +# - Single light +# - Single material +# ============================================================================================== +def render_mesh( + ctx, + mesh, + mtx_in, + w2c, + view_pos, + material, + lgt, + resolution, + spp = 1, + num_layers = 1, + msaa = False, + background = None, + bsdf = None, + feat = None, + prior_mesh = None, + two_sided_shading = True, + render_flow = False, + dino_pred = None, + class_vector = None, + num_frames = None, + im_features_map = None + ): + + def prepare_input_vector(x): + x = torch.tensor(x, dtype=torch.float32, device='cuda') if not torch.is_tensor(x) else x + return x[:, None, None, :] if len(x.shape) == 2 else x + + def composite_buffer(key, layers, background, antialias): + accum = background + for buffers, rast in reversed(layers): + alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:] + accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha) + if antialias: + accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx[0].int()) + return accum + + assert mesh.t_pos_idx.shape[1] > 0, "Got empty training triangle mesh (unrecoverable discontinuity)" + assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1]) + + full_res = [resolution[0] * spp, resolution[1] * spp] + + # Convert numpy arrays to torch tensors + mtx_in = torch.tensor(mtx_in, dtype=torch.float32, device='cuda') if not torch.is_tensor(mtx_in) else mtx_in + view_pos = prepare_input_vector(view_pos) # Shape: (B, 1, 1, 3) + + # clip space transform + v_pos_clip = ru.xfm_points(mesh.v_pos, mtx_in, use_python=True) + + # render flow + if render_flow: + v_pos_clip2 = v_pos_clip[..., :2] / v_pos_clip[..., -1:] + v_pos_clip2 = v_pos_clip2.view(-1, num_frames, *v_pos_clip2.shape[1:]) + delta_xy = v_pos_clip2[:, 1:] - v_pos_clip2[:, :-1] + delta_xy = torch.cat([delta_xy, torch.zeros_like(delta_xy[:, :1])], dim=1) + delta_xy = delta_xy.view(-1, *delta_xy.shape[2:]) + else: + delta_xy = None + + # Render all layers front-to-back + layers = [] + with dr.DepthPeeler(ctx, v_pos_clip, mesh.t_pos_idx[0].int(), full_res) as peeler: + for _ in range(num_layers): + rast, db = peeler.rasterize_next_layer() + rendered = render_layer(rast, db, mesh, w2c, view_pos, material, lgt, resolution, spp, msaa, bsdf, feat=feat, prior_mesh=prior_mesh, two_sided_shading=two_sided_shading, render_flow=render_flow, delta_xy=delta_xy, dino_pred=dino_pred, class_vector=class_vector, im_features_map=im_features_map, mvp=mtx_in) + layers += [(rendered, rast)] + + # Setup background + if background is not None: + if spp > 1: + background = util.scale_img_nhwc(background, full_res, mag='nearest', min='nearest') + background = torch.cat((background, torch.zeros_like(background[..., 0:1])), dim=-1) + else: + background = torch.zeros(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda') + + # Composite layers front-to-back + out_buffers = {} + for key in layers[0][0].keys(): + antialias = key in ['shaded', 'dino_feat_im_pred', 'flow'] + bg = background if key in ['shaded'] else torch.zeros_like(layers[0][0][key]) + accum = composite_buffer(key, layers, bg, antialias) + + # Downscale to framebuffer resolution. Use avg pooling + out_buffers[key] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum + + return out_buffers + +# ============================================================================================== +# Render UVs +# ============================================================================================== +def render_uv(ctx, mesh, resolution, mlp_texture, feat=None, prior_shape=None): + + # clip space transform + uv_clip = mesh.v_tex * 2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx[0].int(), resolution) + + # Interpolate world space position + if prior_shape is not None: + gb_pos, _ = interpolate(prior_shape.v_pos, rast, mesh.t_pos_idx[0].int()) + else: + gb_pos, _ = interpolate(mesh.v_pos, rast, mesh.t_pos_idx[0].int()) + + # Sample out textures from MLP + all_tex = mlp_texture.sample(gb_pos, feat=feat) + assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels" + perturbed_nrm = all_tex[..., -3:] + return (rast[..., -1:] > 0).float(), all_tex[..., :-6], all_tex[..., -6:-3], util.safe_normalize(perturbed_nrm) diff --git a/video3d/render/renderutils/__init__.py b/video3d/render/renderutils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..f29739f961e48de71c58b4bbc45801654df49a70 --- /dev/null +++ b/video3d/render/renderutils/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith +__all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ] diff --git a/video3d/render/renderutils/bsdf.py b/video3d/render/renderutils/bsdf.py new file mode 100755 index 0000000000000000000000000000000000000000..b36dd13a589f1e45882f78ed58355720ca1d3866 --- /dev/null +++ b/video3d/render/renderutils/bsdf.py @@ -0,0 +1,151 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import math +import torch + +NORMAL_THRESHOLD = 0.1 + +################################################################################ +# Vector utility functions +################################################################################ + +def _dot(x, y): + return torch.sum(x*y, -1, keepdim=True) + +def _reflect(x, n): + return 2*_dot(x, n)*n - x + +def _safe_normalize(x): + return torch.nn.functional.normalize(x, dim = -1) + +def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading): + # Swap normal direction for backfacing surfaces + if two_sided_shading: + smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm) + geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm) + + t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1) + return torch.lerp(geom_nrm, smooth_nrm, t) + + +def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl): + smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm, dim=-1)) + if opengl: + shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0) + else: + shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0) + return _safe_normalize(shading_nrm) + +def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl): + smooth_nrm = _safe_normalize(smooth_nrm) + smooth_tng = _safe_normalize(smooth_tng) + view_vec = _safe_normalize(view_pos - pos) + shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl) + return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading) + +################################################################################ +# Simple lambertian diffuse BSDF +################################################################################ + +def bsdf_lambert(nrm, wi): + return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi + +################################################################################ +# Frostbite diffuse +################################################################################ + +def bsdf_frostbite(nrm, wi, wo, linearRoughness): + wiDotN = _dot(wi, nrm) + woDotN = _dot(wo, nrm) + + h = _safe_normalize(wo + wi) + wiDotH = _dot(wi, h) + + energyBias = 0.5 * linearRoughness + energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness + f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness + f0 = 1.0 + + wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN) + woScatter = bsdf_fresnel_shlick(f0, f90, woDotN) + res = wiScatter * woScatter * energyFactor + return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res)) + +################################################################################ +# Phong specular, loosely based on mitsuba implementation +################################################################################ + +def bsdf_phong(nrm, wo, wi, N): + dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0) + dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0) + return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi) + +################################################################################ +# PBR's implementation of GGX specular +################################################################################ + +specular_epsilon = 1e-4 + +def bsdf_fresnel_shlick(f0, f90, cosTheta): + _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) + return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0 + +def bsdf_ndf_ggx(alphaSqr, cosTheta): + _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) + d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1 + return alphaSqr / (d * d * math.pi) + +def bsdf_lambda_ggx(alphaSqr, cosTheta): + _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) + cosThetaSqr = _cosTheta * _cosTheta + tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr + res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0) + return res + +def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO): + lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI) + lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO) + return 1 / (1 + lambdaI + lambdaO) + +def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08): + _alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0) + alphaSqr = _alpha * _alpha + + h = _safe_normalize(wo + wi) + woDotN = _dot(wo, nrm) + wiDotN = _dot(wi, nrm) + woDotH = _dot(wo, h) + nDotH = _dot(nrm, h) + + D = bsdf_ndf_ggx(alphaSqr, nDotH) + G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN) + F = bsdf_fresnel_shlick(col, 1, woDotH) + + w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon) + + frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon) + return torch.where(frontfacing, w, torch.zeros_like(w)) + +def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF): + wo = _safe_normalize(view_pos - pos) + wi = _safe_normalize(light_pos - pos) + + spec_str = arm[..., 0:1] # x component + roughness = arm[..., 1:2] # y component + metallic = arm[..., 2:3] # z component + ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str) + kd = kd * (1.0 - metallic) + + if BSDF == 0: + diffuse = kd * bsdf_lambert(nrm, wi) + else: + diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness) + specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness) + return diffuse + specular diff --git a/video3d/render/renderutils/c_src/bsdf.cu b/video3d/render/renderutils/c_src/bsdf.cu new file mode 100755 index 0000000000000000000000000000000000000000..c167214f9a4cb42b8d640202969e3950be8b806d --- /dev/null +++ b/video3d/render/renderutils/c_src/bsdf.cu @@ -0,0 +1,710 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "common.h" +#include "bsdf.h" + +#define SPECULAR_EPSILON 1e-4f + +//------------------------------------------------------------------------ +// Lambert functions + +__device__ inline float fwdLambert(const vec3f nrm, const vec3f wi) +{ + return max(dot(nrm, wi) / M_PI, 0.0f); +} + +__device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out) +{ + if (dot(nrm, wi) > 0.0f) + bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI); +} + +//------------------------------------------------------------------------ +// Fresnel Schlick + +__device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = powf(1.0f - _cosTheta, 5.0f); + return f0 * (1.0f - scale) + f90 * scale; +} + +__device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); + d_f0 += d_out * (1.0 - scale); + d_f90 += d_out * scale; + if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f); + } +} + +__device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = powf(1.0f - _cosTheta, 5.0f); + return f0 * (1.0f - scale) + f90 * scale; +} + +__device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); + d_f0 += d_out * (1.0 - scale); + d_f90 += d_out * scale; + if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f)); + } +} + +//------------------------------------------------------------------------ +// Frostbite diffuse + +__device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness) +{ + float wiDotN = dot(wi, nrm); + float woDotN = dot(wo, nrm); + if (wiDotN > 0.0f && woDotN > 0.0f) + { + vec3f h = safeNormalize(wo + wi); + float wiDotH = dot(wi, h); + + float energyBias = 0.5f * linearRoughness; + float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float f0 = 1.f; + + float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN); + float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + + return wiScatter * woScatter * energyFactor; + } + else return 0.0f; +} + +__device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out) +{ + float wiDotN = dot(wi, nrm); + float woDotN = dot(wo, nrm); + + if (wiDotN > 0.0f && woDotN > 0.0f) + { + vec3f h = safeNormalize(wo + wi); + float wiDotH = dot(wi, h); + + float energyBias = 0.5f * linearRoughness; + float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float f0 = 1.f; + + float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN); + float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + + // -------------- BWD -------------- + // Backprop: return wiScatter * woScatter * energyFactor; + float d_wiScatter = d_out * woScatter * energyFactor; + float d_woScatter = d_out * wiScatter * energyFactor; + float d_energyFactor = d_out * wiScatter * woScatter; + + // Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f; + bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter); + + // Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN); + float d_wiDotN = 0.0f; + bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter); + + // Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float d_energyBias = d_f90; + float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness; + d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH; + + // Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor; + + // Backprop: float energyBias = 0.5f * linearRoughness; + d_linearRoughness += 0.5 * d_energyBias; + + // Backprop: float wiDotH = dot(wi, h); + vec3f d_h(0); + bwdDot(wi, h, d_wi, d_h, d_wiDotH); + + // Backprop: vec3f h = safeNormalize(wo + wi); + vec3f d_wo_wi(0); + bwdSafeNormalize(wo + wi, d_wo_wi, d_h); + d_wi += d_wo_wi; d_wo += d_wo_wi; + + bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN); + bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN); + } +} + +//------------------------------------------------------------------------ +// Ndf GGX + +__device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f; + return alphaSqr / (d * d * M_PI); +} + +__device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) +{ + // Torch only back propagates if clamp doesn't trigger + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); + if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); + } +} + +//------------------------------------------------------------------------ +// Lambda GGX + +__device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; + float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); + return res; +} + +__device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; + float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); + + d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f); + if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f)); +} + +//------------------------------------------------------------------------ +// Masking GGX + +__device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO) +{ + float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); + float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); + return 1.0f / (1.0f + lambdaI + lambdaO); +} + +__device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out) +{ + // FWD eval + float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); + float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); + + // BWD eval + float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f); + bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO); + bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO); +} + +//------------------------------------------------------------------------ +// GGX specular + +__device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness) +{ + float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); + float alphaSqr = _alpha * _alpha; + + vec3f h = safeNormalize(wo + wi); + float woDotN = dot(wo, nrm); + float wiDotN = dot(wi, nrm); + float woDotH = dot(wo, h); + float nDotH = dot(nrm, h); + + float D = fwdNdfGGX(alphaSqr, nDotH); + float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); + vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH); + vec3f w = F * D * G * 0.25 / woDotN; + + bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); + return frontfacing ? w : 0.0f; +} + +__device__ void bwdPbrSpecular( + const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness, + vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out) +{ + /////////////////////////////////////////////////////////////////////// + // FWD eval + + float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); + float alphaSqr = _alpha * _alpha; + + vec3f h = safeNormalize(wo + wi); + float woDotN = dot(wo, nrm); + float wiDotN = dot(wi, nrm); + float woDotH = dot(wo, h); + float nDotH = dot(nrm, h); + + float D = fwdNdfGGX(alphaSqr, nDotH); + float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); + vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH); + vec3f w = F * D * G * 0.25 / woDotN; + bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); + + if (frontfacing) + { + /////////////////////////////////////////////////////////////////////// + // BWD eval + + vec3f d_F = d_out * D * G * 0.25f / woDotN; + float d_D = sum(d_out * F * G * 0.25f / woDotN); + float d_G = sum(d_out * F * D * 0.25f / woDotN); + + float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN)); + + vec3f d_f90(0); + float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0); + bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F); + bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G); + bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D); + + vec3f d_h(0); + bwdDot(nrm, h, d_nrm, d_h, d_nDotH); + bwdDot(wo, h, d_wo, d_h, d_woDotH); + bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN); + bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN); + + vec3f d_h_unnorm(0); + bwdSafeNormalize(wo + wi, d_h_unnorm, d_h); + d_wo += d_h_unnorm; + d_wi += d_h_unnorm; + + if (alpha > min_roughness * min_roughness) + d_alpha += d_alphaSqr * 2 * alpha; + } +} + +//------------------------------------------------------------------------ +// Full PBR BSDF + +__device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF) +{ + vec3f wo = safeNormalize(view_pos - pos); + vec3f wi = safeNormalize(light_pos - pos); + + float alpha = arm.y * arm.y; + vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); + vec3f diff_col = kd * (1.0f - arm.z); + + float diff = 0.0f; + if (BSDF == 0) + diff = fwdLambert(nrm, wi); + else + diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y); + vec3f diffuse = diff_col * diff; + vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness); + + return diffuse + specular; +} + +__device__ void bwdPbrBSDF( + const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF, + vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out) +{ + //////////////////////////////////////////////////////////////////////// + // FWD + vec3f _wi = light_pos - pos; + vec3f _wo = view_pos - pos; + vec3f wi = safeNormalize(_wi); + vec3f wo = safeNormalize(_wo); + + float alpha = arm.y * arm.y; + vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); + vec3f diff_col = kd * (1.0f - arm.z); + float diff = 0.0f; + if (BSDF == 0) + diff = fwdLambert(nrm, wi); + else + diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y); + + //////////////////////////////////////////////////////////////////////// + // BWD + + float d_alpha(0); + vec3f d_spec_col(0), d_wi(0), d_wo(0); + bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out); + + float d_diff = sum(diff_col * d_out); + if (BSDF == 0) + bwdLambert(nrm, wi, d_nrm, d_wi, d_diff); + else + bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff); + + // Backprop: diff_col = kd * (1.0f - arm.z) + vec3f d_diff_col = d_out * diff; + d_kd += d_diff_col * (1.0f - arm.z); + d_arm.z -= sum(d_diff_col * kd); + + // Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x) + d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z; + d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f)); + d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f)); + + // Backprop: alpha = arm.y * arm.y + d_arm.y += d_alpha * 2 * arm.y; + + // Backprop: vec3f wi = safeNormalize(light_pos - pos); + vec3f d__wi(0); + bwdSafeNormalize(_wi, d__wi, d_wi); + d_light_pos += d__wi; + d_pos -= d__wi; + + // Backprop: vec3f wo = safeNormalize(view_pos - pos); + vec3f d__wo(0); + bwdSafeNormalize(_wo, d__wo, d_wo); + d_view_pos += d__wo; + d_pos -= d__wo; +} + +//------------------------------------------------------------------------ +// Kernels + +__global__ void LambertFwdKernel(LambertKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + + float res = fwdLambert(nrm, wi); + + p.out.store(px, py, pz, res); +} + +__global__ void LambertBwdKernel(LambertKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + vec3f d_nrm(0), d_wi(0); + bwdLambert(nrm, wi, d_nrm, d_wi, d_out); + + p.nrm.store_grad(px, py, pz, d_nrm); + p.wi.store_grad(px, py, pz, d_wi); +} + +__global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + float linearRoughness = p.linearRoughness.fetch1(px, py, pz); + + float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness); + + p.out.store(px, py, pz, res); +} + +__global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + float linearRoughness = p.linearRoughness.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_linearRoughness = 0.0f; + vec3f d_nrm(0), d_wi(0), d_wo(0); + bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out); + + p.nrm.store_grad(px, py, pz, d_nrm); + p.wi.store_grad(px, py, pz, d_wi); + p.wo.store_grad(px, py, pz, d_wo); + p.linearRoughness.store_grad(px, py, pz, d_linearRoughness); +} + +__global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f f0 = p.f0.fetch3(px, py, pz); + vec3f f90 = p.f90.fetch3(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + + vec3f res = fwdFresnelSchlick(f0, f90, cosTheta); + p.out.store(px, py, pz, res); +} + +__global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f f0 = p.f0.fetch3(px, py, pz); + vec3f f90 = p.f90.fetch3(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + vec3f d_f0(0), d_f90(0); + float d_cosTheta(0); + bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out); + + p.f0.store_grad(px, py, pz, d_f0); + p.f90.store_grad(px, py, pz, d_f90); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void ndfGGXFwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float res = fwdNdfGGX(alphaSqr, cosTheta); + + p.out.store(px, py, pz, res); +} + +__global__ void ndfGGXBwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosTheta(0); + bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void lambdaGGXFwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float res = fwdLambdaGGX(alphaSqr, cosTheta); + + p.out.store(px, py, pz, res); +} + +__global__ void lambdaGGXBwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosTheta(0); + bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void maskingSmithFwdKernel(MaskingSmithParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosThetaI = p.cosThetaI.fetch1(px, py, pz); + float cosThetaO = p.cosThetaO.fetch1(px, py, pz); + float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO); + + p.out.store(px, py, pz, res); +} + +__global__ void maskingSmithBwdKernel(MaskingSmithParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosThetaI = p.cosThetaI.fetch1(px, py, pz); + float cosThetaO = p.cosThetaO.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0); + bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosThetaI.store_grad(px, py, pz, d_cosThetaI); + p.cosThetaO.store_grad(px, py, pz, d_cosThetaO); +} + +__global__ void pbrSpecularFwdKernel(PbrSpecular p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f col = p.col.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float alpha = p.alpha.fetch1(px, py, pz); + + vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness); + + p.out.store(px, py, pz, res); +} + +__global__ void pbrSpecularBwdKernel(PbrSpecular p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f col = p.col.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float alpha = p.alpha.fetch1(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + float d_alpha(0); + vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0); + bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out); + + p.col.store_grad(px, py, pz, d_col); + p.nrm.store_grad(px, py, pz, d_nrm); + p.wo.store_grad(px, py, pz, d_wo); + p.wi.store_grad(px, py, pz, d_wi); + p.alpha.store_grad(px, py, pz, d_alpha); +} + +__global__ void pbrBSDFFwdKernel(PbrBSDF p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f kd = p.kd.fetch3(px, py, pz); + vec3f arm = p.arm.fetch3(px, py, pz); + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f light_pos = p.light_pos.fetch3(px, py, pz); + + vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF); + + p.out.store(px, py, pz, res); +} +__global__ void pbrBSDFBwdKernel(PbrBSDF p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f kd = p.kd.fetch3(px, py, pz); + vec3f arm = p.arm.fetch3(px, py, pz); + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f light_pos = p.light_pos.fetch3(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0); + bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out); + + p.kd.store_grad(px, py, pz, d_kd); + p.arm.store_grad(px, py, pz, d_arm); + p.pos.store_grad(px, py, pz, d_pos); + p.nrm.store_grad(px, py, pz, d_nrm); + p.view_pos.store_grad(px, py, pz, d_view_pos); + p.light_pos.store_grad(px, py, pz, d_light_pos); +} diff --git a/video3d/render/renderutils/c_src/bsdf.h b/video3d/render/renderutils/c_src/bsdf.h new file mode 100755 index 0000000000000000000000000000000000000000..59adbf097490c5a643ebdcff9c3784173522e070 --- /dev/null +++ b/video3d/render/renderutils/c_src/bsdf.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct LambertKernelParams +{ + Tensor nrm; + Tensor wi; + Tensor out; + dim3 gridSize; +}; + +struct FrostbiteDiffuseKernelParams +{ + Tensor nrm; + Tensor wi; + Tensor wo; + Tensor linearRoughness; + Tensor out; + dim3 gridSize; +}; + +struct FresnelShlickKernelParams +{ + Tensor f0; + Tensor f90; + Tensor cosTheta; + Tensor out; + dim3 gridSize; +}; + +struct NdfGGXParams +{ + Tensor alphaSqr; + Tensor cosTheta; + Tensor out; + dim3 gridSize; +}; + +struct MaskingSmithParams +{ + Tensor alphaSqr; + Tensor cosThetaI; + Tensor cosThetaO; + Tensor out; + dim3 gridSize; +}; + +struct PbrSpecular +{ + Tensor col; + Tensor nrm; + Tensor wo; + Tensor wi; + Tensor alpha; + Tensor out; + dim3 gridSize; + float min_roughness; +}; + +struct PbrBSDF +{ + Tensor kd; + Tensor arm; + Tensor pos; + Tensor nrm; + Tensor view_pos; + Tensor light_pos; + Tensor out; + dim3 gridSize; + float min_roughness; + int BSDF; +}; diff --git a/video3d/render/renderutils/c_src/common.cpp b/video3d/render/renderutils/c_src/common.cpp new file mode 100755 index 0000000000000000000000000000000000000000..445895e57f7d0bcd6a2812f5ba97d7be2ddfbe28 --- /dev/null +++ b/video3d/render/renderutils/c_src/common.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include + +//------------------------------------------------------------------------ +// Block and grid size calculators for kernel launches. + +dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims) +{ + int maxThreads = maxWidth * maxHeight; + if (maxThreads <= 1 || (dims.x * dims.y) <= 1) + return dim3(1, 1, 1); // Degenerate. + + // Start from max size. + int bw = maxWidth; + int bh = maxHeight; + + // Optimizations for weirdly sized buffers. + if (dims.x < bw) + { + // Decrease block width to smallest power of two that covers the buffer width. + while ((bw >> 1) >= dims.x) + bw >>= 1; + + // Maximize height. + bh = maxThreads / bw; + if (bh > dims.y) + bh = dims.y; + } + else if (dims.y < bh) + { + // Halve height and double width until fits completely inside buffer vertically. + while (bh > dims.y) + { + bh >>= 1; + if (bw < dims.x) + bw <<= 1; + } + } + + // Done. + return dim3(bw, bh, 1); +} + +// returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync) +dim3 getWarpSize(dim3 blockSize) +{ + return dim3( + std::min(blockSize.x, 32u), + std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)), + std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z)) + ); +} + +dim3 getLaunchGridSize(dim3 blockSize, dim3 dims) +{ + dim3 gridSize; + gridSize.x = (dims.x - 1) / blockSize.x + 1; + gridSize.y = (dims.y - 1) / blockSize.y + 1; + gridSize.z = (dims.z - 1) / blockSize.z + 1; + return gridSize; +} + +//------------------------------------------------------------------------ diff --git a/video3d/render/renderutils/c_src/common.h b/video3d/render/renderutils/c_src/common.h new file mode 100755 index 0000000000000000000000000000000000000000..5abaeebdd3f0a0910f7df3e9e0470a9fa682d507 --- /dev/null +++ b/video3d/render/renderutils/c_src/common.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include +#include + +#include "vec3f.h" +#include "vec4f.h" +#include "tensor.h" + +dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims); +dim3 getLaunchGridSize(dim3 blockSize, dim3 dims); + +#ifdef __CUDACC__ + +#ifdef _MSC_VER +#define M_PI 3.14159265358979323846f +#endif + +__host__ __device__ static inline dim3 getWarpSize(dim3 blockSize) +{ + return dim3( + min(blockSize.x, 32u), + min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)), + min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z)) + ); +} + +__device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); } +#else +dim3 getWarpSize(dim3 blockSize); +#endif \ No newline at end of file diff --git a/video3d/render/renderutils/c_src/cubemap.cu b/video3d/render/renderutils/c_src/cubemap.cu new file mode 100755 index 0000000000000000000000000000000000000000..2ce21d83b2dd6759da30874cf8e01b7fd88e9217 --- /dev/null +++ b/video3d/render/renderutils/c_src/cubemap.cu @@ -0,0 +1,350 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "common.h" +#include "cubemap.h" +#include + +// https://cgvr.cs.uni-bremen.de/teaching/cg_literatur/Spherical,%20Cubic,%20and%20Parabolic%20Environment%20Mappings.pdf +__device__ float pixel_area(int x, int y, int N) +{ + if (N > 1) + { + int H = N / 2; + x = abs(x - H); + y = abs(y - H); + float dx = atan((float)(x + 1) / (float)H) - atan((float)x / (float)H); + float dy = atan((float)(y + 1) / (float)H) - atan((float)y / (float)H); + return dx * dy; + } + else + return 1; +} + +__device__ vec3f cube_to_dir(int x, int y, int side, int N) +{ + float fx = 2.0f * (((float)x + 0.5f) / (float)N) - 1.0f; + float fy = 2.0f * (((float)y + 0.5f) / (float)N) - 1.0f; + switch (side) + { + case 0: return safeNormalize(vec3f(1, -fy, -fx)); + case 1: return safeNormalize(vec3f(-1, -fy, fx)); + case 2: return safeNormalize(vec3f(fx, 1, fy)); + case 3: return safeNormalize(vec3f(fx, -1, -fy)); + case 4: return safeNormalize(vec3f(fx, -fy, 1)); + case 5: return safeNormalize(vec3f(-fx, -fy, -1)); + } + return vec3f(0,0,0); // Unreachable +} + +__device__ vec3f dir_to_side(int side, vec3f v) +{ + switch (side) + { + case 0: return vec3f(-v.z, -v.y, v.x); + case 1: return vec3f( v.z, -v.y, -v.x); + case 2: return vec3f( v.x, v.z, v.y); + case 3: return vec3f( v.x, -v.z, -v.y); + case 4: return vec3f( v.x, -v.y, v.z); + case 5: return vec3f(-v.x, -v.y, -v.z); + } + return vec3f(0,0,0); // Unreachable +} + +__device__ void extents_1d(float x, float z, float theta, float& _min, float& _max) +{ + float l = sqrtf(x * x + z * z); + float pxr = x + z * tan(theta) * l, pzr = z - x * tan(theta) * l; + float pxl = x - z * tan(theta) * l, pzl = z + x * tan(theta) * l; + if (pzl <= 0.00001f) + _min = pxl > 0.0f ? FLT_MAX : -FLT_MAX; + else + _min = pxl / pzl; + if (pzr <= 0.00001f) + _max = pxr > 0.0f ? FLT_MAX : -FLT_MAX; + else + _max = pxr / pzr; +} + +__device__ void dir_extents(int side, int N, vec3f v, float theta, int &_xmin, int& _xmax, int& _ymin, int& _ymax) +{ + vec3f c = dir_to_side(side, v); // remap to (x,y,z) where side is at z = 1 + + if (theta < 0.785398f) // PI/4 + { + float xmin, xmax, ymin, ymax; + extents_1d(c.x, c.z, theta, xmin, xmax); + extents_1d(c.y, c.z, theta, ymin, ymax); + + if (xmin > 1.0f || xmax < -1.0f || ymin > 1.0f || ymax < -1.0f) + { + _xmin = -1; _xmax = -1; _ymin = -1; _ymax = -1; // Bad aabb + } + else + { + _xmin = (int)min(max((xmin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + _xmax = (int)min(max((xmax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + _ymin = (int)min(max((ymin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + _ymax = (int)min(max((ymax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + } + } + else + { + _xmin = 0.0f; + _xmax = (float)(N-1); + _ymin = 0.0f; + _ymax = (float)(N-1); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Diffuse kernel +__global__ void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f N = cube_to_dir(px, py, pz, Npx); + + vec3f col(0); + + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + for (int y = 0; y < Npx; ++y) + { + for (int x = 0; x < Npx; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + float costheta = min(max(dot(N, L), 0.0f), 0.999f); + float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere + col += p.cubemap.fetch3(x, y, s) * w; + } + } + } + + p.out.store(px, py, pz, col); +} + +__global__ void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f N = cube_to_dir(px, py, pz, Npx); + vec3f grad = p.out.fetch3(px, py, pz); + + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + for (int y = 0; y < Npx; ++y) + { + for (int x = 0; x < Npx; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + float costheta = min(max(dot(N, L), 0.0f), 0.999f); + float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w); + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////// +// GGX splitsum kernel + +__device__ inline float ndfGGX(const float alphaSqr, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, 0.0, 1.0f); + float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f; + return alphaSqr / (d * d * M_PI); +} + +__global__ void SpecularBoundsKernel(SpecularBoundsKernelParams p) +{ + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.gridSize.x; + vec3f VNR = cube_to_dir(px, py, pz, Npx); + + const int TILE_SIZE = 16; + + // Brute force entire cubemap and compute bounds for the cone + for (int s = 0; s < p.gridSize.z; ++s) + { + // Assume empty BBox + int _min_x = p.gridSize.x - 1, _max_x = 0; + int _min_y = p.gridSize.y - 1, _max_y = 0; + + // For each (8x8) tile + for (int tx = 0; tx < (p.gridSize.x + TILE_SIZE - 1) / TILE_SIZE; tx++) + { + for (int ty = 0; ty < (p.gridSize.y + TILE_SIZE - 1) / TILE_SIZE; ty++) + { + // Compute tile extents + int tsx = tx * TILE_SIZE, tsy = ty * TILE_SIZE; + int tex = min((tx + 1) * TILE_SIZE, p.gridSize.x), tey = min((ty + 1) * TILE_SIZE, p.gridSize.y); + + // Use some blunt interval arithmetics to cull tiles + vec3f L0 = cube_to_dir(tsx, tsy, s, Npx), L1 = cube_to_dir(tex, tsy, s, Npx); + vec3f L2 = cube_to_dir(tsx, tey, s, Npx), L3 = cube_to_dir(tex, tey, s, Npx); + + float minx = min(min(L0.x, L1.x), min(L2.x, L3.x)), maxx = max(max(L0.x, L1.x), max(L2.x, L3.x)); + float miny = min(min(L0.y, L1.y), min(L2.y, L3.y)), maxy = max(max(L0.y, L1.y), max(L2.y, L3.y)); + float minz = min(min(L0.z, L1.z), min(L2.z, L3.z)), maxz = max(max(L0.z, L1.z), max(L2.z, L3.z)); + + float maxdp = max(minx * VNR.x, maxx * VNR.x) + max(miny * VNR.y, maxy * VNR.y) + max(minz * VNR.z, maxz * VNR.z); + if (maxdp >= p.costheta_cutoff) + { + // Test all pixels in tile. + for (int y = tsy; y < tey; ++y) + { + for (int x = tsx; x < tex; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + if (dot(L, VNR) >= p.costheta_cutoff) + { + _min_x = min(_min_x, x); + _max_x = max(_max_x, x); + _min_y = min(_min_y, y); + _max_y = max(_max_y, y); + } + } + } + } + } + } + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 0), _min_x); + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 1), _max_x); + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 2), _min_y); + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 3), _max_y); + } +} + +__global__ void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f VNR = cube_to_dir(px, py, pz, Npx); + + float alpha = p.roughness * p.roughness; + float alphaSqr = alpha * alpha; + + float wsum = 0.0f; + vec3f col(0); + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + int xmin, xmax, ymin, ymax; + xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0)); + xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1)); + ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2)); + ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3)); + + if (xmin <= xmax) + { + for (int y = ymin; y <= ymax; ++y) + { + for (int x = xmin; x <= xmax; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + if (dot(L, VNR) >= p.costheta_cutoff) + { + vec3f H = safeNormalize(L + VNR); + + float wiDotN = max(dot(L, VNR), 0.0f); + float VNRDotH = max(dot(VNR, H), 0.0f); + + float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f; + col += p.cubemap.fetch3(x, y, s) * w; + wsum += w; + } + } + } + } + } + + p.out.store(p.out._nhwcIndex(pz, py, px, 0), col.x); + p.out.store(p.out._nhwcIndex(pz, py, px, 1), col.y); + p.out.store(p.out._nhwcIndex(pz, py, px, 2), col.z); + p.out.store(p.out._nhwcIndex(pz, py, px, 3), wsum); +} + +__global__ void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f VNR = cube_to_dir(px, py, pz, Npx); + + vec3f grad = p.out.fetch3(px, py, pz); + + float alpha = p.roughness * p.roughness; + float alphaSqr = alpha * alpha; + + vec3f col(0); + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + int xmin, xmax, ymin, ymax; + xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0)); + xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1)); + ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2)); + ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3)); + + if (xmin <= xmax) + { + for (int y = ymin; y <= ymax; ++y) + { + for (int x = xmin; x <= xmax; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + if (dot(L, VNR) >= p.costheta_cutoff) + { + vec3f H = safeNormalize(L + VNR); + + float wiDotN = max(dot(L, VNR), 0.0f); + float VNRDotH = max(dot(VNR, H), 0.0f); + + float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f; + + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w); + } + } + } + } + } +} diff --git a/video3d/render/renderutils/c_src/cubemap.h b/video3d/render/renderutils/c_src/cubemap.h new file mode 100755 index 0000000000000000000000000000000000000000..f395cc237d4a46c660bcde18609068a21f3c3fea --- /dev/null +++ b/video3d/render/renderutils/c_src/cubemap.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct DiffuseCubemapKernelParams +{ + Tensor cubemap; + Tensor out; + dim3 gridSize; +}; + +struct SpecularCubemapKernelParams +{ + Tensor cubemap; + Tensor bounds; + Tensor out; + dim3 gridSize; + float costheta_cutoff; + float roughness; +}; + +struct SpecularBoundsKernelParams +{ + float costheta_cutoff; + Tensor out; + dim3 gridSize; +}; diff --git a/video3d/render/renderutils/c_src/loss.cu b/video3d/render/renderutils/c_src/loss.cu new file mode 100755 index 0000000000000000000000000000000000000000..aae5272de3c5364c22ee0bd5fde023d908e9153d --- /dev/null +++ b/video3d/render/renderutils/c_src/loss.cu @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include + +#include "common.h" +#include "loss.h" + +//------------------------------------------------------------------------ +// Utils + +__device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; } + +__device__ float warpSum(float val) { + for (int i = 1; i < 32; i *= 2) + val += __shfl_xor_sync(0xFFFFFFFF, val, i); + return val; +} + +//------------------------------------------------------------------------ +// Tonemapping + +__device__ inline float fwdSRGB(float x) +{ + return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f); +} + +__device__ inline void bwdSRGB(float x, float &d_x, float d_out) +{ + if (x > 0.0031308f) + d_x += d_out * 0.439583f / powf(x, 0.583333f); + else if (x > 0.0f) + d_x += d_out * 12.92f; +} + +__device__ inline vec3f fwdTonemapLogSRGB(vec3f x) +{ + return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f))); +} + +__device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out) +{ + if (x.x > 0.0f && x.x < 65535.0f) + { + bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x); + d_x.x *= 1 / (x.x + 1.0f); + } + if (x.y > 0.0f && x.y < 65535.0f) + { + bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y); + d_x.y *= 1 / (x.y + 1.0f); + } + if (x.z > 0.0f && x.z < 65535.0f) + { + bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z); + d_x.z *= 1 / (x.z + 1.0f); + } +} + +__device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f) +{ + return (img - target) * (img - target) / (img * img + target * target + eps); +} + +__device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f) +{ + float denom = (target * target + img * img + eps); + d_img += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom); + d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom); +} + +__device__ inline float fwdSMAPE(float img, float target, float eps=0.01f) +{ + return abs(img - target) / (img + target + eps); +} + +__device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f) +{ + float denom = (target + img + eps); + d_img += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom); + d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom); +} + +//------------------------------------------------------------------------ +// Kernels + +__global__ void imgLossFwdKernel(LossKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + + float floss = 0.0f; + if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z) + { + vec3f img = p.img.fetch3(px, py, pz); + vec3f target = p.target.fetch3(px, py, pz); + + img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f)); + target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f)); + + if (p.tonemapper == TONEMAPPER_LOG_SRGB) + { + img = fwdTonemapLogSRGB(img); + target = fwdTonemapLogSRGB(target); + } + + vec3f vloss(0); + if (p.loss == LOSS_MSE) + vloss = (img - target) * (img - target); + else if (p.loss == LOSS_RELMSE) + vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z)); + else if (p.loss == LOSS_SMAPE) + vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z)); + else + vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z)); + + floss = sum(vloss) / 3.0f; + } + + floss = warpSum(floss); + + dim3 warpSize = getWarpSize(blockDim); + if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0) + p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss); +} + +__global__ void imgLossBwdKernel(LossKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + dim3 warpSize = getWarpSize(blockDim); + + vec3f _img = p.img.fetch3(px, py, pz); + vec3f _target = p.target.fetch3(px, py, pz); + float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z); + + ///////////////////////////////////////////////////////////////////// + // FWD + + vec3f img = _img, target = _target; + if (p.tonemapper == TONEMAPPER_LOG_SRGB) + { + img = fwdTonemapLogSRGB(img); + target = fwdTonemapLogSRGB(target); + } + + ///////////////////////////////////////////////////////////////////// + // BWD + + vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f; + + vec3f d_img(0), d_target(0); + if (p.loss == LOSS_MSE) + { + d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z)); + d_target = -d_img; + } + else if (p.loss == LOSS_RELMSE) + { + bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x); + bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y); + bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z); + } + else if (p.loss == LOSS_SMAPE) + { + bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x); + bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y); + bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z); + } + else + { + d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z)); + d_target = -d_img; + } + + + if (p.tonemapper == TONEMAPPER_LOG_SRGB) + { + vec3f d__img(0), d__target(0); + bwdTonemapLogSRGB(_img, d__img, d_img); + bwdTonemapLogSRGB(_target, d__target, d_target); + d_img = d__img; d_target = d__target; + } + + if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0; + if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0; + if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0; + if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0; + if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0; + if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0; + + p.img.store_grad(px, py, pz, d_img); + p.target.store_grad(px, py, pz, d_target); +} \ No newline at end of file diff --git a/video3d/render/renderutils/c_src/loss.h b/video3d/render/renderutils/c_src/loss.h new file mode 100755 index 0000000000000000000000000000000000000000..26790bf02de2afd9d27e541edf23d1b064f6f9a9 --- /dev/null +++ b/video3d/render/renderutils/c_src/loss.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +enum TonemapperType +{ + TONEMAPPER_NONE = 0, + TONEMAPPER_LOG_SRGB = 1 +}; + +enum LossType +{ + LOSS_L1 = 0, + LOSS_MSE = 1, + LOSS_RELMSE = 2, + LOSS_SMAPE = 3 +}; + +struct LossKernelParams +{ + Tensor img; + Tensor target; + Tensor out; + dim3 gridSize; + TonemapperType tonemapper; + LossType loss; +}; diff --git a/video3d/render/renderutils/c_src/mesh.cu b/video3d/render/renderutils/c_src/mesh.cu new file mode 100755 index 0000000000000000000000000000000000000000..3690ea3621c38beae03ac9ff228cf5605d303663 --- /dev/null +++ b/video3d/render/renderutils/c_src/mesh.cu @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include + +#include "common.h" +#include "mesh.h" + + +//------------------------------------------------------------------------ +// Kernels + +__global__ void xfmPointsFwdKernel(XfmKernelParams p) +{ + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z; + + __shared__ float mtx[4][4]; + if (threadIdx.x < 16) + mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0)); + __syncthreads(); + + if (px >= p.gridSize.x) + return; + + vec3f pos( + p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0)) + ); + + if (p.isPoints) + { + p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0] + mtx[3][0]); + p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1] + mtx[3][1]); + p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2] + mtx[3][2]); + p.out.store(p.out.nhwcIndex(pz, px, 3, 0), pos.x * mtx[0][3] + pos.y * mtx[1][3] + pos.z * mtx[2][3] + mtx[3][3]); + } + else + { + p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0]); + p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1]); + p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2]); + } +} + +__global__ void xfmPointsBwdKernel(XfmKernelParams p) +{ + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z; + + __shared__ float mtx[4][4]; + if (threadIdx.x < 16) + mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0)); + __syncthreads(); + + if (px >= p.gridSize.x) + return; + + vec3f pos( + p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0)) + ); + + vec4f d_out( + p.out.fetch(p.out.nhwcIndex(pz, px, 0, 0)), + p.out.fetch(p.out.nhwcIndex(pz, px, 1, 0)), + p.out.fetch(p.out.nhwcIndex(pz, px, 2, 0)), + p.out.fetch(p.out.nhwcIndex(pz, px, 3, 0)) + ); + + if (p.isPoints) + { + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2] + d_out.w * mtx[0][3]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2] + d_out.w * mtx[1][3]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2] + d_out.w * mtx[2][3]); + } + else + { + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2]); + } +} \ No newline at end of file diff --git a/video3d/render/renderutils/c_src/mesh.h b/video3d/render/renderutils/c_src/mesh.h new file mode 100755 index 0000000000000000000000000000000000000000..16e2166cc55f41c4482b2c5010529e9c75182d7b --- /dev/null +++ b/video3d/render/renderutils/c_src/mesh.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct XfmKernelParams +{ + bool isPoints; + Tensor points; + Tensor matrix; + Tensor out; + dim3 gridSize; +}; diff --git a/video3d/render/renderutils/c_src/normal.cu b/video3d/render/renderutils/c_src/normal.cu new file mode 100755 index 0000000000000000000000000000000000000000..a50e49e6b5b4061a60ec4d5d8edca2fb0833570e --- /dev/null +++ b/video3d/render/renderutils/c_src/normal.cu @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "common.h" +#include "normal.h" + +#define NORMAL_THRESHOLD 0.1f + +//------------------------------------------------------------------------ +// Perturb shading normal by tangent frame + +__device__ vec3f fwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, bool opengl) +{ + vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm); + vec3f smooth_bitng = safeNormalize(_smooth_bitng); + vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f); + return safeNormalize(_shading_nrm); +} + +__device__ void bwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, vec3f &d_perturbed_nrm, vec3f &d_smooth_nrm, vec3f &d_smooth_tng, const vec3f d_out, bool opengl) +{ + //////////////////////////////////////////////////////////////////////// + // FWD + vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm); + vec3f smooth_bitng = safeNormalize(_smooth_bitng); + vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f); + + //////////////////////////////////////////////////////////////////////// + // BWD + vec3f d_shading_nrm(0); + bwdSafeNormalize(_shading_nrm, d_shading_nrm, d_out); + + vec3f d_smooth_bitng(0); + + if (perturbed_nrm.z > 0.0f) + { + d_smooth_nrm += d_shading_nrm * perturbed_nrm.z; + d_perturbed_nrm.z += sum(d_shading_nrm * smooth_nrm); + } + + d_smooth_bitng += (opengl ? -1 : 1) * d_shading_nrm * perturbed_nrm.y; + d_perturbed_nrm.y += (opengl ? -1 : 1) * sum(d_shading_nrm * smooth_bitng); + + d_smooth_tng += d_shading_nrm * perturbed_nrm.x; + d_perturbed_nrm.x += sum(d_shading_nrm * smooth_tng); + + vec3f d__smooth_bitng(0); + bwdSafeNormalize(_smooth_bitng, d__smooth_bitng, d_smooth_bitng); + + bwdCross(smooth_tng, smooth_nrm, d_smooth_tng, d_smooth_nrm, d__smooth_bitng); +} + +//------------------------------------------------------------------------ +#define bent_nrm_eps 0.001f + +__device__ vec3f fwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm) +{ + float dp = dot(view_vec, smooth_nrm); + float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f); + return geom_nrm * (1.0f - t) + smooth_nrm * t; +} + +__device__ void bwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm, vec3f& d_view_vec, vec3f& d_smooth_nrm, vec3f& d_geom_nrm, const vec3f d_out) +{ + //////////////////////////////////////////////////////////////////////// + // FWD + float dp = dot(view_vec, smooth_nrm); + float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f); + + //////////////////////////////////////////////////////////////////////// + // BWD + if (dp > NORMAL_THRESHOLD) + d_smooth_nrm += d_out; + else + { + // geom_nrm * (1.0f - t) + smooth_nrm * t; + d_geom_nrm += d_out * (1.0f - t); + d_smooth_nrm += d_out * t; + float d_t = sum(d_out * (smooth_nrm - geom_nrm)); + + float d_dp = dp < 0.0f || dp > NORMAL_THRESHOLD ? 0.0f : d_t / NORMAL_THRESHOLD; + + bwdDot(view_vec, smooth_nrm, d_view_vec, d_smooth_nrm, d_dp); + } +} + +//------------------------------------------------------------------------ +// Kernels + +__global__ void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz); + vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz); + vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz); + vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz); + + vec3f smooth_nrm = safeNormalize(_smooth_nrm); + vec3f smooth_tng = safeNormalize(_smooth_tng); + vec3f view_vec = safeNormalize(view_pos - pos); + vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl); + + vec3f res; + if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f) + res = fwdBendNormal(view_vec, -shading_nrm, -geom_nrm); + else + res = fwdBendNormal(view_vec, shading_nrm, geom_nrm); + + p.out.store(px, py, pz, res); +} + +__global__ void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz); + vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz); + vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz); + vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + /////////////////////////////////////////////////////////////////////////////////////////////////// + // FWD + + vec3f smooth_nrm = safeNormalize(_smooth_nrm); + vec3f smooth_tng = safeNormalize(_smooth_tng); + vec3f _view_vec = view_pos - pos; + vec3f view_vec = safeNormalize(view_pos - pos); + + vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl); + + /////////////////////////////////////////////////////////////////////////////////////////////////// + // BWD + + vec3f d_view_vec(0), d_shading_nrm(0), d_geom_nrm(0); + if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f) + { + bwdBendNormal(view_vec, -shading_nrm, -geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out); + d_shading_nrm = -d_shading_nrm; + d_geom_nrm = -d_geom_nrm; + } + else + bwdBendNormal(view_vec, shading_nrm, geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out); + + vec3f d_perturbed_nrm(0), d_smooth_nrm(0), d_smooth_tng(0); + bwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, d_perturbed_nrm, d_smooth_nrm, d_smooth_tng, d_shading_nrm, p.opengl); + + vec3f d__view_vec(0), d__smooth_nrm(0), d__smooth_tng(0); + bwdSafeNormalize(_view_vec, d__view_vec, d_view_vec); + bwdSafeNormalize(_smooth_nrm, d__smooth_nrm, d_smooth_nrm); + bwdSafeNormalize(_smooth_tng, d__smooth_tng, d_smooth_tng); + + p.pos.store_grad(px, py, pz, -d__view_vec); + p.view_pos.store_grad(px, py, pz, d__view_vec); + p.perturbed_nrm.store_grad(px, py, pz, d_perturbed_nrm); + p.smooth_nrm.store_grad(px, py, pz, d__smooth_nrm); + p.smooth_tng.store_grad(px, py, pz, d__smooth_tng); + p.geom_nrm.store_grad(px, py, pz, d_geom_nrm); +} \ No newline at end of file diff --git a/video3d/render/renderutils/c_src/normal.h b/video3d/render/renderutils/c_src/normal.h new file mode 100755 index 0000000000000000000000000000000000000000..8882c225cfba5e747462c056d6fcf0b04dd48751 --- /dev/null +++ b/video3d/render/renderutils/c_src/normal.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct PrepareShadingNormalKernelParams +{ + Tensor pos; + Tensor view_pos; + Tensor perturbed_nrm; + Tensor smooth_nrm; + Tensor smooth_tng; + Tensor geom_nrm; + Tensor out; + dim3 gridSize; + bool two_sided_shading, opengl; +}; diff --git a/video3d/render/renderutils/c_src/tensor.h b/video3d/render/renderutils/c_src/tensor.h new file mode 100755 index 0000000000000000000000000000000000000000..1dfb4e85c46f0394821f2533dc98468e5b7248af --- /dev/null +++ b/video3d/render/renderutils/c_src/tensor.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#if defined(__CUDACC__) && defined(BFLOAT16) +#include // bfloat16 is float32 compatible with less mantissa bits +#endif + +//--------------------------------------------------------------------------------- +// CUDA-side Tensor class for in/out parameter parsing. Can be float32 or bfloat16 + +struct Tensor +{ + void* val; + void* d_val; + int dims[4], _dims[4]; + int strides[4]; + bool fp16; + +#if defined(__CUDA__) && !defined(__CUDA_ARCH__) + Tensor() : val(nullptr), d_val(nullptr), fp16(true), dims{ 0, 0, 0, 0 }, _dims{ 0, 0, 0, 0 }, strides{ 0, 0, 0, 0 } {} +#endif + +#ifdef __CUDACC__ + // Helpers to index and read/write a single element + __device__ inline int _nhwcIndex(int n, int h, int w, int c) const { return n * strides[0] + h * strides[1] + w * strides[2] + c * strides[3]; } + __device__ inline int nhwcIndex(int n, int h, int w, int c) const { return (dims[0] == 1 ? 0 : n * strides[0]) + (dims[1] == 1 ? 0 : h * strides[1]) + (dims[2] == 1 ? 0 : w * strides[2]) + (dims[3] == 1 ? 0 : c * strides[3]); } + __device__ inline int nhwcIndexContinuous(int n, int h, int w, int c) const { return ((n * _dims[1] + h) * _dims[2] + w) * _dims[3] + c; } +#ifdef BFLOAT16 + __device__ inline float fetch(unsigned int idx) const { return fp16 ? __bfloat162float(((__nv_bfloat16*)val)[idx]) : ((float*)val)[idx]; } + __device__ inline void store(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)val)[idx] = __float2bfloat16(_val); else ((float*)val)[idx] = _val; } + __device__ inline void store_grad(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)d_val)[idx] = __float2bfloat16(_val); else ((float*)d_val)[idx] = _val; } +#else + __device__ inline float fetch(unsigned int idx) const { return ((float*)val)[idx]; } + __device__ inline void store(unsigned int idx, float _val) { ((float*)val)[idx] = _val; } + __device__ inline void store_grad(unsigned int idx, float _val) { ((float*)d_val)[idx] = _val; } +#endif + + ////////////////////////////////////////////////////////////////////////////////////////// + // Fetch, use broadcasting for tensor dimensions of size 1 + __device__ inline float fetch1(unsigned int x, unsigned int y, unsigned int z) const + { + return fetch(nhwcIndex(z, y, x, 0)); + } + + __device__ inline vec3f fetch3(unsigned int x, unsigned int y, unsigned int z) const + { + return vec3f( + fetch(nhwcIndex(z, y, x, 0)), + fetch(nhwcIndex(z, y, x, 1)), + fetch(nhwcIndex(z, y, x, 2)) + ); + } + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store, no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside + __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, float _val) + { + store(_nhwcIndex(z, y, x, 0), _val); + } + + __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, vec3f _val) + { + store(_nhwcIndex(z, y, x, 0), _val.x); + store(_nhwcIndex(z, y, x, 1), _val.y); + store(_nhwcIndex(z, y, x, 2), _val.z); + } + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store gradient , no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside + __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, float _val) + { + store_grad(nhwcIndexContinuous(z, y, x, 0), _val); + } + + __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, vec3f _val) + { + store_grad(nhwcIndexContinuous(z, y, x, 0), _val.x); + store_grad(nhwcIndexContinuous(z, y, x, 1), _val.y); + store_grad(nhwcIndexContinuous(z, y, x, 2), _val.z); + } +#endif + +}; diff --git a/video3d/render/renderutils/c_src/torch_bindings.cpp b/video3d/render/renderutils/c_src/torch_bindings.cpp new file mode 100755 index 0000000000000000000000000000000000000000..64c9e70f79507944490cb978233c34ac9e3e97a6 --- /dev/null +++ b/video3d/render/renderutils/c_src/torch_bindings.cpp @@ -0,0 +1,1062 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#ifdef _MSC_VER +#pragma warning(push, 0) +#include +#pragma warning(pop) +#else +#include +#endif + +#include +#include +#include +#include + +#define NVDR_CHECK_CUDA_ERROR(CUDA_CALL) { cudaError_t err = CUDA_CALL; AT_CUDA_CHECK(cudaGetLastError()); } +#define NVDR_CHECK_GL_ERROR(GL_CALL) { GL_CALL; GLenum err = glGetError(); TORCH_CHECK(err == GL_NO_ERROR, "OpenGL error: ", getGLErrorString(err), "[", #GL_CALL, ";]"); } +#define CHECK_TENSOR(X, DIMS, CHANNELS) \ + TORCH_CHECK(X.is_cuda(), #X " must be a cuda tensor") \ + TORCH_CHECK(X.scalar_type() == torch::kFloat || X.scalar_type() == torch::kBFloat16, #X " must be fp32 or bf16") \ + TORCH_CHECK(X.dim() == DIMS, #X " must have " #DIMS " dimensions") \ + TORCH_CHECK(X.size(DIMS - 1) == CHANNELS, #X " must have " #CHANNELS " channels") + +#include "common.h" +#include "loss.h" +#include "normal.h" +#include "cubemap.h" +#include "bsdf.h" +#include "mesh.h" + +#define BLOCK_X 8 +#define BLOCK_Y 8 + +//------------------------------------------------------------------------ +// mesh.cu + +void xfmPointsFwdKernel(XfmKernelParams p); +void xfmPointsBwdKernel(XfmKernelParams p); + +//------------------------------------------------------------------------ +// loss.cu + +void imgLossFwdKernel(LossKernelParams p); +void imgLossBwdKernel(LossKernelParams p); + +//------------------------------------------------------------------------ +// normal.cu + +void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p); +void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p); + +//------------------------------------------------------------------------ +// cubemap.cu + +void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p); +void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p); +void SpecularBoundsKernel(SpecularBoundsKernelParams p); +void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p); +void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p); + +//------------------------------------------------------------------------ +// bsdf.cu + +void LambertFwdKernel(LambertKernelParams p); +void LambertBwdKernel(LambertKernelParams p); + +void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p); +void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p); + +void FresnelShlickFwdKernel(FresnelShlickKernelParams p); +void FresnelShlickBwdKernel(FresnelShlickKernelParams p); + +void ndfGGXFwdKernel(NdfGGXParams p); +void ndfGGXBwdKernel(NdfGGXParams p); + +void lambdaGGXFwdKernel(NdfGGXParams p); +void lambdaGGXBwdKernel(NdfGGXParams p); + +void maskingSmithFwdKernel(MaskingSmithParams p); +void maskingSmithBwdKernel(MaskingSmithParams p); + +void pbrSpecularFwdKernel(PbrSpecular p); +void pbrSpecularBwdKernel(PbrSpecular p); + +void pbrBSDFFwdKernel(PbrBSDF p); +void pbrBSDFBwdKernel(PbrBSDF p); + +//------------------------------------------------------------------------ +// Tensor helpers + +void update_grid(dim3 &gridSize, torch::Tensor x) +{ + gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2)); + gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1)); + gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0)); +} + +template +void update_grid(dim3& gridSize, torch::Tensor x, Ts&&... vs) +{ + gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2)); + gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1)); + gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0)); + update_grid(gridSize, std::forward(vs)...); +} + +Tensor make_cuda_tensor(torch::Tensor val) +{ + Tensor res; + for (int i = 0; i < val.dim(); ++i) + { + res.dims[i] = val.size(i); + res.strides[i] = val.stride(i); + } + res.fp16 = val.scalar_type() == torch::kBFloat16; + res.val = res.fp16 ? (void*)val.data_ptr() : (void*)val.data_ptr(); + res.d_val = nullptr; + return res; +} + +Tensor make_cuda_tensor(torch::Tensor val, dim3 outDims, torch::Tensor* grad = nullptr) +{ + Tensor res; + for (int i = 0; i < val.dim(); ++i) + { + res.dims[i] = val.size(i); + res.strides[i] = val.stride(i); + } + if (val.dim() == 4) + res._dims[0] = outDims.z, res._dims[1] = outDims.y, res._dims[2] = outDims.x, res._dims[3] = val.size(3); + else + res._dims[0] = outDims.z, res._dims[1] = outDims.x, res._dims[2] = val.size(2), res._dims[3] = 1; // Add a trailing one for indexing math to work out + + res.fp16 = val.scalar_type() == torch::kBFloat16; + res.val = res.fp16 ? (void*)val.data_ptr() : (void*)val.data_ptr(); + res.d_val = nullptr; + if (grad != nullptr) + { + if (val.dim() == 4) + *grad = torch::empty({ outDims.z, outDims.y, outDims.x, val.size(3) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA)); + else // 3 + *grad = torch::empty({ outDims.z, outDims.x, val.size(2) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA)); + + res.d_val = res.fp16 ? (void*)grad->data_ptr() : (void*)grad->data_ptr(); + } + return res; +} + +//------------------------------------------------------------------------ +// prepare_shading_normal + +torch::Tensor prepare_shading_normal_fwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, bool two_sided_shading, bool opengl, bool fp16) +{ + CHECK_TENSOR(pos, 4, 3); + CHECK_TENSOR(view_pos, 4, 3); + CHECK_TENSOR(perturbed_nrm, 4, 3); + CHECK_TENSOR(smooth_nrm, 4, 3); + CHECK_TENSOR(smooth_tng, 4, 3); + CHECK_TENSOR(geom_nrm, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PrepareShadingNormalKernelParams p; + p.two_sided_shading = two_sided_shading; + p.opengl = opengl; + p.out.fp16 = fp16; + update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.pos = make_cuda_tensor(pos, p.gridSize); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize); + p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize); + p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize); + p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize); + p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple prepare_shading_normal_bwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, torch::Tensor grad, bool two_sided_shading, bool opengl) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PrepareShadingNormalKernelParams p; + p.two_sided_shading = two_sided_shading; + p.opengl = opengl; + update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + torch::Tensor pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad; + p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad); + p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize, &perturbed_nrm_grad); + p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize, &smooth_nrm_grad); + p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize, &smooth_tng_grad); + p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize, &geom_nrm_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad); +} + +//------------------------------------------------------------------------ +// lambert + +torch::Tensor lambert_fwd(torch::Tensor nrm, torch::Tensor wi, bool fp16) +{ + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(wi, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LambertKernelParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, nrm, wi); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.wi = make_cuda_tensor(wi, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple lambert_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LambertKernelParams p; + update_grid(p.gridSize, nrm, wi); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor nrm_grad, wi_grad; + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(nrm_grad, wi_grad); +} + +//------------------------------------------------------------------------ +// frostbite diffuse + +torch::Tensor frostbite_fwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, bool fp16) +{ + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(wi, 4, 3); + CHECK_TENSOR(wo, 4, 3); + CHECK_TENSOR(linearRoughness, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FrostbiteDiffuseKernelParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, nrm, wi, wo, linearRoughness); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.wi = make_cuda_tensor(wi, p.gridSize); + p.wo = make_cuda_tensor(wo, p.gridSize); + p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple frostbite_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FrostbiteDiffuseKernelParams p; + update_grid(p.gridSize, nrm, wi, wo, linearRoughness); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor nrm_grad, wi_grad, wo_grad, linearRoughness_grad; + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); + p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad); + p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize, &linearRoughness_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(nrm_grad, wi_grad, wo_grad, linearRoughness_grad); +} + +//------------------------------------------------------------------------ +// fresnel_shlick + +torch::Tensor fresnel_shlick_fwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, bool fp16) +{ + CHECK_TENSOR(f0, 4, 3); + CHECK_TENSOR(f90, 4, 3); + CHECK_TENSOR(cosTheta, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FresnelShlickKernelParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, f0, f90, cosTheta); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.f0 = make_cuda_tensor(f0, p.gridSize); + p.f90 = make_cuda_tensor(f90, p.gridSize); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple fresnel_shlick_bwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FresnelShlickKernelParams p; + update_grid(p.gridSize, f0, f90, cosTheta); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor f0_grad, f90_grad, cosT_grad; + p.f0 = make_cuda_tensor(f0, p.gridSize, &f0_grad); + p.f90 = make_cuda_tensor(f90, p.gridSize, &f90_grad); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosT_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(f0_grad, f90_grad, cosT_grad); +} + +//------------------------------------------------------------------------ +// ndf_ggd + +torch::Tensor ndf_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16) +{ + CHECK_TENSOR(alphaSqr, 4, 1); + CHECK_TENSOR(cosTheta, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple ndf_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor alphaSqr_grad, cosTheta_grad; + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(alphaSqr_grad, cosTheta_grad); +} + +//------------------------------------------------------------------------ +// lambda_ggx + +torch::Tensor lambda_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16) +{ + CHECK_TENSOR(alphaSqr, 4, 1); + CHECK_TENSOR(cosTheta, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple lambda_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor alphaSqr_grad, cosTheta_grad; + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(alphaSqr_grad, cosTheta_grad); +} + +//------------------------------------------------------------------------ +// masking_smith + +torch::Tensor masking_smith_fwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, bool fp16) +{ + CHECK_TENSOR(alphaSqr, 4, 1); + CHECK_TENSOR(cosThetaI, 4, 1); + CHECK_TENSOR(cosThetaO, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + MaskingSmithParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); + p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize); + p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple masking_smith_bwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + MaskingSmithParams p; + update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor alphaSqr_grad, cosThetaI_grad, cosThetaO_grad; + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); + p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize, &cosThetaI_grad); + p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize, &cosThetaO_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(alphaSqr_grad, cosThetaI_grad, cosThetaO_grad); +} + +//------------------------------------------------------------------------ +// pbr_specular + +torch::Tensor pbr_specular_fwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, bool fp16) +{ + CHECK_TENSOR(col, 4, 3); + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(wo, 4, 3); + CHECK_TENSOR(wi, 4, 3); + CHECK_TENSOR(alpha, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrSpecular p; + p.out.fp16 = fp16; + p.min_roughness = min_roughness; + update_grid(p.gridSize, col, nrm, wo, wi, alpha); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.col = make_cuda_tensor(col, p.gridSize); + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.wo = make_cuda_tensor(wo, p.gridSize); + p.wi = make_cuda_tensor(wi, p.gridSize); + p.alpha = make_cuda_tensor(alpha, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple pbr_specular_bwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrSpecular p; + update_grid(p.gridSize, col, nrm, wo, wi, alpha); + p.min_roughness = min_roughness; + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad; + p.col = make_cuda_tensor(col, p.gridSize, &col_grad); + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad); + p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); + p.alpha = make_cuda_tensor(alpha, p.gridSize, &alpha_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad); +} + +//------------------------------------------------------------------------ +// pbr_bsdf + +torch::Tensor pbr_bsdf_fwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, bool fp16) +{ + CHECK_TENSOR(kd, 4, 3); + CHECK_TENSOR(arm, 4, 3); + CHECK_TENSOR(pos, 4, 3); + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(view_pos, 4, 3); + CHECK_TENSOR(light_pos, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrBSDF p; + p.out.fp16 = fp16; + p.min_roughness = min_roughness; + p.BSDF = BSDF; + update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.kd = make_cuda_tensor(kd, p.gridSize); + p.arm = make_cuda_tensor(arm, p.gridSize); + p.pos = make_cuda_tensor(pos, p.gridSize); + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize); + p.light_pos = make_cuda_tensor(light_pos, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple pbr_bsdf_bwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrBSDF p; + update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos); + p.min_roughness = min_roughness; + p.BSDF = BSDF; + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad; + p.kd = make_cuda_tensor(kd, p.gridSize, &kd_grad); + p.arm = make_cuda_tensor(arm, p.gridSize, &arm_grad); + p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad); + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad); + p.light_pos = make_cuda_tensor(light_pos, p.gridSize, &light_pos_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad); +} + +//------------------------------------------------------------------------ +// filter_cubemap + +torch::Tensor diffuse_cubemap_fwd(torch::Tensor cubemap) +{ + CHECK_TENSOR(cubemap, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + DiffuseCubemapKernelParams p; + update_grid(p.gridSize, cubemap); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor diffuse_cubemap_bwd(torch::Tensor cubemap, torch::Tensor grad) +{ + CHECK_TENSOR(cubemap, 4, 3); + CHECK_TENSOR(grad, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + DiffuseCubemapKernelParams p; + update_grid(p.gridSize, cubemap); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + torch::Tensor cubemap_grad; + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.out = make_cuda_tensor(grad, p.gridSize); + + cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + p.cubemap.d_val = (void*)cubemap_grad.data_ptr(); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapBwdKernel, gridSize, blockSize, args, 0, stream)); + + return cubemap_grad; +} + +torch::Tensor specular_bounds(int resolution, float costheta_cutoff) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + SpecularBoundsKernelParams p; + p.costheta_cutoff = costheta_cutoff; + p.gridSize = dim3(resolution, resolution, 6); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 6*4 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularBoundsKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor specular_cubemap_fwd(torch::Tensor cubemap, torch::Tensor bounds, float roughness, float costheta_cutoff) +{ + CHECK_TENSOR(cubemap, 4, 3); + CHECK_TENSOR(bounds, 4, 6*4); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + SpecularCubemapKernelParams p; + p.roughness = roughness; + p.costheta_cutoff = costheta_cutoff; + update_grid(p.gridSize, cubemap); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 4 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.bounds = make_cuda_tensor(bounds, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor specular_cubemap_bwd(torch::Tensor cubemap, torch::Tensor bounds, torch::Tensor grad, float roughness, float costheta_cutoff) +{ + CHECK_TENSOR(cubemap, 4, 3); + CHECK_TENSOR(bounds, 4, 6*4); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + SpecularCubemapKernelParams p; + p.roughness = roughness; + p.costheta_cutoff = costheta_cutoff; + update_grid(p.gridSize, cubemap); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + torch::Tensor cubemap_grad; + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.bounds = make_cuda_tensor(bounds, p.gridSize); + p.out = make_cuda_tensor(grad, p.gridSize); + + cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + p.cubemap.d_val = (void*)cubemap_grad.data_ptr(); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapBwdKernel, gridSize, blockSize, args, 0, stream)); + + return cubemap_grad; +} + +//------------------------------------------------------------------------ +// loss function + +LossType strToLoss(std::string str) +{ + if (str == "mse") + return LOSS_MSE; + else if (str == "relmse") + return LOSS_RELMSE; + else if (str == "smape") + return LOSS_SMAPE; + else + return LOSS_L1; +} + +torch::Tensor image_loss_fwd(torch::Tensor img, torch::Tensor target, std::string loss, std::string tonemapper, bool fp16) +{ + CHECK_TENSOR(img, 4, 3); + CHECK_TENSOR(target, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LossKernelParams p; + p.out.fp16 = fp16; + p.loss = strToLoss(loss); + p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE; + update_grid(p.gridSize, img, target); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ (p.gridSize.z - 1)/ warpSize.z + 1, (p.gridSize.y - 1) / warpSize.y + 1, (p.gridSize.x - 1) / warpSize.x + 1, 1 }, opts); + + p.img = make_cuda_tensor(img, p.gridSize); + p.target = make_cuda_tensor(target, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple image_loss_bwd(torch::Tensor img, torch::Tensor target, torch::Tensor grad, std::string loss, std::string tonemapper) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LossKernelParams p; + p.loss = strToLoss(loss); + p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE; + update_grid(p.gridSize, img, target); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor img_grad, target_grad; + p.img = make_cuda_tensor(img, p.gridSize, &img_grad); + p.target = make_cuda_tensor(target, p.gridSize, &target_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(img_grad, target_grad); +} + +//------------------------------------------------------------------------ +// transform function + +torch::Tensor xfm_fwd(torch::Tensor points, torch::Tensor matrix, bool isPoints, bool fp16) +{ + CHECK_TENSOR(points, 3, 3); + CHECK_TENSOR(matrix, 3, 4); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + XfmKernelParams p; + p.out.fp16 = fp16; + p.isPoints = isPoints; + p.gridSize.x = points.size(1); + p.gridSize.y = 1; + p.gridSize.z = std::max(matrix.size(0), points.size(0)); + + // Choose launch parameters. + dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = isPoints ? torch::empty({ matrix.size(0), points.size(1), 4 }, opts) : torch::empty({ matrix.size(0), points.size(1), 3 }, opts); + + p.points = make_cuda_tensor(points, p.gridSize); + p.matrix = make_cuda_tensor(matrix, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor xfm_bwd(torch::Tensor points, torch::Tensor matrix, torch::Tensor grad, bool isPoints) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + XfmKernelParams p; + p.isPoints = isPoints; + p.gridSize.x = points.size(1); + p.gridSize.y = 1; + p.gridSize.z = std::max(matrix.size(0), points.size(0)); + + // Choose launch parameters. + dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor points_grad; + p.points = make_cuda_tensor(points, p.gridSize, &points_grad); + p.matrix = make_cuda_tensor(matrix, p.gridSize); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsBwdKernel, gridSize, blockSize, args, 0, stream)); + + return points_grad; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("prepare_shading_normal_fwd", &prepare_shading_normal_fwd, "prepare_shading_normal_fwd"); + m.def("prepare_shading_normal_bwd", &prepare_shading_normal_bwd, "prepare_shading_normal_bwd"); + m.def("lambert_fwd", &lambert_fwd, "lambert_fwd"); + m.def("lambert_bwd", &lambert_bwd, "lambert_bwd"); + m.def("frostbite_fwd", &frostbite_fwd, "frostbite_fwd"); + m.def("frostbite_bwd", &frostbite_bwd, "frostbite_bwd"); + m.def("fresnel_shlick_fwd", &fresnel_shlick_fwd, "fresnel_shlick_fwd"); + m.def("fresnel_shlick_bwd", &fresnel_shlick_bwd, "fresnel_shlick_bwd"); + m.def("ndf_ggx_fwd", &ndf_ggx_fwd, "ndf_ggx_fwd"); + m.def("ndf_ggx_bwd", &ndf_ggx_bwd, "ndf_ggx_bwd"); + m.def("lambda_ggx_fwd", &lambda_ggx_fwd, "lambda_ggx_fwd"); + m.def("lambda_ggx_bwd", &lambda_ggx_bwd, "lambda_ggx_bwd"); + m.def("masking_smith_fwd", &masking_smith_fwd, "masking_smith_fwd"); + m.def("masking_smith_bwd", &masking_smith_bwd, "masking_smith_bwd"); + m.def("pbr_specular_fwd", &pbr_specular_fwd, "pbr_specular_fwd"); + m.def("pbr_specular_bwd", &pbr_specular_bwd, "pbr_specular_bwd"); + m.def("pbr_bsdf_fwd", &pbr_bsdf_fwd, "pbr_bsdf_fwd"); + m.def("pbr_bsdf_bwd", &pbr_bsdf_bwd, "pbr_bsdf_bwd"); + m.def("diffuse_cubemap_fwd", &diffuse_cubemap_fwd, "diffuse_cubemap_fwd"); + m.def("diffuse_cubemap_bwd", &diffuse_cubemap_bwd, "diffuse_cubemap_bwd"); + m.def("specular_bounds", &specular_bounds, "specular_bounds"); + m.def("specular_cubemap_fwd", &specular_cubemap_fwd, "specular_cubemap_fwd"); + m.def("specular_cubemap_bwd", &specular_cubemap_bwd, "specular_cubemap_bwd"); + m.def("image_loss_fwd", &image_loss_fwd, "image_loss_fwd"); + m.def("image_loss_bwd", &image_loss_bwd, "image_loss_bwd"); + m.def("xfm_fwd", &xfm_fwd, "xfm_fwd"); + m.def("xfm_bwd", &xfm_bwd, "xfm_bwd"); +} \ No newline at end of file diff --git a/video3d/render/renderutils/c_src/vec3f.h b/video3d/render/renderutils/c_src/vec3f.h new file mode 100755 index 0000000000000000000000000000000000000000..7e6745430f19e9fe1834c8cd3dfeb6e68d730297 --- /dev/null +++ b/video3d/render/renderutils/c_src/vec3f.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +struct vec3f +{ + float x, y, z; + +#ifdef __CUDACC__ + __device__ vec3f() { } + __device__ vec3f(float v) { x = v; y = v; z = v; } + __device__ vec3f(float _x, float _y, float _z) { x = _x; y = _y; z = _z; } + __device__ vec3f(float3 v) { x = v.x; y = v.y; z = v.z; } + + __device__ inline vec3f& operator+=(const vec3f& b) { x += b.x; y += b.y; z += b.z; return *this; } + __device__ inline vec3f& operator-=(const vec3f& b) { x -= b.x; y -= b.y; z -= b.z; return *this; } + __device__ inline vec3f& operator*=(const vec3f& b) { x *= b.x; y *= b.y; z *= b.z; return *this; } + __device__ inline vec3f& operator/=(const vec3f& b) { x /= b.x; y /= b.y; z /= b.z; return *this; } +#endif +}; + +#ifdef __CUDACC__ +__device__ static inline vec3f operator+(const vec3f& a, const vec3f& b) { return vec3f(a.x + b.x, a.y + b.y, a.z + b.z); } +__device__ static inline vec3f operator-(const vec3f& a, const vec3f& b) { return vec3f(a.x - b.x, a.y - b.y, a.z - b.z); } +__device__ static inline vec3f operator*(const vec3f& a, const vec3f& b) { return vec3f(a.x * b.x, a.y * b.y, a.z * b.z); } +__device__ static inline vec3f operator/(const vec3f& a, const vec3f& b) { return vec3f(a.x / b.x, a.y / b.y, a.z / b.z); } +__device__ static inline vec3f operator-(const vec3f& a) { return vec3f(-a.x, -a.y, -a.z); } + +__device__ static inline float sum(vec3f a) +{ + return a.x + a.y + a.z; +} + +__device__ static inline vec3f cross(vec3f a, vec3f b) +{ + vec3f out; + out.x = a.y * b.z - a.z * b.y; + out.y = a.z * b.x - a.x * b.z; + out.z = a.x * b.y - a.y * b.x; + return out; +} + +__device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec3f &d_b, vec3f d_out) +{ + d_a.x += d_out.z * b.y - d_out.y * b.z; + d_a.y += d_out.x * b.z - d_out.z * b.x; + d_a.z += d_out.y * b.x - d_out.x * b.y; + + d_b.x += d_out.y * a.z - d_out.z * a.y; + d_b.y += d_out.z * a.x - d_out.x * a.z; + d_b.z += d_out.x * a.y - d_out.y * a.x; +} + +__device__ static inline float dot(vec3f a, vec3f b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} + +__device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f& d_b, float d_out) +{ + d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z; + d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z; +} + +__device__ static inline vec3f reflect(vec3f x, vec3f n) +{ + return n * 2.0f * dot(n, x) - x; +} + +__device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, vec3f& d_n, const vec3f d_out) +{ + d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z); + d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z); + d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1); + + d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x); + d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y); + d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z)); +} + +__device__ static inline vec3f safeNormalize(vec3f v) +{ + float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); + return l > 0.0f ? (v / l) : vec3f(0.0f); +} + +__device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v, const vec3f d_out) +{ + + float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); + if (l > 0.0f) + { + float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f); + d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac; + d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac; + d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac; + } +} + +#endif \ No newline at end of file diff --git a/video3d/render/renderutils/c_src/vec4f.h b/video3d/render/renderutils/c_src/vec4f.h new file mode 100755 index 0000000000000000000000000000000000000000..e3f30776af334597475002275b8b40c584a05035 --- /dev/null +++ b/video3d/render/renderutils/c_src/vec4f.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +struct vec4f +{ + float x, y, z, w; + +#ifdef __CUDACC__ + __device__ vec4f() { } + __device__ vec4f(float v) { x = v; y = v; z = v; w = v; } + __device__ vec4f(float _x, float _y, float _z, float _w) { x = _x; y = _y; z = _z; w = _w; } + __device__ vec4f(float4 v) { x = v.x; y = v.y; z = v.z; w = v.w; } +#endif +}; + diff --git a/video3d/render/renderutils/loss.py b/video3d/render/renderutils/loss.py new file mode 100755 index 0000000000000000000000000000000000000000..92a24c02885380937762698eec578eb81bc80f9e --- /dev/null +++ b/video3d/render/renderutils/loss.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +#---------------------------------------------------------------------------- +# HDR image losses +#---------------------------------------------------------------------------- + +def _tonemap_srgb(f): + return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) + +def _SMAPE(img, target, eps=0.01): + nom = torch.abs(img - target) + denom = torch.abs(img) + torch.abs(target) + 0.01 + return torch.mean(nom / denom) + +def _RELMSE(img, target, eps=0.1): + nom = (img - target) * (img - target) + denom = img * img + target * target + 0.1 + return torch.mean(nom / denom) + +def image_loss_fn(img, target, loss, tonemapper): + if tonemapper == 'log_srgb': + img = _tonemap_srgb(torch.log(torch.clamp(img, min=0, max=65535) + 1)) + target = _tonemap_srgb(torch.log(torch.clamp(target, min=0, max=65535) + 1)) + + if loss == 'mse': + return torch.nn.functional.mse_loss(img, target) + elif loss == 'smape': + return _SMAPE(img, target) + elif loss == 'relmse': + return _RELMSE(img, target) + else: + return torch.nn.functional.l1_loss(img, target) diff --git a/video3d/render/renderutils/ops.py b/video3d/render/renderutils/ops.py new file mode 100755 index 0000000000000000000000000000000000000000..b23bf5ecb019cf6f4d140687530fceb06d4590b5 --- /dev/null +++ b/video3d/render/renderutils/ops.py @@ -0,0 +1,554 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import numpy as np +import os +import sys +import torch +import torch.utils.cpp_extension + +from .bsdf import * +from .loss import * + +#---------------------------------------------------------------------------- +# C++/Cuda plugin compiler/loader. + +_cached_plugin = None +def _get_plugin(): + # Return cached plugin if already loaded. + global _cached_plugin + if _cached_plugin is not None: + return _cached_plugin + + # Make sure we can find the necessary compiler and libary binaries. + if os.name == 'nt': + def find_cl_path(): + import glob + for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']: + paths = sorted(glob.glob(r"C:\Program Files (x86)\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ['PATH'] += ';' + cl_path + + # Compiler options. + opts = ['-DNVDR_TORCH'] + + # Linker options. + if os.name == 'posix': + ldflags = ['-lcuda', '-lnvrtc'] + elif os.name == 'nt': + ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib'] + + # List of sources. + source_files = [ + 'c_src/mesh.cu', + 'c_src/loss.cu', + 'c_src/bsdf.cu', + 'c_src/normal.cu', + 'c_src/cubemap.cu', + 'c_src/common.cpp', + 'c_src/torch_bindings.cpp' + ] + + # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine. + os.environ['TORCH_CUDA_ARCH_LIST'] = '' + + # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment. + try: + lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory('renderutils_plugin', False), 'lock') + if os.path.exists(lock_fn): + print("Warning: Lock file exists in build directory: '%s'" % lock_fn) + except: + pass + + # Compile and load. + source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files] + torch.utils.cpp_extension.load(name='renderutils_plugin', sources=source_paths, extra_cflags=opts, + extra_cuda_cflags=opts, extra_ldflags=ldflags, with_cuda=True, verbose=True) + + # Import, cache, and return the compiled module. + import renderutils_plugin + _cached_plugin = renderutils_plugin + return _cached_plugin + +#---------------------------------------------------------------------------- +# Internal kernels, just used for testing functionality + +class _fresnel_shlick_func(torch.autograd.Function): + @staticmethod + def forward(ctx, f0, f90, cosTheta): + out = _get_plugin().fresnel_shlick_fwd(f0, f90, cosTheta, False) + ctx.save_for_backward(f0, f90, cosTheta) + return out + + @staticmethod + def backward(ctx, dout): + f0, f90, cosTheta = ctx.saved_variables + return _get_plugin().fresnel_shlick_bwd(f0, f90, cosTheta, dout) + (None,) + +def _fresnel_shlick(f0, f90, cosTheta, use_python=False): + if use_python: + out = bsdf_fresnel_shlick(f0, f90, cosTheta) + else: + out = _fresnel_shlick_func.apply(f0, f90, cosTheta) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _fresnel_shlick contains inf or NaN" + return out + + +class _ndf_ggx_func(torch.autograd.Function): + @staticmethod + def forward(ctx, alphaSqr, cosTheta): + out = _get_plugin().ndf_ggx_fwd(alphaSqr, cosTheta, False) + ctx.save_for_backward(alphaSqr, cosTheta) + return out + + @staticmethod + def backward(ctx, dout): + alphaSqr, cosTheta = ctx.saved_variables + return _get_plugin().ndf_ggx_bwd(alphaSqr, cosTheta, dout) + (None,) + +def _ndf_ggx(alphaSqr, cosTheta, use_python=False): + if use_python: + out = bsdf_ndf_ggx(alphaSqr, cosTheta) + else: + out = _ndf_ggx_func.apply(alphaSqr, cosTheta) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _ndf_ggx contains inf or NaN" + return out + +class _lambda_ggx_func(torch.autograd.Function): + @staticmethod + def forward(ctx, alphaSqr, cosTheta): + out = _get_plugin().lambda_ggx_fwd(alphaSqr, cosTheta, False) + ctx.save_for_backward(alphaSqr, cosTheta) + return out + + @staticmethod + def backward(ctx, dout): + alphaSqr, cosTheta = ctx.saved_variables + return _get_plugin().lambda_ggx_bwd(alphaSqr, cosTheta, dout) + (None,) + +def _lambda_ggx(alphaSqr, cosTheta, use_python=False): + if use_python: + out = bsdf_lambda_ggx(alphaSqr, cosTheta) + else: + out = _lambda_ggx_func.apply(alphaSqr, cosTheta) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _lambda_ggx contains inf or NaN" + return out + +class _masking_smith_func(torch.autograd.Function): + @staticmethod + def forward(ctx, alphaSqr, cosThetaI, cosThetaO): + ctx.save_for_backward(alphaSqr, cosThetaI, cosThetaO) + out = _get_plugin().masking_smith_fwd(alphaSqr, cosThetaI, cosThetaO, False) + return out + + @staticmethod + def backward(ctx, dout): + alphaSqr, cosThetaI, cosThetaO = ctx.saved_variables + return _get_plugin().masking_smith_bwd(alphaSqr, cosThetaI, cosThetaO, dout) + (None,) + +def _masking_smith(alphaSqr, cosThetaI, cosThetaO, use_python=False): + if use_python: + out = bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO) + else: + out = _masking_smith_func.apply(alphaSqr, cosThetaI, cosThetaO) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _masking_smith contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# Shading normal setup (bump mapping + bent normals) + +class _prepare_shading_normal_func(torch.autograd.Function): + @staticmethod + def forward(ctx, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl): + ctx.two_sided_shading, ctx.opengl = two_sided_shading, opengl + out = _get_plugin().prepare_shading_normal_fwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl, False) + ctx.save_for_backward(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm) + return out + + @staticmethod + def backward(ctx, dout): + pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm = ctx.saved_variables + return _get_plugin().prepare_shading_normal_bwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, dout, ctx.two_sided_shading, ctx.opengl) + (None, None, None) + +def prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading=True, opengl=True, use_python=False): + '''Takes care of all corner cases and produces a final normal used for shading: + - Constructs tangent space + - Flips normal direction based on geometric normal for two sided Shading + - Perturbs shading normal by normal map + - Bends backfacing normals towards the camera to avoid shading artifacts + + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. + + Args: + pos: World space g-buffer position. + view_pos: Camera position in world space (typically using broadcasting). + perturbed_nrm: Trangent-space normal perturbation from normal map lookup. + smooth_nrm: Interpolated vertex normals. + smooth_tng: Interpolated vertex tangents. + geom_nrm: Geometric (face) normals. + two_sided_shading: Use one/two sided shading + opengl: Use OpenGL/DirectX normal map conventions + use_python: Use PyTorch implementation (for validation) + Returns: + Final shading normal + ''' + + if perturbed_nrm is None: + perturbed_nrm = torch.tensor([0, 0, 1], dtype=torch.float32, device='cuda', requires_grad=False)[None, None, None, ...] + + if use_python: + out = bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl) + else: + out = _prepare_shading_normal_func.apply(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of prepare_shading_normal contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# BSDF functions + +class _lambert_func(torch.autograd.Function): + @staticmethod + def forward(ctx, nrm, wi): + out = _get_plugin().lambert_fwd(nrm, wi, False) + ctx.save_for_backward(nrm, wi) + return out + + @staticmethod + def backward(ctx, dout): + nrm, wi = ctx.saved_variables + return _get_plugin().lambert_bwd(nrm, wi, dout) + (None,) + +def lambert(nrm, wi, use_python=False): + '''Lambertian bsdf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. + + Args: + nrm: World space shading normal. + wi: World space light vector. + use_python: Use PyTorch implementation (for validation) + + Returns: + Shaded diffuse value with shape [minibatch_size, height, width, 1] + ''' + + if use_python: + out = bsdf_lambert(nrm, wi) + else: + out = _lambert_func.apply(nrm, wi) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN" + return out + +class _frostbite_diffuse_func(torch.autograd.Function): + @staticmethod + def forward(ctx, nrm, wi, wo, linearRoughness): + out = _get_plugin().frostbite_fwd(nrm, wi, wo, linearRoughness, False) + ctx.save_for_backward(nrm, wi, wo, linearRoughness) + return out + + @staticmethod + def backward(ctx, dout): + nrm, wi, wo, linearRoughness = ctx.saved_variables + return _get_plugin().frostbite_bwd(nrm, wi, wo, linearRoughness, dout) + (None,) + +def frostbite_diffuse(nrm, wi, wo, linearRoughness, use_python=False): + '''Frostbite, normalized Disney Diffuse bsdf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. + + Args: + nrm: World space shading normal. + wi: World space light vector. + wo: World space camera vector. + linearRoughness: Material roughness + use_python: Use PyTorch implementation (for validation) + + Returns: + Shaded diffuse value with shape [minibatch_size, height, width, 1] + ''' + + if use_python: + out = bsdf_frostbite(nrm, wi, wo, linearRoughness) + else: + out = _frostbite_diffuse_func.apply(nrm, wi, wo, linearRoughness) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN" + return out + +class _pbr_specular_func(torch.autograd.Function): + @staticmethod + def forward(ctx, col, nrm, wo, wi, alpha, min_roughness): + ctx.save_for_backward(col, nrm, wo, wi, alpha) + ctx.min_roughness = min_roughness + out = _get_plugin().pbr_specular_fwd(col, nrm, wo, wi, alpha, min_roughness, False) + return out + + @staticmethod + def backward(ctx, dout): + col, nrm, wo, wi, alpha = ctx.saved_variables + return _get_plugin().pbr_specular_bwd(col, nrm, wo, wi, alpha, ctx.min_roughness, dout) + (None, None) + +def pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08, use_python=False): + '''Physically-based specular bsdf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. + + Args: + col: Specular lobe color + nrm: World space shading normal. + wo: World space camera vector. + wi: World space light vector + alpha: Specular roughness parameter with shape [minibatch_size, height, width, 1] + min_roughness: Scalar roughness clamping threshold + + use_python: Use PyTorch implementation (for validation) + Returns: + Shaded specular color + ''' + + if use_python: + out = bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=min_roughness) + else: + out = _pbr_specular_func.apply(col, nrm, wo, wi, alpha, min_roughness) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of pbr_specular contains inf or NaN" + return out + +class _pbr_bsdf_func(torch.autograd.Function): + @staticmethod + def forward(ctx, kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF): + ctx.save_for_backward(kd, arm, pos, nrm, view_pos, light_pos) + ctx.min_roughness = min_roughness + ctx.BSDF = BSDF + out = _get_plugin().pbr_bsdf_fwd(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF, False) + return out + + @staticmethod + def backward(ctx, dout): + kd, arm, pos, nrm, view_pos, light_pos = ctx.saved_variables + return _get_plugin().pbr_bsdf_bwd(kd, arm, pos, nrm, view_pos, light_pos, ctx.min_roughness, ctx.BSDF, dout) + (None, None, None) + +def pbr_bsdf(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=0.08, bsdf="lambert", use_python=False): + '''Physically-based bsdf, both diffuse & specular lobes + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. + + Args: + kd: Diffuse albedo. + arm: Specular parameters (attenuation, linear roughness, metalness). + pos: World space position. + nrm: World space shading normal. + view_pos: Camera position in world space, typically using broadcasting. + light_pos: Light position in world space, typically using broadcasting. + min_roughness: Scalar roughness clamping threshold + bsdf: Controls diffuse BSDF, can be either 'lambert' or 'frostbite' + + use_python: Use PyTorch implementation (for validation) + + Returns: + Shaded color. + ''' + + BSDF = 0 + if bsdf == 'frostbite': + BSDF = 1 + + if use_python: + out = bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF) + else: + out = _pbr_bsdf_func.apply(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of pbr_bsdf contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# cubemap filter with filtering across edges + +class _diffuse_cubemap_func(torch.autograd.Function): + @staticmethod + def forward(ctx, cubemap): + out = _get_plugin().diffuse_cubemap_fwd(cubemap) + ctx.save_for_backward(cubemap) + return out + + @staticmethod + def backward(ctx, dout): + cubemap, = ctx.saved_variables + cubemap_grad = _get_plugin().diffuse_cubemap_bwd(cubemap, dout) + return cubemap_grad, None + +def diffuse_cubemap(cubemap, use_python=False): + if use_python: + assert False + else: + out = _diffuse_cubemap_func.apply(cubemap) + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of diffuse_cubemap contains inf or NaN" + return out + +class _specular_cubemap(torch.autograd.Function): + @staticmethod + def forward(ctx, cubemap, roughness, costheta_cutoff, bounds): + out = _get_plugin().specular_cubemap_fwd(cubemap, bounds, roughness, costheta_cutoff) + ctx.save_for_backward(cubemap, bounds) + ctx.roughness, ctx.theta_cutoff = roughness, costheta_cutoff + return out + + @staticmethod + def backward(ctx, dout): + cubemap, bounds = ctx.saved_variables + cubemap_grad = _get_plugin().specular_cubemap_bwd(cubemap, bounds, dout, ctx.roughness, ctx.theta_cutoff) + return cubemap_grad, None, None, None + +# Compute the bounds of the GGX NDF lobe to retain "cutoff" percent of the energy +def __ndfBounds(res, roughness, cutoff): + def ndfGGX(alphaSqr, costheta): + costheta = np.clip(costheta, 0.0, 1.0) + d = (costheta * alphaSqr - costheta) * costheta + 1.0 + return alphaSqr / (d * d * np.pi) + + # Sample out cutoff angle + nSamples = 1000000 + costheta = np.cos(np.linspace(0, np.pi/2.0, nSamples)) + D = np.cumsum(ndfGGX(roughness**4, costheta)) + idx = np.argmax(D >= D[..., -1] * cutoff) + + # Brute force compute lookup table with bounds + bounds = _get_plugin().specular_bounds(res, costheta[idx]) + + return costheta[idx], bounds +__ndfBoundsDict = {} + +def specular_cubemap(cubemap, roughness, cutoff=0.99, use_python=False): + assert cubemap.shape[0] == 6 and cubemap.shape[1] == cubemap.shape[2], "Bad shape for cubemap tensor: %s" % str(cubemap.shape) + + if use_python: + assert False + else: + key = (cubemap.shape[1], roughness, cutoff) + if key not in __ndfBoundsDict: + __ndfBoundsDict[key] = __ndfBounds(*key) + out = _specular_cubemap.apply(cubemap, roughness, *__ndfBoundsDict[key]) + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of specular_cubemap contains inf or NaN" + return out[..., 0:3] / out[..., 3:] + +#---------------------------------------------------------------------------- +# Fast image loss function + +class _image_loss_func(torch.autograd.Function): + @staticmethod + def forward(ctx, img, target, loss, tonemapper): + ctx.loss, ctx.tonemapper = loss, tonemapper + ctx.save_for_backward(img, target) + out = _get_plugin().image_loss_fwd(img, target, loss, tonemapper, False) + return out + + @staticmethod + def backward(ctx, dout): + img, target = ctx.saved_variables + return _get_plugin().image_loss_bwd(img, target, dout, ctx.loss, ctx.tonemapper) + (None, None, None) + +def image_loss(img, target, loss='l1', tonemapper='none', use_python=False): + '''Compute HDR image loss. Combines tonemapping and loss into a single kernel for better perf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. + + Args: + img: Input image. + target: Target (reference) image. + loss: Type of loss. Valid options are ['l1', 'mse', 'smape', 'relmse'] + tonemapper: Tonemapping operations. Valid options are ['none', 'log_srgb'] + use_python: Use PyTorch implementation (for validation) + + Returns: + Image space loss (scalar value). + ''' + if use_python: + out = image_loss_fn(img, target, loss, tonemapper) + else: + out = _image_loss_func.apply(img, target, loss, tonemapper) + out = torch.sum(out) / (img.shape[0]*img.shape[1]*img.shape[2]) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of image_loss contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# Transform points function + +class _xfm_func(torch.autograd.Function): + @staticmethod + def forward(ctx, points, matrix, isPoints): + ctx.save_for_backward(points, matrix) + ctx.isPoints = isPoints + return _get_plugin().xfm_fwd(points, matrix, isPoints, False) + + @staticmethod + def backward(ctx, dout): + points, matrix = ctx.saved_variables + return (_get_plugin().xfm_bwd(points, matrix, dout, ctx.isPoints),) + (None, None, None) + +def xfm_points(points, matrix, use_python=False): + '''Transform points. + Args: + points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] + matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] + use_python: Use PyTorch's torch.matmul (for validation) + Returns: + Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. + ''' + if use_python: + out = torch.matmul(torch.nn.functional.pad(points, pad=(0,1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) + else: + out = _xfm_func.apply(points, matrix, True) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" + return out + +def xfm_vectors(vectors, matrix, use_python=False): + '''Transform vectors. + Args: + vectors: Tensor containing 3D vectors with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] + matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] + use_python: Use PyTorch's torch.matmul (for validation) + + Returns: + Transformed vectors in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. + ''' + + if use_python: + out = torch.matmul(torch.nn.functional.pad(vectors, pad=(0,1), mode='constant', value=0.0), torch.transpose(matrix, 1, 2))[..., 0:3].contiguous() + else: + out = _xfm_func.apply(vectors, matrix, False) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of xfm_vectors contains inf or NaN" + return out + + + diff --git a/video3d/render/renderutils/tests/test_bsdf.py b/video3d/render/renderutils/tests/test_bsdf.py new file mode 100755 index 0000000000000000000000000000000000000000..b0b60c350455717826c0f3edb01289b29baac27a --- /dev/null +++ b/video3d/render/renderutils/tests/test_bsdf.py @@ -0,0 +1,296 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +RES = 4 +DTYPE = torch.float32 + +def relative_loss(name, ref, cuda): + ref = ref.float() + cuda = cuda.float() + print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item()) + +def test_normal(): + pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + pos_ref = pos_cuda.clone().detach().requires_grad_(True) + view_pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + view_pos_ref = view_pos_cuda.clone().detach().requires_grad_(True) + perturbed_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + perturbed_nrm_ref = perturbed_nrm_cuda.clone().detach().requires_grad_(True) + smooth_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + smooth_nrm_ref = smooth_nrm_cuda.clone().detach().requires_grad_(True) + smooth_tng_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + smooth_tng_ref = smooth_tng_cuda.clone().detach().requires_grad_(True) + geom_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + geom_nrm_ref = geom_nrm_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru.prepare_shading_normal(pos_ref, view_pos_ref, perturbed_nrm_ref, smooth_nrm_ref, smooth_tng_ref, geom_nrm_ref, True, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.prepare_shading_normal(pos_cuda, view_pos_cuda, perturbed_nrm_cuda, smooth_nrm_cuda, smooth_tng_cuda, geom_nrm_cuda, True) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" bent normal") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("pos:", pos_ref.grad, pos_cuda.grad) + relative_loss("view_pos:", view_pos_ref.grad, view_pos_cuda.grad) + relative_loss("perturbed_nrm:", perturbed_nrm_ref.grad, perturbed_nrm_cuda.grad) + relative_loss("smooth_nrm:", smooth_nrm_ref.grad, smooth_nrm_cuda.grad) + relative_loss("smooth_tng:", smooth_tng_ref.grad, smooth_tng_cuda.grad) + relative_loss("geom_nrm:", geom_nrm_ref.grad, geom_nrm_cuda.grad) + +def test_schlick(): + f0_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + f0_ref = f0_cuda.clone().detach().requires_grad_(True) + f90_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + f90_ref = f90_cuda.clone().detach().requires_grad_(True) + cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 2.0 + cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) + cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru._fresnel_shlick(f0_ref, f90_ref, cosT_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._fresnel_shlick(f0_cuda, f90_cuda, cosT_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Fresnel shlick") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("f0:", f0_ref.grad, f0_cuda.grad) + relative_loss("f90:", f90_ref.grad, f90_cuda.grad) + relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) + +def test_ndf_ggx(): + alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alphaSqr_cuda = alphaSqr_cuda.clone().detach().requires_grad_(True) + alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) + cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1 + cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) + cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru._ndf_ggx(alphaSqr_ref, cosT_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._ndf_ggx(alphaSqr_cuda, cosT_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Ndf GGX") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) + relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) + +def test_lambda_ggx(): + alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) + cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1 + cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) + cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru._lambda_ggx(alphaSqr_ref, cosT_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._lambda_ggx(alphaSqr_cuda, cosT_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Lambda GGX") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) + relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) + +def test_masking_smith(): + alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) + cosThetaI_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + cosThetaI_ref = cosThetaI_cuda.clone().detach().requires_grad_(True) + cosThetaO_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + cosThetaO_ref = cosThetaO_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru._masking_smith(alphaSqr_ref, cosThetaI_ref, cosThetaO_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._masking_smith(alphaSqr_cuda, cosThetaI_cuda, cosThetaO_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Smith masking term") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) + relative_loss("cosThetaI:", cosThetaI_ref.grad, cosThetaI_cuda.grad) + relative_loss("cosThetaO:", cosThetaO_ref.grad, cosThetaO_cuda.grad) + +def test_lambert(): + normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + normals_ref = normals_cuda.clone().detach().requires_grad_(True) + wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wi_ref = wi_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru.lambert(normals_ref, wi_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.lambert(normals_cuda, wi_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Lambert") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("nrm:", normals_ref.grad, normals_cuda.grad) + relative_loss("wi:", wi_ref.grad, wi_cuda.grad) + +def test_frostbite(): + normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + normals_ref = normals_cuda.clone().detach().requires_grad_(True) + wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wi_ref = wi_cuda.clone().detach().requires_grad_(True) + wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wo_ref = wo_cuda.clone().detach().requires_grad_(True) + rough_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + rough_ref = rough_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru.frostbite_diffuse(normals_ref, wi_ref, wo_ref, rough_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.frostbite_diffuse(normals_cuda, wi_cuda, wo_cuda, rough_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Frostbite") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("nrm:", normals_ref.grad, normals_cuda.grad) + relative_loss("wo:", wo_ref.grad, wo_cuda.grad) + relative_loss("wi:", wi_ref.grad, wi_cuda.grad) + relative_loss("rough:", rough_ref.grad, rough_cuda.grad) + +def test_pbr_specular(): + col_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + col_ref = col_cuda.clone().detach().requires_grad_(True) + nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) + wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wi_ref = wi_cuda.clone().detach().requires_grad_(True) + wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wo_ref = wo_cuda.clone().detach().requires_grad_(True) + alpha_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alpha_ref = alpha_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru.pbr_specular(col_ref, nrm_ref, wo_ref, wi_ref, alpha_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.pbr_specular(col_cuda, nrm_cuda, wo_cuda, wi_cuda, alpha_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Pbr specular") + print("-------------------------------------------------------------") + + relative_loss("res:", ref, cuda) + if col_ref.grad is not None: + relative_loss("col:", col_ref.grad, col_cuda.grad) + if nrm_ref.grad is not None: + relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad) + if wi_ref.grad is not None: + relative_loss("wi:", wi_ref.grad, wi_cuda.grad) + if wo_ref.grad is not None: + relative_loss("wo:", wo_ref.grad, wo_cuda.grad) + if alpha_ref.grad is not None: + relative_loss("alpha:", alpha_ref.grad, alpha_cuda.grad) + +def test_pbr_bsdf(bsdf): + kd_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + kd_ref = kd_cuda.clone().detach().requires_grad_(True) + arm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + arm_ref = arm_cuda.clone().detach().requires_grad_(True) + pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + pos_ref = pos_cuda.clone().detach().requires_grad_(True) + nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) + view_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + view_ref = view_cuda.clone().detach().requires_grad_(True) + light_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + light_ref = light_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True, bsdf=bsdf) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda, bsdf=bsdf) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Pbr BSDF") + print("-------------------------------------------------------------") + + relative_loss("res:", ref, cuda) + if kd_ref.grad is not None: + relative_loss("kd:", kd_ref.grad, kd_cuda.grad) + if arm_ref.grad is not None: + relative_loss("arm:", arm_ref.grad, arm_cuda.grad) + if pos_ref.grad is not None: + relative_loss("pos:", pos_ref.grad, pos_cuda.grad) + if nrm_ref.grad is not None: + relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad) + if view_ref.grad is not None: + relative_loss("view:", view_ref.grad, view_cuda.grad) + if light_ref.grad is not None: + relative_loss("light:", light_ref.grad, light_cuda.grad) + +test_normal() + +test_schlick() +test_ndf_ggx() +test_lambda_ggx() +test_masking_smith() + +test_lambert() +test_frostbite() +test_pbr_specular() +test_pbr_bsdf('lambert') +test_pbr_bsdf('frostbite') diff --git a/video3d/render/renderutils/tests/test_cubemap.py b/video3d/render/renderutils/tests/test_cubemap.py new file mode 100755 index 0000000000000000000000000000000000000000..a1ae0a28b3fe6b88201c49c00c5180962d182579 --- /dev/null +++ b/video3d/render/renderutils/tests/test_cubemap.py @@ -0,0 +1,47 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +RES = 4 +DTYPE = torch.float32 + +def relative_loss(name, ref, cuda): + ref = ref.float() + cuda = cuda.float() + print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item()) + +def test_cubemap(): + cubemap_cuda = torch.rand(6, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + cubemap_ref = cubemap_cuda.clone().detach().requires_grad_(True) + weights = torch.rand(3, 3, 1, dtype=DTYPE, device='cuda') + target = torch.rand(6, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru.filter_cubemap(cubemap_ref, weights, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.filter_cubemap(cubemap_cuda, weights, use_python=False) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Cubemap:") + print("-------------------------------------------------------------") + + relative_loss("flt:", ref, cuda) + relative_loss("cubemap:", cubemap_ref.grad, cubemap_cuda.grad) + + +test_cubemap() diff --git a/video3d/render/renderutils/tests/test_loss.py b/video3d/render/renderutils/tests/test_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..7a68f3fc4528431fe405d1d6077af0cb31687d31 --- /dev/null +++ b/video3d/render/renderutils/tests/test_loss.py @@ -0,0 +1,61 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +RES = 8 +DTYPE = torch.float32 + +def tonemap_srgb(f): + return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) + +def l1(output, target): + x = torch.clamp(output, min=0, max=65535) + r = torch.clamp(target, min=0, max=65535) + x = tonemap_srgb(torch.log(x + 1)) + r = tonemap_srgb(torch.log(r + 1)) + return torch.nn.functional.l1_loss(x,r) + +def relative_loss(name, ref, cuda): + ref = ref.float() + cuda = cuda.float() + print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item()) + +def test_loss(loss, tonemapper): + img_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + img_ref = img_cuda.clone().detach().requires_grad_(True) + target_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + target_ref = target_cuda.clone().detach().requires_grad_(True) + + ref_loss = ru.image_loss(img_ref, target_ref, loss=loss, tonemapper=tonemapper, use_python=True) + ref_loss.backward() + + cuda_loss = ru.image_loss(img_cuda, target_cuda, loss=loss, tonemapper=tonemapper) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Loss: %s, %s" % (loss, tonemapper)) + print("-------------------------------------------------------------") + + relative_loss("res:", ref_loss, cuda_loss) + relative_loss("img:", img_ref.grad, img_cuda.grad) + relative_loss("target:", target_ref.grad, target_cuda.grad) + + +test_loss('l1', 'none') +test_loss('l1', 'log_srgb') +test_loss('mse', 'log_srgb') +test_loss('smape', 'none') +test_loss('relmse', 'none') +test_loss('mse', 'none') \ No newline at end of file diff --git a/video3d/render/renderutils/tests/test_mesh.py b/video3d/render/renderutils/tests/test_mesh.py new file mode 100755 index 0000000000000000000000000000000000000000..4856c5ce07e2d6cd5f1fd463c1d3628791eafccc --- /dev/null +++ b/video3d/render/renderutils/tests/test_mesh.py @@ -0,0 +1,90 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +BATCH = 8 +RES = 1024 +DTYPE = torch.float32 + +torch.manual_seed(0) + +def tonemap_srgb(f): + return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) + +def l1(output, target): + x = torch.clamp(output, min=0, max=65535) + r = torch.clamp(target, min=0, max=65535) + x = tonemap_srgb(torch.log(x + 1)) + r = tonemap_srgb(torch.log(r + 1)) + return torch.nn.functional.l1_loss(x,r) + +def relative_loss(name, ref, cuda): + ref = ref.float() + cuda = cuda.float() + print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref)).item()) + +def test_xfm_points(): + points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + points_ref = points_cuda.clone().detach().requires_grad_(True) + mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False) + mtx_ref = mtx_cuda.clone().detach().requires_grad_(True) + target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True) + + ref_out = ru.xfm_points(points_ref, mtx_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref_out, target) + ref_loss.backward() + + cuda_out = ru.xfm_points(points_cuda, mtx_cuda) + cuda_loss = torch.nn.MSELoss()(cuda_out, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + + relative_loss("res:", ref_out, cuda_out) + relative_loss("points:", points_ref.grad, points_cuda.grad) + +def test_xfm_vectors(): + points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + points_ref = points_cuda.clone().detach().requires_grad_(True) + points_cuda_p = points_cuda.clone().detach().requires_grad_(True) + points_ref_p = points_cuda.clone().detach().requires_grad_(True) + mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False) + mtx_ref = mtx_cuda.clone().detach().requires_grad_(True) + target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True) + + ref_out = ru.xfm_vectors(points_ref.contiguous(), mtx_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref_out, target[..., 0:3]) + ref_loss.backward() + + cuda_out = ru.xfm_vectors(points_cuda.contiguous(), mtx_cuda) + cuda_loss = torch.nn.MSELoss()(cuda_out, target[..., 0:3]) + cuda_loss.backward() + + ref_out_p = ru.xfm_points(points_ref_p.contiguous(), mtx_ref, use_python=True) + ref_loss_p = torch.nn.MSELoss()(ref_out_p, target) + ref_loss_p.backward() + + cuda_out_p = ru.xfm_points(points_cuda_p.contiguous(), mtx_cuda) + cuda_loss_p = torch.nn.MSELoss()(cuda_out_p, target) + cuda_loss_p.backward() + + print("-------------------------------------------------------------") + + relative_loss("res:", ref_out, cuda_out) + relative_loss("points:", points_ref.grad, points_cuda.grad) + relative_loss("points_p:", points_ref_p.grad, points_cuda_p.grad) + +test_xfm_points() +test_xfm_vectors() diff --git a/video3d/render/renderutils/tests/test_perf.py b/video3d/render/renderutils/tests/test_perf.py new file mode 100755 index 0000000000000000000000000000000000000000..ffc143e3004c0fd0a42a1941896823bc2bef939a --- /dev/null +++ b/video3d/render/renderutils/tests/test_perf.py @@ -0,0 +1,57 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +DTYPE=torch.float32 + +def test_bsdf(BATCH, RES, ITR): + kd_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + kd_ref = kd_cuda.clone().detach().requires_grad_(True) + arm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + arm_ref = arm_cuda.clone().detach().requires_grad_(True) + pos_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + pos_ref = pos_cuda.clone().detach().requires_grad_(True) + nrm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) + view_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + view_ref = view_cuda.clone().detach().requires_grad_(True) + light_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + light_ref = light_cuda.clone().detach().requires_grad_(True) + target = torch.rand(BATCH, RES, RES, 3, device='cuda') + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda) + + print("--- Testing: [%d, %d, %d] ---" % (BATCH, RES, RES)) + + start.record() + for i in range(ITR): + ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True) + end.record() + torch.cuda.synchronize() + print("Pbr BSDF python:", start.elapsed_time(end)) + + start.record() + for i in range(ITR): + cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda) + end.record() + torch.cuda.synchronize() + print("Pbr BSDF cuda:", start.elapsed_time(end)) + +test_bsdf(1, 512, 1000) +test_bsdf(16, 512, 1000) +test_bsdf(1, 2048, 1000) diff --git a/video3d/render/texture.py b/video3d/render/texture.py new file mode 100755 index 0000000000000000000000000000000000000000..4e4a39d042dc4d356c47133efee897088b9ce5c6 --- /dev/null +++ b/video3d/render/texture.py @@ -0,0 +1,186 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr + +from . import util + +###################################################################################### +# Smooth pooling / mip computation with linear gradient upscaling +###################################################################################### + +class texture2d_mip(torch.autograd.Function): + @staticmethod + def forward(ctx, texture): + return util.avg_pool_nhwc(texture, (2,2)) + + @staticmethod + def backward(ctx, dout): + gy, gx = torch.meshgrid(torch.linspace(0.0 + 0.25 / dout.shape[1], 1.0 - 0.25 / dout.shape[1], dout.shape[1]*2, device="cuda"), + torch.linspace(0.0 + 0.25 / dout.shape[2], 1.0 - 0.25 / dout.shape[2], dout.shape[2]*2, device="cuda"), + indexing='ij') + uv = torch.stack((gx, gy), dim=-1) + return dr.texture(dout * 0.25, uv[None, ...].contiguous(), filter_mode='linear', boundary_mode='clamp') + +######################################################################################################## +# Simple texture class. A texture can be either +# - A 3D tensor (using auto mipmaps) +# - A list of 3D tensors (full custom mip hierarchy) +######################################################################################################## + +class Texture2D(torch.nn.Module): + # Initializes a texture from image data. + # Input can be constant value (1D array) or texture (3D array) or mip hierarchy (list of 3d arrays) + def __init__(self, init, min_max=None): + super(Texture2D, self).__init__() + + if isinstance(init, np.ndarray): + init = torch.tensor(init, dtype=torch.float32, device='cuda') + elif isinstance(init, list) and len(init) == 1: + init = init[0] + + if isinstance(init, list): + self.data = list(torch.nn.Parameter(mip.clone().detach(), requires_grad=True) for mip in init) + elif len(init.shape) == 4: + self.data = torch.nn.Parameter(init.clone().detach(), requires_grad=True) + elif len(init.shape) == 3: + self.data = torch.nn.Parameter(init[None, ...].clone().detach(), requires_grad=True) + elif len(init.shape) == 1: + self.data = torch.nn.Parameter(init[None, None, None, :].clone().detach(), requires_grad=True) # Convert constant to 1x1 tensor + else: + assert False, "Invalid texture object" + + self.min_max = min_max + + # Filtered (trilinear) sample texture at a given location + def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear'): + if isinstance(self.data, list): + out = dr.texture(self.data[0], texc, texc_deriv, mip=self.data[1:], filter_mode=filter_mode) + else: + if self.data.shape[1] > 1 and self.data.shape[2] > 1: + mips = [self.data] + while mips[-1].shape[1] > 1 and mips[-1].shape[2] > 1: + mips += [texture2d_mip.apply(mips[-1])] + out = dr.texture(mips[0], texc, texc_deriv, mip=mips[1:], filter_mode=filter_mode) + else: + out = dr.texture(self.data, texc, texc_deriv, filter_mode=filter_mode) + return out + + def getRes(self): + return self.getMips()[0].shape[1:3] + + def getChannels(self): + return self.getMips()[0].shape[3] + + def getMips(self): + if isinstance(self.data, list): + return self.data + else: + return [self.data] + + # In-place clamp with no derivative to make sure values are in valid range after training + def clamp_(self): + if self.min_max is not None: + for mip in self.getMips(): + for i in range(mip.shape[-1]): + mip[..., i].clamp_(min=self.min_max[0][i], max=self.min_max[1][i]) + + # In-place clamp with no derivative to make sure values are in valid range after training + def normalize_(self): + with torch.no_grad(): + for mip in self.getMips(): + mip = util.safe_normalize(mip) + +######################################################################################################## +# Helper function to create a trainable texture from a regular texture. The trainable weights are +# initialized with texture data as an initial guess +######################################################################################################## + +def create_trainable(init, res=None, auto_mipmaps=True, min_max=None): + with torch.no_grad(): + if isinstance(init, Texture2D): + assert isinstance(init.data, torch.Tensor) + min_max = init.min_max if min_max is None else min_max + init = init.data + elif isinstance(init, np.ndarray): + init = torch.tensor(init, dtype=torch.float32, device='cuda') + + # Pad to NHWC if needed + if len(init.shape) == 1: # Extend constant to NHWC tensor + init = init[None, None, None, :] + elif len(init.shape) == 3: + init = init[None, ...] + + # Scale input to desired resolution. + if res is not None: + init = util.scale_img_nhwc(init, res) + + # Genreate custom mipchain + if not auto_mipmaps: + mip_chain = [init.clone().detach().requires_grad_(True)] + while mip_chain[-1].shape[1] > 1 or mip_chain[-1].shape[2] > 1: + new_size = [max(mip_chain[-1].shape[1] // 2, 1), max(mip_chain[-1].shape[2] // 2, 1)] + mip_chain += [util.scale_img_nhwc(mip_chain[-1], new_size)] + return Texture2D(mip_chain, min_max=min_max) + else: + return Texture2D(init, min_max=min_max) + +######################################################################################################## +# Convert texture to and from SRGB +######################################################################################################## + +def srgb_to_rgb(texture): + return Texture2D(list(util.srgb_to_rgb(mip) for mip in texture.getMips())) + +def rgb_to_srgb(texture): + return Texture2D(list(util.rgb_to_srgb(mip) for mip in texture.getMips())) + +######################################################################################################## +# Utility functions for loading / storing a texture +######################################################################################################## + +def _load_mip2D(fn, lambda_fn=None, channels=None): + imgdata = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda') + if channels is not None: + imgdata = imgdata[..., 0:channels] + if lambda_fn is not None: + imgdata = lambda_fn(imgdata) + return imgdata.detach().clone() + +def load_texture2D(fn, lambda_fn=None, channels=None): + base, ext = os.path.splitext(fn) + if os.path.exists(base + "_0" + ext): + mips = [] + while os.path.exists(base + ("_%d" % len(mips)) + ext): + mips += [_load_mip2D(base + ("_%d" % len(mips)) + ext, lambda_fn, channels)] + return Texture2D(mips) + else: + return Texture2D(_load_mip2D(fn, lambda_fn, channels)) + +def _save_mip2D(fn, mip, mipidx, lambda_fn): + if lambda_fn is not None: + data = lambda_fn(mip).detach().cpu().numpy() + else: + data = mip.detach().cpu().numpy() + + if mipidx is None: + util.save_image(fn, data) + else: + base, ext = os.path.splitext(fn) + util.save_image(base + ("_%d" % mipidx) + ext, data) + +def save_texture2D(fn, tex, lambda_fn=None): + if isinstance(tex.data, list): + for i, mip in enumerate(tex.data): + _save_mip2D(fn, mip[0,...], i, lambda_fn) + else: + _save_mip2D(fn, tex.data[0,...], None, lambda_fn) diff --git a/video3d/render/util.py b/video3d/render/util.py new file mode 100755 index 0000000000000000000000000000000000000000..cf0a04f9281bced81b4fdcb394064871b455fffa --- /dev/null +++ b/video3d/render/util.py @@ -0,0 +1,477 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr +import imageio +import torchvision.transforms.functional as TF +import torch.nn.functional as TNF + +#---------------------------------------------------------------------------- +# Vector operations +#---------------------------------------------------------------------------- + +def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.sum(x*y, -1, keepdim=True) + +def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: + return 2*dot(x, n)*n - x + +def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN + +def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return x / length(x, eps) + +def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor: + return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w) + +#---------------------------------------------------------------------------- +# sRGB color transforms +#---------------------------------------------------------------------------- + +def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055) + +def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4)) + +def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def reinhard(f: torch.Tensor) -> torch.Tensor: + return f/(1+f) + +#----------------------------------------------------------------------------------- +# Metrics (taken from jaxNerf source code, in order to replicate their measurements) +# +# https://github.com/google-research/google-research/blob/301451a62102b046bbeebff49a760ebeec9707b8/jaxnerf/nerf/utils.py#L266 +# +#----------------------------------------------------------------------------------- + +def mse_to_psnr(mse): + """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" + return -10. / np.log(10.) * np.log(mse) + +def psnr_to_mse(psnr): + """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" + return np.exp(-0.1 * np.log(10.) * psnr) + +#---------------------------------------------------------------------------- +# Displacement texture lookup +#---------------------------------------------------------------------------- + +def get_miplevels(texture: np.ndarray) -> float: + minDim = min(texture.shape[0], texture.shape[1]) + return np.floor(np.log2(minDim)) + +def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='nearest') -> torch.Tensor: + tex_map = tex_map[None, ...] # Add batch dimension + tex_map = tex_map.permute(0, 3, 1, 2) # NHWC -> NCHW + tex = torch.nn.functional.grid_sample(tex_map, coords[None, None, ...] * 2 - 1, mode=filter, align_corners=False) + tex = tex.permute(0, 2, 3, 1) # NCHW -> NHWC + return tex[0, 0, ...] + +#---------------------------------------------------------------------------- +# Cubemap utility functions +#---------------------------------------------------------------------------- + +def cube_to_dir(s, x, y): + if s == 0: rx, ry, rz = torch.ones_like(x), -y, -x + elif s == 1: rx, ry, rz = -torch.ones_like(x), -y, x + elif s == 2: rx, ry, rz = x, torch.ones_like(x), y + elif s == 3: rx, ry, rz = x, -torch.ones_like(x), -y + elif s == 4: rx, ry, rz = x, -y, torch.ones_like(x) + elif s == 5: rx, ry, rz = -x, -y, -torch.ones_like(x) + return torch.stack((rx, ry, rz), dim=-1) + +def latlong_to_cubemap(latlong_map, res): + cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda') + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + v = safe_normalize(cube_to_dir(s, gx, gy)) + + tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5 + tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi + texcoord = torch.cat((tu, tv), dim=-1) + + cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0] + return cubemap + +def cubemap_to_latlong(cubemap, res): + gy, gx = torch.meshgrid(torch.linspace( 0.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + + sintheta, costheta = torch.sin(gy*np.pi), torch.cos(gy*np.pi) + sinphi, cosphi = torch.sin(gx*np.pi), torch.cos(gx*np.pi) + + reflvec = torch.stack(( + sintheta*sinphi, + costheta, + -sintheta*cosphi + ), dim=-1) + return dr.texture(cubemap[None, ...], reflvec[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')[0] + +#---------------------------------------------------------------------------- +# Image scaling +#---------------------------------------------------------------------------- + +def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + return scale_img_nhwc(x[None, ...], size, mag, min)[0] + +def scale_img_nhwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other" + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger + y = torch.nn.functional.interpolate(y, size, mode=min) + else: # Magnification + if mag == 'bilinear' or mag == 'bicubic': + y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) + else: + y = torch.nn.functional.interpolate(y, size, mode=mag) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +def avg_pool_nhwc(x : torch.Tensor, size) -> torch.Tensor: + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + y = torch.nn.functional.avg_pool2d(y, size) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Behaves similar to tf.segment_sum +#---------------------------------------------------------------------------- + +def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor: + num_segments = torch.unique_consecutive(segment_ids).shape[0] + + # Repeats ids until same dimension as data + if len(segment_ids.shape) == 1: + s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64, device='cuda')).long() + segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:]) + + assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal" + + shape = [num_segments] + list(data.shape[1:]) + result = torch.zeros(*shape, dtype=torch.float32, device='cuda') + result = result.scatter_add(0, segment_ids, data) + return result + +#---------------------------------------------------------------------------- +# Matrix helpers. +#---------------------------------------------------------------------------- + +def fovx_to_fovy(fovx, aspect): + return np.arctan(np.tan(fovx / 2) / aspect) * 2.0 + +def focal_length_to_fovy(focal_length, sensor_height): + return 2 * np.arctan(0.5 * sensor_height / focal_length) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + return torch.tensor([[1/(y*aspect), 0, 0, 0], + [ 0, 1/-y, 0, 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective_offcenter(fovy, fraction, rx, ry, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + + # Full frustum + R, L = aspect*y, -aspect*y + T, B = y, -y + + # Create a randomized sub-frustum + width = (R-L)*fraction + height = (T-B)*fraction + xstart = (R-L)*rx + ystart = (T-B)*ry + + l = L + xstart + r = l + width + b = B + ystart + t = b + height + + # https://www.scratchapixel.com/lessons/3d-basic-rendering/perspective-and-orthographic-projection-matrix/opengl-perspective-projection-matrix + return torch.tensor([[2/(r-l), 0, (r+l)/(r-l), 0], + [ 0, -2/(t-b), (t+b)/(t-b), 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +def translate(x, y, z, device=None): + return torch.tensor([[1, 0, 0, x], + [0, 1, 0, y], + [0, 0, 1, z], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_x(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[1, 0, 0, 0], + [0, c, s, 0], + [0, -s, c, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_y(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[ c, 0, s, 0], + [ 0, 1, 0, 0], + [-s, 0, c, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def scale(s, device=None): + return torch.tensor([[ s, 0, 0, 0], + [ 0, s, 0, 0], + [ 0, 0, s, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def lookAt(eye, at, up): + a = eye - at + w = a / torch.linalg.norm(a) + u = torch.cross(up, w) + u = u / torch.linalg.norm(u) + v = torch.cross(w, u) + translate = torch.tensor([[1, 0, 0, -eye[0]], + [0, 1, 0, -eye[1]], + [0, 0, 1, -eye[2]], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + rotate = torch.tensor([[u[0], u[1], u[2], 0], + [v[0], v[1], v[2], 0], + [w[0], w[1], w[2], 0], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + return rotate @ translate + +@torch.no_grad() +def random_rotation_translation(t, device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.random.uniform(-t, t, size=[3]) + return torch.tensor(m, dtype=torch.float32, device=device) + +@torch.no_grad() +def random_rotation(device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.array([0,0,0]).astype(np.float32) + return torch.tensor(m, dtype=torch.float32, device=device) + +#---------------------------------------------------------------------------- +# Compute focal points of a set of lines using least squares. +# handy for poorly centered datasets +#---------------------------------------------------------------------------- + +def lines_focal(o, d): + d = safe_normalize(d) + I = torch.eye(3, dtype=o.dtype, device=o.device) + S = torch.sum(d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...], dim=0) + C = torch.sum((d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...]) @ o[..., None], dim=0).squeeze(1) + return torch.linalg.pinv(S) @ C + +#---------------------------------------------------------------------------- +# Cosine sample around a vector N +#---------------------------------------------------------------------------- +@torch.no_grad() +def cosine_sample(N, size=None): + # construct local frame + N = N/torch.linalg.norm(N) + + dx0 = torch.tensor([0, N[2], -N[1]], dtype=N.dtype, device=N.device) + dx1 = torch.tensor([-N[2], 0, N[0]], dtype=N.dtype, device=N.device) + + dx = torch.where(dot(dx0, dx0) > dot(dx1, dx1), dx0, dx1) + #dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1 + dx = dx / torch.linalg.norm(dx) + dy = torch.cross(N,dx) + dy = dy / torch.linalg.norm(dy) + + # cosine sampling in local frame + if size is None: + phi = 2.0 * np.pi * np.random.uniform() + s = np.random.uniform() + else: + phi = 2.0 * np.pi * torch.rand(*size, 1, dtype=N.dtype, device=N.device) + s = torch.rand(*size, 1, dtype=N.dtype, device=N.device) + costheta = np.sqrt(s) + sintheta = np.sqrt(1.0 - s) + + # cartesian vector in local space + x = np.cos(phi)*sintheta + y = np.sin(phi)*sintheta + z = costheta + + # local to world + return dx*x + dy*y + N*z + +#---------------------------------------------------------------------------- +# Bilinear downsample by 2x. +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + w = w.expand(x.shape[-1], 1, 4, 4) + x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1]) + return x.permute(0, 2, 3, 1) + +#---------------------------------------------------------------------------- +# Bilinear downsample log(spp) steps +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + g = x.shape[-1] + w = w.expand(g, 1, 4, 4) + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW + steps = int(np.log2(spp)) + for _ in range(steps): + xp = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate') + x = torch.nn.functional.conv2d(xp, w, padding=0, stride=2, groups=g) + return x.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Singleton initialize GLFW +#---------------------------------------------------------------------------- + +_glfw_initialized = False +def init_glfw(): + global _glfw_initialized + try: + import glfw + glfw.ERROR_REPORTING = 'raise' + glfw.default_window_hints() + glfw.window_hint(glfw.VISIBLE, glfw.FALSE) + test = glfw.create_window(8, 8, "Test", None, None) # Create a window and see if not initialized yet + except glfw.GLFWError as e: + if e.error_code == glfw.NOT_INITIALIZED: + glfw.init() + _glfw_initialized = True + +#---------------------------------------------------------------------------- +# Image display function using OpenGL. +#---------------------------------------------------------------------------- + +_glfw_window = None +def display_image(image, title=None): + # Import OpenGL + import OpenGL.GL as gl + import glfw + + # Zoom image if requested. + image = np.asarray(image[..., 0:3]) if image.shape[-1] == 4 else np.asarray(image) + height, width, channels = image.shape + + # Initialize window. + init_glfw() + if title is None: + title = 'Debug window' + global _glfw_window + if _glfw_window is None: + glfw.default_window_hints() + _glfw_window = glfw.create_window(width, height, title, None, None) + glfw.make_context_current(_glfw_window) + glfw.show_window(_glfw_window) + glfw.swap_interval(0) + else: + glfw.make_context_current(_glfw_window) + glfw.set_window_title(_glfw_window, title) + glfw.set_window_size(_glfw_window, width, height) + + # Update window. + glfw.poll_events() + gl.glClearColor(0, 0, 0, 1) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glWindowPos2f(0, 0) + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels] + gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name] + gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1]) + glfw.swap_buffers(_glfw_window) + if glfw.window_should_close(_glfw_window): + return False + return True + +#---------------------------------------------------------------------------- +# Image save/load helper. +#---------------------------------------------------------------------------- + +def save_image(fn, x : np.ndarray): + try: + if os.path.splitext(fn)[1] == ".png": + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8), compress_level=3) # Low compression for faster saving + else: + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8)) + except: + print("WARNING: FAILED to save image %s" % fn) + +def save_image_raw(fn, x : np.ndarray): + try: + imageio.imwrite(fn, x) + except: + print("WARNING: FAILED to save image %s" % fn) + + +def load_image_raw(fn) -> np.ndarray: + return imageio.imread(fn) + +def load_image(fn) -> np.ndarray: + img = load_image_raw(fn) + if img.dtype == np.float32: # HDR image + return img + else: # LDR image + return img.astype(np.float32) / 255 + +#---------------------------------------------------------------------------- + +def time_to_text(x): + if x > 3600: + return "%.2f h" % (x / 3600) + elif x > 60: + return "%.2f m" % (x / 60) + else: + return "%.2f s" % x + +#---------------------------------------------------------------------------- + +def checkerboard(res, checker_size) -> np.ndarray: + tiles_y = (res[0] + (checker_size*2) - 1) // (checker_size*2) + tiles_x = (res[1] + (checker_size*2) - 1) // (checker_size*2) + check = np.kron([[1, 0] * tiles_x, [0, 1] * tiles_x] * tiles_y, np.ones((checker_size, checker_size)))*0.33 + 0.33 + check = check[:res[0], :res[1]] + return np.stack((check, check, check), axis=-1) + + +def blur_image(image, kernel_size=3, sigma=None, mode='gaussian'): + if mode == 'gaussian': + return TF.gaussian_blur(image, kernel_size, sigma) + elif mode == 'average': + p = kernel_size // 2 + out = TNF.pad(image, (p, p, p, p), mode='replicate') + return TNF.avg_pool2d(out, kernel_size, stride=1, padding=0) + else: + raise Exception("Unknown blur mode") \ No newline at end of file diff --git a/video3d/renderer.py b/video3d/renderer.py new file mode 100755 index 0000000000000000000000000000000000000000..6ade0f43799489cb0ec2edc088a5c6b49855f2ec --- /dev/null +++ b/video3d/renderer.py @@ -0,0 +1,244 @@ +import numpy as np +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +# import pytorch3d +# import pytorch3d.loss +# import pytorch3d.renderer +# import pytorch3d.structures +# import pytorch3d.io +# import pytorch3d.transforms +from PIL import Image +from .utils import sphere +from einops import rearrange + + +def update_camera_pose(cameras, position, at): + cameras.R = pytorch3d.renderer.look_at_rotation(position, at).to(cameras.device) + cameras.T = -torch.bmm(cameras.R.transpose(1, 2), position[:, :, None])[:, :, 0] + + +def get_soft_rasterizer_settings(image_size, sigma=1e-6, gamma=1e-6, faces_per_pixel=30): + blend_params = pytorch3d.renderer.BlendParams(sigma=sigma, gamma=gamma) + settings = pytorch3d.renderer.RasterizationSettings( + image_size=image_size, + blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, + faces_per_pixel=faces_per_pixel, + ) + return settings, blend_params + + +class Renderer(nn.Module): + def __init__(self, cfgs): + super().__init__() + self.cfgs = cfgs + self.device = cfgs.get('device', 'cpu') + self.image_size = cfgs.get('out_image_size', 64) + self.full_size_h = cfgs.get('full_size_h', 1080) + self.full_size_w = cfgs.get('full_size_w', 1920) + self.fov_w = cfgs.get('fov_w', 60) + # self.fov_h = cfgs.get('fov_h', 30) + self.fov_h = np.arctan(np.tan(self.fov_w /2 /180*np.pi) / self.full_size_w * self.full_size_h) *2 /np.pi*180 + self.crop_fov_approx = cfgs.get('crop_fov_approx', 25) + self.cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.) + self.max_range = np.tan(min(self.fov_h, self.fov_w) /2 /180 * np.pi) * self.cam_pos_z_offset + cam_pos = torch.FloatTensor([[0, 0, self.cam_pos_z_offset]]).to(self.device) + cam_at = torch.FloatTensor([[0, 0, 0]]).to(self.device) + self.rot_rep = cfgs.get('rot_rep', 'euler_angle') + # self.cameras = pytorch3d.renderer.FoVPerspectiveCameras(fov=self.crop_fov_approx).to(self.device) + # update_camera_pose(self.cameras, position=cam_pos, at=cam_at) + # self.full_cameras = pytorch3d.renderer.FoVPerspectiveCameras(fov=self.fov_w).to(self.device) + # update_camera_pose(self.full_cameras, position=cam_pos, at=cam_at) + self.image_renderer = self._create_image_renderer() + self.ico_sphere_subdiv = cfgs.get('ico_sphere_subdiv', 2) + self.init_shape_scale_xy = float(cfgs.get('init_shape_scale_xy', 1.)) + self.init_shape_scale_z = float(cfgs.get('init_shape_scale_z', 1.)) + # init_verts, init_faces, init_aux = pytorch3d.io.load_obj(cfgs['init_shape_obj'], create_texture_atlas=True, device=self.device) + # self.init_verts = init_verts *self.init_shape_scale + # self.meshes = pytorch3d.structures.Meshes(verts=[self.init_verts], faces=[init_faces.verts_idx]).to(self.device) + # self.tex_faces_uv = init_faces.textures_idx.unsqueeze(0) + # self.tex_verts_uv = init_aux.verts_uvs.unsqueeze(0) + # self.texture_atlas = init_aux.texture_atlas.unsqueeze(0) + # self.num_verts_total = init_verts.size(0) + + # cmap = plt.cm.get_cmap('hsv', self.num_verts_total) + # verts_texture = cmap(np.random.permutation(self.num_verts_total))[:,:3] + # self.verts_texture = torch.FloatTensor(verts_texture) + # debug_uvtex = cfgs.get('debug_uvtex', None) + # if debug_uvtex is not None: + # face_tex_map = Image.open(debug_uvtex).convert('RGB').resize((512, 512)) + # self.face_tex_map = torch.FloatTensor(np.array(face_tex_map)).permute(2,0,1) / 255. + # else: + # self.face_tex_map = None + + meshes, aux = sphere.get_symmetric_ico_sphere(subdiv=self.ico_sphere_subdiv, return_tex_uv=True, return_face_tex_map=True, device=self.device) + init_verts = meshes.verts_padded() + self.init_verts = init_verts * torch.FloatTensor([self.init_shape_scale_xy, self.init_shape_scale_xy, self.init_shape_scale_z]).view(1,1,3).to(init_verts.device) + # TODO: is this needed? + self.meshes = meshes.update_padded(init_verts * 0) + self.tex_faces_uv = aux['face_tex_ids'].unsqueeze(0) + self.tex_verts_uv = aux['verts_tex_uv'].unsqueeze(0) + self.face_tex_map = aux['face_tex_map'].permute(2,0,1) + self.tex_map_seam_mask = aux['seam_mask'].permute(2,0,1) + self.num_verts_total = init_verts.size(1) + self.num_verts_seam = aux['num_verts_seam'] + self.num_verts_one_side = aux['num_verts_one_side'] + + # hack to turn off texture symmetry + if cfgs.get('disable_sym_tex', False): + tex_uv_seam1 = self.tex_verts_uv[:,:aux['num_verts_seam']].clone() + tex_uv_seam1[:,:,0] = tex_uv_seam1[:,:,0] /2 + 0.5 + tex_uv_side1 = self.tex_verts_uv[:,aux['num_verts_seam']:aux['num_verts_seam']+aux['num_verts_one_side']].clone() + tex_uv_side1[:,:,0] = tex_uv_side1[:,:,0] /2 + 0.5 + tex_uv_seam2 = self.tex_verts_uv[:,:aux['num_verts_seam']].clone() + tex_uv_seam2[:,:,0] = tex_uv_seam2[:,:,0] /2 + tex_uv_side2 = self.tex_verts_uv[:,aux['num_verts_seam']+aux['num_verts_one_side']:].clone() + tex_uv_side2[:,:,0] = tex_uv_side2[:,:,0] /2 + self.tex_verts_uv = torch.cat([tex_uv_seam1, tex_uv_side1, tex_uv_side2, tex_uv_seam2], 1) + + num_faces = self.tex_faces_uv.shape[1] + face_tex_ids1 = self.tex_faces_uv[:, :num_faces//2].clone() + face_tex_ids2 = self.tex_faces_uv[:, num_faces//2:].clone() + face_tex_ids2[face_tex_ids2 < aux['num_verts_seam']] += aux['num_verts_seam'] + 2*aux['num_verts_one_side'] + self.tex_faces_uv = torch.cat([face_tex_ids1, face_tex_ids2], 1) + self.face_tex_map = torch.cat([self.face_tex_map, self.face_tex_map.flip(2)], 2) + self.tex_map_seam_mask = torch.cat([self.tex_map_seam_mask, self.tex_map_seam_mask.flip(2)], 2) + + def _create_silhouette_renderer(self): + settings, blend_params = get_soft_rasterizer_settings(self.image_size) + return pytorch3d.renderer.MeshRenderer( + rasterizer=pytorch3d.renderer.MeshRasterizer(cameras=self.cameras, raster_settings=settings), + shader=pytorch3d.renderer.SoftSilhouetteShader(cameras=self.cameras, blend_params=blend_params) + ) + + def _create_image_renderer(self): + settings, blend_params = get_soft_rasterizer_settings(self.image_size) + lights = pytorch3d.renderer.DirectionalLights(device=self.device, + ambient_color=((1., 1., 1.),), + diffuse_color=((0., 0., 0.),), + specular_color=((0., 0., 0.),), + direction=((0, 1, 0),)) + return pytorch3d.renderer.MeshRenderer( + rasterizer=pytorch3d.renderer.MeshRasterizer(cameras=self.cameras, raster_settings=settings), + shader=pytorch3d.renderer.SoftPhongShader(device=self.device, lights=lights, cameras=self.cameras, blend_params=blend_params) + ) + + def transform_verts(self, verts, pose): + b, f, _ = pose.shape + if self.rot_rep == 'euler_angle' or self.rot_rep == 'soft_calss': + rot_mat = pytorch3d.transforms.euler_angles_to_matrix(pose[...,:3].view(-1,3), convention='XYZ') + tsf = pytorch3d.transforms.Rotate(rot_mat, device=pose.device) + elif self.rot_rep == 'quaternion': + rot_mat = pytorch3d.transforms.quaternion_to_matrix(pose[...,:4].view(-1,4)) + tsf = pytorch3d.transforms.Rotate(rot_mat, device=pose.device) + elif self.rot_rep == 'lookat': + rot_mat = pose[...,:9].view(-1,3,3) + tsf = pytorch3d.transforms.Rotate(rot_mat, device=pose.device) + else: + raise NotImplementedError + tsf = tsf.compose(pytorch3d.transforms.Translate(pose[...,-3:].view(-1,3), device=pose.device)) + new_verts = tsf.transform_points(verts.view(b*f, *verts.shape[2:])) + return new_verts.view(b, f, *new_verts.shape[1:]) + + # def transform_mesh(self, mesh, pose): + # mesh_verts = mesh.verts_padded() + # new_mesh_verts = self.transform_verts(mesh_verts, pose) + # new_mesh = mesh.update_padded(new_mesh_verts) + # return new_mesh + + def symmetrize_shape(self, shape): + verts_seam = shape[:,:,:self.num_verts_seam] * torch.FloatTensor([0,1,1]).to(shape.device) + verts_one_side = shape[:,:,self.num_verts_seam:self.num_verts_seam+self.num_verts_one_side] * torch.FloatTensor([1,1,1]).to(shape.device) + verts_other_side = verts_one_side * torch.FloatTensor([-1,1,1]).to(shape.device) + shape = torch.cat([verts_seam, verts_one_side, verts_other_side], 2) + return shape + + def get_deformed_mesh(self, shape, pose=None, return_shape=False): + b, f, _, _ = shape.shape + if pose is not None: + shape = self.transform_verts(shape, pose) + mesh = self.meshes.extend(b*f) + mesh = mesh.update_padded(rearrange(shape, 'b f ... -> (b f) ...')) + if return_shape: + return shape, mesh + else: + return mesh + + def get_textures(self, tex_im): + b, f, c, h, w = tex_im.shape + + ## top half texture map in ico_sphere.obj is unused, pad with zeros + # if 'sym' not in self.cfgs.get('init_shape_obj', ''): + # tex_im = torch.cat([torch.zeros_like(tex_im), tex_im], 3) + # tex_im = nn.functional.interpolate(tex_im, (h, w), mode='bilinear', align_corners=False) + textures = pytorch3d.renderer.TexturesUV(maps=tex_im.view(b*f, *tex_im.shape[2:]).permute(0, 2, 3, 1), # texture maps are BxHxWx3 + faces_uvs=self.tex_faces_uv.repeat(b*f, 1, 1), + verts_uvs=self.tex_verts_uv.repeat(b*f, 1, 1)) + return textures + + def render_flow(self, meshes, shape, pose, deformed_shape=None): + # verts = meshes.verts_padded() # (B*F)xVx3 + b, f, _, _ = shape.shape + if f < 2: + return None + + if deformed_shape is None: + deformed_shape, meshes = self.get_deformed_mesh(shape.detach(), pose=pose, return_shape=True) + im_size = torch.FloatTensor([self.image_size, self.image_size]).to(shape.device) # (w,h) + verts_2d = self.cameras.transform_points_screen(deformed_shape.view(b*f, *deformed_shape.shape[2:]), im_size.view(1,2).repeat(b*f,1), eps=1e-7) + verts_2d = verts_2d.view(b, f, *verts_2d.shape[1:]) + verts_flow = verts_2d[:, 1:, :, :2] - verts_2d[:, :-1, :, :2] # Bx(F-1)xVx(x,y) + verts_flow = verts_flow / im_size.view(1, 1, 1, 2) * 0.5 + 0.5 # 0~1 + flow_tex = torch.nn.functional.pad(verts_flow, pad=[0, 1, 0, 0, 0, 1]) # BxFxVx3 + + # meshes = meshes.detach() # detach mesh when rendering flow (only texture has gradients) + # meshes = self.get_deformed_mesh(shape.detach()) + meshes.textures = pytorch3d.renderer.TexturesVertex(verts_features=flow_tex.view(b*f, -1, 3)) + flow = self.image_renderer(meshes_world=meshes, cameras=self.cameras) + # settings, blend_params = get_soft_rasterizer_settings(image_size=self.image_size, sigma=1e-6, gamma=1e-6, faces_per_pixel=5) + # flow = self.image_renderer(meshes_world=meshes, cameras=self.cameras, raster_settings=settings, blend_params=blend_params) + flow = flow.view(b, f, *flow.shape[1:])[:, :-1] # Bx(F-1)xHxWx3 + flow_mask = (flow[:, :, :, :, 3:] > 0.01).float() + return (flow[:, :, :, :, :2] - 0.5) * 2 * flow_mask # Bx(F-1)xHxWx2 + + def forward(self, pose, texture, shape, crop_bbox=None, render_flow=True): + b, f, _ = pose.shape + + ## compensate crop with intrinsics, assuming square crops + # x0, y0, w, h = crop_bbox.unbind(2) + # fx = 1 / np.tan(self.fov_w / 2 /180*np.pi) + # fy = fx + # sx = w / self.full_size_w + # sy = sx + # cx = ((x0+w/2) - (self.full_size_w/2)) / (self.full_size_w/2) # [0-w] -> [-1,1] + # cy = ((y0+h/2) - (self.full_size_h/2)) / (self.full_size_w/2) + # znear = 1 + # zfar = 100 + # v1 = zfar / (zfar - znear) + # v2 = -(zfar * znear) / (zfar - znear) + # + # # K = [[[ fx/sx, 0.0000, cx/sx, 0.0000], + # # [ 0.0000, fy/sy, cy/sy, 0.0000], + # # [ 0.0000, 0.0000, v1, v2], + # # [ 0.0000, 0.0000, 1.0000, 0.0000]]] + # zeros = torch.zeros_like(sx) + # K_row1 = torch.stack([fx/sx, zeros, cx/sx, zeros], 2) + # K_row2 = torch.stack([zeros, fy/sy, cy/sy, zeros], 2) + # K_row3 = torch.stack([zeros, zeros, zeros+v1, zeros+v2], 2) + # K_row4 = torch.stack([zeros, zeros, zeros+1, zeros], 2) + # K = torch.stack([K_row1, K_row2, K_row3, K_row4], 2) # BxFx4x4 + # self.crop_cameras = pytorch3d.renderer.FoVPerspectiveCameras(K=K.view(-1, 4, 4), R=self.cameras.R, T=self.cameras.T, device=self.device) + # # reset znear, zfar to scalar to bypass broadcast bug in pytorch3d blending + # self.crop_cameras.znear = znear + # self.crop_cameras.zfar = zfar + + deformed_shape, mesh = self.get_deformed_mesh(shape, pose=pose, return_shape=True) + if render_flow: + flow = self.render_flow(mesh, shape, pose, deformed_shape=deformed_shape) # Bx(F-1)xHxWx2 + # flow = self.render_flow(mesh, shape, pose, deformed_shape=None) # Bx(F-1)xHxWx2 + else: + flow = None + mesh.textures = self.get_textures(texture) + image = self.image_renderer(meshes_world=mesh, cameras=self.cameras) + image = image.view(b, f, *image.shape[1:]) + return image, flow, mesh diff --git a/video3d/segmentation.py b/video3d/segmentation.py new file mode 100755 index 0000000000000000000000000000000000000000..e8ffb011e2494016df086eac852671b89d6662d2 --- /dev/null +++ b/video3d/segmentation.py @@ -0,0 +1,132 @@ +import configargparse +import torch +import torch.nn as nn +import torch.utils.data +import torchvision.utils as tvutils +import torchvision.transforms +from video3d.utils.segmentation_transforms import * +from video3d.utils.misc import setup_runtime +from video3d import networks +from video3d.trainer import Trainer +from video3d.dataloaders import SegmentationDataset + + +class Segmentation: + def __init__(self, cfgs, _): + self.cfgs = cfgs + self.device = cfgs.get('device', 'cpu') + self.total_loss = None + self.net = networks.EDDeconv(cin=3, cout=1, zdim=128, nf=64, activation=None) + self.optimizer = torch.optim.Adam( + filter(lambda p: p.requires_grad, self.net.parameters()), + lr=cfgs.get('lr', 1e-4), + betas=(0.9, 0.999), + weight_decay=5e-4) + + def load_model_state(self, cp): + self.net.load_state_dict(cp["net"]) + + def load_optimizer_state(self, cp): + self.net.load_state_dict(cp["optimizer"]) + + @staticmethod + def get_data_loaders(cfgs): + batch_size = cfgs.get('batch_size', 64) + num_workers = cfgs.get('num_workers', 4) + data_dir = cfgs.get('data_dir', './data') + img_size = cfgs.get('image_size', 64) + min_size = int(img_size * cfgs.get('aug_min_resize', 0.5)) + max_size = int(img_size * cfgs.get('aug_max_resize', 2.0)) + transform = Compose([RandomResize(min_size, max_size), + RandomHorizontalFlip(cfgs.get("aug_horizontal_flip", 0.4)), + RandomCrop(img_size), + ImageOnly(torchvision.transforms.ColorJitter(**cfgs.get("aug_color_jitter", {}))), + ImageOnly(torchvision.transforms.RandomGrayscale(cfgs.get("aug_grayscale", 0.2))), + ToTensor()]) + train_loader = torch.utils.data.DataLoader( + SegmentationDataset(data_dir, is_validation=False, transform=transform, sequence_range=(0, 0.5)), + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True + ) + transform = Compose([ToTensor()]) + val_loader = torch.utils.data.DataLoader( + SegmentationDataset(data_dir, is_validation=True, transform=transform, sequence_range=(0.5, 1.0)), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True + ) + return train_loader, val_loader, None + + def get_state_dict(self): + return { + "net": self.net.state_dict(), + "optimizer": self.optimizer.state_dict() + } + + def to(self, device): + self.device = device + self.net.to(device) + + def set_train(self): + self.net.train() + + def set_eval(self): + self.net.eval() + + def backward(self): + self.optimizer.zero_grad() + self.total_loss.backward() + self.optimizer.step() + + def forward(self, batch, visualize=False): + image, target = batch + image = image.to(self.device)*2 - 1 + target = target[:, 0, :, :].to(self.device).unsqueeze(1) + pred = self.net(image) + + self.total_loss = nn.functional.binary_cross_entropy_with_logits(pred, target) + + metrics = {'loss': self.total_loss} + + visuals = {} + if visualize: + visuals['rgb'] = self.image_visual(image, normalize=True, range=(-1, 1)) + visuals['target'] = self.image_visual(target, normalize=True, range=(0, 1)) + visuals['pred'] = self.image_visual(nn.functional.sigmoid(pred), normalize=True, range=(0, 1)) + + return metrics, visuals + + return metrics + + def visualize(self, logger, total_iter, max_bs=25): + pass + + def save_results(self, save_dir): + pass + + def save_scores(self, path): + pass + + @staticmethod + def image_visual(tensor, **kwargs): + if tensor.shape[1] == 1: + tensor = tensor.repeat(1, 3, 1, 1) + n = int(tensor.shape[0]**0.5 + 0.5) + tensor = tvutils.make_grid(tensor.detach(), nrow=n, **kwargs).permute(1, 2, 0) + return torch.clamp(tensor[:, :, :3] * 255, 0, 255).byte().cpu() + + +if __name__ == "__main__": + parser = configargparse.ArgumentParser(description='Training configurations.') + parser.add_argument('--config', default="config/train_segmentation.yml", type=str, is_config_file=True, + help='Specify a config file path') + parser.add_argument('--gpu', default=1, type=int, help='Specify a GPU device') + parser.add_argument('--seed', default=0, type=int, help='Specify a random seed') + args, _ = parser.parse_known_args() + + cfgs = setup_runtime(args) + trainer = Trainer(cfgs, Segmentation) + trainer.train() diff --git a/video3d/trainer.py b/video3d/trainer.py new file mode 100755 index 0000000000000000000000000000000000000000..94311190d6475c95778376d155a6bd011f9096e6 --- /dev/null +++ b/video3d/trainer.py @@ -0,0 +1,337 @@ +import os +import os.path as osp +import math +import glob +from datetime import datetime +import imageio +import torch +import video3d.utils.meters as meters +import video3d.utils.misc as misc +import wandb + +def sample_frames(batch, num_sample_frames, iteration, stride=1): + ## window slicing sampling + images, masks, flows, bboxs, bg_image, seq_idx, frame_idx = batch + num_seqs, total_num_frames = images.shape[:2] + # start_frame_idx = iteration % (total_num_frames - num_sample_frames +1) + + ## forward and backward + num_windows = total_num_frames - num_sample_frames +1 + start_frame_idx = (iteration * stride) % (2*num_windows) + ## x' = (2n-1)/2 - |(2n-1)/2 - x| : 0,1,2,3,4,5 -> 0,1,2,2,1,0 + mid_val = (2*num_windows -1) /2 + start_frame_idx = int(mid_val - abs(mid_val -start_frame_idx)) + + new_batch = images[:, start_frame_idx:start_frame_idx+num_sample_frames], \ + masks[:, start_frame_idx:start_frame_idx+num_sample_frames], \ + flows[:, start_frame_idx:start_frame_idx+num_sample_frames-1], \ + bboxs[:, start_frame_idx:start_frame_idx+num_sample_frames], \ + bg_image, \ + seq_idx, \ + frame_idx[:, start_frame_idx:start_frame_idx+num_sample_frames] + return new_batch + + +def indefinite_generator(loader): + while True: + for x in loader: + yield x + + +class Trainer: + def __init__(self, cfgs, model): + self.cfgs = cfgs + self.device = cfgs.get('device', 'cpu') + self.num_epochs = cfgs.get('num_epochs', 1) + + # The logic is, if the num_iterations is set in the cfg + # for any 'epoch' in cfg, I rescale it to (epoch / 120) * epoch_now, as in horse exp + # for any 'iter' in cfg, I just keep them the same + self.num_iterations = cfgs.get('num_iterations', 0) + if self.num_iterations != 0: + self.use_total_iterations = True + else: + self.use_total_iterations = False + + self.num_sample_frames = cfgs.get('num_sample_frames', 100) + self.sample_frame_stride = cfgs.get('sample_frame_stride', 1) + self.checkpoint_dir = cfgs.get('checkpoint_dir', 'results') + self.save_checkpoint_freq = cfgs.get('save_checkpoint_freq', 1) + self.keep_num_checkpoint = cfgs.get('keep_num_checkpoint', 2) # -1 for keeping all checkpoints + self.resume = cfgs.get('resume', True) + self.use_logger = cfgs.get('use_logger', True) + self.log_freq_images = cfgs.get('log_freq_images', 1000) + self.log_train_images = cfgs.get('log_train_images', False) + self.log_freq_losses = cfgs.get('log_freq_losses', 100) + self.visualize_validation = cfgs.get('visualize_validation', False) + self.fix_viz_batch = cfgs.get('fix_viz_batch', False) + self.archive_code = cfgs.get('archive_code', True) + self.checkpoint_name = cfgs.get('checkpoint_name', None) + self.test_result_dir = cfgs.get('test_result_dir', None) + self.validate = cfgs.get('validate', False) + self.current_epoch = 0 + self.logger = None + self.viz_input = None + self.dataset = cfgs.get('dataset', 'video') + self.train_with_cub = cfgs.get('train_with_cub', False) + self.train_with_kaggle = cfgs.get('train_with_kaggle', False) + self.cub_start_epoch = cfgs.get('cub_start_epoch', 0) + + self.metrics_trace = meters.MetricsTrace() + self.make_metrics = lambda m=None: meters.StandardMetrics(m) + + self.batch_size = cfgs.get('batch_size', 64) + self.in_image_size = cfgs.get('in_image_size', 256) + self.out_image_size = cfgs.get('out_image_size', 256) + self.num_workers = cfgs.get('num_workers', 4) + self.run_train = cfgs.get('run_train', False) + self.train_data_dir = cfgs.get('train_data_dir', None) + self.val_data_dir = cfgs.get('val_data_dir', None) + self.run_test = cfgs.get('run_test', False) + self.test_data_dir = cfgs.get('test_data_dir', None) + + self.train_loader, self.val_loader, self.test_loader = model.get_data_loaders(cfgs, self.dataset, in_image_size=self.in_image_size, out_image_size=self.out_image_size, batch_size=self.batch_size, num_workers=self.num_workers, run_train=self.run_train, run_test=self.run_test, train_data_dir=self.train_data_dir, val_data_dir=self.val_data_dir, test_data_dir=self.test_data_dir) + if self.train_with_cub: + self.batch_size_cub = cfgs.get('batch_size_cub', 64) + self.data_dir_cub = cfgs.get('data_dir_cub', None) + self.train_loader_cub, self.val_loader_cub, self.test_loader_cub = model.get_data_loaders(cfgs, 'cub', in_image_size=self.in_image_size, batch_size=self.batch_size_cub, num_workers=self.num_workers, run_train=self.run_train, run_test=self.run_test, train_data_dir=self.data_dir_cub, val_data_dir=self.data_dir_cub, test_data_dir=self.data_dir_cub) + if self.train_with_kaggle: + self.batch_size_kaggle = cfgs.get('batch_size_kaggle', 64) + self.data_dir_kaggle = cfgs.get('data_dir_kaggle', None) + self.train_loader_kaggle, self.val_loader_kaggle, self.test_loader_kaggle = model.get_data_loaders(cfgs, 'kaggle', in_image_size=self.in_image_size, batch_size=self.batch_size_kaggle, num_workers=self.num_workers, run_train=self.run_train, run_test=self.run_test, train_data_dir=self.data_dir_kaggle, val_data_dir=self.data_dir_kaggle, test_data_dir=self.data_dir_kaggle) + + if self.use_total_iterations: + # reset the epoch related cfgs + + train_data_dir = cfgs.get("train_data_dir", None) + if isinstance(train_data_dir, str): + num_of_classes = 1 + elif isinstance(train_data_dir, dict): + num_of_classes = len(train_data_dir) + + dataloader_length = 0 + for class_idx in range(num_of_classes): + dataloader_length += len(self.train_loader[class_idx]) + + total_epoch = int(self.num_iterations / dataloader_length) + 1 + + print(f'run for {total_epoch} epochs') + + for k, v in cfgs.items(): + if 'epoch' in k: + if isinstance(v, list): + new_v = [int(total_epoch * x / 120) for x in v] + cfgs[k] = new_v + elif isinstance(v, int): + new_v = int(total_epoch * v / 120) + 1 + cfgs[k] = new_v + else: + continue + + self.num_epochs = total_epoch + self.cub_start_epoch = cfgs.get('cub_start_epoch', 0) + self.cfgs = cfgs + + self.model = model(cfgs) + self.model.trainer = self + self.save_result_freq = cfgs.get('save_result_freq', None) + self.train_result_dir = osp.join(self.checkpoint_dir, 'results') + + def load_checkpoint(self, optim=True): + """Search the specified/latest checkpoint in checkpoint_dir and load the model and optimizer.""" + if self.checkpoint_name is not None: + checkpoint_path = osp.join(self.checkpoint_dir, self.checkpoint_name) + else: + checkpoints = sorted(glob.glob(osp.join(self.checkpoint_dir, '*.pth'))) + if len(checkpoints) == 0: + return 0, 0 + checkpoint_path = checkpoints[-1] + self.checkpoint_name = osp.basename(checkpoint_path) + print(f"Loading checkpoint from {checkpoint_path}") + cp = torch.load(checkpoint_path, map_location=self.device) + self.model.load_model_state(cp) + if optim: + self.model.load_optimizer_state(cp) + self.metrics_trace = cp['metrics_trace'] + epoch = cp['epoch'] + total_iter = cp['total_iter'] + return epoch, total_iter + + def save_checkpoint(self, epoch, total_iter=0, optim=True): + """Save model, optimizer, and metrics state to a checkpoint in checkpoint_dir for the specified epoch.""" + misc.xmkdir(self.checkpoint_dir) + checkpoint_path = osp.join(self.checkpoint_dir, f'checkpoint{epoch:03}.pth') + state_dict = self.model.get_model_state() + if optim: + optimizer_state = self.model.get_optimizer_state() + state_dict = {**state_dict, **optimizer_state} + state_dict['metrics_trace'] = self.metrics_trace + state_dict['epoch'] = epoch + state_dict['total_iter'] = total_iter + print(f"Saving checkpoint to {checkpoint_path}") + torch.save(state_dict, checkpoint_path) + if self.keep_num_checkpoint > 0: + misc.clean_checkpoint(self.checkpoint_dir, keep_num=self.keep_num_checkpoint) + + def save_clean_checkpoint(self, path): + """Save model state only to specified path.""" + torch.save(self.model.get_model_state(), path) + + def reset_viz_data_iterator(self): + self.viz_data_iterator = iter(self.val_loader) if self.visualize_validation else iter(self.train_loader) + + def reset_cub_train_data_iterator(self): + self.cub_train_data_iterator = iter(self.train_loader_cub) + + def reset_cub_viz_data_iterator(self): + self.cub_viz_data_iterator = iter(self.val_loader_cub) if self.visualize_validation else iter(self.train_loader_cub) + + def test(self): + """Perform testing.""" + self.model.to(self.device) + self.model.set_eval() + epoch, self.total_iter = self.load_checkpoint(optim=False) + + if self.test_result_dir is None: + self.test_result_dir = osp.join(self.checkpoint_dir, f'test_results_{self.checkpoint_name}'.replace('.pth', '')) + print(f"Saving testing results to {self.test_result_dir}") + + with torch.no_grad(): + for iteration, batch in enumerate(self.test_loader): + m = self.model.forward(batch, epoch=epoch, iter=iteration, total_iter=self.total_iter, save_results=True, save_dir=self.test_result_dir, which_data=self.dataset, is_training=False) + print(f"T{epoch:04}/{iteration:05}") + + score_path = osp.join(self.test_result_dir, 'all_metrics.txt') + # self.model.save_scores(score_path) + + def train(self): + """Perform training.""" + # archive code and configs + if self.archive_code: + misc.archive_code(osp.join(self.checkpoint_dir, 'archived_code.zip'), filetypes=['.py']) + misc.dump_yaml(osp.join(self.checkpoint_dir, 'configs.yml'), self.cfgs) + + # initialize + start_epoch = 0 + self.total_iter = 0 + self.metrics_trace.reset() + self.model.to(self.device) + self.model.reset_optimizers() + + # resume from checkpoint + if self.resume: + start_epoch, self.total_iter = self.load_checkpoint(optim=True) + + # train with cub + if self.train_with_cub: + self.cub_train_data_iterator = indefinite_generator(self.train_loader_cub) + + # initialize tensorboard logger + if self.use_logger: + wandb.tensorboard.patch(root_logdir=osp.join(self.checkpoint_dir, 'logs', datetime.now().strftime("%Y%m%d-%H%M%S"))) + wandb.init(name=self.checkpoint_dir.split("/")[-1], project="APT36K") + #wandb.tensorboard.patch(save=False, tensorboard_x=True) + from torch.utils.tensorboard import SummaryWriter + self.logger = SummaryWriter(osp.join(self.checkpoint_dir, 'logs', datetime.now().strftime("%Y%m%d-%H%M%S")), flush_secs=10) + self.viz_data_iterator = indefinite_generator(self.val_loader) if self.visualize_validation else indefinite_generator(self.train_loader) + if self.fix_viz_batch: + self.viz_batch = next(self.viz_data_iterator) + + # train with cub + if self.train_with_cub: + self.cub_viz_data_iterator = indefinite_generator(self.val_loader_cub) if self.visualize_validation else indefinite_generator(self.train_loader_cub) + if self.fix_viz_batch: + self.viz_batch_cub = next(self.cub_viz_data_iterator) + + + # run epochs + epoch = 0 + for epoch in range(start_epoch, self.num_epochs): + metrics = self.run_epoch(epoch) + self.metrics_trace.append("train", metrics) + if (epoch+1) % self.save_checkpoint_freq == 0: + self.save_checkpoint(epoch+1, total_iter=self.total_iter, optim=True) + if self.cfgs.get('pyplot_metrics', True): + self.metrics_trace.plot(pdf_path=osp.join(self.checkpoint_dir, 'metrics.pdf')) + self.metrics_trace.save(osp.join(self.checkpoint_dir, 'metrics.json')) + wandb.finish() + print(f"Training completed for all {epoch+1} epochs.") + + def run_epoch(self, epoch): + metrics = self.make_metrics() + + self.model.set_train() + for iteration, batch in enumerate(self.train_loader): + self.total_iter += 1 + + num_seqs, num_frames = batch[0].shape[:2] + total_im_num = num_seqs*num_frames + m = self.model.forward(batch, epoch=epoch, iter=iteration, total_iter=self.total_iter, which_data=self.dataset, is_training=True) + + if self.train_with_cub and epoch >= self.cub_start_epoch: + batch_cub = next(self.cub_train_data_iterator) + num_seqs, num_frames = batch_cub[0].shape[:2] + total_im_num += num_seqs*num_frames + m_cub = self.model.forward(batch_cub, epoch=epoch, iter=iteration, total_iter=self.total_iter, which_data='cub', is_training=True) + m.update({'cub_'+k: v for k,v in m_cub.items()}) + m['total_loss'] = self.model.total_loss + + self.model.backward() + + metrics.update(m, total_im_num) + print(f"T{epoch:04}/{iteration:05}/{metrics}") + + ## reset optimizers + if self.cfgs.get('opt_reset_every_iter', 0) > 0 and self.total_iter < self.cfgs.get('opt_reset_end_iter', 0): + if self.total_iter % self.cfgs.get('opt_reset_every_iter', 0) == 0: + self.model.reset_optimizers() + + if self.use_logger: + if self.total_iter % self.log_freq_losses == 0: + for name, loss in m.items(): + label = f'cub_loss_train/{name[4:]}' if 'cub' in name else f'loss_train/{name}' + self.logger.add_scalar(label, loss, self.total_iter) + + if self.save_result_freq is not None and self.total_iter % self.save_result_freq == 0: + with torch.no_grad(): + m = self.model.forward(batch, epoch=epoch, iter=iteration, total_iter=self.total_iter, save_results=True, save_dir=self.train_result_dir, which_data=self.dataset, is_training=False) + torch.cuda.empty_cache() + + if self.total_iter % self.log_freq_images == 0: + with torch.no_grad(): + if self.log_train_images: + m = self.model.forward(batch, epoch=epoch, iter=iteration, viz_logger=self.logger, total_iter=self.total_iter, which_data=self.dataset, logger_prefix='train_', is_training=True) + + if self.fix_viz_batch: + batch = self.viz_batch + elif self.visualize_validation: + batch = next(self.viz_data_iterator) + # try: + # batch = next(self.viz_data_iterator) + # except: # iterator exhausted + # self.reset_viz_data_iterator() + # batch = next(self.viz_data_iterator) + m = self.model.forward(batch, epoch=epoch, iter=iteration, viz_logger=self.logger, total_iter=self.total_iter, which_data=self.dataset, logger_prefix='val_', is_training=False) + for name, loss in m.items(): + self.logger.add_scalar(f'loss_val/{name}', loss, self.total_iter) + + if self.train_with_cub and epoch >= self.cub_start_epoch: + if self.log_train_images: + m = self.model.forward(batch_cub, epoch=epoch, iter=iteration, viz_logger=self.logger, total_iter=self.total_iter, which_data='cub', logger_prefix='cub_train_', is_training=True) + + if self.fix_viz_batch: + batch_cub = self.viz_batch_cub + elif self.visualize_validation: + batch_cub = next(self.cub_viz_data_iterator) + # try: + # batch = next(self.viz_data_iterator) + # except: # iterator exhausted + # self.reset_viz_data_iterator() + # batch = next(self.viz_data_iterator) + m = self.model.forward(batch_cub, epoch=epoch, iter=iteration, viz_logger=self.logger, total_iter=self.total_iter, which_data='cub', logger_prefix='cub_val_', is_training=False) + for name, loss in m.items(): + self.logger.add_scalar(f'cub_loss_val/{name}', loss, self.total_iter) + torch.cuda.empty_cache() + + self.model.scheduler_step() + return metrics diff --git a/video3d/trainer_ddp.py b/video3d/trainer_ddp.py new file mode 100755 index 0000000000000000000000000000000000000000..e5b3a99436e4cf4b52df4ff46f9682e87f19345c --- /dev/null +++ b/video3d/trainer_ddp.py @@ -0,0 +1,563 @@ +import os +import os.path as osp +import glob +from datetime import datetime +import random +import torch +import video3d.utils.meters as meters +import video3d.utils.misc as misc + +from video3d.dataloaders_ddp import get_sequence_loader_quadrupeds + +def sample_frames(batch, num_sample_frames, iteration, stride=1): + ## window slicing sampling + images, masks, flows, bboxs, bg_image, seq_idx, frame_idx = batch + num_seqs, total_num_frames = images.shape[:2] + # start_frame_idx = iteration % (total_num_frames - num_sample_frames +1) + + ## forward and backward + num_windows = total_num_frames - num_sample_frames +1 + start_frame_idx = (iteration * stride) % (2*num_windows) + ## x' = (2n-1)/2 - |(2n-1)/2 - x| : 0,1,2,3,4,5 -> 0,1,2,2,1,0 + mid_val = (2*num_windows -1) /2 + start_frame_idx = int(mid_val - abs(mid_val -start_frame_idx)) + + new_batch = images[:, start_frame_idx:start_frame_idx+num_sample_frames], \ + masks[:, start_frame_idx:start_frame_idx+num_sample_frames], \ + flows[:, start_frame_idx:start_frame_idx+num_sample_frames-1], \ + bboxs[:, start_frame_idx:start_frame_idx+num_sample_frames], \ + bg_image, \ + seq_idx, \ + frame_idx[:, start_frame_idx:start_frame_idx+num_sample_frames] + return new_batch + + +def indefinite_generator(loader): + while True: + for x in loader: + yield x + +def indefinite_generator_from_list(loaders): + while True: + random_idx = random.randint(0, len(loaders)-1) + for x in loaders[random_idx]: + yield x + break + +def definite_generator(loader): + for x in loader: + yield x + while True: + yield None + + +class TrainerDDP: + def __init__(self, cfgs, model): + self.cfgs = cfgs + self.is_dry_run = cfgs.get('is_dry_run', False) + + self.rank = cfgs.get('rank', 0) + self.world_size = cfgs.get('world_size', 1) + self.use_ddp = cfgs.get('use_ddp', True) + + self.device = cfgs.get('device', 'cpu') + self.num_epochs = cfgs.get('num_epochs', 1) + + # The logic is, if the num_iterations is set in the cfg + # for any 'epoch' in cfg, I rescale it to (epoch / 120) * epoch_now, as in horse exp + # for any 'iter' in cfg, I just keep them the same + self.num_iterations = cfgs.get('num_iterations', 0) + if self.num_iterations != 0: + self.use_total_iterations = True + else: + self.use_total_iterations = False + + self.num_sample_frames = cfgs.get('num_sample_frames', 100) + self.sample_frame_stride = cfgs.get('sample_frame_stride', 1) + self.checkpoint_dir = cfgs.get('checkpoint_dir', 'results') + self.save_checkpoint_freq = cfgs.get('save_checkpoint_freq', 1) + self.keep_num_checkpoint = cfgs.get('keep_num_checkpoint', 2) # -1 for keeping all checkpoints + self.resume = cfgs.get('resume', True) + self.reset_epoch = cfgs.get('reset_epoch', False) + self.finetune_ckpt = cfgs.get('finetune_ckpt', None) + # print('!!!!!!!!!!!!!!!!!!!!!!!!!!') + print(f'reset epoch: {self.reset_epoch}') + # print('!!!!!!!!!!!!!!!!!!!!!!!!!!') + self.use_logger = cfgs.get('use_logger', True) + self.log_freq_images = cfgs.get('log_freq_images', 1000) + self.log_train_images = cfgs.get('log_train_images', False) + self.log_freq_losses = cfgs.get('log_freq_losses', 100) + self.visualize_validation = cfgs.get('visualize_validation', False) + self.fix_viz_batch = cfgs.get('fix_viz_batch', False) + self.archive_code = cfgs.get('archive_code', True) + self.checkpoint_name = cfgs.get('checkpoint_name', None) + self.test_result_dir = cfgs.get('test_result_dir', None) + self.validate = cfgs.get('validate', False) + self.current_epoch = 0 + self.logger = None + self.viz_input = None + self.dataset = cfgs.get('dataset', 'video') + self.train_with_cub = cfgs.get('train_with_cub', False) + self.train_with_kaggle = cfgs.get('train_with_kaggle', False) + self.cub_start_epoch = cfgs.get('cub_start_epoch', 0) + + self.metrics_trace = meters.MetricsTrace() + self.make_metrics = lambda m=None: meters.StandardMetrics(m) + + self.batch_size = cfgs.get('batch_size', 64) + self.in_image_size = cfgs.get('in_image_size', 256) + self.out_image_size = cfgs.get('out_image_size', 256) + self.num_workers = cfgs.get('num_workers', 4) + self.run_train = cfgs.get('run_train', False) + self.train_data_dir = cfgs.get('train_data_dir', None) + self.val_data_dir = cfgs.get('val_data_dir', None) + self.run_test = cfgs.get('run_test', False) + self.test_data_dir = cfgs.get('test_data_dir', None) + self.flow_bool = cfgs.get('flow_bool', 0) + + if len(self.train_data_dir) <= 10 and len(self.val_data_dir) <= 10: + self.train_loader, self.val_loader, self.test_loader = model.get_data_loaders_ddp(cfgs, self.dataset, self.rank, self.world_size, in_image_size=self.in_image_size, out_image_size=self.out_image_size, batch_size=self.batch_size, num_workers=self.num_workers, run_train=self.run_train, run_test=self.run_test, train_data_dir=self.train_data_dir, val_data_dir=self.val_data_dir, test_data_dir=self.test_data_dir, flow_bool=self.flow_bool) + else: + # for 128 categories specific training + self.train_loader, self.val_loader, self.test_loader = self.get_efficient_data_loaders_ddp( + cfgs, + self.batch_size, + self.num_workers, + self.in_image_size, + self.out_image_size + ) + + print(self.train_loader, self.val_loader, self.test_loader) + if self.train_with_cub: + self.batch_size_cub = cfgs.get('batch_size_cub', 64) + self.data_dir_cub = cfgs.get('data_dir_cub', None) + self.train_loader_cub, self.val_loader_cub, self.test_loader_cub = model.get_data_loaders_ddp(cfgs, 'cub', self.rank, self.world_size, in_image_size=self.in_image_size, batch_size=self.batch_size_cub, num_workers=self.num_workers, run_train=self.run_train, run_test=self.run_test, train_data_dir=self.data_dir_cub, val_data_dir=self.data_dir_cub, test_data_dir=self.data_dir_cub) + if self.train_with_kaggle: + self.batch_size_kaggle = cfgs.get('batch_size_kaggle', 64) + self.data_dir_kaggle = cfgs.get('data_dir_kaggle', None) + self.train_loader_kaggle, self.val_loader_kaggle, self.test_loader_kaggle = model.get_data_loaders_ddp(cfgs, 'kaggle', self.rank, self.world_size, in_image_size=self.in_image_size, batch_size=self.batch_size_kaggle, num_workers=self.num_workers, run_train=self.run_train, run_test=self.run_test, train_data_dir=self.data_dir_kaggle, val_data_dir=self.data_dir_kaggle, test_data_dir=self.data_dir_kaggle) + + if self.use_total_iterations: + # reset the epoch related cfgs + + dataloader_length = max([len(loader) for loader in self.train_loader]) * len(self.train_loader) + print("Total length of data loader is: ", dataloader_length) + + total_epoch = int(self.num_iterations / dataloader_length) + 1 + + print(f'run for {total_epoch} epochs') + + print('is_main_process()?', misc.is_main_process()) + + for k, v in cfgs.items(): + if 'epoch' in k: + if isinstance(v, list): + new_v = [int(total_epoch * x / 120) + 1 for x in v] + cfgs[k] = new_v + elif isinstance(v, int): + new_v = int(total_epoch * v / 120) + 1 + cfgs[k] = new_v + else: + continue + + self.num_epochs = total_epoch + self.cub_start_epoch = cfgs.get('cub_start_epoch', 0) + self.cfgs = cfgs + + self.model = model(cfgs) + self.model.trainer = self + self.save_result_freq = cfgs.get('save_result_freq', None) + self.train_result_dir = osp.join(self.checkpoint_dir, 'results') + + self.use_wandb = cfgs.get('use_wandb', False) + + def get_efficient_data_loaders_ddp(self, cfgs, batch_size, num_workers, in_image_size, out_image_size): + train_loader = val_loader = test_loader = None + color_jitter_train = cfgs.get('color_jitter_train', None) + color_jitter_val = cfgs.get('color_jitter_val', None) + random_flip_train = cfgs.get('random_flip_train', False) + + data_loader_mode = cfgs.get('data_loader_mode', 'n_frame') + skip_beginning = cfgs.get('skip_beginning', 4) + skip_end = cfgs.get('skip_end', 4) + num_sample_frames = cfgs.get('num_sample_frames', 2) + min_seq_len = cfgs.get('min_seq_len', 10) + max_seq_len = cfgs.get('max_seq_len', 10) + debug_seq = cfgs.get('debug_seq', False) + random_sample_train_frames = cfgs.get('random_sample_train_frames', False) + shuffle_train_seqs = cfgs.get('shuffle_train_seqs', False) + random_sample_val_frames = cfgs.get('random_sample_val_frames', False) + load_background = cfgs.get('background_mode', 'none') == 'background' + rgb_suffix = cfgs.get('rgb_suffix', '.png') + load_dino_feature = cfgs.get('load_dino_feature', False) + load_dino_cluster = cfgs.get('load_dino_cluster', False) + dino_feature_dim = cfgs.get('dino_feature_dim', 64) + + enhance_back_view = cfgs.get('enhance_back_view', False) + enhance_back_view_path = cfgs.get('enhance_back_view_path', None) + + override_categories = None + + get_loader_ddp = lambda **kwargs: get_sequence_loader_quadrupeds( + mode=data_loader_mode, + num_workers=num_workers, + in_image_size=in_image_size, + out_image_size=out_image_size, + debug_seq=debug_seq, + skip_beginning=skip_beginning, + skip_end=skip_end, + num_sample_frames=num_sample_frames, + min_seq_len=min_seq_len, + max_seq_len=max_seq_len, + load_background=load_background, + rgb_suffix=rgb_suffix, + load_dino_feature=load_dino_feature, + load_dino_cluster=load_dino_cluster, + dino_feature_dim=dino_feature_dim, + flow_bool=0, + enhance_back_view=enhance_back_view, + enhance_back_view_path=enhance_back_view_path, + override_categories=override_categories, + **kwargs) + + # just the train now + print(f"Loading training data...") + val_image_num = cfgs.get('few_shot_val_image_num', 5) + # the train_data_dir is a dict and will go into the original dataset type + + #TODO: very hack here, directly assign first 7 as original categories + o_class = ["horse", "elephant", "zebra", "cow", "giraffe", "sheep", "bear"] + self.original_categories_paths = {} + self.few_shot_categories_paths = {} + self.original_val_data_path = {} + + for k,v in self.train_data_dir.items(): + if k in o_class: + self.original_categories_paths.update({k: v}) + self.original_val_data_path.update({k: self.val_data_dir[k]}) + else: + self.few_shot_categories_paths.update({k:v}) + self.new_classes_num = len(self.few_shot_categories_paths) + self.original_classes_num = len(self.original_categories_paths) + + train_loader = get_loader_ddp( + original_data_dirs=self.original_categories_paths, + few_shot_data_dirs=self.few_shot_categories_paths, + original_num=self.original_classes_num, + few_shot_num=self.new_classes_num, + rank=self.rank, + world_size=self.world_size, + batch_size=batch_size, + is_validation=False, + val_image_num=val_image_num, + shuffle=shuffle_train_seqs, + dense_sample=True, + color_jitter=color_jitter_train, + random_flip=random_flip_train + ) + val_loader = get_loader_ddp( + original_data_dirs=self.original_val_data_path, + few_shot_data_dirs=self.few_shot_categories_paths, + original_num=self.original_classes_num, + few_shot_num=self.new_classes_num, + rank=self.rank, + world_size=self.world_size, + batch_size=1, + is_validation=True, + val_image_num=val_image_num, + shuffle=False, + dense_sample=True, + color_jitter=color_jitter_val, + random_flip=False + ) + + test_loader = None + + return train_loader, val_loader, test_loader + + + def load_checkpoint(self, optim=True, ckpt_path=None): + """Search the specified/latest checkpoint in checkpoint_dir and load the model and optimizer.""" + if ckpt_path is not None: + checkpoint_path = ckpt_path + self.checkpoint_name = osp.basename(checkpoint_path) + elif self.checkpoint_name is not None: + checkpoint_path = osp.join(self.checkpoint_dir, self.checkpoint_name) + else: + checkpoints = sorted(glob.glob(osp.join(self.checkpoint_dir, '*.pth'))) + if len(checkpoints) == 0: + return 0, 0 + checkpoint_path = checkpoints[-1] + self.checkpoint_name = osp.basename(checkpoint_path) + + print(f"Loading checkpoint from {checkpoint_path}") + cp = torch.load(checkpoint_path, map_location=self.device) + # print(cp) + self.model.load_model_state(cp) + if optim: + self.model.load_optimizer_state(cp) + self.metrics_trace = cp['metrics_trace'] + epoch = cp['epoch'] + total_iter = cp['total_iter'] + + if 'classes_vectors' in cp: + self.model.classes_vectors = cp['classes_vectors'] + + return epoch, total_iter + + def save_checkpoint(self, epoch, total_iter=0, optim=True): + """Save model, optimizer, and metrics state to a checkpoint in checkpoint_dir for the specified epoch.""" + misc.xmkdir(self.checkpoint_dir) + checkpoint_path = osp.join(self.checkpoint_dir, f'checkpoint{epoch:03}.pth') + state_dict = self.model.get_model_state() + if optim: + optimizer_state = self.model.get_optimizer_state() + state_dict = {**state_dict, **optimizer_state} + state_dict['metrics_trace'] = self.metrics_trace + state_dict['epoch'] = epoch + state_dict['total_iter'] = total_iter + print(f"Saving checkpoint to {checkpoint_path}") + torch.save(state_dict, checkpoint_path) + if self.keep_num_checkpoint > 0: + misc.clean_checkpoint(self.checkpoint_dir, keep_num=self.keep_num_checkpoint) + + def save_last_checkpoint(self, epoch, total_iter=0, optim=True): + """Save model, optimizer, and metrics state to a checkpoint in checkpoint_dir for the specified epoch.""" + misc.xmkdir(self.checkpoint_dir) + checkpoint_path = osp.join(self.checkpoint_dir, 'last.pth') + if os.path.exists(checkpoint_path): + os.remove(checkpoint_path) + state_dict = self.model.get_model_state() + if optim: + optimizer_state = self.model.get_optimizer_state() + state_dict = {**state_dict, **optimizer_state} + state_dict['metrics_trace'] = self.metrics_trace + state_dict['epoch'] = epoch + state_dict['total_iter'] = total_iter + print(f"Saving checkpoint to {checkpoint_path}") + torch.save(state_dict, checkpoint_path) + + def save_clean_checkpoint(self, path): + """Save model state only to specified path.""" + torch.save(self.model.get_model_state(), path) + + def test(self): + """Perform testing.""" + self.model.to(self.device) + epoch, self.total_iter = self.load_checkpoint(optim=False) + + if self.use_ddp: + self.model.ddp(self.rank, self.world_size) + self.model.set_eval() + + if self.test_result_dir is None: + self.test_result_dir = osp.join(self.checkpoint_dir, f'test_results_{self.checkpoint_name}'.replace('.pth', '')) + print(f"Saving testing results to {self.test_result_dir}") + + with torch.no_grad(): + for iteration, batch in enumerate(self.test_loader): + m = self.model.forward(batch, epoch=epoch, iter=iteration, total_iter=self.total_iter, save_results=True, save_dir=self.test_result_dir, which_data=self.dataset, is_training=False) + print(f"T{epoch:04}/{iteration:05}") + + score_path = osp.join(self.test_result_dir, 'all_metrics.txt') + # self.model.save_scores(score_path) + + def train(self): + """Perform training.""" + # archive code and configs + if self.archive_code: + misc.archive_code(osp.join(self.checkpoint_dir, 'archived_code.zip'), filetypes=['.py']) + misc.dump_yaml(osp.join(self.checkpoint_dir, 'configs.yml'), self.cfgs) + + # initialize + start_epoch = 0 + self.total_iter = 0 + self.metrics_trace.reset() + self.model.to(self.device) + self.model.reset_optimizers() + + # resume from checkpoint + # from IPython import embed; embed() + if self.resume: + start_epoch, self.total_iter = self.load_checkpoint(optim=True) + + if self.reset_epoch: + start_epoch = 0 + self.total_iter = 0 + + if start_epoch == 0 and self.total_iter ==0 and self.finetune_ckpt is not None: + _, _ = self.load_checkpoint(optim=True, ckpt_path=self.finetune_ckpt) + + # distribute model + if self.use_ddp: + self.model.ddp(self.rank, self.world_size) + + # train with cub + if self.train_with_cub: + self.cub_train_data_iterator = indefinite_generator(self.train_loader_cub) + + # initialize tensorboard logger + if misc.is_main_process() and self.use_logger: + if self.use_wandb: + import wandb + wandb.tensorboard.patch(root_logdir=osp.join(self.checkpoint_dir, 'logs', datetime.now().strftime("%Y%m%d-%H%M%S"))) + wandb.init(name=self.checkpoint_dir.split("/")[-1], project="APT36K") + from torch.utils.tensorboard import SummaryWriter + self.logger = SummaryWriter(osp.join(self.checkpoint_dir, 'logs', datetime.now().strftime("%Y%m%d-%H%M%S")), flush_secs=10) + self.viz_data_iterator = indefinite_generator_from_list(self.val_loader) if self.visualize_validation else indefinite_generator_from_list(self.train_loader) + # self.viz_data_iterator = iter(self.viz_data_iterator) + if self.fix_viz_batch: + self.viz_batch = next(self.viz_data_iterator) + + # train with cub + if self.train_with_cub: + self.cub_viz_data_iterator = indefinite_generator(self.val_loader_cub) if self.visualize_validation else indefinite_generator(self.train_loader_cub) + if self.fix_viz_batch: + self.viz_batch_cub = next(self.cub_viz_data_iterator) + + # run epochs + epoch = 0 + for epoch in range(start_epoch, self.num_epochs): + torch.distributed.barrier() + metrics = self.run_epoch(epoch) + if self.rank == 0: + self.metrics_trace.append("train", metrics) + if (epoch+1) % self.save_checkpoint_freq == 0: + self.save_checkpoint(epoch+1, total_iter=self.total_iter, optim=True) + if self.cfgs.get('pyplot_metrics', True): + self.metrics_trace.plot(pdf_path=osp.join(self.checkpoint_dir, 'metrics.pdf')) + self.metrics_trace.save(osp.join(self.checkpoint_dir, 'metrics.json')) + if self.rank == 0: + print(f"Training completed for all {epoch+1} epochs.") + + def dry_run(self): + print(f'rank: {self.rank}, dry_run!!!!!') + self.dry_run_iters = self.cfgs.get('dr_iters', 2) + self.resume = self.cfgs.get('dr_resume', True) + self.use_logger = self.cfgs.get('dr_use_logger', True) + self.log_freq_losses = self.cfgs.get('dr_log_freq_losses', 1) + self.save_result_freq = self.cfgs.get('dr_save_result_freq', 1) + self.log_freq_images = self.cfgs.get('dr_log_freq_images', 1) + self.log_train_images = self.cfgs.get('dr_log_train_images', True) + self.visualize_validation = self.cfgs.get('dr_visualize_validation', True) + self.num_epochs = self.cfgs.get('dr_num_epochs', 1) + self.train() + + def run_epoch(self, epoch): + metrics = self.make_metrics() + + self.model.set_train() + + max_loader_len = max([len(loader) for loader in self.train_loader]) + train_generators = [indefinite_generator(loader) for loader in self.train_loader] + + iteration = 0 + while iteration < max_loader_len * len(self.train_loader): + for generator in train_generators: + batch = next(generator) + + self.total_iter += 1 + + if self.total_iter % 4000 == 0: + self.save_last_checkpoint(epoch+1, self.total_iter, optim=True) + + num_seqs, num_frames = batch[0].shape[:2] + total_im_num = num_seqs * num_frames + m = self.model.forward(batch, epoch=epoch, iter=iteration, total_iter=self.total_iter, which_data=self.dataset, is_training=True) + + if self.train_with_cub and epoch >= self.cub_start_epoch: + batch_cub = next(self.cub_train_data_iterator) + num_seqs, num_frames = batch_cub[0].shape[:2] + total_im_num += num_seqs * num_frames + m_cub = self.model.forward(batch_cub, epoch=epoch, iter=iteration, total_iter=self.total_iter, which_data='cub', is_training=True) + m.update({'cub_'+k: v for k,v in m_cub.items()}) + m['total_loss'] = self.model.total_loss + + self.model.backward() + + if self.model.enable_disc and (self.model.mask_discriminator_iter[0] < self.total_iter) and (self.model.mask_discriminator_iter[1] > self.total_iter): + # the discriminator training + discriminator_loss_dict, grad_loss = self.model.discriminator_step() + m.update( + { + 'mask_disc_loss_discriminator': discriminator_loss_dict['discriminator_loss'] - grad_loss, + 'mask_disc_loss_discriminator_grad': grad_loss, + 'mask_disc_loss_discriminator_rv': discriminator_loss_dict['discriminator_loss_rv'], + 'mask_disc_loss_discriminator_iv': discriminator_loss_dict['discriminator_loss_iv'], + 'mask_disc_loss_discriminator_gt': discriminator_loss_dict['discriminator_loss_gt'] + } + ) + self.logger.add_histogram('train_'+'discriminator_logits/random_view', discriminator_loss_dict['d_rv'], self.total_iter) + if discriminator_loss_dict['d_iv'] is not None: + self.logger.add_histogram('train_'+'discriminator_logits/input_view', discriminator_loss_dict['d_iv'], self.total_iter) + if discriminator_loss_dict['d_gt'] is not None: + self.logger.add_histogram('train_'+'discriminator_logits/gt_view', discriminator_loss_dict['d_gt'], self.total_iter) + + metrics.update(m, total_im_num) + if self.rank == 0: + print(f"T{epoch:04}/{iteration:05}/{metrics}") + + ## reset optimizers + if self.cfgs.get('opt_reset_every_iter', 0) > 0 and self.total_iter < self.cfgs.get('opt_reset_end_iter', 0): + if self.total_iter % self.cfgs.get('opt_reset_every_iter', 0) == 0: + self.model.reset_optimizers() + + if misc.is_main_process() and self.use_logger: + if self.rank == 0 and self.total_iter % self.log_freq_losses == 0: + for name, loss in m.items(): + label = f'cub_loss_train/{name[4:]}' if 'cub' in name else f'loss_train/{name}' + self.logger.add_scalar(label, loss, self.total_iter) + if self.rank == 0 and self.save_result_freq is not None and self.total_iter % self.save_result_freq == 0: + with torch.no_grad(): + m = self.model.forward(batch, epoch=epoch, iter=iteration, total_iter=self.total_iter, save_results=True, save_dir=self.train_result_dir, which_data=self.dataset, is_training=False) + torch.cuda.empty_cache() + if self.total_iter % self.log_freq_images == 0: + with torch.no_grad(): + if self.rank == 0 and self.log_train_images: + m = self.model.forward(batch, epoch=epoch, iter=iteration, viz_logger=self.logger, total_iter=self.total_iter, which_data=self.dataset, logger_prefix='train_', is_training=False) + if self.fix_viz_batch: + print(f'fix_viz_batch:{self.fix_viz_batch}') + batch = self.viz_batch + else: + batch = next(self.viz_data_iterator) + if self.visualize_validation: + import time + vis_start = time.time() + batch = next(self.viz_data_iterator) + # try: + # batch = next(self.viz_data_iterator) + # except: # iterator exhausted + # self.reset_viz_data_iterator() + # batch = next(self.viz_data_iterator) + m = self.model.forward(batch, epoch=epoch, iter=iteration, viz_logger=self.logger, total_iter=self.total_iter, which_data=self.dataset, logger_prefix='val_', is_training=False) + vis_end = time.time() + print(f"vis time: {vis_end - vis_start}") + for name, loss in m.items(): + if self.rank == 0: + self.logger.add_scalar(f'loss_val/{name}', loss, self.total_iter) + + if self.train_with_cub and epoch >= self.cub_start_epoch: + if self.rank == 0 and self.log_train_images: + m = self.model.forward(batch_cub, epoch=epoch, iter=iteration, viz_logger=self.logger, total_iter=self.total_iter, which_data='cub', logger_prefix='cub_train_', is_training=True) + + if self.fix_viz_batch: + batch_cub = self.viz_batch_cub + elif self.visualize_validation: + batch_cub = next(self.cub_viz_data_iterator) + # try: + # batch = next(self.viz_data_iterator) + # except: # iterator exhausted + # self.reset_viz_data_iterator() + # batch = next(self.viz_data_iterator) + if self.rank == 0: + m = self.model.forward(batch_cub, epoch=epoch, iter=iteration, viz_logger=self.logger, total_iter=self.total_iter, which_data='cub', logger_prefix='cub_val_', is_training=False) + for name, loss in m.items(): + self.logger.add_scalar(f'cub_loss_val/{name}', loss, self.total_iter) + torch.cuda.empty_cache() + if self.is_dry_run and iteration >= self.dry_run_iters: + break + + iteration += 1 + + self.model.scheduler_step() + return metrics diff --git a/video3d/trainer_few_shot.py b/video3d/trainer_few_shot.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a654c29dfeaffaaebb1f55fe3e409d0e1ba6d5 --- /dev/null +++ b/video3d/trainer_few_shot.py @@ -0,0 +1,1110 @@ +import os +import os.path as osp +from copy import deepcopy +from collections import OrderedDict +import glob +from datetime import datetime +import random +import copy +import imageio +import torch +import clip +import torchvision.transforms.functional as tvf +import video3d.utils.meters as meters +import video3d.utils.misc as misc +# from video3d.dataloaders import get_image_loader +from video3d.dataloaders_ddp import get_sequence_loader_ddp, get_sequence_loader_quadrupeds, get_test_loader_quadrupeds +from . import discriminator_architecture + + +def sample_frames(batch, num_sample_frames, iteration, stride=1): + ## window slicing sampling + images, masks, flows, bboxs, bg_image, seq_idx, frame_idx = batch + num_seqs, total_num_frames = images.shape[:2] + # start_frame_idx = iteration % (total_num_frames - num_sample_frames +1) + + ## forward and backward + num_windows = total_num_frames - num_sample_frames +1 + start_frame_idx = (iteration * stride) % (2*num_windows) + ## x' = (2n-1)/2 - |(2n-1)/2 - x| : 0,1,2,3,4,5 -> 0,1,2,2,1,0 + mid_val = (2*num_windows -1) /2 + start_frame_idx = int(mid_val - abs(mid_val -start_frame_idx)) + + new_batch = images[:, start_frame_idx:start_frame_idx+num_sample_frames], \ + masks[:, start_frame_idx:start_frame_idx+num_sample_frames], \ + flows[:, start_frame_idx:start_frame_idx+num_sample_frames-1], \ + bboxs[:, start_frame_idx:start_frame_idx+num_sample_frames], \ + bg_image, \ + seq_idx, \ + frame_idx[:, start_frame_idx:start_frame_idx+num_sample_frames] + return new_batch + + +def indefinite_generator(loader): + while True: + for x in loader: + yield x + + +def indefinite_generator_from_list(loaders): + while True: + random_idx = random.randint(0, len(loaders)-1) + for x in loaders[random_idx]: + yield x + break + + +def get_optimizer(model, lr=0.0001, betas=(0.9, 0.999), weight_decay=0): + return torch.optim.Adam( + filter(lambda p: p.requires_grad, model.parameters()), + lr=lr, betas=betas, weight_decay=weight_decay) + + +class Fewshot_Trainer: + def __init__(self, cfgs, model): + # only now supports one gpu + self.cfgs = cfgs + # here should be the one gpu ddp setting + self.rank = cfgs.get('rank', 0) + self.world_size = cfgs.get('world_size', 1) + self.use_ddp = cfgs.get('use_ddp', True) + self.device = cfgs.get('device', 'cpu') + + self.num_epochs = cfgs.get('num_epochs', 1) + self.lr = cfgs.get('few_shot_lr', 1e-4) + self.dataset = 'image' + + self.metrics_trace = meters.MetricsTrace() + self.make_metrics = lambda m=None: meters.StandardMetrics(m) + + self.archive_code = cfgs.get('archive_code', True) + self.batch_size = cfgs.get('batch_size', 64) + self.in_image_size = cfgs.get('in_image_size', 256) + self.out_image_size = cfgs.get('out_image_size', 256) + self.num_workers = cfgs.get('num_workers', 4) + self.checkpoint_dir = cfgs.get('checkpoint_dir', 'results') + misc.xmkdir(self.checkpoint_dir) + self.few_shot_resume = cfgs.get('few_shot_resume', False) + self.save_checkpoint_freq = cfgs.get('save_checkpoint_freq', 1) + self.keep_num_checkpoint = cfgs.get('keep_num_checkpoint', 2) # -1 for keeping all checkpoints + + self.few_shot_data_dir = cfgs.get('few_shot_data_dir', None) + assert self.few_shot_data_dir is not None + # in case we add more data source + if isinstance(self.few_shot_data_dir, list): + self.few_shot_data_dir_more = self.few_shot_data_dir[1:] + self.few_shot_data_dir = self.few_shot_data_dir[0] + else: + self.few_shot_data_dir_more = None + + assert "data_resize_update" in self.few_shot_data_dir # TODO: a hack way to make sure not using wrong data, needs to remove + self.few_shot_categories, self.few_shot_categories_paths = self.parse_few_shot_categories(self.few_shot_data_dir, self.few_shot_data_dir_more) + + # if we need test categories, we pop it from self.few_shot_categories and self.few_shot_categories_path + # the test category needs to be category from few-shot, and we're using bs=1 on them, no need for back views enhancement (for now, use back view images, but don't duplicate them) + self.test_category_num = cfgs.get('few_shot_test_category_num', 0) + self.test_category_names = cfgs.get('few_shot_test_category_names', None) + if self.test_category_num > 0: + # if we have valid test_category names, then use them, the number doesn't need to be equal + if self.test_category_names is not None: + test_cats = self.test_category_names + else: + test_cats = list(self.few_shot_categories_paths.keys())[-(self.test_category_num):] + test_categories_paths = {} + for test_cat in test_cats: + test_categories_paths.update({test_cat: self.few_shot_categories_paths[test_cat]}) + assert test_cat in self.few_shot_categories + self.few_shot_categories.remove(test_cat) + self.few_shot_categories_paths.pop(test_cat) + + self.test_categories_paths = test_categories_paths + else: + self.test_categories_paths = None + + # also load the original 7 categories + self.original_train_data_path = cfgs.get('train_data_dir', None) + self.original_val_data_path = cfgs.get('val_data_dir', None) + self.original_categories = [] + self.original_categories_paths = self.original_train_data_path + for k, v in self.original_train_data_path.items(): + self.original_categories.append(k) + + self.categories = self.original_categories + self.few_shot_categories + self.categories_paths = self.original_train_data_path.copy() + self.categories_paths.update(self.few_shot_categories_paths) + + print(f'Using {len(self.categories)} cateogires: ', self.categories) + + # initialize new things + # self.original_classes_num = cfgs.get('few_shot_original_classes_num', 7) + self.original_classes_num = len(self.original_categories) + self.new_classes_num = len(self.categories) - self.original_classes_num + + self.combine_dataset = cfgs.get('combine_dataset', False) + assert self.combine_dataset, "we should use combine dataset, it's up to date" + if self.combine_dataset: + self.train_loader, self.val_loader, self.test_loader = self.get_data_loaders_quadrupeds(self.cfgs, self.batch_size, self.num_workers, self.in_image_size, self.out_image_size) + else: + self.train_loader_few_shot, self.val_loader_few_shot = self.get_data_loaders_few_shot(self.cfgs, self.batch_size, self.num_workers, self.in_image_size, self.out_image_size) + self.train_loader_original, self.val_loader_original = self.get_data_loaders_original(self.cfgs, self.batch_size, self.num_workers, self.in_image_size, self.out_image_size) + self.train_loader = self.train_loader_original + self.train_loader_few_shot + if self.val_loader_few_shot is not None and self.val_loader_original is not None: + self.val_loader = self.val_loader_original + self.val_loader_few_shot + + self.num_iterations = cfgs.get('num_iterations', 0) + if self.num_iterations != 0: + self.use_total_iterations = True + else: + self.use_total_iterations = False + if self.use_total_iterations: + # reset the epoch related cfgs + + dataloader_length = max([len(loader) for loader in self.train_loader]) * len(self.train_loader) + print("Total length of data loader is: ", dataloader_length) + + total_epoch = int(self.num_iterations / dataloader_length) + 1 + + print(f'run for {total_epoch} epochs') + + print('is_main_process()?', misc.is_main_process()) + + for k, v in cfgs.items(): + if 'epoch' in k: + # if isinstance(v, list): + # new_v = [int(total_epoch * x / 120) + 1 for x in v] + # cfgs[k] = new_v + # elif isinstance(v, int): + # new_v = int(total_epoch * v / 120) + 1 + # cfgs[k] = new_v + + # a better transformation + if isinstance(v, int): + # use the floor int + new_v = int(total_epoch * v / 120) + cfgs[k] = new_v + elif isinstance(v, list): + if v[0] == v[1]: + # if the values in v are the same, then we use both the floor value + new_v = [int(total_epoch * x / 120) for x in v] + else: + # if the values are not the same, make the first using floor value and others using ceil value + new_v = [int(total_epoch * x / 120) + 1 for x in v] + new_v[0] = new_v[0] - 1 + cfgs[k] = new_v + else: + continue + + self.num_epochs = total_epoch + self.cub_start_epoch = cfgs.get('cub_start_epoch', 0) + self.cfgs = cfgs + + # the model is with nothing now + self.model = model(cfgs) + + self.metrics_trace = meters.MetricsTrace() + self.make_metrics = lambda m=None: meters.StandardMetrics(m) + + self.use_logger = True + self.log_freq_images = cfgs.get('log_freq_images', 1000) + self.log_train_images = cfgs.get('log_train_images', False) + self.log_freq_losses = cfgs.get('log_freq_losses', 100) + self.save_result_freq = cfgs.get('save_result_freq', None) + self.train_result_dir = osp.join(self.checkpoint_dir, 'results') + self.fix_viz_batch = cfgs.get('fix_viz_batch', False) + self.visualize_validation = cfgs.get('visualize_validation', False) + # self.visualize_validation = False + self.iteration_save = cfgs.get('few_shot_iteration_save', False) + self.iteration_save_freq = cfgs.get('few_shot_iteration_save_freq', 2000) + + self.enable_memory_bank = cfgs.get('enable_memory_bank', False) + if self.enable_memory_bank: + self.memory_bank_dim = 128 + self.memory_bank_size = cfgs.get('memory_bank_size', 60) + self.memory_bank_topk = cfgs.get('memory_bank_topk', 10) + # assert self.memory_bank_topk < self.memory_bank_size + assert self.memory_bank_topk <= self.memory_bank_size + self.memory_retrieve = cfgs.get('memory_retrieve', 'cos-linear') + + self.memory_bank_init = cfgs.get('memory_bank_init', 'random') + if self.memory_bank_init == 'copy': + # use trained 7 embeddings to initialize + num_piece = self.memory_bank_size // self.original_classes_num + num_left = self.memory_bank_size - num_piece * self.original_classes_num + + tmp_1 = torch.empty_like(self.model.netPrior.classes_vectors) + tmp_1 = tmp_1.copy_(self.model.netPrior.classes_vectors) + tmp_1 = tmp_1.unsqueeze(0).repeat(num_piece, 1, 1) + tmp_1 = tmp_1.reshape(tmp_1.shape[0] * tmp_1.shape[1], tmp_1.shape[-1]) + + if num_left > 0: + tmp_2 = torch.empty_like(self.model.netPrior.classes_vectors) + tmp_2 = tmp_2.copy_(self.model.netPrior.classes_vectors) + tmp_2 = tmp_2[:num_left] + tmp = torch.cat([tmp_1, tmp_2], dim=0) + else: + tmp = tmp_1 + + self.memory_bank = torch.nn.Parameter(tmp, requires_grad=True) + + elif self.memory_bank_init == 'random': + self.memory_bank = torch.nn.Parameter(torch.nn.init.uniform_(torch.empty(self.memory_bank_size, self.memory_bank_dim), a=-0.05, b=0.05), requires_grad=True) + else: + raise NotImplementedError + + self.memory_encoder = cfgs.get('memory_encoder', 'DINO') # if DINO then just use the network encoder + if self.memory_encoder == 'CLIP': + self.clip_model, _ = clip.load('ViT-B/32', self.device) + self.clip_model = self.clip_model.eval().requires_grad_(False) + self.clip_mean = [0.48145466, 0.4578275, 0.40821073] + self.clip_std = [0.26862954, 0.26130258, 0.27577711] + self.clip_reso = 224 + + self.memory_bank_keys_dim = 512 + + elif self.memory_encoder == 'DINO': + self.memory_bank_keys_dim = 384 + + else: + raise NotImplementedError + + memory_bank_keys = torch.nn.init.uniform_(torch.empty(self.memory_bank_size, self.memory_bank_keys_dim), a=-0.05, b=0.05) + self.memory_bank_keys = torch.nn.Parameter(memory_bank_keys, requires_grad=True) + + else: + print("no memory bank, just use image embedding, this is only for one experiment!") + self.memory_encoder = cfgs.get('memory_encoder', 'DINO') # if DINO then just use the network encoder + if self.memory_encoder == 'CLIP': + self.clip_model, _ = clip.load('ViT-B/32', self.device) + self.clip_model = self.clip_model.eval().requires_grad_(False) + self.clip_mean = [0.48145466, 0.4578275, 0.40821073] + self.clip_std = [0.26862954, 0.26130258, 0.27577711] + self.clip_reso = 224 + + self.memory_bank_keys_dim = 512 + + elif self.memory_encoder == 'DINO': + self.memory_bank_keys_dim = 384 + + else: + raise NotImplementedError + + self.prepare_model() + + def parse_few_shot_categories(self, data_dir, data_dir_more=None): + # parse the categories data_dir + few_shot_category_num = self.cfgs.get('few_shot_category_num', -1) + assert few_shot_category_num != 0 + categories = sorted(os.listdir(data_dir)) + cnt = 0 + category_names = [] + category_names_paths = {} + for category in categories: + if osp.isdir(osp.join(self.few_shot_data_dir, category, 'train')): + category_path = osp.join(self.few_shot_data_dir, category, 'train') + category_names.append(category) + category_names_paths.update({category: category_path}) + cnt += 1 + if few_shot_category_num > 0 and cnt >= few_shot_category_num: + break + + # more data + if data_dir_more is not None: + for data_dir_one in data_dir_more: + new_categories = os.listdir(data_dir_one) + for new_category in new_categories: + ''' + if this category is not used before, add a new item + if there is this category before, add the paths to original paths, + if its a str, make it a list + if its already a list, append it + ''' + if new_category not in category_names: + + #TODO: a hacky way here, if in new data there is category used in 7-cat, we just make it a new one + if new_category in list(self.cfgs.get('train_data_dir', None).keys()): + new_category = '_' + new_category + + category_names.append(new_category) + category_names_paths.update({ + new_category: osp.join(data_dir_one, new_category, 'train') + }) + else: + old_category_path = category_names_paths[new_category] + if isinstance(old_category_path, str): + category_names_paths[new_category] = [ + old_category_path, + osp.join(data_dir_one, new_category, 'train') + ] + elif isinstance(old_category_path, list): + old_category_path = old_category_path + [osp.join(data_dir_one, new_category, 'train')] + category_names_paths[new_category] = old_category_path + else: + raise NotImplementedError + + # category_names = sorted(category_names) + + return category_names, category_names_paths + + def prepare_model(self): + # here we prepare the model weights at outside + # 1. load the pretrain weight + # 2. initialize anything new, like new class vectors + # 3. initialize new optimizer for chosen parameters + + assert self.original_classes_num == len(self.model.netPrior.category_id_map) + + # load pretrain + # if not assigned few_shot_checkpoint_name, then skip this part + if self.cfgs.get('few_shot_checkpoint_name', None) is not None: + original_checkpoint_path = osp.join(self.checkpoint_dir, self.cfgs.get('few_shot_checkpoint_name', 'checkpoint060.pth')) + assert osp.exists(original_checkpoint_path) + print(f"Loading pre-trained checkpoint from {original_checkpoint_path}") + cp = torch.load(original_checkpoint_path, map_location=self.device) + + # if using local-texture network in fine-tuning, the texture in previous pre-train ckpt is global + # here we use a hack way, we just get rid of original texture ckpt + if (self.cfgs.get('texture_way', None) is not None) or (self.cfgs.get('texture_act', 'relu') != 'relu'): + new_netInstance_weights = {k: v for k, v in cp['netInstance'].items() if 'netTexture' not in k} + #find the new texture weights + texture_weights = self.model.netInstance.netTexture.state_dict() + #add the new weights to the new model weights + for k, v in texture_weights.items(): + # for the overlapping part in netTexture, we also use them + # if ('netTexture.' + k) in cp['netInstance'].keys(): + # new_netInstance_weights['netTexture.' + k] = cp['netInstance']['netTexture.' + k] + # else: + # new_netInstance_weights['netTexture.' + k] = v + new_netInstance_weights['netTexture.' + k] = v + _ = cp.pop("netInstance") + cp.update({"netInstance": new_netInstance_weights}) + + self.model.netInstance.load_state_dict(cp["netInstance"], strict=False) # For Deform + # self.model.netInstance.load_state_dict(cp["netInstance"]) + self.model.netPrior.load_state_dict(cp["netPrior"]) + + self.original_total_iter = cp["total_iter"] + + else: + print("not load any pre-train weight, the iter will start from 0, make sure you set all the needed parameters") + self.original_total_iter = 0 + + if not self.cfgs.get('disable_fewshot', False): + for i, category in enumerate(self.few_shot_categories): + category_id = self.original_classes_num + i + self.model.netPrior.category_id_map.update({category: category_id}) + + few_shot_class_vector_init = self.cfgs.get('few_shot_class_vector_init', 'random') + if few_shot_class_vector_init == 'random': + tmp = torch.nn.init.uniform_(torch.empty(self.new_classes_num, self.model.netPrior.classes_vectors.shape[-1]), a=-0.05, b=0.05) + tmp = tmp.to(self.model.netPrior.classes_vectors.device) + self.model.netPrior.classes_vectors = torch.nn.Parameter(torch.cat([self.model.netPrior.classes_vectors, tmp], dim=0)) + elif few_shot_class_vector_init == 'copy': + num_7_cat_piece = self.new_classes_num // self.original_classes_num if self.new_classes_num > self.original_classes_num else 0 + num_left = self.new_classes_num - num_7_cat_piece * self.original_classes_num + + if num_7_cat_piece > 0: + tmp_1 = torch.empty_like(self.model.netPrior.classes_vectors) + tmp_1 = tmp_1.copy_(self.model.netPrior.classes_vectors) + tmp_1 = tmp_1.unsqueeze(0).repeat(num_7_cat_piece, 1, 1) + tmp_1 = tmp_1.reshape(tmp_1.shape[0] * tmp_1.shape[1], tmp_1.shape[-1]) + else: + tmp_1 = None + + if num_left > 0: + tmp_2 = torch.empty_like(self.model.netPrior.classes_vectors) + tmp_2 = tmp_2.copy_(self.model.netPrior.classes_vectors) + tmp_2 = tmp_2[:num_left] + else: + tmp_2 = None + + if tmp_1 != None and tmp_2 != None: + tmp = torch.cat([tmp_1, tmp_2], dim=0) + elif tmp_1 == None and tmp_2 != None: + tmp = tmp_2 + elif tmp_2 == None and tmp_1 != None: + tmp = tmp_1 + else: + raise NotImplementedError + + tmp = tmp.to(self.model.netPrior.classes_vectors.device) + self.model.netPrior.classes_vectors = torch.nn.Parameter(torch.cat([self.model.netPrior.classes_vectors, tmp], dim=0)) + else: + raise NotImplementedError + + else: + print("disable few shot, not increasing embedding vectors") + + # initialize new optimizer + optimize_rule = self.cfgs.get('few_shot_optimize', 'all') + if optimize_rule == 'all': + optimize_list = [ + {'name': 'net_Prior', 'params': list(self.model.netPrior.parameters()), 'lr': self.lr * 10.}, + {'name': 'net_Instance', 'params': list(self.model.netInstance.parameters()), 'lr': self.lr * 1.}, + ] + elif optimize_rule == 'only-emb': + optimize_list = [ + {'name': 'class_embeddings', 'params': list([self.model.netPrior.classes_vectors]), 'lr': self.lr * 10.} + ] + elif optimize_rule == 'emb-instance': + optimize_list = [ + {'name': 'class_embeddings', 'params': list([self.model.netPrior.classes_vectors]), 'lr': self.lr * 10.}, + {'name': 'net_Instance', 'params': list(self.model.netInstance.parameters()), 'lr': self.lr * 1.}, + ] + elif optimize_rule == 'custom': + optimize_list = [ + {'name': 'net_Prior', 'params': list(self.model.netPrior.parameters()), 'lr': self.lr * 10.}, + {'name': 'netEncoder', 'params': list(self.model.netInstance.netEncoder.parameters()), 'lr': self.lr * 1.}, + {'name': 'netTexture', 'params': list(self.model.netInstance.netTexture.parameters()), 'lr': self.lr * 1.}, + {'name': 'netPose', 'params': list(self.model.netInstance.netPose.parameters()), 'lr': self.lr * 0.01}, + {'name': 'netArticulation', 'params': list(self.model.netInstance.netArticulation.parameters()), 'lr': self.lr * 1.}, + {'name': 'netLight', 'params': list(self.model.netInstance.netLight.parameters()), 'lr': self.lr * 1.} + ] + elif optimize_rule == 'custom-deform': + optimize_list = [ + {'name': 'net_Prior', 'params': list(self.model.netPrior.parameters()), 'lr': self.lr * 10.}, + {'name': 'netEncoder', 'params': list(self.model.netInstance.netEncoder.parameters()), 'lr': self.lr * 1.}, + {'name': 'netTexture', 'params': list(self.model.netInstance.netTexture.parameters()), 'lr': self.lr * 1.}, + {'name': 'netPose', 'params': list(self.model.netInstance.netPose.parameters()), 'lr': self.lr * 0.01}, + {'name': 'netArticulation', 'params': list(self.model.netInstance.netArticulation.parameters()), 'lr': self.lr * 1.}, + {'name': 'netLight', 'params': list(self.model.netInstance.netLight.parameters()), 'lr': self.lr * 1.}, + {'name': 'netDeform', 'params': list(self.model.netInstance.netDeform.parameters()), 'lr': self.lr * 1.} + ] + elif optimize_rule == 'texture': + optimize_list = [ + {'name': 'netTexture', 'params': list(self.model.netInstance.netTexture.parameters()), 'lr': self.lr * 1.} + ] + elif optimize_rule == 'texture-light': + optimize_list = [ + {'name': 'netTexture', 'params': list(self.model.netInstance.netTexture.parameters()), 'lr': self.lr * 1.}, + {'name': 'netLight', 'params': list(self.model.netInstance.netLight.parameters()), 'lr': self.lr * 1.} + ] + elif optimize_rule == 'exp': + optimize_list = [ + {'name': 'net_Prior', 'params': list(self.model.netPrior.parameters()), 'lr': self.lr * 10.}, + {'name': 'netEncoder', 'params': list(self.model.netInstance.netEncoder.parameters()), 'lr': self.lr * 1.}, + {'name': 'netTexture', 'params': list(self.model.netInstance.netTexture.parameters()), 'lr': self.lr * 1.}, + {'name': 'netPose', 'params': list(self.model.netInstance.netPose.parameters()), 'lr': self.lr * 1.}, + {'name': 'netArticulation', 'params': list(self.model.netInstance.netArticulation.parameters()), 'lr': self.lr * 1.}, + {'name': 'netLight', 'params': list(self.model.netInstance.netLight.parameters()), 'lr': self.lr * 1.}, + {'name': 'netDeform', 'params': list(self.model.netInstance.netDeform.parameters()), 'lr': self.lr * 1.} + ] + else: + raise NotImplementedError + + if self.enable_memory_bank and optimize_rule != 'texture': + + optimize_bank_components = self.cfgs.get('few_shot_optimize_bank', 'all') + if optimize_bank_components == 'value': + optimize_list += [ + {'name': 'memory_bank', 'params': list([self.memory_bank]), 'lr': self.lr * 10.} + ] + elif optimize_bank_components == 'key': + optimize_list += [ + {'name': 'memory_bank_keys', 'params': list([self.memory_bank_keys]), 'lr': self.lr * 10.} + ] + else: + optimize_list += [ + {'name': 'memory_bank', 'params': list([self.memory_bank]), 'lr': self.lr * 10.}, + {'name': 'memory_bank_keys', 'params': list([self.memory_bank_keys]), 'lr': self.lr * 10.} + ] + + if self.model.enable_vsd: + optimize_list += [ + {'name': 'lora', 'params': list(self.model.stable_diffusion.parameters()), 'lr': self.lr} + ] + + # self.optimizerFewShot = torch.optim.Adam( + # [ + # # {'name': 'class_embeddings', 'params': list([self.model.netPrior.classes_vectors]), 'lr': self.lr * 1.}, + # {'name': 'net_Prior', 'params': list(self.model.netPrior.parameters()), 'lr': self.lr * 10.}, + # {'name': 'net_Instance', 'params': list(self.model.netInstance.parameters()), 'lr': self.lr * 1.}, + # # {'name': 'net_articulation', 'params': list(self.model.netInstance.netArticulation.parameters()), 'lr': self.lr * 10.} + # ], betas=(0.9, 0.99), eps=1e-15 + # ) + self.optimizerFewShot = torch.optim.Adam(optimize_list, betas=(0.9, 0.99), eps=1e-15) + + # if self.cfgs.get('texture_way', None) is not None and self.cfgs.get('gan_tex', False): + if self.cfgs.get('gan_tex', False): + self.optimizerDiscTex = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.discriminator_texture.parameters()), lr=self.lr, betas=(0.9, 0.99), eps=1e-15) + + def load_checkpoint(self, optim=True, checkpoint_name=None): + # use to load the checkpoint of model and optimizer in the finetuning + """Search the specified/latest checkpoint in checkpoint_dir and load the model and optimizer.""" + if checkpoint_name is not None: + checkpoint_path = osp.join(self.checkpoint_dir, checkpoint_name) + else: + checkpoints = sorted(glob.glob(osp.join(self.checkpoint_dir, '*.pth'))) + if len(checkpoints) == 0: + return 0, 0 + checkpoint_path = checkpoints[-1] + self.checkpoint_name = osp.basename(checkpoint_path) + print(f"Loading checkpoint from {checkpoint_path}") + cp = torch.load(checkpoint_path, map_location=self.device) + self.model.load_model_state(cp) # the cp has netPrior and netInstance as keys + if optim: + try: + self.optimizerFewShot.load_state_dict(cp['optimizerFewShot']) + except: + print('you should be using the local texture so dont need to load the previous optimizer') + if self.enable_memory_bank: + self.memory_bank_keys = cp['memory_bank_keys'] + self.memory_bank = cp['memory_bank'] + self.metrics_trace = cp['metrics_trace'] + epoch = cp['epoch'] + total_iter = cp['total_iter'] + return epoch, total_iter + + def save_checkpoint(self, epoch, total_iter=0, optim=True, use_iter=False): + """Save model, optimizer, and metrics state to a checkpoint in checkpoint_dir for the specified epoch.""" + misc.xmkdir(self.checkpoint_dir) + if use_iter: + checkpoint_path = osp.join(self.checkpoint_dir, f'iter{total_iter:07}.pth') + prefix = 'iter*.pth' + else: + checkpoint_path = osp.join(self.checkpoint_dir, f'checkpoint{epoch:03}.pth') + prefix = 'checkpoint*.pth' + state_dict = self.model.get_model_state() + if optim: + optimizer_state = {'optimizerFewShot': self.optimizerFewShot.state_dict()} + state_dict = {**state_dict, **optimizer_state} + state_dict['metrics_trace'] = self.metrics_trace + state_dict['epoch'] = epoch + state_dict['total_iter'] = total_iter + if self.enable_memory_bank: + state_dict['memory_bank_keys'] = self.memory_bank_keys + state_dict['memory_bank'] = self.memory_bank + print(f"Saving checkpoint to {checkpoint_path}") + torch.save(state_dict, checkpoint_path) + if self.keep_num_checkpoint > 0: + self.clean_checkpoint(self.checkpoint_dir, keep_num=self.keep_num_checkpoint, prefix=prefix) + + def clean_checkpoint(self, checkpoint_dir, keep_num=2, prefix='checkpoint*.pth'): + if keep_num > 0: + names = list(sorted( + glob.glob(os.path.join(checkpoint_dir, prefix)) + )) + if len(names) > keep_num: + for name in names[:-keep_num]: + print(f"Deleting obslete checkpoint file {name}") + os.remove(name) + + def get_data_loaders_few_shot(self, cfgs, batch_size, num_workers, in_image_size, out_image_size): + # support the train_data_loaders, and also an identical val_data_loader? + train_loader = val_loader = None + + color_jitter_train = cfgs.get('color_jitter_train', None) + color_jitter_val = cfgs.get('color_jitter_val', None) + random_flip_train = cfgs.get('random_flip_train', False) + + data_loader_mode = cfgs.get('data_loader_mode', 'n_frame') + + num_sample_frames = cfgs.get('num_sample_frames', 2) + shuffle_train_seqs = cfgs.get('shuffle_train_seqs', False) + load_background = cfgs.get('background_mode', 'none') == 'background' + rgb_suffix = cfgs.get('rgb_suffix', '.png') + load_dino_feature = cfgs.get('load_dino_feature', False) + dino_feature_dim = cfgs.get('dino_feature_dim', 64) + get_loader_ddp = lambda **kwargs: get_sequence_loader_ddp( + mode=data_loader_mode, + batch_size=batch_size, + num_workers=num_workers, + in_image_size=in_image_size, + out_image_size=out_image_size, + num_sample_frames=num_sample_frames, + load_background=load_background, + rgb_suffix=rgb_suffix, + load_dino_feature=load_dino_feature, + dino_feature_dim=dino_feature_dim, + flow_bool=0, + **kwargs) + + print(f"Loading training data...") + train_loader = get_loader_ddp(data_dir=[self.original_classes_num, self.few_shot_categories_paths], rank=self.rank, world_size=self.world_size, use_few_shot=True, shuffle=False, color_jitter=color_jitter_train, random_flip=random_flip_train) + return train_loader, val_loader + + def get_data_loaders_original(self, cfgs, batch_size, num_workers, in_image_size, out_image_size): + train_loader = val_loader = test_loader = None + color_jitter_train = cfgs.get('color_jitter_train', None) + color_jitter_val = cfgs.get('color_jitter_val', None) + random_flip_train = cfgs.get('random_flip_train', False) + + data_loader_mode = cfgs.get('data_loader_mode', 'n_frame') + skip_beginning = cfgs.get('skip_beginning', 4) + skip_end = cfgs.get('skip_end', 4) + num_sample_frames = cfgs.get('num_sample_frames', 2) + min_seq_len = cfgs.get('min_seq_len', 10) + max_seq_len = cfgs.get('max_seq_len', 10) + debug_seq = cfgs.get('debug_seq', False) + random_sample_train_frames = cfgs.get('random_sample_train_frames', False) + shuffle_train_seqs = cfgs.get('shuffle_train_seqs', False) + random_sample_val_frames = cfgs.get('random_sample_val_frames', False) + load_background = cfgs.get('background_mode', 'none') == 'background' + rgb_suffix = cfgs.get('rgb_suffix', '.png') + load_dino_feature = cfgs.get('load_dino_feature', False) + load_dino_cluster = cfgs.get('load_dino_cluster', False) + dino_feature_dim = cfgs.get('dino_feature_dim', 64) + get_loader_ddp = lambda **kwargs: get_sequence_loader_ddp( + mode=data_loader_mode, + batch_size=batch_size, + num_workers=num_workers, + in_image_size=in_image_size, + out_image_size=out_image_size, + debug_seq=debug_seq, + skip_beginning=skip_beginning, + skip_end=skip_end, + num_sample_frames=num_sample_frames, + min_seq_len=min_seq_len, + max_seq_len=max_seq_len, + load_background=load_background, + rgb_suffix=rgb_suffix, + load_dino_feature=load_dino_feature, + load_dino_cluster=load_dino_cluster, + dino_feature_dim=dino_feature_dim, + flow_bool=0, + **kwargs) + + # just the train now + train_data_dir = self.original_categories_paths + if isinstance(train_data_dir, dict): + for data_path in train_data_dir.values(): + assert osp.isdir(data_path), f"Training data directory does not exist: {data_path}" + elif isinstance(train_data_dir, str): + assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" + else: + raise ValueError("train_data_dir must be a string or a dict of strings") + + print(f"Loading training data...") + # the train_data_dir is a dict and will go into the original dataset type + train_loader = get_loader_ddp(data_dir=train_data_dir, rank=self.rank, world_size=self.world_size, is_validation=False, use_few_shot=False, random_sample=random_sample_train_frames, shuffle=shuffle_train_seqs, dense_sample=True, color_jitter=color_jitter_train, random_flip=random_flip_train) + + return train_loader, val_loader + + def get_data_loaders_quadrupeds(self, cfgs, batch_size, num_workers, in_image_size, out_image_size): + train_loader = val_loader = test_loader = None + color_jitter_train = cfgs.get('color_jitter_train', None) + color_jitter_val = cfgs.get('color_jitter_val', None) + random_flip_train = cfgs.get('random_flip_train', False) + + data_loader_mode = cfgs.get('data_loader_mode', 'n_frame') + skip_beginning = cfgs.get('skip_beginning', 4) + skip_end = cfgs.get('skip_end', 4) + num_sample_frames = cfgs.get('num_sample_frames', 2) + min_seq_len = cfgs.get('min_seq_len', 10) + max_seq_len = cfgs.get('max_seq_len', 10) + debug_seq = cfgs.get('debug_seq', False) + random_sample_train_frames = cfgs.get('random_sample_train_frames', False) + shuffle_train_seqs = cfgs.get('shuffle_train_seqs', False) + random_sample_val_frames = cfgs.get('random_sample_val_frames', False) + load_background = cfgs.get('background_mode', 'none') == 'background' + rgb_suffix = cfgs.get('rgb_suffix', '.png') + load_dino_feature = cfgs.get('load_dino_feature', False) + load_dino_cluster = cfgs.get('load_dino_cluster', False) + dino_feature_dim = cfgs.get('dino_feature_dim', 64) + + enhance_back_view = cfgs.get('enhance_back_view', False) + enhance_back_view_path = cfgs.get('enhance_back_view_path', None) + + override_categories = cfgs.get('override_categories', None) + + disable_fewshot = cfgs.get('disable_fewshot', False) + dataset_split_num = cfgs.get('dataset_split_num', -1) + + get_loader_ddp = lambda **kwargs: get_sequence_loader_quadrupeds( + mode=data_loader_mode, + num_workers=num_workers, + in_image_size=in_image_size, + out_image_size=out_image_size, + debug_seq=debug_seq, + skip_beginning=skip_beginning, + skip_end=skip_end, + num_sample_frames=num_sample_frames, + min_seq_len=min_seq_len, + max_seq_len=max_seq_len, + load_background=load_background, + rgb_suffix=rgb_suffix, + load_dino_feature=load_dino_feature, + load_dino_cluster=load_dino_cluster, + dino_feature_dim=dino_feature_dim, + flow_bool=0, + enhance_back_view=enhance_back_view, + enhance_back_view_path=enhance_back_view_path, + override_categories=override_categories, + disable_fewshot=disable_fewshot, + dataset_split_num=dataset_split_num, + **kwargs) + + # just the train now + + print(f"Loading training data...") + val_image_num = cfgs.get('few_shot_val_image_num', 5) + # the train_data_dir is a dict and will go into the original dataset type + train_loader = get_loader_ddp(original_data_dirs=self.original_categories_paths, few_shot_data_dirs=self.few_shot_categories_paths, original_num=self.original_classes_num, few_shot_num=self.new_classes_num, rank=self.rank, world_size=self.world_size, batch_size=batch_size, is_validation=False, val_image_num=val_image_num, shuffle=shuffle_train_seqs, dense_sample=True, color_jitter=color_jitter_train, random_flip=random_flip_train) + val_loader = get_loader_ddp(original_data_dirs=self.original_val_data_path, few_shot_data_dirs=self.few_shot_categories_paths, original_num=self.original_classes_num, few_shot_num=self.new_classes_num, rank=self.rank, world_size=self.world_size, batch_size=1, is_validation=True, val_image_num=val_image_num, shuffle=False, dense_sample=True, color_jitter=color_jitter_val, random_flip=False) + + if self.test_categories_paths is not None: + get_test_loader_ddp = lambda **kwargs: get_test_loader_quadrupeds( + mode=data_loader_mode, + num_workers=num_workers, + in_image_size=in_image_size, + out_image_size=out_image_size, + debug_seq=debug_seq, + skip_beginning=skip_beginning, + skip_end=skip_end, + num_sample_frames=num_sample_frames, + min_seq_len=min_seq_len, + max_seq_len=max_seq_len, + load_background=load_background, + rgb_suffix=rgb_suffix, + load_dino_feature=load_dino_feature, + load_dino_cluster=load_dino_cluster, + dino_feature_dim=dino_feature_dim, + flow_bool=0, + enhance_back_view=enhance_back_view, + enhance_back_view_path=enhance_back_view_path, + **kwargs) + print(f"Loading testing data...") + test_loader = get_test_loader_ddp(test_data_dirs=self.test_categories_paths, rank=self.rank, world_size=self.world_size, batch_size=1, is_validation=True, shuffle=False, dense_sample=True, color_jitter=color_jitter_val, random_flip=False) + else: + test_loader = None + + return train_loader, val_loader, test_loader + + def forward_frozen_ViT(self, images): + # this part use the frozen pre-train ViT + x = images + with torch.no_grad(): + b, c, h, w = x.shape + self.model.netInstance.netEncoder._feats = [] + self.model.netInstance.netEncoder._register_hooks([11], 'key') + #self._register_hooks([11], 'token') + x = self.model.netInstance.netEncoder.ViT.prepare_tokens(x) + #x = self.ViT.prepare_tokens_with_masks(x) + + for blk in self.model.netInstance.netEncoder.ViT.blocks: + x = blk(x) + out = self.model.netInstance.netEncoder.ViT.norm(x) + self.model.netInstance.netEncoder._unregister_hooks() + + ph, pw = h // self.model.netInstance.netEncoder.patch_size, w // self.model.netInstance.netEncoder.patch_size + patch_out = out[:, 1:] # first is class token + patch_out = patch_out.reshape(b, ph, pw, self.model.netInstance.netEncoder.vit_feat_dim).permute(0, 3, 1, 2) + + patch_key = self.model.netInstance.netEncoder._feats[0][:,:,1:] # B, num_heads, num_patches, dim + patch_key = patch_key.permute(0, 1, 3, 2).reshape(b, self.model.netInstance.netEncoder.vit_feat_dim, ph, pw) + + global_feat = out[:, 0] + + return global_feat + + def forward_fix_embeddings(self, batch): + images = batch[0] + images = images.to(self.device) + batch_size, num_frames, _, h0, w0 = images.shape + images = images.reshape(batch_size*num_frames, *images.shape[2:]) # 0~1 + + if self.memory_encoder == 'DINO': + images_in = images * 2 - 1 # rescale to (-1, 1) + batch_features = self.forward_frozen_ViT(images_in) + elif self.memory_encoder == 'CLIP': + images_in = torch.nn.functional.interpolate(images, (self.clip_reso, self.clip_reso), mode='bilinear') + images_in = tvf.normalize(images_in, self.clip_mean, self.clip_std) + batch_features = self.clip_model.encode_image(images_in).float() + else: + raise NotImplementedError + return batch_features + + def retrieve_memory_bank(self, batch_features, batch): + batch_size = batch_features.shape[0] + + if self.memory_retrieve == 'cos-linear': + query = torch.nn.functional.normalize(batch_features.unsqueeze(1), dim=-1) # [B, 1, d_k] + key = torch.nn.functional.normalize(self.memory_bank_keys, dim=-1) # [size, d_k] + key = key.transpose(1, 0).unsqueeze(0).repeat(batch_size, 1, 1).to(query.device) # [B, d_k, size] + + cos_dist = torch.bmm(query, key).squeeze(1) # [B, size], larger the more similar + rank_idx = torch.sort(cos_dist, dim=-1, descending=True)[1][:, :self.memory_bank_topk] # [B, k] + value = self.memory_bank.unsqueeze(0).repeat(batch_size, 1, 1).to(query.device) # [B, size, d_v] + + out = torch.gather(value, dim=1, index=rank_idx[..., None].repeat(1, 1, self.memory_bank_dim)) # [B, k, d_v] + + weights = torch.gather(cos_dist, dim=-1, index=rank_idx) # [B, k] + weights = torch.nn.functional.normalize(weights, p=1.0, dim=-1).unsqueeze(-1).repeat(1, 1, self.memory_bank_dim) # [B, k, d_v] weights have been normalized + + out = weights * out + out = torch.sum(out, dim=1) + + else: + raise NotImplementedError + + batch_mean_out = torch.mean(out, dim=0) + + weight_aux = { + 'weights': weights[:, :, 0], # [B, k], weights from large to small + 'pick_idx': rank_idx, # [B, k] + } + + return batch_mean_out, out, weight_aux + + def discriminator_texture_step(self): + image_iv = self.model.record_image_iv + image_rv = self.model.record_image_rv + image_gt = self.model.record_image_gt + + self.model.record_image_iv = None + self.model.record_image_rv = None + self.model.record_image_gt = None + + image_iv = image_iv.requires_grad_(True) + image_rv = image_rv.requires_grad_(True) + image_gt = image_gt.requires_grad_(True) + + self.optimizerDiscTex.zero_grad() + disc_loss_gt = 0.0 + disc_loss_iv = 0.0 + disc_loss_rv = 0.0 + grad_penalty = 0.0 + # for the gt image, it can only be in real or not + if 'gt' in self.model.few_shot_gan_tex_real: + d_gt = self.model.discriminator_texture(image_gt) + disc_loss_gt += discriminator_architecture.bce_loss_target(d_gt, 1) + if image_gt.requires_grad: + grad_penalty_gt = 10. * discriminator_architecture.compute_grad2(d_gt, image_gt) + disc_loss_gt += grad_penalty_gt + grad_penalty += grad_penalty_gt + + # for the input view image, it can be in real or fake + if 'iv' in self.model.few_shot_gan_tex_real: + d_iv = self.model.discriminator_texture(image_iv) + disc_loss_iv += discriminator_architecture.bce_loss_target(d_iv, 1) + if image_iv.requires_grad: + grad_penalty_iv = 10. * discriminator_architecture.compute_grad2(d_iv, image_iv) + disc_loss_iv += grad_penalty_iv + grad_penalty += grad_penalty_iv + elif 'iv' in self.model.few_shot_gan_tex_fake: + d_iv = self.model.discriminator_texture(image_iv) + disc_loss_iv += discriminator_architecture.bce_loss_target(d_iv, 0) + + # for the random view image, it can only be in fake + if 'rv' in self.model.few_shot_gan_tex_fake: + d_rv = self.model.discriminator_texture(image_rv) + disc_loss_rv += discriminator_architecture.bce_loss_target(d_rv, 0) + + all_loss = disc_loss_iv + disc_loss_rv + disc_loss_gt + + all_loss = all_loss * self.cfgs.get('gan_tex_loss_discriminator_weight', 0.1) + self.discriminator_texture_loss = all_loss + self.discriminator_texture_loss.backward() + self.optimizerDiscTex.step() + self.discriminator_texture_loss = 0. + + return { + 'discriminator_loss': all_loss.detach(), + 'discriminator_loss_iv': disc_loss_iv.detach(), + 'discriminator_loss_rv': disc_loss_rv.detach(), + 'discriminator_loss_gt': disc_loss_gt.detach(), + 'discriminator_loss_grad': grad_penalty.detach() + } + + def train(self): + """Perform training.""" + # archive code and configs + if self.archive_code: + misc.archive_code(osp.join(self.checkpoint_dir, 'archived_code.zip'), filetypes=['.py']) + misc.dump_yaml(osp.join(self.checkpoint_dir, 'configs.yml'), self.cfgs) + + # initialize + start_epoch = 0 + self.total_iter = 0 + self.total_iter = self.original_total_iter + self.metrics_trace.reset() + self.model.to(self.device) + + if self.model.enable_disc: + self.model.reset_only_disc_optimizer() + + if self.few_shot_resume: + resume_model_name = self.cfgs.get('few_shot_resume_name', None) + start_epoch, self.total_iter = self.load_checkpoint(optim=True, checkpoint_name=resume_model_name) + + self.model.ddp(self.rank, self.world_size) + + # use tensorboard + if self.use_logger: + from torch.utils.tensorboard import SummaryWriter + self.logger = SummaryWriter(osp.join(self.checkpoint_dir, 'logs', datetime.now().strftime("%Y%m%d-%H%M%S")), flush_secs=10) + # self.viz_data_iterator = indefinite_generator_from_list(self.val_loader) if self.visualize_validation else indefinite_generator_from_list(self.train_loader) + self.viz_data_iterator = indefinite_generator(self.val_loader[0]) if self.visualize_validation else indefinite_generator(self.train_loader[0]) + if self.fix_viz_batch: + self.viz_batch = next(self.viz_data_iterator) + + if self.test_loader is not None: + self.viz_test_data_iterator = indefinite_generator(self.test_loader[0]) if self.visualize_validation else indefinite_generator(self.train_loader[0]) + + # run_epochs + epoch = 0 + + for epoch in range(start_epoch, self.num_epochs): + metrics = self.run_epoch(epoch) + if self.combine_dataset: + self.train_loader[0].dataset._shuffle_all() + self.metrics_trace.append("train", metrics) + if (epoch+1) % self.save_checkpoint_freq == 0: + self.save_checkpoint(epoch+1, total_iter=self.total_iter, optim=True) + # if self.cfgs.get('pyplot_metrics', True): + # self.metrics_trace.plot(pdf_path=osp.join(self.checkpoint_dir, 'metrics.pdf')) + self.metrics_trace.save(osp.join(self.checkpoint_dir, 'metrics.json')) + print(f"Training completed for all {epoch+1} epochs.") + + def run_epoch(self, epoch): + """Run one training epoch.""" + metrics = self.make_metrics() + + self.model.set_train() + + max_loader_len = max([len(loader) for loader in self.train_loader]) + train_generators = [indefinite_generator(loader) for loader in self.train_loader] + + iteration = 0 + while iteration < max_loader_len * len(self.train_loader): + for generator in train_generators: + batch = next(generator) + + self.total_iter += 1 + num_seqs, num_frames = batch[0].shape[:2] + total_im_num = num_seqs * num_frames + + if self.enable_memory_bank: + batch_features = self.forward_fix_embeddings(batch) + batch_embedding, embeddings, weights = self.retrieve_memory_bank(batch_features, batch) + bank_embedding_model_input = [batch_embedding, embeddings, weights] + else: + # bank_embedding_model_input = None + batch_features = self.forward_fix_embeddings(batch) + weights = { + "weights": torch.rand(1,10).to(batch_features.device), + "pick_idx": torch.randint(low=0, high=60, size=(1, 10)).to(batch_features.device) + } + bank_embedding_model_input = [batch_features[0], batch_features, weights] + m = self.model.forward(batch, epoch=epoch, iter=iteration, total_iter=self.total_iter, which_data=self.dataset, is_training=True, bank_embedding=bank_embedding_model_input) + + # self.model.backward() + self.optimizerFewShot.zero_grad() + self.model.total_loss.backward() + self.optimizerFewShot.step() + self.model.total_loss = 0. + + # if self.cfgs.get('texture_way', None) is not None and self.cfgs.get('gan_tex', False): + if self.model.few_shot_gan_tex: + # the discriminator for local texture + disc_ret = self.discriminator_texture_step() + m.update(disc_ret) + + if self.model.enable_disc and (self.model.mask_discriminator_iter[0] < self.total_iter) and (self.model.mask_discriminator_iter[1] > self.total_iter): + # the discriminator training + discriminator_loss_dict, grad_loss = self.model.discriminator_step() + m.update( + { + 'mask_disc_loss_discriminator': discriminator_loss_dict['discriminator_loss'] - grad_loss, + 'mask_disc_loss_discriminator_grad': grad_loss, + 'mask_disc_loss_discriminator_rv': discriminator_loss_dict['discriminator_loss_rv'], + 'mask_disc_loss_discriminator_iv': discriminator_loss_dict['discriminator_loss_iv'], + 'mask_disc_loss_discriminator_gt': discriminator_loss_dict['discriminator_loss_gt'] + } + ) + self.logger.add_histogram('train_'+'discriminator_logits/random_view', discriminator_loss_dict['d_rv'], self.total_iter) + if discriminator_loss_dict['d_iv'] is not None: + self.logger.add_histogram('train_'+'discriminator_logits/input_view', discriminator_loss_dict['d_iv'], self.total_iter) + if discriminator_loss_dict['d_gt'] is not None: + self.logger.add_histogram('train_'+'discriminator_logits/gt_view', discriminator_loss_dict['d_gt'], self.total_iter) + + metrics.update(m, total_im_num) + if self.rank == 0: + print(f"T{epoch:04}/{iteration:05}/{metrics}") + + if self.iteration_save and self.total_iter % self.iteration_save_freq == 0: + self.save_checkpoint(epoch+1, total_iter=self.total_iter, optim=True, use_iter=True) + + # ## reset optimizers + # if self.cfgs.get('opt_reset_every_iter', 0) > 0 and self.total_iter < self.cfgs.get('opt_reset_end_iter', 0): + # if self.total_iter % self.cfgs.get('opt_reset_every_iter', 0) == 0: + # self.model.reset_optimizers() + + if misc.is_main_process() and self.use_logger: + if self.rank == 0 and self.total_iter % self.log_freq_losses == 0: + for name, loss in m.items(): + label = f'cub_loss_train/{name[4:]}' if 'cub' in name else f'loss_train/{name}' + self.logger.add_scalar(label, loss, self.total_iter) + if self.rank == 0 and self.save_result_freq is not None and self.total_iter % self.save_result_freq == 0: + with torch.no_grad(): + m = self.model.forward(batch, epoch=epoch, iter=iteration, total_iter=self.total_iter, save_results=False, save_dir=self.train_result_dir, which_data=self.dataset, is_training=False, bank_embedding=bank_embedding_model_input) + torch.cuda.empty_cache() + if self.total_iter % self.log_freq_images == 0: + with torch.no_grad(): + if self.rank == 0 and self.log_train_images: + m = self.model.forward(batch, epoch=epoch, iter=iteration, viz_logger=self.logger, total_iter=self.total_iter, which_data=self.dataset, logger_prefix='train_', is_training=False, bank_embedding=bank_embedding_model_input) + if self.fix_viz_batch: + print(f'fix_viz_batch:{self.fix_viz_batch}') + batch_val = self.viz_batch + else: + batch_val = next(self.viz_data_iterator) + if self.visualize_validation: + import time + vis_start = time.time() + # batch = next(self.viz_data_iterator) + # try: + # batch = next(self.viz_data_iterator) + # except: # iterator exhausted + # self.reset_viz_data_iterator() + # batch = next(self.viz_data_iterator) + if self.enable_memory_bank: + batch_features_val = self.forward_fix_embeddings(batch_val) + batch_embedding_val, embeddings_val, weights_val = self.retrieve_memory_bank(batch_features_val, batch_val) + bank_embedding_model_input_val = [batch_embedding_val, embeddings_val, weights_val] + else: + # bank_embedding_model_input_val = None + batch_features_val = self.forward_fix_embeddings(batch_val) + weights_val = { + "weights": torch.rand(1,10).to(batch_features_val.device), + "pick_idx": torch.randint(low=0, high=60, size=(1, 10)).to(batch_features_val.device) + } + bank_embedding_model_input_val = [batch_features_val[0], batch_features_val, weights_val] + + if self.total_iter % self.save_result_freq == 0: + m = self.model.forward(batch_val, epoch=epoch, iter=iteration, viz_logger=self.logger, total_iter=self.total_iter, save_results=False, save_dir=self.train_result_dir, which_data=self.dataset, logger_prefix='val_', is_training=False, bank_embedding=bank_embedding_model_input_val) + torch.cuda.empty_cache() + + vis_end = time.time() + print(f"vis time: {vis_end - vis_start}") + + if self.test_loader is not None: + # unseen category test visualization + batch_test = next(self.viz_test_data_iterator) + if self.enable_memory_bank: + batch_features_test = self.forward_fix_embeddings(batch_test) + batch_embedding_test, embeddings_test, weights_test = self.retrieve_memory_bank(batch_features_test, batch_test) + bank_embedding_model_input_test = [batch_embedding_test, embeddings_test, weights_test] + else: + # bank_embedding_model_input_test = None + batch_features_test = self.forward_fix_embeddings(batch_test) + weights_test = { + "weights": torch.rand(1,10).to(batch_features_test.device), + "pick_idx": torch.randint(low=0, high=60, size=(1, 10)).to(batch_features_test.device) + } + bank_embedding_model_input_test = [batch_features_test[0], batch_features_test, weights_test] + m_test = self.model.forward(batch_test, epoch=epoch, iter=iteration, viz_logger=self.logger, total_iter=self.total_iter, which_data=self.dataset, logger_prefix='test_', is_training=False, bank_embedding=bank_embedding_model_input_test) + vis_test_end = time.time() + print(f"vis test time: {vis_test_end - vis_end}") + for name, loss in m_test.items(): + if self.rank == 0: + self.logger.add_scalar(f'loss_test/{name}', loss, self.total_iter) + + for name, loss in m.items(): + if self.rank == 0: + self.logger.add_scalar(f'loss_val/{name}', loss, self.total_iter) + torch.cuda.empty_cache() + + iteration += 1 + + self.model.scheduler_step() + return metrics + \ No newline at end of file diff --git a/video3d/triplane_texture/custom_ops.py b/video3d/triplane_texture/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6a48e5b4cff11af3e1d43d654a042067622b7e01 --- /dev/null +++ b/video3d/triplane_texture/custom_ops.py @@ -0,0 +1,162 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import glob +import hashlib +import importlib +import os +import re +import shutil +import uuid + +import torch +import torch.utils.cpp_extension +from torch.utils.file_baton import FileBaton + +# ---------------------------------------------------------------------------- +# Global options. + +verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' + + +# ---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + + +# ---------------------------------------------------------------------------- + +def _get_mangled_gpu_name(): + name = torch.cuda.get_device_name().lower() + out = [] + for c in name: + if re.match('[a-z0-9_-]+', c): + out.append(c) + else: + out.append('-') + return ''.join(out) + + +# ---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + + +def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + if headers is None: + headers = [] + if source_dir is not None: + sources = [os.path.join(source_dir, fname) for fname in sources] + headers = [os.path.join(source_dir, fname) for fname in headers] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + verbose_build = (verbosity == 'full') + + # Compile and load. + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either + # break the build or unnecessarily restrict what's available to nvcc. + # Unset it to let nvcc decide based on what's available on the + # machine. + os.environ['TORCH_CUDA_ARCH_LIST'] = '' + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + # + # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work + # around the *.cu dependency bug in ninja config. + # + all_source_files = sorted(sources + headers) + all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) + if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): + + # Compute combined hash digest for all source files. + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + + # Select cached build directory name. + source_digest = hash_md5.hexdigest() + build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') + + if not os.path.isdir(cached_build_dir): + tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' + os.makedirs(tmpdir) + for src in all_source_files: + shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) + try: + os.replace(tmpdir, cached_build_dir) # atomic + except OSError: + # source directory already exists, delete tmpdir and its contents. + shutil.rmtree(tmpdir) + if not os.path.isdir(cached_build_dir): raise + + # Compile. + cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] + torch.utils.cpp_extension.load( + name=module_name, build_directory=cached_build_dir, + verbose=verbose_build, sources=cached_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + + # Load. + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache dict. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/video3d/triplane_texture/lift_architecture.py b/video3d/triplane_texture/lift_architecture.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d270f5964fc62556868402190754dcbb4146a9 --- /dev/null +++ b/video3d/triplane_texture/lift_architecture.py @@ -0,0 +1,171 @@ +import numpy as np +import torch +import torch.nn as nn +import torchvision +import torchvision.models as models +from typing import Union, List, Tuple +import os +import video3d.utils.misc as misc +import torch.nn.functional as F + + +class Lift_Encoder(nn.Module): + def __init__( + self, + cin, + feat_dim, + grid_scale=7., + grid_size=32, + optim_latent=False, + img_size=256, + with_z_feature=False, + cam_pos_z_offset=10. + ): + super().__init__() + + ''' + unproject the input feature map to tri-plane, each plane is (-1, -1)*grid_scale to (1, 1)*scale + ''' + self.cin = cin + self.nf = feat_dim + self.grid_scale = grid_scale + self.grid_size = grid_size + self.img_size = img_size + self.with_z_feature = with_z_feature + self.cam_pos_z_offset = cam_pos_z_offset + + self.feature_projector = nn.Linear(cin, feat_dim, bias=False) + + self.plane_latent = None + if optim_latent: + self.optim_latent = nn.Parameter(torch.rand(3, feat_dim, grid_size, grid_size)) + else: + self.optim_latent = None + + if with_z_feature: + self.conv_bottleneck = nn.Conv2d(feat_dim+1, feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate") + else: + self.conv_bottleneck = nn.Conv2d(feat_dim, feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate") + + #TODO: implement an upsampler for input feature map here? + self.conv_1 = nn.Conv2d(feat_dim, 4*feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate") + self.conv_2 = nn.Conv2d(feat_dim, 4*feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate") + self.up = nn.PixelShuffle(2) + + self.conv_enc = nn.Conv2d(feat_dim, feat_dim // 2, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate") + self.conv_dec = nn.Conv2d(feat_dim // 2, feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate") + + self.feature_fusion = nn.Linear(3*feat_dim, feat_dim, bias=False) + + def get_coords(self, grid_size): + with torch.no_grad(): + lines = torch.arange(0, grid_size) + grids_x, grids_y = torch.meshgrid([lines, lines], indexing="ij") + grids = torch.stack([grids_x, grids_y], dim=-1) + grids = (grids - self.grid_size // 2) / (self.grid_size // 2) + grids = grids * self.grid_scale + + plane_z0 = torch.cat([grids, torch.zeros(list(grids.shape[:-1]) + [1])], dim=-1) # [S, S, 3] + plane_y0 = plane_z0.clone()[..., [0, 2, 1]] + plane_x0 = plane_z0.clone()[..., [2, 0, 1]] + + planes = torch.stack([plane_x0, plane_y0, plane_z0], dim=0) + return planes # [3, S, S, 3] + + def get_uv_z(self, pts, mvp): + cam4 = torch.matmul(torch.nn.functional.pad(pts, pad=(0,1), mode='constant', value=1.0), torch.transpose(mvp, 1, 2)) + cam3 = cam4[..., :3] / cam4[..., 3:4] + cam_uv = cam3[..., :2] + # cam_uv = cam_uv.detach() + cam_depth = cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(pts.device).view(1, 1, 3) + cam_depth = cam_depth / self.grid_scale * 2 + cam_depth = cam_depth[..., 2:3] + + return cam_uv, cam_depth + + def unproject(self, feature_map, mvp): + ''' + feature_map: [B, C, h, w] + mvp: [B, 4, 4] + ''' + self.plane_latent = None + bs, C, h, w = feature_map.shape + device = feature_map.device + feature_map = self.feature_projector(feature_map.permute(0, 2, 3, 1).reshape(-1, C)).reshape(bs, h, w, self.nf).permute(0, 3, 1, 2) + + feature_map = self.up(self.conv_1(feature_map)) + feature_map = self.up(self.conv_2(feature_map)) + + plane_coords = self.get_coords(self.grid_size) + plane_coords = plane_coords.unsqueeze(0).repeat(bs, 1, 1, 1, 1) + plane_coords = plane_coords.to(device) + + plane_pts = plane_coords.reshape(bs, -1, 3) # [B, N_POINTS, 3] + plane_uv, plane_z = self.get_uv_z(plane_pts, mvp) + plane_uv = plane_uv.detach() + plane_z = plane_z.detach() + + nP = plane_pts.shape[1] + + plane_feature = F.grid_sample(feature_map, plane_uv.reshape(bs, 1, nP, 2), mode="bilinear", padding_mode="zeros").squeeze(dim=-2).permute(0, 2, 1) + if self.with_z_feature: + plane_feature = torch.cat([plane_feature, plane_z], dim=-1) + + plane_feature = plane_feature.reshape(plane_feature.shape[0], 3, self.grid_size, self.grid_size, plane_feature.shape[-1]) + + return plane_feature + + def conv_plane(self, plane_feature): + bs, _, nh, nw, nC = plane_feature.shape + plane_feature = plane_feature.reshape(-1, nh, nw, nC).permute(0, 3, 1, 2) # [bs*3, nC, nh, nw] + + plane_feature = self.conv_bottleneck(plane_feature) + x = self.conv_dec(self.conv_enc(plane_feature)) + out = x + plane_feature + out = out.reshape(bs, 3, out.shape[-3], out.shape[-2], out.shape[-1]) + + if self.optim_latent is not None: + optim_latent = self.optim_latent.unsqueeze(0).repeat(bs, 1, 1, 1, 1) + out = out + optim_latent + + return out + + def sample_plane(self, pts, feat): + ''' + pts: [B, K, 3] + feat: [B, 3, C, h, w] + ''' + pts_x, pts_y, pts_z = pts.unbind(dim=-1) + + pts_x0 = torch.stack([pts_y, pts_z], dim=-1) + pts_y0 = torch.stack([pts_x, pts_z], dim=-1) + pts_z0 = torch.stack([pts_x, pts_y], dim=-1) + + feat_x0 = F.grid_sample(feat[:, 0, :, :], pts_x0.unsqueeze(1), mode="bilinear", padding_mode="border").squeeze(-2).permute(0, 2, 1) + feat_y0 = F.grid_sample(feat[:, 0, :, :], pts_y0.unsqueeze(1), mode="bilinear", padding_mode="border").squeeze(-2).permute(0, 2, 1) + feat_z0 = F.grid_sample(feat[:, 0, :, :], pts_z0.unsqueeze(1), mode="bilinear", padding_mode="border").squeeze(-2).permute(0, 2, 1) + + pts_feat = torch.cat([feat_x0, feat_y0, feat_z0], dim=-1) + return pts_feat + + def forward(self, feature_map, mvp, pts, inference="unproject"): + ''' + inference = "unproject" or "sample" + ''' + assert inference in ["unproject", "sample"] + if inference == "unproject": + plane_feature = self.unproject(feature_map, mvp) + plane_feature = self.conv_plane(plane_feature) + + self.plane_latent = plane_feature.clone().detach() # this is just for test case + + if inference == "unproject": + feat_to_sample = plane_feature + else: + new_bs = pts.shape[0] + feat_to_sample = self.plane_latent[:new_bs] + + pts_feature = self.sample_plane(pts, feat_to_sample) + pts_feature = self.feature_fusion(pts_feature) # [B, K, C] + + return pts_feature \ No newline at end of file diff --git a/video3d/triplane_texture/misc.py b/video3d/triplane_texture/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..e39477647067b115a97fc729d3b772649c943470 --- /dev/null +++ b/video3d/triplane_texture/misc.py @@ -0,0 +1,330 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import re +import contextlib +import numpy as np +import torch +import warnings +from typing import Any, List, Tuple, Union + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +# ---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + + +# ---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +# ---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + + +# ---------------------------------------------------------------------------- +# Context manager to temporarily suppress known warnings in torch.jit.trace(). +# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + +@contextlib.contextmanager +def suppress_tracer_warnings(): + flt = ('ignore', None, torch.jit.TracerWarning, None, 0) + warnings.filters.insert(0, flt) + yield + warnings.filters.remove(flt) + + +# ---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + + +# ---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + + decorator.__name__ = fn.__name__ + return decorator + + +# ---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + + +# ---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + # if name not in src_tensors: + # print(name) + # continue + ################### + assert (not require_all) or (name in src_tensors) + assert (not require_all) or src_tensors[name].shape == tensor.shape + if name in src_tensors: + try: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + except: + print('==> Error in loading checkpoint') + print(name) + + +# ---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + # yield # We follow this + # yield # We always do sync for the processing + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + # print('sync') + yield + else: + # print('no sync') + with module.no_sync(): + yield + + +# ---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + find_not_equal = False + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + try: + assert (tensor == other).all(), fullname + except: + print(fullname) + if not find_not_equal: + print(tensor.shape) + print(tensor[tensor != other]) + print(other[tensor != other]) + print((tensor != other).sum()) + find_not_equal = True + + if find_not_equal: + exit() ############ + # print(tensor.shape) + # print(tensor[tensor != other]) + # print(other[tensor != other]) + # print((tensor != other).sum()) + # exit() + + +# ---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + + def pre_hook(_mod, _inputs): + nesting[0] += 1 + + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(EasyDict(mod=mod, outputs=outputs)) + + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(t.shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/video3d/triplane_texture/ops/bias_act.cpp b/video3d/triplane_texture/ops/bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fe3ebcdfd9ae33c4d66546d8769ee82a2838a819 --- /dev/null +++ b/video3d/triplane_texture/ops/bias_act.cpp @@ -0,0 +1,100 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + + +#include +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act", &bias_act); +} + +//------------------------------------------------------------------------ \ No newline at end of file diff --git a/video3d/triplane_texture/ops/bias_act.cu b/video3d/triplane_texture/ops/bias_act.cu new file mode 100644 index 0000000000000000000000000000000000000000..3eb402cdbf1bb318f774eee541e976573efb8252 --- /dev/null +++ b/video3d/triplane_texture/ops/bias_act.cu @@ -0,0 +1,174 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + + +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ \ No newline at end of file diff --git a/video3d/triplane_texture/ops/bias_act.h b/video3d/triplane_texture/ops/bias_act.h new file mode 100644 index 0000000000000000000000000000000000000000..b5d0b2c6354c7a83b1d04f7111a3061f8183b1b8 --- /dev/null +++ b/video3d/triplane_texture/ops/bias_act.h @@ -0,0 +1,39 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ \ No newline at end of file diff --git a/video3d/triplane_texture/ops/bias_act.py b/video3d/triplane_texture/ops/bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..c02df63e6d98ccbe8562fe9c421196131fe26d6b --- /dev/null +++ b/video3d/triplane_texture/ops/bias_act.py @@ -0,0 +1,233 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +"""Custom PyTorch ops for efficient bias and activation.""" + +import os +import numpy as np +import torch + +from .. import custom_ops +from .. import misc + +from typing import Any, List, Tuple, Union + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + +# ---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +# ---------------------------------------------------------------------------- + +_plugin = None +_null_tensor = torch.empty([0]) + + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='bias_act_plugin', + sources=['bias_act.cpp', 'bias_act.cu'], + headers=['bias_act.h'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + + +# ---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + + +# ---------------------------------------------------------------------------- + +@misc.profiled_function +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + + +# ---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ ### + ctx.memory_format = torch.contiguous_format + # ctx.memory_format = torch.channels_last if x.ndim == 4 and x.stride(1) == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.contiguous_format + # ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/video3d/triplane_texture/ops/conv2d_gradfix.py b/video3d/triplane_texture/ops/conv2d_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..9d63b1ac42bbeec09dec9c604496f49d4742bb5f --- /dev/null +++ b/video3d/triplane_texture/ops/conv2d_gradfix.py @@ -0,0 +1,206 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.conv2d` that supports +arbitrarily high order gradients with zero performance penalty.""" + +import contextlib +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +# ---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + + +@contextlib.contextmanager +def no_weight_gradients(disable=True): + global weight_gradients_disabled + old = weight_gradients_disabled + if disable: + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + + +# ---------------------------------------------------------------------------- + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + + +# ---------------------------------------------------------------------------- + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + return True + + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + + +# ---------------------------------------------------------------------------- + +_conv2d_gradfix_cache = dict() +_null_tensor = torch.empty([0]) + + +def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + ctx.save_for_backward( + input if weight.requires_grad else _null_tensor, + weight if input.requires_grad else _null_tensor, + ) + ctx.input_shape = input.shape + + # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). + if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): + a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) + b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) + c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) + c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) + c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) + return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) + + # General case => cuDNN. + if transpose: + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + input_shape = ctx.input_shape + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) + op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) + grad_input = op.apply(grad_output, weight, None) + assert grad_input.shape == input_shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input): + ctx.save_for_backward( + grad_output if input.requires_grad else _null_tensor, + input if grad_output.requires_grad else _null_tensor, + ) + ctx.grad_output_shape = grad_output.shape + ctx.input_shape = input.shape + + # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). + if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): + a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) + b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) + c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) + return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) + + # General case => cuDNN. + name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' + flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] + return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad_output_shape = ctx.grad_output_shape + input_shape = ctx.input_shape + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output_shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) + op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) + grad2_input = op.apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input_shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d + +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/video3d/triplane_texture/ops/conv2d_resample.py b/video3d/triplane_texture/ops/conv2d_resample.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc7f2e39bae42f4e6dd38b0d91ad3cb1e863599 --- /dev/null +++ b/video3d/triplane_texture/ops/conv2d_resample.py @@ -0,0 +1,146 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +"""2D convolution with optional up/downsampling.""" + +import torch + +from .. import misc +from . import conv2d_gradfix +from . import upfirdn2d +from .upfirdn2d import _parse_padding +from .upfirdn2d import _get_filter_size + + +# ---------------------------------------------------------------------------- + +def _get_weight_shape(w): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + shape = [int(sz) for sz in w.shape] + misc.assert_shape(w, shape) + return shape + + +# ---------------------------------------------------------------------------- + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + if not flip_weight and (kw > 1 or kh > 1): + w = w.flip([2, 3]) + + # Execute using conv2d_gradfix. + op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d + return op(x, w, stride=stride, padding=padding, groups=groups) + + +# ---------------------------------------------------------------------------- + +@misc.profiled_function +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling upfirdn2d.setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + assert isinstance(groups, int) and (groups >= 1) + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + px0, px1, py0, py1 = _parse_padding(padding) + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt, pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], gain=up ** 2, flip_filter=flip_filter) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight) + + # Fallback: Generic reference implementation. + x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/video3d/triplane_texture/ops/fma.py b/video3d/triplane_texture/ops/fma.py new file mode 100644 index 0000000000000000000000000000000000000000..e227ca22b87cdd3aba30fbeeefc65115cc36eb1e --- /dev/null +++ b/video3d/triplane_texture/ops/fma.py @@ -0,0 +1,63 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" + +import torch + + +# ---------------------------------------------------------------------------- + +def fma(a, b, c): # => a * b + c + return _FusedMultiplyAdd.apply(a, b, c) + + +# ---------------------------------------------------------------------------- + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + + +# ---------------------------------------------------------------------------- + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims + 1:]) + assert x.shape == shape + return x + +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/video3d/triplane_texture/ops/grid_sample_gradfix.py b/video3d/triplane_texture/ops/grid_sample_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..b70272b87d6f03b51c6564abfd79b8c8541cda72 --- /dev/null +++ b/video3d/triplane_texture/ops/grid_sample_gradfix.py @@ -0,0 +1,81 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.grid_sample` that +supports arbitrarily high order gradients between the input and output. +Only works on 2D images and assumes +`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" + +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +# ---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. + + +# ---------------------------------------------------------------------------- + +def grid_sample(input, grid): + if _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + + +# ---------------------------------------------------------------------------- + +def _should_use_custom_op(): + return enabled + + +# ---------------------------------------------------------------------------- + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) + return grad_input, grad_grid + + +# ---------------------------------------------------------------------------- + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + grad2_grid = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid + +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/video3d/triplane_texture/ops/upfirdn2d.cpp b/video3d/triplane_texture/ops/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a0b155fab046a113051fce6b26da598a8e63ddfe --- /dev/null +++ b/video3d/triplane_texture/ops/upfirdn2d.cpp @@ -0,0 +1,108 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + + +#include +#include +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.numel() > 0, "x has zero size"); + TORCH_CHECK(f.numel() > 0, "f has zero size"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d", &upfirdn2d); +} + +//------------------------------------------------------------------------ \ No newline at end of file diff --git a/video3d/triplane_texture/ops/upfirdn2d.cu b/video3d/triplane_texture/ops/upfirdn2d.cu new file mode 100644 index 0000000000000000000000000000000000000000..830c3a3e068737273428caefb3804ef5ad6336ff --- /dev/null +++ b/video3d/triplane_texture/ops/upfirdn2d.cu @@ -0,0 +1,385 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + + +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + // No up/downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 2x upsampling. + if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 2x downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; + if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + + // 4x upsampling. + if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 4x downsampling (inefficient). + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; + if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; + if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ \ No newline at end of file diff --git a/video3d/triplane_texture/ops/upfirdn2d.h b/video3d/triplane_texture/ops/upfirdn2d.h new file mode 100644 index 0000000000000000000000000000000000000000..933c61c49b8fac3b4daa24f98eb96273186cce28 --- /dev/null +++ b/video3d/triplane_texture/ops/upfirdn2d.h @@ -0,0 +1,59 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ \ No newline at end of file diff --git a/video3d/triplane_texture/ops/upfirdn2d.py b/video3d/triplane_texture/ops/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..bf2853407e038ff38aacb2b470de4f78ad534dc8 --- /dev/null +++ b/video3d/triplane_texture/ops/upfirdn2d.py @@ -0,0 +1,402 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + + +"""Custom PyTorch ops for efficient resampling of 2D images.""" + +import os +import numpy as np +import torch + +from .. import custom_ops +from .. import misc +from . import conv2d_gradfix + +# ---------------------------------------------------------------------------- + +_plugin = None + + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='upfirdn2d_plugin', + sources=['upfirdn2d.cpp', 'upfirdn2d.cu'], + headers=['upfirdn2d.h'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + with misc.suppress_tracer_warnings(): + fw = int(fw) + fh = int(fh) + misc.assert_shape(f, [fh, fw][:f.ndim]) + assert fw >= 1 and fh >= 1 + return fw, fh + + +# ---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + + +# ---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + + +# ---------------------------------------------------------------------------- + +@misc.profiled_function +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Check that upsampled buffer is not smaller than the filter. + upW = in_width * upx + padx0 + padx1 + upH = in_height * upy + pady0 + pady1 + assert upW >= f.shape[-1] and upH >= f.shape[0] + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + + +# ---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + if f.ndim == 1 and f.shape[0] == 1: + f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1. + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + else: + y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0) + y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + + +# ---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + + +# ---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl) + + +# ---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/video3d/triplane_texture/triplane_predictor.py b/video3d/triplane_texture/triplane_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..64d0cdba50179c86bdcecb99f3a59034de44676f --- /dev/null +++ b/video3d/triplane_texture/triplane_predictor.py @@ -0,0 +1,713 @@ +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F + +from video3d.triplane_texture.ops import bias_act +from video3d.triplane_texture.ops import fma +from video3d.triplane_texture.ops import upfirdn2d +from video3d.triplane_texture.ops import conv2d_resample +from video3d.triplane_texture.ops import grid_sample_gradfix +from video3d.triplane_texture import misc + + +def modulated_conv2d( + x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. + weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. + styles, # Modulation coefficients of shape [batch_size, in_channels]. + noise=None, # Optional noise tensor to add to the output activations. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + padding=0, # Padding with respect to the upsampled image. + resample_filter=None, + # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). + demodulate=True, # Apply weight demodulation? + flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). + fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation? +): + batch_size = x.shape[0] + out_channels, in_channels, kh, kw = weight.shape + misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] + misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] + misc.assert_shape(styles, [batch_size, in_channels]) # [NI] + + # Pre-normalize inputs to avoid FP16 overflow. + if x.dtype == torch.float16 and demodulate: + weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm( + float('inf'), dim=[1, 2, 3], keepdim=True)) # max_Ikk + styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I + + # Calculate per-sample weights and demodulation coefficients. + w = None + dcoefs = None + if demodulate or fused_modconv: + w = weight.unsqueeze(0) # [NOIkk] + w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] + if demodulate: + dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO] + if demodulate and fused_modconv: + w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] + + # Execute by scaling the activations before and after the convolution. + if not fused_modconv: + x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) + x = conv2d_resample.conv2d_resample( + x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) + if demodulate and noise is not None: + x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) + elif demodulate: + x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) + elif noise is not None: + x = x.add_(noise.to(x.dtype)) + return x + + # Execute as one fused op using grouped convolution. + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + batch_size = int(batch_size) + misc.assert_shape(x, [batch_size, in_channels, None, None]) + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_resample.conv2d_resample( + x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, + flip_weight=flip_weight) + x = x.reshape(batch_size, -1, *x.shape[2:]) + if noise is not None: + x = x.add_(noise) + return x + + +def modulated_fc( + x, # Input tensor of shape [batch_size, n_feature, in_channels]. + weight, # Weight tensor of shape [out_channels, in_channels]. + styles, # Modulation coefficients of shape [batch_size, in_channels]. + noise=None, # Optional noise tensor to add to the output activations. + demodulate=True, # Apply weight demodulation? +): + batch_size = x.shape[0] + n_feature = x.shape[1] + out_channels, in_channels = weight.shape + misc.assert_shape(weight, [out_channels, in_channels]) + misc.assert_shape(x, [batch_size, n_feature, in_channels]) + misc.assert_shape(styles, [batch_size, in_channels]) + assert demodulate + # Pre-normalize inputs to avoid FP16 overflow. + if x.dtype == torch.float16 and demodulate: + weight = weight * (1 / np.sqrt(in_channels) / weight.norm(float('inf'), dim=[1, 2, 3], keepdim=True)) # max_Ikk + styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I + + # Calculate per-sample weights and demodulation coefficients. + w = weight.unsqueeze(0) # [NOI] + w = w * styles.unsqueeze(dim=1) # [NOI] + dcoefs = (w.square().sum(dim=[2]) + 1e-8).rsqrt() # [NO] + w = w * dcoefs.unsqueeze(dim=-1) # [NOI] + x = torch.bmm(x, w.permute(0, 2, 1)) + if noise is not None: + x = x.add_(noise) + return x + + +class FullyConnectedLayer(torch.nn.Module): + def __init__( + self, + in_features, # Number of input features. + out_features, # Number of output features. + bias=True, # Apply additive bias before the activation function? + activation='linear', # Activation function: 'relu', 'lrelu', etc. + device='cuda', + lr_multiplier=1, # Learning rate multiplier. + bias_init=0, # Initial value for the additive bias. + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.activation = activation + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features], device=device) / lr_multiplier) + self.bias = torch.nn.Parameter( + torch.full([out_features], np.float32(bias_init), device=device)) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight.to(x.dtype) * self.weight_gain + b = self.bias + if b is not None: + b = b.to(x.dtype) + if self.bias_gain != 1: + b = b * self.bias_gain + + if self.activation == 'linear' and b is not None: + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self): + return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' + + +class SynthesisLayer(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + kernel_size=3, # Convolution kernel size. + up=1, # Integer upsampling factor. + use_noise=True, # Enable noise input? + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + device='cuda', + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + channels_last=False, # Use channels_last format for the weights? + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.resolution = resolution + self.up = up + self.use_noise = use_noise + self.activation = activation + self.conv_clamp = conv_clamp + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.act_gain = bias_act.activation_funcs[activation].def_gain + + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1, device=device) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter( + torch.randn([out_channels, in_channels, kernel_size, kernel_size], device=device).to( + memory_format=memory_format)) + if use_noise: + self.register_buffer('noise_const', torch.randn([resolution, resolution], device=device)) + self.noise_strength = torch.nn.Parameter(torch.zeros([], device=device)) + self.bias = torch.nn.Parameter(torch.zeros([out_channels], device=device)) + + def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): + assert noise_mode in ['random', 'const', 'none'] + in_resolution = self.resolution // self.up + misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution]) + + styles = self.affine(w) + + noise = None + if self.use_noise and noise_mode == 'random': + noise = torch.randn( + [x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength + if self.use_noise and noise_mode == 'const': + noise = self.noise_const * self.noise_strength + + flip_weight = (self.up == 1) # slightly faster + x = modulated_conv2d( + x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, + padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, + fused_modconv=fused_modconv) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) + return x + + def extra_repr(self): + return ' '.join( + [ + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},', + f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}']) + + +class ImplicitSynthesisLayer(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + use_noise=True, # Enable noise input? + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + device='cuda', + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.use_noise = use_noise + self.activation = activation + self.conv_clamp = conv_clamp + self.act_gain = bias_act.activation_funcs[activation].def_gain + + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1, device=device) + self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels], device=device)) + self.bias = torch.nn.Parameter(torch.zeros([out_channels], device=device)) + + def forward(self, w, x, noise_mode='random', gain=1): + # x is the feature############# + # w is the condition + assert noise_mode in ['random', 'const', 'none'] + styles = self.affine(w) + noise = None # in te beegining, we didn't use the noise + x = modulated_fc(x=x, weight=self.weight, styles=styles, noise=noise) + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act( + x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp, + dim=2) # the last dim is the feature dim + return x + + def extra_repr(self): + return ' '.join( + [ + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},', + f'activation={self.activation:s}']) + + + +class Conv2dLayer(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + device='cuda', + bias=True, # Apply additive bias before the activation function? + activation='linear', # Activation function: 'relu', 'lrelu', etc. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + channels_last=False, # Expect the input to have memory_format=channels_last? + trainable=True, # Update the weights of this layer during training? + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.activation = activation + self.up = up + self.down = down + self.conv_clamp = conv_clamp + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.act_gain = bias_act.activation_funcs[activation].def_gain + + memory_format = torch.channels_last if channels_last else torch.contiguous_format + weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size], device=device).to( + memory_format=memory_format) + bias = torch.zeros([out_channels], device=device) if bias else None + if trainable: + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) if bias is not None else None + else: + self.register_buffer('weight', weight) + if bias is not None: + self.register_buffer('bias', bias) + else: + self.bias = None + + def forward(self, x, gain=1): + w = self.weight * self.weight_gain + b = self.bias.to(x.dtype) if self.bias is not None else None + flip_weight = (self.up == 1) # slightly faster + x = conv2d_resample.conv2d_resample( + x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, + flip_weight=flip_weight) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) + return x + + def extra_repr(self): + return ' '.join( + [ + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},', + f'up={self.up}, down={self.down}']) + + +class ToRGBLayer(torch.nn.Module): + def __init__( + self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False, device='cuda'): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.conv_clamp = conv_clamp + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1, device=device) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter( + torch.randn([out_channels, in_channels, kernel_size, kernel_size], device=device).to( + memory_format=memory_format)) + self.bias = torch.nn.Parameter(torch.zeros([out_channels], device=device)) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + + def forward(self, x, w, fused_modconv=True): + styles = self.affine(w) * self.weight_gain + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) + x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) + return x + + def extra_repr(self): + return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}' + + + +class SynthesisBlock(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? + architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16=False, # Use FP16 for this block? + fp16_channels_last=False, # Use channels-last memory format with FP16? + fused_modconv_default=True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training. + device='cuda', + first_layer=False, + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.first_layer = first_layer + self.in_channels = in_channels + self.w_dim = w_dim + self.resolution = resolution + self.img_channels = img_channels + self.is_last = is_last + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.fused_modconv_default = fused_modconv_default + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.num_conv = 0 + self.num_torgb = 0 + if in_channels == 0: + self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution], device=device)) + + if in_channels != 0: + if self.first_layer: + self.conv0 = SynthesisLayer( + in_channels, out_channels, w_dim=w_dim, resolution=resolution, + conv_clamp=conv_clamp, channels_last=self.channels_last, device=device, **layer_kwargs) + else: + self.conv0 = SynthesisLayer( + in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2, + resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, device=device, + **layer_kwargs) + self.num_conv += 1 + + self.conv1 = SynthesisLayer( + out_channels, out_channels, w_dim=w_dim, resolution=resolution, + conv_clamp=conv_clamp, channels_last=self.channels_last, device=device, **layer_kwargs) + self.num_conv += 1 + + if is_last or architecture == 'skip': + self.torgb = ToRGBLayer( + out_channels, img_channels, w_dim=w_dim, + conv_clamp=conv_clamp, channels_last=self.channels_last, device=device) + self.num_torgb += 1 + + if in_channels != 0 and architecture == 'resnet': + self.skip = Conv2dLayer( + in_channels, out_channels, kernel_size=1, bias=False, up=2, + resample_filter=resample_filter, channels_last=self.channels_last, device=device) + + def forward(self, x, img, ws, force_fp32=False, fused_modconv='inference_only', update_emas=False, **layer_kwargs): + _ = update_emas # unused + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + if ws.device.type != 'cuda': + force_fp32 = True + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + fused_modconv = self.fused_modconv_default + if fused_modconv == 'inference_only': + fused_modconv = (not self.training) + ## + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) + else: + # misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y + x + else: + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + + if img is not None: + misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) + img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img + y if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + + + +class SynthesisNetwork(torch.nn.Module): + def __init__( + self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=4, # Use FP16 for the N highest resolutions. + device='cuda', + **block_kwargs, # Arguments for SynthesisBlock. + ): + assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 + super().__init__() + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.num_fp16_res = num_fp16_res + self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] # [4,8,16,32,64,128] + + # {4: 512, 8: 512, 16: 512, 32: 512, 64: 512, 128: 256} + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} + self.num_ws = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res // 2] if res > 4 else 0 + out_channels = channels_dict[res] + is_last = (res == self.img_resolution) + use_fp16 = False + block = SynthesisBlock( + in_channels, out_channels, w_dim=w_dim, resolution=res, + img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, device=device, **block_kwargs) + self.num_ws += block.num_conv + self.num_ws += block.num_torgb + setattr(self, f'b{res}', block) + + def forward(self, ws, **block_kwargs): + block_ws = [] + misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) + ws = ws.to(torch.float32) + w_idx = 0 + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) + w_idx += (block.num_conv + block.num_torgb) + x = img = None + for res, cur_ws in zip(self.block_resolutions, block_ws): + block = getattr(self, f'b{res}') + x, img = block(x, img, cur_ws, **block_kwargs) + return img + + def extra_repr(self): + return ' '.join( + [ + f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', + f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', + f'num_fp16_res={self.num_fp16_res:d}']) + + +class ImplicitSynthesisNetwork(torch.nn.Module): + def __init__( + self, + w_dim=512, # Intermediate latent (W) dimensionality. + input_channel=256, + out_channels=3, # Number of color channels. + latent_channel=256, + n_layers=4, + device='cuda' + ): + super().__init__() + self.n_layer = n_layers + self.layers = [] + self.num_ws = 0 + for i_layer in range(self.n_layer): + layer = ImplicitSynthesisLayer( + w_dim=w_dim, + in_channels=input_channel if i_layer == 0 else latent_channel, + out_channels=latent_channel, device=device) + self.layers.append(layer) + self.num_ws += 1 + + self.layers.append( + ImplicitSynthesisLayer( + w_dim=w_dim, in_channels=latent_channel, out_channels=out_channels, + activation='sigmoid', device=device) + ) + self.num_ws += 1 + self.layers = torch.nn.ModuleList(self.layers) + self.w_dim = w_dim + self.out_channels = out_channels + + def forward(self, ws, position, **block_kwargs): + out = position + for i in range(self.n_layer): + out = self.layers[i](ws[:, i], out) + out = self.layers[-1](ws[:, self.n_layer], out) + return out + + def extra_repr(self): + return ' '.join( + [ + f'w_dim={self.w_dim:d}']) + + + +class TriPlaneTex(torch.nn.Module): + def __init__( + self, + w_dim, # Intermediate latent (W) dimensionality. + img_channels, # Number of color channels. + tri_plane_resolution=256, + device='cuda', + mlp_latent_channel=256, + n_implicit_layer=3, + feat_dim=384, # number of feat dim from encoder + n_mapping_layer=8, + sym_texture=True, + grid_scale=7., + min_max=None, + perturb_normal=False, + **block_kwargs, # Arguments for SynthesisBlock. + ): + super().__init__() + self.n_implicit_layer = n_implicit_layer + self.img_feat_dim = 32 # The setting follows Koki's paper + self.w_dim = w_dim + self.tri_plane_resolution = tri_plane_resolution + + # the mapping network + self.feat_dim = feat_dim + self.n_mapping_layer = n_mapping_layer + self.embed = FullyConnectedLayer(feat_dim, w_dim, device=device) + for idx in range(n_mapping_layer): + layer = FullyConnectedLayer(w_dim, w_dim, activation='lrelu', lr_multiplier=0.1, device=device) + setattr(self, f'mapping{idx}', layer) + + # self.w_dim = w_dim * 2 + + self.tri_plane_synthesis = SynthesisNetwork( + w_dim=self.w_dim, img_resolution=self.tri_plane_resolution, + img_channels=self.img_feat_dim * 3, + device=device, + **block_kwargs) + self.num_ws_tri_plane = self.tri_plane_synthesis.num_ws + + mlp_input_channel = self.img_feat_dim + w_dim # + mlp_latent_channel = mlp_latent_channel + + mlp_input_channel -= w_dim + self.mlp_synthesis = ImplicitSynthesisNetwork( + out_channels=img_channels, + n_layers=self.n_implicit_layer, + w_dim=self.w_dim, + latent_channel=mlp_latent_channel, + input_channel=mlp_input_channel, + device=device) + self.num_ws_all = self.num_ws_tri_plane + self.mlp_synthesis.num_ws + + # texture related + self.sym_texture = sym_texture + self.grid_scale = grid_scale + self.shape_min = 0. + self.shape_lenght = grid_scale / 2. + + if min_max is not None: + self.register_buffer('min_max', min_max) + else: + self.min_max = None + + self.perturb_normal = perturb_normal + + def old_forward( + self, feat, position=None, **block_kwargs): + ''' + Predict texture with given latent code + :param feat: image global feat + :param position: position for the surface points + :param block_kwargs: + :return: + ''' + assert feat.shape[-1] == self.feat_dim + + # mapping global feature to ws + ws = self.embed(feat) + for idx in range(self.n_mapping_layer): + layer = getattr(self, f'mapping{idx}') + ws = layer(ws) + + ws = ws.unsqueeze(1).repeat(1, self.num_ws_all, 1) + + plane_feat = self.tri_plane_synthesis(ws[:, :self.num_ws_tri_plane], **block_kwargs) + tri_plane = torch.split(plane_feat, self.img_feat_dim, dim=1) + + normalized_tex_pos = (position - self.shape_min) / self.shape_lenght # in [-1, 1] + normalized_tex_pos = torch.clamp(normalized_tex_pos, -1.0, 1.0) + + if self.sym_texture: + x_pos, y_pos, z_pos = normalized_tex_pos.unbind(-1) + normalized_tex_pos = torch.stack([x_pos.abs(), y_pos, z_pos], dim=-1) + + + x_feat = grid_sample_gradfix.grid_sample( + tri_plane[0], + torch.cat( + [normalized_tex_pos[:, :, 0:1], normalized_tex_pos[:, :, 1:2]], + dim=-1).unsqueeze(dim=1).detach()) + y_feat = grid_sample_gradfix.grid_sample( + tri_plane[1], + torch.cat( + [normalized_tex_pos[:, :, 1:2], normalized_tex_pos[:, :, 2:3]], + dim=-1).unsqueeze(dim=1).detach()) + z_feat = grid_sample_gradfix.grid_sample( + tri_plane[2], + torch.cat( + [normalized_tex_pos[:, :, 0:1], normalized_tex_pos[:, :, 2:3]], + dim=-1).unsqueeze(dim=1).detach()) + + final_feat = (x_feat + y_feat + z_feat) + final_feat = final_feat.squeeze(dim=2).permute(0, 2, 1) # 32dimension + final_feat_tex = final_feat + out = self.mlp_synthesis(ws[:, self.num_ws_tri_plane:], final_feat_tex) + return out + + def sample(self, xyz, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None): + # query the deformed points or canonical points + # x = deform_xyz + x = xyz + + b, h, w, c = x.shape + mvp = mvp.detach() # [b, 4, 4] + w2c = w2c.detach() # [b, 4, 4] + x = x.reshape(b, -1, c) + + global_feat = feat # [b, d] + + out = self.old_forward( + feat=global_feat, + position=x + ) + if self.min_max is not None: + out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] + return out.view(b, h, w, -1) \ No newline at end of file diff --git a/video3d/triplane_texture/triplane_transformer.py b/video3d/triplane_texture/triplane_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a15295a487f191ac448a4ee85b888b5e20968c4a --- /dev/null +++ b/video3d/triplane_texture/triplane_transformer.py @@ -0,0 +1,181 @@ +import numpy as np +import torch +import torch.nn as nn +import torchvision +import torchvision.models as models +from typing import Union, List, Tuple +import os +import video3d.utils.misc as misc +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class Transformer_layer(nn.Module): + def __init__(self, dim_feat=384, dim=1024, hidden_dim=1024, heads=16): + super().__init__() + ''' + dim: the dim between each attention, mlp, also the input and output dim for the layer + hidden_dim: the dim inside qkv + dim_feat: condition feature dim + ''' + dim_head = hidden_dim // heads + self.heads = heads + self.scale = dim_head ** -0.5 # 8 + + self.norm = nn.LayerNorm(dim) + self.ffn = FeedForward( + dim=dim, + hidden_dim=(4 * dim), + dropout=0. + ) + + # cross attention part + self.to_cross_q = nn.Linear(dim, hidden_dim, bias=False) + self.to_cross_kv = nn.Linear(dim_feat, hidden_dim*2, bias=False) + self.cross_attend = nn.Softmax(dim=-1) + + # self attention part + self.to_self_qkv = nn.Linear(dim, hidden_dim*3, bias=False) + self.self_attend = nn.Softmax(dim=-1) + + def forward_cross_attn(self, x, feature): + x = self.norm(x) + + q = self.to_cross_q(x) + k, v = self.to_cross_kv(feature).chunk(2, dim=-1) + qkv = (q, k, v) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = self.cross_attend(dots) + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return out + + def forward_self_attn(self, x): + x = self.norm(x) + qkv = self.to_self_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = self.self_attend(dots) + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return out + + def forward(self, x, feature): + ''' + x: [B, N, dim] + feature: [B, N, dim_feat] + ''' + cross_token = self.forward_cross_attn(x, feature) + cross_token = cross_token + x + + self_token = self.forward_self_attn(cross_token) + self_token = self_token + cross_token + + out = self.ffn(self_token) + out = out + self_token + + return out + + +class Triplane_Transformer(nn.Module): + def __init__(self, emb_dim=1024, emb_num=1024, num_layers=16, + triplane_dim=80, triplane_scale=7.): + super().__init__() + + self.learnable_embedding = nn.Parameter(torch.randn(1, emb_num, emb_dim)) + self.layers = nn.ModuleList([]) + for _ in range(num_layers): + self.layers.append( + Transformer_layer( + dim_feat=384, + dim=emb_dim, + hidden_dim=emb_dim + ) + ) + + self.triplane_dim = triplane_dim + self.triplane_scale = triplane_scale + + self.to_triplane = nn.ConvTranspose2d( + in_channels=emb_dim, + out_channels=3 * triplane_dim, + kernel_size=4, + padding=1, + stride=2 + ) + + self.norm = nn.LayerNorm(emb_dim) + + def sample_feat(self, feat_maps, pts): + ''' + feat_maps: [B, 3, C, H, W] + pts: [B, K, 3] + ''' + pts = pts / (self.triplane_scale / 2) + + pts_xy = pts[..., [0,1]] + pts_yz = pts[..., [1,2]] + pts_xz = pts[..., [0,2]] + + feat_xy = feat_maps[:, 0, :, :, :] + feat_yz = feat_maps[:, 1, :, :, :] + feat_xz = feat_maps[:, 2, :, :, :] + + sampled_feat_xy = F.grid_sample( + feat_xy, pts_xy.unsqueeze(1), mode='bilinear', align_corners=True + ) + sampled_feat_yz = F.grid_sample( + feat_yz, pts_yz.unsqueeze(1), mode='bilinear', align_corners=True + ) + sampled_feat_xz = F.grid_sample( + feat_xz, pts_xz.unsqueeze(1), mode='bilinear', align_corners=True + ) + + sampled_feat = torch.cat([sampled_feat_xy, sampled_feat_yz, sampled_feat_xz], dim=1).squeeze(-2) # [B, F, K] + sampled_feat = sampled_feat.permute(0, 2, 1) + return sampled_feat + + def forward(self, feature, pts): + ''' + feature: [B, N, dim_feat] + ''' + batch_size = feature.shape[0] + embedding = self.learnable_embedding.repeat(batch_size, 1, 1) + + x = embedding + for layer in self.layers: + x = layer(x, feature) + x = self.norm(x) + # x: [B, 32x32, 1024] + batch_size, pwph, feat_dim = x.shape + ph = int(pwph ** 0.5) + pw = int(pwph ** 0.5) + triplane_feat = x.reshape(batch_size, ph, pw, feat_dim).permute(0, 3, 1, 2) + triplane_feat = self.to_triplane(triplane_feat) # [B, C, 64, 64] + + triplane_feat = triplane_feat.reshape(triplane_feat.shape[0], 3, self.triplane_dim, triplane_feat.shape[-2], triplane_feat.shape[-1]) + + pts_feat = self.sample_feat(triplane_feat, pts) + + return pts_feat + diff --git a/video3d/utils/__init__.py b/video3d/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/video3d/utils/arap.py b/video3d/utils/arap.py new file mode 100755 index 0000000000000000000000000000000000000000..b88bf054ed125d02d5706dafe454fdc6d7081bd0 --- /dev/null +++ b/video3d/utils/arap.py @@ -0,0 +1,289 @@ +# import pytorch3d +import torch +from einops import rearrange +from torch._C import device + + +def edges_to_sparse_incidence(edges, num_vertices): + num_edges = edges.shape[0] + row_indexes = torch.arange(num_edges, dtype=torch.long, device=edges.device).repeat_interleave(2) + col_indexes = edges.reshape(-1) + indexes = torch.stack([row_indexes, col_indexes]) + values = torch.FloatTensor([1, -1]).to(edges.device).repeat(num_edges) + return torch.sparse.FloatTensor(indexes, values, torch.Size([num_edges, num_vertices])) + + +def compute_svd_rotation(vertices_rest_pose, vertices_deformed_pose, incidence_mat): + """ + Adapted from: + https://github.com/kzhou23/shape_pose_disent/blob/a8017c405892c98f52fa9775327172633290b1d8/arap.py#L76 + + vertices_rest_pose: B x V x D + vertices_deformed_pose: B x V x D + incidence_mat: E x V + + """ + batch_size, num_vertices, dimensions = vertices_rest_pose.shape + vertices = torch.cat((vertices_rest_pose, vertices_deformed_pose), dim=0) + # 2B x V x D -> V x (D x 2B) + vertices = rearrange(vertices, 'a v d -> v (d a)') + # E x V . V x (D x 2B) - > E x (D x 2B) + edges = torch.sparse.mm(incidence_mat, vertices) + edges = rearrange(edges, 'e (d a) -> a e d', d=dimensions) + rest_edges, deformed_edges = torch.split(edges, batch_size, dim=0) + + edges_outer = torch.matmul(rest_edges[:, :, :, None], deformed_edges[:, :, None, :]) + edges_outer = rearrange(edges_outer, 'b e d1 d2 -> e (b d1 d2)') + + abs_incidence_mat = incidence_mat.clone() + abs_incidence_mat._values()[:] = torch.abs(abs_incidence_mat._values()) + + # transposed S + S = torch.sparse.mm(abs_incidence_mat.t(), edges_outer) + S = rearrange(S, 'v (b d1 d2) -> b v d2 d1', v=num_vertices, b=batch_size, d1=dimensions, d2=dimensions) + + # SVD on gpu is extremely slow! https://github.com/pytorch/pytorch/pull/48436 + device = S.device + U, _, V = torch.svd(S.cpu()) + U = U.to(device) + V = V.to(device) + + det_sign = torch.det(torch.matmul(U, V.transpose(-2, -1))) + U = torch.cat([U[..., :-1], U[..., -1:] * det_sign[..., None, None]], axis=-1) + + rotations = torch.matmul(U, V.transpose(-2, -1)) + + return rotations + + +def compute_rotation(vertices_rest_pose, vertices_deformed_pose, edges): + """ + vertices_rest_pose: B x V x D + vertices_deformed_pose: B x V x D + edges: E x 2 + """ + num_vertices = vertices_rest_pose.shape[1] + incidence_mat = edges_to_sparse_incidence(edges, num_vertices) + rot = compute_svd_rotation(vertices_rest_pose, vertices_deformed_pose, incidence_mat) + rot = pytorch3d.transforms.matrix_to_quaternion(rot) + return rot + + +def quaternion_normalize(quaternion, eps=1e-12): + """ + Adapted from tensorflow_graphics + + Normalizes a quaternion. + + Note: + In the following, A1 to An are optional batch dimensions. + + Args: + quaternion: A tensor of shape `[A1, ..., An, 4]`, where the last dimension + represents a quaternion. + eps: A lower bound value for the norm that defaults to 1e-12. + name: A name for this op that defaults to "quaternion_normalize". + + Returns: + A N-D tensor of shape `[?, ..., ?, 1]` where the quaternion elements have + been normalized. + + Raises: + ValueError: If the shape of `quaternion` is not supported. + """ + return l2_normalize(quaternion, dim=-1, epsilon=eps) + + +def l2_normalize(x, dim=-1, epsilon=1e-12): + square_sum = torch.sum(x ** 2, dim=dim, keepdim=True) + x_inv_norm = torch.rsqrt(torch.clamp(square_sum, min=epsilon)) + return x * x_inv_norm + + +def arap_energy(vertices_rest_pose, + vertices_deformed_pose, + quaternions, + edges, + vertex_weight=None, + edge_weight=None, + conformal_energy=True, + aggregate_loss=True): + """ + Adapted from tensorflow_graphics + + Estimates an As Conformal As Possible (ACAP) fitting energy. + For a given mesh in rest pose, this function evaluates a variant of the ACAP + [1] fitting energy for a batch of deformed meshes. The vertex weights and edge + weights are defined on the rest pose. + The method implemented here is similar to [2], but with an added free variable + capturing a scale factor per vertex. + [1]: Yusuke Yoshiyasu, Wan-Chun Ma, Eiichi Yoshida, and Fumio Kanehiro. + "As-Conformal-As-Possible Surface Registration." Computer Graphics Forum. Vol. + 33. No. 5. 2014.
+ [2]: Olga Sorkine, and Marc Alexa. + "As-rigid-as-possible surface modeling". Symposium on Geometry Processing. + Vol. 4. 2007. + Note: + In the description of the arguments, V corresponds to + the number of vertices in the mesh, and E to the number of edges in this + mesh. + Note: + In the following, A1 to An are optional batch dimensions. + Args: + vertices_rest_pose: A tensor of shape `[V, 3]` containing the position of + all the vertices of the mesh in rest pose. + vertices_deformed_pose: A tensor of shape `[A1, ..., An, V, 3]` containing + the position of all the vertices of the mesh in deformed pose. + quaternions: A tensor of shape `[A1, ..., An, V, 4]` defining a rigid + transformation to apply to each vertex of the rest pose. See Section 2 + from [1] for further details. + edges: A tensor of shape `[E, 2]` defining indices of vertices that are + connected by an edge. + vertex_weight: An optional tensor of shape `[V]` defining the weight + associated with each vertex. Defaults to a tensor of ones. + edge_weight: A tensor of shape `[E]` defining the weight of edges. Common + choices for these weights include uniform weighting, and cotangent + weights. Defaults to a tensor of ones. + conformal_energy: A `bool` indicating whether each vertex is associated with + a scale factor or not. If this parameter is True, scaling information must + be encoded in the norm of `quaternions`. If this parameter is False, this + function implements the energy described in [2]. + aggregate_loss: A `bool` defining whether the returned loss should be an + aggregate measure. When True, the mean squared error is returned. When + False, returns two losses for every edge of the mesh. + name: A name for this op. Defaults to "as_conformal_as_possible_energy". + Returns: + When aggregate_loss is `True`, returns a tensor of shape `[A1, ..., An]` + containing the ACAP energies. When aggregate_loss is `False`, returns a + tensor of shape `[A1, ..., An, 2*E]` containing each term of the summation + described in the equation 7 of [2]. + Raises: + ValueError: if the shape of `vertices_rest_pose`, `vertices_deformed_pose`, + `quaternions`, `edges`, `vertex_weight`, or `edge_weight` is not supported. + """ + # with tf.compat.v1.name_scope(name, "as_conformal_as_possible_energy", [ + # vertices_rest_pose, vertices_deformed_pose, quaternions, edges, + # conformal_energy, vertex_weight, edge_weight + # ]): + # vertices_rest_pose = tf.convert_to_tensor(value=vertices_rest_pose) + # vertices_deformed_pose = tf.convert_to_tensor(value=vertices_deformed_pose) + # quaternions = tf.convert_to_tensor(value=quaternions) + # edges = tf.convert_to_tensor(value=edges) + # if vertex_weight is not None: + # vertex_weight = tf.convert_to_tensor(value=vertex_weight) + # if edge_weight is not None: + # edge_weight = tf.convert_to_tensor(value=edge_weight) + + # shape.check_static( + # tensor=vertices_rest_pose, + # tensor_name="vertices_rest_pose", + # has_rank=2, + # has_dim_equals=(-1, 3)) + # shape.check_static( + # tensor=vertices_deformed_pose, + # tensor_name="vertices_deformed_pose", + # has_rank_greater_than=1, + # has_dim_equals=(-1, 3)) + # shape.check_static( + # tensor=quaternions, + # tensor_name="quaternions", + # has_rank_greater_than=1, + # has_dim_equals=(-1, 4)) + # shape.compare_batch_dimensions( + # tensors=(vertices_deformed_pose, quaternions), + # last_axes=(-3, -3), + # broadcast_compatible=False) + # shape.check_static( + # tensor=edges, tensor_name="edges", has_rank=2, has_dim_equals=(-1, 2)) + # tensors_with_vertices = [vertices_rest_pose, + # vertices_deformed_pose, + # quaternions] + # names_with_vertices = ["vertices_rest_pose", + # "vertices_deformed_pose", + # "quaternions"] + # axes_with_vertices = [-2, -2, -2] + # if vertex_weight is not None: + # shape.check_static( + # tensor=vertex_weight, tensor_name="vertex_weight", has_rank=1) + # tensors_with_vertices.append(vertex_weight) + # names_with_vertices.append("vertex_weight") + # axes_with_vertices.append(0) + # shape.compare_dimensions( + # tensors=tensors_with_vertices, + # axes=axes_with_vertices, + # tensor_names=names_with_vertices) + # if edge_weight is not None: + # shape.check_static( + # tensor=edge_weight, tensor_name="edge_weight", has_rank=1) + # shape.compare_dimensions( + # tensors=(edges, edge_weight), + # axes=(0, 0), + # tensor_names=("edges", "edge_weight")) + + if not conformal_energy: + quaternions = quaternion_normalize(quaternions) + # Extracts the indices of vertices. + indices_i, indices_j = torch.unbind(edges, dim=-1) + # Extracts the vertices we need per term. + vertices_i_rest = vertices_rest_pose[..., indices_i, :] + vertices_j_rest = vertices_rest_pose[..., indices_j, :] + vertices_i_deformed = vertices_deformed_pose[..., indices_i, :] + vertices_j_deformed = vertices_deformed_pose[..., indices_j, :] + # Extracts the weights we need per term. + weights_shape = vertices_i_rest.shape[-2] + if vertex_weight is not None: + weight_i = vertex_weight[indices_i] + weight_j = vertex_weight[indices_j] + else: + weight_i = weight_j = torch.ones(weights_shape, dtype=vertices_rest_pose.dtype, device=vertices_rest_pose.device) + weight_i = weight_i[..., None] + weight_j = weight_j[..., None] + if edge_weight is not None: + weight_ij = edge_weight + else: + weight_ij = torch.ones(weights_shape, dtype=vertices_rest_pose.dtype, device=vertices_rest_pose.device) + weight_ij = weight_ij[..., None] + # Extracts the rotation we need per term. + quaternion_i = quaternions[..., indices_i, :] + quaternion_j = quaternions[..., indices_j, :] + # Computes the energy. + deformed_ij = vertices_i_deformed - vertices_j_deformed + rotated_rest_ij = pytorch3d.transforms.quaternion_apply(quaternion_i, (vertices_i_rest - vertices_j_rest)) + energy_ij = weight_i * weight_ij * (deformed_ij - rotated_rest_ij) + deformed_ji = vertices_j_deformed - vertices_i_deformed + rotated_rest_ji = pytorch3d.transforms.quaternion_apply(quaternion_j, (vertices_j_rest - vertices_i_rest)) + energy_ji = weight_j * weight_ij * (deformed_ji - rotated_rest_ji) + energy_ij_squared = torch.sum(energy_ij ** 2, dim=-1) + energy_ji_squared = torch.sum(energy_ji ** 2, dim=-1) + if aggregate_loss: + average_energy_ij = torch.mean(energy_ij_squared, dim=-1) + average_energy_ji = torch.mean(energy_ji_squared, dim=-1) + return (average_energy_ij + average_energy_ji) / 2.0 + return torch.cat((energy_ij_squared, energy_ji_squared), dim=-1) + + +def arap_loss(vertices_rest_pose, vertices_deformed_pose, edges): + # squash batch dimensions + vertices_rest_pose_shape = list(vertices_rest_pose.shape) + vertices_deformed_pose_shape = list(vertices_deformed_pose.shape) + vertices_rest_pose = vertices_rest_pose.reshape([-1] + vertices_rest_pose_shape[-2:]) + vertices_deformed_pose = vertices_deformed_pose.reshape([-1] + vertices_deformed_pose_shape[-2:]) + + # try: + quaternions = compute_rotation(vertices_rest_pose, vertices_deformed_pose, edges) + # except RuntimeError: + # print('SVD did not converge') + # batch_size = vertices_rest_pose.shape[0] + # num_vertices = vertices_rest_pose.shape[-2] + # quaternions = pytorch3d.transforms.matrix_to_quaternion(pytorch3d.transforms.euler_angles_to_matrix(torch.zeros([batch_size, num_vertices, 3], device=vertices_rest_pose.device), 'XYZ')) + + quaternions = quaternions.detach() + + energy = arap_energy( + vertices_rest_pose, + vertices_deformed_pose, + quaternions, + edges, + aggregate_loss=True, + conformal_energy=False) + return energy.reshape(vertices_rest_pose_shape[:-2]) diff --git a/video3d/utils/custom_loss.py b/video3d/utils/custom_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..246c0d1b4c2ceb3817eda2ebf674722c1781e978 --- /dev/null +++ b/video3d/utils/custom_loss.py @@ -0,0 +1,322 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +from itertools import islice + +import torch + + +def mesh_normal_consistency(meshes, reduce=True): + r""" + Computes the normal consistency of each mesh in meshes. + We compute the normal consistency for each pair of neighboring faces. + If e = (v0, v1) is the connecting edge of two neighboring faces f0 and f1, + then the normal consistency between f0 and f1 + + .. code-block:: python + + a + /\ + / \ + / f0 \ + / \ + v0 /____e___\ v1 + \ / + \ / + \ f1 / + \ / + \/ + b + + The normal consistency is + + .. code-block:: python + + nc(f0, f1) = 1 - cos(n0, n1) + + where cos(n0, n1) = n0^n1 / ||n0|| / ||n1|| is the cosine of the angle + between the normals n0 and n1, and + + n0 = (v1 - v0) x (a - v0) + n1 = - (v1 - v0) x (b - v0) = (b - v0) x (v1 - v0) + + This means that if nc(f0, f1) = 0 then n0 and n1 point to the same + direction, while if nc(f0, f1) = 2 then n0 and n1 point opposite direction. + + .. note:: + For well-constructed meshes the assumption that only two faces share an + edge is true. This assumption could make the implementation easier and faster. + This implementation does not follow this assumption. All the faces sharing e, + which can be any in number, are discovered. + + Args: + meshes: Meshes object with a batch of meshes. + + Returns: + loss: Average normal consistency across the batch. + Returns 0 if meshes contains no meshes or all empty meshes. + """ + if meshes.isempty(): + return torch.tensor( + [0.0], dtype=torch.float32, device=meshes.device, requires_grad=True + ) + + N = len(meshes) + verts_packed = meshes.verts_packed() # (sum(V_n), 3) + faces_packed = meshes.faces_packed() # (sum(F_n), 3) + edges_packed = meshes.edges_packed() # (sum(E_n), 2) + verts_packed_to_mesh_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),) + face_to_edge = meshes.faces_packed_to_edges_packed() # (sum(F_n), 3) + E = edges_packed.shape[0] # sum(E_n) + F = faces_packed.shape[0] # sum(F_n) + + # We don't want gradients for the following operation. The goal is to + # find for each edge e all the vertices associated with e. In the example above, + # the vertices associated with e are (v0, v1, a, b), i.e. points on e (=v0, v1) + # and points connected on faces to e (=a, b). + with torch.no_grad(): + edge_idx = face_to_edge.reshape(F * 3) # (3 * F,) indexes into edges + vert_idx = ( + faces_packed.view(1, F, 3).expand(3, F, 3).transpose(0, 1).reshape(3 * F, 3) + ) + edge_idx, edge_sort_idx = edge_idx.sort() + vert_idx = vert_idx[edge_sort_idx] + + # In well constructed meshes each edge is shared by precisely 2 faces + # However, in many meshes, this assumption is not always satisfied. + # We want to find all faces that share an edge, a number which can + # vary and which depends on the topology. + # In particular, we find the vertices not on the edge on the shared faces. + # In the example above, we want to associate edge e with vertices a and b. + # This operation is done more efficiently in cpu with lists. + # TODO(gkioxari) find a better way to do this. + + # edge_idx represents the index of the edge for each vertex. We can count + # the number of vertices which are associated with each edge. + # There can be a different number for each edge. + edge_num = edge_idx.bincount(minlength=E) + # Create pairs of vertices associated to e. We generate a list of lists: + # each list has the indices of the vertices which are opposite to one edge. + # The length of the list for each edge will vary. + vert_edge_pair_idx = split_list( + list(range(edge_idx.shape[0])), edge_num.tolist() + ) + # For each list find all combinations of pairs in the list. This represents + # all pairs of vertices which are opposite to the same edge. + vert_edge_pair_idx = [ + [e[i], e[j]] + for e in vert_edge_pair_idx + for i in range(len(e) - 1) + for j in range(1, len(e)) + if i != j + ] + vert_edge_pair_idx = torch.tensor( + vert_edge_pair_idx, device=meshes.device, dtype=torch.int64 + ) + + v0_idx = edges_packed[edge_idx, 0] + v0 = verts_packed[v0_idx] + v1_idx = edges_packed[edge_idx, 1] + v1 = verts_packed[v1_idx] + + # two of the following cross products are zeros as they are cross product + # with either (v1-v0)x(v1-v0) or (v1-v0)x(v0-v0) + n_temp0 = (v1 - v0).cross(verts_packed[vert_idx[:, 0]] - v0, dim=1) + n_temp1 = (v1 - v0).cross(verts_packed[vert_idx[:, 1]] - v0, dim=1) + n_temp2 = (v1 - v0).cross(verts_packed[vert_idx[:, 2]] - v0, dim=1) + n = n_temp0 + n_temp1 + n_temp2 + n0 = n[vert_edge_pair_idx[:, 0]] + n1 = -n[vert_edge_pair_idx[:, 1]] + loss = 1 - torch.cosine_similarity(n0, n1, dim=1) + + verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[vert_idx[:, 0]] + verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[vert_edge_pair_idx[:, 0]] + num_normals = verts_packed_to_mesh_idx.bincount(minlength=N) + weights = 1.0 / num_normals[verts_packed_to_mesh_idx].float() + + loss = loss * weights + if reduce: + return loss.sum() / N + else: + return loss + + +def split_list(input, length_to_split): + inputt = iter(input) + return [list(islice(inputt, elem)) for elem in length_to_split] + + +## new mesh laplacian loss with bug fix (https://github.com/facebookresearch/pytorch3d/blob/ff9c6612b457a2021d88fea119bdb9b94ba017bd/pytorch3d/loss/mesh_laplacian_smoothing.py) +def mesh_laplacian_smoothing(meshes, method: str = "uniform"): + r""" + Computes the laplacian smoothing objective for a batch of meshes. + This function supports three variants of Laplacian smoothing, + namely with uniform weights("uniform"), with cotangent weights ("cot"), + and cotangent cuvature ("cotcurv").For more details read [1, 2]. + Args: + meshes: Meshes object with a batch of meshes. + method: str specifying the method for the laplacian. + Returns: + loss: Average laplacian smoothing loss across the batch. + Returns 0 if meshes contains no meshes or all empty meshes. + Consider a mesh M = (V, F), with verts of shape Nx3 and faces of shape Mx3. + The Laplacian matrix L is a NxN tensor such that LV gives a tensor of vectors: + for a uniform Laplacian, LuV[i] points to the centroid of its neighboring + vertices, a cotangent Laplacian LcV[i] is known to be an approximation of + the surface normal, while the curvature variant LckV[i] scales the normals + by the discrete mean curvature. For vertex i, assume S[i] is the set of + neighboring vertices to i, a_ij and b_ij are the "outside" angles in the + two triangles connecting vertex v_i and its neighboring vertex v_j + for j in S[i], as seen in the diagram below. + .. code-block:: python + a_ij + /\ + / \ + / \ + / \ + v_i /________\ v_j + \ / + \ / + \ / + \ / + \/ + b_ij + The definition of the Laplacian is LV[i] = sum_j w_ij (v_j - v_i) + For the uniform variant, w_ij = 1 / |S[i]| + For the cotangent variant, + w_ij = (cot a_ij + cot b_ij) / (sum_k cot a_ik + cot b_ik) + For the cotangent curvature, w_ij = (cot a_ij + cot b_ij) / (4 A[i]) + where A[i] is the sum of the areas of all triangles containing vertex v_i. + There is a nice trigonometry identity to compute cotangents. Consider a triangle + with side lengths A, B, C and angles a, b, c. + .. code-block:: python + c + /|\ + / | \ + / | \ + B / H| \ A + / | \ + / | \ + /a_____|_____b\ + C + Then cot a = (B^2 + C^2 - A^2) / 4 * area + We know that area = CH/2, and by the law of cosines we have + A^2 = B^2 + C^2 - 2BC cos a => B^2 + C^2 - A^2 = 2BC cos a + Putting these together, we get: + B^2 + C^2 - A^2 2BC cos a + _______________ = _________ = (B/H) cos a = cos a / sin a = cot a + 4 * area 2CH + [1] Desbrun et al, "Implicit fairing of irregular meshes using diffusion + and curvature flow", SIGGRAPH 1999. + [2] Nealan et al, "Laplacian Mesh Optimization", Graphite 2006. + """ + + if meshes.isempty(): + return torch.tensor( + [0.0], dtype=torch.float32, device=meshes.device, requires_grad=True + ) + + N = len(meshes) + verts_packed = meshes.verts_packed() # (sum(V_n), 3) + num_verts_per_mesh = meshes.num_verts_per_mesh() # (N,) + verts_packed_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),) + weights = num_verts_per_mesh.gather(0, verts_packed_idx) # (sum(V_n),) + weights = 1.0 / weights.float() + + # We don't want to backprop through the computation of the Laplacian; + # just treat it as a magic constant matrix that is used to transform + # verts into normals + with torch.no_grad(): + if method == "uniform": + L = meshes.laplacian_packed() + elif method in ["cot", "cotcurv"]: + L, inv_areas = laplacian_cot(meshes) + if method == "cot": + norm_w = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1) + idx = norm_w > 0 + norm_w[idx] = 1.0 / norm_w[idx] + else: + L_sum = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1) + norm_w = 0.25 * inv_areas + else: + raise ValueError("Method should be one of {uniform, cot, cotcurv}") + + if method == "uniform": + loss = L.mm(verts_packed) + elif method == "cot": + loss = L.mm(verts_packed) * norm_w - verts_packed + elif method == "cotcurv": + loss = (L.mm(verts_packed) - L_sum * verts_packed) * norm_w + loss = loss.norm(dim=1) + + loss = loss * weights + return loss.sum() / N + + +def laplacian_cot(meshes): + """ + Returns the Laplacian matrix with cotangent weights and the inverse of the + face areas. + Args: + meshes: Meshes object with a batch of meshes. + Returns: + 2-element tuple containing + - **L**: FloatTensor of shape (V,V) for the Laplacian matrix (V = sum(V_n)) + Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes. + See the description above for more clarity. + - **inv_areas**: FloatTensor of shape (V,) containing the inverse of sum of + face areas containing each vertex + """ + verts_packed = meshes.verts_packed() # (sum(V_n), 3) + faces_packed = meshes.faces_packed() # (sum(F_n), 3) + # V = sum(V_n), F = sum(F_n) + V, F = verts_packed.shape[0], faces_packed.shape[0] + + face_verts = verts_packed[faces_packed] + v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2] + + # Side lengths of each triangle, of shape (sum(F_n),) + # A is the side opposite v1, B is opposite v2, and C is opposite v3 + A = (v1 - v2).norm(dim=1) + B = (v0 - v2).norm(dim=1) + C = (v0 - v1).norm(dim=1) + + # Area of each triangle (with Heron's formula); shape is (sum(F_n),) + s = 0.5 * (A + B + C) + # note that the area can be negative (close to 0) causing nans after sqrt() + # we clip it to a small positive value + area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt() + + # Compute cotangents of angles, of shape (sum(F_n), 3) + A2, B2, C2 = A * A, B * B, C * C + cota = (B2 + C2 - A2) / area + cotb = (A2 + C2 - B2) / area + cotc = (A2 + B2 - C2) / area + cot = torch.stack([cota, cotb, cotc], dim=1) + cot /= 4.0 + + # Construct a sparse matrix by basically doing: + # L[v1, v2] = cota + # L[v2, v0] = cotb + # L[v0, v1] = cotc + ii = faces_packed[:, [1, 2, 0]] + jj = faces_packed[:, [2, 0, 1]] + idx = torch.stack([ii, jj], dim=0).view(2, F * 3) + L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V)) + + # Make it symmetric; this means we are also setting + # L[v2, v1] = cota + # L[v0, v2] = cotb + # L[v1, v0] = cotc + L += L.t() + + # For each vertex, compute the sum of areas for triangles containing it. + idx = faces_packed.view(-1) + inv_areas = torch.zeros(V, dtype=torch.float32, device=meshes.device) + val = torch.stack([area] * 3, dim=1).view(-1) + inv_areas.scatter_add_(0, idx, val) + idx = inv_areas > 0 + inv_areas[idx] = 1.0 / inv_areas[idx] + inv_areas = inv_areas.view(-1, 1) + + return L, inv_areas diff --git a/video3d/utils/flow_viz.py b/video3d/utils/flow_viz.py new file mode 100755 index 0000000000000000000000000000000000000000..f89fdff549ed83af208fdc17efec7dc4da9749f3 --- /dev/null +++ b/video3d/utils/flow_viz.py @@ -0,0 +1,142 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np +import torch + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0, RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0, YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0, GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0, BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:, i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:, :, ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False, normalize=True): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:, :, 0] + v = flow_uv[:, :, 1] + if normalize: + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + else: + rad_max = np.sqrt(flow_uv.shape[0] ** 2 + flow_uv.shape[0] ** 2) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) + + +def flow_batch_to_images(flow_uv, clip_flow=None, convert_to_bgr=False, normalize=True): + flows = [flow_to_image(f.detach().cpu().numpy(), clip_flow, convert_to_bgr, normalize) for f in flow_uv.permute(0, 2, 3, 1)] + flows = torch.cat([torch.from_numpy(f).unsqueeze(0) for f in flows], dim=0) + return flows.to(flow_uv.device).permute(0, 3, 1, 2) / 255.0 diff --git a/video3d/utils/geometry.py b/video3d/utils/geometry.py new file mode 100755 index 0000000000000000000000000000000000000000..a5b68449761c7aa05a1e0dd4b93f03d8fc603781 --- /dev/null +++ b/video3d/utils/geometry.py @@ -0,0 +1,53 @@ +import torch +from einops import repeat + + +def sample_farthest_points(pts, k, return_index=False): + b, c, n = pts.shape + farthest_pts = torch.zeros((b, 3, k), device=pts.device, dtype=pts.dtype) + indexes = torch.zeros((b, k), device=pts.device, dtype=torch.int64) + + index = torch.randint(n, [b], device=pts.device) + + gather_index = repeat(index, 'b -> b c 1', c=c) + farthest_pts[:, :, 0] = torch.gather(pts, 2, gather_index)[:, :, 0] + indexes[:, 0] = index + distances = torch.norm(farthest_pts[:, :, 0][:, :, None] - pts, dim=1) + + for i in range(1, k): + _, index = torch.max(distances, dim=1) + gather_index = repeat(index, 'b -> b c 1', c=c) + farthest_pts[:, :, i] = torch.gather(pts, 2, gather_index)[:, :, 0] + indexes[:, i] = index + distances = torch.min(distances, torch.norm(farthest_pts[:, :, i][:, :, None] - pts, dim=1)) + + if return_index: + return farthest_pts, indexes + else: + return farthest_pts + + +def line_segment_distance(a, b, points, sqrt=True): + """ + compute the distance between a point and a line segment defined by a and b + a, b: ... x D + points: ... x D + """ + def sumprod(x, y, keepdim=True): + return torch.sum(x * y, dim=-1, keepdim=keepdim) + + a, b = a[..., None, :], b[..., None, :] + + t_min = sumprod(points - a, b - a) / torch.max(sumprod(b - a, b - a), torch.tensor(1e-6, device=a.device)) + + t_line = torch.clamp(t_min, 0.0, 1.0) + + # closest points on the line to every point + s = a + t_line * (b - a) + + distance = sumprod(s - points, s - points, keepdim=False) + + if sqrt: + distance = torch.sqrt(distance + 1e-6) + + return distance diff --git a/video3d/utils/meters.py b/video3d/utils/meters.py new file mode 100755 index 0000000000000000000000000000000000000000..dc1e3f5509e967f2b81e66c3cfe1be499e99fc0c --- /dev/null +++ b/video3d/utils/meters.py @@ -0,0 +1,180 @@ +import os +import json +import time +import torch +# import matplotlib.pyplot as plt +import collections + + +class TotalAverage(): + def __init__(self): + self.reset() + + def reset(self): + self.last_value = 0. + self.mass = 0. + self.sum = 0. + + def update(self, value, mass=1): + self.last_value = value + self.mass += mass + self.sum += value * mass + + def get(self): + return self.sum / self.mass + + +class MovingAverage(): + def __init__(self, inertia=0.9): + self.inertia = inertia + self.reset() + self.last_value = None + + def reset(self): + self.last_value = None + self.average = None + + def update(self, value, mass=1): + self.last_value = value + if self.average is None: + self.average = value + else: + self.average = self.inertia * self.average + (1 - self.inertia) * value + + def get(self): + return self.average + + +class MetricsTrace: + def __init__(self): + self.data = {} + self.reset() + + def reset(self): + self.data = {} + + def append(self, dataset, metric): + if dataset not in self.data: + self.data[dataset] = [] + self.data[dataset].append(metric.get_data_dict()) + + def load(self, path): + """Load the metrics trace from the specified JSON file.""" + with open(path, 'r') as f: + self.data = json.load(f) + + def save(self, path): + """Save the metrics trace to the specified JSON file.""" + if path is None: + return + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w') as f: + json.dump(self.data, f, indent=2) + + def plot(self, pdf_path=None): + """Plots and optionally save as PDF the metrics trace.""" + plot_metrics(self.data, pdf_path=pdf_path) + + def get(self): + return self.data + + def __str__(self): + pass + + +class Metrics(): + def __init__(self): + self.iteration_time = MovingAverage(inertia=0.9) + self.now = time.time() + + def update(self, prediction=None, ground_truth=None): + self.iteration_time.update(time.time() - self.now) + self.now = time.time() + + def get_data_dict(self): + return {"objective" : self.objective.get(), "iteration_time" : self.iteration_time.get()} + + +class StandardMetrics(Metrics): + def __init__(self, m=None): + super(StandardMetrics, self).__init__() + self.metrics = m or {} + self.speed = MovingAverage(inertia=0.9) + + def update(self, metric_dict, mass=1): + super(StandardMetrics, self).update() + for metric, val in metric_dict.items(): + if torch.is_tensor(val): + val = val.item() + if metric not in self.metrics: + if 'moving_average' in metric: + try: + p = float(metric.split('moving_average')[-1].split('_')[-1]) + except: + p = 0.9 + self.metrics[metric] = MovingAverage(p) + else: + self.metrics[metric] = TotalAverage() + self.metrics[metric].update(val, mass) + self.speed.update(mass / self.iteration_time.last_value) + + def get_data_dict(self): + data_dict = {k: v.get() for k,v in self.metrics.items()} + data_dict['speed'] = self.speed.get() + return data_dict + + def __str__(self): + pstr = '%7.1fHz\t' %self.speed.get() + pstr += '\t'.join(['%s: %6.5f' %(k,v.get()) for k,v in self.metrics.items()]) + return pstr + + +def plot_metrics(stats, pdf_path=None, fig=1, datasets=None, metrics=None): + """Plot metrics. `stats` should be a dictionary of type + + stats[dataset][t][metric][i] + + where dataset is the dataset name (e.g. `train` or `val`), t is an iteration number, + metric is the name of a metric (e.g. `loss` or `top1`), and i is a loss dimension. + + Alternatively, if a loss has a single dimension, `stats[dataset][t][metric]` can + be a scalar. + + The supported options are: + + - pdf_file: path to a PDF file to store the figure (default: None) + - fig: MatPlotLib figure index (default: 1) + - datasets: list of dataset names to plot (default: None) + - metrics: list of metrics to plot (default: None) + """ + plt.figure(fig) + plt.clf() + linestyles = ['-', '--', '-.', ':'] + datasets = list(stats.keys()) if datasets is None else datasets + # Filter out empty datasets + datasets = [d for d in datasets if len(stats[d]) > 0] + duration = len(stats[datasets[0]]) + metrics = list(stats[datasets[0]][0].keys()) if metrics is None else metrics + for m, metric in enumerate(metrics): + plt.subplot(len(metrics),1,m+1) + legend_content = [] + for d, dataset in enumerate(datasets): + ls = linestyles[d % len(linestyles)] + if isinstance(stats[dataset][0][metric], collections.Iterable): + metric_dimension = len(stats[dataset][0][metric]) + for sl in range(metric_dimension): + x = [stats[dataset][t][metric][sl] for t in range(duration)] + plt.plot(x, linestyle=ls) + name = f'{dataset} {metric}[{sl}]' + legend_content.append(name) + else: + x = [stats[dataset][t][metric] for t in range(duration)] + plt.plot(x, linestyle=ls) + name = f'{dataset} {metric}' + legend_content.append(name) + plt.legend(legend_content, loc=(1.04,0)) + plt.grid(True) + if pdf_path is not None: + plt.savefig(pdf_path, format='pdf', bbox_inches='tight') + plt.draw() + plt.pause(0.0001) diff --git a/video3d/utils/misc.py b/video3d/utils/misc.py new file mode 100755 index 0000000000000000000000000000000000000000..2add7b48132e8e31af893bf0961f34fd52096c18 --- /dev/null +++ b/video3d/utils/misc.py @@ -0,0 +1,334 @@ +import os +import glob +import yaml +import random +import numpy as np +import cv2 +import torch +import torchvision.utils as tvutils +import zipfile +import argparse +from ..render.obj import write_obj, write_textured_obj +import einops +import torch.distributed as dist + + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def setup_runtime(args): + """Load configs, initialize CUDA, CuDNN and the random seeds.""" + + # Setup CUDA + cuda_device_id = args.gpu + if cuda_device_id is not None: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_device_id) + if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + # Setup random seeds for reproducibility + random.seed(args.seed) + np.random.seed(args.seed) + cv2.setRNGSeed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + ## Load config + cfgs = {} + if args.config is not None and os.path.isfile(args.config): + cfgs = load_yaml(args.config) + + cfgs['config'] = args.config + cfgs['seed'] = args.seed + cfgs['num_workers'] = args.num_workers + cfgs['device'] = f"cuda:{args.rank}" if torch.cuda.is_available() and cuda_device_id is not None else 'cpu' + + print(f"Environment: GPU {cuda_device_id} - seed {args.seed}") + return cfgs + + +def load_yaml(path): + print(f"Loading configs from {path}") + with open(path, 'r') as f: + return yaml.safe_load(f) + + +def dump_yaml(path, cfgs): + print(f"Saving configs to {path}") + xmkdir(os.path.dirname(path)) + with open(path, 'w') as f: + return yaml.safe_dump(cfgs, f) + + +def xmkdir(path): + """Create directory PATH recursively if it does not exist.""" + os.makedirs(path, exist_ok=True) + + +def clean_checkpoint(checkpoint_dir, keep_num=2): + if keep_num > 0: + names = list(sorted( + glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.pth')) + )) + if len(names) > keep_num: + for name in names[:-keep_num]: + print(f"Deleting obslete checkpoint file {name}") + os.remove(name) + + +def archive_code(arc_path, filetypes=['.py']): + print(f"Archiving code to {arc_path}") + xmkdir(os.path.dirname(arc_path)) + zipf = zipfile.ZipFile(arc_path, 'w', zipfile.ZIP_DEFLATED) + cur_dir = os.getcwd() + flist = [] + for ftype in filetypes: + flist.extend(glob.glob(os.path.join(cur_dir, '[!results]*', '**', '*'+ftype), recursive=True)) # ignore results folder + flist.extend(glob.glob(os.path.join(cur_dir, '*'+ftype))) + [zipf.write(f, arcname=f.replace(cur_dir,'archived_code', 1)) for f in flist] + zipf.close() + + +def get_model_device(model): + return next(model.parameters()).device + + +def set_requires_grad(nets, requires_grad=False): + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + +def save_videos(out_fold, imgs, prefix='', suffix='', fnames=None, ext='.mp4', cycle=False): + prefix = prefix + '_' if prefix else '' + suffix = '_' + suffix if suffix else '' + + imgs = imgs.transpose(0,1,3,4,2) # BxTxCxHxW -> BxTxHxWxC + for i, fs in enumerate(imgs): + if cycle: + fs = np.concatenate([fs, fs[::-1]], 0) + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + # fourcc = cv2.VideoWriter_fourcc(*'avc1') + + out_fold_i = out_fold[i] if isinstance(out_fold, list) else out_fold + xmkdir(out_fold_i) + + if fnames is None: + idx = len(glob.glob(os.path.join(out_fold_i, prefix+'*'+suffix+ext))) +1 + fname = '%07d' % idx + else: + fname = fnames[i] + fpath = os.path.join(out_fold_i, prefix+fname+suffix+ext) + + vid = cv2.VideoWriter(fpath, fourcc, 5, (fs.shape[2], fs.shape[1])) + [vid.write(np.uint8(f[...,::-1]*255.)) for f in fs] + vid.release() + + +def save_images(out_fold, imgs, prefix='', suffix='', fnames=None, ext='.png'): + prefix = prefix + '_' if prefix else '' + suffix = '_' + suffix if suffix else '' + + imgs = imgs.transpose(0,2,3,1) + for i, img in enumerate(imgs): + img = np.concatenate([np.flip(img[...,:3], -1), img[...,3:]], -1) # RGBA to BGRA + if 'depth' in suffix: + im_out = np.uint16(img*65535.) + else: + im_out = np.uint8(img*255.) + + out_fold_i = out_fold[i] if isinstance(out_fold, list) else out_fold + xmkdir(out_fold_i) + + if fnames is None: + idx = len(glob.glob(os.path.join(out_fold_i, prefix+'*'+suffix+ext))) +1 + fname = '%07d' % idx + else: + fname = fnames[i] + fpath = os.path.join(out_fold_i, prefix+fname+suffix+ext) + + cv2.imwrite(fpath, im_out) + + +def save_txt(out_fold, data, prefix='', suffix='', fnames=None, ext='.txt', fmt='%.6f'): + prefix = prefix + '_' if prefix else '' + suffix = '_' + suffix if suffix else '' + + for i, d in enumerate(data): + out_fold_i = out_fold[i] if isinstance(out_fold, list) else out_fold + xmkdir(out_fold_i) + + if fnames is None: + idx = len(glob.glob(os.path.join(out_fold_i, prefix+'*'+suffix+ext))) +1 + fname = '%07d' % idx + else: + fname = fnames[i] + fpath = os.path.join(out_fold_i, prefix+fname+suffix+ext) + + np.savetxt(fpath, d, fmt=fmt, delimiter=', ') + + +def save_obj(out_fold, meshes=None, save_material=True, feat=None, prefix='', suffix='', fnames=None, resolution=[256, 256], prior_shape=None): + prefix = prefix + '_' if prefix else '' + suffix = '_' + suffix if suffix else '' + + if meshes.v_pos is None: + return + + batch_size = meshes.v_pos.shape[0] + for i in range(batch_size): + out_fold_i = out_fold[i] if isinstance(out_fold, list) else out_fold + xmkdir(out_fold_i) + + if fnames is None: + idx = len(glob.glob(os.path.join(out_fold_i, prefix+'*'+suffix+".obj"))) + 1 + fname = '%07d' % idx + else: + fname = fnames[i] + if save_material: + os.makedirs(os.path.join(out_fold_i, fname), exist_ok=True) + write_textured_obj(out_fold_i, f'{fname}/{prefix+suffix}', meshes, i, save_material=save_material, feat=feat, resolution=resolution, prior_shape=prior_shape) + else: + write_obj(out_fold_i, prefix+fname+suffix, meshes, i, save_material=False, feat=feat, resolution=resolution) + + +def compute_sc_inv_err(d_pred, d_gt, mask=None): + b = d_pred.size(0) + diff = d_pred - d_gt + if mask is not None: + diff = diff * mask + avg = diff.view(b, -1).sum(1) / (mask.view(b, -1).sum(1)) + score = (diff - avg.view(b,1,1))**2 * mask + else: + avg = diff.view(b, -1).mean(1) + score = (diff - avg.view(b,1,1))**2 + return score # masked error maps + + +def compute_angular_distance(n1, n2, mask=None): + dist = (n1*n2).sum(3).clamp(-1,1).acos() /np.pi*180 + return dist*mask if mask is not None else dist + + +def save_scores(out_path, scores, header=''): + print('Saving scores to %s' %out_path) + np.savetxt(out_path, scores, fmt='%.8f', delimiter=',\t', header=header) + + +def image_grid(tensor, nrow=None): + # check if list -> stack to numpy array + if isinstance(tensor, list): + tensor = np.stack(tensor, 0) + # check if numpy array -> convert to torch tensor and swap axes + if isinstance(tensor, np.ndarray): + tensor = torch.from_numpy(tensor).permute(0, 3, 1, 2) + + b, c, h, w = tensor.shape + if nrow is None: + nrow = int(np.ceil(b**0.5)) + if c == 1: + tensor = tensor.repeat(1, 3, 1, 1) + tensor = tvutils.make_grid(tensor, nrow=nrow, normalize=False) + return tensor + + +def video_grid(tensor, nrow=None): + return torch.stack([image_grid(t, nrow=nrow) for t in tensor.unbind(1)], 0) + + +class LazyClass(object): + def __init__(self, cls, *args, **kwargs): + self.cls = cls + self.args = args + self.kwargs = kwargs + self.instance = None + + def get_instance(self): + if self.instance is None: + self.instance = self.cls(*self.args, **self.kwargs) + return self.instance + + def __call__(self, *args, **kwargs): + return self.get_instance()(*args, **kwargs) + + def __getattribute__(self, name): + if name in ['cls', 'args', 'kwargs', 'instance', 'get_instance']: + return super().__getattribute__(name) + else: + return getattr(self.get_instance(), name) + +def add_text_to_image(img, text, pos=(12, 12), color=(1, 1, 1), font_scale=1, thickness=2): + if isinstance(img, torch.Tensor): + img = img.permute(1,2,0).cpu().numpy() + # if grayscale -> convert to RGB + if img.shape[2] == 1: + img = np.repeat(img, 3, 2) + img = cv2.putText(np.ascontiguousarray(img), text, pos, cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness) + return img + +def image_grid_multi_channel(tensor, pca=False, texts=None, font_scale=0.5): + """ + visualize multi-channel image, each channel is a different greyscale image + tensor: (b, c, h, w) + texts: list of strings of length b + """ + # rescale to [0, 1] for per each sample in batch + tensor = tensor.detach().cpu() + min_ = einops.reduce(tensor, 'b c h w -> b 1 1 1', 'min') + max_ = einops.reduce(tensor, 'b c h w -> b 1 1 1', 'max') + tensor = (tensor - min_) / (max_ - min_) + if pca: + import faiss + (b, c, h, w) = tensor.shape + # reshape the tensor to (b, c*h*w) + # tensor = tensor.reshape(b, c*h*w) + tensor_flat = einops.rearrange(tensor, 'b c h w -> (b h w) c') + pca_mat = faiss.PCAMatrix(c, 3) + pca_mat.train(tensor_flat.numpy()) + assert pca_mat.is_trained + tensor_flat_pca = pca_mat.apply_py(tensor_flat.numpy()) + tensor = einops.rearrange(tensor_flat_pca, '(b h w) c -> b h w c', b=b, c=3, h=h, w=w) + else: + tensor = einops.rearrange(tensor, 'b c h w -> (b c) 1 h w') + if texts is not None: + # duplicate texts for each channel + texts = [text for text in texts for _ in range(tensor.shape[0] // len(texts))] + tensor = [add_text_to_image(img, text, font_scale=font_scale) for img, text in zip(tensor, texts)] + return image_grid(tensor) + + +########## DDP Part Taken from: https://github.com/fundamentalvision/Deformable-DETR/blob/main/util/misc.py + +def is_main_process(): + return get_rank() == 0 + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + diff --git a/video3d/utils/segmentation_transforms.py b/video3d/utils/segmentation_transforms.py new file mode 100755 index 0000000000000000000000000000000000000000..6caa8413c0332aadb75aa33438641168170c3f37 --- /dev/null +++ b/video3d/utils/segmentation_transforms.py @@ -0,0 +1,101 @@ +import numpy as np +from PIL import Image +import random + +import torch +from torchvision import transforms as T +from torchvision.transforms import functional as F + + +def pad_if_smaller(img, size, fill=0): + min_size = min(img.size) + if min_size < size: + ow, oh = img.size + padh = size - oh if oh < size else 0 + padw = size - ow if ow < size else 0 + img = F.pad(img, (0, 0, padw, padh), fill=fill) + return img + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + +class ImageOnly(object): + def __init__(self, image_transform): + self.transform = image_transform + + def __call__(self, image, target): + image = self.transform(image) + return image, target + + +class RandomResize(object): + def __init__(self, min_size, max_size=None): + self.min_size = min_size + if max_size is None: + max_size = min_size + self.max_size = max_size + + def __call__(self, image, target): + size = random.randint(self.min_size, self.max_size) + image = F.resize(image, size) + target = F.resize(target, size, interpolation=Image.NEAREST) + return image, target + + +class RandomHorizontalFlip(object): + def __init__(self, flip_prob): + self.flip_prob = flip_prob + + def __call__(self, image, target): + if random.random() < self.flip_prob: + image = F.hflip(image) + target = F.hflip(target) + return image, target + + +class RandomCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, image, target): + image = pad_if_smaller(image, self.size) + target = pad_if_smaller(target, self.size, fill=0) + crop_params = T.RandomCrop.get_params(image, (self.size, self.size)) + image = F.crop(image, *crop_params) + target = F.crop(target, *crop_params) + return image, target + + +class CenterCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, image, target): + image = F.center_crop(image, self.size) + target = F.center_crop(target, self.size) + return image, target + + +class ToTensor(object): + def __call__(self, image, target): + image = F.to_tensor(image) + target = F.to_tensor(target) + return image, target + + +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target): + image = F.normalize(image, mean=self.mean, std=self.std) + return image, target diff --git a/video3d/utils/skinning_v4.py b/video3d/utils/skinning_v4.py new file mode 100755 index 0000000000000000000000000000000000000000..f71f99f044ac30d6c8b5132096946bed892f7d1f --- /dev/null +++ b/video3d/utils/skinning_v4.py @@ -0,0 +1,525 @@ +import math +import torch +import torch.nn as nn +from . import geometry +from einops import rearrange +import itertools + + +def _joints_to_bones(joints, bones_idxs): + bones = [] + for a, b in bones_idxs: + bones += [torch.stack([joints[:, :, a, :], joints[:, :, b, :]], dim=2)] + bones = torch.stack(bones, dim=2) + return bones + + +def _compute_vertices_to_bones_weights(bones_pred, seq_shape_pred, temperature=1): + vertices_to_bones = [] + for i in range(bones_pred.shape[2]): + vertices_to_bones += [geometry.line_segment_distance(bones_pred[:, :, i, 0], bones_pred[:, :, i, 1], seq_shape_pred)] + # vertices_to_bones = nn.functional.softmax(1 / torch.stack(vertices_to_bones) / temperature, dim=0) + vertices_to_bones = nn.functional.softmax(-torch.stack(vertices_to_bones) / temperature, dim=0) + return vertices_to_bones + + +def build_kinematic_chain(n_bones, start_bone_idx): + # build bones and kinematic chain starting from leaf bone (body joint) + bones_to_joints = [] + kinematic_chain = [] + bone_idx = start_bone_idx + # bones from leaf to root + dependent_bones = [] + for i in range(n_bones): + bones_to_joints += [(i + 1, i)] + kinematic_chain = [(bone_idx, dependent_bones)] + kinematic_chain # parent is always in the front + dependent_bones = dependent_bones + [bone_idx] + bone_idx += 1 + return bones_to_joints, kinematic_chain, dependent_bones + + +def update_body_kinematic_chain(kinematic_chain, leg_kinematic_chain, body_bone_idx, leg_bone_idxs, attach_legs_to_body=True): + if attach_legs_to_body: + for bone_idx, dependent_bones in kinematic_chain: + if bone_idx == body_bone_idx or body_bone_idx in dependent_bones: + dependent_bones += leg_bone_idxs + kinematic_chain = kinematic_chain + leg_kinematic_chain # parent is always in the front + return kinematic_chain + + +def lift_points_mesh(points, seq_shape, size_aspect=0.5): + """ + for a set of points that's generated by linear interpolation, lift them in y-axis to match the actual bones + this operates on all the joint points except for the first and last one + """ + + points_to_lift = points[:, :, 1:-1, :] + points_z_range_max = points_to_lift[..., 2] - size_aspect * (points_to_lift[..., 2] - points[:, :, :-2, 2]) + points_z_range_min = points_to_lift[..., 2] - size_aspect * (points_to_lift[..., 2] - points[:, :, 2:, 2]) + + points_z_range_min = points_z_range_min.unsqueeze(-1).expand(-1, -1, -1, seq_shape.shape[-2]) + points_z_range_max = points_z_range_max.unsqueeze(-1).expand(-1, -1, -1, seq_shape.shape[-2]) + + valid_points = seq_shape.unsqueeze(2).expand(-1, -1, points_to_lift.shape[-2], -1, -1) + + valid_idx_1 = valid_points[..., 2] > points_z_range_min + valid_idx_2 = valid_points[..., 2] < points_z_range_max + + valid_idx = valid_idx_1 * valid_idx_2 + valid_idx = valid_idx.float() + + valid_y = valid_points[..., 1] * valid_idx + (-1e6) * (1 - valid_idx) + + valid_y, _ = valid_y.max(dim=-1) + is_valid = valid_y != (-1e6) + is_valid = is_valid.float() + + points[:, :, 1:-1, 1] = points[:, :, 1:-1, 1] * (1-is_valid) + valid_y * is_valid + + return points + + +@torch.no_grad() +def estimate_bones(seq_shape, n_body_bones, resample=False, n_legs=4, n_leg_bones=0, body_bones_type='z_minmax', compute_kinematic_chain=True, aux=None, attach_legs_to_body=True, bone_y_threshold=None, body_bone_idx_preset=[3, 5, 5, 3]): + """ + Estimate the position and structure of bones given the mesh vertex positions. + + Args: + seq_shape: a tensor of shape (B, F, V, 3), the batched position of mesh vertices. + n_body_bones: an integer, the desired number of bones. + Returns: + (bones_pred, kinematic_chain) where + bones_pred: a tensor of shape (B, F, num_bones, 2, 3) + kinematic_chain: a list of tuples of length n_body_bones; for each tuple, the first element is the bone index while + the second element is a list of bones indices of dependent bones. + """ + # preprocess shape + if resample: + b, _, n, _ = seq_shape.shape + seq_shape = geometry.sample_farthest_points(rearrange(seq_shape, 'b f n d -> (b f) d n'), n // 4) + seq_shape = rearrange(seq_shape, '(b f) d n -> b f n d', b=b) + + if body_bones_type == 'max_distance': + raise NotImplementedError + # find two farthest points + # x is the symmetry plane, ignore it + # dists = torch.linalg.norm(seq_shape[:, :, None, :, 1:] - seq_shape[:, :, :, None, 1:], dim=-1) # Shape: (B, F, V, V) + # num_verts = dists.shape[-1] + # indices_flat = rearrange(dists, 'b f d1 d2 -> b f (d1 d2)').argmax(2) # Shape: (B, F) + # indices = torch.cat([(indices_flat // num_verts)[..., None], (indices_flat % num_verts)[..., None]], dim=2) # Shape: (B, F, 2) + # indices_gather = indices[..., None].repeat(1, 1, 1, 3) # Shape: (B, F, 2, 3) + # points = seq_shape.gather(2, indices_gather) # Shape: (B, F, 2, 3) + # fix the points order along z axis + # z_coordinate = points[:, :, :, 2] # Shape: (B, F, 2) + # front = z_coordinate < 0 + # point_a = rearrange(points[~front], '(b f) d -> b f d', b=seq_shape.shape[0]) # Shape: (B, F, 3) + # point_b = rearrange(points[front], '(b f) d -> b f d', b=seq_shape.shape[0]) # Shape: (B, F, 3) + elif body_bones_type == 'z_minmax': + indices_max, indices_min = seq_shape[..., 2].argmax(dim=2), seq_shape[..., 2].argmin(dim=2) + indices = torch.cat([indices_max[..., None], indices_min[..., None]], dim=2) + indices_gather = indices[..., None].repeat(1, 1, 1, 3) # Shape: (B, F, 2, 3) + points = seq_shape.gather(2, indices_gather) + point_a = points[:, :, 0, :] + point_b = points[:, :, 1, :] + elif body_bones_type == 'z_minmax_y+': + ## TODO: mean may not be very robust, as inside is noisy + mid_point = seq_shape.mean(2) + seq_shape_pos_y_mask = (seq_shape[:, :, :, 1] > (mid_point[:, :, None, 1] - 0.5)).float() # y higher than midpoint + seq_shape_z = seq_shape[:, :, :, 2] * seq_shape_pos_y_mask + (-1e6) * (1 - seq_shape_pos_y_mask) + indices = seq_shape_z.argmax(2) + indices_gather = indices[..., None, None].repeat(1, 1, 1, 3) + point_a = seq_shape.gather(2, indices_gather).squeeze(2) + seq_shape_z = seq_shape[:, :, :, 2] * seq_shape_pos_y_mask + 1e6 * (1 - seq_shape_pos_y_mask) + indices = seq_shape_z.argmin(2) + indices_gather = indices[..., None, None].repeat(1, 1, 1, 3) + point_b = seq_shape.gather(2, indices_gather).squeeze(2) + elif body_bones_type == 'mine': + ## TODO: mean may not be very robust, as inside is noisy + mid_point = seq_shape.mean(2) + seq_shape_pos_y_mask = (seq_shape[:, :, :, 1] > (mid_point[:, :, None, 1] - 0.5)).float() # y higher than midpoint + seq_shape_z = seq_shape[:, :, :, 2] * seq_shape_pos_y_mask + (-1e6) * (1 - seq_shape_pos_y_mask) + indices = seq_shape_z.argmax(2) + indices_gather = indices[..., None, None].repeat(1, 1, 1, 3) + point_a = seq_shape.gather(2, indices_gather).squeeze(2) + seq_shape_z = seq_shape[:, :, :, 2] * seq_shape_pos_y_mask + 1e6 * (1 - seq_shape_pos_y_mask) + indices = seq_shape_z.argmin(2) + indices_gather = indices[..., None, None].repeat(1, 1, 1, 3) + point_b = seq_shape.gather(2, indices_gather).squeeze(2) + else: + raise NotImplementedError + + # place points on the symmetry axis + point_a[..., 0] = 0 + point_b[..., 0] = 0 + + mid_point = seq_shape.mean(2) # Shape: (B, F, 3) + # place points on the symmetry axis + mid_point[..., 0] = 0 + if n_leg_bones > 0: + mid_point[..., 1] += 0.5 # lift mid point a bit higher if there are legs + + assert n_body_bones % 2 == 0 + n_joints = n_body_bones + 1 + blend = torch.linspace(0., 1., math.ceil(n_joints / 2), device=point_a.device)[None, None, :, None] # Shape: (1, 1, (n_joints + 1) / 2, 1) + joints_a = point_a[:, :, None, :] * (1 - blend) + mid_point[:, :, None, :] * blend + # point_a to mid_point + joints_b = point_b[:, :, None, :] * blend + mid_point[:, :, None, :] * (1 - blend) + # mid_point to point_b + joints = torch.cat([joints_a[:, :, :-1], joints_b], 2) # Shape: (B, F, n_joints, 3) + + if body_bones_type == 'mine': + joints = lift_points_mesh(joints, seq_shape) + + # build bones and kinematic chain starting from leaf bones + if compute_kinematic_chain: + aux = {} + half_n_body_bones = n_body_bones // 2 + bones_to_joints = [] + kinematic_chain = [] + bone_idx = 0 + # bones from point_a to mid_point + dependent_bones = [] + for i in range(half_n_body_bones): + bones_to_joints += [(i + 1, i)] + kinematic_chain = [(bone_idx, dependent_bones)] + kinematic_chain # parent is always in the front + dependent_bones = dependent_bones + [bone_idx] + bone_idx += 1 + # bones from point_b to mid_point + dependent_bones = [] + for i in range(n_body_bones - 1, half_n_body_bones - 1, -1): + bones_to_joints += [(i, i + 1)] + kinematic_chain = [(bone_idx, dependent_bones)] + kinematic_chain # parent is always in the front + dependent_bones = dependent_bones + [bone_idx] + bone_idx += 1 + aux['bones_to_joints'] = bones_to_joints + else: + bones_to_joints = aux['bones_to_joints'] + kinematic_chain = aux['kinematic_chain'] + + bones_pred = _joints_to_bones(joints, bones_to_joints) + + if n_leg_bones > 0: + assert n_legs == 4 + # attach four legs + # y, z is symetry plain + # y axis is up + # + # top down view: + # + # | + # 2 | 1 + # -------|------ > x + # 3 | 0 + # ⌄ + # z + # + # find a point with the lowest y in each quadrant + # max_dist = (point_a - point_b).norm(p=2, dim=-1) + xs, ys, zs = seq_shape.unbind(-1) + # if bone_y_threshold is not None: + # flags = (ys < bone_y_threshold) + # x_margin = (xs[flags].quantile(0.95) - xs[flags].quantile(0.05)) * 0.2 + # x0 = xs[flags].quantile(0.5) + # else: + # x_margin = (xs.quantile(0.95) - xs.quantile(0.05)) * 0.2 + # x0 = 0 + if bone_y_threshold is None: + x_margin = (xs.quantile(0.95) - xs.quantile(0.05)) * 0.2 + x0 = 0 + quadrant0 = torch.logical_and(xs - x0 > x_margin, zs > 0) + quadrant1 = torch.logical_and(xs - x0 > x_margin, zs < 0) + quadrant2 = torch.logical_and(xs - x0 < -x_margin, zs < 0) + quadrant3 = torch.logical_and(xs - x0 < -x_margin, zs > 0) + + else: + y_threshold = ys.quantile(bone_y_threshold) + flags = (ys < y_threshold) + + x0 = xs[flags].quantile(0.5) + z0 = zs[flags].quantile(0.5) + x_margin = (xs[flags].quantile(0.95) - xs[flags].quantile(0.05)) * 0.2 + z_margin = (zs[flags].quantile(0.95) - zs[flags].quantile(0.05)) * 0.2 + + # quadrant0 = torch.logical_and(xs - x0 > x_margin, zs > z0) + # quadrant1 = torch.logical_and(xs - x0 > x_margin, zs < z0) + # quadrant2 = torch.logical_and(xs - x0 < -x_margin, zs < z0) + # quadrant3 = torch.logical_and(xs - x0 < -x_margin, zs > z0) + + quadrant0 = torch.logical_and(xs - x0 > x_margin, zs - z0 > z_margin) + quadrant1 = torch.logical_and(xs - x0 > x_margin, zs < z0) + quadrant2 = torch.logical_and(xs - x0 < -x_margin, zs < z0) + quadrant3 = torch.logical_and(xs - x0 < -x_margin, zs - z0 > z_margin) + + def find_leg_in_quadrant(quadrant, n_bones, body_bone_idx, body_bones_type=None): + all_joints = torch.zeros([seq_shape.shape[0], seq_shape.shape[1], n_bones + 1, 3], dtype=seq_shape.dtype, device=seq_shape.device) + for b in range(seq_shape.shape[0]): + for f in range(seq_shape.shape[1]): + # find a point with the lowest y + quadrant_points = seq_shape[b, f][quadrant[b, f]] + if len(quadrant_points.view(-1)) < 1: + import pdb; pdb.set_trace() + + idx = torch.argmin(quadrant_points[:, 1]) ## lowest y + foot = quadrant_points[idx] + + # find closest point on the body joints (the end joint of the bone) + if body_bone_idx is None: + body_bone_idx_1 = int(torch.argmin(torch.norm(bones_pred[b, f, :, 1] - foot[None], dim=1))) + body_bone_idx_2 = int(torch.argmin((bones_pred[b, f, :, 1, 2] - foot[None, 2]).abs())) # closest in z axis + # if the body_bone_idx_1 is 4, then should use body_bone_idx_2 + # body_bone_idx = body_bone_idx_1 if body_bone_idx_1 != 4 else body_bone_idx_2 # this is used for distribution loss caused tilt effect + body_bone_idx = body_bone_idx_2 + body_joint = bones_pred[b, f, body_bone_idx, 1] + + # create bone structure from the foot to the body joint + blend = torch.linspace(0., 1., n_bones + 1, device=seq_shape.device)[:, None] + joints = foot[None] * (1 - blend) + body_joint[None] * blend + all_joints[b, f] = joints + return all_joints, body_bone_idx + + quadrants = [quadrant0, quadrant1, quadrant2, quadrant3] + # body_bone_idxs = [None, None, None, None] + # body_bone_idxs = [3, 5, 5, 3] + # body_bone_idxs = [2, 6, 6, 2] + # body_bone_idxs = [2, 7, 7, 2] + # body_bone_idxs = [3, 6, 6, 3] + if body_bone_idx_preset == [0, 0, 0, 0]: + body_bone_idx_preset = [None, None, None, None] + body_bone_idxs = body_bone_idx_preset + + start_bone_idx = n_body_bones + all_leg_bones = [] + if compute_kinematic_chain: + leg_auxs = [] + else: + leg_auxs = aux['legs'] + for i, quadrant in enumerate(quadrants): + if compute_kinematic_chain: + leg_i_aux = {} + body_bone_idx = body_bone_idxs[i] + if i == 2: + body_bone_idx = body_bone_idxs[1] + elif i == 3: + body_bone_idx = body_bone_idxs[0] + + leg_joints, body_bone_idx = find_leg_in_quadrant(quadrant, n_leg_bones, body_bone_idx=body_bone_idx, body_bones_type=body_bones_type) + body_bone_idxs[i] = body_bone_idx + + leg_bones_to_joints, leg_kinematic_chain, leg_bone_idxs = build_kinematic_chain(n_leg_bones, start_bone_idx=start_bone_idx) + kinematic_chain = update_body_kinematic_chain(kinematic_chain, leg_kinematic_chain, body_bone_idx, leg_bone_idxs, attach_legs_to_body=attach_legs_to_body) + leg_i_aux['body_bone_idx'] = body_bone_idx + leg_i_aux['leg_bones_to_joints'] = leg_bones_to_joints + start_bone_idx += n_leg_bones + else: + leg_i_aux = leg_auxs[i] + body_bone_idx = leg_i_aux['body_bone_idx'] + leg_joints, _ = find_leg_in_quadrant(quadrant, n_leg_bones, body_bone_idx, body_bones_type=body_bones_type) + leg_bones_to_joints = leg_i_aux['leg_bones_to_joints'] + leg_bones = _joints_to_bones(leg_joints, leg_bones_to_joints) + all_leg_bones += [leg_bones] + if compute_kinematic_chain: + leg_auxs += [leg_i_aux] + + all_bones = [bones_pred] + all_leg_bones + all_bones = torch.cat(all_bones, dim=2) + else: + all_bones = bones_pred + + if compute_kinematic_chain: + aux['kinematic_chain'] = kinematic_chain + if n_leg_bones > 0: + aux['legs'] = leg_auxs + return all_bones.detach(), kinematic_chain, aux + else: + return all_bones.detach() + + +def _estimate_bone_rotation(forward): + """ + (0, 0, 1) = matmul(b, R^(-1)) + + assumes y, z is a symmetry plane + + returns R + """ + forward = nn.functional.normalize(forward, p=2, dim=-1) + + right = torch.FloatTensor([[1, 0, 0]]).to(forward.device) + right = right.expand_as(forward) + up = torch.cross(forward, right, dim=-1) + up = nn.functional.normalize(up, p=2, dim=-1) + right = torch.cross(up, forward, dim=-1) + up = nn.functional.normalize(up, p=2, dim=-1) + + R = torch.stack([right, up, forward], dim=-1) + + return R + + +def children_to_parents(kinematic_tree): + """ + converts list [(bone1, [children1, ...]), (bone2, [children1, ...]), ...] to [(bone1, [parent1, ...]), ....] + """ + parents = [] + for bone_id, _ in kinematic_tree: + # establish a kinematic chain with current bone as the leaf bone + parents_ids = [parent_id for parent_id, children in kinematic_tree if bone_id in children] + parents += [(bone_id, parents_ids)] + return parents + + +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ [Borrowed from PyTorch3D] + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: + """ [Borrowed from PyTorch3D] + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +def _prepare_transform_mtx(rotation=None, translation=None): + mtx = torch.eye(4)[None] + if rotation is not None: + if len(mtx) != len(rotation): + assert len(mtx) == 1 + mtx = mtx.repeat(len(rotation), 1, 1) + mtx = mtx.to(rotation.device) + mtx[:, :3, :3] = rotation + if translation is not None: + if len(mtx) != len(translation): + assert len(mtx) == 1 + mtx = mtx.repeat(len(translation), 1, 1) + mtx = mtx.to(translation.device) + mtx[:, :3, 3] = translation + return mtx + + +def _invert_transform_mtx(mtx): + inv_mtx = torch.eye(4)[None].repeat(len(mtx), 1, 1).to(mtx.device) + rotation = mtx[:, :3, :3] + translation = mtx[:, :3, 3] + inv_mtx[:, :3, :3] = rotation.transpose(1, 2) + inv_mtx[:, :3, 3] = -torch.bmm(rotation.transpose(1, 2), translation.unsqueeze(-1)).squeeze(-1) + return inv_mtx + + +def skinning(v_pos, bones_pred, kinematic_tree, deform_params, output_posed_bones=False, temperature=1): + """ + """ + device = deform_params.device + batch_size, num_frames = deform_params.shape[:2] + shape = v_pos + + # Associate vertices to bones + vertices_to_bones = _compute_vertices_to_bones_weights(bones_pred, shape.detach(), temperature=temperature) # Shape: (num_bones, B, F, V) + + rots_pred = deform_params + + # Rotate vertices based on bone assignments + frame_shape_pred = [] + if output_posed_bones: + posed_bones = bones_pred.clone() + if posed_bones.shape[0] != batch_size or posed_bones.shape[1] != num_frames: + posed_bones = posed_bones.repeat(batch_size, num_frames, 1, 1, 1) # Shape: (B, F, num_bones, 2, 3) + + # Go through each bone + for bone_id, _ in kinematic_tree: + # Establish a kinematic chain with current bone as the leaf bone + ## TODO: this assumes the parents is always in the front of the list + parents_ids = [parent_id for parent_id, children in kinematic_tree if bone_id in children] + chain_ids = parents_ids + [bone_id] + # Chain from leaf to root + chain_ids = chain_ids[::-1] + + # Go through the kinematic chain from leaf to root and compose transformation + transform_mtx = torch.eye(4)[None].to(device) + for i in chain_ids: + # Establish transformation + rest_joint = bones_pred[:, :, i, 0, :].view(-1, 3) + rest_bone_vector = bones_pred[:, :, i, 1, :] - bones_pred[:, :, i, 0, :] + rest_bone_rot = _estimate_bone_rotation(rest_bone_vector.view(-1, 3)) + rest_bone_mtx = _prepare_transform_mtx(rotation=rest_bone_rot, translation=rest_joint) + rest_bone_inv_mtx = _invert_transform_mtx(rest_bone_mtx) + + # Transform to the bone local frame + transform_mtx = torch.matmul(rest_bone_inv_mtx, transform_mtx) + + # Rotate the mesh in the bone local frame + rot_pred = rots_pred[:, :, i] + rot_pred_mat = euler_angles_to_matrix(rot_pred.view(-1, 3), convention='XYZ') + rot_pred_mtx = _prepare_transform_mtx(rotation=rot_pred_mat, translation=None) + transform_mtx = torch.matmul(rot_pred_mtx, transform_mtx) + + # Transform to the world frame + transform_mtx = torch.matmul(rest_bone_mtx, transform_mtx) + + # Transform vertices + shape4 = rearrange(torch.cat([shape, torch.ones_like(shape[...,:1])], dim=-1), 'b f ... -> (b f) ...') + seq_shape_bone = torch.matmul(shape4, transform_mtx.transpose(-2, -1))[..., :3] + seq_shape_bone = rearrange(seq_shape_bone, '(b f) ... -> b f ...', b=batch_size, f=num_frames) + + if output_posed_bones: + bones4 = torch.cat([rearrange(posed_bones[:, :, bone_id], 'b f ... -> (b f) ...'), torch.ones(batch_size * num_frames, 2, 1).to(device)], dim=-1) + posed_bones[:, :, bone_id] = rearrange(torch.matmul(bones4, transform_mtx.transpose(-2, -1))[..., :3], '(b f) ... -> b f ...', b=batch_size, f=num_frames) + + # Transform mesh with weights + frame_shape_pred += [vertices_to_bones[bone_id, ..., None] * seq_shape_bone] + + frame_shape_pred = sum(frame_shape_pred) + + aux = {} + aux['bones_pred'] = bones_pred + aux['vertices_to_bones'] = vertices_to_bones + if output_posed_bones: + aux['posed_bones'] = posed_bones + + return frame_shape_pred, aux diff --git a/video3d/utils/sphere.py b/video3d/utils/sphere.py new file mode 100755 index 0000000000000000000000000000000000000000..4b586a81fb9343c026a50a96f4b64f02a1ac74c2 --- /dev/null +++ b/video3d/utils/sphere.py @@ -0,0 +1,170 @@ +import io +import numpy as np +import cv2 +from PIL import Image +import matplotlib as mpl +mpl.use('Agg') +import matplotlib.pyplot as plt +import torch +# import pytorch3d +# import pytorch3d.renderer +# import pytorch3d.structures +# import pytorch3d.io +# import pytorch3d.transforms +# import pytorch3d.utils + + +## https://stackoverflow.com/a/58641662/11471407 +def fig_to_img(fig, dpi=200, im_size=(512,512)): + buf = io.BytesIO() + fig.savefig(buf, format="png", dpi=dpi) + buf.seek(0) + img = np.array(Image.open(buf).convert('RGB').resize(im_size)) / 255. + return img + + +def get_ico_sphere(subdiv=1): + return pytorch3d.utils.ico_sphere(level=subdiv) + + +def get_symmetric_ico_sphere(subdiv=1, return_tex_uv=True, return_face_tex_map=True, device='cpu'): + sph_mesh = get_ico_sphere(subdiv=subdiv) + sph_verts = sph_mesh.verts_padded()[0] + sph_faces = sph_mesh.faces_padded()[0] + + ## rotate the default mesh s.t. the seam is exactly on yz-plane + rot_z = np.arctan(0.5000/0.3090) # computed from vertices in ico_sphere + tfs = pytorch3d.transforms.RotateAxisAngle(rot_z, 'Z', degrees=False) + rotated_verts = tfs.transform_points(sph_verts) + + ## identify vertices on each side and on the seam + verts_id_seam = [] + verts_id_one_side = [] + verts_id_other_side = [] + for i, v in enumerate(rotated_verts): + ## on the seam, x=0 + if v[0].abs() < 0.001: # threshold 0.001 + verts_id_seam += [i] + rotated_verts[i][0] = 0. # force it to be 0 + + ## right side, x>0 + elif v[0] > 0: + verts_id_one_side += [i] + + ## left side, x<0 + else: + verts_id_other_side += [i] + + ## create a new set of symmetric vertices + new_vid = 0 + vid_old_to_new = {} + verts_seam = [] + for vid in verts_id_seam: + verts_seam += [rotated_verts[vid]] + vid_old_to_new[vid] = new_vid + new_vid += 1 + verts_seam = torch.stack(verts_seam, 0) + + verts_one_side = [] + for vid in verts_id_one_side: + verts_one_side += [rotated_verts[vid]] + vid_old_to_new[vid] = new_vid + new_vid += 1 + verts_one_side = torch.stack(verts_one_side, 0) + + verts_other_side = [] + for vid in verts_id_one_side: + verts_other_side += [rotated_verts[vid] * torch.FloatTensor([-1,1,1])] # flip x + new_vid += 1 + verts_other_side = torch.stack(verts_other_side, 0) + + new_verts = torch.cat([verts_seam, verts_one_side, verts_other_side], 0) + + ## create a new set of symmetric faces + faces_one_side = [] + faces_other_side = [] + for old_face in sph_faces: + new_face1 = [] # one side + new_face2 = [] # the other side + for vi in old_face: + vi = vi.item() + if vi in verts_id_seam: + new_face1 += [vid_old_to_new[vi]] + new_face2 += [vid_old_to_new[vi]] + elif vi in verts_id_one_side: + new_face1 += [vid_old_to_new[vi]] + new_face2 += [vid_old_to_new[vi]+len(verts_id_one_side)] # assuming the symmetric vertices are appended right after the original ones + else: + break + + if len(new_face1) == 3: # no vert on the other side + faces_one_side += [new_face1] + faces_other_side += [new_face2[::-1]] # reverse face orientation + new_faces = faces_one_side + faces_other_side + new_faces = torch.LongTensor(new_faces) + sym_sph_mesh = pytorch3d.structures.Meshes(verts=[new_verts], faces=[new_faces]) + + aux = {} + aux['num_verts_seam'] = len(verts_seam) + aux['num_verts_one_side'] = len(verts_one_side) + + ## create texture map uv + if return_tex_uv: + verts_tex_uv = torch.stack([-new_verts[:,2], new_verts[:,1]], 1) # -z,y + verts_tex_uv = verts_tex_uv / ((verts_tex_uv**2).sum(1,keepdim=True)**0.5).clamp(min=1e-8) + magnitude = new_verts[:,:1].abs().acos() # set magnitude to angle deviation from vertical axis, for more even texture mapping + magnitude = magnitude / magnitude.max() *0.95 # max 0.95 + verts_tex_uv = verts_tex_uv * magnitude + verts_tex_uv = verts_tex_uv /2 + 0.5 # rescale to 0~1 + face_tex_ids = new_faces + aux['verts_tex_uv'] = verts_tex_uv.to(device) + aux['face_tex_ids'] = face_tex_ids.to(device) + + ## create face color map + if return_face_tex_map: + dpi = 200 + im_size = (512, 512) + fig = plt.figure(figsize=(8,8), dpi=dpi, frameon=False) + ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax.set_axis_off() + fig.add_axes(ax) + + num_colors = 10 + cmap = plt.get_cmap('tab10', num_colors) + num_faces = len(face_tex_ids) + face_tex_ids_one_side = face_tex_ids[:num_faces//2] # assuming symmetric faces are appended right after the original ones + for i, face in enumerate(face_tex_ids_one_side): + vert_uv = verts_tex_uv[face] # 3x2 + # color = cmap(i%num_colors) + color = cmap(np.random.randint(num_colors)) + t = plt.Polygon(vert_uv, facecolor=color, edgecolor='black', linewidth=2) + ax.add_patch(t) + ## draw arrow + ax.arrow(0.85, 0.5, -0.7, 0., length_includes_head=True, width=0.03, head_width=0.15, overhang=0.2, color='white') + ax.set_xlim(0,1) + ax.set_ylim(0,1) + face_tex_map = torch.FloatTensor(fig_to_img(fig, dpi, im_size)) + plt.close() + + ## draw seam + fig = plt.figure(figsize=(8,8), dpi=dpi, frameon=False) + ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax.set_axis_off() + fig.add_axes(ax) + for i, face in enumerate(face_tex_ids_one_side): + vert_uv = verts_tex_uv[face] # 3x2 + vert_on_seam = ((vert_uv-0.5)**2).sum(1)**0.5 > 0.47 + if vert_on_seam.sum() == 2: + ax.plot(*vert_uv[vert_on_seam].t(), color='black', linewidth=10) + ax.set_xlim(0,1) + ax.set_ylim(0,1) + seam_mask = torch.FloatTensor(fig_to_img(fig, dpi, im_size)) + plt.close() + seam_mask = (seam_mask[:,:,:1] < 0.1).float() + + red = torch.FloatTensor([1,0,0]).view(1,1,3) + face_tex_map = seam_mask * red + (1-seam_mask) * face_tex_map + aux['face_tex_map'] = face_tex_map.to(device) + aux['seam_mask'] = seam_mask.to(device) + + return sym_sph_mesh.to(device), aux