import argparse from typing import List import pandas as pd from mmengine.config import Config from videogen_hub.pipelines.opensora.opensora.datasets.bucket import Bucket def split_by_bucket( bucket: Bucket, input_files: List[str], output_path: str, limit: int, frame_interval: int, ): print(f"Split {len(input_files)} files into {len(bucket)} buckets") total_limit = len(bucket) * limit bucket_cnt = {} # get all bucket id for hw_id, d in bucket.ar_criteria.items(): for t_id, v in d.items(): for ar_id in v.keys(): bucket_id = (hw_id, t_id, ar_id) bucket_cnt[bucket_id] = 0 output_df = None # split files for path in input_files: df = pd.read_csv(path) if output_df is None: output_df = pd.DataFrame(columns=df.columns) for i in range(len(df)): row = df.iloc[i] t, h, w = row["num_frames"], row["height"], row["width"] bucket_id = bucket.get_bucket_id(t, h, w, frame_interval) if bucket_id is None: continue if bucket_cnt[bucket_id] < limit: bucket_cnt[bucket_id] += 1 output_df = pd.concat([output_df, pd.DataFrame([row])], ignore_index=True) if len(output_df) >= total_limit: break if len(output_df) >= total_limit: break assert len(output_df) <= total_limit if len(output_df) == total_limit: print(f"All buckets are full ({total_limit} samples)") else: print(f"Only {len(output_df)} files are used") output_df.to_csv(output_path, index=False) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("input", type=str, nargs="+") parser.add_argument("-o", "--output", required=True) parser.add_argument("-c", "--config", required=True) parser.add_argument("-l", "--limit", default=200, type=int) args = parser.parse_args() assert args.limit > 0 cfg = Config.fromfile(args.config) bucket_config = cfg.bucket_config # rewrite bucket_config for ar, d in bucket_config.items(): for frames, t in d.items(): p, bs = t if p > 0.0: p = 1.0 d[frames] = (p, bs) bucket = Bucket(bucket_config) split_by_bucket(bucket, args.input, args.output, args.limit, cfg.dataset.frame_interval)