from datasets import load_dataset | |
from transformers import AutoTokenizer | |
# 设置模型 tokenizer(例如使用 GPT-2) | |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") | |
# 加载 parquet 数据集(假设文件名为 data.parquet) | |
dataset = load_dataset("parquet", data_files="/workspace/0525_zyw/verl/counting/mk_data/counting_dataset_qwen25_max2048.parquet") | |
# 选择其中一个 split,例如 'train' 或默认的 'train' | |
data = dataset["train"] | |
# 记录超过 2048 tokens 的样本索引和内容 | |
long_items = [] | |
for idx, example in enumerate(data): | |
prompt = example.get("prompt", "") | |
tokens = tokenizer(prompt, truncation=False, return_tensors="pt") | |
input_len = tokens.input_ids.shape[1] | |
if input_len > 2048: | |
long_items.append({"index": idx, "length": input_len, "prompt": prompt}) | |
print(f"Found {len(long_items)} items with more than 2048 tokens.") | |
# 可选:保存结果到 JSON 文件 | |
import json | |
with open("long_prompts.json", "w", encoding="utf-8") as f: | |
json.dump(long_items, f, ensure_ascii=False, indent=2) |