CarlosMalaga's picture
Upload 201 files
2f044c1 verified
raw
history blame
3.21 kB
import argparse
import torch
from relik.reader.data.relik_reader_sample import load_relik_reader_samples
from relik.reader.pytorch_modules.hf.modeling_relik import (
RelikReaderConfig,
RelikReaderREModel,
)
from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction
from relik.reader.utils.relation_matching_eval import StrongMatching
def eval(model_path, data_path, is_eval, output_path=None):
if model_path.endswith(".ckpt"):
# if it is a lightning checkpoint we load the model state dict and the tokenizer from the config
model_dict = torch.load(model_path)
additional_special_symbols = model_dict["hyper_parameters"][
"additional_special_symbols"
]
from transformers import AutoTokenizer
from relik.reader.utils.special_symbols import get_special_symbols_re
special_symbols = get_special_symbols_re(additional_special_symbols - 1)
tokenizer = AutoTokenizer.from_pretrained(
model_dict["hyper_parameters"]["transformer_model"],
additional_special_tokens=special_symbols,
add_prefix_space=True,
)
config_model = RelikReaderConfig(
model_dict["hyper_parameters"]["transformer_model"],
len(special_symbols),
training=False,
)
model = RelikReaderREModel(config_model)
model_dict["state_dict"] = {
k.replace("relik_reader_re_model.", ""): v
for k, v in model_dict["state_dict"].items()
}
model.load_state_dict(model_dict["state_dict"], strict=False)
reader = RelikReaderForTripletExtraction(
model, training=False, device="cuda", tokenizer=tokenizer
)
else:
# if it is a huggingface model we load the model directly. Note that it could even be a string from the hub
model = RelikReaderREModel.from_pretrained(model_path)
reader = RelikReaderForTripletExtraction(
model, training=False, device="cuda"
) # , dataset_kwargs={"use_nme": True}) if we want to use NME
samples = list(load_relik_reader_samples(data_path))
predicted_samples = reader.read(samples=samples, progress_bar=True)
if is_eval:
strong_matching_metric = StrongMatching()
predicted_samples = list(predicted_samples)
for k, v in strong_matching_metric(predicted_samples).items():
print(f"test_{k}", v)
if output_path is not None:
with open(output_path, "w") as f:
for sample in predicted_samples:
f.write(sample.to_jsons() + "\n")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
default="/root/relik/experiments/relik_reader_re_small",
)
parser.add_argument(
"--data_path",
type=str,
default="/root/relik/data/re/test.jsonl",
)
parser.add_argument("--is-eval", action="store_true")
parser.add_argument("--output_path", type=str, default=None)
args = parser.parse_args()
eval(args.model_path, args.data_path, args.is_eval, args.output_path)
if __name__ == "__main__":
main()