File size: 4,917 Bytes
16bbc11
 
 
 
 
 
 
 
4d65fc7
16bbc11
 
 
 
 
 
4d65fc7
16bbc11
 
 
 
 
 
 
 
 
4d65fc7
16bbc11
4d65fc7
16bbc11
 
4d65fc7
 
16bbc11
4d65fc7
16bbc11
 
 
 
4d65fc7
 
 
 
 
 
 
16bbc11
 
4d65fc7
16bbc11
 
4d65fc7
16bbc11
4d65fc7
16bbc11
4d65fc7
16bbc11
 
 
4d65fc7
16bbc11
 
 
4d65fc7
16bbc11
 
 
4d65fc7
16bbc11
 
4d65fc7
16bbc11
 
 
4d65fc7
16bbc11
 
 
4d65fc7
16bbc11
 
 
4d65fc7
16bbc11
 
4d65fc7
 
 
 
16bbc11
4d65fc7
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from models.pairwise_model import *
from features.text_utils import *
import regex as re
from models.bm25_utils import BM25Gensim
from models.qa_model import *
from tqdm.auto import tqdm
tqdm.pandas()
from datasets import load_dataset
from transformers import pipeline

class EndpointHandler():
    def __init__(self, path=""):
        df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki",  data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
        df_wiki = load_dataset("foxxy-hm/e2eqa-wiki",  data_files="wikipedia_20220620_short.csv")["train"].to_pandas()
        df_wiki.title = df_wiki.title.apply(str)

        entity_dict = load_dataset("foxxy-hm/e2eqa-wiki",  data_files="processed/entities.json")["train"].to_dict()
        new_dict = dict()
        for key, val in entity_dict.items():
            val = val[0].replace("wiki/", "").replace("_", " ")
            entity_dict[key] = val
            key = preprocess(key)
            new_dict[key.lower()] = val
        entity_dict.update(new_dict)
        title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])

        qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["qa_model_robust.bin"], entity_dict)
        pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base")
        pairwise_model_stage1.load_state_dict(torch.load("pairwise_v2.bin", map_location=torch.device('cpu')))
        pairwise_model_stage1.eval()

        pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base")
        pairwise_model_stage2.load_state_dict(torch.load("pairwise_stage2_seed0.bin", map_location=torch.device('cpu')))

        bm25_model_stage1 = BM25Gensim("bm25_stage1/", entity_dict, title2idx)
        bm25_model_stage2_full = BM25Gensim("bm25_stage2/full_text/", entity_dict, title2idx)
        bm25_model_stage2_title = BM25Gensim("bm25_stage2/title/", entity_dict, title2idx)

        self.qa_model = qa_model
        self.pairwise_model_stage1 = pairwise_model_stage1
        self.pairwise_model_stage2 = pairwise_model_stage2
        self.bm25_model_stage1 = bm25_model_stage1
        self.bm25_model_stage2_full = bm25_model_stage2_full
        self.bm25_model_stage2_title = bm25_model_stage2_title
    
    def get_answer_e2e(self, question):
        query = preprocess(question).lower()
        top_n, bm25_scores = self.bm25_model_stage1.get_topk_stage1(query, topk=200)
        titles = [preprocess(df_wiki_windows.title.values[i]) for i in top_n]
        texts = [preprocess(df_wiki_windows.text.values[i]) for i in top_n]

        question = preprocess(question)
        ranking_preds = self.pairwise_model_stage1.stage1_ranking(question, texts)
        ranking_scores = ranking_preds * bm25_scores

        best_idxs = np.argsort(ranking_scores)[-10:]
        ranking_scores = np.array(ranking_scores)[best_idxs]
        texts = np.array(texts)[best_idxs]
        best_answer = self.qa_model(question, texts, ranking_scores)
        if best_answer is None:
            return "Chịu"
        bm25_answer = preprocess(str(best_answer).lower(), max_length=128, remove_puncts=True)

        if not check_number(bm25_answer):
            bm25_question = preprocess(str(question).lower(), max_length=128, remove_puncts=True)
            bm25_question_answer = bm25_question + " " + bm25_answer
            candidates, scores = self.bm25_model_stage2_title.get_topk_stage2(bm25_answer, raw_answer=best_answer)
            titles = [df_wiki.title.values[i] for i in candidates]
            texts = [df_wiki.text.values[i] for i in candidates]
            ranking_preds = self.pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts)
            if ranking_preds.max() >= 0.1:
                final_answer = titles[ranking_preds.argmax()]
            else:
                candidates, scores = self.bm25_model_stage2_full.get_topk_stage2(bm25_question_answer)
                titles = [df_wiki.title.values[i] for i in candidates] + titles
                texts = [df_wiki.text.values[i] for i in candidates] + texts
                ranking_preds = np.concatenate(
                    [self.pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts), ranking_preds])
            final_answer = "wiki/"+titles[ranking_preds.argmax()].replace(" ","_")
        else:
            final_answer = bm25_answer.lower()
        return final_answer

    
class InferencePipeline:
    def __init__(self):
        self.endpoint_handler = EndpointHandler()  # Instantiate the EndpointHandler class
    
    def __call__(self, question):
        answer = self.endpoint_handler.get_answer_e2e(question)  # Call the get_answer_e2e method from EndpointHandler
        return {"answer": answer}  # Return the answer as a dictionary


inference_pipeline = InferencePipeline()
pipeline = pipeline("qa-model", model=inference_pipeline, tokenizer=None)