qianxiao1111's picture
upgrade: add benchmarks eval
2a26d3b
raw
history blame
5.01 kB
import ast
import os
import re
from typing import Any, Dict, List
from recall_eval.eval_metrics import Metric
from recall_eval.prompt import prompt_sys, prompt_user
from utils import save_json
def pprint_format(result):
print("tables:")
print("macro-Averaged Recall", result["table"]["macro-Averaged Recall"])
print("macro-Averaged F1", result["table"]["macro-Averaged F1"])
print("columns:")
print("macro-Averaged Recall", result["column"]["macro-Averaged Recall"])
print("macro-Averaged F1", result["column"]["macro-Averaged F1"])
def make_pred(samples, pred_ext):
assert len(samples) == len(pred_ext)
preds = []
for i in range(len(samples)):
preds.append(
{
"query": samples[i]["query"],
"output": str(pred_ext[i]["output"]),
"pred_table": str(pred_ext[i]["tables"]),
"pred_col": str(pred_ext[i]["columns"]),
"label_table": samples[i]["label_table"],
"label_col": samples[i]["label_col"],
}
)
return preds
def format_inputs(samples: List[Dict]) -> List:
"""
输入数据格式化函数,按照 generate 的格式要求改造 inputs
:param samples: 待格式化样例数据
:param mode: 格式化模式
"""
# 把需要推理的数据拼成 message 形式
msgs = []
for sample in samples:
msg_sys = prompt_sys
msg_user = prompt_user.format(
table_infos=sample["table_infos"], query=sample["query"]
)
msg = [
{"role": "system", "content": msg_sys},
{"role": "user", "content": msg_user},
]
msgs.append(msg)
return msgs
def parser_text(text: str) -> Any:
"""
llm 推理结果解析函数,提取 生成代码 或 召回的表格和字段信息
:param text: 文本,形如 llm_response['output_text']
:param mode: 解析模式,gen 为解析生成的代码,extract 为解析召回结果
"""
pattern_table = r"(?i)tables(?:\s+is)?\s*:\s*\[([^\]]+)\]"
pattern_column = r"(?i)columns(?:\s+is)?\s*:\s*\[([^\]]+)\]"
text = text.replace("【", "[").replace("】", "]").replace("`", "")
match_tables = re.findall(pattern_table, text.strip())
match_columns = re.findall(pattern_column, text.strip())
tables = []
columns = []
if match_tables:
try:
tables = ast.literal_eval(f"[{match_tables[0]}]")
except Exception as e:
tables = []
if match_columns:
try:
columns = ast.literal_eval(f"[{match_columns[0]}]")
if len(tables) == 1 and len(columns) > 0:
columns = [
(
"{}.{}".format(tables[0], column)
if not column.startswith("{}.".format(tables[0]))
else column
)
for column in columns
]
except Exception as e:
columns = []
return {"tables": tables, "columns": columns, "output": text}
def parser_list(batch_resp):
# 批处理解析
return [parser_text(resp["output_text"]) for resp in batch_resp]
def save_result(preds, report, test_file_path):
# 保存 LLM 生成的内容 和 最终的 recall_eval 结果
parent_path = os.path.dirname(test_file_path)
save_path = os.path.join(parent_path, "recall_eval_llm_gen_data.json")
save_json(save_path, preds)
print(f"Recall Eval Saved:{save_path}")
save_path = os.path.join(parent_path, "recall_eval_report.json")
save_json(save_path, report)
print(f"Recall Eval Saved:{save_path}")
def eval_outputs(preds: List[Dict], samples: List[Dict], lang: str = None) -> Dict:
"""
eval结果计算函数,使用 Metric 中评估方法,评估表格、字段召回的相关指标
:param preds: 模型预测结果
:param samples: 数据集测试样本
"""
def combine_metrics_under_key(pred_data, label_data, key):
combined_metrics = {}
# for metric_name in ["averaged", "jaccard", "hamming"]:
for metric_name in ["averaged"]:
metric_results = getattr(Metric, metric_name)(pred_data, label_data)
combined_metrics.update(metric_results)
return {key: combined_metrics}
pred_tables = [pred["tables"] for pred in preds]
pred_columns = [pred["columns"] for pred in preds]
label_tables = [sample["label_table"] for sample in samples]
label_columns = [sample["label_col"] for sample in samples]
if lang == "python":
label_columns = [[col.split(".")[1] for col in cols] for cols in label_columns]
table_metrics_combined = combine_metrics_under_key(
pred_tables, label_tables, "table"
)
column_metrics_combined = combine_metrics_under_key(
pred_columns, label_columns, "column"
)
merged_results = {**table_metrics_combined, **column_metrics_combined}
return merged_results