import numpy as np import torch import torch.nn as nn # from transformers import AutoModelForQuestionAnswering, pipeline from features.text_utils import post_process_answer from features.graph_utils import find_best_cluster from optimum.onnxruntime import ORTModelForQuestionAnswering class QAEnsembleModel(nn.Module): def __init__(self, model_name, model_checkpoints, entity_dict, thr=0.1, device="cpu"): super(QAEnsembleModel, self).__init__() # self.nlps = [] self.models = [] self.tokenizers = [] for model_checkpoint in model_checkpoints: model = ORTModelForQuestionAnswering.from_pretrained(model_name, from_transformers=True)#.half() model.load_state_dict(torch.load(model_checkpoint, map_location=torch.device('cpu')), strict=False) # nlp = pipeline('question-answering', model=model, # tokenizer=model_name, device=device) # self.nlps.append(nlp) tokenizer = AutoTokenizer.from_pretrained(model_name) self.models.append(model) self.tokenizers.append(tokenizer) self.entity_dict = entity_dict self.thr = thr def forward(self, question, texts, ranking_scores=None): if ranking_scores is None: ranking_scores = np.ones((len(texts),)) curr_answers = [] curr_scores = [] best_score = 0 # for i, nlp in enumerate(self.nlps): # for text, score in zip(texts, ranking_scores): # QA_input = { # 'question': question, # 'context': text # } # res = nlp(QA_input) # # print(res) # if res["score"] > self.thr: # curr_answers.append(res["answer"]) # curr_scores.append(res["score"]) # res["score"] = res["score"] * score # if i == 0: # if res["score"] > best_score: # answer = res["answer"] # best_score = res["score"] for i, (model, tokenizer) in enumerate(zip(self.models, self.tokenizers)): for text, score in zip(texts, ranking_scores): # Encode the question and context as input ids and attention mask inputs = tokenizer(question, text, return_tensors="pt") input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] # Get the start and end logits from the model outputs = model(input_ids, attention_mask=attention_mask) start_logits = outputs.start_logits end_logits = outputs.end_logits # Get the most likely start and end indices start_idx = torch.argmax(start_logits) end_idx = torch.argmax(end_logits) # Get the answer span from the input ids answer_ids = input_ids[0][start_idx:end_idx+1] # Decode the answer ids to get the answer text answer_text = tokenizer.decode(answer_ids) # Get the answer score from the start and end logits answer_score = torch.max(start_logits) + torch.max(end_logits) # Convert to numpy values answer_text = answer_text.numpy() answer_score = answer_score.numpy() if answer_score > self.thr: curr_answers.append(answer_text) curr_scores.append(answer_score) answer_score = answer_score * score if i == 0: if answer_score > best_score: answer = answer_text best_score = answer_score if len(curr_answers) == 0: return None curr_answers = [post_process_answer(x, self.entity_dict) for x in curr_answers] answer = post_process_answer(answer, self.entity_dict) new_best_answer = post_process_answer(find_best_cluster(curr_answers, answer), self.entity_dict) return new_best_answer