prasadnu commited on
Commit
c77dc87
·
1 Parent(s): bbe0e25

search pipeline updated

Browse files
Files changed (2) hide show
  1. RAG/colpali.py +2 -22
  2. RAG/rag_DocumentSearcher.py +17 -2
RAG/colpali.py CHANGED
@@ -206,39 +206,23 @@ def colpali_search_rerank(query):
206
  add_score = 0
207
 
208
  for index,i in enumerate(query_token_vectors):
209
- #token = vocab_dict[str(token_ids[index])]
210
- #if(token!='[SEP]' and token!='[CLS]'):
211
  query_token_vector = np.array(i)
212
- #print("query token: "+token)
213
- #print("-----------------")
214
  scores = []
215
  for m in with_s:
216
- #m_arr = m.split("-")
217
- #if(m_arr[-1]!='[SEP]' and m_arr[-1]!='[CLS]'):
218
- #print("document token: "+m_arr[3])
219
  doc_token_vector = np.array(m['page_sub_vector'])
220
  score = np.dot(query_token_vector,doc_token_vector)
221
  scores.append(score)
222
- #print({"doc_token":m_arr[3],"score":score})
223
-
224
  scores.sort(reverse=True)
225
  max_score = scores[0]
226
  add_score+=max_score
227
- #max_score_dict_list.append(newlist[0])
228
- #print(newlist[0])
229
- #max_score_dict_list_sorted = sorted(max_score_dict_list, key=lambda d: d['score'], reverse=True)
230
- #print(max_score_dict_list_sorted)
231
- # print(add_score)
232
  doc["total_score"] = add_score
233
- #doc['max_score_dict_list_sorted'] = max_score_dict_list_sorted
234
  final_docs.append(doc)
235
  final_docs_sorted = sorted(final_docs, key=lambda d: d['total_score'], reverse=True)
236
  final_docs_sorted_20.append(final_docs_sorted[:20])
237
  img = "/home/user/app/vs/"+final_docs_sorted_20[0][0]['image']
238
  ans = generate_ans(img,query)
239
  images_highlighted = [{'file':img}]
240
- # if(st.session_state.show_columns == True):
241
- # images_highlighted = img_highlight(img,query_token_vectors,result['query_tokens'])
242
  st.session_state.top_img = img
243
  st.session_state.query_token_vectors = query_token_vectors
244
  st.session_state.query_tokens = result['query_tokens']
@@ -312,12 +296,8 @@ def img_highlight(img,batch_queries,query_tokens):
312
  # # Get the similarity map for our (only) input image
313
  similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
314
 
315
- print(f"Similarity map shape: (query_length, n_patches_x, n_patches_y) = {tuple(similarity_maps.shape)}")
316
- print(query_tokens)
317
  query_tokens_from_model = query_tokens[0]['tokens']
318
- print(query_tokens_from_model)
319
- print(type(query_tokens_from_model))
320
-
321
  plots = plot_all_similarity_maps(
322
  image=image,
323
  query_tokens=query_tokens_from_model,
 
206
  add_score = 0
207
 
208
  for index,i in enumerate(query_token_vectors):
 
 
209
  query_token_vector = np.array(i)
 
 
210
  scores = []
211
  for m in with_s:
 
 
 
212
  doc_token_vector = np.array(m['page_sub_vector'])
213
  score = np.dot(query_token_vector,doc_token_vector)
214
  scores.append(score)
215
+
 
216
  scores.sort(reverse=True)
217
  max_score = scores[0]
218
  add_score+=max_score
 
 
 
 
 
219
  doc["total_score"] = add_score
 
220
  final_docs.append(doc)
221
  final_docs_sorted = sorted(final_docs, key=lambda d: d['total_score'], reverse=True)
222
  final_docs_sorted_20.append(final_docs_sorted[:20])
223
  img = "/home/user/app/vs/"+final_docs_sorted_20[0][0]['image']
224
  ans = generate_ans(img,query)
225
  images_highlighted = [{'file':img}]
 
 
226
  st.session_state.top_img = img
227
  st.session_state.query_token_vectors = query_token_vectors
228
  st.session_state.query_tokens = result['query_tokens']
 
296
  # # Get the similarity map for our (only) input image
297
  similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
298
 
 
 
299
  query_tokens_from_model = query_tokens[0]['tokens']
300
+
 
 
301
  plots = plot_all_similarity_maps(
302
  image=image,
303
  query_tokens=query_tokens_from_model,
RAG/rag_DocumentSearcher.py CHANGED
@@ -189,23 +189,38 @@ def query_(awsauth,inputs, session_id,search_types):
189
  # query_sparse = sparse_["inference_results"][0]["output"][0]["dataAsMap"]["response"][0]
190
 
191
  hits = []
192
- if(num_queries>1 or st.session_state.input_is_rerank):
193
  s_pipeline_url = host + s_pipeline_path
194
  r = requests.put(s_pipeline_url, auth=awsauth, json=s_pipeline_payload, headers=headers)
195
  path = st.session_state.input_index+"/_search?search_pipeline=rag-search-pipeline"
196
  else:
197
- path = st.session_state.input_index+"/_search"
 
 
 
198
  url = host+path
199
  if(len(hybrid_payload["query"]["hybrid"]["queries"])==1):
200
  single_query = hybrid_payload["query"]["hybrid"]["queries"][0]
201
  del hybrid_payload["query"]["hybrid"]
202
  hybrid_payload["query"] = single_query
 
 
 
 
 
 
203
  r = requests.get(url, auth=awsauth, json=hybrid_payload, headers=headers)
204
  response_ = json.loads(r.text)
205
  print(response_)
206
  hits = response_['hits']['hits']
207
 
208
  else:
 
 
 
 
 
 
209
  r = requests.get(url, auth=awsauth, json=hybrid_payload, headers=headers)
210
  response_ = json.loads(r.text)
211
  hits = response_['hits']['hits']
 
189
  # query_sparse = sparse_["inference_results"][0]["output"][0]["dataAsMap"]["response"][0]
190
 
191
  hits = []
192
+ if(num_queries>1):
193
  s_pipeline_url = host + s_pipeline_path
194
  r = requests.put(s_pipeline_url, auth=awsauth, json=s_pipeline_payload, headers=headers)
195
  path = st.session_state.input_index+"/_search?search_pipeline=rag-search-pipeline"
196
  else:
197
+ if(input_is_rerank):
198
+ path = st.session_state.input_index+"/_search?search_pipeline=rerank_pipeline_rag"
199
+ else:
200
+ path = st.session_state.input_index+"/_search"
201
  url = host+path
202
  if(len(hybrid_payload["query"]["hybrid"]["queries"])==1):
203
  single_query = hybrid_payload["query"]["hybrid"]["queries"][0]
204
  del hybrid_payload["query"]["hybrid"]
205
  hybrid_payload["query"] = single_query
206
+ if(st.session_state.input_is_rerank):
207
+ hybrid_payload["ext"] = {"rerank": {
208
+ "query_context": {
209
+ "query_text": question
210
+ }
211
+ }}
212
  r = requests.get(url, auth=awsauth, json=hybrid_payload, headers=headers)
213
  response_ = json.loads(r.text)
214
  print(response_)
215
  hits = response_['hits']['hits']
216
 
217
  else:
218
+ if(st.session_state.input_is_rerank):
219
+ hybrid_payload["ext"] = {"rerank": {
220
+ "query_context": {
221
+ "query_text": question
222
+ }
223
+ }}
224
  r = requests.get(url, auth=awsauth, json=hybrid_payload, headers=headers)
225
  response_ = json.loads(r.text)
226
  hits = response_['hits']['hits']