File size: 5,009 Bytes
2a26d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import ast
import os
import re
from typing import Any, Dict, List

from recall_eval.eval_metrics import Metric
from recall_eval.prompt import prompt_sys, prompt_user
from utils import save_json


def pprint_format(result):
    print("tables:")
    print("macro-Averaged Recall", result["table"]["macro-Averaged Recall"])
    print("macro-Averaged F1", result["table"]["macro-Averaged F1"])
    print("columns:")
    print("macro-Averaged Recall", result["column"]["macro-Averaged Recall"])
    print("macro-Averaged F1", result["column"]["macro-Averaged F1"])


def make_pred(samples, pred_ext):
    assert len(samples) == len(pred_ext)
    preds = []
    for i in range(len(samples)):
        preds.append(
            {
                "query": samples[i]["query"],
                "output": str(pred_ext[i]["output"]),
                "pred_table": str(pred_ext[i]["tables"]),
                "pred_col": str(pred_ext[i]["columns"]),
                "label_table": samples[i]["label_table"],
                "label_col": samples[i]["label_col"],
            }
        )
    return preds


def format_inputs(samples: List[Dict]) -> List:
    """
    输入数据格式化函数,按照 generate 的格式要求改造 inputs
    :param samples: 待格式化样例数据
    :param mode: 格式化模式
    """
    # 把需要推理的数据拼成 message 形式
    msgs = []
    for sample in samples:
        msg_sys = prompt_sys
        msg_user = prompt_user.format(
            table_infos=sample["table_infos"], query=sample["query"]
        )
        msg = [
            {"role": "system", "content": msg_sys},
            {"role": "user", "content": msg_user},
        ]
        msgs.append(msg)
    return msgs


def parser_text(text: str) -> Any:
    """
    llm 推理结果解析函数,提取 生成代码 或 召回的表格和字段信息
    :param text: 文本,形如 llm_response['output_text']
    :param mode: 解析模式,gen 为解析生成的代码,extract 为解析召回结果
    """
    pattern_table = r"(?i)tables(?:\s+is)?\s*:\s*\[([^\]]+)\]"
    pattern_column = r"(?i)columns(?:\s+is)?\s*:\s*\[([^\]]+)\]"
    text = text.replace("【", "[").replace("】", "]").replace("`", "")
    match_tables = re.findall(pattern_table, text.strip())
    match_columns = re.findall(pattern_column, text.strip())
    tables = []
    columns = []
    if match_tables:
        try:
            tables = ast.literal_eval(f"[{match_tables[0]}]")
        except Exception as e:
            tables = []
    if match_columns:
        try:
            columns = ast.literal_eval(f"[{match_columns[0]}]")
            if len(tables) == 1 and len(columns) > 0:
                columns = [
                    (
                        "{}.{}".format(tables[0], column)
                        if not column.startswith("{}.".format(tables[0]))
                        else column
                    )
                    for column in columns
                ]
        except Exception as e:
            columns = []
    return {"tables": tables, "columns": columns, "output": text}


def parser_list(batch_resp):
    # 批处理解析
    return [parser_text(resp["output_text"]) for resp in batch_resp]


def save_result(preds, report, test_file_path):
    # 保存 LLM 生成的内容 和 最终的 recall_eval 结果
    parent_path = os.path.dirname(test_file_path)
    save_path = os.path.join(parent_path, "recall_eval_llm_gen_data.json")
    save_json(save_path, preds)
    print(f"Recall Eval Saved:{save_path}")
    save_path = os.path.join(parent_path, "recall_eval_report.json")
    save_json(save_path, report)
    print(f"Recall Eval Saved:{save_path}")


def eval_outputs(preds: List[Dict], samples: List[Dict], lang: str = None) -> Dict:
    """
    eval结果计算函数,使用 Metric 中评估方法,评估表格、字段召回的相关指标
    :param preds: 模型预测结果
    :param samples: 数据集测试样本
    """

    def combine_metrics_under_key(pred_data, label_data, key):
        combined_metrics = {}
        # for metric_name in ["averaged", "jaccard", "hamming"]:
        for metric_name in ["averaged"]:
            metric_results = getattr(Metric, metric_name)(pred_data, label_data)
            combined_metrics.update(metric_results)

        return {key: combined_metrics}

    pred_tables = [pred["tables"] for pred in preds]
    pred_columns = [pred["columns"] for pred in preds]
    label_tables = [sample["label_table"] for sample in samples]
    label_columns = [sample["label_col"] for sample in samples]
    if lang == "python":
        label_columns = [[col.split(".")[1] for col in cols] for cols in label_columns]
    table_metrics_combined = combine_metrics_under_key(
        pred_tables, label_tables, "table"
    )
    column_metrics_combined = combine_metrics_under_key(
        pred_columns, label_columns, "column"
    )
    merged_results = {**table_metrics_combined, **column_metrics_combined}
    return merged_results