|
import os |
|
import argparse |
|
import json |
|
import shutil |
|
import re |
|
|
|
from datasets import load_dataset, load_metric |
|
from huggingface_hub import hf_hub_download |
|
|
|
from nv.evaluate_f1_sft_zeroshot import evaluate_f1 |
|
|
|
DATASETS = [ |
|
'doc2dial_full_dialogue', |
|
] |
|
|
|
PATTERN = re.compile(r'\b[A-D]\b') |
|
|
|
def find_answer(s): |
|
match = PATTERN.search(s) |
|
if match is None: |
|
return None |
|
return match.group() |
|
|
|
def read_json_data(data_path): |
|
references = [] |
|
questions = [] |
|
id_to_labels = dict() |
|
id_list = list() |
|
idx = 0 |
|
with open(data_path, "r") as f: |
|
examples = json.load(f) |
|
for data_item in examples: |
|
idx_str = str(idx) if 'id' not in data_item else str(data_item['id']) |
|
idx += 1 |
|
id_list.append(idx_str) |
|
|
|
questions.append(data_item['question']) |
|
if "answers" in data_item: |
|
references.append(data_item['answers']) |
|
answer_list = [answer_str for answer_str in data_item['answers']] |
|
id_to_labels[idx_str] = answer_list |
|
|
|
elif "answer" in data_item: |
|
references.append([data_item['answer']]) |
|
id_to_labels[idx_str] = [data_item['answer']] |
|
else: |
|
raise ValueError("need answer or answers from input json") |
|
return id_to_labels, id_list, questions, references |
|
|
|
def convert_to_seq(aquestion, apred): |
|
if apred is None: |
|
apred = "" |
|
|
|
matched_pred = find_answer(apred) |
|
if matched_pred is None: |
|
matched_pred = apred |
|
|
|
apred = '({})'.format(matched_pred) |
|
|
|
alist = aquestion.split('\n') |
|
for aitem in alist: |
|
aitem = aitem.strip() |
|
if aitem.startswith(apred): |
|
pred_out = ' '.join(aitem.split(' ')[1:]) |
|
print('from {} to [{}]'.format(apred, pred_out)) |
|
return pred_out |
|
print('Warning: could not find ({}) from question {}'.format(apred, aquestion)) |
|
|
|
return apred |
|
|
|
|
|
def load_prediction(test_file, id_list, id_to_labels, questions, dataset_name): |
|
predictions = [] |
|
with open(test_file, "r") as f: |
|
for line in f.readlines(): |
|
predictions.append(line.strip()) |
|
if len(predictions) != len(id_list): |
|
print("NOTE: different number of samples, {} in prediction, yet {} in reference".format( |
|
len(predictions), len(id_list))) |
|
id_list = id_list[0: len(predictions)] |
|
|
|
id_to_prediction = dict() |
|
for aid, apred in zip(id_list, predictions): |
|
id_to_prediction[aid] = apred |
|
|
|
if dataset_name.startswith('quality'): |
|
print('quality dataset, and rewriting the prediction to the full textual sequence...') |
|
questions = questions[0: len(predictions)] |
|
id_to_prediction = dict() |
|
for aid, aquestion, apred in zip(id_list, questions, predictions): |
|
apred_seq = convert_to_seq(aquestion, apred) |
|
id_to_prediction[aid] = apred_seq |
|
|
|
return id_to_prediction, id_list, predictions |
|
|
|
def main(args): |
|
datasets = [args.dataset] if args.dataset in DATASETS else DATASETS |
|
for dataset_name in datasets: |
|
print(dataset_name) |
|
|
|
ground_truth_file = args.datapath |
|
prediction_file = args.gen_test_file |
|
|
|
evaluate_f1(ground_truth_file, prediction_file, dataset_name) |
|
|
|
|
|
def main_orig(args, raise_on_errors=False): |
|
datasets = [args.dataset] if args.dataset in DATASETS else DATASETS |
|
for dataset_name in datasets: |
|
print(dataset_name) |
|
|
|
id_to_labels, id_list, questions, answers = read_json_data(args.datapath) |
|
id_to_pred, id_list, predictions = load_prediction(args.gen_test_file, |
|
id_list, id_to_labels, questions, |
|
dataset_name) |
|
|
|
if len(id_to_labels) > len(id_list): |
|
print('NOTE: prune the reference set from {} to {}'.format( |
|
len(id_to_labels), len(id_list))) |
|
id_to_labels = {aid:id_to_labels[aid] for aid in id_list} |
|
|
|
errors, details = verify(id_to_pred, id_to_labels) |
|
|
|
if len(errors) == 0: |
|
score = scorer(dataset_name, predictions, answers, all_classes=None) |
|
print('final display:', dataset_name, score) |
|
elif len(errors) > 0: |
|
errors_msg = errors[0] if len(errors) == 1 else " ".join(f"{i}: {err}" for i, err in enumerate(errors)) |
|
print(json.dumps(errors, indent=4)) |
|
raise ValueError(f"Failed to evaluate due to: {errors_msg}") |
|
|
|
|
|
def download_metric(): |
|
scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset") |
|
updated_scrolls_metric_path = ( |
|
os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py" |
|
) |
|
shutil.copy(scrolls_metric_path, updated_scrolls_metric_path) |
|
return updated_scrolls_metric_path |
|
|
|
def verify(id_to_pred, id_to_labels): |
|
errors = [] |
|
details = {"missing_keys": [], "redundant_keys": []} |
|
if not isinstance(id_to_pred, dict): |
|
errors.append('The predictions must be saved a JSON object: {"id1": "prediction1", "id2": "prediction2", ...}') |
|
else: |
|
if not all(isinstance(key, str) for key in id_to_pred.keys()): |
|
errors.append("All keys of the predictions dictionary must be strings") |
|
if not all(isinstance(value, str) for value in id_to_pred.values()): |
|
errors.append("All values of the predictions dictionary must be strings") |
|
if len(errors) == 0: |
|
predictions_keys, reference_keys = set(id_to_pred.keys()), set(id_to_labels.keys()) |
|
missing_keys = reference_keys - predictions_keys |
|
redundant_keys = predictions_keys - reference_keys |
|
|
|
if len(missing_keys) > 0: |
|
details["missing_keys"] = list(missing_keys) |
|
errors.append(f"There are missing example IDs.") |
|
else: |
|
del details["missing_keys"] |
|
|
|
if len(redundant_keys) > 0: |
|
details["redundant_keys"] = list(redundant_keys) |
|
errors.append(f"There are redundant example IDs.") |
|
else: |
|
del details["redundant_keys"] |
|
|
|
return errors, details |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Evaluate SCROLLS predictions per dataset") |
|
|
|
parser.add_argument("--datapath", type=str, |
|
default=None, help="datapath for test json file [reference]") |
|
parser.add_argument("--gen_test_file", type=str, |
|
default=None, help="generations for test file [system prediction]") |
|
parser.add_argument("--dataset", type=str, |
|
default=None, help="name of the dataset used in scrolls: {}".format(DATASETS)) |
|
|
|
args = parser.parse_args() |
|
main(args) |
|
|