Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,423 Bytes
0fd2f06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
"""
python create_lmdb_14b_shards.py \
--data_path /mnt/localssd/wanx_14b_data \
--lmdb_path /mnt/localssd/wanx_14B_shift-3.0_cfg-5.0_lmdb
"""
from tqdm import tqdm
import numpy as np
import argparse
import torch
import lmdb
import glob
import os
from utils.lmdb import store_arrays_to_lmdb, process_data_dict
def main():
"""
Aggregate all ode pairs inside a folder into a lmdb dataset.
Each pt file should contain a (key, value) pair representing a
video's ODE trajectories.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str,
required=True, help="path to ode pairs")
parser.add_argument("--lmdb_path", type=str,
required=True, help="path to lmdb")
parser.add_argument("--num_shards", type=int,
default=16, help="num_shards")
args = parser.parse_args()
all_dirs = sorted(os.listdir(args.data_path))
# figure out the maximum map size needed
map_size = int(1e12) # adapt to your need, set to 1TB by default
os.makedirs(args.lmdb_path, exist_ok=True)
# 1) Open one LMDB env per shard
envs = []
num_shards = args.num_shards
for shard_id in range(num_shards):
print("shard_id ", shard_id)
path = os.path.join(args.lmdb_path, f"shard_{shard_id}")
env = lmdb.open(path,
map_size=map_size,
subdir=True, # set to True if you want a directory per env
readonly=False,
metasync=True,
sync=True,
lock=True,
readahead=False,
meminit=False)
envs.append(env)
counters = [0] * num_shards
seen_prompts = set() # for deduplication
total_samples = 0
all_files = []
for part_dir in all_dirs:
all_files += sorted(glob.glob(os.path.join(args.data_path, part_dir, "*.pt")))
# 2) Prepare a write transaction for each shard
for idx, file in tqdm(enumerate(all_files)):
try:
data_dict = torch.load(file)
data_dict = process_data_dict(data_dict, seen_prompts)
except Exception as e:
print(f"Error processing {file}: {e}")
continue
if data_dict["latents"].shape != (1, 21, 16, 60, 104):
continue
shard_id = idx % num_shards
# write to lmdb file
store_arrays_to_lmdb(envs[shard_id], data_dict, start_index=counters[shard_id])
counters[shard_id] += len(data_dict['prompts'])
data_shape = data_dict["latents"].shape
total_samples += len(all_files)
print(len(seen_prompts))
# save each entry's shape to lmdb
for shard_id, env in enumerate(envs):
with env.begin(write=True) as txn:
for key, val in (data_dict.items()):
assert len(data_shape) == 5
array_shape = np.array(data_shape) # val.shape)
array_shape[0] = counters[shard_id]
shape_key = f"{key}_shape".encode()
print(shape_key, array_shape)
shape_str = " ".join(map(str, array_shape))
txn.put(shape_key, shape_str.encode())
print(f"Finished writing {total_samples} examples into {num_shards} shards under {args.lmdb_path}")
if __name__ == "__main__":
main()
|