File size: 4,917 Bytes
16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 16bbc11 4d65fc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from models.pairwise_model import *
from features.text_utils import *
import regex as re
from models.bm25_utils import BM25Gensim
from models.qa_model import *
from tqdm.auto import tqdm
tqdm.pandas()
from datasets import load_dataset
from transformers import pipeline
class EndpointHandler():
def __init__(self, path=""):
df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="wikipedia_20220620_short.csv")["train"].to_pandas()
df_wiki.title = df_wiki.title.apply(str)
entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict()
new_dict = dict()
for key, val in entity_dict.items():
val = val[0].replace("wiki/", "").replace("_", " ")
entity_dict[key] = val
key = preprocess(key)
new_dict[key.lower()] = val
entity_dict.update(new_dict)
title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["qa_model_robust.bin"], entity_dict)
pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base")
pairwise_model_stage1.load_state_dict(torch.load("pairwise_v2.bin", map_location=torch.device('cpu')))
pairwise_model_stage1.eval()
pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base")
pairwise_model_stage2.load_state_dict(torch.load("pairwise_stage2_seed0.bin", map_location=torch.device('cpu')))
bm25_model_stage1 = BM25Gensim("bm25_stage1/", entity_dict, title2idx)
bm25_model_stage2_full = BM25Gensim("bm25_stage2/full_text/", entity_dict, title2idx)
bm25_model_stage2_title = BM25Gensim("bm25_stage2/title/", entity_dict, title2idx)
self.qa_model = qa_model
self.pairwise_model_stage1 = pairwise_model_stage1
self.pairwise_model_stage2 = pairwise_model_stage2
self.bm25_model_stage1 = bm25_model_stage1
self.bm25_model_stage2_full = bm25_model_stage2_full
self.bm25_model_stage2_title = bm25_model_stage2_title
def get_answer_e2e(self, question):
query = preprocess(question).lower()
top_n, bm25_scores = self.bm25_model_stage1.get_topk_stage1(query, topk=200)
titles = [preprocess(df_wiki_windows.title.values[i]) for i in top_n]
texts = [preprocess(df_wiki_windows.text.values[i]) for i in top_n]
question = preprocess(question)
ranking_preds = self.pairwise_model_stage1.stage1_ranking(question, texts)
ranking_scores = ranking_preds * bm25_scores
best_idxs = np.argsort(ranking_scores)[-10:]
ranking_scores = np.array(ranking_scores)[best_idxs]
texts = np.array(texts)[best_idxs]
best_answer = self.qa_model(question, texts, ranking_scores)
if best_answer is None:
return "Chịu"
bm25_answer = preprocess(str(best_answer).lower(), max_length=128, remove_puncts=True)
if not check_number(bm25_answer):
bm25_question = preprocess(str(question).lower(), max_length=128, remove_puncts=True)
bm25_question_answer = bm25_question + " " + bm25_answer
candidates, scores = self.bm25_model_stage2_title.get_topk_stage2(bm25_answer, raw_answer=best_answer)
titles = [df_wiki.title.values[i] for i in candidates]
texts = [df_wiki.text.values[i] for i in candidates]
ranking_preds = self.pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts)
if ranking_preds.max() >= 0.1:
final_answer = titles[ranking_preds.argmax()]
else:
candidates, scores = self.bm25_model_stage2_full.get_topk_stage2(bm25_question_answer)
titles = [df_wiki.title.values[i] for i in candidates] + titles
texts = [df_wiki.text.values[i] for i in candidates] + texts
ranking_preds = np.concatenate(
[self.pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts), ranking_preds])
final_answer = "wiki/"+titles[ranking_preds.argmax()].replace(" ","_")
else:
final_answer = bm25_answer.lower()
return final_answer
class InferencePipeline:
def __init__(self):
self.endpoint_handler = EndpointHandler() # Instantiate the EndpointHandler class
def __call__(self, question):
answer = self.endpoint_handler.get_answer_e2e(question) # Call the get_answer_e2e method from EndpointHandler
return {"answer": answer} # Return the answer as a dictionary
inference_pipeline = InferencePipeline()
pipeline = pipeline("qa-model", model=inference_pipeline, tokenizer=None)
|