|
import ast |
|
import os |
|
import re |
|
from typing import Any, Dict, List |
|
|
|
from recall_eval.eval_metrics import Metric |
|
from recall_eval.prompt import prompt_sys, prompt_user |
|
from utils import save_json |
|
|
|
|
|
def pprint_format(result): |
|
print("tables:") |
|
print("macro-Averaged Recall", result["table"]["macro-Averaged Recall"]) |
|
print("macro-Averaged F1", result["table"]["macro-Averaged F1"]) |
|
print("columns:") |
|
print("macro-Averaged Recall", result["column"]["macro-Averaged Recall"]) |
|
print("macro-Averaged F1", result["column"]["macro-Averaged F1"]) |
|
|
|
|
|
def make_pred(samples, pred_ext): |
|
assert len(samples) == len(pred_ext) |
|
preds = [] |
|
for i in range(len(samples)): |
|
preds.append( |
|
{ |
|
"query": samples[i]["query"], |
|
"output": str(pred_ext[i]["output"]), |
|
"pred_table": str(pred_ext[i]["tables"]), |
|
"pred_col": str(pred_ext[i]["columns"]), |
|
"label_table": samples[i]["label_table"], |
|
"label_col": samples[i]["label_col"], |
|
} |
|
) |
|
return preds |
|
|
|
|
|
def format_inputs(samples: List[Dict]) -> List: |
|
""" |
|
输入数据格式化函数,按照 generate 的格式要求改造 inputs |
|
:param samples: 待格式化样例数据 |
|
:param mode: 格式化模式 |
|
""" |
|
|
|
msgs = [] |
|
for sample in samples: |
|
msg_sys = prompt_sys |
|
msg_user = prompt_user.format( |
|
table_infos=sample["table_infos"], query=sample["query"] |
|
) |
|
msg = [ |
|
{"role": "system", "content": msg_sys}, |
|
{"role": "user", "content": msg_user}, |
|
] |
|
msgs.append(msg) |
|
return msgs |
|
|
|
|
|
def parser_text(text: str) -> Any: |
|
""" |
|
llm 推理结果解析函数,提取 生成代码 或 召回的表格和字段信息 |
|
:param text: 文本,形如 llm_response['output_text'] |
|
:param mode: 解析模式,gen 为解析生成的代码,extract 为解析召回结果 |
|
""" |
|
pattern_table = r"(?i)tables(?:\s+is)?\s*:\s*\[([^\]]+)\]" |
|
pattern_column = r"(?i)columns(?:\s+is)?\s*:\s*\[([^\]]+)\]" |
|
text = text.replace("【", "[").replace("】", "]").replace("`", "") |
|
match_tables = re.findall(pattern_table, text.strip()) |
|
match_columns = re.findall(pattern_column, text.strip()) |
|
tables = [] |
|
columns = [] |
|
if match_tables: |
|
try: |
|
tables = ast.literal_eval(f"[{match_tables[0]}]") |
|
except Exception as e: |
|
tables = [] |
|
if match_columns: |
|
try: |
|
columns = ast.literal_eval(f"[{match_columns[0]}]") |
|
if len(tables) == 1 and len(columns) > 0: |
|
columns = [ |
|
( |
|
"{}.{}".format(tables[0], column) |
|
if not column.startswith("{}.".format(tables[0])) |
|
else column |
|
) |
|
for column in columns |
|
] |
|
except Exception as e: |
|
columns = [] |
|
return {"tables": tables, "columns": columns, "output": text} |
|
|
|
|
|
def parser_list(batch_resp): |
|
|
|
return [parser_text(resp["output_text"]) for resp in batch_resp] |
|
|
|
|
|
def save_result(preds, report, test_file_path): |
|
|
|
parent_path = os.path.dirname(test_file_path) |
|
save_path = os.path.join(parent_path, "recall_eval_llm_gen_data.json") |
|
save_json(save_path, preds) |
|
print(f"Recall Eval Saved:{save_path}") |
|
save_path = os.path.join(parent_path, "recall_eval_report.json") |
|
save_json(save_path, report) |
|
print(f"Recall Eval Saved:{save_path}") |
|
|
|
|
|
def eval_outputs(preds: List[Dict], samples: List[Dict], lang: str = None) -> Dict: |
|
""" |
|
eval结果计算函数,使用 Metric 中评估方法,评估表格、字段召回的相关指标 |
|
:param preds: 模型预测结果 |
|
:param samples: 数据集测试样本 |
|
""" |
|
|
|
def combine_metrics_under_key(pred_data, label_data, key): |
|
combined_metrics = {} |
|
|
|
for metric_name in ["averaged"]: |
|
metric_results = getattr(Metric, metric_name)(pred_data, label_data) |
|
combined_metrics.update(metric_results) |
|
|
|
return {key: combined_metrics} |
|
|
|
pred_tables = [pred["tables"] for pred in preds] |
|
pred_columns = [pred["columns"] for pred in preds] |
|
label_tables = [sample["label_table"] for sample in samples] |
|
label_columns = [sample["label_col"] for sample in samples] |
|
if lang == "python": |
|
label_columns = [[col.split(".")[1] for col in cols] for cols in label_columns] |
|
table_metrics_combined = combine_metrics_under_key( |
|
pred_tables, label_tables, "table" |
|
) |
|
column_metrics_combined = combine_metrics_under_key( |
|
pred_columns, label_columns, "column" |
|
) |
|
merged_results = {**table_metrics_combined, **column_metrics_combined} |
|
return merged_results |
|
|