predictive_auditing_data / check_length.py
s1ghhh's picture
Upload folder using huggingface_hub
98595da verified
from datasets import load_dataset
from transformers import AutoTokenizer
# 设置模型 tokenizer(例如使用 GPT-2)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
# 加载 parquet 数据集(假设文件名为 data.parquet)
dataset = load_dataset("parquet", data_files="/workspace/0525_zyw/verl/counting/mk_data/counting_dataset_qwen25_max2048.parquet")
# 选择其中一个 split,例如 'train' 或默认的 'train'
data = dataset["train"]
# 记录超过 2048 tokens 的样本索引和内容
long_items = []
for idx, example in enumerate(data):
prompt = example.get("prompt", "")
tokens = tokenizer(prompt, truncation=False, return_tensors="pt")
input_len = tokens.input_ids.shape[1]
if input_len > 2048:
long_items.append({"index": idx, "length": input_len, "prompt": prompt})
print(f"Found {len(long_items)} items with more than 2048 tokens.")
# 可选:保存结果到 JSON 文件
import json
with open("long_prompts.json", "w", encoding="utf-8") as f:
json.dump(long_items, f, ensure_ascii=False, indent=2)