File size: 3,313 Bytes
c8b3fc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from haystack.document_stores import FAISSDocumentStore
from haystack.utils import convert_files_to_docs, fetch_archive_from_http, clean_wiki_text
from haystack.nodes import DensePassageRetriever
from haystack.utils import print_documents, print_answers
from haystack.pipelines import DocumentSearchPipeline
from haystack.nodes import Seq2SeqGenerator
from haystack.pipelines import GenerativeQAPipeline


# %% Save/Load FAISS and embeddings
# Try out this script. Make sure you have deleted any old saves of the document store, including the file called faiss_document_store.db that is saved and loaded by default.

# # Convert files to dicts
# dicts = convert_files_to_dicts(dir_path=doc_dir, clean_func=clean_wiki_text, split_paragraphs=True)[:10]

# document_store = FAISSDocumentStore(faiss_index_factory_str="Flat", vector_dim=128)
# # document_store = FAISSDocumentStore(sql_url= "sqlite:///faiss_document_store.db")

# retriever = EmbeddingRetriever(document_store=document_store,
#                                embedding_model="yjernite/retribert-base-uncased",
#                                model_format="retribert",
#                                use_gpu=False)

# # Now, let's write the dicts containing documents to our DB.
# document_store.write_documents(dicts)
# document_store.update_embeddings(retriever)

# document_store.save("my_faiss_index.faiss")
# new_document_store= FAISSDocumentStore.load("my_faiss_index.faiss")
# # new_document_store = FAISSDocumentStore.load(faiss_file_path="testfile_path", sql_url= "sqlite:///faiss_document_store.db")

# %% ------------------------------------------------------------------------------------------------------------

def prepare():
    # %% Document Store
    document_store= FAISSDocumentStore.load("faiss_index.faiss")

    # %% Initialize Retriever and Reader/Generator

    # Retriever (DPR)

    retriever = DensePassageRetriever(
        document_store=document_store,
        query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
        passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
        use_gpu=False
    )


    # # Test DPR
    # p_retrieval = DocumentSearchPipeline(retriever)
    # res = p_retrieval.run(query="Tell me something about Arya Stark?", params={"Retriever": {"top_k": 5}})
    # print_documents(res, max_text_len=512)

    # Reader/Generator
    # Here we use a Seq2SeqGenerator with the vblagoje/bart_lfqa model (https://huggingface.co/vblagoje/bart_lfqa)

    generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa",
                                 use_gpu=False)

    # %% Pipeline

    pipe = GenerativeQAPipeline(generator, retriever)
    return pipe

def answer(pipe, question, k_retriever=3):
    res = pipe.run(question, params={"Retriever": {"top_k": k_retriever}})
    # # Question
    # pipe.run(
    #     query="How did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 3}}
    # )
    # # Answer
    # res = pipe.run(query="Why is Arya Stark an unusual character?", params={"Retriever": {"top_k": 3}})
    return res

if __name__ == '__main__':
    question = 'Tell me something about Arya Stark?'
    pipe = prepare()
    res = answer(pipe, question)
    print_answers(res, details="minimum")