File size: 3,210 Bytes
2f044c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse

import torch

from relik.reader.data.relik_reader_sample import load_relik_reader_samples
from relik.reader.pytorch_modules.hf.modeling_relik import (
    RelikReaderConfig,
    RelikReaderREModel,
)
from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction
from relik.reader.utils.relation_matching_eval import StrongMatching


def eval(model_path, data_path, is_eval, output_path=None):
    if model_path.endswith(".ckpt"):
        # if it is a lightning checkpoint we load the model state dict and the tokenizer from the config
        model_dict = torch.load(model_path)

        additional_special_symbols = model_dict["hyper_parameters"][
            "additional_special_symbols"
        ]
        from transformers import AutoTokenizer

        from relik.reader.utils.special_symbols import get_special_symbols_re

        special_symbols = get_special_symbols_re(additional_special_symbols - 1)
        tokenizer = AutoTokenizer.from_pretrained(
            model_dict["hyper_parameters"]["transformer_model"],
            additional_special_tokens=special_symbols,
            add_prefix_space=True,
        )
        config_model = RelikReaderConfig(
            model_dict["hyper_parameters"]["transformer_model"],
            len(special_symbols),
            training=False,
        )
        model = RelikReaderREModel(config_model)
        model_dict["state_dict"] = {
            k.replace("relik_reader_re_model.", ""): v
            for k, v in model_dict["state_dict"].items()
        }
        model.load_state_dict(model_dict["state_dict"], strict=False)
        reader = RelikReaderForTripletExtraction(
            model, training=False, device="cuda", tokenizer=tokenizer
        )
    else:
        # if it is a huggingface model we load the model directly. Note that it could even be a string from the hub
        model = RelikReaderREModel.from_pretrained(model_path)
        reader = RelikReaderForTripletExtraction(
            model, training=False, device="cuda"
        )  # , dataset_kwargs={"use_nme": True}) if we want to use NME

    samples = list(load_relik_reader_samples(data_path))

    predicted_samples = reader.read(samples=samples, progress_bar=True)
    if is_eval:
        strong_matching_metric = StrongMatching()
        predicted_samples = list(predicted_samples)
        for k, v in strong_matching_metric(predicted_samples).items():
            print(f"test_{k}", v)
    if output_path is not None:
        with open(output_path, "w") as f:
            for sample in predicted_samples:
                f.write(sample.to_jsons() + "\n")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_path",
        type=str,
        default="/root/relik/experiments/relik_reader_re_small",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        default="/root/relik/data/re/test.jsonl",
    )
    parser.add_argument("--is-eval", action="store_true")
    parser.add_argument("--output_path", type=str, default=None)
    args = parser.parse_args()
    eval(args.model_path, args.data_path, args.is_eval, args.output_path)


if __name__ == "__main__":
    main()