yentinglin
commited on
Upload decontaminate.py with huggingface_hub
Browse files- decontaminate.py +125 -0
decontaminate.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import difflib
|
3 |
+
import re
|
4 |
+
import unicodedata
|
5 |
+
from pathlib import Path
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
from datasets import load_dataset, Dataset
|
8 |
+
import jieba
|
9 |
+
|
10 |
+
|
11 |
+
def tokenize(text):
|
12 |
+
"""Tokenize Chinese text using Jieba."""
|
13 |
+
tokens = jieba.lcut(text)
|
14 |
+
return tokens
|
15 |
+
|
16 |
+
|
17 |
+
def get_ngrams(tokens, n):
|
18 |
+
"""Generate n-grams from tokens."""
|
19 |
+
return set(zip(*[tokens[i:] for i in range(n)]))
|
20 |
+
|
21 |
+
|
22 |
+
def retrieve_ngrams_batch(batch, eval_ngrams, eval_datasets, eval_texts, ngram_len):
|
23 |
+
"""Find contaminated samples based on n-grams."""
|
24 |
+
new_batch = {"completion": [], "ngram": [], "bench_name": [], "bench_text": []}
|
25 |
+
for completion in batch["completion"]:
|
26 |
+
tokens = tokenize(completion)
|
27 |
+
ngrams = get_ngrams(tokens, ngram_len)
|
28 |
+
for ngram in ngrams:
|
29 |
+
if ngram in eval_ngrams:
|
30 |
+
idx = eval_ngrams[ngram]
|
31 |
+
new_batch["completion"].append(completion)
|
32 |
+
new_batch["ngram"].append(ngram)
|
33 |
+
new_batch["bench_name"].append(eval_datasets[idx])
|
34 |
+
new_batch["bench_text"].append(eval_texts[idx])
|
35 |
+
break
|
36 |
+
return new_batch
|
37 |
+
|
38 |
+
|
39 |
+
def diff_strings(string1, string2):
|
40 |
+
"""Find matching parts between two strings."""
|
41 |
+
matcher = difflib.SequenceMatcher(None, string1.lower(), string2.lower(), autojunk=False)
|
42 |
+
matching_blocks = matcher.get_matching_blocks()
|
43 |
+
matches = []
|
44 |
+
for block in matching_blocks:
|
45 |
+
start_a, start_b, length = block
|
46 |
+
if length > 5:
|
47 |
+
match = string1[start_a:start_a + length]
|
48 |
+
matches.append(match)
|
49 |
+
return matches
|
50 |
+
|
51 |
+
|
52 |
+
def add_match_stats(example):
|
53 |
+
gen_text = " ".join(tokenize(example["completion"]))
|
54 |
+
bench_text = " ".join(tokenize(example["bench_text"]))
|
55 |
+
matching_parts = diff_strings(gen_text, bench_text)
|
56 |
+
match = " ".join("".join(matching_parts).split())
|
57 |
+
example["diff"] = matching_parts
|
58 |
+
example["diff_ratio"] = len(match) / len(bench_text) if len(bench_text) > 0 else 0
|
59 |
+
example["diff_length"] = len(match)
|
60 |
+
example["longest_diff_part"] = max(matching_parts, key=len, default="")
|
61 |
+
example["longest_diff_part_length"] = len(example["longest_diff_part"])
|
62 |
+
return example
|
63 |
+
|
64 |
+
|
65 |
+
def main(args):
|
66 |
+
# Load the evaluation data to build n-grams index
|
67 |
+
eval_ngrams, eval_datasets, eval_texts = {}, [], []
|
68 |
+
eval_data = load_dataset(args.eval_dataset, split="train")
|
69 |
+
for example in tqdm(eval_data):
|
70 |
+
tokens = tokenize(example["text"])
|
71 |
+
ngrams = get_ngrams(tokens, args.ngram_length)
|
72 |
+
if ngrams:
|
73 |
+
idx = len(eval_texts)
|
74 |
+
eval_ngrams.update(zip(ngrams, [idx] * len(ngrams)))
|
75 |
+
eval_datasets.append(example.get("task_name", "unknown"))
|
76 |
+
eval_texts.append(example["text"])
|
77 |
+
|
78 |
+
train_dataset_path = Path(args.train_dataset)
|
79 |
+
if train_dataset_path.exists() and train_dataset_path.suffix in [".json", ".csv"]:
|
80 |
+
if train_dataset_path.suffix == ".json":
|
81 |
+
train_data = Dataset.from_json(args.train_dataset)
|
82 |
+
elif train_dataset_path.suffix == ".csv":
|
83 |
+
train_data = Dataset.from_csv(args.train_dataset)
|
84 |
+
else:
|
85 |
+
train_data = load_dataset(args.train_dataset, split="train")
|
86 |
+
|
87 |
+
contamination_report = train_data.map(
|
88 |
+
lambda batch: retrieve_ngrams_batch(batch, eval_ngrams, eval_datasets, eval_texts, args.ngram_length),
|
89 |
+
batched=True, batch_size=1000, num_proc=args.num_proc, remove_columns=train_data.column_names
|
90 |
+
)
|
91 |
+
|
92 |
+
contamination_report = contamination_report.map(
|
93 |
+
lambda example: add_match_stats(example), num_proc=args.num_proc
|
94 |
+
)
|
95 |
+
|
96 |
+
contamination_report.push_to_hub(args.report_dataset_name, private=args.private)
|
97 |
+
|
98 |
+
contamination_report = contamination_report.filter(lambda x: x["diff_ratio"] > args.diff_threshold)
|
99 |
+
|
100 |
+
if args.save_decontaminated:
|
101 |
+
contaminated_completions = set(contamination_report["completion"])
|
102 |
+
filtered_data = train_data.filter(lambda x: x["completion"] not in contaminated_completions)
|
103 |
+
filtered_data.push_to_hub(args.decontaminated_dataset_name, private=args.private)
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
parser = argparse.ArgumentParser(description="Generate a decontamination report for a dataset.")
|
108 |
+
parser.add_argument("--eval_dataset", type=str,
|
109 |
+
default="HuggingFaceTB/phi2_eval_data_for_decontamination",
|
110 |
+
help="Name of the dataset with benchmark samples to use for decontamination.")
|
111 |
+
parser.add_argument("--train_dataset", type=str, required=True,
|
112 |
+
help="Path or name of the training dataset to process.")
|
113 |
+
parser.add_argument("--report_dataset_name", type=str, required=True,
|
114 |
+
help="Name for the output dataset with decontamination report.")
|
115 |
+
parser.add_argument("--decontaminated_dataset_name", type=str, help="Name for the decontaminated dataset.")
|
116 |
+
parser.add_argument("--private", action='store_true', help="Whether to make the output dataset private.")
|
117 |
+
parser.add_argument("--ngram_length", type=int, default=10, help="Length of the n-grams to consider.")
|
118 |
+
parser.add_argument("--diff_threshold", type=float, default=0.5,
|
119 |
+
help="Threshold for filtering based on difference ratio.")
|
120 |
+
parser.add_argument("--num_proc", type=int, default=90, help="Number of processes to use for map operations.")
|
121 |
+
parser.add_argument("--save_decontaminated", action='store_true',
|
122 |
+
help="Whether to save the decontaminated dataset.")
|
123 |
+
|
124 |
+
args = parser.parse_args()
|
125 |
+
main(args)
|