|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Deduplicate downstream tasks from training dataset. 13-grams have been used. |
|
All split documents with less than 200 characters got filtered. Any document |
|
with more than 10 splits got filtered as well. |
|
""" |
|
|
|
import argparse |
|
from functools import partial |
|
import json |
|
import multiprocessing |
|
import nltk |
|
import pickle |
|
import re |
|
import string |
|
import sys |
|
import time |
|
|
|
def get_words(text): |
|
|
|
words, positions = [], [] |
|
for match in re.finditer(r'\w+', text.lower()): |
|
words.append(match.group(0)) |
|
positions.append(match.start()) |
|
return words, positions |
|
|
|
|
|
def split_text(text, start_position, remove_char_each_side, seq): |
|
|
|
punctuations = ".!?" |
|
pos = start_position - remove_char_each_side |
|
text_first = "" |
|
while pos > 0 and not text[pos] in punctuations: |
|
pos -= 1 |
|
if pos > 0: |
|
text_first = text[0:pos+1] |
|
|
|
|
|
pos = start_position + len(seq) + remove_char_each_side |
|
|
|
|
|
text_second = "" |
|
while pos < len(text) and not text[pos] in punctuations: |
|
pos += 1 |
|
if pos + 1 < len(text): |
|
text_second = text[pos+1:len(text)] |
|
|
|
return text_first, text_second |
|
|
|
def check_and_clean_text(args, words, ngrams, text, start_position, \ |
|
text_buf_ngram_free, text_buf, local_ngram): |
|
|
|
seq = " ".join(words) |
|
if seq in ngrams: |
|
print(" [matched]: {}".format(seq), flush=True) |
|
|
|
if args.get_ngram_freq_only: |
|
|
|
|
|
if seq in local_ngram: |
|
local_ngram[seq] += 1 |
|
else: |
|
local_ngram[seq] = 1 |
|
|
|
if (start_position + len(seq) + 1) < len(text): |
|
text_buf.append(text[start_position + len(seq) + 1:len(text)]) |
|
return False |
|
|
|
|
|
text_first, text_second = split_text(text, start_position, \ |
|
args.remove_char_each_side, seq) |
|
|
|
|
|
if len(text_first) > args.filter_text_char_len: |
|
text_buf_ngram_free.append(text_first) |
|
|
|
|
|
if len(text_second) > args.filter_text_char_len: |
|
text_buf.append(text_second) |
|
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): |
|
|
|
|
|
try: |
|
myjson = json.loads(line) |
|
text_buf = [myjson[key]] |
|
except Exception as e: |
|
print("Error: {}".format(e), flush=True) |
|
text_buf = [] |
|
|
|
text_buf_ngram_free = [] |
|
local_ngram = {} |
|
while len(text_buf) > 0: |
|
|
|
|
|
text = text_buf.pop(0) |
|
words, positions = get_words(text) |
|
|
|
ngram_free = True |
|
|
|
for i in range(len(words) - args.max_ngram_size + 1): |
|
check_ngram_free = check_and_clean_text(args, words[i:\ |
|
i+args.max_ngram_size], ngrams, text, positions[i], \ |
|
text_buf_ngram_free, text_buf, local_ngram) |
|
|
|
|
|
if not check_ngram_free: |
|
ngram_free = False |
|
break |
|
|
|
|
|
|
|
for ngram_len, _ in ngrams_freq_sorted: |
|
check_ngram_free = check_and_clean_text(args, words[i:\ |
|
i+ngram_len], ngrams, text, positions[i], \ |
|
text_buf_ngram_free, text_buf, local_ngram) |
|
|
|
|
|
if not check_ngram_free: |
|
ngram_free = False |
|
break |
|
|
|
|
|
if not ngram_free: |
|
break |
|
|
|
|
|
if ngram_free and len(words) - args.max_ngram_size > 0: |
|
|
|
last_seq_words = words[(len(words)-args.max_ngram_size):len(words)] |
|
last_seq_start_position = len(words) - args.max_ngram_size |
|
|
|
|
|
for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted): |
|
|
|
|
|
if ngram_len == args.max_ngram_size: |
|
continue |
|
|
|
|
|
for i in range(len(last_seq_words) - ngram_len + 1): |
|
check_ngram_free = check_and_clean_text(args, \ |
|
last_seq_words[i:i+ngram_len], ngrams, text,\ |
|
positions[last_seq_start_position+i], \ |
|
text_buf_ngram_free, text_buf, local_ngram) |
|
|
|
if not check_ngram_free: |
|
ngram_free = False |
|
break |
|
|
|
if not ngram_free: |
|
break |
|
|
|
|
|
if ngram_free and not args.get_ngram_freq_only: |
|
text_buf_ngram_free.append(text) |
|
|
|
|
|
trimmed = 0 |
|
if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \ |
|
len(text_buf_ngram_free[0]) < len(myjson[key]): |
|
trimmed = 1 |
|
|
|
return text_buf_ngram_free, trimmed, myjson, local_ngram |
|
|
|
|
|
def insert_dict(words, ngrams, pos): |
|
seq = " ".join(words) |
|
if seq not in ngrams: |
|
ngrams[seq] = 0 |
|
|
|
|
|
|
|
def compute_ngrams_insert_dict(args, text, ngrams): |
|
words, positions = get_words(text) |
|
if len(words) < args.min_ngram_size: |
|
return |
|
|
|
if len(words) < args.max_ngram_size: |
|
insert_dict(words, ngrams, positions[0]) |
|
|
|
for i in range(len(words) - args.max_ngram_size+1): |
|
insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i]) |
|
|
|
|
|
|
|
def process_task_lambda(args, task_file, ngrams): |
|
print(' reading from {} and computing ngrams'.format(task_file)) |
|
with open(task_file, 'r') as f: |
|
for line in f: |
|
try: |
|
myjson = json.loads(line) |
|
text = myjson['text'] |
|
compute_ngrams_insert_dict(args, text, ngrams) |
|
except Exception as e: |
|
print('Error:', e) |
|
print(" Entities in ngrams {}".format(len(ngrams)), flush=True) |
|
|
|
|
|
|
|
def process_task(args, task_name, ngrams): |
|
|
|
print(' reading from {} and computing ngrams'.format('import datasets')) |
|
print(" Current entities in ngrams {}".format(len(ngrams)), flush=True) |
|
|
|
from datasets import load_dataset |
|
|
|
entities_in_ngrams = len(ngrams) |
|
|
|
|
|
if task_name == 'squad': |
|
dataset = load_dataset('squad_v2', split='validation') |
|
elif task_name == 'natural_questions': |
|
dataset = load_dataset('natural_questions', split='validation') |
|
elif task_name == 'triviaqa': |
|
dataset = load_dataset('trivia_qa', 'unfiltered', split='test') |
|
elif task_name == 'webqa': |
|
dataset = load_dataset('web_questions', split='test') |
|
elif task_name == 'race': |
|
dataset = load_dataset('race', 'all', split='test') |
|
elif task_name == 'drop': |
|
dataset = load_dataset('drop', split='validation') |
|
elif task_name == 'coqa': |
|
dataset = load_dataset('coqa', split='validation') |
|
elif task_name == 'piqa': |
|
dataset = load_dataset('piqa', split='test') |
|
else: |
|
print("Invalid task name: {}".format(task_name), flush=True) |
|
return |
|
|
|
|
|
for line in dataset: |
|
try: |
|
if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']: |
|
text = line['question'] |
|
compute_ngrams_insert_dict(args, text, ngrams) |
|
elif task_name == 'natural_questions': |
|
text = line['question']['text'] |
|
compute_ngrams_insert_dict(args, text, ngrams) |
|
elif task_name == 'coqa': |
|
all_questions = line['questions'] |
|
for question in all_questions: |
|
compute_ngrams_insert_dict(args, question, ngrams) |
|
elif task_name == 'piqa': |
|
text = line['goal'] |
|
compute_ngrams_insert_dict(args, text, ngrams) |
|
except Exception as e: |
|
print('Error:', e) |
|
|
|
print(" After task {} entities in ngrams {}, added {}".format(task_name, \ |
|
len(ngrams), len(ngrams) - entities_in_ngrams), flush=True) |
|
|
|
def compute_tasks_ngrams(args, ngrams): |
|
start_time = time.time() |
|
for _, task_name in enumerate(args.tasks): |
|
print('Task: {}'.format(task_name), flush=True) |
|
if task_name == 'lambada': |
|
assert args.lambada_path is not None |
|
process_task_lambda(args, args.lambada_path, ngrams) |
|
else: |
|
process_task(args, task_name, ngrams) |
|
print(" Taken time to compute ngrams {:.2f}".format(time.time() - \ |
|
start_time), flush=True) |
|
|
|
def compute_ngram_freq_sorted(args, ngrams): |
|
ngrams_freq = {} |
|
for ngram_key in ngrams.keys(): |
|
length = len(ngram_key.split()) |
|
ngrams_freq[length] = ngrams_freq[length] + 1 if length in \ |
|
ngrams_freq else 1 |
|
|
|
ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0]) |
|
print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True) |
|
print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\ |
|
len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\ |
|
ngrams_freq_sorted) -1 ][0]), flush=True) |
|
return ngrams_freq_sorted |
|
|
|
def get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \ |
|
dedup_file, dedup_key, ngrams_freq_sorted): |
|
|
|
start_time = time.time() |
|
|
|
args.get_ngram_freq_only = True |
|
|
|
|
|
num_workers = args.num_threads |
|
pool = multiprocessing.Pool(num_workers) |
|
fin = open(dedup_file, 'r', encoding='utf-8') |
|
free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \ |
|
ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted) |
|
free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500) |
|
|
|
counter = 0 |
|
for _, _, _, local_ngram in free_ngrams_abt: |
|
counter += 1 |
|
if counter % 1000 == 0: |
|
print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'. |
|
format(counter, time.time() - start_time), flush=True) |
|
for local_key in local_ngram: |
|
if local_key in ngrams: |
|
ngrams[local_key] += 1 |
|
local_ngram = {} |
|
|
|
print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \ |
|
start_time), flush=True) |
|
pool.close() |
|
pool.join() |
|
|
|
start_time = time.time() |
|
counter_threshold = 0 |
|
|
|
for local_key, local_val in ngrams.items(): |
|
if ngrams[local_key] < args.key_threshold: |
|
print(" [threshold] {} {}".format(local_key, local_val), flush=True) |
|
counter_threshold += 1 |
|
ngrams_below_threshold[local_key] = 1 |
|
|
|
print(' Ngrams below threshold {}'.format(counter_threshold), flush=True) |
|
fin.close() |
|
|
|
def clean_ngrams_below_threshold(args, ngrams_below_threshold, dedup_file, \ |
|
dedup_key): |
|
|
|
start_time = time.time() |
|
|
|
args.get_ngram_freq_only = False |
|
|
|
id_prefix = '-'.join(args.tasks[::1]) |
|
|
|
|
|
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_below_threshold) |
|
|
|
|
|
counter = splitted = ignored = split_mt_thld = trimmed_count = 0 |
|
num_workers = args.num_threads |
|
pool = multiprocessing.Pool(num_workers) |
|
fin = open(dedup_file, 'r', encoding='utf-8') |
|
free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \ |
|
ngrams=ngrams_below_threshold, ngrams_freq_sorted=ngrams_freq_sorted) |
|
free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500) |
|
|
|
out_f = open(args.output, 'wb') |
|
|
|
for text_buf_ngram_free, trimmed, myjson, _ in free_ngrams_clean: |
|
counter += 1 |
|
try: |
|
|
|
trimmed_count += trimmed |
|
|
|
if len(text_buf_ngram_free) > 1: |
|
splitted += 1 |
|
if len(text_buf_ngram_free) == 0: |
|
ignored += 1 |
|
|
|
if len(text_buf_ngram_free) > args.splits_count: |
|
text_buf_ngram_free = [] |
|
split_mt_thld += 1 |
|
|
|
if args.output is not None: |
|
if "split_id" in myjson: |
|
use_prefix = myjson["split_id"] + "-" |
|
else: |
|
use_prefix = "" |
|
|
|
for i in range(len(text_buf_ngram_free)): |
|
split_id_string = id_prefix + '-{:010d}'.format(int(\ |
|
counter)) + '-{:04d}'.format(int(i)) |
|
myjson[dedup_key] = text_buf_ngram_free[i] |
|
myjson["split_id"] = use_prefix + split_id_string |
|
outjson = json.dumps(myjson, ensure_ascii=False) |
|
|
|
|
|
|
|
out_f.write(outjson.encode('utf-8')) |
|
out_f.write('\n'.encode('utf-8')) |
|
|
|
if counter % 1000 == 0: |
|
print(' [final]> processed {} documents in {:.2f} seconds ...'. |
|
format(counter, time.time() - start_time), flush=True) |
|
except Exception as e: |
|
print('Error:', e) |
|
|
|
print(' [final]> processed {} documents in {:.2f} seconds ...'. |
|
format(counter, time.time() - start_time), flush=True) |
|
|
|
print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\ |
|
' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\ |
|
, flush=True) |
|
|
|
pool.close() |
|
pool.join() |
|
|
|
out_f.close() |
|
fin.close() |
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
print('parsing the arguments ...') |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \ |
|
help = 'Tasks to use for deduplication: currently ' |
|
' suuport [lambada, squad, natural_questions,' |
|
' triviaqa, webqa, race, drop, coqa, and piqa]') |
|
parser.add_argument('--lambada-path', type=str, default=None, |
|
help='Only Lambada task needs the path') |
|
parser.add_argument('--dedup-dataset', nargs = '*', default=None, |
|
help='Dataset to deduplicate with the key to use' |
|
' e.g. cc.json text') |
|
parser.add_argument('--output', type=str, default=None, |
|
help='Output file name to save dedup dataset') |
|
parser.add_argument('--num-threads', type=int, default=40, |
|
help='Number of threads to use') |
|
|
|
parser.add_argument('--max-ngram-size', type=int, default=13, |
|
help='Maximum size of ngram to use.') |
|
parser.add_argument('--min-ngram-size', type=int, default=8, |
|
help='Minimum size of ngram to use.') |
|
parser.add_argument('--filter-text-char-len', type=int, default=200, |
|
help='Remove any text below this length.') |
|
parser.add_argument('--key-threshold', type=int, default=10, |
|
help='Number of keys to consider as threshold') |
|
parser.add_argument('--save-dictionary', type=str, default=None, |
|
help='Save the dictionary') |
|
parser.add_argument('--load-dictionary', type=str, default=None, |
|
help='Load the dictionary') |
|
parser.add_argument('--splits-count', type=int, default=10, |
|
help='Remove any documents more than this many splits') |
|
parser.add_argument('--remove-char-each-side', type=int, default=200, |
|
help='Maximum size of ngram to use.') |
|
|
|
args = parser.parse_args() |
|
|
|
assert len(args.dedup_dataset) == 2 |
|
dedup_file = args.dedup_dataset[0] |
|
dedup_key = args.dedup_dataset[1] |
|
|
|
|
|
num_workers = args.num_threads |
|
if args.load_dictionary is None: |
|
|
|
|
|
ngrams = {} |
|
compute_tasks_ngrams(args, ngrams) |
|
|
|
|
|
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams) |
|
|
|
|
|
|
|
ngrams_below_threshold = {} |
|
get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \ |
|
dedup_file, dedup_key, ngrams_freq_sorted) |
|
|
|
|
|
if args.save_dictionary is not None: |
|
with open(args.save_dictionary, 'wb') as save_dict_handle: |
|
pickle.dump(ngrams_below_threshold, save_dict_handle) |
|
else: |
|
with open(args.load_dictionary, 'rb') as load_dict_handle: |
|
ngrams_below_threshold = pickle.load(load_dict_handle) |
|
|
|
|
|
if args.output is not None: |
|
clean_ngrams_below_threshold(args, ngrams_below_threshold, \ |
|
dedup_file, dedup_key) |
|
|
|
print('done :-)') |
|
|