File size: 3,910 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Merge data shards generated from `encode_{extern,openx}_dataset.py`
In addition to CLI args, `SHARD_DATA_FORMAT` must be changed depending on the dataset.
"""

import argparse
import json
import os

import numpy as np
from tqdm.auto import tqdm

SHARD_DATA_FORMAT = "/private/home/xinleic/LR/HPT-Video-KZ/sharded_data/droid_magvit_shard{}_of_{}_train"


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--out_data_dir", type=str, required=True,
                        help="Directory to save merged data, must not exist.")
    parser.add_argument("--num_shards", type=int, required=True, help="Number of shards the dataset was split into.")

    args = parser.parse_args()
    assert not os.path.exists(args.out_data_dir), "Will not overwrite existing directory."
    os.makedirs(os.path.join(args.out_data_dir, "actions"), exist_ok=True)

    num_frames = 0
    valid_inds = []

    for shard_ind in range(args.num_shards):
        shard_path = SHARD_DATA_FORMAT.format(shard_ind, args.num_shards)
        if os.path.isfile(os.path.join(shard_path, "metadata.json")):
            valid_inds.append(shard_ind)
            with open(os.path.join(shard_path, "metadata.json"), "r") as f:
                shard_metadata = json.load(f)

            num_frames += shard_metadata["num_images"]
        else:
            print(f"{shard_ind=} is invalid.")

    if num_frames == 0:
        print("No valid shards")
        exit(0)

    token_dtype = np.dtype(shard_metadata["token_dtype"])
    if shard_metadata["quantized"]:
        frame_dims = (shard_metadata["h"], shard_metadata["w"])
    else:
        frame_dims = (shard_metadata["latent_channels"], shard_metadata["h"], shard_metadata["w"])

    action_dim = shard_metadata["action_dim"]
    videos = np.memmap(
        os.path.join(args.out_data_dir, "video.bin"),
        dtype=token_dtype,
        mode="write",
        shape=(num_frames, *frame_dims)
    )

    actions = np.memmap(
        os.path.join(args.out_data_dir, "actions", "actions.bin"),
        dtype=np.float32,
        mode="write",
        shape=(num_frames, action_dim)
    )

    segment_ids = np.memmap(
        os.path.join(args.out_data_dir, "segment_ids.bin"),
        dtype=np.int32,
        mode="write",
        shape=(num_frames,)
    )

    prev_frame_ind = 0
    prev_segment_id = 0

    for shard_ind in tqdm(valid_inds):
        shard_path = SHARD_DATA_FORMAT.format(shard_ind, args.num_shards)
        with open(os.path.join(shard_path, "metadata.json"), "r") as f:
            shard_metadata = json.load(f)

        shard_num_frames = shard_metadata["num_images"]
        videos[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap(
            os.path.join(shard_path, "video.bin"),
            dtype=np.dtype(shard_metadata["token_dtype"]),
            mode="r",
            shape=(shard_num_frames, *frame_dims),
        )

        actions[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap(
            os.path.join(shard_path, "actions", "actions.bin"),
            dtype=np.float32,
            mode="r",
            shape=(shard_num_frames, action_dim),
        )

        segment_ids[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap(
            os.path.join(shard_path, "segment_ids.bin"),
            dtype=np.int32,
            mode="r",
            shape=(shard_num_frames,),
        ) + prev_segment_id

        prev_segment_id = segment_ids[prev_frame_ind + shard_num_frames - 1] + 1
        prev_frame_ind += shard_num_frames

    assert prev_frame_ind == num_frames
    print("Finished")

    with (open(os.path.join(args.out_data_dir, "metadata.json"), "w") as f):
        merged_metadata = shard_metadata \
            | vars(args) \
            | {"num_images": num_frames, "input_path": SHARD_DATA_FORMAT.format(0, args.num_shards)}

        json.dump(merged_metadata, f)