Spaces:
Runtime error
Runtime error
# MIT License | |
# Copyright (c) 2022 Intelligent Systems Lab Org | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# File author: Shariq Farooq Bhat | |
# This file is partly inspired from BTS (https://github.com/cleinc/bts/blob/master/pytorch/bts_dataloader.py); author: Jin Han Lee | |
# This file may include modifications from author Zhenyu Li | |
import itertools | |
import os | |
import random | |
import numpy as np | |
import cv2 | |
import torch | |
import torch.nn as nn | |
import torch.utils.data.distributed | |
from zoedepth.utils.easydict import EasyDict as edict | |
from PIL import Image, ImageOps | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision import transforms | |
from zoedepth.utils.config import change_dataset | |
from .ddad import get_ddad_loader | |
from .diml_indoor_test import get_diml_indoor_loader | |
from .diml_outdoor_test import get_diml_outdoor_loader | |
from .diode import get_diode_loader | |
from .hypersim import get_hypersim_loader | |
from .ibims import get_ibims_loader | |
from .sun_rgbd_loader import get_sunrgbd_loader | |
from .vkitti import get_vkitti_loader | |
from .vkitti2 import get_vkitti2_loader | |
from .u4k import get_u4k_loader | |
from .middleburry import get_mid_loader | |
from .gta import get_gta_loader | |
from .preprocess import CropParams, get_white_border, get_black_border | |
import copy | |
from zoedepth.utils.misc import get_boundaries | |
from zoedepth.models.base_models.midas import Resize | |
def _is_pil_image(img): | |
return isinstance(img, Image.Image) | |
def _is_numpy_image(img): | |
return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) | |
# def preprocessing_transforms(mode, **kwargs): | |
# return transforms.Compose([ | |
# ToTensor(mode=mode, **kwargs) | |
# ]) | |
def preprocessing_transforms(mode, sec_stage=False, **kwargs): | |
return transforms.Compose([ | |
ToTensor(mode=mode, sec_stage=sec_stage, **kwargs) | |
]) | |
class DepthDataLoader(object): | |
def __init__(self, config, mode, device='cpu', transform=None, **kwargs): | |
""" | |
Data loader for depth datasets | |
Args: | |
config (dict): Config dictionary. Refer to utils/config.py | |
mode (str): "train" or "online_eval" | |
device (str, optional): Device to load the data on. Defaults to 'cpu'. | |
transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None. | |
""" | |
self.config = config | |
if config.dataset == 'ibims': | |
self.data = get_ibims_loader(config, batch_size=1, num_workers=1) | |
return | |
if config.dataset == 'sunrgbd': | |
self.data = get_sunrgbd_loader( | |
data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1) | |
return | |
if config.dataset == 'diml_indoor': | |
self.data = get_diml_indoor_loader( | |
data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1) | |
return | |
if config.dataset == 'diml_outdoor': | |
self.data = get_diml_outdoor_loader( | |
data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1) | |
return | |
if "diode" in config.dataset: | |
self.data = get_diode_loader( | |
config[config.dataset+"_root"], batch_size=1, num_workers=1) | |
return | |
if config.dataset == 'hypersim_test': | |
self.data = get_hypersim_loader( | |
config.hypersim_test_root, batch_size=1, num_workers=1) | |
return | |
if config.dataset == 'vkitti': | |
self.data = get_vkitti_loader( | |
config.vkitti_root, batch_size=1, num_workers=1) | |
return | |
if config.dataset == 'vkitti2': | |
self.data = get_vkitti2_loader( | |
config.vkitti2_root, batch_size=1, num_workers=1) | |
return | |
if config.dataset == 'ddad': | |
self.data = get_ddad_loader(config.ddad_root, resize_shape=( | |
352, 1216), batch_size=1, num_workers=1) | |
return | |
# under construction | |
if config.dataset == 'u4k': | |
self.data = get_u4k_loader(config, mode, transform) | |
return | |
if config.dataset == 'mid': | |
self.data = get_mid_loader(config, mode, transform) | |
return | |
if config.dataset == 'gta': | |
self.data = get_gta_loader(config, mode, transform) | |
return | |
img_size = self.config.get("img_size", None) | |
img_size = img_size if self.config.get( | |
"do_input_resize", False) else None | |
if transform is None: | |
# transform = preprocessing_transforms(mode, size=img_size) | |
transform = preprocessing_transforms(mode, size=img_size, sec_stage=config.get("sec_stage", False)) | |
if mode == 'train': | |
Dataset = DataLoadPreprocess | |
self.training_samples = Dataset( | |
config, mode, transform=transform, device=device) | |
if config.distributed: | |
self.train_sampler = torch.utils.data.distributed.DistributedSampler( | |
self.training_samples) | |
else: | |
self.train_sampler = None | |
self.data = DataLoader(self.training_samples, | |
batch_size=config.batch_size, | |
shuffle=(self.train_sampler is None), | |
num_workers=config.workers, | |
pin_memory=True, | |
persistent_workers=True, | |
# prefetch_factor=2, | |
sampler=self.train_sampler) | |
elif mode == 'online_eval': | |
self.testing_samples = DataLoadPreprocess( | |
config, mode, transform=transform) | |
if config.distributed: # redundant. here only for readability and to be more explicit | |
# Give whole test set to all processes (and report evaluation only on one) regardless | |
self.eval_sampler = None | |
else: | |
self.eval_sampler = None | |
self.data = DataLoader(self.testing_samples, 1, | |
shuffle=False, | |
num_workers=1, | |
pin_memory=False, | |
sampler=self.eval_sampler) | |
elif mode == 'test': | |
self.testing_samples = DataLoadPreprocess( | |
config, mode, transform=transform) | |
self.data = DataLoader(self.testing_samples, | |
1, shuffle=False, num_workers=1) | |
else: | |
print( | |
'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode)) | |
def repetitive_roundrobin(*iterables): | |
""" | |
cycles through iterables but sample wise | |
first yield first sample from first iterable then first sample from second iterable and so on | |
then second sample from first iterable then second sample from second iterable and so on | |
If one iterable is shorter than the others, it is repeated until all iterables are exhausted | |
repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E | |
""" | |
# Repetitive roundrobin | |
iterables_ = [iter(it) for it in iterables] | |
exhausted = [False] * len(iterables) | |
while not all(exhausted): | |
for i, it in enumerate(iterables_): | |
try: | |
yield next(it) | |
except StopIteration: | |
exhausted[i] = True | |
iterables_[i] = itertools.cycle(iterables[i]) | |
# First elements may get repeated if one iterable is shorter than the others | |
yield next(iterables_[i]) | |
class RepetitiveRoundRobinDataLoader(object): | |
def __init__(self, *dataloaders): | |
self.dataloaders = dataloaders | |
def __iter__(self): | |
return repetitive_roundrobin(*self.dataloaders) | |
def __len__(self): | |
# First samples get repeated, thats why the plus one | |
return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1) | |
class MixedNYUKITTI(object): | |
def __init__(self, config, mode, device='cpu', **kwargs): | |
config = edict(config) | |
config.workers = config.workers // 2 | |
self.config = config | |
nyu_conf = change_dataset(edict(config), 'nyu') | |
kitti_conf = change_dataset(edict(config), 'kitti') | |
# make nyu default for testing | |
self.config = config = nyu_conf | |
img_size = self.config.get("img_size", None) | |
img_size = img_size if self.config.get( | |
"do_input_resize", False) else None | |
if mode == 'train': | |
nyu_loader = DepthDataLoader( | |
nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data | |
kitti_loader = DepthDataLoader( | |
kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data | |
# It has been changed to repetitive roundrobin | |
self.data = RepetitiveRoundRobinDataLoader( | |
nyu_loader, kitti_loader) | |
else: | |
self.data = DepthDataLoader(nyu_conf, mode, device=device).data | |
def remove_leading_slash(s): | |
if s[0] == '/' or s[0] == '\\': | |
return s[1:] | |
return s | |
class CachedReader: | |
def __init__(self, shared_dict=None): | |
if shared_dict: | |
self._cache = shared_dict | |
else: | |
self._cache = {} | |
def open(self, fpath): | |
im = self._cache.get(fpath, None) | |
if im is None: | |
im = self._cache[fpath] = Image.open(fpath) | |
return im | |
class ImReader: | |
def __init__(self): | |
pass | |
# @cache | |
def open(self, fpath): | |
return Image.open(fpath) | |
class DataLoadPreprocess(Dataset): | |
def __init__(self, config, mode, transform=None, is_for_online_eval=False, **kwargs): | |
self.config = config | |
if mode == 'online_eval': | |
with open(config.filenames_file_eval, 'r') as f: | |
self.filenames = f.readlines() | |
else: | |
with open(config.filenames_file, 'r') as f: | |
self.filenames = f.readlines() | |
self.sec_stage = self.config.get("sec_stage", False) | |
# self.crop_size = [120, 160] # 1/4 | |
self.crop_size = [120*2, 160*2] # 1/4 | |
self.overlap = self.config.get("overlap", False) | |
self.consistency_training = self.config.get("consistency_training", False) | |
self.overlap_length_h = self.config.get("overlap_length_h", int(60)) | |
self.overlap_length_w = self.config.get("overlap_length_w", int(80)) | |
print("current overlap_length_h and overlap_length_w are {} and {}".format(self.overlap_length_h, self.overlap_length_w)) | |
self.mode = mode | |
self.transform = transform | |
self.to_tensor = ToTensor(mode) | |
self.is_for_online_eval = is_for_online_eval | |
if config.use_shared_dict: | |
self.reader = CachedReader(config.shared_dict) | |
else: | |
self.reader = ImReader() | |
def postprocess(self, sample): | |
return sample | |
def get_crop_bbox(self, img): | |
"""Randomly get a crop bounding box.""" | |
margin_h = max(img.shape[0] - self.crop_size[0], 0) | |
margin_w = max(img.shape[1] - self.crop_size[1], 0) | |
offset_h = np.random.randint(0, margin_h + 1) | |
offset_w = np.random.randint(0, margin_w + 1) | |
crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] | |
crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] | |
return crop_y1, crop_y2, crop_x1, crop_x2 | |
def crop(self, img, crop_bbox, tmp=False): | |
"""Crop from ``img``""" | |
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox | |
if tmp: | |
templete = np.zeros((img.shape[0], img.shape[1], 1), dtype=np.float32) | |
templete[crop_y1:crop_y2, crop_x1:crop_x2, :] = 1.0 | |
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] | |
return img, templete | |
else: | |
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] | |
return img | |
def __getitem__(self, idx): | |
sample_path = self.filenames[idx] | |
focal = float(sample_path.split()[2]) | |
sample = {} | |
height=480 | |
width=640 | |
if self.mode == 'train': | |
if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5: | |
image_path = os.path.join( | |
self.config.data_path, remove_leading_slash(sample_path.split()[3])) | |
depth_path = os.path.join( | |
self.config.gt_path, remove_leading_slash(sample_path.split()[4])) | |
else: | |
image_path = os.path.join( | |
self.config.data_path, remove_leading_slash(sample_path.split()[0])) | |
depth_path = os.path.join( | |
self.config.gt_path, remove_leading_slash(sample_path.split()[1])) | |
image = self.reader.open(image_path) | |
depth_gt = self.reader.open(depth_path) | |
w, h = image.size | |
if self.config.do_kb_crop: | |
height = image.height | |
width = image.width | |
top_margin = int(height - 352) | |
left_margin = int((width - 1216) / 2) | |
depth_gt = depth_gt.crop( | |
(left_margin, top_margin, left_margin + 1216, top_margin + 352)) | |
image = image.crop( | |
(left_margin, top_margin, left_margin + 1216, top_margin + 352)) | |
# Avoid blank boundaries due to pixel registration? | |
# Train images have white border. Test images have black border. | |
if self.config.dataset == 'nyu' and self.config.avoid_boundary: | |
# print("Avoiding Blank Boundaries!") | |
# We just crop and pad again with reflect padding to original size | |
# original_size = image.size | |
crop_params = get_white_border(np.array(image, dtype=np.uint8)) | |
image = image.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom)) | |
depth_gt = depth_gt.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom)) | |
# Use reflect padding to fill the blank | |
image = np.array(image) | |
image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect') | |
image = Image.fromarray(image) | |
depth_gt = np.array(depth_gt) | |
depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right)), 'constant', constant_values=0) | |
depth_gt = Image.fromarray(depth_gt) | |
if self.config.do_random_rotate and (self.config.aug): | |
# NOTE: YES! | |
random_angle = (random.random() - 0.5) * 2 * self.config.degree | |
image = self.rotate_image(image, random_angle) | |
depth_gt = self.rotate_image( | |
depth_gt, random_angle, flag=Image.NEAREST) | |
image = np.asarray(image, dtype=np.float32) / 255.0 | |
depth_gt = np.asarray(depth_gt, dtype=np.float32) | |
depth_gt = np.expand_dims(depth_gt, axis=2) | |
disp_gt_copy = depth_gt[:, :, 0].copy() | |
if self.config.dataset == 'nyu': | |
depth_gt = depth_gt / 1000.0 | |
else: | |
depth_gt = depth_gt / 256.0 | |
# if self.config.aug and (self.config.random_crop): | |
# image, depth_gt = self.random_crop( | |
# image, depth_gt, self.config.input_height, self.config.input_width) | |
image, depth_gt = self.train_preprocess(image, depth_gt) | |
img_temp = copy.deepcopy(image) | |
depth_gt_temp = copy.deepcopy(depth_gt) | |
if self.random_crop: # use in sec_stage | |
if self.consistency_training: | |
crop_y1, crop_y2, crop_x1, crop_x2 = self.get_crop_bbox(image) # ensure the prob of crop is the same | |
while True: | |
# shift_x = random.randint(self.overlap_length//3, self.overlap_length) | |
# shift_y = random.randint(self.overlap_length//3, self.overlap_length) | |
shift_x = self.overlap_length_w | |
shift_y = self.overlap_length_h | |
if random.random() > 0.5: | |
shift_x = shift_x * -1 | |
if random.random() > 0.5: | |
shift_y = shift_y * -1 | |
crop_y1_shift, crop_y2_shift, crop_x1_shift, crop_x2_shift = crop_y1 + shift_y, crop_y2 + shift_y, crop_x1 + shift_x, crop_x2 + shift_x | |
if crop_y1_shift > 0 and crop_x1_shift > 0 and crop_y2_shift < image.shape[0] and crop_x2_shift < image.shape[1]: | |
break | |
bbox_ori = (crop_y1, crop_y2, crop_x1, crop_x2) | |
bbox_shift = (crop_y1_shift, crop_y2_shift, crop_x1_shift, crop_x2_shift) | |
image_ori, crop_area_ori = self.crop(image, bbox_ori, tmp=True) | |
image_shift, crop_area_shift = self.crop(image, bbox_shift, tmp=True) | |
depth_gt_ori = self.crop(depth_gt, bbox_ori) | |
depth_gt_shift = self.crop(depth_gt, bbox_shift) | |
disp_gt_copy_ori = self.crop(disp_gt_copy, bbox_ori) | |
disp_gt_copy_shift = self.crop(disp_gt_copy, bbox_shift) | |
bboxs_ori = torch.tensor([crop_x1 / width * 160 * 2, crop_y1 / height * 120 * 2, crop_x2 / width * 160 * 2, crop_y2 / height * 120 * 2]) | |
bboxs_shift = torch.tensor([crop_x1_shift / width * 160 * 2, crop_y1_shift / height * 120 * 2, crop_x2_shift / width * 160 * 2, crop_y2_shift / height * 120 * 2]) | |
bboxs_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2]) | |
bboxs_raw_shift = torch.tensor([crop_x1_shift, crop_y1_shift, crop_x2_shift, crop_y2_shift]) | |
else: | |
bbox = self.get_crop_bbox(image) | |
image, crop_area = self.crop(image, bbox, tmp=True) | |
depth_gt = self.crop(depth_gt, bbox) | |
disp_gt_copy = self.crop(disp_gt_copy, bbox) | |
crop_y1, crop_y2, crop_x1, crop_x2 = bbox | |
bboxs_res = torch.tensor([crop_x1 / width * 160 * 2, crop_y1 / height * 120 * 2, crop_x2 / width * 160 * 2, crop_y2 / height * 120 * 2]) # coord in 384, 512 | |
bboxs_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2]) | |
mask = np.logical_and(depth_gt > self.config.min_depth, | |
depth_gt < self.config.max_depth).squeeze()[None, ...] | |
mask_raw = np.logical_and(depth_gt_temp > self.config.min_depth, depth_gt_temp < self.config.max_depth).squeeze()[None, ...] | |
sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'mask': mask, 'image_raw': image.copy(), 'mask_raw': mask_raw} | |
if self.random_crop: | |
if self.consistency_training: | |
image = np.concatenate([image_ori, image_shift], axis=-1) | |
depth_gt = np.concatenate([depth_gt_ori, depth_gt_shift], axis=-1) | |
crop_area = np.concatenate([crop_area_ori, crop_area_shift], axis=-1) | |
bboxs_res = torch.cat([bboxs_ori, bboxs_shift], dim=-1) | |
bboxes_raw_res = torch.cat([bboxs_raw, bboxs_raw_shift], dim=-1) | |
mask = np.logical_and(depth_gt > self.config.min_depth, | |
depth_gt < self.config.max_depth) | |
# hack the sample dict | |
sample['image'] = image | |
sample['depth'] = depth_gt | |
sample['crop_area'] = crop_area | |
sample['bbox'] = bboxs_res | |
sample['bbox_raw'] = bboxes_raw_res | |
sample['shift'] = torch.tensor([shift_y, shift_x]) # h direction, then w direction | |
sample['mask'] = mask | |
else: | |
if bboxs_res is not None: | |
sample['bbox'] = bboxs_res | |
sample['bbox_raw'] = bboxs_raw | |
sample['crop_area'] = crop_area | |
if self.config.aug and self.config.random_translate: | |
image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation) | |
else: | |
if self.mode == 'online_eval': | |
data_path = self.config.data_path_eval | |
else: | |
data_path = self.config.data_path | |
image_path = os.path.join( | |
data_path, remove_leading_slash(sample_path.split()[0])) | |
image = np.asarray(self.reader.open(image_path), | |
dtype=np.float32) / 255.0 | |
if self.mode == 'online_eval': | |
gt_path = self.config.gt_path_eval | |
depth_path = os.path.join( | |
gt_path, remove_leading_slash(sample_path.split()[1])) | |
has_valid_depth = False | |
try: | |
depth_gt = self.reader.open(depth_path) | |
has_valid_depth = True | |
except IOError: | |
depth_gt = False | |
# print('Missing gt for {}'.format(image_path)) | |
if has_valid_depth: | |
depth_gt = np.asarray(depth_gt, dtype=np.float32) | |
depth_gt = np.expand_dims(depth_gt, axis=2) | |
disp_gt_copy = depth_gt[:, :, 0].copy() | |
if self.config.dataset == 'nyu': | |
depth_gt = depth_gt / 1000.0 | |
else: | |
depth_gt = depth_gt / 256.0 | |
mask = np.logical_and( | |
depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...] | |
else: | |
mask = False | |
if self.config.do_kb_crop: | |
height = image.shape[0] | |
width = image.shape[1] | |
top_margin = int(height - 352) | |
left_margin = int((width - 1216) / 2) | |
image = image[top_margin:top_margin + 352, | |
left_margin:left_margin + 1216, :] | |
if self.mode == 'online_eval' and has_valid_depth: | |
depth_gt = depth_gt[top_margin:top_margin + | |
352, left_margin:left_margin + 1216, :] | |
# NOTE: start insert something new for sec_stage training | |
if self.sec_stage: | |
img_temp = copy.deepcopy(image) | |
depth_gt_temp = copy.deepcopy(depth_gt) | |
x_start, y_start = [0, 240], [0, 320] | |
# x_start, y_start = [0 + 3 * self.overlap / 2, 120 + self.overlap / 2, 240 - self.overlap / 2, 360 - 3 * self.overlap / 2], \ | |
# [0 + 3 * self.overlap / 2, 160 + self.overlap / 2, 320 - self.overlap / 2, 480 - 3 * self.overlap / 2] | |
img_crops = [] | |
bboxs_roi = [] | |
crop_areas = [] | |
bboxs_raw_list = [] | |
for x in x_start: | |
for y in y_start: | |
bbox = (int(x), int(x+240), int(y), int(y+320)) | |
img_crop, crop_area = self.crop(image, bbox, tmp=True) | |
img_crops.append(img_crop) | |
crop_areas.append(crop_area) | |
crop_y1, crop_y2, crop_x1, crop_x2 = bbox | |
bbox_roi = torch.tensor([crop_x1 / width * 160 * 2, crop_y1 / height * 120 * 2, crop_x2 / width * 160 * 2, crop_y2 / height * 120 * 2]) | |
bboxs_roi.append(bbox_roi) | |
bboxs_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2]) | |
bboxs_raw_list.append(bboxs_raw) | |
image = img_crops | |
bboxs_roi = torch.stack(bboxs_roi, dim=0) | |
bboxs_raw = torch.stack(bboxs_raw_list, dim=0) | |
disp_gt_edges = get_boundaries(disp_gt_copy, th=1, dilation=0) | |
if self.mode == 'online_eval': | |
sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth, | |
'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1], | |
'mask': mask, 'image_raw': image.copy(), 'disp_gt_edges': disp_gt_edges} | |
if bboxs_roi is not None: | |
sample['bbox'] = bboxs_roi | |
sample['bbox_raw'] = bboxs_raw | |
if crop_areas is not None: | |
sample['crop_area'] = crop_areas | |
else: | |
sample = {'image': image, 'focal': focal} | |
if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']): | |
mask = np.logical_and(depth_gt > self.config.min_depth, | |
depth_gt < self.config.max_depth).squeeze()[None, ...] | |
sample['mask'] = mask | |
if self.transform: | |
# sample = self.transform(sample) | |
sample['img_temp'] = img_temp | |
sample['depth_gt_temp'] = depth_gt_temp | |
sample = self.transform(sample) | |
sample = self.postprocess(sample) | |
sample['dataset'] = self.config.dataset | |
sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]} | |
return sample | |
def rotate_image(self, image, angle, flag=Image.BILINEAR): | |
result = image.rotate(angle, resample=flag) | |
return result | |
def random_crop(self, img, depth, height, width): | |
assert img.shape[0] >= height | |
assert img.shape[1] >= width | |
assert img.shape[0] == depth.shape[0] | |
assert img.shape[1] == depth.shape[1] | |
x = random.randint(0, img.shape[1] - width) | |
y = random.randint(0, img.shape[0] - height) | |
img = img[y:y + height, x:x + width, :] | |
depth = depth[y:y + height, x:x + width, :] | |
return img, depth | |
def random_translate(self, img, depth, max_t=20): | |
assert img.shape[0] == depth.shape[0] | |
assert img.shape[1] == depth.shape[1] | |
p = self.config.translate_prob | |
do_translate = random.random() | |
if do_translate > p: | |
return img, depth | |
x = random.randint(-max_t, max_t) | |
y = random.randint(-max_t, max_t) | |
M = np.float32([[1, 0, x], [0, 1, y]]) | |
# print(img.shape, depth.shape) | |
img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0])) | |
depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0])) | |
depth = depth.squeeze()[..., None] # add channel dim back. Affine warp removes it | |
# print("after", img.shape, depth.shape) | |
return img, depth | |
def train_preprocess(self, image, depth_gt): | |
if self.config.aug: | |
# Random flipping | |
do_flip = random.random() | |
if do_flip > 0.5: | |
image = (image[:, ::-1, :]).copy() | |
depth_gt = (depth_gt[:, ::-1, :]).copy() | |
# Random gamma, brightness, color augmentation | |
do_augment = random.random() | |
if do_augment > 0.5: | |
image = self.augment_image(image) | |
return image, depth_gt | |
def augment_image(self, image): | |
# gamma augmentation | |
gamma = random.uniform(0.9, 1.1) | |
image_aug = image ** gamma | |
# brightness augmentation | |
if self.config.dataset == 'nyu': | |
brightness = random.uniform(0.75, 1.25) | |
else: | |
brightness = random.uniform(0.9, 1.1) | |
image_aug = image_aug * brightness | |
# color augmentation | |
colors = np.random.uniform(0.9, 1.1, size=3) | |
white = np.ones((image.shape[0], image.shape[1])) | |
color_image = np.stack([white * colors[i] for i in range(3)], axis=2) | |
image_aug *= color_image | |
image_aug = np.clip(image_aug, 0, 1) | |
return image_aug | |
def __len__(self): | |
return len(self.filenames) | |
# class ToTensor(object): | |
# def __init__(self, mode, do_normalize=False, size=None): | |
# self.mode = mode | |
# self.normalize = transforms.Normalize( | |
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity() | |
# self.size = size | |
# if size is not None: | |
# self.resize = transforms.Resize(size=size) | |
# else: | |
# self.resize = nn.Identity() | |
# def __call__(self, sample): | |
# image, focal = sample['image'], sample['focal'] | |
# image = self.to_tensor(image) | |
# image = self.normalize(image) | |
# image = self.resize(image) | |
# if self.mode == 'test': | |
# return {'image': image, 'focal': focal} | |
# depth = sample['depth'] | |
# if self.mode == 'train': | |
# depth = self.to_tensor(depth) | |
# return {**sample, 'image': image, 'depth': depth, 'focal': focal} | |
# else: | |
# has_valid_depth = sample['has_valid_depth'] | |
# image = self.resize(image) | |
# return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth, | |
# 'image_path': sample['image_path'], 'depth_path': sample['depth_path']} | |
# def to_tensor(self, pic): | |
# if not (_is_pil_image(pic) or _is_numpy_image(pic)): | |
# raise TypeError( | |
# 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) | |
# if isinstance(pic, np.ndarray): | |
# img = torch.from_numpy(pic.transpose((2, 0, 1))) | |
# return img | |
# # handle PIL Image | |
# if pic.mode == 'I': | |
# img = torch.from_numpy(np.array(pic, np.int32, copy=False)) | |
# elif pic.mode == 'I;16': | |
# img = torch.from_numpy(np.array(pic, np.int16, copy=False)) | |
# else: | |
# img = torch.ByteTensor( | |
# torch.ByteStorage.from_buffer(pic.tobytes())) | |
# # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK | |
# if pic.mode == 'YCbCr': | |
# nchannel = 3 | |
# elif pic.mode == 'I;16': | |
# nchannel = 1 | |
# else: | |
# nchannel = len(pic.mode) | |
# img = img.view(pic.size[1], pic.size[0], nchannel) | |
# img = img.transpose(0, 1).transpose(0, 2).contiguous() | |
# if isinstance(img, torch.ByteTensor): | |
# return img.float() | |
# else: | |
# return img | |
class ToTensor(object): | |
def __init__(self, mode, do_normalize=False, size=None, sec_stage=False): | |
self.mode = mode | |
# don't do normalization as default | |
self.normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity() | |
self.size = size | |
if size is not None: | |
# self.resize = transforms.Resize(size=size) | |
net_h, net_w = size | |
self.resize = Resize(net_w, net_h, keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal") | |
else: | |
self.resize = nn.Identity() | |
self.sec_stage = sec_stage | |
def __call__(self, sample): | |
image, focal = sample['image'], sample['focal'] | |
crop_areas = sample.get('crop_area', None) | |
if isinstance(image, list): | |
# there must be crop_areas | |
# only infer on eval sec_stage | |
imgs_process = [] | |
crp_process = [] | |
for img, crp in zip(image, crop_areas): | |
img = self.to_tensor(img) | |
img = self.normalize(img) | |
img = img.unsqueeze(dim=0) | |
img = self.resize(img) | |
img = img.squeeze(dim=0) | |
imgs_process.append(img) | |
crp = self.to_tensor(crp) | |
crp = crp.unsqueeze(dim=0) | |
crp = self.resize(crp) | |
crp = crp.squeeze(dim=0) | |
crp_process.append(crp) | |
image = torch.cat(imgs_process, dim=0) | |
crop_areas = torch.cat(crp_process, dim=0) | |
img_temp = sample['img_temp'] | |
img_temp = self.to_tensor(img_temp) | |
img_temp = self.normalize(img_temp) | |
img_temp = img_temp.unsqueeze(dim=0) | |
img_temp = self.resize(img_temp) #NOTE: hack | |
img_temp = img_temp.squeeze(dim=0) | |
image_raw = copy.deepcopy(img_temp) | |
else: | |
image = self.to_tensor(image) | |
image = self.normalize(image) | |
if crop_areas is not None: | |
crop_areas = self.to_tensor(crop_areas) | |
crop_areas = crop_areas.unsqueeze(dim=0) | |
crop_areas = self.resize(crop_areas) | |
crop_areas = crop_areas.squeeze(dim=0) | |
if self.sec_stage: | |
img_temp = sample['img_temp'] | |
img_temp = self.to_tensor(img_temp) | |
img_temp = self.normalize(img_temp) | |
img_temp = img_temp.unsqueeze(dim=0) | |
img_temp = self.resize(img_temp) | |
image_raw = img_temp.squeeze(dim=0) | |
image = image.unsqueeze(dim=0) | |
image = self.resize(image) | |
image = image.squeeze(dim=0) | |
else: | |
# in the first stage, this hr info is reserved | |
image_raw = copy.deepcopy(image) | |
image = image.unsqueeze(dim=0) | |
image = self.resize(image) | |
image = image.squeeze(dim=0) | |
if self.mode == 'test': | |
return_dict = {'image': image, 'focal': focal} | |
if crop_areas is not None: | |
return_dict['crop_area'] = crop_areas | |
return return_dict | |
depth = sample['depth'] | |
depth = self.to_tensor(depth) | |
depth_gt_temp = sample['depth_gt_temp'] | |
depth_gt_raw = self.to_tensor(depth_gt_temp) | |
if self.mode == 'train': | |
return_dict = {**sample, 'image': image, 'depth': depth, 'focal': focal, 'image_raw': image_raw, 'depth_raw': depth_gt_raw} | |
if crop_areas is not None: | |
return_dict['crop_area'] = crop_areas | |
return return_dict | |
else: | |
has_valid_depth = sample['has_valid_depth'] | |
# image = self.resize(image) | |
return_dict = {**sample, 'image': image, 'depth': depth, 'focal': focal, 'image_raw': image_raw, | |
'has_valid_depth': has_valid_depth, 'image_path': sample['image_path'], 'depth_path': sample['depth_path'], | |
'depth_raw': depth_gt_raw} | |
if crop_areas is not None: | |
return_dict['crop_area'] = crop_areas | |
return return_dict | |
def to_tensor(self, pic): | |
if isinstance(pic, np.ndarray): | |
img = torch.from_numpy(pic.transpose((2, 0, 1))) # img here | |
return img | |