import os import argparse import json import shutil import re from datasets import load_dataset, load_metric from huggingface_hub import hf_hub_download DATASETS = [ "gov_report", "summ_screen_fd", "qmsum", "qasper", "narrative_qa", "quality", "quality_hard", "contract_nli", ] PATTERN = re.compile(r'\b[A-D]\b') def find_answer(s): match = PATTERN.search(s) if match is None: return None # None is a signal of not find! NOTE return match.group() def read_json_data(data_path): references = [] questions = [] id_to_labels = dict() id_list = list() idx = 0 with open(data_path, "r") as f: examples = json.load(f) for data_item in examples: # dict_keys(['source', 'paragraph_id', 'question', 'answer', 'sub-paragraphs', 'word_count', 'id', 'ctxs']) idx_str = str(idx) if 'id' not in data_item else str(data_item['id']) idx += 1 id_list.append(idx_str) questions.append(data_item['question']) if "answers" in data_item: references.append(data_item['answers'][0]) answer_list = [answer_str for answer_str in data_item['answers']] id_to_labels[idx_str] = answer_list elif "answer" in data_item: references.append(data_item['answer']) # take the single answer id_to_labels[idx_str] = [data_item['answer']] else: raise ValueError("need answer or answers from input json") return id_to_labels, id_list, questions def convert_to_seq(aquestion, apred): if apred is None: apred = "" matched_pred = find_answer(apred) if matched_pred is None: matched_pred = apred apred = '({})'.format(matched_pred) alist = aquestion.split('\n') for aitem in alist: aitem = aitem.strip() if aitem.startswith(apred): pred_out = ' '.join(aitem.split(' ')[1:]) print('from {} to [{}]'.format(apred, pred_out)) return pred_out print('Warning: could not find ({}) from question {}'.format(apred, aquestion)) return apred # 500 -> 100 def load_prediction(test_file, id_list, id_to_labels, questions, dataset_name): predictions = [] with open(test_file, "r") as f: for line in f.readlines(): predictions.append(line.strip()) if len(predictions) != len(id_list): print("NOTE: different number of samples, {} in prediction, yet {} in reference".format( len(predictions), len(id_list))) id_list = id_list[0: len(predictions)] id_to_prediction = dict() for aid, apred in zip(id_list, predictions): id_to_prediction[aid] = apred if dataset_name.startswith('quality'): print('quality dataset, and rewriting the prediction to the full textual sequence...') questions = questions[0: len(predictions)] id_to_prediction = dict() for aid, aquestion, apred in zip(id_list, questions, predictions): apred_seq = convert_to_seq(aquestion, apred) id_to_prediction[aid] = apred_seq return id_to_prediction, id_list def main(args, raise_on_errors=False): datasets = [args.dataset] if args.dataset in DATASETS else DATASETS for dataset_name in datasets: print(dataset_name) scrolls_metric = load_metric(download_metric(), dataset_name) # TODO cost time to load ! NOTE id_to_labels, id_list, questions = read_json_data(args.datapath) id_to_pred, id_list = load_prediction(args.gen_test_file, id_list, id_to_labels, questions, dataset_name) if len(id_to_labels) > len(id_list): print('NOTE: prune the reference set from {} to {}'.format( len(id_to_labels), len(id_list))) id_to_labels = {aid:id_to_labels[aid] for aid in id_list} errors, details = verify(id_to_pred, id_to_labels) if len(errors) == 0: metrics = scrolls_metric.compute(**scrolls_metric.convert_from_map_format(id_to_pred, id_to_labels)) print(json.dumps(metrics, indent=4)) dislist = [str(item) for item in metrics['display']] print('final display:', dataset_name, ' '.join(dislist)) elif len(errors) > 0: errors_msg = errors[0] if len(errors) == 1 else " ".join(f"{i}: {err}" for i, err in enumerate(errors)) print(json.dumps(errors, indent=4)) raise ValueError(f"Failed to evaluate due to: {errors_msg}") def download_metric(): scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset") updated_scrolls_metric_path = ( os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py" ) shutil.copy(scrolls_metric_path, updated_scrolls_metric_path) return updated_scrolls_metric_path def verify(id_to_pred, id_to_labels): errors = [] details = {"missing_keys": [], "redundant_keys": []} if not isinstance(id_to_pred, dict): errors.append('The predictions must be saved a JSON object: {"id1": "prediction1", "id2": "prediction2", ...}') else: if not all(isinstance(key, str) for key in id_to_pred.keys()): errors.append("All keys of the predictions dictionary must be strings") if not all(isinstance(value, str) for value in id_to_pred.values()): errors.append("All values of the predictions dictionary must be strings") if len(errors) == 0: predictions_keys, reference_keys = set(id_to_pred.keys()), set(id_to_labels.keys()) missing_keys = reference_keys - predictions_keys redundant_keys = predictions_keys - reference_keys if len(missing_keys) > 0: details["missing_keys"] = list(missing_keys) errors.append(f"There are missing example IDs.") else: del details["missing_keys"] if len(redundant_keys) > 0: details["redundant_keys"] = list(redundant_keys) errors.append(f"There are redundant example IDs.") else: del details["redundant_keys"] return errors, details if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluate SCROLLS predictions per dataset") parser.add_argument("--datapath", type=str, default=None, help="datapath for test json file [reference]") parser.add_argument("--gen_test_file", type=str, default=None, help="generations for test file [system prediction]") parser.add_argument("--dataset", type=str, default=None, help="name of the dataset used in scrolls: {}".format(DATASETS)) args = parser.parse_args() main(args)