Spaces:
Sleeping
Sleeping
import torch | |
from os import path | |
from model.utils import action_sequences_to_clusters | |
from model.entity_ranking_model import EntityRankingModel | |
from inference.tokenize_doc import tokenize_and_segment_doc, basic_tokenize_doc | |
from omegaconf import OmegaConf, open_dict | |
from transformers import AutoModel, AutoTokenizer | |
import spacy | |
import json | |
import pytorch_utils.utils as utils | |
class Inference: | |
def __init__(self, model_path): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.best_model_path = path.join(model_path, "best/model.pth") | |
self._load_model() | |
self.max_segment_len = self.config.model.doc_encoder.transformer.max_segment_len | |
self.tokenizer = self.model.mention_proposer.doc_encoder.tokenizer | |
def find_repr_and_clean(self, basic_tokenized_doc): | |
## Find marked representatives | |
num_brackets = 0 | |
start_tok = 0 | |
tokens_new = [] ## Contains {{ and }} | |
tokens_proc = [] ## Does not contain {{ and }} | |
basic_tokenized_doc_proc = [] ## Does not contain {{ and }} | |
skip_next = 0 | |
for sentence in basic_tokenized_doc: | |
tokens_sent = [] | |
for token_ind, token in enumerate(sentence): | |
if skip_next: | |
skip_next = 0 | |
continue | |
if token_ind + 1 < len(sentence): | |
if token == "{" and sentence[token_ind + 1] == "{": | |
tokens_new.append("{{") | |
skip_next = 1 | |
elif token == "}" and sentence[token_ind + 1] == "}": | |
tokens_new.append("}}") | |
skip_next = 1 | |
else: | |
tokens_new.append(token) | |
tokens_sent.append(token) | |
else: | |
tokens_new.append(token) | |
tokens_sent.append(token) | |
basic_tokenized_doc_proc.append(tokens_sent) | |
tokens_proc.extend(tokens_sent) | |
active_ent_toks = [] | |
ent_toks = [] | |
for word_ind, word in enumerate(tokens_new): | |
if word == "{{": | |
num_brackets += 1 | |
start_tok += 1 | |
elif word == "}}": | |
num_brackets += 1 | |
active_ent_toks[-1].append( | |
word_ind - num_brackets | |
) ## Since we included the current bracket upfront | |
new_entity = active_ent_toks.pop() | |
ent_toks.append(new_entity) | |
else: | |
while start_tok > 0: | |
active_ent_toks.append([word_ind - num_brackets]) | |
start_tok -= 1 | |
ent_names = [] | |
for ent in ent_toks: | |
ent_names.append(" ".join(tokens_proc[ent[0] : ent[1] + 1])) | |
print("Entities: ", ent_toks) | |
print("Entity Names: ", ent_names) | |
return basic_tokenized_doc_proc, ent_toks, ent_names | |
def get_ts_from_st(self, subtoken_map, representatives): | |
ts_map = {} | |
for subtoken_ind, token_ind in enumerate(subtoken_map): | |
if token_ind not in ts_map: | |
ts_map[token_ind] = [subtoken_ind] | |
if subtoken_ind != 0: | |
ts_map[token_ind - 1].append(subtoken_ind - 1) | |
ent_toks_st = [] | |
for entity in representatives: | |
start_st = ts_map[entity[0]][0] | |
end_st = ts_map[entity[1]][-1] | |
ent_toks_st.append((start_st, end_st)) | |
return ent_toks_st, ts_map | |
def process_doc_str(self, document): | |
# Raw document string. First perform basic tokenization before further tokenization. | |
basic_tokenizer = spacy.load("en_core_web_trf") | |
basic_tokenized_doc = basic_tokenize_doc(document, basic_tokenizer) | |
basic_tokenized_doc, representatives, representatives_names = ( | |
self.find_repr_and_clean(basic_tokenized_doc) | |
) | |
tokenized_doc = tokenize_and_segment_doc( | |
basic_tokenized_doc, | |
self.tokenizer, | |
max_segment_len=self.max_segment_len, | |
) | |
representatives, representatives_names = zip( | |
*sorted(zip(representatives, representatives_names)) | |
) | |
print("Representatives: ", representatives) | |
print("Representative Names: ", representatives_names) | |
ent_toks_st, ts_map = self.get_ts_from_st( | |
tokenized_doc["subtoken_map"], representatives | |
) | |
return ( | |
basic_tokenized_doc, | |
tokenized_doc, | |
representatives, | |
representatives_names, | |
ent_toks_st, | |
ts_map, | |
) | |
def _load_model(self): | |
checkpoint = torch.load(self.best_model_path, map_location="cpu") | |
self.config = checkpoint["config"] | |
self.train_info = checkpoint["train_info"] | |
if self.config.model.doc_encoder.finetune: | |
# Load the document encoder params if encoder is finetuned | |
doc_encoder_dir = path.join( | |
path.dirname(self.best_model_path), | |
self.config.paths.doc_encoder_dirname, | |
) | |
if path.exists(doc_encoder_dir): | |
self.config.model.doc_encoder.transformer.model_str = doc_encoder_dir | |
self.config.model.memory.thresh = 0.5 | |
self.model = EntityRankingModel(self.config.model, self.config.trainer) | |
# Document encoder parameters will be loaded via the huggingface initialization | |
self.model.load_state_dict(checkpoint["model"], strict=False) | |
if torch.cuda.is_available(): | |
self.model.cuda(device=self.config.device) | |
self.model.eval() | |
def perform_coreference(self, document, doc_name): | |
if isinstance(document, str): | |
( | |
basic_tokenized_doc, | |
tokenized_doc, | |
ent_toks, | |
ent_names, | |
ent_toks_st, | |
ts_map, | |
) = self.process_doc_str(document) | |
tokenized_doc["representatives"] = ent_toks_st | |
tokenized_doc["doc_key"] = doc_name | |
tokenized_doc["clusters"] = [] | |
else: | |
raise ValueError | |
( | |
pred_mentions, | |
pred_mention_emb_list, | |
mention_scores, | |
gt_actions, | |
pred_actions, | |
coref_scores_doc, | |
entity_cluster_states, | |
link_time, | |
) = self.model(tokenized_doc) | |
idx_clusters = action_sequences_to_clusters( | |
pred_actions, pred_mentions, len(ent_toks_st) | |
) | |
subtoken_map = tokenized_doc["subtoken_map"] | |
orig_tokens = tokenized_doc["orig_tokens"] | |
clusters = [] | |
for idx_cluster in idx_clusters: | |
cur_cluster = [] | |
for ment_start, ment_end in idx_cluster: | |
cur_cluster.append( | |
( | |
(subtoken_map[ment_start], subtoken_map[ment_end]), | |
" ".join( | |
orig_tokens[ | |
subtoken_map[ment_start] : subtoken_map[ment_end] + 1 | |
] | |
), | |
) | |
) | |
clusters.append(cur_cluster) | |
keys_tokenized_doc = list(tokenized_doc.keys()) | |
for key in keys_tokenized_doc: | |
if type(tokenized_doc[key]) == torch.Tensor: | |
del tokenized_doc[key] | |
tokenized_doc["tensorized_sent"] = [ | |
sent.tolist() for sent in tokenized_doc["tensorized_sent"] | |
] | |
return { | |
"tokenized_doc": tokenized_doc["orig_tokens"], | |
"clusters": clusters, | |
# "subtoken_idx_clusters": idx_clusters, | |
# "actions": pred_actions, | |
# "mentions": pred_mentions, | |
# "representative_embs": entity_cluster_states["mem"], | |
"representative_names": ent_names, | |
} | |
if __name__ == "__main__": | |
## Arg Parser | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-m", "--model", type=str, help="Specify model path") | |
parser.add_argument("-d", "--doc", type=str, help="Specify document path") | |
parser.add_argument( | |
"-g", "--gpu", type=str, default="cuda:0", help="Specify GPU device" | |
) | |
parser.add_argument( | |
"--doc_name", type=str, default="eval_doc", help="Specify encoder name" | |
) | |
parser.add_argument("-r", "--results", type=str, help="Specify results path") | |
args = parser.parse_args() | |
model_str = args.model | |
doc_str = args.doc | |
model = Inference(model_str) | |
doc_str = open(doc_str).read() | |
output_dict = model.perform_coreference(doc_str, args.doc_name) | |
print("Keys: ", output_dict.keys()) | |
# for cluster_ind, cluster in enumerate(output_dict["clusters"]): | |
# print(f"{cluster_ind}:", cluster) | |
with open(args.results, "w") as f: | |
json.dump(output_dict, f) | |