predictive_auditing_data / merge_train.py
s1ghhh's picture
Upload folder using huggingface_hub
98595da verified
import pandas as pd
import glob
import random
random_seed = 42
sample_size = 15000
# 1. 找到所有 parquet 文件
parquet_files = glob.glob("v2_train_counting_dataset_*.parquet")
selected_parquet_files = []
for parquet_file in parquet_files:
if "v2_train_counting_dataset_OpenThoughts-114k-math_88120.parquet" not in parquet_file:
selected_parquet_files.append(parquet_file)
print("找到的parquet文件:", selected_parquet_files)
# 2. 合并所有数据
all_data = []
for file in selected_parquet_files:
print(file)
df = pd.read_parquet(file)
all_data.append(df)
df_all = pd.concat(all_data, ignore_index=True)
print("合并后总数据量:", len(df_all))
# 3. 按 data_source 分组,每组采样 25k
sampled_dfs = []
for name, group in df_all.groupby("data_source"):
if len(group) > sample_size:
sampled = group.sample(n=sample_size, random_state=random_seed)
else:
sampled = group
sampled_dfs.append(sampled)
print(f"{name}: 原始{len(group)}条,采样{len(sampled)}条")
# 4. 合并采样后的数据
df_sampled = pd.concat(sampled_dfs, ignore_index=True)
print("采样后总数据量:", len(df_sampled))
shuffled_df = df_sampled.sample(frac=1, random_state=random_seed).reset_index(drop=True)
# 5. 保存
shuffled_df.to_parquet("merged_sampled_4datasets_15k_each.parquet", index=False)
print("已保存到 merged_sampled_4datasets_15k_each.parquet")