|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json
|
|
import os
|
|
|
|
from .easydict import EasyDict as edict
|
|
from .arg_utils import infer_type
|
|
|
|
import pathlib
|
|
import platform
|
|
|
|
ROOT = pathlib.Path(__file__).parent.parent.resolve()
|
|
|
|
HOME_DIR = os.path.expanduser("~")
|
|
|
|
COMMON_CONFIG = {
|
|
"save_dir": os.path.expanduser("~/shortcuts/monodepth3_checkpoints"),
|
|
"project": "ZoeDepth",
|
|
"tags": '',
|
|
"notes": "",
|
|
"gpu": None,
|
|
"root": ".",
|
|
"uid": None,
|
|
"print_losses": False
|
|
}
|
|
|
|
DATASETS_CONFIG = {
|
|
"kitti": {
|
|
"dataset": "kitti",
|
|
"min_depth": 0.001,
|
|
"max_depth": 80,
|
|
"data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
|
|
"gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
|
|
"filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt",
|
|
"input_height": 352,
|
|
"input_width": 1216,
|
|
"data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
|
|
"gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
|
|
"filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt",
|
|
|
|
"min_depth_eval": 1e-3,
|
|
"max_depth_eval": 80,
|
|
|
|
"do_random_rotate": True,
|
|
"degree": 1.0,
|
|
"do_kb_crop": True,
|
|
"garg_crop": True,
|
|
"eigen_crop": False,
|
|
"use_right": False
|
|
},
|
|
"kitti_test": {
|
|
"dataset": "kitti",
|
|
"min_depth": 0.001,
|
|
"max_depth": 80,
|
|
"data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
|
|
"gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
|
|
"filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt",
|
|
"input_height": 352,
|
|
"input_width": 1216,
|
|
"data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
|
|
"gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
|
|
"filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt",
|
|
|
|
"min_depth_eval": 1e-3,
|
|
"max_depth_eval": 80,
|
|
|
|
"do_random_rotate": False,
|
|
"degree": 1.0,
|
|
"do_kb_crop": True,
|
|
"garg_crop": True,
|
|
"eigen_crop": False,
|
|
"use_right": False
|
|
},
|
|
"nyu": {
|
|
"dataset": "nyu",
|
|
"avoid_boundary": False,
|
|
"min_depth": 1e-3,
|
|
"max_depth": 10,
|
|
"data_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"),
|
|
"gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"),
|
|
"filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt",
|
|
"input_height": 480,
|
|
"input_width": 640,
|
|
"data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"),
|
|
"gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"),
|
|
"filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt",
|
|
"min_depth_eval": 1e-3,
|
|
"max_depth_eval": 10,
|
|
"min_depth_diff": -10,
|
|
"max_depth_diff": 10,
|
|
|
|
"do_random_rotate": True,
|
|
"degree": 1.0,
|
|
"do_kb_crop": False,
|
|
"garg_crop": False,
|
|
"eigen_crop": True
|
|
},
|
|
"ibims": {
|
|
"dataset": "ibims",
|
|
"ibims_root": os.path.join(HOME_DIR, "shortcuts/datasets/ibims/ibims1_core_raw/"),
|
|
"eigen_crop": True,
|
|
"garg_crop": False,
|
|
"do_kb_crop": False,
|
|
"min_depth_eval": 0,
|
|
"max_depth_eval": 10,
|
|
"min_depth": 1e-3,
|
|
"max_depth": 10
|
|
},
|
|
"sunrgbd": {
|
|
"dataset": "sunrgbd",
|
|
"sunrgbd_root": os.path.join(HOME_DIR, "shortcuts/datasets/SUNRGBD/test/"),
|
|
"eigen_crop": True,
|
|
"garg_crop": False,
|
|
"do_kb_crop": False,
|
|
"min_depth_eval": 0,
|
|
"max_depth_eval": 8,
|
|
"min_depth": 1e-3,
|
|
"max_depth": 10
|
|
},
|
|
"diml_indoor": {
|
|
"dataset": "diml_indoor",
|
|
"diml_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_indoor_test/"),
|
|
"eigen_crop": True,
|
|
"garg_crop": False,
|
|
"do_kb_crop": False,
|
|
"min_depth_eval": 0,
|
|
"max_depth_eval": 10,
|
|
"min_depth": 1e-3,
|
|
"max_depth": 10
|
|
},
|
|
"diml_outdoor": {
|
|
"dataset": "diml_outdoor",
|
|
"diml_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_outdoor_test/"),
|
|
"eigen_crop": False,
|
|
"garg_crop": True,
|
|
"do_kb_crop": False,
|
|
"min_depth_eval": 2,
|
|
"max_depth_eval": 80,
|
|
"min_depth": 1e-3,
|
|
"max_depth": 80
|
|
},
|
|
"diode_indoor": {
|
|
"dataset": "diode_indoor",
|
|
"diode_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_indoor/"),
|
|
"eigen_crop": True,
|
|
"garg_crop": False,
|
|
"do_kb_crop": False,
|
|
"min_depth_eval": 1e-3,
|
|
"max_depth_eval": 10,
|
|
"min_depth": 1e-3,
|
|
"max_depth": 10
|
|
},
|
|
"diode_outdoor": {
|
|
"dataset": "diode_outdoor",
|
|
"diode_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_outdoor/"),
|
|
"eigen_crop": False,
|
|
"garg_crop": True,
|
|
"do_kb_crop": False,
|
|
"min_depth_eval": 1e-3,
|
|
"max_depth_eval": 80,
|
|
"min_depth": 1e-3,
|
|
"max_depth": 80
|
|
},
|
|
"hypersim_test": {
|
|
"dataset": "hypersim_test",
|
|
"hypersim_test_root": os.path.join(HOME_DIR, "shortcuts/datasets/hypersim_test/"),
|
|
"eigen_crop": True,
|
|
"garg_crop": False,
|
|
"do_kb_crop": False,
|
|
"min_depth_eval": 1e-3,
|
|
"max_depth_eval": 80,
|
|
"min_depth": 1e-3,
|
|
"max_depth": 10
|
|
},
|
|
"vkitti": {
|
|
"dataset": "vkitti",
|
|
"vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"),
|
|
"eigen_crop": False,
|
|
"garg_crop": True,
|
|
"do_kb_crop": True,
|
|
"min_depth_eval": 1e-3,
|
|
"max_depth_eval": 80,
|
|
"min_depth": 1e-3,
|
|
"max_depth": 80
|
|
},
|
|
"vkitti2": {
|
|
"dataset": "vkitti2",
|
|
"vkitti2_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti2/"),
|
|
"eigen_crop": False,
|
|
"garg_crop": True,
|
|
"do_kb_crop": True,
|
|
"min_depth_eval": 1e-3,
|
|
"max_depth_eval": 80,
|
|
"min_depth": 1e-3,
|
|
"max_depth": 80,
|
|
},
|
|
"ddad": {
|
|
"dataset": "ddad",
|
|
"ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"),
|
|
"eigen_crop": False,
|
|
"garg_crop": True,
|
|
"do_kb_crop": True,
|
|
"min_depth_eval": 1e-3,
|
|
"max_depth_eval": 80,
|
|
"min_depth": 1e-3,
|
|
"max_depth": 80,
|
|
},
|
|
}
|
|
|
|
ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"]
|
|
ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"]
|
|
ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR
|
|
|
|
COMMON_TRAINING_CONFIG = {
|
|
"dataset": "nyu",
|
|
"distributed": True,
|
|
"workers": 16,
|
|
"clip_grad": 0.1,
|
|
"use_shared_dict": False,
|
|
"shared_dict": None,
|
|
"use_amp": False,
|
|
|
|
"aug": True,
|
|
"random_crop": False,
|
|
"random_translate": False,
|
|
"translate_prob": 0.2,
|
|
"max_translation": 100,
|
|
|
|
"validate_every": 0.25,
|
|
"log_images_every": 0.1,
|
|
"prefetch": False,
|
|
}
|
|
|
|
|
|
def flatten(config, except_keys=('bin_conf')):
|
|
def recurse(inp):
|
|
if isinstance(inp, dict):
|
|
for key, value in inp.items():
|
|
if key in except_keys:
|
|
yield (key, value)
|
|
if isinstance(value, dict):
|
|
yield from recurse(value)
|
|
else:
|
|
yield (key, value)
|
|
|
|
return dict(list(recurse(config)))
|
|
|
|
|
|
def split_combined_args(kwargs):
|
|
"""Splits the arguments that are combined with '__' into multiple arguments.
|
|
Combined arguments should have equal number of keys and values.
|
|
Keys are separated by '__' and Values are separated with ';'.
|
|
For example, '__n_bins__lr=256;0.001'
|
|
|
|
Args:
|
|
kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format.
|
|
|
|
Returns:
|
|
dict: Parsed dict with the combined arguments split into individual key-value pairs.
|
|
"""
|
|
new_kwargs = dict(kwargs)
|
|
for key, value in kwargs.items():
|
|
if key.startswith("__"):
|
|
keys = key.split("__")[1:]
|
|
values = value.split(";")
|
|
assert len(keys) == len(
|
|
values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})"
|
|
for k, v in zip(keys, values):
|
|
new_kwargs[k] = v
|
|
return new_kwargs
|
|
|
|
|
|
def parse_list(config, key, dtype=int):
|
|
"""Parse a list of values for the key if the value is a string. The values are separated by a comma.
|
|
Modifies the config in place.
|
|
"""
|
|
if key in config:
|
|
if isinstance(config[key], str):
|
|
config[key] = list(map(dtype, config[key].split(',')))
|
|
assert isinstance(config[key], list) and all([isinstance(e, dtype) for e in config[key]]
|
|
), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}."
|
|
|
|
|
|
def get_model_config(model_name, model_version=None):
|
|
"""Find and parse the .json config file for the model.
|
|
|
|
Args:
|
|
model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory.
|
|
model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None.
|
|
|
|
Returns:
|
|
easydict: the config dictionary for the model.
|
|
"""
|
|
config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json"
|
|
config_file = os.path.join(ROOT, "models", model_name, config_fname)
|
|
if not os.path.exists(config_file):
|
|
return None
|
|
|
|
with open(config_file, "r") as f:
|
|
config = edict(json.load(f))
|
|
|
|
|
|
|
|
if "inherit" in config.train and config.train.inherit is not None:
|
|
inherit_config = get_model_config(config.train["inherit"]).train
|
|
for key, value in inherit_config.items():
|
|
if key not in config.train:
|
|
config.train[key] = value
|
|
return edict(config)
|
|
|
|
|
|
def update_model_config(config, mode, model_name, model_version=None, strict=False):
|
|
model_config = get_model_config(model_name, model_version)
|
|
if model_config is not None:
|
|
config = {**config, **
|
|
flatten({**model_config.model, **model_config[mode]})}
|
|
elif strict:
|
|
raise ValueError(f"Config file for model {model_name} not found.")
|
|
return config
|
|
|
|
|
|
def check_choices(name, value, choices):
|
|
|
|
if value not in choices:
|
|
raise ValueError(f"{name} {value} not in supported choices {choices}")
|
|
|
|
|
|
KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase",
|
|
"prefetch", "cycle_momentum"]
|
|
|
|
|
|
def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs):
|
|
"""Main entry point to get the config for the model.
|
|
|
|
Args:
|
|
model_name (str): name of the desired model.
|
|
mode (str, optional): "train" or "infer". Defaults to 'train'.
|
|
dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None.
|
|
|
|
Keyword Args: key-value pairs of arguments to overwrite the default config.
|
|
|
|
The order of precedence for overwriting the config is (Higher precedence first):
|
|
# 1. overwrite_kwargs
|
|
# 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json
|
|
# 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json
|
|
# 4. common_config: Default config for all models specified in COMMON_CONFIG
|
|
|
|
Returns:
|
|
easydict: The config dictionary for the model.
|
|
"""
|
|
|
|
|
|
check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"])
|
|
check_choices("Mode", mode, ["train", "infer", "eval"])
|
|
if mode == "train":
|
|
check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None])
|
|
|
|
config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG})
|
|
config = update_model_config(config, mode, model_name)
|
|
|
|
|
|
version_name = overwrite_kwargs.get("version_name", config["version_name"])
|
|
config = update_model_config(config, mode, model_name, version_name)
|
|
|
|
|
|
config_version = overwrite_kwargs.get("config_version", None)
|
|
if config_version is not None:
|
|
print("Overwriting config with config_version", config_version)
|
|
config = update_model_config(config, mode, model_name, config_version)
|
|
|
|
|
|
|
|
overwrite_kwargs = split_combined_args(overwrite_kwargs)
|
|
config = {**config, **overwrite_kwargs}
|
|
|
|
|
|
for key in KEYS_TYPE_BOOL:
|
|
if key in config:
|
|
config[key] = bool(config[key])
|
|
|
|
|
|
parse_list(config, "n_attractors")
|
|
|
|
|
|
if 'bin_conf' in config and 'n_bins' in overwrite_kwargs:
|
|
bin_conf = config['bin_conf']
|
|
n_bins = overwrite_kwargs['n_bins']
|
|
new_bin_conf = []
|
|
for conf in bin_conf:
|
|
conf['n_bins'] = n_bins
|
|
new_bin_conf.append(conf)
|
|
config['bin_conf'] = new_bin_conf
|
|
|
|
if mode == "train":
|
|
orig_dataset = dataset
|
|
if dataset == "mix":
|
|
dataset = 'nyu'
|
|
if dataset is not None:
|
|
config['project'] = f"MonoDepth3-{orig_dataset}"
|
|
|
|
if dataset is not None:
|
|
config['dataset'] = dataset
|
|
config = {**DATASETS_CONFIG[dataset], **config}
|
|
|
|
|
|
config['model'] = model_name
|
|
typed_config = {k: infer_type(v) for k, v in config.items()}
|
|
|
|
config['hostname'] = platform.node()
|
|
return edict(typed_config)
|
|
|
|
|
|
def change_dataset(config, new_dataset):
|
|
config.update(DATASETS_CONFIG[new_dataset])
|
|
return config
|
|
|