|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script will merge prompt-specific train files into a single file per task. |
|
""" |
|
import json |
|
import os |
|
from argparse import ArgumentParser |
|
|
|
tasks = [ |
|
'adversarial_qa', |
|
'ag_news', |
|
'ai2_arc_ARC_Challenge', |
|
'ai2_arc_ARC_Easy', |
|
'amazon_polarity', |
|
'anli', |
|
'app_reviews', |
|
'cnn_dailymail_3.0.0', |
|
'common_gen', |
|
'cos_e_v1.11', |
|
'cosmos_qa', |
|
'dbpedia_14', |
|
'dream', |
|
'duorc_ParaphraseRC', |
|
'duorc_SelfRC', |
|
'gigaword', |
|
'glue_mrpc', |
|
'glue_qqp', |
|
'hellaswag', |
|
'imdb', |
|
'kilt_tasks_hotpotqa', |
|
'multi_news', |
|
'openbookqa_main', |
|
'paws_labeled_final', |
|
'piqa', |
|
'qasc', |
|
'quail', |
|
'quarel', |
|
'quartz', |
|
'quoref', |
|
'race_high', |
|
'race_middle', |
|
'ropes', |
|
'rotten_tomatoes', |
|
'samsum', |
|
'sciq', |
|
'social_i_qa', |
|
'squad_v2', |
|
'super_glue_boolq', |
|
'super_glue_cb', |
|
'super_glue_copa', |
|
'super_glue_multirc', |
|
'super_glue_record', |
|
'super_glue_rte', |
|
'super_glue_wic', |
|
'super_glue_wsc', |
|
'trec', |
|
'trivia_qa', |
|
'web_questions', |
|
'wiki_bio', |
|
'wiki_hop', |
|
'wiki_qa', |
|
'winogrande_winogrande', |
|
'wiqa', |
|
'xsum', |
|
'yelp_review_full', |
|
] |
|
|
|
|
|
def merge_train_folder(train_data_folder, merged_train_data_folder): |
|
if not os.path.exists(merged_train_data_folder): |
|
os.makedirs(merged_train_data_folder) |
|
task_counter = {task: 0 for task in tasks} |
|
fptrs = {task: open(os.path.join(merged_train_data_folder, task + '.jsonl'), 'w') for task in tasks} |
|
for idx, fname in enumerate(os.listdir(train_data_folder)): |
|
if idx % 10 == 0: |
|
print(f'Processed {idx + 1}/{len(os.listdir(train_data_folder))} files ...') |
|
if fname.endswith('.jsonl') and '_score_eval' not in fname: |
|
found = False |
|
for task in tasks: |
|
if fname.startswith(task): |
|
task_counter[task] += 1 |
|
found = True |
|
with open(os.path.join(train_data_folder, fname), 'r') as f: |
|
for line in f: |
|
line = json.loads(line) |
|
line['task_name_with_prompt'] = fname |
|
if line['input'].strip() == '': |
|
print(f'WARNING: Empty input for {fname}') |
|
continue |
|
if line['output'].strip() == '': |
|
print(f'WARNING: Empty output for {fname}') |
|
continue |
|
fptrs[task].write(json.dumps(line) + '\n') |
|
if not found: |
|
print(f'WARNING: Could not find task for {fname}') |
|
|
|
for _, v in fptrs.items(): |
|
v.close() |
|
if task_counter[task] == 0: |
|
print('WARNING: No files found for task: ', task) |
|
|
|
for k, v in task_counter.items(): |
|
print(f'Task {k} had {v} prompt templates.') |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--p3_processed_train_dataset_path", |
|
type=str, |
|
required=True, |
|
help="Path to the processed P3 train dataset. This is the output of the t0_dataset_preproc.py script.", |
|
) |
|
parser.add_argument( |
|
"--p3_processed_merged_train_dataset_path", |
|
type=str, |
|
required=True, |
|
help="Path to output folder where merged JSONL files will be written.", |
|
) |
|
args = parser.parse_args() |
|
merge_train_folder(args.p3_processed_train_dataset_path, args.p3_processed_merged_train_dataset_path) |
|
|