Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
def read_file(input_path): | |
if input_path.endswith(".csv"): | |
return pd.read_csv(input_path) | |
elif input_path.endswith(".parquet"): | |
return pd.read_parquet(input_path) | |
else: | |
raise NotImplementedError(f"Unsupported file format: {input_path}") | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("input", type=str, help="Path to the input dataset") | |
parser.add_argument("--save-img", type=str, default="samples/infos/", help="Path to save the image") | |
return parser.parse_args() | |
def plot_data(data, column, bins, name): | |
plt.clf() | |
data.hist(column=column, bins=bins) | |
os.makedirs(os.path.dirname(name), exist_ok=True) | |
plt.savefig(name) | |
print(f"Saved {name}") | |
def plot_categorical_data(data, column, name): | |
plt.clf() | |
data[column].value_counts().plot(kind="bar") | |
os.makedirs(os.path.dirname(name), exist_ok=True) | |
plt.savefig(name) | |
print(f"Saved {name}") | |
COLUMNS = { | |
"num_frames": 100, | |
"resolution": 100, | |
"text_len": 100, | |
"aes": 100, | |
"match": 100, | |
"flow": 100, | |
"cmotion": None, | |
} | |
def main(args): | |
data = read_file(args.input) | |
# === Image Data Info === | |
image_index = data["num_frames"] == 1 | |
if image_index.sum() > 0: | |
print("=== Image Data Info ===") | |
img_data = data[image_index] | |
print(f"Number of images: {len(img_data)}") | |
print(img_data.head()) | |
print(img_data.describe()) | |
if args.save_img: | |
for column in COLUMNS: | |
if column in img_data.columns and column not in ["num_frames", "cmotion"]: | |
if COLUMNS[column] is None: | |
plot_categorical_data(img_data, column, os.path.join(args.save_img, f"image_{column}.png")) | |
else: | |
plot_data(img_data, column, COLUMNS[column], os.path.join(args.save_img, f"image_{column}.png")) | |
# === Video Data Info === | |
if not image_index.all(): | |
print("=== Video Data Info ===") | |
video_data = data[~image_index] | |
print(f"Number of videos: {len(video_data)}") | |
if "num_frames" in video_data.columns: | |
total_num_frames = video_data["num_frames"].sum() | |
print(f"Number of frames: {total_num_frames}") | |
DEFAULT_FPS = 30 | |
total_hours = total_num_frames / DEFAULT_FPS / 3600 | |
print(f"Total hours (30 FPS): {int(total_hours)}") | |
print(video_data.head()) | |
print(video_data.describe()) | |
if args.save_img: | |
for column in COLUMNS: | |
if column in video_data.columns: | |
if COLUMNS[column] is None: | |
plot_categorical_data(video_data, column, os.path.join(args.save_img, f"video_{column}.png")) | |
else: | |
plot_data( | |
video_data, column, COLUMNS[column], os.path.join(args.save_img, f"video_{column}.png") | |
) | |
if __name__ == "__main__": | |
args = parse_args() | |
main(args) | |