foxxy-hm commited on
Commit
234683b
·
1 Parent(s): 4d65fc7

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +95 -0
handler.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.pairwise_model import *
2
+ from features.text_utils import *
3
+ import regex as re
4
+ from models.bm25_utils import BM25Gensim
5
+ from models.qa_model import *
6
+ from tqdm.auto import tqdm
7
+ tqdm.pandas()
8
+ from datasets import load_dataset
9
+ from transformers import pipeline
10
+
11
+ class EndpointHandler():
12
+ def __init__(self, path=""):
13
+ df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
14
+ df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="wikipedia_20220620_short.csv")["train"].to_pandas()
15
+ df_wiki.title = df_wiki.title.apply(str)
16
+
17
+ entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict()
18
+ new_dict = dict()
19
+ for key, val in entity_dict.items():
20
+ val = val[0].replace("wiki/", "").replace("_", " ")
21
+ entity_dict[key] = val
22
+ key = preprocess(key)
23
+ new_dict[key.lower()] = val
24
+ entity_dict.update(new_dict)
25
+ title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
26
+
27
+ qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["qa_model_robust.bin"], entity_dict)
28
+ pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base")
29
+ pairwise_model_stage1.load_state_dict(torch.load("pairwise_v2.bin", map_location=torch.device('cpu')))
30
+ pairwise_model_stage1.eval()
31
+
32
+ pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base")
33
+ pairwise_model_stage2.load_state_dict(torch.load("pairwise_stage2_seed0.bin", map_location=torch.device('cpu')))
34
+
35
+ bm25_model_stage1 = BM25Gensim("bm25_stage1/", entity_dict, title2idx)
36
+ bm25_model_stage2_full = BM25Gensim("bm25_stage2/full_text/", entity_dict, title2idx)
37
+ bm25_model_stage2_title = BM25Gensim("bm25_stage2/title/", entity_dict, title2idx)
38
+
39
+ self.qa_model = qa_model
40
+ self.pairwise_model_stage1 = pairwise_model_stage1
41
+ self.pairwise_model_stage2 = pairwise_model_stage2
42
+ self.bm25_model_stage1 = bm25_model_stage1
43
+ self.bm25_model_stage2_full = bm25_model_stage2_full
44
+ self.bm25_model_stage2_title = bm25_model_stage2_title
45
+
46
+ def get_answer_e2e(self, question):
47
+ query = preprocess(question).lower()
48
+ top_n, bm25_scores = self.bm25_model_stage1.get_topk_stage1(query, topk=200)
49
+ titles = [preprocess(df_wiki_windows.title.values[i]) for i in top_n]
50
+ texts = [preprocess(df_wiki_windows.text.values[i]) for i in top_n]
51
+
52
+ question = preprocess(question)
53
+ ranking_preds = self.pairwise_model_stage1.stage1_ranking(question, texts)
54
+ ranking_scores = ranking_preds * bm25_scores
55
+
56
+ best_idxs = np.argsort(ranking_scores)[-10:]
57
+ ranking_scores = np.array(ranking_scores)[best_idxs]
58
+ texts = np.array(texts)[best_idxs]
59
+ best_answer = self.qa_model(question, texts, ranking_scores)
60
+ if best_answer is None:
61
+ return "Chịu"
62
+ bm25_answer = preprocess(str(best_answer).lower(), max_length=128, remove_puncts=True)
63
+
64
+ if not check_number(bm25_answer):
65
+ bm25_question = preprocess(str(question).lower(), max_length=128, remove_puncts=True)
66
+ bm25_question_answer = bm25_question + " " + bm25_answer
67
+ candidates, scores = self.bm25_model_stage2_title.get_topk_stage2(bm25_answer, raw_answer=best_answer)
68
+ titles = [df_wiki.title.values[i] for i in candidates]
69
+ texts = [df_wiki.text.values[i] for i in candidates]
70
+ ranking_preds = self.pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts)
71
+ if ranking_preds.max() >= 0.1:
72
+ final_answer = titles[ranking_preds.argmax()]
73
+ else:
74
+ candidates, scores = self.bm25_model_stage2_full.get_topk_stage2(bm25_question_answer)
75
+ titles = [df_wiki.title.values[i] for i in candidates] + titles
76
+ texts = [df_wiki.text.values[i] for i in candidates] + texts
77
+ ranking_preds = np.concatenate(
78
+ [self.pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts), ranking_preds])
79
+ final_answer = "wiki/"+titles[ranking_preds.argmax()].replace(" ","_")
80
+ else:
81
+ final_answer = bm25_answer.lower()
82
+ return final_answer
83
+
84
+
85
+ class InferencePipeline:
86
+ def __init__(self):
87
+ self.endpoint_handler = EndpointHandler() # Instantiate the EndpointHandler class
88
+
89
+ def __call__(self, question):
90
+ answer = self.endpoint_handler.get_answer_e2e(question) # Call the get_answer_e2e method from EndpointHandler
91
+ return {"answer": answer} # Return the answer as a dictionary
92
+
93
+
94
+ inference_pipeline = InferencePipeline()
95
+ pipeline = pipeline("qa-model", model=inference_pipeline, tokenizer=None)