HarryLee commited on
Commit
39caa01
·
1 Parent(s): 71bf38c

Add application file

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_tags import st_tags, st_tags_sidebar
3
+ from keytotext import pipeline
4
+ from PIL import Image
5
+
6
+ import json
7
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
8
+ import gzip
9
+ import os
10
+ import torch
11
+
12
+ ############
13
+ ## Main page
14
+ ############
15
+
16
+ st.write("# Code for Query Expansion")
17
+
18
+ st.markdown("***Idea is to build a model which will take query as inputs and generate expansion information as outputs.***")
19
+ image = Image.open('top.png')
20
+ st.image(image)
21
+
22
+ st.sidebar.write("# Parameter Selection")
23
+ maxtags_sidebar = st.sidebar.slider('Number of query allowed?', 1, 10, 1, key='ehikwegrjifbwreuk')
24
+ user_query = st_tags(
25
+ label='# Enter Query:',
26
+ text='Press enter to add more',
27
+ value=['Mother'],
28
+ suggestions=['five', 'six', 'seven', 'eight', 'nine', 'three', 'eleven', 'ten', 'four'],
29
+ maxtags=maxtags_sidebar,
30
+ key="aljnf")
31
+
32
+ # Add selectbox in streamlit
33
+ option1 = st.sidebar.selectbox(
34
+ 'Which transformers model would you like to be selected?',
35
+ ('multi-qa-MiniLM-L6-cos-v1'))
36
+
37
+ option2 = st.sidebar.selectbox(
38
+ 'Which corss-encoder model would you like to be selected?',
39
+ ('cross-encoder/ms-marco-MiniLM-L-6-v2'))
40
+
41
+ if not torch.cuda.is_available():
42
+ print("Warning: No GPU found. Please add GPU to your notebook")
43
+
44
+
45
+ #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
46
+ bi_encoder = SentenceTransformer(option1)
47
+ bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
48
+ top_k = 32 #Number of passages we want to retrieve with the bi-encoder
49
+
50
+ #The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
51
+ cross_encoder = CrossEncoder(option2)
52
+
53
+ # As dataset, we use Simple English Wikipedia. Compared to the full English wikipedia, it has only
54
+ # about 170k articles. We split these articles into paragraphs and encode them with the bi-encoder
55
+
56
+ etsy_filepath = '000000000001.json'
57
+
58
+ #if not os.path.exists(wikipedia_filepath):
59
+ # util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath)
60
+
61
+ passages = []
62
+ '''
63
+ with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:
64
+ for line in fIn:
65
+ data = json.loads(line.strip())
66
+
67
+ #Add all paragraphs
68
+ #passages.extend(data['paragraphs'])
69
+
70
+ #Only add the first paragraph
71
+ passages.append(data['paragraphs'][0])
72
+ '''
73
+
74
+ with open(etsy_filepath, 'r') as EtsyJson:
75
+ for line in EtsyJson:
76
+ data = json.loads(line.strip())
77
+ #passages.append(data['query'])
78
+ passages.append(data['title'])
79
+
80
+
81
+ print("Passages:", len(passages))
82
+
83
+ # We encode all passages into our vector space. This takes about 5 minutes (depends on your GPU speed)
84
+ corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
85
+
86
+ # This function will search all wikipedia articles for passages that
87
+ # answer the query
88
+ def search(query):
89
+ print("Input question:", query)
90
+
91
+ ##### BM25 search (lexical search) #####
92
+ #bm25_scores = bm25.get_scores(bm25_tokenizer(query))
93
+ #top_n = np.argpartition(bm25_scores, -5)[-5:]
94
+ #bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
95
+ #bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
96
+
97
+ #print("Top-10 lexical search (BM25) hits")
98
+ #for hit in bm25_hits[0:10]:
99
+ # print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))
100
+
101
+ ##### Sematic Search #####
102
+ # Encode the query using the bi-encoder and find potentially relevant passages
103
+ query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
104
+ query_embedding = query_embedding.cuda()
105
+ hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
106
+ hits = hits[0] # Get the hits for the first query
107
+
108
+ ##### Re-Ranking #####
109
+ # Now, score all retrieved passages with the cross_encoder
110
+ cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
111
+ cross_scores = cross_encoder.predict(cross_inp)
112
+
113
+ # Sort results by the cross-encoder scores
114
+ for idx in range(len(cross_scores)):
115
+ hits[idx]['cross-score'] = cross_scores[idx]
116
+
117
+ # Output of top-10 hits from bi-encoder
118
+ print("\n-------------------------\n")
119
+ print("Top-10 Bi-Encoder Retrieval hits")
120
+ hits = sorted(hits, key=lambda x: x['score'], reverse=True)
121
+ for hit in hits[0:10]:
122
+ print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))
123
+
124
+ # Output of top-10 hits from re-ranker
125
+ print("\n-------------------------\n")
126
+ print("Top-10 Cross-Encoder Re-ranker hits")
127
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
128
+ for hit in hits[0:10]:
129
+ print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
130
+
131
+ st.write("## Results:")
132
+ if st.button('Generate Sentence'):
133
+ out = search(query = user_query)