hma / datasets /merge_shards.py
LeroyWaa's picture
draft
246c106
raw
history blame
3.91 kB
"""
Merge data shards generated from `encode_{extern,openx}_dataset.py`
In addition to CLI args, `SHARD_DATA_FORMAT` must be changed depending on the dataset.
"""
import argparse
import json
import os
import numpy as np
from tqdm.auto import tqdm
SHARD_DATA_FORMAT = "/private/home/xinleic/LR/HPT-Video-KZ/sharded_data/droid_magvit_shard{}_of_{}_train"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--out_data_dir", type=str, required=True,
help="Directory to save merged data, must not exist.")
parser.add_argument("--num_shards", type=int, required=True, help="Number of shards the dataset was split into.")
args = parser.parse_args()
assert not os.path.exists(args.out_data_dir), "Will not overwrite existing directory."
os.makedirs(os.path.join(args.out_data_dir, "actions"), exist_ok=True)
num_frames = 0
valid_inds = []
for shard_ind in range(args.num_shards):
shard_path = SHARD_DATA_FORMAT.format(shard_ind, args.num_shards)
if os.path.isfile(os.path.join(shard_path, "metadata.json")):
valid_inds.append(shard_ind)
with open(os.path.join(shard_path, "metadata.json"), "r") as f:
shard_metadata = json.load(f)
num_frames += shard_metadata["num_images"]
else:
print(f"{shard_ind=} is invalid.")
if num_frames == 0:
print("No valid shards")
exit(0)
token_dtype = np.dtype(shard_metadata["token_dtype"])
if shard_metadata["quantized"]:
frame_dims = (shard_metadata["h"], shard_metadata["w"])
else:
frame_dims = (shard_metadata["latent_channels"], shard_metadata["h"], shard_metadata["w"])
action_dim = shard_metadata["action_dim"]
videos = np.memmap(
os.path.join(args.out_data_dir, "video.bin"),
dtype=token_dtype,
mode="write",
shape=(num_frames, *frame_dims)
)
actions = np.memmap(
os.path.join(args.out_data_dir, "actions", "actions.bin"),
dtype=np.float32,
mode="write",
shape=(num_frames, action_dim)
)
segment_ids = np.memmap(
os.path.join(args.out_data_dir, "segment_ids.bin"),
dtype=np.int32,
mode="write",
shape=(num_frames,)
)
prev_frame_ind = 0
prev_segment_id = 0
for shard_ind in tqdm(valid_inds):
shard_path = SHARD_DATA_FORMAT.format(shard_ind, args.num_shards)
with open(os.path.join(shard_path, "metadata.json"), "r") as f:
shard_metadata = json.load(f)
shard_num_frames = shard_metadata["num_images"]
videos[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap(
os.path.join(shard_path, "video.bin"),
dtype=np.dtype(shard_metadata["token_dtype"]),
mode="r",
shape=(shard_num_frames, *frame_dims),
)
actions[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap(
os.path.join(shard_path, "actions", "actions.bin"),
dtype=np.float32,
mode="r",
shape=(shard_num_frames, action_dim),
)
segment_ids[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap(
os.path.join(shard_path, "segment_ids.bin"),
dtype=np.int32,
mode="r",
shape=(shard_num_frames,),
) + prev_segment_id
prev_segment_id = segment_ids[prev_frame_ind + shard_num_frames - 1] + 1
prev_frame_ind += shard_num_frames
assert prev_frame_ind == num_frames
print("Finished")
with (open(os.path.join(args.out_data_dir, "metadata.json"), "w") as f):
merged_metadata = shard_metadata \
| vars(args) \
| {"num_images": num_frames, "input_path": SHARD_DATA_FORMAT.format(0, args.num_shards)}
json.dump(merged_metadata, f)