riccorl's picture
first commit
626eca0
import argparse
import torch
from reader.data.relik_reader_sample import load_relik_reader_samples
from relik.reader.pytorch_modules.hf.modeling_relik import (
RelikReaderConfig,
RelikReaderREModel,
)
from relik.reader.relik_reader_re import RelikReaderForTripletExtraction
from relik.reader.utils.relation_matching_eval import StrongMatching
dict_nyt = {
"/people/person/nationality": "nationality",
"/sports/sports_team/location": "sports team location",
"/location/country/administrative_divisions": "administrative divisions",
"/business/company/major_shareholders": "shareholders",
"/people/ethnicity/people": "ethnicity",
"/people/ethnicity/geographic_distribution": "geographic distributi6on",
"/business/company_shareholder/major_shareholder_of": "major shareholder",
"/location/location/contains": "location",
"/business/company/founders": "founders",
"/business/person/company": "company",
"/business/company/advisors": "advisor",
"/people/deceased_person/place_of_death": "place of death",
"/business/company/industry": "industry",
"/people/person/ethnicity": "ethnic background",
"/people/person/place_of_birth": "place of birth",
"/location/administrative_division/country": "country of an administration division",
"/people/person/place_lived": "place lived",
"/sports/sports_team_location/teams": "sports team",
"/people/person/children": "child",
"/people/person/religion": "religion",
"/location/neighborhood/neighborhood_of": "neighborhood",
"/location/country/capital": "capital",
"/business/company/place_founded": "company founded location",
"/people/person/profession": "occupation",
}
def eval(model_path, data_path, is_eval, output_path=None):
if model_path.endswith(".ckpt"):
# if it is a lightning checkpoint we load the model state dict and the tokenizer from the config
model_dict = torch.load(model_path)
additional_special_symbols = model_dict["hyper_parameters"][
"additional_special_symbols"
]
from transformers import AutoTokenizer
from relik.reader.utils.special_symbols import get_special_symbols_re
special_symbols = get_special_symbols_re(additional_special_symbols - 1)
tokenizer = AutoTokenizer.from_pretrained(
model_dict["hyper_parameters"]["transformer_model"],
additional_special_tokens=special_symbols,
add_prefix_space=True,
)
config_model = RelikReaderConfig(
model_dict["hyper_parameters"]["transformer_model"],
len(special_symbols),
training=False,
)
model = RelikReaderREModel(config_model)
model_dict["state_dict"] = {
k.replace("relik_reader_re_model.", ""): v
for k, v in model_dict["state_dict"].items()
}
model.load_state_dict(model_dict["state_dict"], strict=False)
reader = RelikReaderForTripletExtraction(
model, training=False, device="cuda", tokenizer=tokenizer
)
else:
# if it is a huggingface model we load the model directly. Note that it could even be a string from the hub
model = RelikReaderREModel.from_pretrained(model_path)
reader = RelikReaderForTripletExtraction(model, training=False, device="cuda")
samples = list(load_relik_reader_samples(data_path))
for sample in samples:
sample.candidates = [dict_nyt[cand] for cand in sample.candidates]
sample.triplets = [
{
"subject": triplet["subject"],
"relation": {
"name": dict_nyt[triplet["relation"]["name"]],
"type": triplet["relation"]["type"],
},
"object": triplet["object"],
}
for triplet in sample.triplets
]
predicted_samples = reader.read(samples=samples, progress_bar=True)
if is_eval:
strong_matching_metric = StrongMatching()
predicted_samples = list(predicted_samples)
for k, v in strong_matching_metric(predicted_samples).items():
print(f"test_{k}", v)
if output_path is not None:
with open(output_path, "w") as f:
for sample in predicted_samples:
f.write(sample.to_jsons() + "\n")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
default="/home/huguetcabot/alby-re/relik/relik/reader/models/relik_re_reader_base",
)
parser.add_argument(
"--data_path",
type=str,
default="/home/huguetcabot/alby-re/relik/relik/reader/data/testa.jsonl",
)
parser.add_argument("--is-eval", action="store_true")
parser.add_argument("--output_path", type=str, default=None)
args = parser.parse_args()
eval(args.model_path, args.data_path, args.is_eval, args.output_path)
if __name__ == "__main__":
main()