Spaces:
Running
on
T4
Running
on
T4
search pipeline updated
Browse files- RAG/colpali.py +2 -22
- RAG/rag_DocumentSearcher.py +17 -2
RAG/colpali.py
CHANGED
@@ -206,39 +206,23 @@ def colpali_search_rerank(query):
|
|
206 |
add_score = 0
|
207 |
|
208 |
for index,i in enumerate(query_token_vectors):
|
209 |
-
#token = vocab_dict[str(token_ids[index])]
|
210 |
-
#if(token!='[SEP]' and token!='[CLS]'):
|
211 |
query_token_vector = np.array(i)
|
212 |
-
#print("query token: "+token)
|
213 |
-
#print("-----------------")
|
214 |
scores = []
|
215 |
for m in with_s:
|
216 |
-
#m_arr = m.split("-")
|
217 |
-
#if(m_arr[-1]!='[SEP]' and m_arr[-1]!='[CLS]'):
|
218 |
-
#print("document token: "+m_arr[3])
|
219 |
doc_token_vector = np.array(m['page_sub_vector'])
|
220 |
score = np.dot(query_token_vector,doc_token_vector)
|
221 |
scores.append(score)
|
222 |
-
|
223 |
-
|
224 |
scores.sort(reverse=True)
|
225 |
max_score = scores[0]
|
226 |
add_score+=max_score
|
227 |
-
#max_score_dict_list.append(newlist[0])
|
228 |
-
#print(newlist[0])
|
229 |
-
#max_score_dict_list_sorted = sorted(max_score_dict_list, key=lambda d: d['score'], reverse=True)
|
230 |
-
#print(max_score_dict_list_sorted)
|
231 |
-
# print(add_score)
|
232 |
doc["total_score"] = add_score
|
233 |
-
#doc['max_score_dict_list_sorted'] = max_score_dict_list_sorted
|
234 |
final_docs.append(doc)
|
235 |
final_docs_sorted = sorted(final_docs, key=lambda d: d['total_score'], reverse=True)
|
236 |
final_docs_sorted_20.append(final_docs_sorted[:20])
|
237 |
img = "/home/user/app/vs/"+final_docs_sorted_20[0][0]['image']
|
238 |
ans = generate_ans(img,query)
|
239 |
images_highlighted = [{'file':img}]
|
240 |
-
# if(st.session_state.show_columns == True):
|
241 |
-
# images_highlighted = img_highlight(img,query_token_vectors,result['query_tokens'])
|
242 |
st.session_state.top_img = img
|
243 |
st.session_state.query_token_vectors = query_token_vectors
|
244 |
st.session_state.query_tokens = result['query_tokens']
|
@@ -312,12 +296,8 @@ def img_highlight(img,batch_queries,query_tokens):
|
|
312 |
# # Get the similarity map for our (only) input image
|
313 |
similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
|
314 |
|
315 |
-
print(f"Similarity map shape: (query_length, n_patches_x, n_patches_y) = {tuple(similarity_maps.shape)}")
|
316 |
-
print(query_tokens)
|
317 |
query_tokens_from_model = query_tokens[0]['tokens']
|
318 |
-
|
319 |
-
print(type(query_tokens_from_model))
|
320 |
-
|
321 |
plots = plot_all_similarity_maps(
|
322 |
image=image,
|
323 |
query_tokens=query_tokens_from_model,
|
|
|
206 |
add_score = 0
|
207 |
|
208 |
for index,i in enumerate(query_token_vectors):
|
|
|
|
|
209 |
query_token_vector = np.array(i)
|
|
|
|
|
210 |
scores = []
|
211 |
for m in with_s:
|
|
|
|
|
|
|
212 |
doc_token_vector = np.array(m['page_sub_vector'])
|
213 |
score = np.dot(query_token_vector,doc_token_vector)
|
214 |
scores.append(score)
|
215 |
+
|
|
|
216 |
scores.sort(reverse=True)
|
217 |
max_score = scores[0]
|
218 |
add_score+=max_score
|
|
|
|
|
|
|
|
|
|
|
219 |
doc["total_score"] = add_score
|
|
|
220 |
final_docs.append(doc)
|
221 |
final_docs_sorted = sorted(final_docs, key=lambda d: d['total_score'], reverse=True)
|
222 |
final_docs_sorted_20.append(final_docs_sorted[:20])
|
223 |
img = "/home/user/app/vs/"+final_docs_sorted_20[0][0]['image']
|
224 |
ans = generate_ans(img,query)
|
225 |
images_highlighted = [{'file':img}]
|
|
|
|
|
226 |
st.session_state.top_img = img
|
227 |
st.session_state.query_token_vectors = query_token_vectors
|
228 |
st.session_state.query_tokens = result['query_tokens']
|
|
|
296 |
# # Get the similarity map for our (only) input image
|
297 |
similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
|
298 |
|
|
|
|
|
299 |
query_tokens_from_model = query_tokens[0]['tokens']
|
300 |
+
|
|
|
|
|
301 |
plots = plot_all_similarity_maps(
|
302 |
image=image,
|
303 |
query_tokens=query_tokens_from_model,
|
RAG/rag_DocumentSearcher.py
CHANGED
@@ -189,23 +189,38 @@ def query_(awsauth,inputs, session_id,search_types):
|
|
189 |
# query_sparse = sparse_["inference_results"][0]["output"][0]["dataAsMap"]["response"][0]
|
190 |
|
191 |
hits = []
|
192 |
-
if(num_queries>1
|
193 |
s_pipeline_url = host + s_pipeline_path
|
194 |
r = requests.put(s_pipeline_url, auth=awsauth, json=s_pipeline_payload, headers=headers)
|
195 |
path = st.session_state.input_index+"/_search?search_pipeline=rag-search-pipeline"
|
196 |
else:
|
197 |
-
|
|
|
|
|
|
|
198 |
url = host+path
|
199 |
if(len(hybrid_payload["query"]["hybrid"]["queries"])==1):
|
200 |
single_query = hybrid_payload["query"]["hybrid"]["queries"][0]
|
201 |
del hybrid_payload["query"]["hybrid"]
|
202 |
hybrid_payload["query"] = single_query
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
r = requests.get(url, auth=awsauth, json=hybrid_payload, headers=headers)
|
204 |
response_ = json.loads(r.text)
|
205 |
print(response_)
|
206 |
hits = response_['hits']['hits']
|
207 |
|
208 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
r = requests.get(url, auth=awsauth, json=hybrid_payload, headers=headers)
|
210 |
response_ = json.loads(r.text)
|
211 |
hits = response_['hits']['hits']
|
|
|
189 |
# query_sparse = sparse_["inference_results"][0]["output"][0]["dataAsMap"]["response"][0]
|
190 |
|
191 |
hits = []
|
192 |
+
if(num_queries>1):
|
193 |
s_pipeline_url = host + s_pipeline_path
|
194 |
r = requests.put(s_pipeline_url, auth=awsauth, json=s_pipeline_payload, headers=headers)
|
195 |
path = st.session_state.input_index+"/_search?search_pipeline=rag-search-pipeline"
|
196 |
else:
|
197 |
+
if(input_is_rerank):
|
198 |
+
path = st.session_state.input_index+"/_search?search_pipeline=rerank_pipeline_rag"
|
199 |
+
else:
|
200 |
+
path = st.session_state.input_index+"/_search"
|
201 |
url = host+path
|
202 |
if(len(hybrid_payload["query"]["hybrid"]["queries"])==1):
|
203 |
single_query = hybrid_payload["query"]["hybrid"]["queries"][0]
|
204 |
del hybrid_payload["query"]["hybrid"]
|
205 |
hybrid_payload["query"] = single_query
|
206 |
+
if(st.session_state.input_is_rerank):
|
207 |
+
hybrid_payload["ext"] = {"rerank": {
|
208 |
+
"query_context": {
|
209 |
+
"query_text": question
|
210 |
+
}
|
211 |
+
}}
|
212 |
r = requests.get(url, auth=awsauth, json=hybrid_payload, headers=headers)
|
213 |
response_ = json.loads(r.text)
|
214 |
print(response_)
|
215 |
hits = response_['hits']['hits']
|
216 |
|
217 |
else:
|
218 |
+
if(st.session_state.input_is_rerank):
|
219 |
+
hybrid_payload["ext"] = {"rerank": {
|
220 |
+
"query_context": {
|
221 |
+
"query_text": question
|
222 |
+
}
|
223 |
+
}}
|
224 |
r = requests.get(url, auth=awsauth, json=hybrid_payload, headers=headers)
|
225 |
response_ = json.loads(r.text)
|
226 |
hits = response_['hits']['hits']
|