mc4_tmlu_contamination_report / decontaminate.py
yentinglin's picture
Upload decontaminate.py with huggingface_hub
f25fae7 verified
raw
history blame
5.56 kB
import argparse
import difflib
import re
import unicodedata
from pathlib import Path
from tqdm.auto import tqdm
from datasets import load_dataset, Dataset
import jieba
def tokenize(text):
"""Tokenize Chinese text using Jieba."""
tokens = jieba.lcut(text)
return tokens
def get_ngrams(tokens, n):
"""Generate n-grams from tokens."""
return set(zip(*[tokens[i:] for i in range(n)]))
def retrieve_ngrams_batch(batch, eval_ngrams, eval_datasets, eval_texts, ngram_len):
"""Find contaminated samples based on n-grams."""
new_batch = {"completion": [], "ngram": [], "bench_name": [], "bench_text": []}
for completion in batch["completion"]:
tokens = tokenize(completion)
ngrams = get_ngrams(tokens, ngram_len)
for ngram in ngrams:
if ngram in eval_ngrams:
idx = eval_ngrams[ngram]
new_batch["completion"].append(completion)
new_batch["ngram"].append(ngram)
new_batch["bench_name"].append(eval_datasets[idx])
new_batch["bench_text"].append(eval_texts[idx])
break
return new_batch
def diff_strings(string1, string2):
"""Find matching parts between two strings."""
matcher = difflib.SequenceMatcher(None, string1.lower(), string2.lower(), autojunk=False)
matching_blocks = matcher.get_matching_blocks()
matches = []
for block in matching_blocks:
start_a, start_b, length = block
if length > 5:
match = string1[start_a:start_a + length]
matches.append(match)
return matches
def add_match_stats(example):
gen_text = " ".join(tokenize(example["completion"]))
bench_text = " ".join(tokenize(example["bench_text"]))
matching_parts = diff_strings(gen_text, bench_text)
match = " ".join("".join(matching_parts).split())
example["diff"] = matching_parts
example["diff_ratio"] = len(match) / len(bench_text) if len(bench_text) > 0 else 0
example["diff_length"] = len(match)
example["longest_diff_part"] = max(matching_parts, key=len, default="")
example["longest_diff_part_length"] = len(example["longest_diff_part"])
return example
def main(args):
# Load the evaluation data to build n-grams index
eval_ngrams, eval_datasets, eval_texts = {}, [], []
eval_data = load_dataset(args.eval_dataset, split="train")
for example in tqdm(eval_data):
tokens = tokenize(example["text"])
ngrams = get_ngrams(tokens, args.ngram_length)
if ngrams:
idx = len(eval_texts)
eval_ngrams.update(zip(ngrams, [idx] * len(ngrams)))
eval_datasets.append(example.get("task_name", "unknown"))
eval_texts.append(example["text"])
train_dataset_path = Path(args.train_dataset)
if train_dataset_path.exists() and train_dataset_path.suffix in [".json", ".csv"]:
if train_dataset_path.suffix == ".json":
train_data = Dataset.from_json(args.train_dataset)
elif train_dataset_path.suffix == ".csv":
train_data = Dataset.from_csv(args.train_dataset)
else:
train_data = load_dataset(args.train_dataset, split="train")
contamination_report = train_data.map(
lambda batch: retrieve_ngrams_batch(batch, eval_ngrams, eval_datasets, eval_texts, args.ngram_length),
batched=True, batch_size=1000, num_proc=args.num_proc, remove_columns=train_data.column_names
)
contamination_report = contamination_report.map(
lambda example: add_match_stats(example), num_proc=args.num_proc
)
contamination_report.push_to_hub(args.report_dataset_name, private=args.private)
contamination_report = contamination_report.filter(lambda x: x["diff_ratio"] > args.diff_threshold)
if args.save_decontaminated:
contaminated_completions = set(contamination_report["completion"])
filtered_data = train_data.filter(lambda x: x["completion"] not in contaminated_completions)
filtered_data.push_to_hub(args.decontaminated_dataset_name, private=args.private)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a decontamination report for a dataset.")
parser.add_argument("--eval_dataset", type=str,
default="HuggingFaceTB/phi2_eval_data_for_decontamination",
help="Name of the dataset with benchmark samples to use for decontamination.")
parser.add_argument("--train_dataset", type=str, required=True,
help="Path or name of the training dataset to process.")
parser.add_argument("--report_dataset_name", type=str, required=True,
help="Name for the output dataset with decontamination report.")
parser.add_argument("--decontaminated_dataset_name", type=str, help="Name for the decontaminated dataset.")
parser.add_argument("--private", action='store_true', help="Whether to make the output dataset private.")
parser.add_argument("--ngram_length", type=int, default=10, help="Length of the n-grams to consider.")
parser.add_argument("--diff_threshold", type=float, default=0.5,
help="Threshold for filtering based on difference ratio.")
parser.add_argument("--num_proc", type=int, default=90, help="Number of processes to use for map operations.")
parser.add_argument("--save_decontaminated", action='store_true',
help="Whether to save the decontaminated dataset.")
args = parser.parse_args()
main(args)