diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/__pycache__/__init__.cpython-310.pyc b/src/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eedc67183fda2309e19da901a295def8be02ef74 Binary files /dev/null and b/src/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/dataset/__init__.py b/src/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a8837f5d5b9826ebb44ae72c83b69af622937a23 --- /dev/null +++ b/src/dataset/__init__.py @@ -0,0 +1,97 @@ +# Last modified: 2024-04-16 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + +import os +import pdb + +from .base_depth_dataset import BaseDepthDataset # noqa: F401 +from .eval_base_dataset import EvaluateBaseDataset, DatasetMode, get_pred_name +from .diode_dataset import DIODEDataset +from .eth3d_dataset import ETH3DDataset +from .hypersim_dataset import HypersimDataset +from .kitti_dataset import KITTIDataset +from .nyu_dataset import NYUDataset +from .scannet_dataset import ScanNetDataset +from .vkitti_dataset import VirtualKITTIDataset +from .depthanything_dataset import DepthAnythingDataset +from .base_inpaint_dataset import BaseInpaintDataset + +dataset_name_class_dict = { + "hypersim": HypersimDataset, + "vkitti": VirtualKITTIDataset, + "nyu_v2": NYUDataset, + "kitti": KITTIDataset, + "eth3d": ETH3DDataset, + "diode": DIODEDataset, + "scannet": ScanNetDataset, + 'depthanything': DepthAnythingDataset, + 'inpainting': BaseInpaintDataset +} + + +def get_dataset( + cfg_data_split, base_data_dir: str, mode: DatasetMode, **kwargs +): + if "mixed" == cfg_data_split.name: + # assert DatasetMode.TRAIN == mode, "Only training mode supports mixed datasets." + dataset_ls = [ + get_dataset(_cfg, base_data_dir, mode, **kwargs) + for _cfg in cfg_data_split.dataset_list + ] + return dataset_ls + elif cfg_data_split.name in dataset_name_class_dict.keys(): + dataset_class = dataset_name_class_dict[cfg_data_split.name] + dataset = dataset_class( + mode=mode, + filename_ls_path=cfg_data_split.filenames, + dataset_dir=os.path.join(base_data_dir, cfg_data_split.dir), + **cfg_data_split, + **kwargs, + ) + else: + raise NotImplementedError + + return dataset + +def get_eval_dataset( + cfg_data_split, base_data_dir: str, mode: DatasetMode, **kwargs +) -> EvaluateBaseDataset: + if "mixed" == cfg_data_split.name: + assert DatasetMode.TRAIN == mode, "Only training mode supports mixed datasets." + dataset_ls = [ + get_dataset(_cfg, base_data_dir, mode, **kwargs) + for _cfg in cfg_data_split.dataset_list + ] + return dataset_ls + elif cfg_data_split.name in dataset_name_class_dict.keys(): + dataset_class = dataset_name_class_dict[cfg_data_split.name] + dataset = dataset_class( + mode=mode, + filename_ls_path=cfg_data_split.filenames, + dataset_dir=os.path.join(base_data_dir, cfg_data_split.dir), + **cfg_data_split, + **kwargs, + ) + else: + raise NotImplementedError + + return dataset diff --git a/src/dataset/__pycache__/__init__.cpython-310.pyc b/src/dataset/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab1c88df53f61aef109de219f53e407f431b87e9 Binary files /dev/null and b/src/dataset/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/dataset/__pycache__/base_depth_dataset.cpython-310.pyc b/src/dataset/__pycache__/base_depth_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84a4fd02b39c6d6c24d4ee90275249f3ebf3a2cd Binary files /dev/null and b/src/dataset/__pycache__/base_depth_dataset.cpython-310.pyc differ diff --git a/src/dataset/__pycache__/base_inpaint_dataset.cpython-310.pyc b/src/dataset/__pycache__/base_inpaint_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe643aaa8db80662187141ba4825272faa8df7b6 Binary files /dev/null and b/src/dataset/__pycache__/base_inpaint_dataset.cpython-310.pyc differ diff --git a/src/dataset/__pycache__/depthanything_dataset.cpython-310.pyc b/src/dataset/__pycache__/depthanything_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eede1d818e62f1540d94ae577f41de6ebf4f0b9a Binary files /dev/null and b/src/dataset/__pycache__/depthanything_dataset.cpython-310.pyc differ diff --git a/src/dataset/__pycache__/diode_dataset.cpython-310.pyc b/src/dataset/__pycache__/diode_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1c2c840938e53923147f94ccff414e2f59c0d1f Binary files /dev/null and b/src/dataset/__pycache__/diode_dataset.cpython-310.pyc differ diff --git a/src/dataset/__pycache__/eth3d_dataset.cpython-310.pyc b/src/dataset/__pycache__/eth3d_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e2e27657baa763bc604e789fec8df1eb83bcd13 Binary files /dev/null and b/src/dataset/__pycache__/eth3d_dataset.cpython-310.pyc differ diff --git a/src/dataset/__pycache__/eval_base_dataset.cpython-310.pyc b/src/dataset/__pycache__/eval_base_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0fda2d4a4208d4da2fd76742011b03b4590d254 Binary files /dev/null and b/src/dataset/__pycache__/eval_base_dataset.cpython-310.pyc differ diff --git a/src/dataset/__pycache__/hypersim_dataset.cpython-310.pyc b/src/dataset/__pycache__/hypersim_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67ad87390541931944de3b1c5f729ab3da55646c Binary files /dev/null and b/src/dataset/__pycache__/hypersim_dataset.cpython-310.pyc differ diff --git a/src/dataset/__pycache__/kitti_dataset.cpython-310.pyc b/src/dataset/__pycache__/kitti_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d7472f74450eb48467d905f70c3061775045821 Binary files /dev/null and b/src/dataset/__pycache__/kitti_dataset.cpython-310.pyc differ diff --git a/src/dataset/__pycache__/mixed_sampler.cpython-310.pyc b/src/dataset/__pycache__/mixed_sampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d955bd44452c1e43fd53956fd8337558da0ce79 Binary files /dev/null and b/src/dataset/__pycache__/mixed_sampler.cpython-310.pyc differ diff --git a/src/dataset/__pycache__/nyu_dataset.cpython-310.pyc b/src/dataset/__pycache__/nyu_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f799ed21069ae71417feb5e8aeadf410554b3bf Binary files /dev/null and b/src/dataset/__pycache__/nyu_dataset.cpython-310.pyc differ diff --git a/src/dataset/__pycache__/scannet_dataset.cpython-310.pyc b/src/dataset/__pycache__/scannet_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f935077f824610bee9df21a6c1dc5e22e4b362b1 Binary files /dev/null and b/src/dataset/__pycache__/scannet_dataset.cpython-310.pyc differ diff --git a/src/dataset/__pycache__/vkitti_dataset.cpython-310.pyc b/src/dataset/__pycache__/vkitti_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b031b6f8370b80562a7f3c79d90c4855d519389b Binary files /dev/null and b/src/dataset/__pycache__/vkitti_dataset.cpython-310.pyc differ diff --git a/src/dataset/base_depth_dataset.py b/src/dataset/base_depth_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e6666c2ef7e23874cb9b511864a9c1af8b3943 --- /dev/null +++ b/src/dataset/base_depth_dataset.py @@ -0,0 +1,286 @@ +# Last modified: 2024-04-30 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- +import glob +import io +import json +import os +import pdb +import random +import tarfile +from enum import Enum +from typing import Union + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision.transforms import InterpolationMode, Resize, CenterCrop +import torchvision.transforms as transforms +from transformers import CLIPTextModel, CLIPTokenizer +from src.util.depth_transform import DepthNormalizerBase +import random + +from src.dataset.eval_base_dataset import DatasetMode, DepthFileNameMode + + +def read_image_from_tar(tar_obj, img_rel_path): + image = tar_obj.extractfile("./" + img_rel_path) + image = image.read() + image = Image.open(io.BytesIO(image)) + + +class BaseDepthDataset(Dataset): + def __init__( + self, + mode: DatasetMode, + filename_ls_path: str, + dataset_dir: str, + disp_name: str, + min_depth: float, + max_depth: float, + has_filled_depth: bool, + name_mode: DepthFileNameMode, + depth_transform: Union[DepthNormalizerBase, None] = None, + tokenizer: CLIPTokenizer = None, + augmentation_args: dict = None, + resize_to_hw=None, + move_invalid_to_far_plane: bool = True, + rgb_transform=lambda x: x / 255.0 * 2 - 1, # [0, 255] -> [-1, 1], + **kwargs, + ) -> None: + super().__init__() + self.mode = mode + # dataset info + self.filename_ls_path = filename_ls_path + self.disp_name = disp_name + self.has_filled_depth = has_filled_depth + self.name_mode: DepthFileNameMode = name_mode + self.min_depth = min_depth + self.max_depth = max_depth + # training arguments + self.depth_transform: DepthNormalizerBase = depth_transform + self.augm_args = augmentation_args + self.resize_to_hw = resize_to_hw + self.rgb_transform = rgb_transform + self.move_invalid_to_far_plane = move_invalid_to_far_plane + self.tokenizer = tokenizer + # Load filenames + self.filenames = [] + filename_paths = glob.glob(self.filename_ls_path) + for path in filename_paths: + with open(path, "r") as f: + self.filenames += json.load(f) + # Tar dataset + self.tar_obj = None + self.is_tar = ( + True + if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir) + else False + ) + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, index): + rasters, other = self._get_data_item(index) + if DatasetMode.TRAIN == self.mode: + rasters = self._training_preprocess(rasters) + # merge + outputs = rasters + outputs.update(other) + return outputs + + def _get_data_item(self, index): + rgb_path = self.filenames[index]['rgb_path'] + depth_path = self.filenames[index]['depth_path'] + mask_path = None + if 'valid_mask' in self.filenames[index]: + mask_path = self.filenames[index]['valid_mask'] + if self.filenames[index]['caption'] is not None: + coca_caption = self.filenames[index]['caption']['coca_caption'] + spatial_caption = self.filenames[index]['caption']['spatial_caption'] + empty_caption = '' + caption_choices = [coca_caption, spatial_caption, empty_caption] + probabilities = [0.4, 0.4, 0.2] + caption = random.choices(caption_choices, probabilities)[0] + else: + caption = '' + + rasters = {} + # RGB data + rasters.update(self._load_rgb_data(rgb_path)) + + # Depth data + if DatasetMode.RGB_ONLY != self.mode and depth_path is not None: + # load data + depth_data = self._load_depth_data(depth_path) + rasters.update(depth_data) + # valid mask + if mask_path is not None: + valid_mask_raw = Image.open(mask_path) + valid_mask_filled = Image.open(mask_path) + rasters["valid_mask_raw"] = torch.from_numpy(np.asarray(valid_mask_raw)).unsqueeze(0).bool() + rasters["valid_mask_filled"] = torch.from_numpy(np.asarray(valid_mask_filled)).unsqueeze(0).bool() + else: + rasters["valid_mask_raw"] = self._get_valid_mask( + rasters["depth_raw_linear"] + ).clone() + rasters["valid_mask_filled"] = self._get_valid_mask( + rasters["depth_filled_linear"] + ).clone() + + other = {"index": index, "rgb_path": rgb_path, 'text': caption} + + if self.resize_to_hw is not None: + resize_transform = transforms.Compose([ + Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT), + CenterCrop(size=self.resize_to_hw)]) + rasters = {k: resize_transform(v) for k, v in rasters.items()} + + return rasters, other + + def _load_rgb_data(self, rgb_path): + # Read RGB data + rgb = self._read_rgb_file(rgb_path) + rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] + + outputs = { + "rgb_int": torch.from_numpy(rgb).int(), + "rgb_norm": torch.from_numpy(rgb_norm).float(), + } + return outputs + + def _load_depth_data(self, depth_path, filled_rel_path=None): + # Read depth data + outputs = {} + depth_raw = self._read_depth_file(depth_path).squeeze() + depth_raw_linear = torch.from_numpy(depth_raw.copy()).float().unsqueeze(0) # [1, H, W] + outputs["depth_raw_linear"] = depth_raw_linear.clone() + + if self.has_filled_depth: + depth_filled = self._read_depth_file(filled_rel_path).squeeze() + depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0) + outputs["depth_filled_linear"] = depth_filled_linear + else: + outputs["depth_filled_linear"] = depth_raw_linear.clone() + + return outputs + + def _get_data_path(self, index): + filename_line = self.filenames[index] + + # Get data path + rgb_rel_path = filename_line[0] + + depth_rel_path, text_rel_path = None, None + if DatasetMode.RGB_ONLY != self.mode: + depth_rel_path = filename_line[1] + if len(filename_line) > 2: + text_rel_path = filename_line[2] + return rgb_rel_path, depth_rel_path, text_rel_path + + def _read_image(self, img_path) -> np.ndarray: + image_to_read = img_path + image = Image.open(image_to_read) # [H, W, rgb] + image = np.asarray(image) + return image + + def _read_rgb_file(self, path) -> np.ndarray: + rgb = self._read_image(path) + rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W] + return rgb + + def _read_depth_file(self, path): + depth_in = self._read_image(path) + # Replace code below to decode depth according to dataset definition + depth_decoded = depth_in + return depth_decoded + + def _get_valid_mask(self, depth: torch.Tensor): + valid_mask = torch.logical_and( + (depth > self.min_depth), (depth < self.max_depth) + ).bool() + return valid_mask + + def _training_preprocess(self, rasters): + # Augmentation + if self.augm_args is not None: + rasters = self._augment_data(rasters) + + # Normalization + # rasters["depth_raw_norm"] = rasters["depth_raw_linear"] / 255.0 * 2.0 - 1.0 + # rasters["depth_filled_norm"] = rasters["depth_filled_linear"] / 255.0 * 2.0 - 1.0 + + rasters["depth_raw_norm"] = self.depth_transform( + rasters["depth_raw_linear"], rasters["valid_mask_raw"] + ).clone() + rasters["depth_filled_norm"] = self.depth_transform( + rasters["depth_filled_linear"], rasters["valid_mask_filled"] + ).clone() + + # Set invalid pixel to far plane + if self.move_invalid_to_far_plane: + if self.depth_transform.far_plane_at_max: + rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( + self.depth_transform.norm_max + ) + else: + rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( + self.depth_transform.norm_min + ) + + # Resize + if self.resize_to_hw is not None: + resize_transform = transforms.Compose([ + Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT), + CenterCrop(size=self.resize_to_hw)]) + rasters = {k: resize_transform(v) for k, v in rasters.items()} + return rasters + + def _augment_data(self, rasters_dict): + # lr flipping + lr_flip_p = self.augm_args.lr_flip_p + if random.random() < lr_flip_p: + rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()} + + return rasters_dict + + def __del__(self): + if hasattr(self, "tar_obj") and self.tar_obj is not None: + self.tar_obj.close() + self.tar_obj = None + +def get_pred_name(rgb_basename, name_mode, suffix=".png"): + if DepthFileNameMode.rgb_id == name_mode: + pred_basename = "pred_" + rgb_basename.split("_")[1] + elif DepthFileNameMode.i_d_rgb == name_mode: + pred_basename = rgb_basename.replace("_rgb.", "_pred.") + elif DepthFileNameMode.id == name_mode: + pred_basename = "pred_" + rgb_basename + elif DepthFileNameMode.rgb_i_d == name_mode: + pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:]) + else: + raise NotImplementedError + # change suffix + pred_basename = os.path.splitext(pred_basename)[0] + suffix + + return pred_basename diff --git a/src/dataset/base_inpaint_dataset.py b/src/dataset/base_inpaint_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4215f6b8c675e2dbe063fe9bc0128ebf889e6d --- /dev/null +++ b/src/dataset/base_inpaint_dataset.py @@ -0,0 +1,280 @@ +# Last modified: 2024-04-30 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- +import glob +import io +import json +import os +import pdb +import random +import tarfile +from enum import Enum +from typing import Union + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision.transforms import InterpolationMode, Resize, CenterCrop +import torchvision.transforms as transforms +from transformers import CLIPTextModel, CLIPTokenizer +from src.util.depth_transform import DepthNormalizerBase +import random + +from src.dataset.eval_base_dataset import DatasetMode, DepthFileNameMode +from pycocotools import mask as coco_mask +from scipy.ndimage import gaussian_filter + +def read_image_from_tar(tar_obj, img_rel_path): + image = tar_obj.extractfile("./" + img_rel_path) + image = image.read() + image = Image.open(io.BytesIO(image)) + + +class BaseInpaintDataset(Dataset): + def __init__( + self, + mode: DatasetMode, + filename_ls_path: str, + dataset_dir: str, + disp_name: str, + depth_transform: Union[DepthNormalizerBase, None] = None, + tokenizer: CLIPTokenizer = None, + augmentation_args: dict = None, + resize_to_hw=None, + move_invalid_to_far_plane: bool = True, + rgb_transform=lambda x: x / 255.0 * 2 - 1, # [0, 255] -> [-1, 1], + **kwargs, + ) -> None: + super().__init__() + self.mode = mode + # dataset info + self.filename_ls_path = filename_ls_path + self.disp_name = disp_name + # training arguments + self.depth_transform: DepthNormalizerBase = depth_transform + self.augm_args = augmentation_args + self.resize_to_hw = resize_to_hw + self.rgb_transform = rgb_transform + self.move_invalid_to_far_plane = move_invalid_to_far_plane + self.tokenizer = tokenizer + # Load filenames + self.filenames = [] + filename_paths = glob.glob(self.filename_ls_path) + for path in filename_paths: + with open(path, "r") as f: + self.filenames += json.load(f) + # Tar dataset + self.tar_obj = None + self.is_tar = ( + True + if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir) + else False + ) + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, index): + rasters, other = self._get_data_item(index) + if DatasetMode.TRAIN == self.mode: + rasters = self._training_preprocess(rasters) + # merge + outputs = rasters + outputs.update(other) + return outputs + + def _get_data_item(self, index): + rgb_path = self.filenames[index]['rgb_path'] + mask_path = None + if 'valid_mask' in self.filenames[index]: + mask_path = self.filenames[index]['valid_mask'] + if self.filenames[index]['caption'] is not None: + coca_caption = self.filenames[index]['caption']['coca_caption'] + spatial_caption = self.filenames[index]['caption']['spatial_caption'] + empty_caption = '' + caption_choices = [coca_caption, spatial_caption, empty_caption] + probabilities = [0.4, 0.4, 0.2] + caption = random.choices(caption_choices, probabilities)[0] + else: + caption = '' + + rasters = {} + # RGB data + rasters.update(self._load_rgb_data(rgb_path)) + + try: + anno = json.load(open(rgb_path.replace('.jpg', '.json')))['annotations'] + random.shuffle(anno) + object_num = random.randint(5, 10) + mask = np.array(coco_mask.decode(anno[0]['segmentation']), dtype=np.uint8) + for single_anno in (anno[0:object_num] if len(anno)>object_num else anno): + mask += np.array(coco_mask.decode(single_anno['segmentation']), dtype=np.uint8) + except: + mask = None + + a = random.random() + if a < 0.1 or mask is None: + mask = np.zeros(rasters['rgb_int'].shape[-2:]) + rows, cols = mask.shape + grid_size = random.randint(5, 14) + grid_rows, grid_cols = rows // grid_size, cols // grid_size + for i in range(grid_rows): + for j in range(grid_cols): + random_prob = np.random.rand() + if random_prob < 0.2: + row_start = i * grid_size + row_end = (i + 1) * grid_size + col_start = j * grid_size + col_end = (j + 1) * grid_size + mask[row_start:row_end, col_start:col_end] = 1 + + rasters['mask'] = torch.from_numpy(mask).unsqueeze(0).to(torch.float32) + + if self.resize_to_hw is not None: + resize_transform = transforms.Compose([ + Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT), + CenterCrop(size=self.resize_to_hw)]) + rasters = {k: resize_transform(v) for k, v in rasters.items()} + + # mask = torch.zeros(rasters['rgb_int'].shape[-2:]) + # rows, cols = mask.shape + # grid_size = random.randint(3, 10) + # grid_rows, grid_cols = rows // grid_size, cols // grid_size + # for i in range(grid_rows): + # for j in range(grid_cols): + # random_prob = np.random.rand() + # if random_prob < 0.5: + # row_start = i * grid_size + # row_end = (i + 1) * grid_size + # col_start = j * grid_size + # col_end = (j + 1) * grid_size + # mask[row_start:row_end, col_start:col_end] = 1 + + # rasters['mask'] = mask.unsqueeze(0) + + other = {"index": index, "rgb_path": rgb_path, 'text': caption} + return rasters, other + + def _load_rgb_data(self, rgb_path): + # Read RGB data + rgb = self._read_rgb_file(rgb_path) + rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] + + outputs = { + "rgb_int": torch.from_numpy(rgb).int(), + "rgb_norm": torch.from_numpy(rgb_norm).float(), + } + return outputs + + def _get_data_path(self, index): + filename_line = self.filenames[index] + + # Get data path + rgb_rel_path = filename_line[0] + + depth_rel_path, text_rel_path = None, None + if DatasetMode.RGB_ONLY != self.mode: + depth_rel_path = filename_line[1] + if len(filename_line) > 2: + text_rel_path = filename_line[2] + return rgb_rel_path, depth_rel_path, text_rel_path + + def _read_image(self, img_path) -> np.ndarray: + image_to_read = img_path + image = Image.open(image_to_read) # [H, W, rgb] + image = np.asarray(image) + return image + + def _read_rgb_file(self, path) -> np.ndarray: + rgb = self._read_image(path) + rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W] + return rgb + + def _read_depth_file(self, path): + depth_in = self._read_image(path) + # Replace code below to decode depth according to dataset definition + depth_decoded = depth_in + return depth_decoded + + def _training_preprocess(self, rasters): + # Augmentation + if self.augm_args is not None: + rasters = self._augment_data(rasters) + + # Normalization + # rasters["depth_raw_norm"] = rasters["depth_raw_linear"] / 255.0 * 2.0 - 1.0 + # rasters["depth_filled_norm"] = rasters["depth_filled_linear"] / 255.0 * 2.0 - 1.0 + + rasters["depth_raw_norm"] = self.depth_transform( + rasters["depth_raw_linear"], rasters["valid_mask_raw"] + ).clone() + rasters["depth_filled_norm"] = self.depth_transform( + rasters["depth_filled_linear"], rasters["valid_mask_filled"] + ).clone() + + # Set invalid pixel to far plane + if self.move_invalid_to_far_plane: + if self.depth_transform.far_plane_at_max: + rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( + self.depth_transform.norm_max + ) + else: + rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( + self.depth_transform.norm_min + ) + + # Resize + if self.resize_to_hw is not None: + resize_transform = transforms.Compose([ + Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT), + CenterCrop(size=self.resize_to_hw)]) + rasters = {k: resize_transform(v) for k, v in rasters.items()} + return rasters + + def _augment_data(self, rasters_dict): + # lr flipping + lr_flip_p = self.augm_args.lr_flip_p + if random.random() < lr_flip_p: + rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()} + + return rasters_dict + + def __del__(self): + if hasattr(self, "tar_obj") and self.tar_obj is not None: + self.tar_obj.close() + self.tar_obj = None + +def get_pred_name(rgb_basename, name_mode, suffix=".png"): + if DepthFileNameMode.rgb_id == name_mode: + pred_basename = "pred_" + rgb_basename.split("_")[1] + elif DepthFileNameMode.i_d_rgb == name_mode: + pred_basename = rgb_basename.replace("_rgb.", "_pred.") + elif DepthFileNameMode.id == name_mode: + pred_basename = "pred_" + rgb_basename + elif DepthFileNameMode.rgb_i_d == name_mode: + pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:]) + else: + raise NotImplementedError + # change suffix + pred_basename = os.path.splitext(pred_basename)[0] + suffix + + return pred_basename diff --git a/src/dataset/depthanything_dataset.py b/src/dataset/depthanything_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..07fae02bd68115696a5c2dc59ccb70419a1e49b4 --- /dev/null +++ b/src/dataset/depthanything_dataset.py @@ -0,0 +1,91 @@ +# Last modified: 2024-02-08 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + +from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode +import torch +from torchvision.transforms import InterpolationMode, Resize, CenterCrop +import torchvision.transforms as transforms + +class DepthAnythingDataset(BaseDepthDataset): + def __init__( + self, + **kwargs, + ) -> None: + super().__init__( + # ScanNet data parameter + min_depth=-1, + max_depth=256, + has_filled_depth=False, + name_mode=DepthFileNameMode.id, + **kwargs, + ) + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Decode ScanNet depth + # depth_decoded = depth_in / 1000.0 + return depth_in + + def _training_preprocess(self, rasters): + # Augmentation + if self.augm_args is not None: + rasters = self._augment_data(rasters) + + # Normalization + rasters["depth_raw_norm"] = rasters["depth_raw_linear"] / 255.0 * 2.0 - 1.0 + rasters["depth_filled_norm"] = rasters["depth_filled_linear"] / 255.0 * 2.0 - 1.0 + + # Set invalid pixel to far plane + if self.move_invalid_to_far_plane: + if self.depth_transform.far_plane_at_max: + rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( + self.depth_transform.norm_max + ) + else: + rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( + self.depth_transform.norm_min + ) + + # Resize + if self.resize_to_hw is not None: + T = transforms.Compose([ + Resize(self.resize_to_hw[0]), + CenterCrop(self.resize_to_hw), + ]) + rasters = {k: T(v) for k, v in rasters.items()} + return rasters + + # def _load_depth_data(self, depth_rel_path, filled_rel_path): + # # Read depth data + # outputs = {} + # depth_raw = self._read_depth_file(depth_rel_path).squeeze() + # depth_raw_linear = torch.from_numpy(depth_raw).float().unsqueeze(0) # [1, H, W] [0, 255] + # outputs["depth_raw_linear"] = depth_raw_linear.clone() + # + # if self.has_filled_depth: + # depth_filled = self._read_depth_file(filled_rel_path).squeeze() + # depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0) + # outputs["depth_filled_linear"] = depth_filled_linear + # else: + # outputs["depth_filled_linear"] = depth_raw_linear.clone() + # + # return outputs \ No newline at end of file diff --git a/src/dataset/diode_dataset.py b/src/dataset/diode_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b69b917fc434c04027d8efe968585f64d2bfdc --- /dev/null +++ b/src/dataset/diode_dataset.py @@ -0,0 +1,91 @@ +# Last modified: 2024-02-26 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + +import os +import tarfile +from io import BytesIO + +import numpy as np +import torch + +from .eval_base_dataset import EvaluateBaseDataset, DepthFileNameMode, DatasetMode + + +class DIODEDataset(EvaluateBaseDataset): + def __init__( + self, + **kwargs, + ) -> None: + super().__init__( + # DIODE data parameter + min_depth=0.6, + max_depth=350, + has_filled_depth=False, + name_mode=DepthFileNameMode.id, + **kwargs, + ) + + def _read_npy_file(self, rel_path): + if self.is_tar: + if self.tar_obj is None: + self.tar_obj = tarfile.open(self.dataset_dir) + fileobj = self.tar_obj.extractfile("./" + rel_path) + npy_path_or_content = BytesIO(fileobj.read()) + else: + npy_path_or_content = os.path.join(self.dataset_dir, rel_path) + data = np.load(npy_path_or_content).squeeze()[np.newaxis, :, :] + return data + + def _read_depth_file(self, rel_path): + depth = self._read_npy_file(rel_path) + return depth + + def _get_data_path(self, index): + return self.filenames[index] + + def _get_data_item(self, index): + # Special: depth mask is read from data + + rgb_rel_path, depth_rel_path, mask_rel_path = self._get_data_path(index=index) + + rasters = {} + + # RGB data + rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) + + # Depth data + if DatasetMode.RGB_ONLY != self.mode: + # load data + depth_data = self._load_depth_data( + depth_rel_path=depth_rel_path, filled_rel_path=None + ) + rasters.update(depth_data) + + # valid mask + mask = self._read_npy_file(mask_rel_path).astype(bool) + mask = torch.from_numpy(mask).bool() + rasters["valid_mask_raw"] = mask.clone() + rasters["valid_mask_filled"] = mask.clone() + + other = {"index": index, "rgb_relative_path": rgb_rel_path} + + return rasters, other diff --git a/src/dataset/eth3d_dataset.py b/src/dataset/eth3d_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3d924c6699f55fc51711dd056903ae2f91d19be7 --- /dev/null +++ b/src/dataset/eth3d_dataset.py @@ -0,0 +1,65 @@ +# Last modified: 2024-02-08 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + +import torch +import tarfile +import os +import numpy as np + +from .eval_base_dataset import DepthFileNameMode, EvaluateBaseDataset + + +class ETH3DDataset(EvaluateBaseDataset): + HEIGHT, WIDTH = 4032, 6048 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__( + # ETH3D data parameter + min_depth=1e-5, + max_depth=torch.inf, + has_filled_depth=False, + name_mode=DepthFileNameMode.id, + **kwargs, + ) + + def _read_depth_file(self, rel_path): + # Read special binary data: https://www.eth3d.net/documentation#format-of-multi-view-data-image-formats + if self.is_tar: + if self.tar_obj is None: + self.tar_obj = tarfile.open(self.dataset_dir) + binary_data = self.tar_obj.extractfile("./" + rel_path) + binary_data = binary_data.read() + + else: + depth_path = os.path.join(self.dataset_dir, rel_path) + with open(depth_path, "rb") as file: + binary_data = file.read() + # Convert the binary data to a numpy array of 32-bit floats + depth_decoded = np.frombuffer(binary_data, dtype=np.float32).copy() + + depth_decoded[depth_decoded == torch.inf] = 0.0 + + depth_decoded = depth_decoded.reshape((self.HEIGHT, self.WIDTH)) + return depth_decoded diff --git a/src/dataset/eval_base_dataset.py b/src/dataset/eval_base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..610f6c59ace1c6e8b043824bbf4dc89bf0efcb84 --- /dev/null +++ b/src/dataset/eval_base_dataset.py @@ -0,0 +1,283 @@ +# Last modified: 2024-04-30 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + +import io +import os +import random +import tarfile +from enum import Enum +from typing import Union + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision.transforms import InterpolationMode, Resize + +from src.util.depth_transform import DepthNormalizerBase + + +class DatasetMode(Enum): + RGB_ONLY = "rgb_only" + EVAL = "evaluate" + TRAIN = "train" + + +class DepthFileNameMode(Enum): + """Prediction file naming modes""" + + id = 1 # id.png + rgb_id = 2 # rgb_id.png + i_d_rgb = 3 # i_d_1_rgb.png + rgb_i_d = 4 + + +def read_image_from_tar(tar_obj, img_rel_path): + image = tar_obj.extractfile("./" + img_rel_path) + image = image.read() + image = Image.open(io.BytesIO(image)) + + +class EvaluateBaseDataset(Dataset): + def __init__( + self, + mode: DatasetMode, + filename_ls_path: str, + dataset_dir: str, + disp_name: str, + min_depth: float, + max_depth: float, + has_filled_depth: bool, + name_mode: DepthFileNameMode, + depth_transform: Union[DepthNormalizerBase, None] = None, + augmentation_args: dict = None, + resize_to_hw=None, + move_invalid_to_far_plane: bool = True, + rgb_transform=lambda x: x / 255.0 * 2 - 1, # [0, 255] -> [-1, 1], + **kwargs, + ) -> None: + super().__init__() + self.mode = mode + # dataset info + self.filename_ls_path = filename_ls_path + self.dataset_dir = dataset_dir + assert os.path.exists( + self.dataset_dir + ), f"Dataset does not exist at: {self.dataset_dir}" + self.disp_name = disp_name + self.has_filled_depth = has_filled_depth + self.name_mode: DepthFileNameMode = name_mode + self.min_depth = min_depth + self.max_depth = max_depth + + # training arguments + self.depth_transform: DepthNormalizerBase = depth_transform + self.augm_args = augmentation_args + self.resize_to_hw = resize_to_hw + self.rgb_transform = rgb_transform + self.move_invalid_to_far_plane = move_invalid_to_far_plane + + # Load filenames + with open(self.filename_ls_path, "r") as f: + self.filenames = [ + s.split() for s in f.readlines() + ] # [['rgb.png', 'depth.tif'], [], ...] + + # Tar dataset + self.tar_obj = None + self.is_tar = ( + True + if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir) + else False + ) + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, index): + rasters, other = self._get_data_item(index) + if DatasetMode.TRAIN == self.mode: + rasters = self._training_preprocess(rasters) + # merge + outputs = rasters + outputs.update(other) + return outputs + + def _get_data_item(self, index): + rgb_rel_path, depth_rel_path, filled_rel_path = self._get_data_path(index=index) + + rasters = {} + + # RGB data + rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) + + # Depth data + if DatasetMode.RGB_ONLY != self.mode: + # load data + depth_data = self._load_depth_data( + depth_rel_path=depth_rel_path, filled_rel_path=filled_rel_path + ) + rasters.update(depth_data) + # valid mask + rasters["valid_mask_raw"] = self._get_valid_mask( + rasters["depth_raw_linear"] + ).clone() + rasters["valid_mask_filled"] = self._get_valid_mask( + rasters["depth_filled_linear"] + ).clone() + + other = {"index": index, "rgb_relative_path": rgb_rel_path} + + return rasters, other + + def _load_rgb_data(self, rgb_rel_path): + # Read RGB data + rgb = self._read_rgb_file(rgb_rel_path) + rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] + + outputs = { + "rgb_int": torch.from_numpy(rgb).int(), + "rgb_norm": torch.from_numpy(rgb_norm).float(), + } + return outputs + + def _load_depth_data(self, depth_rel_path, filled_rel_path): + # Read depth data + outputs = {} + depth_raw = self._read_depth_file(depth_rel_path).squeeze() + depth_raw_linear = torch.from_numpy(depth_raw).float().unsqueeze(0) # [1, H, W] + outputs["depth_raw_linear"] = depth_raw_linear.clone() + + if self.has_filled_depth: + depth_filled = self._read_depth_file(filled_rel_path).squeeze() + depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0) + outputs["depth_filled_linear"] = depth_filled_linear + else: + outputs["depth_filled_linear"] = depth_raw_linear.clone() + + return outputs + + def _get_data_path(self, index): + filename_line = self.filenames[index] + + # Get data path + rgb_rel_path = filename_line[0] + + depth_rel_path, filled_rel_path = None, None + if DatasetMode.RGB_ONLY != self.mode: + depth_rel_path = filename_line[1] + if self.has_filled_depth: + filled_rel_path = filename_line[2] + return rgb_rel_path, depth_rel_path, filled_rel_path + + def _read_image(self, img_rel_path) -> np.ndarray: + if self.is_tar: + if self.tar_obj is None: + self.tar_obj = tarfile.open(self.dataset_dir) + image_to_read = self.tar_obj.extractfile("./" + img_rel_path) + image_to_read = image_to_read.read() + image_to_read = io.BytesIO(image_to_read) + else: + image_to_read = os.path.join(self.dataset_dir, img_rel_path) + image = Image.open(image_to_read) # [H, W, rgb] + image = np.asarray(image) + return image + + def _read_rgb_file(self, rel_path) -> np.ndarray: + rgb = self._read_image(rel_path) + rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W] + return rgb + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Replace code below to decode depth according to dataset definition + depth_decoded = depth_in + + return depth_decoded + + def _get_valid_mask(self, depth: torch.Tensor): + valid_mask = torch.logical_and( + (depth > self.min_depth), (depth < self.max_depth) + ).bool() + return valid_mask + + def _training_preprocess(self, rasters): + # Augmentation + if self.augm_args is not None: + rasters = self._augment_data(rasters) + + # Normalization + rasters["depth_raw_norm"] = self.depth_transform( + rasters["depth_raw_linear"], rasters["valid_mask_raw"] + ).clone() + rasters["depth_filled_norm"] = self.depth_transform( + rasters["depth_filled_linear"], rasters["valid_mask_filled"] + ).clone() + + # Set invalid pixel to far plane + if self.move_invalid_to_far_plane: + if self.depth_transform.far_plane_at_max: + rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( + self.depth_transform.norm_max + ) + else: + rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( + self.depth_transform.norm_min + ) + + # Resize + if self.resize_to_hw is not None: + resize_transform = Resize( + size=self.resize_to_hw, interpolation=InterpolationMode.NEAREST_EXACT + ) + rasters = {k: resize_transform(v) for k, v in rasters.items()} + + return rasters + + def _augment_data(self, rasters_dict): + # lr flipping + lr_flip_p = self.augm_args.lr_flip_p + if random.random() < lr_flip_p: + rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()} + + return rasters_dict + + def __del__(self): + if hasattr(self, "tar_obj") and self.tar_obj is not None: + self.tar_obj.close() + self.tar_obj = None + +def get_pred_name(rgb_basename, name_mode, suffix=".png"): + if DepthFileNameMode.rgb_id == name_mode: + pred_basename = "pred_" + rgb_basename.split("_")[1] + elif DepthFileNameMode.i_d_rgb == name_mode: + pred_basename = rgb_basename.replace("_rgb.", "_pred.") + elif DepthFileNameMode.id == name_mode: + pred_basename = "pred_" + rgb_basename + elif DepthFileNameMode.rgb_i_d == name_mode: + pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:]) + else: + raise NotImplementedError + # change suffix + pred_basename = os.path.splitext(pred_basename)[0] + suffix + + return pred_basename \ No newline at end of file diff --git a/src/dataset/hypersim_dataset.py b/src/dataset/hypersim_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..36f569eac4ef602cb221103c5efdc83a01ddc7ab --- /dev/null +++ b/src/dataset/hypersim_dataset.py @@ -0,0 +1,44 @@ +# Last modified: 2024-02-08 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + + +from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode + +class HypersimDataset(BaseDepthDataset): + def __init__( + self, + **kwargs, + ) -> None: + super().__init__( + # Hypersim data parameter + min_depth=1e-5, + max_depth=65.0, + has_filled_depth=False, + name_mode=DepthFileNameMode.rgb_i_d, + **kwargs, + ) + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Decode Hypersim depth + depth_decoded = depth_in / 1000.0 + return depth_decoded \ No newline at end of file diff --git a/src/dataset/inpaint_dataset.py b/src/dataset/inpaint_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e6666c2ef7e23874cb9b511864a9c1af8b3943 --- /dev/null +++ b/src/dataset/inpaint_dataset.py @@ -0,0 +1,286 @@ +# Last modified: 2024-04-30 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- +import glob +import io +import json +import os +import pdb +import random +import tarfile +from enum import Enum +from typing import Union + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision.transforms import InterpolationMode, Resize, CenterCrop +import torchvision.transforms as transforms +from transformers import CLIPTextModel, CLIPTokenizer +from src.util.depth_transform import DepthNormalizerBase +import random + +from src.dataset.eval_base_dataset import DatasetMode, DepthFileNameMode + + +def read_image_from_tar(tar_obj, img_rel_path): + image = tar_obj.extractfile("./" + img_rel_path) + image = image.read() + image = Image.open(io.BytesIO(image)) + + +class BaseDepthDataset(Dataset): + def __init__( + self, + mode: DatasetMode, + filename_ls_path: str, + dataset_dir: str, + disp_name: str, + min_depth: float, + max_depth: float, + has_filled_depth: bool, + name_mode: DepthFileNameMode, + depth_transform: Union[DepthNormalizerBase, None] = None, + tokenizer: CLIPTokenizer = None, + augmentation_args: dict = None, + resize_to_hw=None, + move_invalid_to_far_plane: bool = True, + rgb_transform=lambda x: x / 255.0 * 2 - 1, # [0, 255] -> [-1, 1], + **kwargs, + ) -> None: + super().__init__() + self.mode = mode + # dataset info + self.filename_ls_path = filename_ls_path + self.disp_name = disp_name + self.has_filled_depth = has_filled_depth + self.name_mode: DepthFileNameMode = name_mode + self.min_depth = min_depth + self.max_depth = max_depth + # training arguments + self.depth_transform: DepthNormalizerBase = depth_transform + self.augm_args = augmentation_args + self.resize_to_hw = resize_to_hw + self.rgb_transform = rgb_transform + self.move_invalid_to_far_plane = move_invalid_to_far_plane + self.tokenizer = tokenizer + # Load filenames + self.filenames = [] + filename_paths = glob.glob(self.filename_ls_path) + for path in filename_paths: + with open(path, "r") as f: + self.filenames += json.load(f) + # Tar dataset + self.tar_obj = None + self.is_tar = ( + True + if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir) + else False + ) + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, index): + rasters, other = self._get_data_item(index) + if DatasetMode.TRAIN == self.mode: + rasters = self._training_preprocess(rasters) + # merge + outputs = rasters + outputs.update(other) + return outputs + + def _get_data_item(self, index): + rgb_path = self.filenames[index]['rgb_path'] + depth_path = self.filenames[index]['depth_path'] + mask_path = None + if 'valid_mask' in self.filenames[index]: + mask_path = self.filenames[index]['valid_mask'] + if self.filenames[index]['caption'] is not None: + coca_caption = self.filenames[index]['caption']['coca_caption'] + spatial_caption = self.filenames[index]['caption']['spatial_caption'] + empty_caption = '' + caption_choices = [coca_caption, spatial_caption, empty_caption] + probabilities = [0.4, 0.4, 0.2] + caption = random.choices(caption_choices, probabilities)[0] + else: + caption = '' + + rasters = {} + # RGB data + rasters.update(self._load_rgb_data(rgb_path)) + + # Depth data + if DatasetMode.RGB_ONLY != self.mode and depth_path is not None: + # load data + depth_data = self._load_depth_data(depth_path) + rasters.update(depth_data) + # valid mask + if mask_path is not None: + valid_mask_raw = Image.open(mask_path) + valid_mask_filled = Image.open(mask_path) + rasters["valid_mask_raw"] = torch.from_numpy(np.asarray(valid_mask_raw)).unsqueeze(0).bool() + rasters["valid_mask_filled"] = torch.from_numpy(np.asarray(valid_mask_filled)).unsqueeze(0).bool() + else: + rasters["valid_mask_raw"] = self._get_valid_mask( + rasters["depth_raw_linear"] + ).clone() + rasters["valid_mask_filled"] = self._get_valid_mask( + rasters["depth_filled_linear"] + ).clone() + + other = {"index": index, "rgb_path": rgb_path, 'text': caption} + + if self.resize_to_hw is not None: + resize_transform = transforms.Compose([ + Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT), + CenterCrop(size=self.resize_to_hw)]) + rasters = {k: resize_transform(v) for k, v in rasters.items()} + + return rasters, other + + def _load_rgb_data(self, rgb_path): + # Read RGB data + rgb = self._read_rgb_file(rgb_path) + rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] + + outputs = { + "rgb_int": torch.from_numpy(rgb).int(), + "rgb_norm": torch.from_numpy(rgb_norm).float(), + } + return outputs + + def _load_depth_data(self, depth_path, filled_rel_path=None): + # Read depth data + outputs = {} + depth_raw = self._read_depth_file(depth_path).squeeze() + depth_raw_linear = torch.from_numpy(depth_raw.copy()).float().unsqueeze(0) # [1, H, W] + outputs["depth_raw_linear"] = depth_raw_linear.clone() + + if self.has_filled_depth: + depth_filled = self._read_depth_file(filled_rel_path).squeeze() + depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0) + outputs["depth_filled_linear"] = depth_filled_linear + else: + outputs["depth_filled_linear"] = depth_raw_linear.clone() + + return outputs + + def _get_data_path(self, index): + filename_line = self.filenames[index] + + # Get data path + rgb_rel_path = filename_line[0] + + depth_rel_path, text_rel_path = None, None + if DatasetMode.RGB_ONLY != self.mode: + depth_rel_path = filename_line[1] + if len(filename_line) > 2: + text_rel_path = filename_line[2] + return rgb_rel_path, depth_rel_path, text_rel_path + + def _read_image(self, img_path) -> np.ndarray: + image_to_read = img_path + image = Image.open(image_to_read) # [H, W, rgb] + image = np.asarray(image) + return image + + def _read_rgb_file(self, path) -> np.ndarray: + rgb = self._read_image(path) + rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W] + return rgb + + def _read_depth_file(self, path): + depth_in = self._read_image(path) + # Replace code below to decode depth according to dataset definition + depth_decoded = depth_in + return depth_decoded + + def _get_valid_mask(self, depth: torch.Tensor): + valid_mask = torch.logical_and( + (depth > self.min_depth), (depth < self.max_depth) + ).bool() + return valid_mask + + def _training_preprocess(self, rasters): + # Augmentation + if self.augm_args is not None: + rasters = self._augment_data(rasters) + + # Normalization + # rasters["depth_raw_norm"] = rasters["depth_raw_linear"] / 255.0 * 2.0 - 1.0 + # rasters["depth_filled_norm"] = rasters["depth_filled_linear"] / 255.0 * 2.0 - 1.0 + + rasters["depth_raw_norm"] = self.depth_transform( + rasters["depth_raw_linear"], rasters["valid_mask_raw"] + ).clone() + rasters["depth_filled_norm"] = self.depth_transform( + rasters["depth_filled_linear"], rasters["valid_mask_filled"] + ).clone() + + # Set invalid pixel to far plane + if self.move_invalid_to_far_plane: + if self.depth_transform.far_plane_at_max: + rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( + self.depth_transform.norm_max + ) + else: + rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( + self.depth_transform.norm_min + ) + + # Resize + if self.resize_to_hw is not None: + resize_transform = transforms.Compose([ + Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT), + CenterCrop(size=self.resize_to_hw)]) + rasters = {k: resize_transform(v) for k, v in rasters.items()} + return rasters + + def _augment_data(self, rasters_dict): + # lr flipping + lr_flip_p = self.augm_args.lr_flip_p + if random.random() < lr_flip_p: + rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()} + + return rasters_dict + + def __del__(self): + if hasattr(self, "tar_obj") and self.tar_obj is not None: + self.tar_obj.close() + self.tar_obj = None + +def get_pred_name(rgb_basename, name_mode, suffix=".png"): + if DepthFileNameMode.rgb_id == name_mode: + pred_basename = "pred_" + rgb_basename.split("_")[1] + elif DepthFileNameMode.i_d_rgb == name_mode: + pred_basename = rgb_basename.replace("_rgb.", "_pred.") + elif DepthFileNameMode.id == name_mode: + pred_basename = "pred_" + rgb_basename + elif DepthFileNameMode.rgb_i_d == name_mode: + pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:]) + else: + raise NotImplementedError + # change suffix + pred_basename = os.path.splitext(pred_basename)[0] + suffix + + return pred_basename diff --git a/src/dataset/kitti_dataset.py b/src/dataset/kitti_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cea6b583bb06e9aa152d271314e6477a0d0ac96b --- /dev/null +++ b/src/dataset/kitti_dataset.py @@ -0,0 +1,124 @@ +# Last modified: 2024-02-08 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + +import torch + +from .eval_base_dataset import DepthFileNameMode, EvaluateBaseDataset + + +class KITTIDataset(EvaluateBaseDataset): + def __init__( + self, + kitti_bm_crop, # Crop to KITTI benchmark size + valid_mask_crop, # Evaluation mask. [None, garg or eigen] + **kwargs, + ) -> None: + super().__init__( + # KITTI data parameter + min_depth=1e-5, + max_depth=80, + has_filled_depth=False, + name_mode=DepthFileNameMode.id, + **kwargs, + ) + self.kitti_bm_crop = kitti_bm_crop + self.valid_mask_crop = valid_mask_crop + assert self.valid_mask_crop in [ + None, + "garg", # set evaluation mask according to Garg ECCV16 + "eigen", # set evaluation mask according to Eigen NIPS14 + ], f"Unknown crop type: {self.valid_mask_crop}" + + # Filter out empty depth + self.filenames = [f for f in self.filenames if "None" != f[1]] + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Decode KITTI depth + depth_decoded = depth_in / 256.0 + return depth_decoded + + def _load_rgb_data(self, rgb_rel_path): + rgb_data = super()._load_rgb_data(rgb_rel_path) + if self.kitti_bm_crop: + rgb_data = {k: self.kitti_benchmark_crop(v) for k, v in rgb_data.items()} + return rgb_data + + def _load_depth_data(self, depth_rel_path, filled_rel_path): + depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path) + if self.kitti_bm_crop: + depth_data = { + k: self.kitti_benchmark_crop(v) for k, v in depth_data.items() + } + return depth_data + + @staticmethod + def kitti_benchmark_crop(input_img): + """ + Crop images to KITTI benchmark size + Args: + `input_img` (torch.Tensor): Input image to be cropped. + + Returns: + torch.Tensor:Cropped image. + """ + KB_CROP_HEIGHT = 352 + KB_CROP_WIDTH = 1216 + + height, width = input_img.shape[-2:] + top_margin = int(height - KB_CROP_HEIGHT) + left_margin = int((width - KB_CROP_WIDTH) / 2) + if 2 == len(input_img.shape): + out = input_img[ + top_margin : top_margin + KB_CROP_HEIGHT, + left_margin : left_margin + KB_CROP_WIDTH, + ] + elif 3 == len(input_img.shape): + out = input_img[ + :, + top_margin : top_margin + KB_CROP_HEIGHT, + left_margin : left_margin + KB_CROP_WIDTH, + ] + return out + + def _get_valid_mask(self, depth: torch.Tensor): + # reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py + valid_mask = super()._get_valid_mask(depth) # [1, H, W] + + if self.valid_mask_crop is not None: + eval_mask = torch.zeros_like(valid_mask.squeeze()).bool() + gt_height, gt_width = eval_mask.shape + + if "garg" == self.valid_mask_crop: + eval_mask[ + int(0.40810811 * gt_height) : int(0.99189189 * gt_height), + int(0.03594771 * gt_width) : int(0.96405229 * gt_width), + ] = 1 + elif "eigen" == self.valid_mask_crop: + eval_mask[ + int(0.3324324 * gt_height) : int(0.91351351 * gt_height), + int(0.0359477 * gt_width) : int(0.96405229 * gt_width), + ] = 1 + + eval_mask.reshape(valid_mask.shape) + valid_mask = torch.logical_and(valid_mask, eval_mask) + return valid_mask diff --git a/src/dataset/mixed_sampler.py b/src/dataset/mixed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..3abc60f50b3374fe4b836ac45846e67366491941 --- /dev/null +++ b/src/dataset/mixed_sampler.py @@ -0,0 +1,149 @@ +# Last modified: 2024-04-18 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + +import torch +from torch.utils.data import ( + BatchSampler, + RandomSampler, + SequentialSampler, +) + + +class MixedBatchSampler(BatchSampler): + """Sample one batch from a selected dataset with given probability. + Compatible with datasets at different resolution + """ + + def __init__( + self, src_dataset_ls, batch_size, drop_last, shuffle, prob=None, generator=None + ): + self.base_sampler = None + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + self.generator = generator + + self.src_dataset_ls = src_dataset_ls + self.n_dataset = len(self.src_dataset_ls) + + # Dataset length + self.dataset_length = [len(ds) for ds in self.src_dataset_ls] + self.cum_dataset_length = [ + sum(self.dataset_length[:i]) for i in range(self.n_dataset) + ] # cumulative dataset length + + # BatchSamplers for each source dataset + if self.shuffle: + self.src_batch_samplers = [ + BatchSampler( + sampler=RandomSampler( + ds, replacement=False, generator=self.generator + ), + batch_size=self.batch_size, + drop_last=self.drop_last, + ) + for ds in self.src_dataset_ls + ] + else: + self.src_batch_samplers = [ + BatchSampler( + sampler=SequentialSampler(ds), + batch_size=self.batch_size, + drop_last=self.drop_last, + ) + for ds in self.src_dataset_ls + ] + self.raw_batches = [ + list(bs) for bs in self.src_batch_samplers + ] # index in original dataset + self.n_batches = [len(b) for b in self.raw_batches] + self.n_total_batch = sum(self.n_batches) + + # sampling probability + if prob is None: + # if not given, decide by dataset length + self.prob = torch.tensor(self.n_batches) / self.n_total_batch + else: + self.prob = torch.as_tensor(prob) + + def __iter__(self): + """_summary_ + + Yields: + list(int): a batch of indics, corresponding to ConcatDataset of src_dataset_ls + """ + for _ in range(self.n_total_batch): + idx_ds = torch.multinomial( + self.prob, 1, replacement=True, generator=self.generator + ).item() + # if batch list is empty, generate new list + if 0 == len(self.raw_batches[idx_ds]): + self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds]) + # get a batch from list + batch_raw = self.raw_batches[idx_ds].pop() + # shift by cumulative dataset length + shift = self.cum_dataset_length[idx_ds] + batch = [n + shift for n in batch_raw] + + yield batch + + def __len__(self): + return self.n_total_batch + + +# Unit test +if "__main__" == __name__: + from torch.utils.data import ConcatDataset, DataLoader, Dataset + + class SimpleDataset(Dataset): + def __init__(self, start, len) -> None: + super().__init__() + self.start = start + self.len = len + + def __len__(self): + return self.len + + def __getitem__(self, index): + return self.start + index + + dataset_1 = SimpleDataset(0, 10) + dataset_2 = SimpleDataset(200, 20) + dataset_3 = SimpleDataset(1000, 50) + + concat_dataset = ConcatDataset( + [dataset_1, dataset_2, dataset_3] + ) # will directly concatenate + + mixed_sampler = MixedBatchSampler( + src_dataset_ls=[dataset_1, dataset_2, dataset_3], + batch_size=4, + drop_last=True, + shuffle=False, + prob=[0.6, 0.3, 0.1], + generator=torch.Generator().manual_seed(0), + ) + + loader = DataLoader(concat_dataset, batch_sampler=mixed_sampler) + + for d in loader: + print(d) diff --git a/src/dataset/nyu_dataset.py b/src/dataset/nyu_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..29e868aab520077bbde7a3b3a0b4cbe2f5f2fae5 --- /dev/null +++ b/src/dataset/nyu_dataset.py @@ -0,0 +1,61 @@ +# Last modified: 2024-02-08 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + +import torch + +from .eval_base_dataset import DepthFileNameMode, EvaluateBaseDataset + + +class NYUDataset(EvaluateBaseDataset): + def __init__( + self, + eigen_valid_mask: bool, + **kwargs, + ) -> None: + super().__init__( + # NYUv2 dataset parameter + min_depth=1e-3, + max_depth=10.0, + has_filled_depth=True, + name_mode=DepthFileNameMode.rgb_id, + **kwargs, + ) + + self.eigen_valid_mask = eigen_valid_mask + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Decode NYU depth + depth_decoded = depth_in / 1000.0 + return depth_decoded + + def _get_valid_mask(self, depth: torch.Tensor): + valid_mask = super()._get_valid_mask(depth) + + # Eigen crop for evaluation + if self.eigen_valid_mask: + eval_mask = torch.zeros_like(valid_mask.squeeze()).bool() + eval_mask[45:471, 41:601] = 1 + eval_mask.reshape(valid_mask.shape) + valid_mask = torch.logical_and(valid_mask, eval_mask) + + return valid_mask diff --git a/src/dataset/scannet_dataset.py b/src/dataset/scannet_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..17c262ac74a06f1be1fe689a9de34af54e610ddb --- /dev/null +++ b/src/dataset/scannet_dataset.py @@ -0,0 +1,44 @@ +# Last modified: 2024-02-08 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + +from .eval_base_dataset import DepthFileNameMode, EvaluateBaseDataset + + +class ScanNetDataset(EvaluateBaseDataset): + def __init__( + self, + **kwargs, + ) -> None: + super().__init__( + # ScanNet data parameter + min_depth=1e-3, + max_depth=10, + has_filled_depth=False, + name_mode=DepthFileNameMode.id, + **kwargs, + ) + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Decode ScanNet depth + depth_decoded = depth_in / 1000.0 + return depth_decoded diff --git a/src/dataset/vkitti_dataset.py b/src/dataset/vkitti_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b723b2d1b3e5530b682349ae743d217b05509117 --- /dev/null +++ b/src/dataset/vkitti_dataset.py @@ -0,0 +1,97 @@ +# Last modified: 2024-02-08 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + +import torch + +from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode +from .kitti_dataset import KITTIDataset + +class VirtualKITTIDataset(BaseDepthDataset): + def __init__( + self, + kitti_bm_crop, # Crop to KITTI benchmark size + valid_mask_crop, # Evaluation mask. [None, garg or eigen] + **kwargs, + ) -> None: + super().__init__( + # virtual KITTI data parameter + min_depth=1e-5, + max_depth=80, # 655.35 + has_filled_depth=False, + name_mode=DepthFileNameMode.id, + **kwargs, + ) + self.kitti_bm_crop = kitti_bm_crop + self.valid_mask_crop = valid_mask_crop + assert self.valid_mask_crop in [ + None, + "garg", # set evaluation mask according to Garg ECCV16 + "eigen", # set evaluation mask according to Eigen NIPS14 + ], f"Unknown crop type: {self.valid_mask_crop}" + + # Filter out empty depth + self.filenames = self.filenames + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Decode vKITTI depth + depth_decoded = depth_in / 100.0 + return depth_decoded + + def _load_rgb_data(self, rgb_rel_path): + rgb_data = super()._load_rgb_data(rgb_rel_path) + if self.kitti_bm_crop: + rgb_data = { + k: KITTIDataset.kitti_benchmark_crop(v) for k, v in rgb_data.items() + } + return rgb_data + + def _load_depth_data(self, depth_rel_path, filled_rel_path=None): + depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path) + if self.kitti_bm_crop: + depth_data = { + k: KITTIDataset.kitti_benchmark_crop(v) for k, v in depth_data.items() + } + return depth_data + + def _get_valid_mask(self, depth: torch.Tensor): + # reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py + valid_mask = super()._get_valid_mask(depth) # [1, H, W] + + if self.valid_mask_crop is not None: + eval_mask = torch.zeros_like(valid_mask.squeeze()).bool() + gt_height, gt_width = eval_mask.shape + + if "garg" == self.valid_mask_crop: + eval_mask[ + int(0.40810811 * gt_height) : int(0.99189189 * gt_height), + int(0.03594771 * gt_width) : int(0.96405229 * gt_width), + ] = 1 + elif "eigen" == self.valid_mask_crop: + eval_mask[ + int(0.3324324 * gt_height) : int(0.91351351 * gt_height), + int(0.0359477 * gt_width) : int(0.96405229 * gt_width), + ] = 1 + + eval_mask.reshape(valid_mask.shape) + valid_mask = torch.logical_and(valid_mask, eval_mask) + return valid_mask diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4a505daacc68d552f6fe200bb3c8814f3a02a5b --- /dev/null +++ b/src/trainer/__init__.py @@ -0,0 +1,16 @@ +# Author: Bingxin Ke +# Last modified: 2024-05-17 + +from .marigold_trainer import MarigoldTrainer +from .marigold_xl_trainer import MarigoldXLTrainer +from .marigold_inpaint_trainer import MarigoldInpaintTrainer + +trainer_cls_name_dict = { + "MarigoldTrainer": MarigoldTrainer, + "MarigoldXLTrainer": MarigoldXLTrainer, + "MarigoldInpaintTrainer": MarigoldInpaintTrainer +} + + +def get_trainer_cls(trainer_name): + return trainer_cls_name_dict[trainer_name] diff --git a/src/trainer/__pycache__/__init__.cpython-310.pyc b/src/trainer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c1aee05301248fbb0413d7ea69bad0296a593e8 Binary files /dev/null and b/src/trainer/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/trainer/__pycache__/marigold_inpaint_trainer.cpython-310.pyc b/src/trainer/__pycache__/marigold_inpaint_trainer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..625b77aee188942dbea4aa3a05db60542ce9c417 Binary files /dev/null and b/src/trainer/__pycache__/marigold_inpaint_trainer.cpython-310.pyc differ diff --git a/src/trainer/__pycache__/marigold_trainer.cpython-310.pyc b/src/trainer/__pycache__/marigold_trainer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d5b32dfe663ff779aef4b9282a5836e664e5fd1 Binary files /dev/null and b/src/trainer/__pycache__/marigold_trainer.cpython-310.pyc differ diff --git a/src/trainer/__pycache__/marigold_xl_trainer.cpython-310.pyc b/src/trainer/__pycache__/marigold_xl_trainer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..160410f04646177af9f100f76a68d92e7f05c5b1 Binary files /dev/null and b/src/trainer/__pycache__/marigold_xl_trainer.cpython-310.pyc differ diff --git a/src/trainer/marigold_inpaint_trainer.py b/src/trainer/marigold_inpaint_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..006247adde3c0e195d863066041d2f060706fe42 --- /dev/null +++ b/src/trainer/marigold_inpaint_trainer.py @@ -0,0 +1,665 @@ +# An official reimplemented version of Marigold training script. +# Last modified: 2024-04-29 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- +from diffusers import StableDiffusionInpaintPipeline +import logging +import os +import pdb +import cv2 +import shutil +import json +from pycocotools import mask as coco_mask +from datetime import datetime +from typing import List, Union +import random +import safetensors +import numpy as np +import torch +from diffusers import DDPMScheduler +from omegaconf import OmegaConf +from torch.nn import Conv2d +from torch.nn.parameter import Parameter +from torch.optim import Adam +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +from PIL import Image +# import torch.optim.lr_scheduler + +from diffusers.schedulers import PNDMScheduler +from torchvision.transforms.functional import pil_to_tensor +from marigold.marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput +from src.util import metric +from src.util.data_loader import skip_first_batches +from src.util.logging_util import tb_logger, eval_dic_to_text +from src.util.loss import get_loss +from src.util.lr_scheduler import IterExponential +from src.util.metric import MetricTracker +from src.util.multi_res_noise import multi_res_noise_like +from src.util.alignment import align_depth_least_square, depth2disparity, disparity2depth +from src.util.seeding import generate_seed_sequence +from accelerate import Accelerator +import os +from torchvision.transforms import InterpolationMode, Resize, CenterCrop +import torchvision.transforms as transforms +# os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + +class MarigoldInpaintTrainer: + def __init__( + self, + cfg: OmegaConf, + model: MarigoldPipeline, + train_dataloader: DataLoader, + device, + base_ckpt_dir, + out_dir_ckpt, + out_dir_eval, + out_dir_vis, + accumulation_steps: int, + depth_model = None, + separate_list: List = None, + val_dataloaders: List[DataLoader] = None, + vis_dataloaders: List[DataLoader] = None, + train_dataset: Dataset = None, + timestep_method: str = 'unidiffuser', + connection: bool = False + ): + self.cfg: OmegaConf = cfg + self.model: MarigoldPipeline = model + self.depth_model = depth_model + self.device = device + self.seed: Union[int, None] = ( + self.cfg.trainer.init_seed + ) # used to generate seed sequence, set to `None` to train w/o seeding + self.out_dir_ckpt = out_dir_ckpt + self.out_dir_eval = out_dir_eval + self.out_dir_vis = out_dir_vis + self.train_loader: DataLoader = train_dataloader + self.val_loaders: List[DataLoader] = val_dataloaders + self.vis_loaders: List[DataLoader] = vis_dataloaders + self.accumulation_steps: int = accumulation_steps + self.separate_list = separate_list + self.timestep_method = timestep_method + self.train_dataset = train_dataset + self.connection = connection + # Adapt input layers + # if 8 != self.model.unet.config["in_channels"]: + # self._replace_unet_conv_in() + # if 8 != self.model.unet.config["out_channels"]: + # self._replace_unet_conv_out() + + self.train_metrics = MetricTracker(*["loss", 'rgb_loss', 'depth_loss']) + # self.generator = torch.Generator('cuda:0').manual_seed(1024) + + # Encode empty text prompt + self.model.encode_empty_text() + self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) + + self.model.unet.enable_xformers_memory_efficient_attention() + + # Trainability + self.model.text_encoder.requires_grad_(False) + # self.model.unet.requires_grad_(True) + + grad_part = filter(lambda p: p.requires_grad, self.model.unet.parameters()) + + # Optimizer !should be defined after input layer is adapted + lr = self.cfg.lr + self.optimizer = Adam(grad_part, lr=lr) + + total_params = sum(p.numel() for p in self.model.unet.parameters()) + total_params_m = total_params / 1_000_000 + print(f"Total parameters: {total_params_m:.2f}M") + trainable_params = sum(p.numel() for p in self.model.unet.parameters() if p.requires_grad) + trainable_params_m = trainable_params / 1_000_000 + print(f"Trainable parameters: {trainable_params_m:.2f}M") + + # LR scheduler + lr_func = IterExponential( + total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, + final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, + warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, + ) + self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) + + # Loss + self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) + + # Training noise scheduler + # self.rgb_training_noise_scheduler: PNDMScheduler = PNDMScheduler.from_pretrained( + # os.path.join( + # cfg.trainer.rgb_training_noise_scheduler.pretrained_path, + # "scheduler", + # ) + # ) + + self.rgb_training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained( + cfg.trainer.depth_training_noise_scheduler.pretrained_path, subfolder="scheduler") + self.depth_training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained( + cfg.trainer.depth_training_noise_scheduler.pretrained_path, subfolder="scheduler") + + self.rgb_prediction_type = self.rgb_training_noise_scheduler.config.prediction_type + # assert ( + # self.rgb_prediction_type == self.model.rgb_scheduler.config.prediction_type + # ), "Different prediction types" + self.depth_prediction_type = self.depth_training_noise_scheduler.config.prediction_type + assert ( + self.depth_prediction_type == self.model.depth_scheduler.config.prediction_type + ), "Different prediction types" + self.scheduler_timesteps = ( + self.rgb_training_noise_scheduler.config.num_train_timesteps + ) + + # Settings + self.max_epoch = self.cfg.max_epoch + self.max_iter = self.cfg.max_iter + self.gradient_accumulation_steps = accumulation_steps + self.gt_depth_type = self.cfg.gt_depth_type + self.gt_mask_type = self.cfg.gt_mask_type + self.save_period = self.cfg.trainer.save_period + self.backup_period = self.cfg.trainer.backup_period + self.val_period = self.cfg.trainer.validation_period + self.vis_period = self.cfg.trainer.visualization_period + + # Multi-resolution noise + self.apply_multi_res_noise = self.cfg.multi_res_noise is not None + if self.apply_multi_res_noise: + self.mr_noise_strength = self.cfg.multi_res_noise.strength + self.annealed_mr_noise = self.cfg.multi_res_noise.annealed + self.mr_noise_downscale_strategy = ( + self.cfg.multi_res_noise.downscale_strategy + ) + + # Internal variables + self.epoch = 0 + self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training + self.effective_iter = 0 # how many times optimizer.step() is called + self.in_evaluation = False + self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming + + def _replace_unet_conv_in(self): + # replace the first layer to accept 8 in_channels + _weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3] + _bias = self.model.unet.conv_in.bias.clone() # [320] + zero_weight = torch.zeros(_weight.shape).to(_weight.device) + _weight = torch.cat([_weight, zero_weight], dim=1) + # _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s) + # half the activation magnitude + # _weight *= 0.5 + # new conv_in channel + _n_convin_out_channel = self.model.unet.conv_in.out_channels + _new_conv_in = Conv2d( + 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + _new_conv_in.weight = Parameter(_weight) + _new_conv_in.bias = Parameter(_bias) + self.model.unet.conv_in = _new_conv_in + logging.info("Unet conv_in layer is replaced") + # replace config + self.model.unet.config["in_channels"] = 8 + logging.info("Unet config is updated") + return + + def parallel_train(self, t_end=None, accelerator=None): + logging.info("Start training") + self.model, self.optimizer, self.train_loader, self.lr_scheduler = accelerator.prepare( + self.model, self.optimizer, self.train_loader, self.lr_scheduler + ) + self.depth_model = accelerator.prepare(self.depth_model) + + self.accelerator = accelerator + if os.path.exists(os.path.join(self.out_dir_ckpt, 'latest')): + accelerator.load_state(os.path.join(self.out_dir_ckpt, 'latest')) + self.load_miscs(os.path.join(self.out_dir_ckpt, 'latest')) + + # if accelerator.is_main_process: + # self._inpaint_rgbd() + + self.train_metrics.reset() + accumulated_step = 0 + for epoch in range(self.epoch, self.max_epoch + 1): + self.epoch = epoch + logging.debug(f"epoch: {self.epoch}") + + # Skip previous batches when resume + for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): + self.model.unet.train() + + # globally consistent random generators + if self.seed is not None: + local_seed = self._get_next_seed() + rand_num_generator = torch.Generator(device=self.model.device) + rand_num_generator.manual_seed(local_seed) + else: + rand_num_generator = None + + # >>> With gradient accumulation >>> + + # Get data + rgb = batch["rgb_norm"].to(self.model.device) + with torch.no_grad(): + disparities = self.depth_model(batch["rgb_int"].numpy().astype(np.uint8), 518, device=self.model.device) + + if len(disparities.shape) == 2: + disparities = disparities.unsqueeze(0) + + depth_gt_for_latent = [] + for disparity_map in disparities: + depth_map = ((disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min())) * 2 - 1 + depth_gt_for_latent.append(depth_map) + depth_gt_for_latent = torch.stack(depth_gt_for_latent, dim=0) + + batch_size = rgb.shape[0] + + mask = self.model.mask_processor.preprocess(batch['mask'] * 255).to(self.model.device) + + rgb_timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=self.model.device, + generator=rand_num_generator, + ).long() # [B] + depth_timesteps = rgb_timesteps + + rgb_flag = 1 + depth_flag = 1 + + if self.timestep_method == 'joint': + rgb_mask = mask + depth_mask = mask + + elif self.timestep_method == 'partition': + rand_num = random.random() + if rand_num < 0.5: # joint prediction + rgb_mask = mask + depth_mask = mask + elif rand_num < 0.75: # full rgb; depth prediction + rgb_flag = 0 + rgb_mask = torch.zeros_like(mask) + depth_mask = mask + else: + depth_flag = 0 + rgb_mask = mask + if random.random() < 0.5: + depth_mask = torch.zeros_like(mask) # full depth; rgb prediction + else: + depth_mask = mask # partial depth; rgb prediction + + masked_rgb = rgb * (rgb_mask < 0.5) + masked_depth = depth_gt_for_latent * (depth_mask.squeeze() < 0.5) + with torch.no_grad(): + # Encode image + rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w] + mask_rgb_latent = self.model.encode_rgb(masked_rgb) + + if depth_timesteps.sum() == 0: + gt_depth_latent = self.encode_depth(masked_depth) + else: + gt_depth_latent = self.encode_depth(depth_gt_for_latent) + mask_depth_latent = self.encode_depth(masked_depth) + + rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:]) + depth_mask = torch.nn.functional.interpolate(depth_mask, size=gt_depth_latent.shape[-2:]) + + # Sample noise + rgb_noise = torch.randn( + rgb_latent.shape, + device=self.model.device, + generator=rand_num_generator, + ) # [B, 4, h, w] + depth_noise = torch.randn( + gt_depth_latent.shape, + device=self.model.device, + generator=rand_num_generator, + ) # [B, 4, h, w] + + if rgb_timesteps.sum() == 0: + noisy_rgb_latents = rgb_latent + else: + noisy_rgb_latents = self.rgb_training_noise_scheduler.add_noise( + rgb_latent, rgb_noise, rgb_timesteps + ) # [B, 4, h, w] + if depth_timesteps.sum() == 0: + noisy_depth_latents = gt_depth_latent + else: + noisy_depth_latents = self.depth_training_noise_scheduler.add_noise( + gt_depth_latent, depth_noise, depth_timesteps + ) # [B, 4, h, w] + + noisy_latents = torch.cat( + [noisy_rgb_latents, rgb_mask, mask_rgb_latent, mask_depth_latent, noisy_depth_latents, depth_mask, mask_rgb_latent, mask_depth_latent], dim=1 + ).float() # [B, 9*2, h, w] + + # Text embedding + input_ids = self.model.tokenizer( + batch['text'], + padding="max_length", + max_length=self.model.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + input_ids = {k: v.to(self.model.device) for k, v in input_ids.items()} + text_embed = self.model.text_encoder(**input_ids)[0] + + model_pred = self.model.unet( + noisy_latents, rgb_timesteps, depth_timesteps, text_embed, controlnet_connection=self.connection + ).sample # [B, 8, h, w] + + if torch.isnan(model_pred).any(): + logging.warning("model_pred contains NaN.") + + # Get the target for loss depending on the prediction type + if "sample" == self.rgb_prediction_type: + rgb_target = rgb_latent + elif "epsilon" == self.rgb_prediction_type: + rgb_target = rgb_latent + elif "v_prediction" == self.rgb_prediction_type: + rgb_target = self.rgb_training_noise_scheduler.get_velocity( + rgb_latent, rgb_noise, rgb_timesteps + ) # [B, 4, h, w] + else: + raise ValueError(f"Unknown rgb prediction type {self.prediction_type}") + + if "sample" == self.depth_prediction_type: + depth_target = gt_depth_latent + elif "epsilon" == self.depth_prediction_type: + depth_target = gt_depth_latent + elif "v_prediction" == self.depth_prediction_type: + depth_target = self.depth_training_noise_scheduler.get_velocity( + gt_depth_latent, depth_noise, depth_timesteps + ) # [B, 4, h, w] + else: + raise ValueError(f"Unknown depth prediction type {self.prediction_type}") + # Masked latent loss + with accelerator.accumulate(self.model): + + rgb_loss = self.loss(model_pred[:, 0:4, :, :].float(), rgb_target.float()) + depth_loss = self.loss(model_pred[:, 4:, :, :].float(), depth_target.float()) + + if rgb_flag == 0: + loss = depth_loss + elif depth_flag == 0: + loss = rgb_loss + else: + loss = self.cfg.loss.depth_factor * depth_loss + (1 - self.cfg.loss.depth_factor) * rgb_loss + + self.train_metrics.update("loss", loss.item()) + self.train_metrics.update("rgb_loss", rgb_loss.item()) + self.train_metrics.update("depth_loss", depth_loss.item()) + # loss = loss / self.gradient_accumulation_steps + accelerator.backward(loss) + self.optimizer.step() + self.optimizer.zero_grad() + # loss.backward() + self.n_batch_in_epoch += 1 + # print(accelerator.process_index, self.lr_scheduler.get_last_lr()) + self.lr_scheduler.step(self.effective_iter) + + if accelerator.sync_gradients: + accumulated_step += 1 + + if accumulated_step >= self.gradient_accumulation_steps: + accumulated_step = 0 + self.effective_iter += 1 + + if accelerator.is_main_process: + # Log to tensorboard + if self.effective_iter == 1: + self._inpaint_rgbd() + + accumulated_loss = self.train_metrics.result()["loss"] + rgb_loss = self.train_metrics.result()["rgb_loss"] + depth_loss = self.train_metrics.result()["depth_loss"] + tb_logger.log_dic( + { + f"train/{k}": v + for k, v in self.train_metrics.result().items() + }, + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "lr", + self.lr_scheduler.get_last_lr()[0], + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "n_batch_in_epoch", + self.n_batch_in_epoch, + global_step=self.effective_iter, + ) + logging.info( + f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}, rgb_loss={rgb_loss:.5f}, depth_loss={depth_loss:.5f}" + ) + accelerator.wait_for_everyone() + + if self.save_period > 0 and 0 == self.effective_iter % self.save_period: + accelerator.save_state(output_dir=os.path.join(self.out_dir_ckpt, 'latest')) + unwrapped_model = accelerator.unwrap_model(self.model) + if accelerator.is_main_process: + accelerator.save_model(unwrapped_model.unet, + os.path.join(self.out_dir_ckpt, 'latest'), safe_serialization=False) + self.save_miscs('latest') + self._inpaint_rgbd() + accelerator.wait_for_everyone() + + if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: + unwrapped_model = accelerator.unwrap_model(self.model) + if accelerator.is_main_process: + accelerator.save_model(unwrapped_model.unet, + os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name()), safe_serialization=False) + accelerator.wait_for_everyone() + + # End of training + if self.max_iter > 0 and self.effective_iter >= self.max_iter: + unwrapped_model = accelerator.unwrap_model(self.model) + if accelerator.is_main_process: + unwrapped_model.unet.save_pretrained( + os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name())) + accelerator.wait_for_everyone() + return + + torch.cuda.empty_cache() + # <<< Effective batch end <<< + + # Epoch end + self.n_batch_in_epoch = 0 + + def _inpaint_rgbd(self): + image_path = ['/dataset/~sa-1b/data/sa_001000/sa_10000335.jpg', + '/dataset/~sa-1b/data/sa_000357/sa_3572319.jpg', + '/dataset/~sa-1b/data/sa_000045/sa_457934.jpg'] + prompt = ['A white car is parked in front of the factory', + 'church with cemetery next to it', + 'A house with a red brick roof'] + + imgs = [pil_to_tensor(Image.open(p)) for p in image_path] + depth_imgs = [self.depth_model(img.unsqueeze(0).cpu().numpy()) for img in imgs] + + masks = [] + for rgb_path in image_path: + anno = json.load(open(rgb_path.replace('.jpg', '.json')))['annotations'] + random.shuffle(anno) + object_num = random.randint(5, 10) + mask = np.array(coco_mask.decode(anno[0]['segmentation']), dtype=np.uint8) + for single_anno in (anno[0:object_num] if len(anno)>object_num else anno): + mask += np.array(coco_mask.decode(single_anno['segmentation']), dtype=np.uint8) + masks.append(torch.from_numpy(mask)) + + resize_transform = transforms.Compose([ + Resize(size=512, interpolation=InterpolationMode.NEAREST_EXACT), + CenterCrop(size=[512, 512])]) + imgs = [resize_transform(img) for img in imgs] + depth_imgs = [resize_transform(depth_img.unsqueeze(0)) for depth_img in depth_imgs] + masks = [resize_transform(mask.unsqueeze(0)) for mask in masks] + # pdb.set_trace() + + for i in range(len(imgs)): + output_image = self.model._rgbd_inpaint(imgs[i], depth_imgs[i], masks[i], [prompt[i]], processing_res=512, mode='joint_inpaint') + tb_logger.writer.add_image(f'{prompt[i]}', pil_to_tensor(output_image), self.effective_iter) + + def encode_depth(self, depth_in): + # stack depth into 3-channel + stacked = self.stack_depth_images(depth_in) + # encode using VAE encoder + depth_latent = self.model.encode_rgb(stacked) + return depth_latent + + @staticmethod + def stack_depth_images(depth_in): + if 4 == len(depth_in.shape): + stacked = depth_in.repeat(1, 3, 1, 1) + elif 3 == len(depth_in.shape): + stacked = depth_in.unsqueeze(1) + stacked = stacked.repeat(1, 3, 1, 1) + elif 2 == len(depth_in.shape): + stacked = depth_in.unsqueeze(0).unsqueeze(0) + stacked = stacked.repeat(1, 3, 1, 1) + return stacked + + def visualize(self): + for val_loader in self.vis_loaders: + vis_dataset_name = val_loader.dataset.disp_name + vis_out_dir = os.path.join( + self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name + ) + os.makedirs(vis_out_dir, exist_ok=True) + _ = self.validate_single_dataset( + data_loader=val_loader, + metric_tracker=self.val_metrics, + save_to_dir=vis_out_dir, + ) + + def _get_next_seed(self): + if 0 == len(self.global_seed_sequence): + self.global_seed_sequence = generate_seed_sequence( + initial_seed=self.seed, + length=self.max_iter * self.gradient_accumulation_steps, + ) + logging.info( + f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" + ) + return self.global_seed_sequence.pop() + + def save_miscs(self, ckpt_name): + ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) + state = { + "config": self.cfg, + "effective_iter": self.effective_iter, + "epoch": self.epoch, + "n_batch_in_epoch": self.n_batch_in_epoch, + "global_seed_sequence": self.global_seed_sequence, + } + train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") + torch.save(state, train_state_path) + + logging.info(f"Misc state is saved to: {train_state_path}") + + def load_miscs(self, ckpt_path): + checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) + self.effective_iter = checkpoint["effective_iter"] + self.epoch = checkpoint["epoch"] + self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] + self.global_seed_sequence = checkpoint["global_seed_sequence"] + + logging.info(f"Misc state is loaded from {ckpt_path}") + + + def save_checkpoint(self, ckpt_name, save_train_state): + ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) + logging.info(f"Saving checkpoint to: {ckpt_dir}") + # Backup previous checkpoint + temp_ckpt_dir = None + if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): + temp_ckpt_dir = os.path.join( + os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" + ) + if os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + os.rename(ckpt_dir, temp_ckpt_dir) + logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") + + # Save UNet + unet_path = os.path.join(ckpt_dir, "unet") + self.model.unet.save_pretrained(unet_path, safe_serialization=False) + logging.info(f"UNet is saved to: {unet_path}") + + if save_train_state: + state = { + "config": self.cfg, + "effective_iter": self.effective_iter, + "epoch": self.epoch, + "n_batch_in_epoch": self.n_batch_in_epoch, + "best_metric": self.best_metric, + "in_evaluation": self.in_evaluation, + "global_seed_sequence": self.global_seed_sequence, + } + train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") + torch.save(state, train_state_path) + # iteration indicator + f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") + f.close() + + logging.info(f"Trainer state is saved to: {train_state_path}") + + # Remove temp ckpt + if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + logging.debug("Old checkpoint backup is removed.") + + def load_checkpoint( + self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True + ): + logging.info(f"Loading checkpoint from: {ckpt_path}") + # Load UNet + _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") + self.model.unet.load_state_dict( + torch.load(_model_path, map_location=self.device) + ) + self.model.unet.to(self.device) + logging.info(f"UNet parameters are loaded from {_model_path}") + + # Load training states + if load_trainer_state: + checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) + self.effective_iter = checkpoint["effective_iter"] + self.epoch = checkpoint["epoch"] + self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] + self.in_evaluation = checkpoint["in_evaluation"] + self.global_seed_sequence = checkpoint["global_seed_sequence"] + + self.best_metric = checkpoint["best_metric"] + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + logging.info(f"optimizer state is loaded from {ckpt_path}") + + if resume_lr_scheduler: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + logging.info(f"LR scheduler state is loaded from {ckpt_path}") + + logging.info( + f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" + ) + return + + def _get_backup_ckpt_name(self): + return f"iter_{self.effective_iter:06d}" \ No newline at end of file diff --git a/src/trainer/marigold_trainer.py b/src/trainer/marigold_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..a9eb24281c13115ddbb6f4da49a0a214d17aeb35 --- /dev/null +++ b/src/trainer/marigold_trainer.py @@ -0,0 +1,968 @@ +# An official reimplemented version of Marigold training script. +# Last modified: 2024-04-29 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + + +import logging +import os +import pdb +import shutil +from datetime import datetime +from typing import List, Union +import random +import safetensors +import numpy as np +import torch +from diffusers import DDPMScheduler +from omegaconf import OmegaConf +from torch.nn import Conv2d +from torch.nn.parameter import Parameter +from torch.optim import Adam +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader +from tqdm import tqdm +from PIL import Image +# import torch.optim.lr_scheduler + +from marigold.marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput +from src.util import metric +from src.util.data_loader import skip_first_batches +from src.util.logging_util import tb_logger, eval_dic_to_text +from src.util.loss import get_loss +from src.util.lr_scheduler import IterExponential +from src.util.metric import MetricTracker +from src.util.multi_res_noise import multi_res_noise_like +from src.util.alignment import align_depth_least_square, depth2disparity, disparity2depth +from src.util.seeding import generate_seed_sequence +from accelerate import Accelerator +import os +# os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + +class MarigoldTrainer: + def __init__( + self, + cfg: OmegaConf, + model: MarigoldPipeline, + train_dataloader: DataLoader, + device, + base_ckpt_dir, + out_dir_ckpt, + out_dir_eval, + out_dir_vis, + accumulation_steps: int, + depth_model = None, + separate_list: List = None, + val_dataloaders: List[DataLoader] = None, + vis_dataloaders: List[DataLoader] = None, + timestep_method: str = 'unidiffuser' + ): + self.cfg: OmegaConf = cfg + self.model: MarigoldPipeline = model + self.depth_model = depth_model + self.device = device + self.seed: Union[int, None] = ( + self.cfg.trainer.init_seed + ) # used to generate seed sequence, set to `None` to train w/o seeding + self.out_dir_ckpt = out_dir_ckpt + self.out_dir_eval = out_dir_eval + self.out_dir_vis = out_dir_vis + self.train_loader: DataLoader = train_dataloader + self.val_loaders: List[DataLoader] = val_dataloaders + self.vis_loaders: List[DataLoader] = vis_dataloaders + self.accumulation_steps: int = accumulation_steps + self.separate_list = separate_list + self.timestep_method = timestep_method + # Adapt input layers + # if 8 != self.model.unet.config["in_channels"]: + # self._replace_unet_conv_in() + # if 8 != self.model.unet.config["out_channels"]: + # self._replace_unet_conv_out() + + self.prompt = ['a view of a city skyline from a bridge', + 'a man and a woman sitting on a couch', + 'a black car parked in a parking lot next to the water', + 'Enchanted forest with glowing plants, fairies, and ancient castle.', + 'Futuristic city with skyscrapers, neon lights, and hovering vehicles.', + 'Fantasy mountain landscape with waterfalls, dragons, and mythical creatures.'] + # self.generator = torch.Generator('cuda:0').manual_seed(1024) + + # Encode empty text prompt + self.model.encode_empty_text() + self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) + + self.model.unet.enable_xformers_memory_efficient_attention() + + # Trainability + self.model.text_encoder.requires_grad_(False) + # self.model.unet.requires_grad_(True) + + grad_part = filter(lambda p: p.requires_grad, self.model.unet.parameters()) + + # Optimizer !should be defined after input layer is adapted + lr = self.cfg.lr + self.optimizer = Adam(grad_part, lr=lr) + + total_params = sum(p.numel() for p in self.model.unet.parameters()) + total_params_m = total_params / 1_000_000 + print(f"Total parameters: {total_params_m:.2f}M") + trainable_params = sum(p.numel() for p in self.model.unet.parameters() if p.requires_grad) + trainable_params_m = trainable_params / 1_000_000 + print(f"Trainable parameters: {trainable_params_m:.2f}M") + + # LR scheduler + lr_func = IterExponential( + total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, + final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, + warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, + ) + self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) + + # Loss + self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) + + # Training noise scheduler + self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained( + os.path.join( + cfg.trainer.training_noise_scheduler.pretrained_path, + "scheduler", + ) + ) + # pdb.set_trace() + self.prediction_type = self.training_noise_scheduler.config.prediction_type + assert ( + self.prediction_type == self.model.scheduler.config.prediction_type + ), "Different prediction types" + self.scheduler_timesteps = ( + self.training_noise_scheduler.config.num_train_timesteps + ) + + # Eval metrics + self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics] + self.train_metrics = MetricTracker(*["loss", 'rgb_loss', 'depth_loss']) + self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs]) + # main metric for best checkpoint saving + self.main_val_metric = cfg.validation.main_val_metric + self.main_val_metric_goal = cfg.validation.main_val_metric_goal + assert ( + self.main_val_metric in cfg.eval.eval_metrics + ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." + self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 + + # Settings + self.max_epoch = self.cfg.max_epoch + self.max_iter = self.cfg.max_iter + self.gradient_accumulation_steps = accumulation_steps + self.gt_depth_type = self.cfg.gt_depth_type + self.gt_mask_type = self.cfg.gt_mask_type + self.save_period = self.cfg.trainer.save_period + self.backup_period = self.cfg.trainer.backup_period + self.val_period = self.cfg.trainer.validation_period + self.vis_period = self.cfg.trainer.visualization_period + + # Multi-resolution noise + self.apply_multi_res_noise = self.cfg.multi_res_noise is not None + if self.apply_multi_res_noise: + self.mr_noise_strength = self.cfg.multi_res_noise.strength + self.annealed_mr_noise = self.cfg.multi_res_noise.annealed + self.mr_noise_downscale_strategy = ( + self.cfg.multi_res_noise.downscale_strategy + ) + + # Internal variables + self.epoch = 0 + self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training + self.effective_iter = 0 # how many times optimizer.step() is called + self.in_evaluation = False + self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming + + def _replace_unet_conv_in(self): + # replace the first layer to accept 8 in_channels + _weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3] + _bias = self.model.unet.conv_in.bias.clone() # [320] + zero_weight = torch.zeros(_weight.shape).to(_weight.device) + _weight = torch.cat([_weight, zero_weight], dim=1) + # _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s) + # half the activation magnitude + # _weight *= 0.5 + # new conv_in channel + _n_convin_out_channel = self.model.unet.conv_in.out_channels + _new_conv_in = Conv2d( + 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + _new_conv_in.weight = Parameter(_weight) + _new_conv_in.bias = Parameter(_bias) + self.model.unet.conv_in = _new_conv_in + logging.info("Unet conv_in layer is replaced") + # replace config + self.model.unet.config["in_channels"] = 8 + logging.info("Unet config is updated") + return + + def _replace_unet_conv_out(self): + # replace the first layer to accept 8 in_channels + _weight = self.model.unet.conv_out.weight.clone() # [8, 320, 3, 3] + _bias = self.model.unet.conv_out.bias.clone() # [320] + _weight = _weight.repeat((2, 1, 1, 1)) # Keep selected channel(s) + _bias = _bias.repeat((2)) + # half the activation magnitude + + # new conv_in channel + _n_convin_out_channel = self.model.unet.conv_out.out_channels + _new_conv_out = Conv2d( + _n_convin_out_channel, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + _new_conv_out.weight = Parameter(_weight) + _new_conv_out.bias = Parameter(_bias) + self.model.unet.conv_out = _new_conv_out + logging.info("Unet conv_out layer is replaced") + # replace config + self.model.unet.config["out_channels"] = 8 + logging.info("Unet config is updated") + return + + def parallel_train(self, t_end=None, accelerator=None): + logging.info("Start training") + # pdb.set_trace() + self.model, self.optimizer, self.train_loader, self.lr_scheduler = accelerator.prepare( + self.model, self.optimizer, self.train_loader, self.lr_scheduler + ) + self.depth_model = accelerator.prepare(self.depth_model) + + self.accelerator = accelerator + if self.val_loaders is not None: + for idx, loader in enumerate(self.val_loaders): + self.val_loaders[idx] = accelerator.prepare(loader) + + if os.path.exists(os.path.join(self.out_dir_ckpt, 'latest')): + accelerator.load_state(os.path.join(self.out_dir_ckpt, 'latest')) + self.load_miscs(os.path.join(self.out_dir_ckpt, 'latest')) + + self.train_metrics.reset() + accumulated_step = 0 + for epoch in range(self.epoch, self.max_epoch + 1): + self.epoch = epoch + logging.debug(f"epoch: {self.epoch}") + + # Skip previous batches when resume + for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): + self.model.unet.train() + + # globally consistent random generators + if self.seed is not None: + local_seed = self._get_next_seed() + rand_num_generator = torch.Generator(device=self.model.device) + rand_num_generator.manual_seed(local_seed) + else: + rand_num_generator = None + + # >>> With gradient accumulation >>> + + # Get data + rgb = batch["rgb_norm"].to(self.model.device) + if self.gt_depth_type not in batch: + with torch.no_grad(): + disparities = self.depth_model(batch["rgb_int"].numpy().astype(np.uint8), 518, device=self.model.device) + depth_gt_for_latent = [] + for disparity_map in disparities: + depth_map = ((disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min())) * 2 - 1 + depth_gt_for_latent.append(depth_map) + depth_gt_for_latent = torch.stack(depth_gt_for_latent, dim=0) + else: + if "least_square_disparity" == self.cfg.eval.alignment: + # convert GT depth -> GT disparity + depth_raw_ts = batch["depth_raw_linear"].squeeze() + depth_raw = depth_raw_ts.cpu().numpy() + # pdb.set_trace() + disparities = depth2disparity( + depth=depth_raw + ) + depth_gt_for_latent = [] + for disparity_map in disparities: + depth_map = ((disparity_map - disparity_map.min()) / ( + disparity_map.max() - disparity_map.min())) * 2 - 1 + depth_gt_for_latent.append(torch.from_numpy(depth_map)) + depth_gt_for_latent = torch.stack(depth_gt_for_latent, dim=0).to(self.model.device) + else: + depth_gt_for_latent = batch[self.gt_depth_type].to(self.model.device) + + batch_size = rgb.shape[0] + + if self.gt_mask_type is not None: + valid_mask_for_latent = batch[self.gt_mask_type].to(self.model.device) + invalid_mask = ~valid_mask_for_latent + valid_mask_down = ~torch.max_pool2d( + invalid_mask.float(), 8, 8 + ).bool() + valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1)) + + with torch.no_grad(): + # Encode image + rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w] + # Encode GT depth + gt_depth_latent = self.encode_depth( + depth_gt_for_latent + ) # [B, 4, h, w] + # Sample a random timestep for each image + if self.cfg.loss.depth_factor == 1: + rgb_timesteps = torch.zeros( + (batch_size), + device=self.model.device + ).long() # [B] + depth_timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=self.model.device, + generator=rand_num_generator, + ).long() # [B] + elif self.timestep_method == 'unidiffuser': + rgb_timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=self.model.device, + generator=rand_num_generator, + ).long() # [B] + depth_timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=self.model.device, + generator=rand_num_generator, + ).long() # [B] + elif self.timestep_method == 'joint': + rgb_timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=self.model.device, + generator=rand_num_generator, + ).long() # [B] + depth_timesteps = rgb_timesteps # [B] + elif self.timestep_method == 'partition': + rand_num = random.random() + if rand_num < 0.3333: + # joint generation + rgb_timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=self.model.device, + generator=rand_num_generator, + ).long() # [B] + depth_timesteps = rgb_timesteps + elif rand_num < 0.6666: + # image2depth generation + rgb_timesteps = torch.zeros( + (batch_size), + device=self.model.device + ).long() # [B] + depth_timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=self.model.device, + generator=rand_num_generator, + ).long() # [B] + else: + # depth2image generation + rgb_timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=self.model.device, + generator=rand_num_generator, + ).long() # [B] + depth_timesteps = torch.zeros( + (batch_size), + device=self.model.device + ).long() # [B] + + # Sample noise + if self.apply_multi_res_noise: + rgb_strength = self.mr_noise_strength + if self.annealed_mr_noise: + # calculate strength depending on t + rgb_strength = rgb_strength * (rgb_timesteps / self.scheduler_timesteps) + rgb_noise = multi_res_noise_like( + rgb_latent, + strength=rgb_strength, + downscale_strategy=self.mr_noise_downscale_strategy, + generator=rand_num_generator, + device=self.model.device, + ) + + depth_strength = self.mr_noise_strength + if self.annealed_mr_noise: + # calculate strength depending on t + depth_strength = depth_strength * (depth_timesteps / self.scheduler_timesteps) + depth_noise = multi_res_noise_like( + gt_depth_latent, + strength=depth_strength, + downscale_strategy=self.mr_noise_downscale_strategy, + generator=rand_num_generator, + device=self.model.device, + ) + else: + rgb_noise = torch.randn( + rgb_latent.shape, + device=self.model.device, + generator=rand_num_generator, + ) # [B, 8, h, w] + + depth_noise = torch.randn( + gt_depth_latent.shape, + device=self.model.device, + generator=rand_num_generator, + ) # [B, 8, h, w] + # Add noise to the latents (diffusion forward process) + + if depth_timesteps.sum() == 0: + noisy_rgb_latents = rgb_latent + else: + noisy_rgb_latents = self.training_noise_scheduler.add_noise( + rgb_latent, rgb_noise, rgb_timesteps + ) # [B, 4, h, w] + + noisy_depth_latents = self.training_noise_scheduler.add_noise( + gt_depth_latent, depth_noise, depth_timesteps + ) # [B, 4, h, w] + + noisy_latents = torch.cat( + [noisy_rgb_latents, noisy_depth_latents], dim=1 + ).float() # [B, 8, h, w] + + # Text embedding + input_ids = self.model.tokenizer( + batch['text'], + padding="max_length", + max_length=self.model.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + input_ids = {k: v.to(self.model.device) for k, v in input_ids.items()} + text_embed = self.model.text_encoder(**input_ids)[0] + # text_embed = self.empty_text_embed.to(device).repeat( + # (batch_size, 1, 1) + # ) # [B, 77, 1024] + model_pred = self.model.unet( + noisy_latents, rgb_timesteps, depth_timesteps, text_embed + ).sample # [B, 4, h, w] + if torch.isnan(model_pred).any(): + logging.warning("model_pred contains NaN.") + + # Get the target for loss depending on the prediction type + if "sample" == self.prediction_type: + rgb_target = rgb_latent + depth_target = gt_depth_latent + elif "epsilon" == self.prediction_type: + rgb_target = rgb_latent + depth_target = gt_depth_latent + elif "v_prediction" == self.prediction_type: + rgb_target = self.training_noise_scheduler.get_velocity( + rgb_latent, rgb_noise, rgb_timesteps + ) # [B, 4, h, w] + depth_target = self.training_noise_scheduler.get_velocity( + gt_depth_latent, depth_noise, depth_timesteps + ) # [B, 4, h, w] + else: + raise ValueError(f"Unknown prediction type {self.prediction_type}") + # Masked latent loss + with accelerator.accumulate(self.model): + if self.gt_mask_type is not None: + depth_loss = self.loss( + model_pred[:, 4:, :, :][valid_mask_down].float(), + depth_target[valid_mask_down].float(), + ) + else: + depth_loss = self.loss(model_pred[:, 4:, :, :].float(),depth_target.float()) + + rgb_loss = self.loss(model_pred[:, 0:4, :, :].float(), rgb_target.float()) + + if torch.sum(rgb_timesteps) == 0 or torch.sum(rgb_timesteps) == len(rgb_timesteps) * self.scheduler_timesteps: + loss = depth_loss + elif torch.sum(depth_timesteps) == 0 or torch.sum(depth_timesteps) == len(depth_timesteps) * self.scheduler_timesteps: + loss = rgb_loss + else: + loss = self.cfg.loss.depth_factor * depth_loss + (1 - self.cfg.loss.depth_factor) * rgb_loss + + self.train_metrics.update("loss", loss.item()) + self.train_metrics.update("rgb_loss", rgb_loss.item()) + self.train_metrics.update("depth_loss", depth_loss.item()) + # loss = loss / self.gradient_accumulation_steps + accelerator.backward(loss) + self.optimizer.step() + self.optimizer.zero_grad() + # loss.backward() + self.n_batch_in_epoch += 1 + # print(accelerator.process_index, self.lr_scheduler.get_last_lr()) + self.lr_scheduler.step(self.effective_iter) + + if accelerator.sync_gradients: + accumulated_step += 1 + + if accumulated_step >= self.gradient_accumulation_steps: + accumulated_step = 0 + self.effective_iter += 1 + + if accelerator.is_main_process: + # Log to tensorboard + if self.effective_iter == 1: + generator = torch.Generator(self.model.device).manual_seed(1024) + img = self.model.generate_rgbd(self.prompt, num_inference_steps=50, generator=generator, + show_pbar=True) + for idx in range(len(self.prompt)): + tb_logger.writer.add_image(f'image/{self.prompt[idx]}', img[idx], self.effective_iter) + self._depth2image() + self._image2depth() + + accumulated_loss = self.train_metrics.result()["loss"] + rgb_loss = self.train_metrics.result()["rgb_loss"] + depth_loss = self.train_metrics.result()["depth_loss"] + tb_logger.log_dic( + { + f"train/{k}": v + for k, v in self.train_metrics.result().items() + }, + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "lr", + self.lr_scheduler.get_last_lr()[0], + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "n_batch_in_epoch", + self.n_batch_in_epoch, + global_step=self.effective_iter, + ) + logging.info( + f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}, rgb_loss={rgb_loss:.5f}, depth_loss={depth_loss:.5f}" + ) + accelerator.wait_for_everyone() + + if self.save_period > 0 and 0 == self.effective_iter % self.save_period: + accelerator.save_state(output_dir=os.path.join(self.out_dir_ckpt, 'latest')) + unwrapped_model = accelerator.unwrap_model(self.model) + if accelerator.is_main_process: + accelerator.save_model(unwrapped_model.unet, + os.path.join(self.out_dir_ckpt, 'latest'), safe_serialization=False) + self.save_miscs('latest') + + # RGB-D joint generation + generator = torch.Generator(self.model.device).manual_seed(1024) + img = self.model.generate_rgbd(self.prompt, num_inference_steps=50, generator=generator, show_pbar=False, height=64, width=64) + for idx in range(len(self.prompt)): + tb_logger.writer.add_image(f'image/{self.prompt[idx]}', img[idx], self.effective_iter) + + # depth to RGB generation + self._depth2image() + # # RGB to depth generation + self._image2depth() + + accelerator.wait_for_everyone() + + if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: + unwrapped_model = accelerator.unwrap_model(self.model) + if accelerator.is_main_process: + unwrapped_model.unet.save_pretrained( + os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name())) + accelerator.wait_for_everyone() + + if self.val_period > 0 and 0 == self.effective_iter % self.val_period: + self.validate() + + # End of training + if self.max_iter > 0 and self.effective_iter >= self.max_iter: + unwrapped_model = accelerator.unwrap_model(self.model) + if accelerator.is_main_process: + unwrapped_model.unet.save_pretrained( + os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name())) + accelerator.wait_for_everyone() + return + + torch.cuda.empty_cache() + # <<< Effective batch end <<< + + # Epoch end + self.n_batch_in_epoch = 0 + + def _image2depth(self): + generator = torch.Generator(self.model.device).manual_seed(1024) + image2dept_paths = ['/home/aiops/wangzh/data/scannet/scene0593_00/color/000100.jpg', + '/home/aiops/wangzh/data/scannet/scene0593_00/color/000700.jpg', + '/home/aiops/wangzh/data/scannet/scene0591_01/color/000600.jpg', + '/home/aiops/wangzh/data/scannet/scene0591_01/color/001500.jpg'] + for img_idx, image_path in enumerate(image2dept_paths): + rgb_input = Image.open(image_path) + depth_pred: MarigoldDepthOutput = self.model.image2depth( + rgb_input, + denoising_steps=self.cfg.validation.denoising_steps, + ensemble_size=self.cfg.validation.ensemble_size, + processing_res=self.cfg.validation.processing_res, + match_input_res=self.cfg.validation.match_input_res, + generator=generator, + batch_size=self.cfg.validation.ensemble_size, + # use batch size 1 to increase reproducibility + color_map="Spectral", + show_progress_bar=False, + resample_method=self.cfg.validation.resample_method, + ) + img = self.model.post_process_rgbd(['None'], [rgb_input], [depth_pred['depth_colored']]) + tb_logger.writer.add_image(f'image2depth_{img_idx}', img[0], self.effective_iter) + + def _depth2image(self): + generator = torch.Generator(self.model.device).manual_seed(1024) + if "least_square_disparity" == self.cfg.eval.alignment: + depth2image_path = ['/home/aiops/wangzh/data/ori_depth_part0-0/sa_10000335.jpg', + '/home/aiops/wangzh/data/ori_depth_part0-0/sa_3572319.jpg', + '/home/aiops/wangzh/data/ori_depth_part0-0/sa_457934.jpg'] + else: + depth2image_path = ['/home/aiops/wangzh/data/sa_001000/sa_10000335.jpg', + '/home/aiops/wangzh/data/sa_000357/sa_3572319.jpg', + '/home/aiops/wangzh/data/sa_000045/sa_457934.jpg'] + prompts = ['Red car parked in the factory', + 'White gothic church with cemetery next to it', + 'House with red roof and starry sky in the background'] + for img_idx, depth_path in enumerate(depth2image_path): + depth_input = Image.open(depth_path) + image_pred = self.model.single_depth2image( + depth_input, + prompts[img_idx], + num_inference_steps=50, + processing_res=self.cfg.validation.processing_res, + generator=generator, + show_pbar=False, + resample_method=self.cfg.validation.resample_method, + ) + img = self.model.post_process_rgbd([prompts[img_idx]], [image_pred], [depth_input]) + tb_logger.writer.add_image(f'depth2image_{img_idx}', img[0], self.effective_iter) + + def encode_depth(self, depth_in): + # stack depth into 3-channel + stacked = self.stack_depth_images(depth_in) + # encode using VAE encoder + depth_latent = self.model.encode_rgb(stacked) + return depth_latent + + @staticmethod + def stack_depth_images(depth_in): + if 4 == len(depth_in.shape): + stacked = depth_in.repeat(1, 3, 1, 1) + elif 3 == len(depth_in.shape): + stacked = depth_in.unsqueeze(1) + stacked = stacked.repeat(1, 3, 1, 1) + return stacked + + def validate(self): + for i, val_loader in enumerate(self.val_loaders): + val_dataset_name = val_loader.dataset.disp_name + val_metric_dic = self.validate_single_dataset( + data_loader=val_loader, metric_tracker=self.val_metrics + ) + + if self.accelerator.is_main_process: + val_metric_dic = {k:torch.tensor(v).cuda() for k,v in val_metric_dic.items()} + + tb_logger.log_dic( + {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()}, + global_step=self.effective_iter, + ) + # save to file + eval_text = eval_dic_to_text( + val_metrics=val_metric_dic, + dataset_name=val_dataset_name, + sample_list_path=val_loader.dataset.filename_ls_path, + ) + _save_to = os.path.join( + self.out_dir_eval, + f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", + ) + with open(_save_to, "w+") as f: + f.write(eval_text) + + # Update main eval metric + if 0 == i: + main_eval_metric = val_metric_dic[self.main_val_metric] + if ( + "minimize" == self.main_val_metric_goal + and main_eval_metric < self.best_metric + or "maximize" == self.main_val_metric_goal + and main_eval_metric > self.best_metric + ): + self.best_metric = main_eval_metric + logging.info( + f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" + ) + # Save a checkpoint + self.save_checkpoint( + ckpt_name='best', save_train_state=False + ) + + self.accelerator.wait_for_everyone() + + def visualize(self): + for val_loader in self.vis_loaders: + vis_dataset_name = val_loader.dataset.disp_name + vis_out_dir = os.path.join( + self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name + ) + os.makedirs(vis_out_dir, exist_ok=True) + _ = self.validate_single_dataset( + data_loader=val_loader, + metric_tracker=self.val_metrics, + save_to_dir=vis_out_dir, + ) + + @torch.no_grad() + def validate_single_dataset( + self, + data_loader: DataLoader, + metric_tracker: MetricTracker, + save_to_dir: str = None, + ): + self.model.to(self.device) + metric_tracker.reset() + + # Generate seed sequence for consistent evaluation + val_init_seed = self.cfg.validation.init_seed + val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) + + for i, batch in enumerate( + tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), + start=1, + ): + + rgb_int = batch["rgb_int"] # [3, H, W] + # GT depth + depth_raw_ts = batch["depth_raw_linear"].squeeze() + depth_raw = depth_raw_ts.cpu().numpy() + depth_raw_ts = depth_raw_ts.to(self.device) + valid_mask_ts = batch["valid_mask_raw"].squeeze() + valid_mask = valid_mask_ts.cpu().numpy() + valid_mask_ts = valid_mask_ts.to(self.device) + + # Random number generator + seed = val_seed_ls.pop() + if seed is None: + generator = None + else: + generator = torch.Generator(device=self.device) + generator.manual_seed(seed) + + # Predict depth + pipe_out: MarigoldDepthOutput = self.model.image2depth( + rgb_int, + denoising_steps=self.cfg.validation.denoising_steps, + ensemble_size=self.cfg.validation.ensemble_size, + processing_res=self.cfg.validation.processing_res, + match_input_res=self.cfg.validation.match_input_res, + generator=generator, + batch_size=self.cfg.validation.ensemble_size, # use batch size 1 to increase reproducibility + color_map=None, + show_progress_bar=False, + resample_method=self.cfg.validation.resample_method, + ) + + depth_pred: np.ndarray = pipe_out.depth_np + + if "least_square" == self.cfg.eval.alignment: + depth_pred, scale, shift = align_depth_least_square( + gt_arr=depth_raw, + pred_arr=depth_pred, + valid_mask_arr=valid_mask, + return_scale_shift=True, + max_resolution=self.cfg.eval.align_max_res, + ) + elif "least_square_disparity" == self.cfg.eval.alignment: + # convert GT depth -> GT disparity + gt_disparity, gt_non_neg_mask = depth2disparity( + depth=depth_raw, return_mask=True + ) + + pred_non_neg_mask = depth_pred > 0 + valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask + + disparity_pred, scale, shift = align_depth_least_square( + gt_arr=gt_disparity, + pred_arr=depth_pred, + valid_mask_arr=valid_nonnegative_mask, + return_scale_shift=True, + max_resolution=self.cfg.eval.align_max_res, + ) + # convert to depth + disparity_pred = np.clip( + disparity_pred, a_min=1e-3, a_max=None + ) # avoid 0 disparity + depth_pred = disparity2depth(disparity_pred) + + # Clip to dataset min max + depth_pred = np.clip( + depth_pred, + a_min=data_loader.dataset.min_depth, + a_max=data_loader.dataset.max_depth, + ) + + # clip to d > 0 for evaluation + depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) + + # Evaluate + sample_metric = [] + depth_pred_ts = torch.from_numpy(depth_pred).to(self.device) + + for met_func in self.metric_funcs: + _metric_name = met_func.__name__ + _metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).cuda(self.accelerator.process_index) + self.accelerator.wait_for_everyone() + _metric = self.accelerator.gather_for_metrics(_metric.unsqueeze(0)).mean().item() + sample_metric.append(_metric.__str__()) + metric_tracker.update(_metric_name, _metric) + + self.accelerator.wait_for_everyone() + # Save as 16-bit uint png + if save_to_dir is not None: + img_name = batch["rgb_relative_path"][0].replace("/", "_") + png_save_path = os.path.join(save_to_dir, f"{img_name}.png") + depth_to_save = (pipe_out.depth_np * 65535.0).astype(np.uint16) + Image.fromarray(depth_to_save).save(png_save_path, mode="I;16") + + return metric_tracker.result() + + def _get_next_seed(self): + if 0 == len(self.global_seed_sequence): + self.global_seed_sequence = generate_seed_sequence( + initial_seed=self.seed, + length=self.max_iter * self.gradient_accumulation_steps, + ) + logging.info( + f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" + ) + return self.global_seed_sequence.pop() + + def save_miscs(self, ckpt_name): + ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) + state = { + "config": self.cfg, + "effective_iter": self.effective_iter, + "epoch": self.epoch, + "n_batch_in_epoch": self.n_batch_in_epoch, + "best_metric": self.best_metric, + "in_evaluation": self.in_evaluation, + "global_seed_sequence": self.global_seed_sequence, + } + train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") + torch.save(state, train_state_path) + + logging.info(f"Misc state is saved to: {train_state_path}") + + def load_miscs(self, ckpt_path): + checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) + self.effective_iter = checkpoint["effective_iter"] + self.epoch = checkpoint["epoch"] + self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] + self.in_evaluation = checkpoint["in_evaluation"] + self.global_seed_sequence = checkpoint["global_seed_sequence"] + + self.best_metric = checkpoint["best_metric"] + + logging.info(f"Misc state is loaded from {ckpt_path}") + + + def save_checkpoint(self, ckpt_name, save_train_state): + ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) + logging.info(f"Saving checkpoint to: {ckpt_dir}") + # Backup previous checkpoint + temp_ckpt_dir = None + if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): + temp_ckpt_dir = os.path.join( + os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" + ) + if os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + os.rename(ckpt_dir, temp_ckpt_dir) + logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") + + # Save UNet + unet_path = os.path.join(ckpt_dir, "unet") + self.model.unet.save_pretrained(unet_path, safe_serialization=False) + logging.info(f"UNet is saved to: {unet_path}") + + if save_train_state: + state = { + "config": self.cfg, + "effective_iter": self.effective_iter, + "epoch": self.epoch, + "n_batch_in_epoch": self.n_batch_in_epoch, + "best_metric": self.best_metric, + "in_evaluation": self.in_evaluation, + "global_seed_sequence": self.global_seed_sequence, + } + train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") + torch.save(state, train_state_path) + # iteration indicator + f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") + f.close() + + logging.info(f"Trainer state is saved to: {train_state_path}") + + # Remove temp ckpt + if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + logging.debug("Old checkpoint backup is removed.") + + def load_checkpoint( + self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True + ): + logging.info(f"Loading checkpoint from: {ckpt_path}") + # Load UNet + _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") + self.model.unet.load_state_dict( + torch.load(_model_path, map_location=self.device) + ) + self.model.unet.to(self.device) + logging.info(f"UNet parameters are loaded from {_model_path}") + + # Load training states + if load_trainer_state: + checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) + self.effective_iter = checkpoint["effective_iter"] + self.epoch = checkpoint["epoch"] + self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] + self.in_evaluation = checkpoint["in_evaluation"] + self.global_seed_sequence = checkpoint["global_seed_sequence"] + + self.best_metric = checkpoint["best_metric"] + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + logging.info(f"optimizer state is loaded from {ckpt_path}") + + if resume_lr_scheduler: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + logging.info(f"LR scheduler state is loaded from {ckpt_path}") + + logging.info( + f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" + ) + return + + def _get_backup_ckpt_name(self): + return f"iter_{self.effective_iter:06d}" diff --git a/src/trainer/marigold_xl_trainer.py b/src/trainer/marigold_xl_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..2a46c820c75915eed92e816418fc5e8cb3b00258 --- /dev/null +++ b/src/trainer/marigold_xl_trainer.py @@ -0,0 +1,948 @@ +# An official reimplemented version of Marigold training script. +# Last modified: 2024-04-29 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + + +import logging +import os +import pdb +import shutil +from datetime import datetime +from typing import List, Union +import safetensors +import numpy as np +import torch +from diffusers import DDPMScheduler +from omegaconf import OmegaConf +from torch.nn import Conv2d +from torch.nn.parameter import Parameter +from torch.optim import Adam +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader +from tqdm import tqdm +from PIL import Image +# import torch.optim.lr_scheduler + +from marigold.marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput +from src.util import metric +from src.util.data_loader import skip_first_batches +from src.util.logging_util import tb_logger, eval_dic_to_text +from src.util.loss import get_loss +from src.util.lr_scheduler import IterExponential +from src.util.metric import MetricTracker +from src.util.multi_res_noise import multi_res_noise_like +from src.util.alignment import align_depth_least_square +from src.util.seeding import generate_seed_sequence +from accelerate import Accelerator +import random + +class MarigoldXLTrainer: + def __init__( + self, + cfg: OmegaConf, + model: MarigoldPipeline, + train_dataloader: DataLoader, + device, + base_ckpt_dir, + out_dir_ckpt, + out_dir_eval, + out_dir_vis, + accumulation_steps: int, + separate_list: List = None, + val_dataloaders: List[DataLoader] = None, + vis_dataloaders: List[DataLoader] = None, + timestep_method: str = 'unidiffuser' + ): + self.cfg: OmegaConf = cfg + self.model: MarigoldPipeline = model + self.device = device + self.seed: Union[int, None] = ( + self.cfg.trainer.init_seed + ) # used to generate seed sequence, set to `None` to train w/o seeding + self.out_dir_ckpt = out_dir_ckpt + self.out_dir_eval = out_dir_eval + self.out_dir_vis = out_dir_vis + self.train_loader: DataLoader = train_dataloader + self.val_loaders: List[DataLoader] = val_dataloaders + self.vis_loaders: List[DataLoader] = vis_dataloaders + self.accumulation_steps: int = accumulation_steps + self.separate_list = separate_list + self.timestep_method = timestep_method + # Adapt input layers + # if 8 != self.model.unet.config["in_channels"]: + # self._replace_unet_conv_in() + # if 8 != self.model.unet.config["out_channels"]: + # self._replace_unet_conv_out() + + self.prompt = ['a view of a city skyline from a bridge', + 'a man and a woman sitting on a couch', + 'a black car parked in a parking lot next to the water', + 'Enchanted forest with glowing plants, fairies, and ancient castle.', + 'Futuristic city with skyscrapers, neon lights, and hovering vehicles.', + 'Fantasy mountain landscape with waterfalls, dragons, and mythical creatures.'] + # self.generator = torch.Generator('cuda:0').manual_seed(1024) + + # Encode empty text prompt + # self.model.encode_empty_text() + # self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) + + self.model.unet.enable_xformers_memory_efficient_attention() + + # Trainability + self.model.vae.requires_grad_(False) + self.model.text_encoder.requires_grad_(False) + # self.model.unet.requires_grad_(True) + + grad_part = filter(lambda p: p.requires_grad, self.model.unet.parameters()) + + # Optimizer !should be defined after input layer is adapted + lr = self.cfg.lr + self.optimizer = Adam(grad_part, lr=lr) + + total_params = sum(p.numel() for p in self.model.unet.parameters()) + total_params_m = total_params / 1_000_000 + print(f"Total parameters: {total_params_m:.2f}M") + trainable_params = sum(p.numel() for p in self.model.unet.parameters() if p.requires_grad) + trainable_params_m = trainable_params / 1_000_000 + print(f"Trainable parameters: {trainable_params_m:.2f}M") + + # LR scheduler + lr_func = IterExponential( + total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, + final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, + warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, + ) + self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) + + # Loss + self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) + + # Training noise scheduler + self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained( + os.path.join( + cfg.trainer.training_noise_scheduler.pretrained_path, + "scheduler", + ) + ) + self.prediction_type = self.training_noise_scheduler.config.prediction_type + assert ( + self.prediction_type == self.model.scheduler.config.prediction_type + ), "Different prediction types" + self.scheduler_timesteps = ( + self.training_noise_scheduler.config.num_train_timesteps + ) + + # Eval metrics + self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics] + self.train_metrics = MetricTracker(*["loss", 'rgb_loss', 'depth_loss']) + self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs]) + # main metric for best checkpoint saving + self.main_val_metric = cfg.validation.main_val_metric + self.main_val_metric_goal = cfg.validation.main_val_metric_goal + assert ( + self.main_val_metric in cfg.eval.eval_metrics + ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." + self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 + + # Settings + self.max_epoch = self.cfg.max_epoch + self.max_iter = self.cfg.max_iter + self.gradient_accumulation_steps = accumulation_steps + self.gt_depth_type = self.cfg.gt_depth_type + self.gt_mask_type = self.cfg.gt_mask_type + self.save_period = self.cfg.trainer.save_period + self.backup_period = self.cfg.trainer.backup_period + self.val_period = self.cfg.trainer.validation_period + self.vis_period = self.cfg.trainer.visualization_period + + # Multi-resolution noise + self.apply_multi_res_noise = self.cfg.multi_res_noise is not None + if self.apply_multi_res_noise: + self.mr_noise_strength = self.cfg.multi_res_noise.strength + self.annealed_mr_noise = self.cfg.multi_res_noise.annealed + self.mr_noise_downscale_strategy = ( + self.cfg.multi_res_noise.downscale_strategy + ) + + # Internal variables + self.epoch = 0 + self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training + self.effective_iter = 0 # how many times optimizer.step() is called + self.in_evaluation = False + self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming + + def _replace_unet_conv_in(self): + # replace the first layer to accept 8 in_channels + _weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3] + _bias = self.model.unet.conv_in.bias.clone() # [320] + zero_weight = torch.zeros(_weight.shape).to(_weight.device) + _weight = torch.cat([_weight, zero_weight], dim=1) + # _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s) + # half the activation magnitude + # _weight *= 0.5 + # new conv_in channel + _n_convin_out_channel = self.model.unet.conv_in.out_channels + _new_conv_in = Conv2d( + 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + _new_conv_in.weight = Parameter(_weight) + _new_conv_in.bias = Parameter(_bias) + self.model.unet.conv_in = _new_conv_in + logging.info("Unet conv_in layer is replaced") + # replace config + self.model.unet.config["in_channels"] = 8 + logging.info("Unet config is updated") + return + + def _replace_unet_conv_out(self): + # replace the first layer to accept 8 in_channels + _weight = self.model.unet.conv_out.weight.clone() # [8, 320, 3, 3] + _bias = self.model.unet.conv_out.bias.clone() # [320] + _weight = _weight.repeat((2, 1, 1, 1)) # Keep selected channel(s) + _bias = _bias.repeat((2)) + # half the activation magnitude + + # new conv_in channel + _n_convin_out_channel = self.model.unet.conv_out.out_channels + _new_conv_out = Conv2d( + _n_convin_out_channel, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + _new_conv_out.weight = Parameter(_weight) + _new_conv_out.bias = Parameter(_bias) + self.model.unet.conv_out = _new_conv_out + logging.info("Unet conv_out layer is replaced") + # replace config + self.model.unet.config["out_channels"] = 8 + logging.info("Unet config is updated") + return + + def parallel_train(self, t_end=None, accelerator=None): + logging.info("Start training") + + self.model, self.optimizer, self.train_loader, self.lr_scheduler = accelerator.prepare( + self.model, self.optimizer, self.train_loader, self.lr_scheduler + ) + self.accelerator = accelerator + if self.val_loaders is not None: + for idx, loader in enumerate(self.val_loaders): + self.val_loaders[idx] = accelerator.prepare(loader) + + if os.path.exists(os.path.join(self.out_dir_ckpt, 'latest')): + accelerator.load_state(os.path.join(self.out_dir_ckpt, 'latest')) + self.load_miscs(os.path.join(self.out_dir_ckpt, 'latest')) + + self.train_metrics.reset() + accumulated_step = 0 + for epoch in range(self.epoch, self.max_epoch + 1): + self.epoch = epoch + logging.debug(f"epoch: {self.epoch}") + + # Skip previous batches when resume + for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): + self.model.unet.train() + + # globally consistent random generators + if self.seed is not None: + local_seed = self._get_next_seed() + rand_num_generator = torch.Generator(device=self.model.device) + rand_num_generator.manual_seed(local_seed) + else: + rand_num_generator = None + + # >>> With gradient accumulation >>> + + # Get data + rgb = batch["rgb_norm"].to(self.model.device) + depth_gt_for_latent = batch[self.gt_depth_type].to(self.model.device) + batch_size = rgb.shape[0] + + if self.gt_mask_type is not None: + valid_mask_for_latent = batch[self.gt_mask_type].to(self.model.device) + invalid_mask = ~valid_mask_for_latent + valid_mask_down = ~torch.max_pool2d( + invalid_mask.float(), 8, 8 + ).bool() + valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1)) + + with torch.no_grad(): + # Encode image + rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w] + # Encode GT depth + gt_depth_latent = self.encode_depth( + depth_gt_for_latent + ) # [B, 4, h, w] + + # Sample a random timestep for each image + if self.cfg.loss.depth_factor == 1: + rgb_timesteps = torch.zeros( + (batch_size), + device=self.model.device + ).long() # [B] + depth_timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=self.model.device, + generator=rand_num_generator, + ).long() # [B] + elif self.timestep_method == 'unidiffuser': + rgb_timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=self.model.device, + generator=rand_num_generator, + ).long() # [B] + depth_timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=self.model.device, + generator=rand_num_generator, + ).long() # [B] + elif self.timestep_method == 'partition': + rand_num = random.random() + if rand_num < 0.3333: + # joint generation + rgb_timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=self.model.device, + generator=rand_num_generator, + ).long() # [B] + depth_timesteps = rgb_timesteps + elif rand_num < 0.6666: + # image2depth generation + rgb_timesteps = torch.zeros( + (batch_size), + device=self.model.device + ).long() # [B] + depth_timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=self.model.device, + generator=rand_num_generator, + ).long() # [B] + else: + # depth2image generation + rgb_timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=self.model.device, + generator=rand_num_generator, + ).long() # [B] + depth_timesteps = torch.zeros( + (batch_size), + device=self.model.device + ).long() # [B] + + # Sample noise + if self.apply_multi_res_noise: + rgb_strength = self.mr_noise_strength + if self.annealed_mr_noise: + # calculate strength depending on t + rgb_strength = rgb_strength * (rgb_timesteps / self.scheduler_timesteps) + rgb_noise = multi_res_noise_like( + rgb_latent, + strength=rgb_strength, + downscale_strategy=self.mr_noise_downscale_strategy, + generator=rand_num_generator, + device=self.model.device, + ) + + depth_strength = self.mr_noise_strength + if self.annealed_mr_noise: + # calculate strength depending on t + depth_strength = depth_strength * (depth_timesteps / self.scheduler_timesteps) + depth_noise = multi_res_noise_like( + gt_depth_latent, + strength=depth_strength, + downscale_strategy=self.mr_noise_downscale_strategy, + generator=rand_num_generator, + device=self.model.device, + ) + else: + rgb_noise = torch.randn( + rgb_latent.shape, + device=self.model.device, + generator=rand_num_generator, + ) # [B, 8, h, w] + + depth_noise = torch.randn( + gt_depth_latent.shape, + device=self.model.device, + generator=rand_num_generator, + ) # [B, 8, h, w] + # Add noise to the latents (diffusion forward process) + + noisy_rgb_latents = self.training_noise_scheduler.add_noise( + rgb_latent, rgb_noise, rgb_timesteps + ) # [B, 4, h, w] + noisy_depth_latents = self.training_noise_scheduler.add_noise( + gt_depth_latent, depth_noise, depth_timesteps + ) # [B, 4, h, w] + + noisy_latents = torch.cat( + [noisy_rgb_latents, noisy_depth_latents], dim=1 + ).float() # [B, 8, h, w] + + # Text embedding + batch_text_embed = [] + batch_pooled_text_embed = [] + for p in batch['text']: + prompt_embed, pooled_prompt_embed = self.model.encode_text(p) + batch_text_embed.append(prompt_embed) + batch_pooled_text_embed.append(pooled_prompt_embed) + batch_text_embed = torch.cat(batch_text_embed, dim=0) + batch_pooled_text_embed = torch.cat(batch_pooled_text_embed, dim=0) + # input_ids = {k:v.squeeze().to(self.model.device) for k,v in batch['text'].items()} + # prompt_embed, pooled_prompt_embed = self.model.encode_text(batch['text']) + # text_embed = self.empty_text_embed.to(device).repeat( + # (batch_size, 1, 1) + # ) # [B, 77, 1024] + # Predict the noise residual + add_time_ids = self.model._get_add_time_ids( + (batch['rgb_int'].shape[-2], batch['rgb_int'].shape[-1]), (0, 0), (batch['rgb_int'].shape[-2], batch['rgb_int'].shape[-1]), dtype=batch_text_embed.dtype + ) + pdb.set_trace() + dtype = self.model.unet.dtype + added_cond_kwargs = {"text_embeds": batch_pooled_text_embed.to(self.model.device).to(dtype), "time_ids": add_time_ids.to(self.model.device).to(dtype)} + model_pred = self.model.unet( + noisy_latents.to(self.model.unet.dtype), rgb_timesteps, depth_timesteps, encoder_hidden_states=batch_text_embed.to(dtype), + added_cond_kwargs=added_cond_kwargs, separate_list=self.separate_list + ).sample # [B, 4, h, w] + if torch.isnan(model_pred).any(): + logging.warning("model_pred contains NaN.") + + # Get the target for loss depending on the prediction type + if "sample" == self.prediction_type: + rgb_target = rgb_latent + depth_target = gt_depth_latent + elif "epsilon" == self.prediction_type: + rgb_target = rgb_latent + depth_target = gt_depth_latent + elif "v_prediction" == self.prediction_type: + rgb_target = self.training_noise_scheduler.get_velocity( + rgb_latent, rgb_noise, rgb_timesteps + ) # [B, 4, h, w] + depth_target = self.training_noise_scheduler.get_velocity( + gt_depth_latent, depth_noise, depth_timesteps + ) # [B, 4, h, w] + else: + raise ValueError(f"Unknown prediction type {self.prediction_type}") + # Masked latent loss + with accelerator.accumulate(self.model): + if self.gt_mask_type is not None: + depth_loss = self.loss( + model_pred[:, 4:, :, :][valid_mask_down].float(), + depth_target[valid_mask_down].float(), + ) + else: + depth_loss = self.cfg.loss.depth_factor * self.loss(model_pred[:, 4:, :, :].float(),depth_target.float()) + + rgb_loss = (1 - self.cfg.loss.depth_factor) * self.loss(model_pred[:, 0:4, :, :].float(), rgb_target.float()) + if self.cfg.loss.depth_factor == 1: + loss = depth_loss + else: + loss = rgb_loss + depth_loss + + self.train_metrics.update("loss", loss.item()) + self.train_metrics.update("rgb_loss", rgb_loss.item()) + self.train_metrics.update("depth_loss", depth_loss.item()) + # loss = loss / self.gradient_accumulation_steps + accelerator.backward(loss) + self.optimizer.step() + self.optimizer.zero_grad() + # loss.backward() + self.n_batch_in_epoch += 1 + # print(accelerator.process_index, self.lr_scheduler.get_last_lr()) + self.lr_scheduler.step(self.effective_iter) + + if accelerator.sync_gradients: + accumulated_step += 1 + + if accumulated_step >= self.gradient_accumulation_steps: + accumulated_step = 0 + self.effective_iter += 1 + + if accelerator.is_main_process: + # Log to tensorboard + if self.effective_iter == 1: + generator = torch.Generator(self.model.device).manual_seed(1024) + img = self.model.generate_rgbd(self.prompt, num_inference_steps=50, generator=generator, + show_pbar=True) + for idx in range(len(self.prompt)): + tb_logger.writer.add_image(f'image/{self.prompt[idx]}', img[idx], self.effective_iter) + + accumulated_loss = self.train_metrics.result()["loss"] + rgb_loss = self.train_metrics.result()["rgb_loss"] + depth_loss = self.train_metrics.result()["depth_loss"] + tb_logger.log_dic( + { + f"train/{k}": v + for k, v in self.train_metrics.result().items() + }, + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "lr", + self.lr_scheduler.get_last_lr()[0], + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "n_batch_in_epoch", + self.n_batch_in_epoch, + global_step=self.effective_iter, + ) + logging.info( + f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}, rgb_loss={rgb_loss:.5f}, depth_loss={depth_loss:.5f}" + ) + accelerator.wait_for_everyone() + + if self.save_period > 0 and 0 == self.effective_iter % self.save_period: + accelerator.save_state(output_dir=os.path.join(self.out_dir_ckpt, 'latest')) + unwrapped_model = accelerator.unwrap_model(self.model) + if accelerator.is_main_process: + accelerator.save_model(unwrapped_model.unet, + os.path.join(self.out_dir_ckpt, 'latest'), safe_serialization=False) + self.save_miscs('latest') + + # RGB-D joint generation + generator = torch.Generator(self.model.device).manual_seed(1024) + img = self.model.generate_rgbd(self.prompt, num_inference_steps=50, generator=generator,show_pbar=False) + for idx in range(len(self.prompt)): + tb_logger.writer.add_image(f'image/{self.prompt[idx]}', img[idx], self.effective_iter) + + # depth to RGB generation + self._depth2image() + from diffusers import StableDiffusionControlNetInpaintPipeline + # RGB to depth generation + self._image2depth() + + accelerator.wait_for_everyone() + + accelerator.wait_for_everyone() + + if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: + unwrapped_model = accelerator.unwrap_model(self.model) + if accelerator.is_main_process: + unwrapped_model.unet.save_pretrained( + os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name())) + accelerator.wait_for_everyone() + + if self.val_period > 0 and 0 == self.effective_iter % self.val_period: + self.validate() + + # End of training + if self.max_iter > 0 and self.effective_iter >= self.max_iter: + unwrapped_model = accelerator.unwrap_model(self.model) + if accelerator.is_main_process: + unwrapped_model.unet.save_pretrained( + os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name())) + accelerator.wait_for_everyone() + return + + torch.cuda.empty_cache() + # <<< Effective batch end <<< + + # Epoch end + self.n_batch_in_epoch = 0 + + def _image2depth(self): + generator = torch.Generator(self.model.device).manual_seed(1024) + image2dept_paths = ['/home/aiops/wangzh/data/scannet/scene0593_00/color/000100.jpg', + '/home/aiops/wangzh/data/scannet/scene0593_00/color/000700.jpg', + '/home/aiops/wangzh/data/scannet/scene0591_01/color/000600.jpg', + '/home/aiops/wangzh/data/scannet/scene0591_01/color/001500.jpg'] + for img_idx, image_path in enumerate(image2dept_paths): + rgb_input = Image.open(image_path) + depth_pred: MarigoldDepthOutput = self.model.image2depth( + rgb_input, + denoising_steps=self.cfg.validation.denoising_steps, + ensemble_size=self.cfg.validation.ensemble_size, + processing_res=self.cfg.validation.processing_res, + match_input_res=self.cfg.validation.match_input_res, + generator=generator, + batch_size=self.cfg.validation.ensemble_size, + # use batch size 1 to increase reproducibility + color_map="Spectral", + show_progress_bar=False, + resample_method=self.cfg.validation.resample_method, + ) + img = self.model.post_process_rgbd(['None'], [rgb_input], [depth_pred['depth_colored']]) + tb_logger.writer.add_image(f'image2depth_{img_idx}', img[0], self.effective_iter) + + def _depth2image(self): + generator = torch.Generator(self.model.device).manual_seed(1024) + if "least_square_disparity" == self.cfg.eval.alignment: + depth2image_path = ['/home/aiops/wangzh/data/ori_depth_part0-0/sa_10000335.jpg', + '/home/aiops/wangzh/data/ori_depth_part0-0/sa_3572319.jpg', + '/home/aiops/wangzh/data/ori_depth_part0-0/sa_457934.jpg'] + else: + depth2image_path = ['/home/aiops/wangzh/data/depth_part0-0/sa_10000335.jpg', + '/home/aiops/wangzh/data/depth_part0-0/sa_3572319.jpg', + '/home/aiops/wangzh/data/depth_part0-0/sa_457934.jpg'] + prompts = ['Red car parked in the factory', + 'White gothic church with cemetery next to it', + 'House with red roof and starry sky in the background'] + for img_idx, depth_path in enumerate(depth2image_path): + depth_input = Image.open(depth_path) + image_pred = self.model.single_depth2image( + depth_input, + prompts[img_idx], + num_inference_steps=50, + processing_res=1024, + generator=generator, + show_pbar=False, + resample_method=self.cfg.validation.resample_method, + ) + img = self.model.post_process_rgbd([prompts[img_idx]], [image_pred], [depth_input]) + tb_logger.writer.add_image(f'depth2image_{img_idx}', img[0], self.effective_iter) + + def encode_depth(self, depth_in): + # stack depth into 3-channel + stacked = self.stack_depth_images(depth_in) + # encode using VAE encoder + depth_latent = self.model.encode_rgb(stacked) + return depth_latent + + @staticmethod + def stack_depth_images(depth_in): + if 4 == len(depth_in.shape): + stacked = depth_in.repeat(1, 3, 1, 1) + elif 3 == len(depth_in.shape): + stacked = depth_in.unsqueeze(1) + stacked = depth_in.repeat(1, 3, 1, 1) + return stacked + + def _train_step_callback(self): + """Executed after every iteration""" + # Save backup (with a larger interval, without training states) + if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), save_train_state=False + ) + + _is_latest_saved = False + # Validation + if self.val_period > 0 and 0 == self.effective_iter % self.val_period: + self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + _is_latest_saved = True + self.validate() + self.in_evaluation = False + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + + # Save training checkpoint (can be resumed) + if ( + self.save_period > 0 + and 0 == self.effective_iter % self.save_period + and not _is_latest_saved + ): + generator = torch.Generator(self.model.device).manual_seed(1024) + img = self.model.generate_rgbd(self.prompt, num_inference_steps=50, generator=generator, show_pbar=True) + for idx in range(len(self.prompt)): + tb_logger.writer.add_image(f'image/{self.prompt[idx]}', img[idx], self.effective_iter) + + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + + # Visualization + if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period: + self.visualize() + + def validate(self): + for i, val_loader in enumerate(self.val_loaders): + val_dataset_name = val_loader.dataset.disp_name + val_metric_dic = self.validate_single_dataset( + data_loader=val_loader, metric_tracker=self.val_metrics + ) + + if self.accelerator.is_main_process: + val_metric_dic = {k:torch.tensor(v).cuda() for k,v in val_metric_dic.items()} + + tb_logger.log_dic( + {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()}, + global_step=self.effective_iter, + ) + # save to file + eval_text = eval_dic_to_text( + val_metrics=val_metric_dic, + dataset_name=val_dataset_name, + sample_list_path=val_loader.dataset.filename_ls_path, + ) + _save_to = os.path.join( + self.out_dir_eval, + f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", + ) + with open(_save_to, "w+") as f: + f.write(eval_text) + + # Update main eval metric + if 0 == i: + main_eval_metric = val_metric_dic[self.main_val_metric] + if ( + "minimize" == self.main_val_metric_goal + and main_eval_metric < self.best_metric + or "maximize" == self.main_val_metric_goal + and main_eval_metric > self.best_metric + ): + self.best_metric = main_eval_metric + logging.info( + f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" + ) + # Save a checkpoint + self.save_checkpoint( + ckpt_name='best', save_train_state=False + ) + + self.accelerator.wait_for_everyone() + + def visualize(self): + for val_loader in self.vis_loaders: + vis_dataset_name = val_loader.dataset.disp_name + vis_out_dir = os.path.join( + self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name + ) + os.makedirs(vis_out_dir, exist_ok=True) + _ = self.validate_single_dataset( + data_loader=val_loader, + metric_tracker=self.val_metrics, + save_to_dir=vis_out_dir, + ) + + @torch.no_grad() + def validate_single_dataset( + self, + data_loader: DataLoader, + metric_tracker: MetricTracker, + save_to_dir: str = None, + ): + self.model.to(self.device) + metric_tracker.reset() + + # Generate seed sequence for consistent evaluation + val_init_seed = self.cfg.validation.init_seed + val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) + + for i, batch in enumerate( + tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), + start=1, + ): + + rgb_int = batch["rgb_int"] # [3, H, W] + # GT depth + depth_raw_ts = batch["depth_raw_linear"].squeeze() + depth_raw = depth_raw_ts.cpu().numpy() + depth_raw_ts = depth_raw_ts.to(self.device) + valid_mask_ts = batch["valid_mask_raw"].squeeze() + valid_mask = valid_mask_ts.cpu().numpy() + valid_mask_ts = valid_mask_ts.to(self.device) + + # Random number generator + seed = val_seed_ls.pop() + if seed is None: + generator = None + else: + generator = torch.Generator(device=self.device) + generator.manual_seed(seed) + + # Predict depth + pipe_out: MarigoldDepthOutput = self.model.image2depth( + rgb_int, + denoising_steps=self.cfg.validation.denoising_steps, + ensemble_size=self.cfg.validation.ensemble_size, + processing_res=self.cfg.validation.processing_res, + match_input_res=self.cfg.validation.match_input_res, + generator=generator, + batch_size=self.cfg.validation.ensemble_size, # use batch size 1 to increase reproducibility + color_map=None, + show_progress_bar=False, + resample_method=self.cfg.validation.resample_method, + ) + + depth_pred: np.ndarray = pipe_out.depth_np + + if "least_square" == self.cfg.eval.alignment: + depth_pred, scale, shift = align_depth_least_square( + gt_arr=depth_raw, + pred_arr=depth_pred, + valid_mask_arr=valid_mask, + return_scale_shift=True, + max_resolution=self.cfg.eval.align_max_res, + ) + else: + raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}") + + # Clip to dataset min max + depth_pred = np.clip( + depth_pred, + a_min=data_loader.dataset.min_depth, + a_max=data_loader.dataset.max_depth, + ) + + # clip to d > 0 for evaluation + depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) + + # Evaluate + sample_metric = [] + depth_pred_ts = torch.from_numpy(depth_pred).to(self.device) + + for met_func in self.metric_funcs: + _metric_name = met_func.__name__ + _metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).cuda(self.accelerator.process_index) + self.accelerator.wait_for_everyone() + _metric = self.accelerator.gather_for_metrics(_metric.unsqueeze(0)).mean().item() + sample_metric.append(_metric.__str__()) + metric_tracker.update(_metric_name, _metric) + + self.accelerator.wait_for_everyone() + # Save as 16-bit uint png + if save_to_dir is not None: + img_name = batch["rgb_relative_path"][0].replace("/", "_") + png_save_path = os.path.join(save_to_dir, f"{img_name}.png") + depth_to_save = (pipe_out.depth_np * 65535.0).astype(np.uint16) + Image.fromarray(depth_to_save).save(png_save_path, mode="I;16") + + return metric_tracker.result() + + def _get_next_seed(self): + if 0 == len(self.global_seed_sequence): + self.global_seed_sequence = generate_seed_sequence( + initial_seed=self.seed, + length=self.max_iter * self.gradient_accumulation_steps, + ) + logging.info( + f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" + ) + return self.global_seed_sequence.pop() + + def save_miscs(self, ckpt_name): + ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) + state = { + "config": self.cfg, + "effective_iter": self.effective_iter, + "epoch": self.epoch, + "n_batch_in_epoch": self.n_batch_in_epoch, + "best_metric": self.best_metric, + "in_evaluation": self.in_evaluation, + "global_seed_sequence": self.global_seed_sequence, + } + train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") + torch.save(state, train_state_path) + + logging.info(f"Misc state is saved to: {train_state_path}") + + def load_miscs(self, ckpt_path): + checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) + self.effective_iter = checkpoint["effective_iter"] + self.epoch = checkpoint["epoch"] + self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] + self.in_evaluation = checkpoint["in_evaluation"] + self.global_seed_sequence = checkpoint["global_seed_sequence"] + + self.best_metric = checkpoint["best_metric"] + + logging.info(f"Misc state is loaded from {ckpt_path}") + + + def save_checkpoint(self, ckpt_name, save_train_state): + ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) + logging.info(f"Saving checkpoint to: {ckpt_dir}") + # Backup previous checkpoint + temp_ckpt_dir = None + if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): + temp_ckpt_dir = os.path.join( + os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" + ) + if os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + os.rename(ckpt_dir, temp_ckpt_dir) + logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") + + # Save UNet + unet_path = os.path.join(ckpt_dir, "unet") + self.model.unet.save_pretrained(unet_path, safe_serialization=False) + logging.info(f"UNet is saved to: {unet_path}") + + if save_train_state: + state = { + "config": self.cfg, + "effective_iter": self.effective_iter, + "epoch": self.epoch, + "n_batch_in_epoch": self.n_batch_in_epoch, + "best_metric": self.best_metric, + "in_evaluation": self.in_evaluation, + "global_seed_sequence": self.global_seed_sequence, + } + train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") + torch.save(state, train_state_path) + # iteration indicator + f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") + f.close() + + logging.info(f"Trainer state is saved to: {train_state_path}") + + # Remove temp ckpt + if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + logging.debug("Old checkpoint backup is removed.") + + def load_checkpoint( + self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True + ): + logging.info(f"Loading checkpoint from: {ckpt_path}") + # Load UNet + _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") + self.model.unet.load_state_dict( + torch.load(_model_path, map_location=self.device) + ) + self.model.unet.to(self.device) + logging.info(f"UNet parameters are loaded from {_model_path}") + + # Load training states + if load_trainer_state: + checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) + self.effective_iter = checkpoint["effective_iter"] + self.epoch = checkpoint["epoch"] + self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] + self.in_evaluation = checkpoint["in_evaluation"] + self.global_seed_sequence = checkpoint["global_seed_sequence"] + + self.best_metric = checkpoint["best_metric"] + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + logging.info(f"optimizer state is loaded from {ckpt_path}") + + if resume_lr_scheduler: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + logging.info(f"LR scheduler state is loaded from {ckpt_path}") + + logging.info( + f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" + ) + return + + def _get_backup_ckpt_name(self): + return f"iter_{self.effective_iter:06d}" diff --git a/src/util/__pycache__/alignment.cpython-310.pyc b/src/util/__pycache__/alignment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4202b0a8582f154d2d5315d665df4d141e0c6f01 Binary files /dev/null and b/src/util/__pycache__/alignment.cpython-310.pyc differ diff --git a/src/util/__pycache__/config_util.cpython-310.pyc b/src/util/__pycache__/config_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b4ae9c09290073d7587a12eb7bbd5ec88c084d6 Binary files /dev/null and b/src/util/__pycache__/config_util.cpython-310.pyc differ diff --git a/src/util/__pycache__/data_loader.cpython-310.pyc b/src/util/__pycache__/data_loader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..869ec0d6e357784a3db4272b7fd40597c7b4fc91 Binary files /dev/null and b/src/util/__pycache__/data_loader.cpython-310.pyc differ diff --git a/src/util/__pycache__/depth_transform.cpython-310.pyc b/src/util/__pycache__/depth_transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..394420f5fde73a44d7d72dfe04b827ec877e4ca9 Binary files /dev/null and b/src/util/__pycache__/depth_transform.cpython-310.pyc differ diff --git a/src/util/__pycache__/logging_util.cpython-310.pyc b/src/util/__pycache__/logging_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba428dd3177f1f9345028286defa01b8d5737504 Binary files /dev/null and b/src/util/__pycache__/logging_util.cpython-310.pyc differ diff --git a/src/util/__pycache__/loss.cpython-310.pyc b/src/util/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a644d781037a0e60d6d7508c46fcf58f07a2f584 Binary files /dev/null and b/src/util/__pycache__/loss.cpython-310.pyc differ diff --git a/src/util/__pycache__/lr_scheduler.cpython-310.pyc b/src/util/__pycache__/lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb175a5503c54b9fe2a2cc0f9bed2c5c4e92cd35 Binary files /dev/null and b/src/util/__pycache__/lr_scheduler.cpython-310.pyc differ diff --git a/src/util/__pycache__/metric.cpython-310.pyc b/src/util/__pycache__/metric.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eab160d0f903245f3c07c904eca0e35c2280c54f Binary files /dev/null and b/src/util/__pycache__/metric.cpython-310.pyc differ diff --git a/src/util/__pycache__/multi_res_noise.cpython-310.pyc b/src/util/__pycache__/multi_res_noise.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0d60799d2696dc88fc9212a2ffe54b82c1693c4 Binary files /dev/null and b/src/util/__pycache__/multi_res_noise.cpython-310.pyc differ diff --git a/src/util/__pycache__/seeding.cpython-310.pyc b/src/util/__pycache__/seeding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1294462dc1134f78c5e74831f2763c3ecb82185e Binary files /dev/null and b/src/util/__pycache__/seeding.cpython-310.pyc differ diff --git a/src/util/__pycache__/slurm_util.cpython-310.pyc b/src/util/__pycache__/slurm_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7beaa95896505ba9c578baa2f576536350b6989f Binary files /dev/null and b/src/util/__pycache__/slurm_util.cpython-310.pyc differ diff --git a/src/util/alignment.py b/src/util/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..2fadd5ec88f8aa7107fc1a95d08aabf30a17e6df --- /dev/null +++ b/src/util/alignment.py @@ -0,0 +1,72 @@ +# Author: Bingxin Ke +# Last modified: 2024-01-11 + +import numpy as np +import torch + + +def align_depth_least_square( + gt_arr: np.ndarray, + pred_arr: np.ndarray, + valid_mask_arr: np.ndarray, + return_scale_shift=True, + max_resolution=None, +): + ori_shape = pred_arr.shape # input shape + + gt = gt_arr.squeeze() # [H, W] + pred = pred_arr.squeeze() + valid_mask = valid_mask_arr.squeeze() + + # Downsample + if max_resolution is not None: + scale_factor = np.min(max_resolution / np.array(ori_shape[-2:])) + if scale_factor < 1: + downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") + gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy() + pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy() + valid_mask = ( + downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float()) + .bool() + .numpy() + ) + + assert ( + gt.shape == pred.shape == valid_mask.shape + ), f"{gt.shape}, {pred.shape}, {valid_mask.shape}" + + gt_masked = gt[valid_mask].reshape((-1, 1)) + pred_masked = pred[valid_mask].reshape((-1, 1)) + + # numpy solver + _ones = np.ones_like(pred_masked) + A = np.concatenate([pred_masked, _ones], axis=-1) + X = np.linalg.lstsq(A, gt_masked, rcond=None)[0] + scale, shift = X + + aligned_pred = pred_arr * scale + shift + + # restore dimensions + aligned_pred = aligned_pred.reshape(ori_shape) + + if return_scale_shift: + return aligned_pred, scale, shift + else: + return aligned_pred + + +# ******************** disparity space ******************** +def depth2disparity(depth, return_mask=False): + if isinstance(depth, torch.Tensor): + disparity = torch.zeros_like(depth) + elif isinstance(depth, np.ndarray): + disparity = np.zeros_like(depth) + non_negtive_mask = depth > 0 + disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask] + if return_mask: + return disparity, non_negtive_mask + else: + return disparity + +def disparity2depth(disparity, **kwargs): + return depth2disparity(disparity, **kwargs) diff --git a/src/util/config_util.py b/src/util/config_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0e9fa45676d0f01d9d5fd62b251eeeceec60e243 --- /dev/null +++ b/src/util/config_util.py @@ -0,0 +1,49 @@ +# Author: Bingxin Ke +# Last modified: 2024-02-14 + +import omegaconf +from omegaconf import OmegaConf + + +def recursive_load_config(config_path: str) -> OmegaConf: + conf = OmegaConf.load(config_path) + + output_conf = OmegaConf.create({}) + + # Load base config. Later configs on the list will overwrite previous + base_configs = conf.get("base_config", default_value=None) + if base_configs is not None: + assert isinstance(base_configs, omegaconf.listconfig.ListConfig) + for _path in base_configs: + assert ( + _path != config_path + ), "Circulate merging, base_config should not include itself." + _base_conf = recursive_load_config(_path) + output_conf = OmegaConf.merge(output_conf, _base_conf) + + # Merge configs and overwrite values + output_conf = OmegaConf.merge(output_conf, conf) + + return output_conf + + +def find_value_in_omegaconf(search_key, config): + result_list = [] + + if isinstance(config, omegaconf.DictConfig): + for key, value in config.items(): + if key == search_key: + result_list.append(value) + elif isinstance(value, (omegaconf.DictConfig, omegaconf.ListConfig)): + result_list.extend(find_value_in_omegaconf(search_key, value)) + elif isinstance(config, omegaconf.ListConfig): + for item in config: + if isinstance(item, (omegaconf.DictConfig, omegaconf.ListConfig)): + result_list.extend(find_value_in_omegaconf(search_key, item)) + + return result_list + + +if "__main__" == __name__: + conf = recursive_load_config("config/train_base.yaml") + print(OmegaConf.to_yaml(conf)) diff --git a/src/util/data_loader.py b/src/util/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe42abfa55c0a5a76844153961e942c2e133bb4 --- /dev/null +++ b/src/util/data_loader.py @@ -0,0 +1,111 @@ +# Copied from https://github.com/huggingface/accelerate/blob/e2ae254008061b3e53fc1c97f88d65743a857e75/src/accelerate/data_loader.py + +from torch.utils.data import BatchSampler, DataLoader, IterableDataset + +# kwargs of the DataLoader in min version 1.4.0. +_PYTORCH_DATALOADER_KWARGS = { + "batch_size": 1, + "shuffle": False, + "sampler": None, + "batch_sampler": None, + "num_workers": 0, + "collate_fn": None, + "pin_memory": False, + "drop_last": False, + "timeout": 0, + "worker_init_fn": None, + "multiprocessing_context": None, + "generator": None, + "prefetch_factor": 2, + "persistent_workers": False, +} + + +class SkipBatchSampler(BatchSampler): + """ + A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. + """ + + def __init__(self, batch_sampler, skip_batches=0): + self.batch_sampler = batch_sampler + self.skip_batches = skip_batches + + def __iter__(self): + for index, samples in enumerate(self.batch_sampler): + if index >= self.skip_batches: + yield samples + + @property + def total_length(self): + return len(self.batch_sampler) + + def __len__(self): + return len(self.batch_sampler) - self.skip_batches + + +class SkipDataLoader(DataLoader): + """ + Subclass of a PyTorch `DataLoader` that will skip the first batches. + + Args: + dataset (`torch.utils.data.dataset.Dataset`): + The dataset to use to build this datalaoder. + skip_batches (`int`, *optional*, defaults to 0): + The number of batches to skip at the beginning. + kwargs: + All other keyword arguments to pass to the regular `DataLoader` initialization. + """ + + def __init__(self, dataset, skip_batches=0, **kwargs): + super().__init__(dataset, **kwargs) + self.skip_batches = skip_batches + + def __iter__(self): + for index, batch in enumerate(super().__iter__()): + if index >= self.skip_batches: + yield batch + + +# Adapted from https://github.com/huggingface/accelerate +def skip_first_batches(dataloader, num_batches=0): + """ + Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. + """ + dataset = dataloader.dataset + sampler_is_batch_sampler = False + if isinstance(dataset, IterableDataset): + new_batch_sampler = None + else: + sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) + batch_sampler = ( + dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler + ) + new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches) + + # We ignore all of those since they are all dealt with by our new_batch_sampler + ignore_kwargs = [ + "batch_size", + "shuffle", + "sampler", + "batch_sampler", + "drop_last", + ] + + kwargs = { + k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) + for k in _PYTORCH_DATALOADER_KWARGS + if k not in ignore_kwargs + } + + # Need to provide batch_size as batch_sampler is None for Iterable dataset + if new_batch_sampler is None: + kwargs["drop_last"] = dataloader.drop_last + kwargs["batch_size"] = dataloader.batch_size + + if new_batch_sampler is None: + # Need to manually skip batches in the dataloader + dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) + else: + dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) + + return dataloader diff --git a/src/util/depth_transform.py b/src/util/depth_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..b184453f187d228f8d4af82985ba6adc759279c4 --- /dev/null +++ b/src/util/depth_transform.py @@ -0,0 +1,102 @@ +# Author: Bingxin Ke +# Last modified: 2024-04-18 + +import torch +import logging + + +def get_depth_normalizer(cfg_normalizer): + if cfg_normalizer is None: + + def identical(x): + return x + + depth_transform = identical + + elif "scale_shift_depth" == cfg_normalizer.type: + depth_transform = ScaleShiftDepthNormalizer( + norm_min=cfg_normalizer.norm_min, + norm_max=cfg_normalizer.norm_max, + min_max_quantile=cfg_normalizer.min_max_quantile, + clip=cfg_normalizer.clip, + ) + else: + raise NotImplementedError + return depth_transform + + +class DepthNormalizerBase: + is_absolute = None + far_plane_at_max = None + + def __init__( + self, + norm_min=-1.0, + norm_max=1.0, + ) -> None: + self.norm_min = norm_min + self.norm_max = norm_max + raise NotImplementedError + + def __call__(self, depth, valid_mask=None, clip=None): + raise NotImplementedError + + def denormalize(self, depth_norm, **kwargs): + # For metric depth: convert prediction back to metric depth + # For relative depth: convert prediction to [0, 1] + raise NotImplementedError + +class ScaleShiftDepthNormalizer(DepthNormalizerBase): + """ + Use near and far plane to linearly normalize depth, + i.e. d' = d * s + t, + where near plane is mapped to `norm_min`, and far plane is mapped to `norm_max` + Near and far planes are determined by taking quantile values. + """ + + is_absolute = False + far_plane_at_max = True + + def __init__( + self, norm_min=-1.0, norm_max=1.0, min_max_quantile=0.02, clip=True + ) -> None: + self.norm_min = norm_min + self.norm_max = norm_max + self.norm_range = self.norm_max - self.norm_min + self.min_quantile = min_max_quantile + self.max_quantile = 1.0 - self.min_quantile + self.clip = clip + + def __call__(self, depth_linear, valid_mask=None, clip=None): + clip = clip if clip is not None else self.clip + + if valid_mask is None: + valid_mask = torch.ones_like(depth_linear).bool() + valid_mask = valid_mask & (depth_linear > 0) + + # Take quantiles as min and max + _min, _max = torch.quantile( + depth_linear[valid_mask], + torch.tensor([self.min_quantile, self.max_quantile]), + ) + + # scale and shift + depth_norm_linear = (depth_linear - _min) / ( + _max - _min + ) * self.norm_range + self.norm_min + + if clip: + depth_norm_linear = torch.clip( + depth_norm_linear, self.norm_min, self.norm_max + ) + + return depth_norm_linear + + def scale_back(self, depth_norm): + # scale to [0, 1] + depth_linear = (depth_norm - self.norm_min) / self.norm_range + return depth_linear + + def denormalize(self, depth_norm, **kwargs): + logging.warning(f"{self.__class__} is not revertible without GT") + return self.scale_back(depth_norm=depth_norm) diff --git a/src/util/logging_util.py b/src/util/logging_util.py new file mode 100644 index 0000000000000000000000000000000000000000..37dd103baa1958397b150eb9e07a11c02027ba6a --- /dev/null +++ b/src/util/logging_util.py @@ -0,0 +1,102 @@ +# Author: Bingxin Ke +# Last modified: 2024-03-12 + +import logging +import os +import sys +import wandb +from tabulate import tabulate +from torch.utils.tensorboard import SummaryWriter + + +def config_logging(cfg_logging, out_dir=None): + file_level = cfg_logging.get("file_level", 10) + console_level = cfg_logging.get("console_level", 10) + + log_formatter = logging.Formatter(cfg_logging["format"]) + + root_logger = logging.getLogger() + root_logger.handlers.clear() + + root_logger.setLevel(min(file_level, console_level)) + + if out_dir is not None: + _logging_file = os.path.join( + out_dir, cfg_logging.get("filename", "logging.log") + ) + file_handler = logging.FileHandler(_logging_file) + file_handler.setFormatter(log_formatter) + file_handler.setLevel(file_level) + root_logger.addHandler(file_handler) + + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(log_formatter) + console_handler.setLevel(console_level) + root_logger.addHandler(console_handler) + + # Avoid pollution by packages + logging.getLogger("PIL").setLevel(logging.INFO) + logging.getLogger("matplotlib").setLevel(logging.INFO) + + +class MyTrainingLogger: + """Tensorboard + wandb logger""" + + writer: SummaryWriter + is_initialized = False + + def __init__(self) -> None: + pass + + def set_dir(self, tb_log_dir): + if self.is_initialized: + raise ValueError("Do not initialize writer twice") + self.writer = SummaryWriter(tb_log_dir) + self.is_initialized = True + + def log_dic(self, scalar_dic, global_step, walltime=None): + for k, v in scalar_dic.items(): + self.writer.add_scalar(k, v, global_step=global_step, walltime=walltime) + return + + +# global instance +tb_logger = MyTrainingLogger() + + +# -------------- wandb tools -------------- +def init_wandb(enable: bool, **kwargs): + if enable: + run = wandb.init(sync_tensorboard=True, **kwargs) + else: + run = wandb.init(mode="disabled") + return run + + +def log_slurm_job_id(step): + global tb_logger + _jobid = os.getenv("SLURM_JOB_ID") + if _jobid is None: + _jobid = -1 + tb_logger.writer.add_scalar("job_id", int(_jobid), global_step=step) + logging.debug(f"Slurm job_id: {_jobid}") + + +def load_wandb_job_id(out_dir): + with open(os.path.join(out_dir, "WANDB_ID"), "r") as f: + wandb_id = f.read() + return wandb_id + + +def save_wandb_job_id(run, out_dir): + with open(os.path.join(out_dir, "WANDB_ID"), "w+") as f: + f.write(run.id) + + +def eval_dic_to_text(val_metrics: dict, dataset_name: str, sample_list_path: str): + eval_text = f"Evaluation metrics:\n\ + on dataset: {dataset_name}\n\ + over samples in: {sample_list_path}\n" + + eval_text += tabulate([val_metrics.keys(), val_metrics.values()]) + return eval_text diff --git a/src/util/loss.py b/src/util/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ee6dace55ed155a691d527bfe45ef3d87823a65e --- /dev/null +++ b/src/util/loss.py @@ -0,0 +1,124 @@ +# Author: Bingxin Ke +# Last modified: 2024-02-22 + +import torch + + +def get_loss(loss_name, **kwargs): + if "silog_mse" == loss_name: + criterion = SILogMSELoss(**kwargs) + elif "silog_rmse" == loss_name: + criterion = SILogRMSELoss(**kwargs) + elif "mse_loss" == loss_name: + criterion = torch.nn.MSELoss(**kwargs) + elif "l1_loss" == loss_name: + criterion = torch.nn.L1Loss(**kwargs) + elif "l1_loss_with_mask" == loss_name: + criterion = L1LossWithMask(**kwargs) + elif "mean_abs_rel" == loss_name: + criterion = MeanAbsRelLoss() + else: + raise NotImplementedError + + return criterion + + +class L1LossWithMask: + def __init__(self, batch_reduction=False): + self.batch_reduction = batch_reduction + + def __call__(self, depth_pred, depth_gt, valid_mask=None): + diff = depth_pred - depth_gt + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = depth_gt.shape[-2] * depth_gt.shape[-1] + + loss = torch.sum(torch.abs(diff)) / n + if self.batch_reduction: + loss = loss.mean() + return loss + + +class MeanAbsRelLoss: + def __init__(self) -> None: + # super().__init__() + pass + + def __call__(self, pred, gt): + diff = pred - gt + rel_abs = torch.abs(diff / gt) + loss = torch.mean(rel_abs, dim=0) + return loss + + +class SILogMSELoss: + def __init__(self, lamb, log_pred=True, batch_reduction=True): + """Scale Invariant Log MSE Loss + + Args: + lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss + log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred + """ + super(SILogMSELoss, self).__init__() + self.lamb = lamb + self.pred_in_log = log_pred + self.batch_reduction = batch_reduction + + def __call__(self, depth_pred, depth_gt, valid_mask=None): + log_depth_pred = ( + depth_pred if self.pred_in_log else torch.log(torch.clip(depth_pred, 1e-8)) + ) + log_depth_gt = torch.log(depth_gt) + + diff = log_depth_pred - log_depth_gt + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = depth_gt.shape[-2] * depth_gt.shape[-1] + + diff2 = torch.pow(diff, 2) + + first_term = torch.sum(diff2, (-1, -2)) / n + second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) + loss = first_term - second_term + if self.batch_reduction: + loss = loss.mean() + return loss + + +class SILogRMSELoss: + def __init__(self, lamb, alpha, log_pred=True): + """Scale Invariant Log RMSE Loss + + Args: + lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss + alpha: + log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred + """ + super(SILogRMSELoss, self).__init__() + self.lamb = lamb + self.alpha = alpha + self.pred_in_log = log_pred + + def __call__(self, depth_pred, depth_gt, valid_mask): + log_depth_pred = depth_pred if self.pred_in_log else torch.log(depth_pred) + log_depth_gt = torch.log(depth_gt) + # borrowed from https://github.com/aliyun/NeWCRFs + # diff = log_depth_pred[valid_mask] - log_depth_gt[valid_mask] + # return torch.sqrt((diff ** 2).mean() - self.lamb * (diff.mean() ** 2)) * self.alpha + + diff = log_depth_pred - log_depth_gt + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = depth_gt.shape[-2] * depth_gt.shape[-1] + + diff2 = torch.pow(diff, 2) + first_term = torch.sum(diff2, (-1, -2)) / n + second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) + loss = torch.sqrt(first_term - second_term).mean() * self.alpha + return loss diff --git a/src/util/lr_scheduler.py b/src/util/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..cd2d67f512723c31966498616511202eed4d9806 --- /dev/null +++ b/src/util/lr_scheduler.py @@ -0,0 +1,48 @@ +# Author: Bingxin Ke +# Last modified: 2024-02-22 + +import numpy as np + + +class IterExponential: + def __init__(self, total_iter_length, final_ratio, warmup_steps=0) -> None: + """ + Customized iteration-wise exponential scheduler. + Re-calculate for every step, to reduce error accumulation + + Args: + total_iter_length (int): Expected total iteration number + final_ratio (float): Expected LR ratio at n_iter = total_iter_length + """ + self.total_length = total_iter_length + self.effective_length = total_iter_length - warmup_steps + self.final_ratio = final_ratio + self.warmup_steps = warmup_steps + + def __call__(self, n_iter) -> float: + if n_iter < self.warmup_steps: + alpha = 1.0 * n_iter / self.warmup_steps + elif n_iter >= self.total_length: + alpha = self.final_ratio + else: + actual_iter = n_iter - self.warmup_steps + alpha = np.exp( + actual_iter / self.effective_length * np.log(self.final_ratio) + ) + return alpha + + +if "__main__" == __name__: + lr_scheduler = IterExponential( + total_iter_length=50000, final_ratio=0.01, warmup_steps=200 + ) + lr_scheduler = IterExponential( + total_iter_length=50000, final_ratio=0.01, warmup_steps=0 + ) + + x = np.arange(100000) + alphas = [lr_scheduler(i) for i in x] + import matplotlib.pyplot as plt + + plt.plot(alphas) + plt.savefig("lr_scheduler.png") diff --git a/src/util/metric.py b/src/util/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..b318d8a461f4c8745b1e381faba0fe9b7c734be0 --- /dev/null +++ b/src/util/metric.py @@ -0,0 +1,157 @@ +# Author: Bingxin Ke +# Last modified: 2024-02-15 + + +import pandas as pd +import torch + + +# Adapted from: https://github.com/victoresque/pytorch-template/blob/master/utils/util.py +class MetricTracker: + def __init__(self, *keys, writer=None): + self.writer = writer + self._data = pd.DataFrame(index=keys, columns=["total", "counts", "average"]) + self.reset() + + def reset(self): + for col in self._data.columns: + self._data[col].values[:] = 0 + + def update(self, key, value, n=1): + if self.writer is not None: + self.writer.add_scalar(key, value) + self._data.loc[key, "total"] += value * n + self._data.loc[key, "counts"] += n + self._data.loc[key, "average"] = self._data.total[key] / self._data.counts[key] + + def avg(self, key): + return self._data.average[key] + + def result(self): + return dict(self._data.average) + +def abs_relative_difference(output, target, valid_mask=None): + actual_output = output + actual_target = target + abs_relative_diff = torch.abs(actual_output - actual_target) / actual_target + if valid_mask is not None: + abs_relative_diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = output.shape[-1] * output.shape[-2] + abs_relative_diff = torch.sum(abs_relative_diff, (-1, -2)) / n + return abs_relative_diff.mean() + + +def squared_relative_difference(output, target, valid_mask=None): + actual_output = output + actual_target = target + square_relative_diff = ( + torch.pow(torch.abs(actual_output - actual_target), 2) / actual_target + ) + if valid_mask is not None: + square_relative_diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = output.shape[-1] * output.shape[-2] + square_relative_diff = torch.sum(square_relative_diff, (-1, -2)) / n + return square_relative_diff.mean() + + +def rmse_linear(output, target, valid_mask=None): + actual_output = output + actual_target = target + diff = actual_output - actual_target + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = output.shape[-1] * output.shape[-2] + diff2 = torch.pow(diff, 2) + mse = torch.sum(diff2, (-1, -2)) / n + rmse = torch.sqrt(mse) + return rmse.mean() + + +def rmse_log(output, target, valid_mask=None): + diff = torch.log(output) - torch.log(target) + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = output.shape[-1] * output.shape[-2] + diff2 = torch.pow(diff, 2) + mse = torch.sum(diff2, (-1, -2)) / n # [B] + rmse = torch.sqrt(mse) + return rmse.mean() + + +def log10(output, target, valid_mask=None): + if valid_mask is not None: + diff = torch.abs( + torch.log10(output[valid_mask]) - torch.log10(target[valid_mask]) + ) + else: + diff = torch.abs(torch.log10(output) - torch.log10(target)) + return diff.mean() + + +# adapt from: https://github.com/imran3180/depth-map-prediction/blob/master/main.py +def threshold_percentage(output, target, threshold_val, valid_mask=None): + d1 = output / target + d2 = target / output + max_d1_d2 = torch.max(d1, d2) + zero = torch.zeros(*output.shape) + one = torch.ones(*output.shape) + bit_mat = torch.where(max_d1_d2.cpu() < threshold_val, one, zero) + if valid_mask is not None: + bit_mat[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = output.shape[-1] * output.shape[-2] + count_mat = torch.sum(bit_mat, (-1, -2)) + threshold_mat = count_mat / n.cpu() + return threshold_mat.mean() + + +def delta1_acc(pred, gt, valid_mask): + return threshold_percentage(pred, gt, 1.25, valid_mask) + + +def delta2_acc(pred, gt, valid_mask): + return threshold_percentage(pred, gt, 1.25**2, valid_mask) + + +def delta3_acc(pred, gt, valid_mask): + return threshold_percentage(pred, gt, 1.25**3, valid_mask) + + +def i_rmse(output, target, valid_mask=None): + output_inv = 1.0 / output + target_inv = 1.0 / target + diff = output_inv - target_inv + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = output.shape[-1] * output.shape[-2] + diff2 = torch.pow(diff, 2) + mse = torch.sum(diff2, (-1, -2)) / n # [B] + rmse = torch.sqrt(mse) + return rmse.mean() + + +def silog_rmse(depth_pred, depth_gt, valid_mask=None): + diff = torch.log(depth_pred) - torch.log(depth_gt) + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = depth_gt.shape[-2] * depth_gt.shape[-1] + + diff2 = torch.pow(diff, 2) + + first_term = torch.sum(diff2, (-1, -2)) / n + second_term = torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) + loss = torch.sqrt(torch.mean(first_term - second_term)) * 100 + return loss diff --git a/src/util/multi_res_noise.py b/src/util/multi_res_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..e4d0ee057ec0a02a527c04fc9cd2f1c53252d51c --- /dev/null +++ b/src/util/multi_res_noise.py @@ -0,0 +1,75 @@ +# Author: Bingxin Ke +# Last modified: 2024-04-18 + +import torch +import math + + +# adapted from: https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2?s=31 +def multi_res_noise_like( + x, strength=0.9, downscale_strategy="original", generator=None, device=None +): + if torch.is_tensor(strength): + strength = strength.reshape((-1, 1, 1, 1)) + b, c, w, h = x.shape + + if device is None: + device = x.device + + up_sampler = torch.nn.Upsample(size=(w, h), mode="bilinear") + noise = torch.randn(x.shape, device=x.device, generator=generator) + + if "original" == downscale_strategy: + for i in range(10): + r = ( + torch.rand(1, generator=generator, device=device) * 2 + 2 + ) # Rather than always going 2x, + w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i))) + noise += ( + up_sampler( + torch.randn(b, c, w, h, generator=generator, device=device).to(x) + ) + * strength**i + ) + if w == 1 or h == 1: + break # Lowest resolution is 1x1 + elif "every_layer" == downscale_strategy: + for i in range(int(math.log2(min(w, h)))): + w, h = max(1, int(w / 2)), max(1, int(h / 2)) + noise += ( + up_sampler( + torch.randn(b, c, w, h, generator=generator, device=device).to(x) + ) + * strength**i + ) + elif "power_of_two" == downscale_strategy: + for i in range(10): + r = 2 + w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i))) + noise += ( + up_sampler( + torch.randn(b, c, w, h, generator=generator, device=device).to(x) + ) + * strength**i + ) + if w == 1 or h == 1: + break # Lowest resolution is 1x1 + elif "random_step" == downscale_strategy: + for i in range(10): + r = ( + torch.rand(1, generator=generator, device=device) * 2 + 2 + ) # Rather than always going 2x, + w, h = max(1, int(w / (r))), max(1, int(h / (r))) + noise += ( + up_sampler( + torch.randn(b, c, w, h, generator=generator, device=device).to(x) + ) + * strength**i + ) + if w == 1 or h == 1: + break # Lowest resolution is 1x1 + else: + raise ValueError(f"unknown downscale strategy: {downscale_strategy}") + + noise = noise / noise.std() # Scaled back to roughly unit variance + return noise diff --git a/src/util/seeding.py b/src/util/seeding.py new file mode 100644 index 0000000000000000000000000000000000000000..b63a778d28be5d094dff083d411d8edeef2d5604 --- /dev/null +++ b/src/util/seeding.py @@ -0,0 +1,54 @@ +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + + +import numpy as np +import random +import torch +import logging + + +def seed_all(seed: int = 0): + """ + Set random seeds of all components. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def generate_seed_sequence( + initial_seed: int, + length: int, + min_val=-0x8000_0000_0000_0000, + max_val=0xFFFF_FFFF_FFFF_FFFF, +): + if initial_seed is None: + logging.warning("initial_seed is None, reproducibility is not guaranteed") + random.seed(initial_seed) + + seed_sequence = [] + + for _ in range(length): + seed = random.randint(min_val, max_val) + + seed_sequence.append(seed) + + return seed_sequence diff --git a/src/util/slurm_util.py b/src/util/slurm_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a983d86c03f11569add80001d230aced342ceeac --- /dev/null +++ b/src/util/slurm_util.py @@ -0,0 +1,15 @@ +# Author: Bingxin Ke +# Last modified: 2024-02-22 + +import os + + +def is_on_slurm(): + cluster_name = os.getenv("SLURM_CLUSTER_NAME") + is_on_slurm = cluster_name is not None + return is_on_slurm + + +def get_local_scratch_dir(): + local_scratch_dir = os.getenv("TMPDIR") + return local_scratch_dir