|
import argparse |
|
|
|
import torch |
|
|
|
from relik.reader.data.relik_reader_sample import load_relik_reader_samples |
|
from relik.reader.pytorch_modules.hf.modeling_relik import ( |
|
RelikReaderConfig, |
|
RelikReaderREModel, |
|
) |
|
from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction |
|
from relik.reader.utils.relation_matching_eval import StrongMatching |
|
|
|
|
|
def eval(model_path, data_path, is_eval, output_path=None): |
|
if model_path.endswith(".ckpt"): |
|
|
|
model_dict = torch.load(model_path) |
|
|
|
additional_special_symbols = model_dict["hyper_parameters"][ |
|
"additional_special_symbols" |
|
] |
|
from transformers import AutoTokenizer |
|
|
|
from relik.reader.utils.special_symbols import get_special_symbols_re |
|
|
|
special_symbols = get_special_symbols_re(additional_special_symbols - 1) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_dict["hyper_parameters"]["transformer_model"], |
|
additional_special_tokens=special_symbols, |
|
add_prefix_space=True, |
|
) |
|
config_model = RelikReaderConfig( |
|
model_dict["hyper_parameters"]["transformer_model"], |
|
len(special_symbols), |
|
training=False, |
|
) |
|
model = RelikReaderREModel(config_model) |
|
model_dict["state_dict"] = { |
|
k.replace("relik_reader_re_model.", ""): v |
|
for k, v in model_dict["state_dict"].items() |
|
} |
|
model.load_state_dict(model_dict["state_dict"], strict=False) |
|
reader = RelikReaderForTripletExtraction( |
|
model, training=False, device="cuda", tokenizer=tokenizer |
|
) |
|
else: |
|
|
|
model = RelikReaderREModel.from_pretrained(model_path) |
|
reader = RelikReaderForTripletExtraction( |
|
model, training=False, device="cuda" |
|
) |
|
|
|
samples = list(load_relik_reader_samples(data_path)) |
|
|
|
predicted_samples = reader.read(samples=samples, progress_bar=True) |
|
if is_eval: |
|
strong_matching_metric = StrongMatching() |
|
predicted_samples = list(predicted_samples) |
|
for k, v in strong_matching_metric(predicted_samples).items(): |
|
print(f"test_{k}", v) |
|
if output_path is not None: |
|
with open(output_path, "w") as f: |
|
for sample in predicted_samples: |
|
f.write(sample.to_jsons() + "\n") |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model_path", |
|
type=str, |
|
default="/root/relik/experiments/relik_reader_re_small", |
|
) |
|
parser.add_argument( |
|
"--data_path", |
|
type=str, |
|
default="/root/relik/data/re/test.jsonl", |
|
) |
|
parser.add_argument("--is-eval", action="store_true") |
|
parser.add_argument("--output_path", type=str, default=None) |
|
args = parser.parse_args() |
|
eval(args.model_path, args.data_path, args.is_eval, args.output_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|