yentinglin commited on
Commit
f25fae7
·
verified ·
1 Parent(s): 1e7c9b4

Upload decontaminate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. decontaminate.py +125 -0
decontaminate.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import difflib
3
+ import re
4
+ import unicodedata
5
+ from pathlib import Path
6
+ from tqdm.auto import tqdm
7
+ from datasets import load_dataset, Dataset
8
+ import jieba
9
+
10
+
11
+ def tokenize(text):
12
+ """Tokenize Chinese text using Jieba."""
13
+ tokens = jieba.lcut(text)
14
+ return tokens
15
+
16
+
17
+ def get_ngrams(tokens, n):
18
+ """Generate n-grams from tokens."""
19
+ return set(zip(*[tokens[i:] for i in range(n)]))
20
+
21
+
22
+ def retrieve_ngrams_batch(batch, eval_ngrams, eval_datasets, eval_texts, ngram_len):
23
+ """Find contaminated samples based on n-grams."""
24
+ new_batch = {"completion": [], "ngram": [], "bench_name": [], "bench_text": []}
25
+ for completion in batch["completion"]:
26
+ tokens = tokenize(completion)
27
+ ngrams = get_ngrams(tokens, ngram_len)
28
+ for ngram in ngrams:
29
+ if ngram in eval_ngrams:
30
+ idx = eval_ngrams[ngram]
31
+ new_batch["completion"].append(completion)
32
+ new_batch["ngram"].append(ngram)
33
+ new_batch["bench_name"].append(eval_datasets[idx])
34
+ new_batch["bench_text"].append(eval_texts[idx])
35
+ break
36
+ return new_batch
37
+
38
+
39
+ def diff_strings(string1, string2):
40
+ """Find matching parts between two strings."""
41
+ matcher = difflib.SequenceMatcher(None, string1.lower(), string2.lower(), autojunk=False)
42
+ matching_blocks = matcher.get_matching_blocks()
43
+ matches = []
44
+ for block in matching_blocks:
45
+ start_a, start_b, length = block
46
+ if length > 5:
47
+ match = string1[start_a:start_a + length]
48
+ matches.append(match)
49
+ return matches
50
+
51
+
52
+ def add_match_stats(example):
53
+ gen_text = " ".join(tokenize(example["completion"]))
54
+ bench_text = " ".join(tokenize(example["bench_text"]))
55
+ matching_parts = diff_strings(gen_text, bench_text)
56
+ match = " ".join("".join(matching_parts).split())
57
+ example["diff"] = matching_parts
58
+ example["diff_ratio"] = len(match) / len(bench_text) if len(bench_text) > 0 else 0
59
+ example["diff_length"] = len(match)
60
+ example["longest_diff_part"] = max(matching_parts, key=len, default="")
61
+ example["longest_diff_part_length"] = len(example["longest_diff_part"])
62
+ return example
63
+
64
+
65
+ def main(args):
66
+ # Load the evaluation data to build n-grams index
67
+ eval_ngrams, eval_datasets, eval_texts = {}, [], []
68
+ eval_data = load_dataset(args.eval_dataset, split="train")
69
+ for example in tqdm(eval_data):
70
+ tokens = tokenize(example["text"])
71
+ ngrams = get_ngrams(tokens, args.ngram_length)
72
+ if ngrams:
73
+ idx = len(eval_texts)
74
+ eval_ngrams.update(zip(ngrams, [idx] * len(ngrams)))
75
+ eval_datasets.append(example.get("task_name", "unknown"))
76
+ eval_texts.append(example["text"])
77
+
78
+ train_dataset_path = Path(args.train_dataset)
79
+ if train_dataset_path.exists() and train_dataset_path.suffix in [".json", ".csv"]:
80
+ if train_dataset_path.suffix == ".json":
81
+ train_data = Dataset.from_json(args.train_dataset)
82
+ elif train_dataset_path.suffix == ".csv":
83
+ train_data = Dataset.from_csv(args.train_dataset)
84
+ else:
85
+ train_data = load_dataset(args.train_dataset, split="train")
86
+
87
+ contamination_report = train_data.map(
88
+ lambda batch: retrieve_ngrams_batch(batch, eval_ngrams, eval_datasets, eval_texts, args.ngram_length),
89
+ batched=True, batch_size=1000, num_proc=args.num_proc, remove_columns=train_data.column_names
90
+ )
91
+
92
+ contamination_report = contamination_report.map(
93
+ lambda example: add_match_stats(example), num_proc=args.num_proc
94
+ )
95
+
96
+ contamination_report.push_to_hub(args.report_dataset_name, private=args.private)
97
+
98
+ contamination_report = contamination_report.filter(lambda x: x["diff_ratio"] > args.diff_threshold)
99
+
100
+ if args.save_decontaminated:
101
+ contaminated_completions = set(contamination_report["completion"])
102
+ filtered_data = train_data.filter(lambda x: x["completion"] not in contaminated_completions)
103
+ filtered_data.push_to_hub(args.decontaminated_dataset_name, private=args.private)
104
+
105
+
106
+ if __name__ == "__main__":
107
+ parser = argparse.ArgumentParser(description="Generate a decontamination report for a dataset.")
108
+ parser.add_argument("--eval_dataset", type=str,
109
+ default="HuggingFaceTB/phi2_eval_data_for_decontamination",
110
+ help="Name of the dataset with benchmark samples to use for decontamination.")
111
+ parser.add_argument("--train_dataset", type=str, required=True,
112
+ help="Path or name of the training dataset to process.")
113
+ parser.add_argument("--report_dataset_name", type=str, required=True,
114
+ help="Name for the output dataset with decontamination report.")
115
+ parser.add_argument("--decontaminated_dataset_name", type=str, help="Name for the decontaminated dataset.")
116
+ parser.add_argument("--private", action='store_true', help="Whether to make the output dataset private.")
117
+ parser.add_argument("--ngram_length", type=int, default=10, help="Length of the n-grams to consider.")
118
+ parser.add_argument("--diff_threshold", type=float, default=0.5,
119
+ help="Threshold for filtering based on difference ratio.")
120
+ parser.add_argument("--num_proc", type=int, default=90, help="Number of processes to use for map operations.")
121
+ parser.add_argument("--save_decontaminated", action='store_true',
122
+ help="Whether to save the decontaminated dataset.")
123
+
124
+ args = parser.parse_args()
125
+ main(args)