foxxy-hm commited on
Commit
4d65fc7
·
1 Parent(s): 16bbc11

Update models/handler.py

Browse files
Files changed (1) hide show
  1. models/handler.py +35 -32
models/handler.py CHANGED
@@ -6,16 +6,14 @@ from models.qa_model import *
6
  from tqdm.auto import tqdm
7
  tqdm.pandas()
8
  from datasets import load_dataset
9
- # from typing import Dict, List, Any
10
- # from transformers import pipeline, AutoTokenizer
11
-
12
 
13
  class EndpointHandler():
14
  def __init__(self, path=""):
15
  df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
16
  df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="wikipedia_20220620_short.csv")["train"].to_pandas()
17
  df_wiki.title = df_wiki.title.apply(str)
18
-
19
  entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict()
20
  new_dict = dict()
21
  for key, val in entity_dict.items():
@@ -25,68 +23,73 @@ class EndpointHandler():
25
  new_dict[key.lower()] = val
26
  entity_dict.update(new_dict)
27
  title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
28
- # load the optimized model
29
  qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["qa_model_robust.bin"], entity_dict)
30
- pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base")#.half()
31
  pairwise_model_stage1.load_state_dict(torch.load("pairwise_v2.bin", map_location=torch.device('cpu')))
32
  pairwise_model_stage1.eval()
33
-
34
- pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base")#.half()
35
  pairwise_model_stage2.load_state_dict(torch.load("pairwise_stage2_seed0.bin", map_location=torch.device('cpu')))
36
-
37
  bm25_model_stage1 = BM25Gensim("bm25_stage1/", entity_dict, title2idx)
38
  bm25_model_stage2_full = BM25Gensim("bm25_stage2/full_text/", entity_dict, title2idx)
39
  bm25_model_stage2_title = BM25Gensim("bm25_stage2/title/", entity_dict, title2idx)
40
- # # create inference pipeline
41
- # self.pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
42
 
 
 
 
 
 
 
 
43
  def get_answer_e2e(self, question):
44
- #Bm25 retrieval for top200 candidates
45
  query = preprocess(question).lower()
46
- top_n, bm25_scores = bm25_model_stage1.get_topk_stage1(query, topk=200)
47
  titles = [preprocess(df_wiki_windows.title.values[i]) for i in top_n]
48
  texts = [preprocess(df_wiki_windows.text.values[i]) for i in top_n]
49
-
50
- #Reranking with pairwise model for top10
51
  question = preprocess(question)
52
- ranking_preds = pairwise_model_stage1.stage1_ranking(question, texts)
53
  ranking_scores = ranking_preds * bm25_scores
54
-
55
- #Question answering
56
  best_idxs = np.argsort(ranking_scores)[-10:]
57
  ranking_scores = np.array(ranking_scores)[best_idxs]
58
  texts = np.array(texts)[best_idxs]
59
- best_answer = qa_model(question, texts, ranking_scores)
60
  if best_answer is None:
61
  return "Chịu"
62
  bm25_answer = preprocess(str(best_answer).lower(), max_length=128, remove_puncts=True)
63
-
64
- #Entity mapping
65
  if not check_number(bm25_answer):
66
  bm25_question = preprocess(str(question).lower(), max_length=128, remove_puncts=True)
67
  bm25_question_answer = bm25_question + " " + bm25_answer
68
- candidates, scores = bm25_model_stage2_title.get_topk_stage2(bm25_answer, raw_answer=best_answer)
69
  titles = [df_wiki.title.values[i] for i in candidates]
70
  texts = [df_wiki.text.values[i] for i in candidates]
71
- ranking_preds = pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts)
72
  if ranking_preds.max() >= 0.1:
73
  final_answer = titles[ranking_preds.argmax()]
74
  else:
75
- candidates, scores = bm25_model_stage2_full.get_topk_stage2(bm25_question_answer)
76
  titles = [df_wiki.title.values[i] for i in candidates] + titles
77
  texts = [df_wiki.text.values[i] for i in candidates] + texts
78
  ranking_preds = np.concatenate(
79
- [pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts), ranking_preds])
80
  final_answer = "wiki/"+titles[ranking_preds.argmax()].replace(" ","_")
81
  else:
82
  final_answer = bm25_answer.lower()
83
- return final_answer
84
 
85
 
 
 
 
 
86
  def __call__(self, question):
87
- """
88
- """
89
- # Call the get_answer_e2e method with the question
90
- answer = self.get_answer_e2e(question)
91
- # Return the answer as a dictionary
92
- return {"answer": answer}
 
6
  from tqdm.auto import tqdm
7
  tqdm.pandas()
8
  from datasets import load_dataset
9
+ from transformers import pipeline
 
 
10
 
11
  class EndpointHandler():
12
  def __init__(self, path=""):
13
  df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
14
  df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="wikipedia_20220620_short.csv")["train"].to_pandas()
15
  df_wiki.title = df_wiki.title.apply(str)
16
+
17
  entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict()
18
  new_dict = dict()
19
  for key, val in entity_dict.items():
 
23
  new_dict[key.lower()] = val
24
  entity_dict.update(new_dict)
25
  title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
26
+
27
  qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["qa_model_robust.bin"], entity_dict)
28
+ pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base")
29
  pairwise_model_stage1.load_state_dict(torch.load("pairwise_v2.bin", map_location=torch.device('cpu')))
30
  pairwise_model_stage1.eval()
31
+
32
+ pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base")
33
  pairwise_model_stage2.load_state_dict(torch.load("pairwise_stage2_seed0.bin", map_location=torch.device('cpu')))
34
+
35
  bm25_model_stage1 = BM25Gensim("bm25_stage1/", entity_dict, title2idx)
36
  bm25_model_stage2_full = BM25Gensim("bm25_stage2/full_text/", entity_dict, title2idx)
37
  bm25_model_stage2_title = BM25Gensim("bm25_stage2/title/", entity_dict, title2idx)
 
 
38
 
39
+ self.qa_model = qa_model
40
+ self.pairwise_model_stage1 = pairwise_model_stage1
41
+ self.pairwise_model_stage2 = pairwise_model_stage2
42
+ self.bm25_model_stage1 = bm25_model_stage1
43
+ self.bm25_model_stage2_full = bm25_model_stage2_full
44
+ self.bm25_model_stage2_title = bm25_model_stage2_title
45
+
46
  def get_answer_e2e(self, question):
 
47
  query = preprocess(question).lower()
48
+ top_n, bm25_scores = self.bm25_model_stage1.get_topk_stage1(query, topk=200)
49
  titles = [preprocess(df_wiki_windows.title.values[i]) for i in top_n]
50
  texts = [preprocess(df_wiki_windows.text.values[i]) for i in top_n]
51
+
 
52
  question = preprocess(question)
53
+ ranking_preds = self.pairwise_model_stage1.stage1_ranking(question, texts)
54
  ranking_scores = ranking_preds * bm25_scores
55
+
 
56
  best_idxs = np.argsort(ranking_scores)[-10:]
57
  ranking_scores = np.array(ranking_scores)[best_idxs]
58
  texts = np.array(texts)[best_idxs]
59
+ best_answer = self.qa_model(question, texts, ranking_scores)
60
  if best_answer is None:
61
  return "Chịu"
62
  bm25_answer = preprocess(str(best_answer).lower(), max_length=128, remove_puncts=True)
63
+
 
64
  if not check_number(bm25_answer):
65
  bm25_question = preprocess(str(question).lower(), max_length=128, remove_puncts=True)
66
  bm25_question_answer = bm25_question + " " + bm25_answer
67
+ candidates, scores = self.bm25_model_stage2_title.get_topk_stage2(bm25_answer, raw_answer=best_answer)
68
  titles = [df_wiki.title.values[i] for i in candidates]
69
  texts = [df_wiki.text.values[i] for i in candidates]
70
+ ranking_preds = self.pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts)
71
  if ranking_preds.max() >= 0.1:
72
  final_answer = titles[ranking_preds.argmax()]
73
  else:
74
+ candidates, scores = self.bm25_model_stage2_full.get_topk_stage2(bm25_question_answer)
75
  titles = [df_wiki.title.values[i] for i in candidates] + titles
76
  texts = [df_wiki.text.values[i] for i in candidates] + texts
77
  ranking_preds = np.concatenate(
78
+ [self.pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts), ranking_preds])
79
  final_answer = "wiki/"+titles[ranking_preds.argmax()].replace(" ","_")
80
  else:
81
  final_answer = bm25_answer.lower()
82
+ return final_answer
83
 
84
 
85
+ class InferencePipeline:
86
+ def __init__(self):
87
+ self.endpoint_handler = EndpointHandler() # Instantiate the EndpointHandler class
88
+
89
  def __call__(self, question):
90
+ answer = self.endpoint_handler.get_answer_e2e(question) # Call the get_answer_e2e method from EndpointHandler
91
+ return {"answer": answer} # Return the answer as a dictionary
92
+
93
+
94
+ inference_pipeline = InferencePipeline()
95
+ pipeline = pipeline("qa-model", model=inference_pipeline, tokenizer=None)