import torch from os import path from model.utils import action_sequences_to_clusters from model.entity_ranking_model import EntityRankingModel from inference.tokenize_doc import tokenize_and_segment_doc, basic_tokenize_doc from omegaconf import OmegaConf, open_dict from transformers import AutoModel, AutoTokenizer import spacy import json import pytorch_utils.utils as utils class Inference: def __init__(self, model_path): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.best_model_path = path.join(model_path, "best/model.pth") self._load_model() self.max_segment_len = self.config.model.doc_encoder.transformer.max_segment_len self.tokenizer = self.model.mention_proposer.doc_encoder.tokenizer def find_repr_and_clean(self, basic_tokenized_doc): ## Find marked representatives num_brackets = 0 start_tok = 0 tokens_new = [] ## Contains {{ and }} tokens_proc = [] ## Does not contain {{ and }} basic_tokenized_doc_proc = [] ## Does not contain {{ and }} skip_next = 0 for sentence in basic_tokenized_doc: tokens_sent = [] for token_ind, token in enumerate(sentence): if skip_next: skip_next = 0 continue if token_ind + 1 < len(sentence): if token == "{" and sentence[token_ind + 1] == "{": tokens_new.append("{{") skip_next = 1 elif token == "}" and sentence[token_ind + 1] == "}": tokens_new.append("}}") skip_next = 1 else: tokens_new.append(token) tokens_sent.append(token) else: tokens_new.append(token) tokens_sent.append(token) basic_tokenized_doc_proc.append(tokens_sent) tokens_proc.extend(tokens_sent) active_ent_toks = [] ent_toks = [] for word_ind, word in enumerate(tokens_new): if word == "{{": num_brackets += 1 start_tok += 1 elif word == "}}": num_brackets += 1 active_ent_toks[-1].append( word_ind - num_brackets ) ## Since we included the current bracket upfront new_entity = active_ent_toks.pop() ent_toks.append(new_entity) else: while start_tok > 0: active_ent_toks.append([word_ind - num_brackets]) start_tok -= 1 ent_names = [] for ent in ent_toks: ent_names.append(" ".join(tokens_proc[ent[0] : ent[1] + 1])) print("Entities: ", ent_toks) print("Entity Names: ", ent_names) return basic_tokenized_doc_proc, ent_toks, ent_names def get_ts_from_st(self, subtoken_map, representatives): ts_map = {} for subtoken_ind, token_ind in enumerate(subtoken_map): if token_ind not in ts_map: ts_map[token_ind] = [subtoken_ind] if subtoken_ind != 0: ts_map[token_ind - 1].append(subtoken_ind - 1) ent_toks_st = [] for entity in representatives: start_st = ts_map[entity[0]][0] end_st = ts_map[entity[1]][-1] ent_toks_st.append((start_st, end_st)) return ent_toks_st, ts_map def process_doc_str(self, document): # Raw document string. First perform basic tokenization before further tokenization. basic_tokenizer = spacy.load("en_core_web_trf") basic_tokenized_doc = basic_tokenize_doc(document, basic_tokenizer) basic_tokenized_doc, representatives, representatives_names = ( self.find_repr_and_clean(basic_tokenized_doc) ) tokenized_doc = tokenize_and_segment_doc( basic_tokenized_doc, self.tokenizer, max_segment_len=self.max_segment_len, ) representatives, representatives_names = zip( *sorted(zip(representatives, representatives_names)) ) print("Representatives: ", representatives) print("Representative Names: ", representatives_names) ent_toks_st, ts_map = self.get_ts_from_st( tokenized_doc["subtoken_map"], representatives ) return ( basic_tokenized_doc, tokenized_doc, representatives, representatives_names, ent_toks_st, ts_map, ) def _load_model(self): checkpoint = torch.load(self.best_model_path, map_location="cpu") self.config = checkpoint["config"] self.train_info = checkpoint["train_info"] if self.config.model.doc_encoder.finetune: # Load the document encoder params if encoder is finetuned doc_encoder_dir = path.join( path.dirname(self.best_model_path), self.config.paths.doc_encoder_dirname, ) if path.exists(doc_encoder_dir): self.config.model.doc_encoder.transformer.model_str = doc_encoder_dir self.config.model.memory.thresh = 0.5 self.model = EntityRankingModel(self.config.model, self.config.trainer) # Document encoder parameters will be loaded via the huggingface initialization self.model.load_state_dict(checkpoint["model"], strict=False) if torch.cuda.is_available(): self.model.cuda(device=self.config.device) self.model.eval() @torch.no_grad() def perform_coreference(self, document, doc_name): if isinstance(document, str): ( basic_tokenized_doc, tokenized_doc, ent_toks, ent_names, ent_toks_st, ts_map, ) = self.process_doc_str(document) tokenized_doc["representatives"] = ent_toks_st tokenized_doc["doc_key"] = doc_name tokenized_doc["clusters"] = [] else: raise ValueError ( pred_mentions, pred_mention_emb_list, mention_scores, gt_actions, pred_actions, coref_scores_doc, entity_cluster_states, link_time, ) = self.model(tokenized_doc) idx_clusters = action_sequences_to_clusters( pred_actions, pred_mentions, len(ent_toks_st) ) subtoken_map = tokenized_doc["subtoken_map"] orig_tokens = tokenized_doc["orig_tokens"] clusters = [] for idx_cluster in idx_clusters: cur_cluster = [] for ment_start, ment_end in idx_cluster: cur_cluster.append( ( (subtoken_map[ment_start], subtoken_map[ment_end]), " ".join( orig_tokens[ subtoken_map[ment_start] : subtoken_map[ment_end] + 1 ] ), ) ) clusters.append(cur_cluster) keys_tokenized_doc = list(tokenized_doc.keys()) for key in keys_tokenized_doc: if type(tokenized_doc[key]) == torch.Tensor: del tokenized_doc[key] tokenized_doc["tensorized_sent"] = [ sent.tolist() for sent in tokenized_doc["tensorized_sent"] ] return { "tokenized_doc": tokenized_doc["orig_tokens"], "clusters": clusters, # "subtoken_idx_clusters": idx_clusters, # "actions": pred_actions, # "mentions": pred_mentions, # "representative_embs": entity_cluster_states["mem"], "representative_names": ent_names, } if __name__ == "__main__": ## Arg Parser import argparse parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, help="Specify model path") parser.add_argument("-d", "--doc", type=str, help="Specify document path") parser.add_argument( "-g", "--gpu", type=str, default="cuda:0", help="Specify GPU device" ) parser.add_argument( "--doc_name", type=str, default="eval_doc", help="Specify encoder name" ) parser.add_argument("-r", "--results", type=str, help="Specify results path") args = parser.parse_args() model_str = args.model doc_str = args.doc model = Inference(model_str) doc_str = open(doc_str).read() output_dict = model.perform_coreference(doc_str, args.doc_name) print("Keys: ", output_dict.keys()) # for cluster_ind, cluster in enumerate(output_dict["clusters"]): # print(f"{cluster_ind}:", cluster) with open(args.results, "w") as f: json.dump(output_dict, f)