Spaces:
Runtime error
Runtime error
import argparse | |
from typing import List | |
import pandas as pd | |
from mmengine.config import Config | |
from videogen_hub.pipelines.opensora.opensora.datasets.bucket import Bucket | |
def split_by_bucket( | |
bucket: Bucket, | |
input_files: List[str], | |
output_path: str, | |
limit: int, | |
frame_interval: int, | |
): | |
print(f"Split {len(input_files)} files into {len(bucket)} buckets") | |
total_limit = len(bucket) * limit | |
bucket_cnt = {} | |
# get all bucket id | |
for hw_id, d in bucket.ar_criteria.items(): | |
for t_id, v in d.items(): | |
for ar_id in v.keys(): | |
bucket_id = (hw_id, t_id, ar_id) | |
bucket_cnt[bucket_id] = 0 | |
output_df = None | |
# split files | |
for path in input_files: | |
df = pd.read_csv(path) | |
if output_df is None: | |
output_df = pd.DataFrame(columns=df.columns) | |
for i in range(len(df)): | |
row = df.iloc[i] | |
t, h, w = row["num_frames"], row["height"], row["width"] | |
bucket_id = bucket.get_bucket_id(t, h, w, frame_interval) | |
if bucket_id is None: | |
continue | |
if bucket_cnt[bucket_id] < limit: | |
bucket_cnt[bucket_id] += 1 | |
output_df = pd.concat([output_df, pd.DataFrame([row])], ignore_index=True) | |
if len(output_df) >= total_limit: | |
break | |
if len(output_df) >= total_limit: | |
break | |
assert len(output_df) <= total_limit | |
if len(output_df) == total_limit: | |
print(f"All buckets are full ({total_limit} samples)") | |
else: | |
print(f"Only {len(output_df)} files are used") | |
output_df.to_csv(output_path, index=False) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("input", type=str, nargs="+") | |
parser.add_argument("-o", "--output", required=True) | |
parser.add_argument("-c", "--config", required=True) | |
parser.add_argument("-l", "--limit", default=200, type=int) | |
args = parser.parse_args() | |
assert args.limit > 0 | |
cfg = Config.fromfile(args.config) | |
bucket_config = cfg.bucket_config | |
# rewrite bucket_config | |
for ar, d in bucket_config.items(): | |
for frames, t in d.items(): | |
p, bs = t | |
if p > 0.0: | |
p = 1.0 | |
d[frames] = (p, bs) | |
bucket = Bucket(bucket_config) | |
split_by_bucket(bucket, args.input, args.output, args.limit, cfg.dataset.frame_interval) | |