Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
sys.path.insert(1, '.') | |
import numpy as np | |
from omegaconf import DictConfig | |
import torch | |
from PIL import Image | |
import torchvision | |
import cv2 | |
import matplotlib.pyplot as plt | |
from ldm.util import instantiate_from_config | |
import os | |
import io | |
import pickle | |
import webdataset as wds | |
import imageio | |
import time | |
from torch import distributed as dist | |
from itertools import chain | |
class ObjaverseDataDecoder: | |
def __init__(self, | |
target_name="albedo", | |
image_transforms=[], | |
default_trans=torch.zeros(3), | |
postprocess=None, | |
return_paths=False, | |
mask_name="alpha", | |
test=False, | |
condition_name=None, | |
bg_color="white", | |
target_name_pool=None, | |
**kargs | |
) -> None: | |
"""Create a dataset from blender rendering results. | |
If you pass in a root directory it will be searched for images | |
ending in ext (ext can be a list) | |
""" | |
# testing behaves differently | |
self.test = test | |
self.target_name = target_name | |
self.mask_name = mask_name | |
self.default_trans = default_trans | |
self.return_paths = return_paths | |
if isinstance(postprocess, DictConfig): | |
postprocess = instantiate_from_config(postprocess) | |
self.postprocess = postprocess | |
# extra condition | |
self.condition_name = condition_name | |
self.target_name_pool = target_name_pool if not target_name_pool is None else [target_name] | |
self.counter = 0 | |
self.tform = image_transforms["totensor"] | |
self.img_size = image_transforms["size"] | |
self.tsize = torchvision.transforms.Compose([torchvision.transforms.Resize(self.img_size)]) | |
if bg_color == "white": | |
self.bg_color = [1., 1., 1., 1.] | |
elif bg_color == "noise": | |
self.bg_color = "noise" | |
else: | |
raise NotImplementedError | |
def path_parsing(self, filename, cond_name=None): | |
# cached path loads albedo | |
if 'albedo' in filename: | |
filename = filename.replace('albedo', self.target_name) | |
if self.target_name=="gloss_shaded": | |
filename = filename.replace('gloss_direct', self.target_name).replace("exr", "jpg") | |
filename_targets = [filename.replace(self.target_name, "gloss_direct").replace("jpg", "exr"), | |
filename.replace(self.target_name, "gloss_color")] | |
elif self.target_name=="diffuse_shaded": | |
filename = filename.replace('diffuse_direct', self.target_name).replace("exr", "jpg") | |
filename_targets = [filename.replace(self.target_name, "diffuse_direct").replace("jpg", "exr"), | |
filename.replace(self.target_name, "albedo")] | |
else: | |
filename_targets = None | |
normal_condition_filename = None | |
if self.test and "images_train" in filename: | |
# Currently. "images_train" exists in test set, we write this for clearity | |
condition_filename = filename | |
mask_filename = filename.replace('images_train', 'masks') | |
if self.condition_name == "normal": | |
raise NotImplementedError("Testing with normal conditioning on custom data is not supported") | |
else: | |
cond_name_prefix = filename.split(".", 1)[0] + "." if cond_name is None else cond_name | |
condition_filename = cond_name_prefix + filename.rsplit('.', 1)[1] | |
mask_filename = filename.replace(self.target_name, self.mask_name) | |
if self.condition_name == "normal": | |
normal_condition_filename = filename.replace(self.target_name, "normal") | |
return filename, condition_filename, mask_filename, normal_condition_filename, filename_targets | |
def read_images(self, filename, condition_filename, normal_condition_filename): | |
# image reading | |
if self.target_name in ["gloss_shaded", "diffuse_shaded"]: | |
target_im_0 = np.array(self.normalized_read(filename[0])) | |
target_im_1 = np.array(self.normalized_read(filename[1])) | |
target_im = np.clip(target_im_0 * target_im_1, 0, 1) | |
else: | |
target_im = np.array(self.normalized_read(filename)) | |
cond_im = np.array(self.normalized_read(condition_filename)) | |
if self.condition_name == "normal": | |
normal_img = np.array(self.normalized_read(normal_condition_filename)) | |
else: | |
normal_img = None | |
return target_im, cond_im, normal_img | |
def image_post_processing(self, img_mask, target_im, cond_im, normal_img): | |
# make sure image has 3 dimension | |
if len(img_mask.shape) == 2: | |
img_mask = img_mask[:, :, np.newaxis] | |
else: | |
img_mask = img_mask[:, :, :3] | |
# transform into desired format | |
target_im, crop_idx = self.load_im(target_im, img_mask, self.bg_color, crop_idx=True) | |
target_im = np.uint8(self.tsize(target_im)) | |
cond_im = np.uint8(self.tsize(self.load_im(cond_im, img_mask, self.bg_color))) | |
if self.condition_name == "normal": | |
normal_img = np.uint8(self.tsize(self.load_im(normal_img, img_mask, self.bg_color))) | |
else: | |
normal_img = None | |
return target_im, cond_im, normal_img, crop_idx | |
# def cartesian_to_spherical(self, xyz): | |
# ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) | |
# xy = xyz[:,0]**2 + xyz[:,1]**2 | |
# z = np.sqrt(xy + xyz[:,2]**2) | |
# theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down | |
# #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up | |
# azimuth = np.arctan2(xyz[:,1], xyz[:,0]) | |
# return np.array([theta, azimuth, z]) | |
def load_im(self, img, img_mask, color, crop_idx=False): | |
''' | |
replace background pixel with random color in rendering | |
''' | |
# our rendering do not have a valid alpha channel. | |
# We use a seperate mask, which also do not have a valid alpha | |
if img.shape[-1] == 3: | |
img = np.concatenate([img, np.ones_like(img[..., :1])], axis=-1) | |
# image maske shape align with image size | |
if (img.shape[0] != img_mask.shape[0]) or (img.shape[1] != img_mask.shape[1]): | |
img_mask = cv2.resize(img_mask, | |
(img.shape[1], img.shape[0]), | |
interpolation=cv2.INTER_NEAREST)[:, :, np.newaxis] | |
if isinstance(color, str): | |
random_img = np.random.rand(*(img.shape)) | |
img[img_mask[:, :, -1] <= 0.5] = random_img[img_mask[:, :, -1] <= 0.5] | |
else: | |
img[img_mask[:, :, -1] <= 0.5] = color | |
if self.test: | |
# crop out valid_mask | |
img, crop_uv = self.center_crop(img[:, :, :3], img_mask) | |
else: | |
crop_uv = None | |
# center crop | |
if img.shape[0] > img.shape[1]: | |
margin = int((img.shape[0] - img.shape[1]) // 2) | |
img = img[margin:margin+img.shape[1]] | |
elif img.shape[1] > img.shape[0]: | |
margin = int((img.shape[1] - img.shape[0]) // 2) | |
img = img[:, margin:margin+img.shape[0]] | |
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) | |
if crop_idx: | |
return img, crop_uv | |
return img | |
def center_crop(self, img, mask, mask_ratio=.8): | |
mask_uvs = np.vstack(np.nonzero(mask[:, :, -1] > 0.5)) | |
min_uv, max_uv = np.min(mask_uvs, axis=-1), np.max(mask_uvs, axis=-1) | |
img = img + (mask[..., -1:] <= 0.5) | |
half_size = int(max(max_uv - min_uv) // 2) | |
crop_length = (max_uv - min_uv) // 2 | |
center_uv = min_uv + crop_length | |
expand_hasl_size = int(half_size / mask_ratio) | |
size = expand_hasl_size * 2 + 1 | |
img_new = np.ones((size, size, 3)) | |
img_new[expand_hasl_size-crop_length[0]:expand_hasl_size+crop_length[0]+1, expand_hasl_size-crop_length[1]:expand_hasl_size+crop_length[1]+1] = \ | |
img[center_uv[0]-crop_length[0]:center_uv[0]+crop_length[0]+1, center_uv[1]-crop_length[1]:center_uv[1]+crop_length[1]+1] | |
crop_uv = np.array([expand_hasl_size, crop_length[0], crop_length[1], center_uv[0], center_uv[1], size], dtype=int) | |
return img_new, crop_uv | |
def transform_normal(self, normal_input, cam): | |
# load camera | |
img_mask = torch.linalg.norm(normal_input, dim=-1) > 1.5 | |
extrinsic, K = cam | |
extrinsic = np.concatenate([extrinsic, np.zeros(4).reshape(1, 4)], axis=0) | |
extrinsic[3, 3] = 1 | |
pose = np.linalg.inv(extrinsic) | |
temp = pose[1] + 0.0 | |
pose[1] = -pose[2] | |
pose[2] = temp | |
extrinsic = torch.from_numpy(np.linalg.inv(pose)).float() | |
# to normal | |
normal_img = extrinsic[None, :3, :3] @ normal_input[..., :3].reshape(-1, 3, 1) | |
normal_img = normal_img.reshape(normal_input.shape[0], normal_input.shape[1], 3) | |
normal_img[img_mask] = 1.0 | |
return normal_img | |
def parse_item(self, target_im, cond_img, normal_img, filename, target_ids, **args): | |
data = {} | |
# we need to transform normal to cmaera frame | |
if self.target_name == "normal": | |
target_im = self.transform_normal(target_im, self.get_camera(filename, **args)) | |
# normal conditioning | |
if self.condition_name == "normal": | |
normal_img = self.transform_normal(normal_img, self.get_camera(filename, **args)) | |
data["image_target"] = target_im | |
data["image_cond"] = cond_img | |
if self.condition_name == "normal": | |
data["img_normal"] = normal_img | |
if self.test or self.return_paths: | |
data["path"] = str(filename) | |
data["label"] = torch.zeros(1).reshape(1, 1, 1)+target_ids | |
if self.postprocess is not None: | |
data = self.postprocess(data) | |
return data | |
def normalized_read(self, imgpath): | |
img = np.array(imageio.imread(imgpath)) | |
if img.dtype == np.uint8: | |
img = img / 255.0 | |
else: | |
img = img ** (1 / 2.2) | |
return img | |
def process_im(self, im): | |
im = Image.fromarray(im) | |
im = im.convert("RGB") | |
return self.tform(im) | |
class ObjaverseDecoerWDS(ObjaverseDataDecoder): | |
def __init__(self, **kargs) -> None: | |
super().__init__(**kargs) | |
def dict2tuple(self, data): | |
returns = (data["image_target"], data["image_cond"],data["label"],) | |
if self.condition_name == "normal": | |
returns +=(data["img_normal"], ) | |
if self.test or self.return_paths: | |
returns += (data["path"],) | |
return returns | |
def tuple2dict(self, data): | |
returns = {} | |
returns["image_target"] = data[0] | |
returns["image_cond"] = data[1] | |
returns["label"] = data[2] | |
if self.condition_name == "normal": | |
returns["img_normal"] = data[3] | |
if self.test or self.return_paths: | |
returns["path"] = data[-1] | |
return returns | |
def data_filter(self, albedo, spec, diffuse_shad, spec_shad): | |
returns = {} | |
returns["image_target"] = data[0] | |
returns["image_cond"] = data[1] | |
if self.condition_name == "normal": | |
returns["img_normal"] = data[2] | |
if self.test or self.return_paths: | |
returns["path"] = data[-1] | |
return returns | |
def get_camera(self, input_filename, sample): | |
camera_file = input_filename.replace(f'{self.target_name}0001', \ | |
'camera').rsplit(".")[0] + ".pkl" | |
mask_filename_byte = io.BytesIO(sample[camera_file]) | |
cam = pickle.load(mask_filename_byte) | |
return cam | |
def process_sample(self, sample): | |
# start_worker=time.time() | |
results = [] | |
for target_ids, target_name in enumerate(self.target_name_pool): | |
_result = self.process_sample_single(sample, target_ids, target_name) | |
results.append(self.dict2tuple(_result)) | |
results = wds.filters.default_collation_fn(results) | |
return results | |
def batch_reordering(self, sample): | |
batch_splits = [] | |
for data_idx, _ in enumerate(sample): | |
batch_splits.append( | |
torch.cat( | |
torch.chunk(sample[data_idx], dim=1, | |
chunks=len(self.target_name_pool)), | |
dim=0)[:,0] | |
) | |
return self.tuple2dict(batch_splits) | |
def process_sample_single(self, sample, target_ids, target_name): | |
# get target image filename | |
self.target_name = target_name | |
target_file_name = self.target_name | |
if self.target_name=="gloss_shaded": | |
target_file_name = "gloss_direct" | |
elif self.target_name=="diffuse_shaded": | |
target_file_name = "diffuse_direct" | |
for k in list(sample.keys()): | |
if target_file_name not in k: | |
continue | |
target_key = k | |
break | |
# ############## | |
# prev_time = start_worker | |
# current_time = time.time() | |
# print(f"find target takes: {current_time - prev_time}") | |
# ############## | |
filename, condition_filename, \ | |
mask_filename, normal_condition_filename, filename_targets = self.path_parsing(target_key, "") | |
# get file streams | |
if filename_targets is None: | |
filename_byte = io.BytesIO(sample[filename]) | |
else: | |
filename_byte = [io.BytesIO(sample[filename_target]) for filename_target in filename_targets] | |
condition_filename_byte = io.BytesIO(sample[condition_filename]) | |
normal_condition_filename_byte = io.BytesIO(sample[normal_condition_filename]) \ | |
if self.condition_name == "normal" else None | |
mask_filename_byte = io.BytesIO(sample[mask_filename]) | |
# image reading | |
target_im, cond_im, normal_img = self.read_images(filename_byte, | |
condition_filename_byte, normal_condition_filename_byte) | |
# mask reading | |
img_mask = np.array(self.normalized_read(mask_filename_byte)) | |
# post processing | |
target_im, cond_im, normal_img, _ = self.image_post_processing(img_mask, target_im, cond_im, normal_img) | |
# transform | |
target_im = self.process_im(target_im) | |
cond_im = self.process_im(cond_im) | |
normal_img = self.process_im(normal_img) \ | |
if self.condition_name == "normal" \ | |
else None | |
data = self.parse_item(target_im, cond_im, normal_img, filename, target_ids, sample=sample) | |
# override for file path | |
if self.test or self.return_paths: | |
data["path"] = sample["__key__"] | |
result = dict(__key__=sample["__key__"]) | |
result.update(data) | |
return result | |
if __name__=="__main__": | |
from torchvision import transforms | |
from einops import rearrange | |
torch.distributed.init_process_group(backend="nccl") | |
image_transforms = [transforms.ToTensor(), | |
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))] | |
image_transforms = torchvision.transforms.Compose(image_transforms) | |
image_transforms = { | |
"size": 256, | |
"totensor": image_transforms | |
} | |
data_list_dir = "/home/chenxi/code/material-diffusion/data/big_data_lists" | |
tar_name_list = sorted(os.listdir(data_list_dir))[1:4] | |
tar_list = [_name.rsplit("_num")[0]+".tar" for _name in tar_name_list] | |
tar_dir = "/home/chenxi/code/material-diffusion/data/big_data_transed" | |
tars = [os.path.join(tar_dir, _name) for _name in tar_list] | |
dataset_size = 0 | |
imgperobj = 10 | |
print("list dirs...") | |
for _name in tar_name_list: | |
num_obj = int(_name.rsplit("_num_")[1].rsplit(".")[0]) | |
print(num_obj, " : ", _name) | |
dataset_size += num_obj * imgperobj | |
decoder = ObjaverseDecoerWDS(image_transforms=image_transforms, | |
return_paths=True) | |
batch_size = 8 | |
print('============= length of training dataset %d =============' % (dataset_size // batch_size // 2)) | |
dataset = (wds.WebDataset(tars, | |
repeat=0, | |
nodesplitter=wds.shardlists.split_by_node) | |
.shuffle(100) | |
.map(decoder.process_sample) | |
.map(decoder.dict2tuple) | |
.batched(batch_size, partial=False) | |
.map(decoder.tuple2dict) | |
.with_epoch(dataset_size // batch_size // 2) | |
.with_length(dataset_size // batch_size) | |
) | |
from torch.utils.data import DataLoader | |
# loader = DataLoader(dataset, batch_size=None, num_workers=8, shuffle=False) | |
loader = (wds.WebLoader(dataset, batch_size=None, num_workers=2, shuffle=False) | |
.map(decoder.dict2tuple) | |
.unbatched() | |
# .shuffle(100) | |
.batched(batch_size) | |
.map(decoder.tuple2dict) | |
) | |
print("# loader length", len(dataset)) | |
for epoch in range(2): | |
ind = -1 | |
for sample in loader: | |
assert "image_target" in sample | |
assert "image_cond" in sample | |
assert "path" in sample | |
ind += 1 | |
if ind != 0: | |
continue | |
# replace to this for file path | |
# worker_info = torch.utils.data.get_worker_info() | |
# if worker_info is not None: | |
# worker = worker_info.id | |
# num_workers = worker_info.num_workers | |
# data["path"] = sample["__url__"]+"--"+sample["__key__"] +f".{worker}/{num_workers}" | |
# print(f"{ind}: shape {sample['image_target'].shape} {sample['path'][0].rsplit('/', 1)[-2]}") | |
print("##############") | |
for i in range(len(sample['path'])): | |
print(f"epoch {epoch}, it {ind}: shape {sample['image_target'].shape} {sample['path'][i].rsplit('--', 1)[0].rsplit('/', 2)[-1]} {sample['path'][i].rsplit('--', 1)[1].rsplit('/', 3)[-3]} {sample['path'][i].rsplit('--', 1)[1].rsplit('/',4)[-4]} {sample['path'][i].rsplit('.', 1)[-1]} rank: {dist.get_rank()}") | |
print("##############") | |
print(sample["path"]) | |
print(sample["path"]) | |
print(f"NUmber of samples: {ind} {dataset_size} {len(dataset)} rank: {dist.get_rank()}") | |
# 1. Remember samples are batched inside each worker, the outside data loader only sees one sample | |
# 2. All batch, epoch, and length settings are only visible within each worker | |
# 3. Unbatch and Suffle and then re-batch in loader result in between worker shuffle. | |
# This also allows to control of loader batching and worker batching for CPU optimization of worker-loader data transfer. | |
# https://github.com/webdataset/webdataset/issues/141#issuecomment-1043190147 | |
# 4. It seems that data just repeat forever to satisfy with_epoch | |
# 5. Torch datalogger requires the dataset to have a len() method, which is used to schdule sample idx | |
# 6. DDP sampler will return its only length | |
# 7. WebLoader does not need length, it only raises the end of the iteration when data is running out | |
# 8. How does torch loader deal with datasets with fewer sizes than claims? | |
# 9. Set epoch will make sampling start from the beginning when a new epoch starts. Observed by disable shuffle and one batch repeat | |
# And each epoch will have a different sampling seed | |
# 10. DataLoader with IterableDataset: expected unspecified sampler option. DDP sampler will not be usable. | |
# !0. In summary: | |
# For ddp multi-worker training, the worker splitter and node splitter will make sure tars are splitted into each worker | |
# We have to manually adjust with_epoch with respect to num_worker and num_node and batch_size | |
def nodesplitter(src, group=None): | |
if torch.distributed.is_initialized(): | |
if group is None: | |
group = torch.distributed.group.WORLD | |
rank = torch.distributed.get_rank(group=group) | |
size = torch.distributed.get_world_size(group=group) | |
print(f"nodesplitter: rank={rank} size={size}") | |
count = 0 | |
for i, item in enumerate(src): | |
if i % size == rank: | |
yield item | |
count += 1 | |
print(f"nodesplitter: rank={rank} size={size} count={count} DONE") | |
else: | |
yield from src |