hma / datasets /encode_extern_dataset.py
LeroyWaa's picture
draft
246c106
raw
history blame
11.8 kB
# --------------------------------------------------------
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import argparse
import json
import os
import time
import traceback
from typing import Optional
import numpy as np
from tqdm import tqdm
from datasets.encode_openx_dataset import MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES, get_shard_inds, VAL_RATIO, \
process_dataset_step, DATA_FREQ_TABLE
from datasets.extern.ego4d import ego4d_dataset_size, ego4d_dataset_generator
from datasets.extern.egoexo4d import egoexo4d_dataset_size, egoexo4d_dataset_generator
from datasets.extern.robomimic import robomimic_dataset_generator, robomimic_dataset_size
from . import utils
SCRIPT_DESCRIPTION="""
Similar to encode_openx_dataset.py except for non-OpenX datasets.
Again, each split can be partitioned into multiple shards,
which is useful for parallelized encoding across GPUs.
Example usage:
CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_extern_dataset --dataset_name egoexo4d --data_split train --num_shards 1000 --curr_shard_rank 400
Untested usage (SVD tokenizer):
CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_extern_dataset --dataset_name robomimic --data_split val --no_quantization --encoder_type temporalvae --encoder_name_or_path 'stabilityai/stable-video-diffusion-img2vid'
""".strip()
DATASET_TO_GEN_AND_SIZE = {
"ego4d": (ego4d_dataset_generator, ego4d_dataset_size),
"egoexo4d": (egoexo4d_dataset_generator, egoexo4d_dataset_size),
"robomimic": (robomimic_dataset_generator, robomimic_dataset_size),
}
def encode_dataset_split(
extern_dataset_name: str,
split: str,
max_episodes: Optional[int],
original_res: bool,
no_quantization: bool,
curr_shard_rank: int,
num_shards: int,
root_dir: str,
encoder_type: str,
encoder_name_or_path: str,
dataset_postfix: str = "",
no_encoding: bool = False,
):
"""
Encodes (e.g. tokenizes) dataset.
The data written to disk can be used to load a `RawTokenDataset` (or the continuous version.)
Args:
extern_dataset_name: TODO
split: expected to be either "train" or "val". TODO: decide how to split
max_episodes: the maximum number of trajectories to include in the dataset.
dataset_postfix: will be a suffix of the output dirname.
image_encoder: string specifying the type of image encoder/tokenizer to use.
original_res: if True, will maintain original resolution of the video rather than resizing it to 256x256.
no_quantization: if True, will not perform quantization step in image encoder.
"""
extern_dataset_name = extern_dataset_name.strip() # never modified
suffixed_dataset_name = extern_dataset_name # will modify later
if original_res:
suffixed_dataset_name = f"{suffixed_dataset_name}_originalres"
if no_quantization:
suffixed_dataset_name = f"{suffixed_dataset_name}_noquant"
if no_encoding:
suffixed_dataset_name = f"{suffixed_dataset_name}_noencoding"
save_dirname = "_".join([suffixed_dataset_name, encoder_type, dataset_postfix, split])
dataset_path = os.path.join(root_dir, save_dirname)
print("=" * 25)
print(f"{dataset_path=}")
utils.mkdir_if_missing(dataset_path)
# Load data
generator, size_func = DATASET_TO_GEN_AND_SIZE[extern_dataset_name]
num_examples = size_func()
if max_episodes is not None:
num_examples = min(num_examples, max_episodes) # clip num_examples
# We will only operate on a subset of the training examples, depending on:
# 1) The split (train/val). Some examples are reserved for the other split.
# 2) Sharding
assert num_examples > MIN_VAL_EXAMPLES # non-positive number of train examples otherwise
num_val_examples = np.clip(int(VAL_RATIO * num_examples), MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES)
if split == "train": # first_ind inclusive, last_ind exclusive
first_split_ind, last_split_ind = num_val_examples, num_examples
elif split == "val":
first_split_ind, last_split_ind = 0, num_val_examples
else:
raise NotImplementedError(f"{split=}")
first_shard_ind, last_shard_ind = get_shard_inds(first_split_ind, last_split_ind, curr_shard_rank, num_shards)
print(f"Total number of examples in {suffixed_dataset_name}: {num_examples}")
print(f"Number of examples for {split=}, shard {curr_shard_rank} of {num_shards}: "
f"{last_shard_ind - first_shard_ind}. {first_shard_ind=} {last_shard_ind=}")
##### Encode data #####
traj_lens = [] # only used to print statistics
videos = [] # NOTE: videos/actions for the entire shard are stored in RAM until the end
actions = []
segment_ids = []
# split based on some fixed batch sizes to reset RAM.
max_batch_per_loading = 10
pbar = tqdm(range(first_shard_ind, last_shard_ind, max_batch_per_loading), position=0, leave=True)
start_time = time.time()
for start_idx in pbar:
end_idx = min(start_idx + max_batch_per_loading, last_shard_ind)
pbar.set_description(f"{suffixed_dataset_name} caching episodes: {start_idx}:{end_idx}")
ds = generator(range(start_idx, end_idx))
for chunk_idx, episode in enumerate(tqdm(ds, position=1, leave=False)):
segment_id = start_idx + chunk_idx
try:
# batchify the data and then process
for step_ind, step_data in enumerate(episode["steps"]):
dataset_step = process_dataset_step(
step_data,
encoder_type=encoder_type,
encoder_name_or_path=encoder_name_or_path,
keep_res=original_res,
quantize=not no_quantization,
no_encoding=no_encoding
)
segment_ids.append(segment_id)
videos.append(dataset_step["image"])
actions.append(dataset_step["action"])
traj_lens.append(step_ind + 1) # number of steps in this trajectory
except:
print("-" * 25)
print(f"Add episode failed: {segment_id=}", traceback.format_exc(), suffixed_dataset_name)
# 2 day timeout
if time.time() - start_time > 86400 * 2:
print(f"Writing dataset {suffixed_dataset_name} timed out")
break
if len(videos) == 0:
print("Empty shard!")
with open(f"{dataset_path}/error.json", "w") as f:
json.dump({"status": "empty_shard"}, f)
return
if no_quantization:
num_channels, height, width = videos[-1].shape[:3] # num_channels is not actually stored in metadata
else:
height, width = videos[-1].shape[:2]
num_channels = None
##### Write videos, actions, segment_ids, and metadata #####
# align format to save segment_ids.bin, video.bin, actions/action.bin, metadata.json
# save videos
videos = np.stack(videos, axis=0)
# fp = np.memmap(f'{dataset_path}/video.bin', dtype=video_dtype, mode='w+', shape=videos.shape)
# fp[:] = videos[:]
videos.tofile(f'{dataset_path}/video.bin')
# save action
utils.mkdir_if_missing(f'{dataset_path}/actions')
actions = np.stack(actions, axis=0)
# fp = np.memmap(f'{dataset_path}/actions/actions.bin', dtype=np.float32, mode='w+', shape=actions.shape)
# fp[:] = actions[:]
actions = actions.astype(np.float32)
actions.tofile(f'{dataset_path}/actions/actions.bin')
# save segment_ids
segment_ids = np.array(segment_ids)
# fp = np.memmap(f'{dataset_path}/segment_ids.bin', dtype=np.int32, mode='w+', shape=segment_ids.shape)
# fp[:] = segment_ids[:] # map to trajectory index
segment_ids = segment_ids.astype(np.int32)
segment_ids.tofile(f'{dataset_path}/segment_ids.bin')
# feature_mean = np.mean(videos)
# feature_std = np.std((videos - feature_mean) / 1e9) * 1e9
# save metadata
if encoder_type == "magvit":
vocab_size = int(2 ** 18)
elif encoder_type == "temporalvae":
vocab_size = None
else:
raise NotImplementedError(f"{encoder_type=}")
with open(f'{dataset_path}/metadata.json', 'w') as f: # Technically only need to save most of this data for shard 0
json.dump({
"token_dtype": str(np.dtype(videos.dtype)),
"action_dim": actions[0].shape[-1],
"s": 16,
"h": height,
"w": width,
"vocab_size": vocab_size,
"hz": DATA_FREQ_TABLE.get(extern_dataset_name, 1), # to be loaded from the data code
"encoder_name_or_path": encoder_name_or_path,
"encoder_type": encoder_type,
"num_images": len(videos),
"latent_channels": num_channels,
"name": extern_dataset_name,
# "feature_mean": feature_mean,
# "feature_std": feature_std,
}, f)
print(f"{len(traj_lens)=} {np.mean(traj_lens)=} {np.sum(traj_lens)=}")
print(f"Dataset creation time: {time.time() - start_time:.3f}")
def parse_args():
parser = argparse.ArgumentParser(description=SCRIPT_DESCRIPTION)
parser.add_argument(
"--dataset_name", type=str, required=True, choices=DATASET_TO_GEN_AND_SIZE.keys(),
help="TODO"
)
parser.add_argument(
"--data_split", type=str, choices=["train", "val"], required=True,
help="The split of the dataset to create."
)
parser.add_argument(
"--episode_cnt", type=int,
help="If specified, will limit the maximum number of trajectories to encode."
)
parser.add_argument(
"--original_res", action='store_true',
help="Maintain original resolution of the video rather than resizing it to 256x256."
)
parser.add_argument(
"--no_quantization", action='store_true',
help="Skip quantization step in visual encoder."
)
parser.add_argument(
"--num_shards", type=int, default=1,
help="The number of shards to partition the train/val dataset into."
)
parser.add_argument(
"--curr_shard_rank", type=int, default=0,
help="The (0-indexed) shard number to encode."
)
parser.add_argument(
"--root_dir", type=str, default="data",
help="The root directory to write all datasets to."
)
parser.add_argument(
"--encoder_type", type=str, default="magvit", choices=["magvit", "temporalvae"],
help="Type of the image tokenizer."
)
parser.add_argument(
"--encoder_name_or_path", type=str, default="data/magvit2.ckpt",
help="The path or name of the image encoder."
)
parser.add_argument(
"--no_encoding", action='store_true',
help="Preserve the groundtruth raw images to compute metrics in validation."
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
utils.set_seed(233)
dataset_postfix = f"shard{args.curr_shard_rank}_of_{args.num_shards}"
if args.episode_cnt is not None:
dataset_postfix = f"max{args.episode_cnt}_{dataset_postfix}"
encode_dataset_split(
extern_dataset_name=args.dataset_name,
split=args.data_split,
max_episodes=args.episode_cnt,
dataset_postfix=dataset_postfix,
original_res=args.original_res,
no_quantization=args.no_quantization,
num_shards=args.num_shards,
curr_shard_rank=args.curr_shard_rank,
root_dir=args.root_dir,
encoder_type=args.encoder_type,
encoder_name_or_path=args.encoder_name_or_path,
no_encoding=args.no_encoding,
)