Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Merge data shards generated from `encode_{extern,openx}_dataset.py` | |
In addition to CLI args, `SHARD_DATA_FORMAT` must be changed depending on the dataset. | |
""" | |
import argparse | |
import json | |
import os | |
import numpy as np | |
from tqdm.auto import tqdm | |
SHARD_DATA_FORMAT = "/private/home/xinleic/LR/HPT-Video-KZ/sharded_data/droid_magvit_shard{}_of_{}_train" | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--out_data_dir", type=str, required=True, | |
help="Directory to save merged data, must not exist.") | |
parser.add_argument("--num_shards", type=int, required=True, help="Number of shards the dataset was split into.") | |
args = parser.parse_args() | |
assert not os.path.exists(args.out_data_dir), "Will not overwrite existing directory." | |
os.makedirs(os.path.join(args.out_data_dir, "actions"), exist_ok=True) | |
num_frames = 0 | |
valid_inds = [] | |
for shard_ind in range(args.num_shards): | |
shard_path = SHARD_DATA_FORMAT.format(shard_ind, args.num_shards) | |
if os.path.isfile(os.path.join(shard_path, "metadata.json")): | |
valid_inds.append(shard_ind) | |
with open(os.path.join(shard_path, "metadata.json"), "r") as f: | |
shard_metadata = json.load(f) | |
num_frames += shard_metadata["num_images"] | |
else: | |
print(f"{shard_ind=} is invalid.") | |
if num_frames == 0: | |
print("No valid shards") | |
exit(0) | |
token_dtype = np.dtype(shard_metadata["token_dtype"]) | |
if shard_metadata["quantized"]: | |
frame_dims = (shard_metadata["h"], shard_metadata["w"]) | |
else: | |
frame_dims = (shard_metadata["latent_channels"], shard_metadata["h"], shard_metadata["w"]) | |
action_dim = shard_metadata["action_dim"] | |
videos = np.memmap( | |
os.path.join(args.out_data_dir, "video.bin"), | |
dtype=token_dtype, | |
mode="write", | |
shape=(num_frames, *frame_dims) | |
) | |
actions = np.memmap( | |
os.path.join(args.out_data_dir, "actions", "actions.bin"), | |
dtype=np.float32, | |
mode="write", | |
shape=(num_frames, action_dim) | |
) | |
segment_ids = np.memmap( | |
os.path.join(args.out_data_dir, "segment_ids.bin"), | |
dtype=np.int32, | |
mode="write", | |
shape=(num_frames,) | |
) | |
prev_frame_ind = 0 | |
prev_segment_id = 0 | |
for shard_ind in tqdm(valid_inds): | |
shard_path = SHARD_DATA_FORMAT.format(shard_ind, args.num_shards) | |
with open(os.path.join(shard_path, "metadata.json"), "r") as f: | |
shard_metadata = json.load(f) | |
shard_num_frames = shard_metadata["num_images"] | |
videos[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap( | |
os.path.join(shard_path, "video.bin"), | |
dtype=np.dtype(shard_metadata["token_dtype"]), | |
mode="r", | |
shape=(shard_num_frames, *frame_dims), | |
) | |
actions[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap( | |
os.path.join(shard_path, "actions", "actions.bin"), | |
dtype=np.float32, | |
mode="r", | |
shape=(shard_num_frames, action_dim), | |
) | |
segment_ids[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap( | |
os.path.join(shard_path, "segment_ids.bin"), | |
dtype=np.int32, | |
mode="r", | |
shape=(shard_num_frames,), | |
) + prev_segment_id | |
prev_segment_id = segment_ids[prev_frame_ind + shard_num_frames - 1] + 1 | |
prev_frame_ind += shard_num_frames | |
assert prev_frame_ind == num_frames | |
print("Finished") | |
with (open(os.path.join(args.out_data_dir, "metadata.json"), "w") as f): | |
merged_metadata = shard_metadata \ | |
| vars(args) \ | |
| {"num_images": num_frames, "input_path": SHARD_DATA_FORMAT.format(0, args.num_shards)} | |
json.dump(merged_metadata, f) |