prasadnu commited on
Commit
5cd2735
·
1 Parent(s): 0b4e50b

colpali fix

Browse files
RAG/colpali.py CHANGED
@@ -156,88 +156,6 @@ def generate_ans(top_result,query):
156
  print(content_text)
157
  return content_text
158
 
159
- @st.cache_resource
160
- def colpali_search_rerank(query):
161
- # Convert to JSON string
162
- payload = {
163
- "queries": [query]
164
- }
165
- body = json.dumps(payload)
166
-
167
- # Call the endpoint
168
- response = runtime.invoke_endpoint(
169
- EndpointName=endpoint_name,
170
- ContentType="application/json",
171
- Body=body
172
- )
173
-
174
- # Read and print the response
175
- result = json.loads(response["Body"].read().decode())
176
- #print(len(result['query_embeddings'][0]))
177
-
178
- final_docs_sorted_20 = []
179
- for i in result['query_embeddings']:
180
- batch_embeddings = i
181
- a = np.array(batch_embeddings)
182
- vec = a.mean(axis=0)
183
- #print(vec)
184
- hits = []
185
- #for v in batch_embeddings:
186
- query_ = {
187
- "size": 200,
188
- "query": {
189
- "nested": {
190
- "path": "page_sub_vectors",
191
- "query": {
192
- "knn": {
193
- "page_sub_vectors.page_sub_vector": {
194
- "vector": vec.tolist(),
195
- "k": 200
196
- }
197
- }
198
- }
199
- }
200
- }
201
- }
202
- response = aos_client.search(
203
- body = query_,
204
- index = 'colpali-vs'
205
- )
206
- #print(response)
207
- query_token_vectors = batch_embeddings
208
- final_docs = []
209
- hits += response['hits']['hits']
210
- #print(len(hits))
211
- for ind,j in enumerate(hits):
212
- max_score_dict_list = []
213
- doc={"id":j["_id"],"score":j["_score"],"image":j["_source"]["image"]}
214
- with_s = j['_source']['page_sub_vectors']
215
- add_score = 0
216
-
217
- for index,i in enumerate(query_token_vectors):
218
- query_token_vector = np.array(i)
219
- scores = []
220
- for m in with_s:
221
- doc_token_vector = np.array(m['page_sub_vector'])
222
- score = np.dot(query_token_vector,doc_token_vector)
223
- scores.append(score)
224
-
225
- scores.sort(reverse=True)
226
- max_score = scores[0]
227
- add_score+=max_score
228
- doc["total_score"] = add_score
229
- final_docs.append(doc)
230
- final_docs_sorted = sorted(final_docs, key=lambda d: d['total_score'], reverse=True)
231
- final_docs_sorted_20.append(final_docs_sorted[:20])
232
- img = "/home/user/app/vs/"+final_docs_sorted_20[0][0]['image']
233
- ans = generate_ans(img,query)
234
- images_highlighted = [{'file':img}]
235
- st.session_state.top_img = img
236
- st.session_state.query_token_vectors = query_token_vectors
237
- st.session_state.query_tokens = result['query_tokens']
238
- return {'text':ans,'source':img,'image':images_highlighted,'table':[]}#[{'file':img}]
239
-
240
-
241
  @st.cache_resource
242
  def img_highlight(img,batch_queries,query_tokens):
243
  # Reference from : https://github.com/tonywu71/colpali-cookbooks/blob/main/examples/gen_colpali_similarity_maps.ipynb
@@ -331,3 +249,94 @@ def img_highlight(img,batch_queries,query_tokens):
331
 
332
 
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  print(content_text)
157
  return content_text
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  @st.cache_resource
160
  def img_highlight(img,batch_queries,query_tokens):
161
  # Reference from : https://github.com/tonywu71/colpali-cookbooks/blob/main/examples/gen_colpali_similarity_maps.ipynb
 
249
 
250
 
251
 
252
+
253
+ @st.cache_resource
254
+ def colpali_search_rerank(query):
255
+ if(st.session_state.show_columns == True):
256
+ print("show columns activated------------------------")
257
+ st.session_state.maxSimImages = img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
258
+ st.session_state.show_columns = False
259
+ st.session_state.input_query = st.session_state.questions_[-1]["question"]
260
+ st.session_state.answers_.pop()
261
+ st.session_state.questions_.pop()
262
+ return {'text':st.session_state.answers_[0]['answer'],'source':st.session_state.answers_[0]['source'],'image':st.session_state.maxSimImages,'table':[]}
263
+ # Convert to JSON string
264
+ payload = {
265
+ "queries": [query]
266
+ }
267
+ body = json.dumps(payload)
268
+
269
+ # Call the endpoint
270
+ response = runtime.invoke_endpoint(
271
+ EndpointName=endpoint_name,
272
+ ContentType="application/json",
273
+ Body=body
274
+ )
275
+
276
+ # Read and print the response
277
+ result = json.loads(response["Body"].read().decode())
278
+ #print(len(result['query_embeddings'][0]))
279
+
280
+ final_docs_sorted_20 = []
281
+ for i in result['query_embeddings']:
282
+ batch_embeddings = i
283
+ a = np.array(batch_embeddings)
284
+ vec = a.mean(axis=0)
285
+ #print(vec)
286
+ hits = []
287
+ #for v in batch_embeddings:
288
+ query_ = {
289
+ "size": 200,
290
+ "query": {
291
+ "nested": {
292
+ "path": "page_sub_vectors",
293
+ "query": {
294
+ "knn": {
295
+ "page_sub_vectors.page_sub_vector": {
296
+ "vector": vec.tolist(),
297
+ "k": 200
298
+ }
299
+ }
300
+ }
301
+ }
302
+ }
303
+ }
304
+ response = aos_client.search(
305
+ body = query_,
306
+ index = 'colpali-vs'
307
+ )
308
+ #print(response)
309
+ query_token_vectors = batch_embeddings
310
+ final_docs = []
311
+ hits += response['hits']['hits']
312
+ #print(len(hits))
313
+ for ind,j in enumerate(hits):
314
+ max_score_dict_list = []
315
+ doc={"id":j["_id"],"score":j["_score"],"image":j["_source"]["image"]}
316
+ with_s = j['_source']['page_sub_vectors']
317
+ add_score = 0
318
+
319
+ for index,i in enumerate(query_token_vectors):
320
+ query_token_vector = np.array(i)
321
+ scores = []
322
+ for m in with_s:
323
+ doc_token_vector = np.array(m['page_sub_vector'])
324
+ score = np.dot(query_token_vector,doc_token_vector)
325
+ scores.append(score)
326
+
327
+ scores.sort(reverse=True)
328
+ max_score = scores[0]
329
+ add_score+=max_score
330
+ doc["total_score"] = add_score
331
+ final_docs.append(doc)
332
+ final_docs_sorted = sorted(final_docs, key=lambda d: d['total_score'], reverse=True)
333
+ final_docs_sorted_20.append(final_docs_sorted[:20])
334
+ img = "/home/user/app/vs/"+final_docs_sorted_20[0][0]['image']
335
+ ans = generate_ans(img,query)
336
+
337
+ images_highlighted = [{'file':img}]
338
+ st.session_state.top_img = img
339
+ st.session_state.query_token_vectors = query_token_vectors
340
+ st.session_state.query_tokens = result['query_tokens']
341
+ return {'text':ans,'source':img,'image':images_highlighted,'table':[]}#[{'file':img}]
342
+
pages/Multimodal_Conversational_Search.py CHANGED
@@ -281,18 +281,18 @@ def render_answer(question,answer,index,res_img):
281
  for _ in range(10)])
282
  def show_maxsim():
283
  st.session_state.show_columns = True
284
- st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
285
- st.session_state.input_query = st.session_state.questions_[-1]["question"]
286
- st.session_state.answers_.pop()
287
- st.session_state.questions_.pop()
288
  handle_input()
289
- with placeholder.container():
290
- render_all()
291
  print("*"*20)
292
  print(st.session_state.input_query)
293
  print(st.session_state.answers_)
294
  print(st.session_state.questions_)
295
  print("*"*20)
 
 
 
296
  if(st.session_state.input_is_colpali):
297
  st.button("Show similarity map",key=rdn_key_1,on_click = show_maxsim)
298
 
@@ -309,7 +309,7 @@ def render_answer(question,answer,index,res_img):
309
  for i in range(0,len(res_img)):
310
 
311
  if(st.session_state.input_is_colpali):
312
- if(st.session_state.show_columns == True):
313
  cols_per_row = 3
314
  st.session_state.image_placeholder=st.empty()
315
  with st.session_state.image_placeholder.container():
@@ -324,11 +324,6 @@ def render_answer(question,answer,index,res_img):
324
  col3_,col4_,col5_ = st.columns([33,33,33])
325
  with col3_:
326
  st.image(res_img[i]['file'])
327
-
328
-
329
-
330
-
331
-
332
  else:
333
  if(res_img[i]['file'].lower()!='none' and idx < 2):
334
  col3,col4,col5 = st.columns([33,33,33])
@@ -398,10 +393,7 @@ def render_answer(question,answer,index,res_img):
398
 
399
  #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
400
  def write_chat_message(md, q,index):
401
- if(st.session_state.show_columns):
402
- res_img = st.session_state.maxSimImages
403
- else:
404
- res_img = md['image']
405
  #st.session_state['session_id'] = res['session_id'] to be added in memory
406
  chat = st.container()
407
  with chat:
 
281
  for _ in range(10)])
282
  def show_maxsim():
283
  st.session_state.show_columns = True
284
+ # st.session_state.input_query = st.session_state.questions_[-1]["question"]
285
+ # st.session_state.answers_.pop()
286
+ # st.session_state.questions_.pop()
 
287
  handle_input()
 
 
288
  print("*"*20)
289
  print(st.session_state.input_query)
290
  print(st.session_state.answers_)
291
  print(st.session_state.questions_)
292
  print("*"*20)
293
+ with placeholder.container():
294
+ render_all()
295
+
296
  if(st.session_state.input_is_colpali):
297
  st.button("Show similarity map",key=rdn_key_1,on_click = show_maxsim)
298
 
 
309
  for i in range(0,len(res_img)):
310
 
311
  if(st.session_state.input_is_colpali):
312
+ if(len(res_img)>1):
313
  cols_per_row = 3
314
  st.session_state.image_placeholder=st.empty()
315
  with st.session_state.image_placeholder.container():
 
324
  col3_,col4_,col5_ = st.columns([33,33,33])
325
  with col3_:
326
  st.image(res_img[i]['file'])
 
 
 
 
 
327
  else:
328
  if(res_img[i]['file'].lower()!='none' and idx < 2):
329
  col3,col4,col5 = st.columns([33,33,33])
 
393
 
394
  #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
395
  def write_chat_message(md, q,index):
396
+ res_img = md['image']
 
 
 
397
  #st.session_state['session_id'] = res['session_id'] to be added in memory
398
  chat = st.container()
399
  with chat: