Baraaqasem's picture
Upload 585 files
5d32408 verified
raw
history blame contribute delete
2.46 kB
import argparse
from typing import List
import pandas as pd
from mmengine.config import Config
from videogen_hub.pipelines.opensora.opensora.datasets.bucket import Bucket
def split_by_bucket(
bucket: Bucket,
input_files: List[str],
output_path: str,
limit: int,
frame_interval: int,
):
print(f"Split {len(input_files)} files into {len(bucket)} buckets")
total_limit = len(bucket) * limit
bucket_cnt = {}
# get all bucket id
for hw_id, d in bucket.ar_criteria.items():
for t_id, v in d.items():
for ar_id in v.keys():
bucket_id = (hw_id, t_id, ar_id)
bucket_cnt[bucket_id] = 0
output_df = None
# split files
for path in input_files:
df = pd.read_csv(path)
if output_df is None:
output_df = pd.DataFrame(columns=df.columns)
for i in range(len(df)):
row = df.iloc[i]
t, h, w = row["num_frames"], row["height"], row["width"]
bucket_id = bucket.get_bucket_id(t, h, w, frame_interval)
if bucket_id is None:
continue
if bucket_cnt[bucket_id] < limit:
bucket_cnt[bucket_id] += 1
output_df = pd.concat([output_df, pd.DataFrame([row])], ignore_index=True)
if len(output_df) >= total_limit:
break
if len(output_df) >= total_limit:
break
assert len(output_df) <= total_limit
if len(output_df) == total_limit:
print(f"All buckets are full ({total_limit} samples)")
else:
print(f"Only {len(output_df)} files are used")
output_df.to_csv(output_path, index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, nargs="+")
parser.add_argument("-o", "--output", required=True)
parser.add_argument("-c", "--config", required=True)
parser.add_argument("-l", "--limit", default=200, type=int)
args = parser.parse_args()
assert args.limit > 0
cfg = Config.fromfile(args.config)
bucket_config = cfg.bucket_config
# rewrite bucket_config
for ar, d in bucket_config.items():
for frames, t in d.items():
p, bs = t
if p > 0.0:
p = 1.0
d[frames] = (p, bs)
bucket = Bucket(bucket_config)
split_by_bucket(bucket, args.input, args.output, args.limit, cfg.dataset.frame_interval)