self-forcing / scripts /create_lmdb_iterative.py
multimodalart's picture
Upload 80 files
0fd2f06 verified
raw
history blame
1.73 kB
from tqdm import tqdm
import numpy as np
import argparse
import torch
import lmdb
import glob
import os
from utils.lmdb import store_arrays_to_lmdb, process_data_dict
def main():
"""
Aggregate all ode pairs inside a folder into a lmdb dataset.
Each pt file should contain a (key, value) pair representing a
video's ODE trajectories.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str,
required=True, help="path to ode pairs")
parser.add_argument("--lmdb_path", type=str,
required=True, help="path to lmdb")
args = parser.parse_args()
all_files = sorted(glob.glob(os.path.join(args.data_path, "*.pt")))
# figure out the maximum map size needed
total_array_size = 5000000000000 # adapt to your need, set to 5TB by default
env = lmdb.open(args.lmdb_path, map_size=total_array_size * 2)
counter = 0
seen_prompts = set() # for deduplication
for index, file in tqdm(enumerate(all_files)):
# read from disk
data_dict = torch.load(file)
data_dict = process_data_dict(data_dict, seen_prompts)
# write to lmdb file
store_arrays_to_lmdb(env, data_dict, start_index=counter)
counter += len(data_dict['prompts'])
# save each entry's shape to lmdb
with env.begin(write=True) as txn:
for key, val in data_dict.items():
print(key, val)
array_shape = np.array(val.shape)
array_shape[0] = counter
shape_key = f"{key}_shape".encode()
shape_str = " ".join(map(str, array_shape))
txn.put(shape_key, shape_str.encode())
if __name__ == "__main__":
main()