Llama3-ChatQA-2-70B / evaluation /long_32k_eval /dataset_evaluator_retro.py
root
add long_32k_eval
dfdc6c0
import os
import argparse
import json
import shutil
import re
from datasets import load_dataset, load_metric
from huggingface_hub import hf_hub_download
DATASETS = [
"gov_report",
"summ_screen_fd",
"qmsum",
"qasper",
"narrative_qa",
"quality",
"quality_hard",
"contract_nli",
]
PATTERN = re.compile(r'\b[A-D]\b')
def find_answer(s):
match = PATTERN.search(s)
if match is None:
return None # None is a signal of not find! NOTE
return match.group()
def read_json_data(data_path):
references = []
questions = []
id_to_labels = dict()
id_list = list()
idx = 0
with open(data_path, "r") as f:
examples = json.load(f)
for data_item in examples: # dict_keys(['source', 'paragraph_id', 'question', 'answer', 'sub-paragraphs', 'word_count', 'id', 'ctxs'])
idx_str = str(idx) if 'id' not in data_item else str(data_item['id'])
idx += 1
id_list.append(idx_str)
questions.append(data_item['question'])
if "answers" in data_item:
references.append(data_item['answers'][0])
answer_list = [answer_str for answer_str in data_item['answers']]
id_to_labels[idx_str] = answer_list
elif "answer" in data_item:
references.append(data_item['answer']) # take the single answer
id_to_labels[idx_str] = [data_item['answer']]
else:
raise ValueError("need answer or answers from input json")
return id_to_labels, id_list, questions
def convert_to_seq(aquestion, apred):
if apred is None:
apred = ""
matched_pred = find_answer(apred)
if matched_pred is None:
matched_pred = apred
apred = '({})'.format(matched_pred)
alist = aquestion.split('\n')
for aitem in alist:
aitem = aitem.strip()
if aitem.startswith(apred):
pred_out = ' '.join(aitem.split(' ')[1:])
print('from {} to [{}]'.format(apred, pred_out))
return pred_out
print('Warning: could not find ({}) from question {}'.format(apred, aquestion))
return apred
# 500 -> 100
def load_prediction(test_file, id_list, id_to_labels, questions, dataset_name):
predictions = []
with open(test_file, "r") as f:
for line in f.readlines():
predictions.append(line.strip())
if len(predictions) != len(id_list):
print("NOTE: different number of samples, {} in prediction, yet {} in reference".format(
len(predictions), len(id_list)))
id_list = id_list[0: len(predictions)]
id_to_prediction = dict()
for aid, apred in zip(id_list, predictions):
id_to_prediction[aid] = apred
if dataset_name.startswith('quality'):
print('quality dataset, and rewriting the prediction to the full textual sequence...')
questions = questions[0: len(predictions)]
id_to_prediction = dict()
for aid, aquestion, apred in zip(id_list, questions, predictions):
apred_seq = convert_to_seq(aquestion, apred)
id_to_prediction[aid] = apred_seq
return id_to_prediction, id_list
def main(args, raise_on_errors=False):
datasets = [args.dataset] if args.dataset in DATASETS else DATASETS
for dataset_name in datasets:
print(dataset_name)
scrolls_metric = load_metric(download_metric(), dataset_name) # TODO cost time to load ! NOTE
id_to_labels, id_list, questions = read_json_data(args.datapath)
id_to_pred, id_list = load_prediction(args.gen_test_file,
id_list, id_to_labels, questions,
dataset_name)
if len(id_to_labels) > len(id_list):
print('NOTE: prune the reference set from {} to {}'.format(
len(id_to_labels), len(id_list)))
id_to_labels = {aid:id_to_labels[aid] for aid in id_list}
errors, details = verify(id_to_pred, id_to_labels)
if len(errors) == 0:
metrics = scrolls_metric.compute(**scrolls_metric.convert_from_map_format(id_to_pred, id_to_labels))
print(json.dumps(metrics, indent=4))
dislist = [str(item) for item in metrics['display']]
print('final display:', dataset_name, ' '.join(dislist))
elif len(errors) > 0:
errors_msg = errors[0] if len(errors) == 1 else " ".join(f"{i}: {err}" for i, err in enumerate(errors))
print(json.dumps(errors, indent=4))
raise ValueError(f"Failed to evaluate due to: {errors_msg}")
def download_metric():
scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset")
updated_scrolls_metric_path = (
os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
)
shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
return updated_scrolls_metric_path
def verify(id_to_pred, id_to_labels):
errors = []
details = {"missing_keys": [], "redundant_keys": []}
if not isinstance(id_to_pred, dict):
errors.append('The predictions must be saved a JSON object: {"id1": "prediction1", "id2": "prediction2", ...}')
else:
if not all(isinstance(key, str) for key in id_to_pred.keys()):
errors.append("All keys of the predictions dictionary must be strings")
if not all(isinstance(value, str) for value in id_to_pred.values()):
errors.append("All values of the predictions dictionary must be strings")
if len(errors) == 0:
predictions_keys, reference_keys = set(id_to_pred.keys()), set(id_to_labels.keys())
missing_keys = reference_keys - predictions_keys
redundant_keys = predictions_keys - reference_keys
if len(missing_keys) > 0:
details["missing_keys"] = list(missing_keys)
errors.append(f"There are missing example IDs.")
else:
del details["missing_keys"]
if len(redundant_keys) > 0:
details["redundant_keys"] = list(redundant_keys)
errors.append(f"There are redundant example IDs.")
else:
del details["redundant_keys"]
return errors, details
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate SCROLLS predictions per dataset")
parser.add_argument("--datapath", type=str,
default=None, help="datapath for test json file [reference]")
parser.add_argument("--gen_test_file", type=str,
default=None, help="generations for test file [system prediction]")
parser.add_argument("--dataset", type=str,
default=None, help="name of the dataset used in scrolls: {}".format(DATASETS))
args = parser.parse_args()
main(args)