|
import pandas as pd |
|
import glob |
|
import random |
|
|
|
random_seed = 42 |
|
sample_size = 15000 |
|
|
|
|
|
parquet_files = glob.glob("v2_train_counting_dataset_*.parquet") |
|
|
|
selected_parquet_files = [] |
|
for parquet_file in parquet_files: |
|
if "v2_train_counting_dataset_OpenThoughts-114k-math_88120.parquet" not in parquet_file: |
|
selected_parquet_files.append(parquet_file) |
|
print("找到的parquet文件:", selected_parquet_files) |
|
|
|
|
|
all_data = [] |
|
for file in selected_parquet_files: |
|
print(file) |
|
df = pd.read_parquet(file) |
|
all_data.append(df) |
|
df_all = pd.concat(all_data, ignore_index=True) |
|
print("合并后总数据量:", len(df_all)) |
|
|
|
|
|
sampled_dfs = [] |
|
for name, group in df_all.groupby("data_source"): |
|
if len(group) > sample_size: |
|
sampled = group.sample(n=sample_size, random_state=random_seed) |
|
else: |
|
sampled = group |
|
sampled_dfs.append(sampled) |
|
print(f"{name}: 原始{len(group)}条,采样{len(sampled)}条") |
|
|
|
|
|
df_sampled = pd.concat(sampled_dfs, ignore_index=True) |
|
print("采样后总数据量:", len(df_sampled)) |
|
|
|
shuffled_df = df_sampled.sample(frac=1, random_state=random_seed).reset_index(drop=True) |
|
|
|
|
|
shuffled_df.to_parquet("merged_sampled_4datasets_15k_each.parquet", index=False) |
|
print("已保存到 merged_sampled_4datasets_15k_each.parquet") |