BillBojangeles2000 commited on
Commit
9eb3f19
·
1 Parent(s): 5e98df0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pinecone
2
+ import streamlit as st
3
+
4
+ API = st.text_area('Enter API key:')
5
+
6
+ # connect to pinecone environment
7
+ pinecone.init(
8
+ api_key="API",
9
+ environment="us-central1-gcp" # find next to API key in console
10
+ )
11
+
12
+ index_name = "abstractive-question-answering"
13
+
14
+ # check if the abstractive-question-answering index exists
15
+ if index_name not in pinecone.list_indexes():
16
+ # create the index if it does not exist
17
+ pinecone.create_index(
18
+ index_name,
19
+ dimension=768,
20
+ metric="cosine"
21
+ )
22
+
23
+ # connect to abstractive-question-answering index we created
24
+ index = pinecone.Index(index_name)
25
+
26
+ import torch
27
+ from sentence_transformers import SentenceTransformer
28
+
29
+ # set device to GPU if available
30
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
+ # load the retriever model from huggingface model hub
32
+ retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base", device=device)
33
+
34
+ from transformers import BartTokenizer, BartForConditionalGeneration
35
+
36
+ # load bart tokenizer and model from huggingface
37
+ tokenizer = BartTokenizer.from_pretrained('vblagoje/bart_lfqa')
38
+ generator = BartForConditionalGeneration.from_pretrained('vblagoje/bart_lfqa').to('cpu')
39
+
40
+ def query_pinecone(query, top_k):
41
+ # generate embeddings for the query
42
+ xq = retriever.encode([query]).tolist()
43
+ # search pinecone index for context passage with the answer
44
+ xc = index.query(xq, top_k=top_k, include_metadata=True)
45
+ return xc
46
+
47
+ def format_query(query, context):
48
+ # extract passage_text from Pinecone search result and add the <P> tag
49
+ context = [f"<P> {m['metadata']['text']}" for m in context]
50
+ # concatinate all context passages
51
+ context = " ".join(context)
52
+ # contcatinate the query and context passages
53
+ query = f"question: {query} context: {context}"
54
+ return query
55
+
56
+ def generate_answer(query):
57
+ # tokenize the query to get input_ids
58
+ inputs = tokenizer([query], trunication=True, max_length=1024, return_tensors="pt")
59
+ # use generator to predict output ids
60
+ ids = generator.generate(inputs["input_ids"], num_beams=2, min_length=20, max_length=64)
61
+ # use tokenizer to decode the output ids
62
+ answer = tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
63
+ return pprint(answer)
64
+
65
+ query = st.text_area('Enter your question:')
66
+ context = query_pinecone(query, top_k=5)
67
+ query = format_query(query, context["matches"])
68
+ generate_answer(query)