File size: 6,924 Bytes
5953ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from pathlib import Path

import pandas as pd
from rex.utils.initialization import set_seed_and_log_path
from rex.utils.io import load_json
from rich.console import Console
from rich.table import Table

from src.task import SchemaGuidedInstructBertTask

set_seed_and_log_path(log_path="tmp_eval.log")


if __name__ == "__main__":
    task_dir = "mirror_outputs/Mirror_Pretrain_AllExcluded_2"
    # task_dir = "mirror_outputs/Mirror_SingleTask_wPTAllExcluded_Event_ACE05"
    task: SchemaGuidedInstructBertTask = SchemaGuidedInstructBertTask.from_taskdir(
        task_dir,
        load_best_model=True,
        initialize=False,
        dump_configfile=False,
        update_config={
            "regenerate_cache": True,
            "eval_on_data": ["dev"],
            "select_best_on_data": "dev",
            "select_best_by_key": "metric",
            "best_metric_field": "general_spans.micro.f1",
            "eval_batch_size": 32,
        },
    )
    table = Table(title=task_dir)

    data_pairs = [
        # fmt: off

        # UIE eval data
        # ["ent_ace04_test", "resources/Mirror/uie/ent/ace04/test.jsonl"],
        # ["ent_ace05_test", "resources/Mirror/uie/ent/ace05/test.jsonl"],
        ["ent_conll03_test", "resources/Mirror/uie/ent/conll03/test.jsonl"],
        # ["rel_ace05_test", "resources/Mirror/uie/rel/ace05-rel/test.jsonl"],
        ["rel_conll04_test", "resources/Mirror/uie/rel/conll04/test.jsonl"],
        # ["rel_nyt_test", "resources/Mirror/uie/rel/nyt/test.jsonl"],
        # ["rel_scierc_test", "resources/Mirror/uie/rel/scierc/test.jsonl"],
        ["event_ace05_test", "resources/Mirror/uie/event/ace05-evt/test.jsonl"],
        # ["event_casie_test", "resources/Mirror/uie/event/casie/test.jsonl"],
        # ["absa_14res_test", "resources/Mirror/uie/absa/14res/test.jsonl"],
        # ["absa_14lap_test", "resources/Mirror/uie/absa/14lap/test.jsonl"],
        # ["absa_15res_test", "resources/Mirror/uie/absa/15res/test.jsonl"],
        # ["absa_16res_test", "resources/Mirror/uie/absa/16res/test.jsonl"],
        # # discontinuous NER
        # ["discontinuous_ent", "resources/Mirror/new_abilities_v2/cadec/new/test.jsonl"],
        # # hyper-RE
        # ["hyper_rel", "resources/Mirror/new_abilities_v2/HyperRED/new/test.jsonl"],
        # # zero-shot NER
        # ["ent_movie", "resources/Mirror/v1.3/ent/en/MIT_MOVIE_Review/instructed/test.jsonl"],
        # ["ent_restaurant", "resources/Mirror/v1.3/ent/en/MIT_Restaurant_Review/instructed/test.jsonl"],
        # ["ent_ai", "resources/Mirror/v1.3/ent/en/CrossNER_AI/instructed/test.jsonl"],
        # ["ent_literature", "resources/Mirror/v1.3/ent/en/CrossNER_literature/instructed/test.jsonl"],
        # ["ent_music", "resources/Mirror/v1.3/ent/en/CrossNER_music/instructed/test.jsonl"],
        # ["ent_politics", "resources/Mirror/v1.3/ent/en/CrossNER_politics/instructed/test.jsonl"],
        # ["ent_science", "resources/Mirror/v1.3/ent/en/CrossNER_science/instructed/test.jsonl"],
        # # mrc
        # ["span_squad2", "resources/Mirror/v1.3/span/en/squad_v2/dev.jsonl"],
        # # glue
        # ["cls_glue_cola", "resources/Mirror/v1.3/cls/en/CoLA/formated/dev.jsonl"],
        # ["cls_glue_qqp", "resources/Mirror/v1.3/cls/en/QQP/new/dev.jsonl"],
        # ["cls_glue_mnli", "resources/Mirror/v1.3/cls/en/MNLI/formated/MNLI_dev.jsonl"],
        # ["cls_glue_sst2", "resources/Mirror/v1.3/cls/en/SST-2/instructed/SST-2_dev.jsonl"],
        # ["cls_glue_qnli", "resources/Mirror/v1.3/cls/en/QNLI/processed/QNLI_dev.jsonl"],
        # ["cls_glue_rte", "resources/Mirror/v1.3/cls/en/RTE/formated/RTE_dev.jsonl"],
        # ["cls_glue_mrpc", "resources/Mirror/v1.3/cls/en/MRPC/formated/dev.jsonl"],
        # fmt: on
    ]

    eval_res = {"task": [], "dataset": [], "metric_val": []}
    table.add_column("Task", justify="left", style="cyan")
    table.add_column("Dataset", justify="left", style="magenta")
    table.add_column("Metric (%)", justify="right", style="green")
    for dname, fpath in data_pairs:
        dname = dname.lower()
        task.data_manager.update_datapath(dname, fpath)
        _, res = task.eval(dname, verbose=True, dump=True, dump_middle=True)
        # res = load_json(Path(task_dir) / "measures" / f"{dname}.json")["metrics"]
        if dname.startswith("ent_"):
            eval_res["task"].append("ent")
            eval_res["dataset"].append(dname)
            eval_res["metric_val"].append(res["ent"]["micro"]["f1"])
        elif dname.startswith("rel_"):
            eval_res["task"].append("rel")
            eval_res["dataset"].append(dname)
            eval_res["metric_val"].append(res["rel"]["rel"]["micro"]["f1"])
        elif dname.startswith("event_"):
            eval_res["task"].append("event")
            eval_res["dataset"].append(dname + "_tgg")
            eval_res["metric_val"].append(res["event"]["trigger_cls"]["f1"])
            eval_res["task"].append("event")
            eval_res["dataset"].append(dname + "_arg")
            eval_res["metric_val"].append(res["event"]["arg_cls"]["f1"])
        elif dname.startswith("absa_"):
            eval_res["task"].append("absa")
            eval_res["dataset"].append(dname)
            eval_res["metric_val"].append(res["rel"]["rel"]["micro"]["f1"])
        elif dname.startswith("cls_"):
            eval_res["task"].append("cls")
            eval_res["dataset"].append(dname)
            if "_glue_" in dname:
                if "_cola" in dname:
                    eval_res["metric_val"].append(res["cls"]["mcc"])
                else:
                    eval_res["metric_val"].append(res["cls"]["acc"])
            else:
                eval_res["metric_val"].append(res["cls"]["mf1"]["micro"]["f1"])
        elif dname.startswith("span"):
            eval_res["task"].append("span_em")
            eval_res["dataset"].append(dname)
            eval_res["metric_val"].append(res["span"]["em"])
            eval_res["task"].append("span_f1")
            eval_res["dataset"].append(dname)
            eval_res["metric_val"].append(res["span"]["f1"]["f1"])
        elif dname.startswith("discontinuous_ent"):
            eval_res["task"].append("discontinuous_ent")
            eval_res["dataset"].append(dname)
            eval_res["metric_val"].append(res["discontinuous_ent"]["micro"]["f1"])
        elif dname.startswith("hyper_rel"):
            eval_res["task"].append("hyper_rel")
            eval_res["dataset"].append(dname)
            eval_res["metric_val"].append(res["hyper_rel"]["micro"]["f1"])
        else:
            raise ValueError

    for i in range(len(eval_res["task"])):
        table.add_row(
            eval_res["task"][i],
            eval_res["dataset"][i],
            f"{100*eval_res['metric_val'][i]:.3f}",
        )

    console = Console()
    console.print(table)

    df = pd.DataFrame(eval_res)
    df.to_excel(task.measures_path.joinpath("data_eval_res.xlsx"))