self-forcing / scripts /create_lmdb_iterative.py
multimodalart's picture
Upload 80 files
0fd2f06 verified
from tqdm import tqdm
import numpy as np
import argparse
import torch
import lmdb
import glob
import os
from utils.lmdb import store_arrays_to_lmdb, process_data_dict
def main():
"""
Aggregate all ode pairs inside a folder into a lmdb dataset.
Each pt file should contain a (key, value) pair representing a
video's ODE trajectories.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str,
required=True, help="path to ode pairs")
parser.add_argument("--lmdb_path", type=str,
required=True, help="path to lmdb")
args = parser.parse_args()
all_files = sorted(glob.glob(os.path.join(args.data_path, "*.pt")))
# figure out the maximum map size needed
total_array_size = 5000000000000 # adapt to your need, set to 5TB by default
env = lmdb.open(args.lmdb_path, map_size=total_array_size * 2)
counter = 0
seen_prompts = set() # for deduplication
for index, file in tqdm(enumerate(all_files)):
# read from disk
data_dict = torch.load(file)
data_dict = process_data_dict(data_dict, seen_prompts)
# write to lmdb file
store_arrays_to_lmdb(env, data_dict, start_index=counter)
counter += len(data_dict['prompts'])
# save each entry's shape to lmdb
with env.begin(write=True) as txn:
for key, val in data_dict.items():
print(key, val)
array_shape = np.array(val.shape)
array_shape[0] = counter
shape_key = f"{key}_shape".encode()
shape_str = " ".join(map(str, array_shape))
txn.put(shape_key, shape_str.encode())
if __name__ == "__main__":
main()