Spaces:
Running
on
Zero
Running
on
Zero
from tqdm import tqdm | |
import numpy as np | |
import argparse | |
import torch | |
import lmdb | |
import glob | |
import os | |
from utils.lmdb import store_arrays_to_lmdb, process_data_dict | |
def main(): | |
""" | |
Aggregate all ode pairs inside a folder into a lmdb dataset. | |
Each pt file should contain a (key, value) pair representing a | |
video's ODE trajectories. | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--data_path", type=str, | |
required=True, help="path to ode pairs") | |
parser.add_argument("--lmdb_path", type=str, | |
required=True, help="path to lmdb") | |
args = parser.parse_args() | |
all_files = sorted(glob.glob(os.path.join(args.data_path, "*.pt"))) | |
# figure out the maximum map size needed | |
total_array_size = 5000000000000 # adapt to your need, set to 5TB by default | |
env = lmdb.open(args.lmdb_path, map_size=total_array_size * 2) | |
counter = 0 | |
seen_prompts = set() # for deduplication | |
for index, file in tqdm(enumerate(all_files)): | |
# read from disk | |
data_dict = torch.load(file) | |
data_dict = process_data_dict(data_dict, seen_prompts) | |
# write to lmdb file | |
store_arrays_to_lmdb(env, data_dict, start_index=counter) | |
counter += len(data_dict['prompts']) | |
# save each entry's shape to lmdb | |
with env.begin(write=True) as txn: | |
for key, val in data_dict.items(): | |
print(key, val) | |
array_shape = np.array(val.shape) | |
array_shape[0] = counter | |
shape_key = f"{key}_shape".encode() | |
shape_str = " ".join(map(str, array_shape)) | |
txn.put(shape_key, shape_str.encode()) | |
if __name__ == "__main__": | |
main() | |