File size: 2,798 Bytes
9eb3f19
 
 
 
abb2086
 
 
 
 
 
9eb3f19
abb2086
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import pinecone
import streamlit as st

API = st.text_area('Enter API key:')
res = st.button('Submit')
if res = True:
    # connect to pinecone environment
    pinecone.init(
        api_key="API",
        environment="us-central1-gcp"  # find next to API key in console
    )
    
    index_name = "abstractive-question-answering"
    
    # check if the abstractive-question-answering index exists
    if index_name not in pinecone.list_indexes():
        # create the index if it does not exist
        pinecone.create_index(
            index_name,
            dimension=768,
            metric="cosine"
        )
    
    # connect to abstractive-question-answering index we created
    index = pinecone.Index(index_name)
    
    import torch
    from sentence_transformers import SentenceTransformer
    
    # set device to GPU if available
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # load the retriever model from huggingface model hub
    retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base", device=device)
    
    from transformers import BartTokenizer, BartForConditionalGeneration
    
    # load bart tokenizer and model from huggingface
    tokenizer = BartTokenizer.from_pretrained('vblagoje/bart_lfqa')
    generator = BartForConditionalGeneration.from_pretrained('vblagoje/bart_lfqa').to('cpu')
    
    def query_pinecone(query, top_k):
        # generate embeddings for the query
        xq = retriever.encode([query]).tolist()
        # search pinecone index for context passage with the answer
        xc = index.query(xq, top_k=top_k, include_metadata=True)
        return xc
    
    def format_query(query, context):
        # extract passage_text from Pinecone search result and add the <P> tag
        context = [f"<P> {m['metadata']['text']}" for m in context]
        # concatinate all context passages
        context = " ".join(context)
        # contcatinate the query and context passages
        query = f"question: {query} context: {context}"
        return query
    
    def generate_answer(query):
        # tokenize the query to get input_ids
        inputs = tokenizer([query], trunication=True, max_length=1024, return_tensors="pt")
        # use generator to predict output ids
        ids = generator.generate(inputs["input_ids"], num_beams=2, min_length=20, max_length=64)
        # use tokenizer to decode the output ids
        answer = tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        return pprint(answer)
    
    query = st.text_area('Enter your question:')
    s = st.button('Submit')
    if s = True:
        context = query_pinecone(query, top_k=5)
        query = format_query(query, context["matches"])
        generate_answer(query)