Flash3d / flash3d /networks /gaussian_predictor.py
Ryukijano's picture
commit the whole flash3d
ffbcf9e verified
from pathlib import Path
import logging
import torch
import torch.nn as nn
from einops import rearrange
from networks.layers import BackprojectDepth, disp_to_depth
from networks.resnet_encoder import ResnetEncoder
from networks.depth_decoder import DepthDecoder
from networks.gaussian_decoder import GaussianDecoder
def default_param_group(model):
return [{'params': model.parameters()}]
def to_device(inputs, device):
for key, ipt in inputs.items():
if isinstance(ipt, torch.Tensor):
inputs[key] = ipt.to(device)
return inputs
class GaussianPredictor(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
# checking height and width are multiples of 32
# assert cfg.dataset.width % 32 == 0, "'width' must be a multiple of 32"
models = {}
self.parameters_to_train = []
self.num_scales = len(cfg.model.scales)
assert cfg.model.frame_ids[0] == 0, "frame_ids must start with 0"
if cfg.model.use_stereo:
cfg.model.frame_ids.append("s")
model_name = cfg.model.name
if model_name == "resnet":
models["encoder"] = ResnetEncoder(
cfg.model.num_layers,
cfg.model.weights_init == "pretrained",
cfg.model.resnet_bn_order
)
self.parameters_to_train += default_param_group(models["encoder"])
if not cfg.model.unified_decoder:
models["depth"] = DepthDecoder(
cfg, models["encoder"].num_ch_enc)
self.parameters_to_train += default_param_group(models["depth"])
if cfg.model.gaussian_rendering:
for i in range(cfg.model.gaussians_per_pixel):
gauss_decoder = GaussianDecoder(
cfg, models["encoder"].num_ch_enc,
)
self.parameters_to_train += default_param_group(gauss_decoder)
models["gauss_decoder_"+str(i)] = gauss_decoder
elif model_name == "unidepth":
from networks.unidepth import UniDepthSplatter
models["unidepth"] = UniDepthSplatter(cfg)
self.parameters_to_train += models["unidepth"].get_parameter_groups()
elif model_name in ["unidepth_unprojector_vit", "unidepth_unprojector_cnvnxtl"]:
from networks.unidepth import UniDepthUnprojector
models["unidepth"] = UniDepthUnprojector(cfg)
self.parameters_to_train += models["unidepth"].get_parameter_groups()
elif model_name in ["unidepth_extension_vit", "unidepth_extension_cnvnxtl"]:
from networks.unidepth_extension import UniDepthExtended
models["unidepth_extended"] = UniDepthExtended(cfg)
self.parameters_to_train += models["unidepth_extended"].get_parameter_groups()
self.models = nn.ModuleDict(models)
backproject_depth = {}
H = cfg.dataset.height
W = cfg.dataset.width
for scale in cfg.model.scales:
h = H // (2 ** scale)
w = W // (2 ** scale)
if cfg.model.shift_rays_half_pixel == "zero":
shift_rays_half_pixel = 0
elif cfg.model.shift_rays_half_pixel == "forward":
shift_rays_half_pixel = 0.5
elif cfg.model.shift_rays_half_pixel == "backward":
shift_rays_half_pixel = -0.5
else:
raise NotImplementedError
backproject_depth[str(scale)] = BackprojectDepth(
cfg.optimiser.batch_size * cfg.model.gaussians_per_pixel,
# backprojection can be different if padding was used
h + 2 * self.cfg.dataset.pad_border_aug,
w + 2 * self.cfg.dataset.pad_border_aug,
shift_rays_half_pixel=shift_rays_half_pixel
)
self.backproject_depth = nn.ModuleDict(backproject_depth)
def set_train(self):
"""Convert all models to training mode
"""
for m in self.models.values():
m.train()
self._is_train = True
def set_eval(self):
"""Convert all models to testing/evaluation mode
"""
for m in self.models.values():
m.eval()
self._is_train = False
def is_train(self):
return self._is_train
def forward(self, inputs):
cfg = self.cfg
B = cfg.optimiser.batch_size
if cfg.model.name == "resnet":
do_flip = self.is_train() and \
cfg.train.lazy_flip_augmentation and \
(torch.rand(1) > .5).item()
# Otherwise, we only feed the image with frame_id 0 through the depth encoder
input_img = inputs["color_aug", 0, 0]
if do_flip:
input_img = torch.flip(input_img, dims=(-1, ))
features = self.models["encoder"](input_img)
if not cfg.model.unified_decoder:
outputs = self.models["depth"](features)
else:
outputs = dict()
if self.cfg.model.gaussian_rendering:
# gauss_feats = self.models["gauss_encoder"](inputs["color_aug", 0, 0])
input_f_id = 0
gauss_feats = features
gauss_outs = dict()
for i in range(self.cfg.model.gaussians_per_pixel):
outs = self.models["gauss_decoder_"+str(i)](gauss_feats)
for key, v in outs.items():
gauss_outs[key] = outs[key][:,None,...] if i==0 else torch.cat([gauss_outs[key], outs[key][:,None,...]], dim=1)
for key, v in gauss_outs.items():
gauss_outs[key] = rearrange(gauss_outs[key], 'b n ... -> (b n) ...')
outputs |= gauss_outs
outputs = {(key[0], input_f_id, key[1]): v for key, v in outputs.items()}
else:
for scale in cfg.model.scales:
outputs[("disp", 0, scale)] = outputs[("disp", scale)]
# unflip all outputs
if do_flip:
for k, v in outputs.items():
outputs[k] = torch.flip(v, dims=(-1, ))
elif "unidepth" in cfg.model.name:
if cfg.model.name in ["unidepth",
"unidepth_unprojector_vit",
"unidepth_unprojector_cnvnxtl"]:
outputs = self.models["unidepth"](inputs)
elif cfg.model.name in ["unidepth_extension_vit",
"unidepth_extension_cnvnxtl"]:
outputs = self.models["unidepth_extended"](inputs)
input_f_id = 0
outputs = {(key[0], input_f_id, key[1]): v for key, v in outputs.items()}
input_f_id = 0
scale = 0
if not ("depth", input_f_id, scale) in outputs:
disp = outputs[("disp", input_f_id, scale)]
_, depth = disp_to_depth(disp, cfg.model.min_depth, cfg.model.max_depth)
outputs[("depth", input_f_id, scale)] = depth
self.compute_gauss_means(inputs, outputs)
return outputs
def target_tensor_image_dims(self, inputs):
B, _, H, W = inputs["color", 0, 0].shape
return B, H, W
def compute_gauss_means(self, inputs, outputs):
cfg = self.cfg
input_f_id = 0
scale = 0
depth = outputs[("depth", input_f_id, scale)]
B, _, H, W = depth.shape
if ("inv_K_src", scale) in inputs:
inv_K = inputs[("inv_K_src", scale)]
else:
inv_K = outputs[("inv_K_src", input_f_id, scale)]
if self.cfg.model.gaussians_per_pixel > 1:
inv_K = rearrange(inv_K[:,None,...].
repeat(1, self.cfg.model.gaussians_per_pixel, 1, 1),
'b n ... -> (b n) ...')
xyz = self.backproject_depth[str(scale)](
depth, inv_K
)
inputs[("inv_K_src", scale)] = inv_K
if cfg.model.predict_offset:
offset = outputs[("gauss_offset", input_f_id, scale)]
if cfg.model.scaled_offset:
offset = offset * depth.detach()
offset = offset.view(B, 3, -1)
zeros = torch.zeros(B, 1, H * W, device=depth.device)
offset = torch.cat([offset, zeros], 1)
xyz = xyz + offset # [B, 4, W*H]
outputs[("gauss_means", input_f_id, scale)] = xyz
def checkpoint_dir(self):
return Path("checkpoints")
def save_model(self, optimizer, step, ema=None):
"""Save model weights to disk
"""
save_folder = self.checkpoint_dir()
save_folder.mkdir(exist_ok=True, parents=True)
save_path = save_folder / f"model_{step:07}.pth"
logging.info(f"saving checkpoint to {str(save_path)}")
model = ema.ema_model if ema is not None else self
save_dict = {
"model": model.state_dict(),
"version": "1.0",
"optimiser": optimizer.state_dict(),
"step": step
}
torch.save(save_dict, save_path)
num_ckpts = self.cfg.optimiser.num_keep_ckpts
ckpts = sorted(list(save_folder.glob("model_*.pth")), reverse=True)
if len(ckpts) > num_ckpts:
for ckpt in ckpts[num_ckpts:]:
ckpt.unlink()
def load_model(self, weights_path, optimizer=None):
"""Load model(s) from disk
"""
weights_path = Path(weights_path)
# determine if it is an old or new saving format
if weights_path.is_dir() and weights_path.joinpath("encoder.pth").exists():
self.load_model_old(weights_path, optimizer)
return
logging.info(f"Loading weights from {weights_path}...")
state_dict = torch.load(weights_path)
if "version" in state_dict and state_dict["version"] == "1.0":
new_dict = {}
for k, v in state_dict["model"].items():
if "backproject_depth" in k:
new_dict[k] = self.state_dict()[k].clone()
else:
new_dict[k] = v.clone()
# for k, v in state_dict["model"].items():
# if "backproject_depth" in k and ("pix_coords" in k or "ones" in k):
# # model has these parameters set as a function of batch size
# # when batch size changes in eval this results in a loading error
# state_dict["model"][k] = v[:1, ...]
self.load_state_dict(new_dict, strict=False)
else:
# TODO remove loading according to the old format
for name in self.cfg.train.models_to_load:
if name not in self.models:
continue
self.models[name].load_state_dict(state_dict[name])
# loading adam state
if optimizer is not None:
optimizer.load_state_dict(state_dict["optimiser"])
self.step = state_dict["step"]
def load_model_old(self, weights_folder, optimizer=None):
for n in self.cfg.train.models_to_load:
print(f"Loading {n} weights...")
path = weights_folder / f"{n}.pth"
if n not in self.models:
continue
model_dict = self.models[n].state_dict()
pretrained_dict = torch.load(path)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
self.models[n].load_state_dict(model_dict)
# loading adam state
optimizer_load_path = weights_folder / "adam.pth"
if optimizer is not None and optimizer_load_path.is_file():
print("Loading Adam weights")
optimizer_state = torch.load(optimizer_load_path)
optimizer.load_state_dict(optimizer_state["adam"])
self.step = optimizer_state["step"]