prasadnu commited on
Commit
dc05755
·
1 Parent(s): ca368ff

rerank model

Browse files
Files changed (1) hide show
  1. RAG/colpali.py +32 -32
RAG/colpali.py CHANGED
@@ -19,11 +19,11 @@ import boto3
19
  import streamlit as st
20
  from IPython.display import display, Markdown
21
  import base64
22
- from colpali_engine.interpretability import (
23
- get_similarity_maps_from_embeddings,
24
- plot_all_similarity_maps,
25
- plot_similarity_map,
26
- )
27
  import torch
28
  # from colpali_engine.models import ColPali, ColPaliProcessor
29
  # from colpali_engine.utils.torch_utils import get_torch_device
@@ -286,37 +286,37 @@ def img_highlight(img,batch_queries,query_tokens):
286
  print(f"Number of image patches: {n_patches}")
287
 
288
  # # Generate the similarity maps
289
- batched_similarity_maps = get_similarity_maps_from_embeddings(
290
- image_embeddings=image_embeddings,
291
- query_embeddings=query_embeddings,
292
- n_patches=n_patches,
293
- image_mask = image_mask
294
- )
295
 
296
- # # Get the similarity map for our (only) input image
297
- similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
298
 
299
- query_tokens_from_model = query_tokens[0]['tokens']
300
 
301
- plots = plot_all_similarity_maps(
302
- image=image,
303
- query_tokens=query_tokens_from_model,
304
- similarity_maps=similarity_maps,
305
- figsize=(8, 8),
306
- show_colorbar=False,
307
- add_title=True,
308
- )
309
- map_images = []
310
- for idx, (fig, ax) in enumerate(plots):
311
- if(idx<3):
312
- continue
313
- savepath = "/home/user/app/similarity_maps/similarity_map_"+(img.split("/"))[-1]+"_token_"+str(idx)+"_"+query_tokens_from_model[idx]+".png"
314
- fig.savefig(savepath, bbox_inches="tight")
315
- map_images.append({'file':savepath})
316
- print(f"Similarity map for token `{query_tokens_from_model[idx]}` saved at `{savepath}`")
317
 
318
- plt.close("all")
319
- return map_images
320
 
321
 
322
 
 
19
  import streamlit as st
20
  from IPython.display import display, Markdown
21
  import base64
22
+ # from colpali_engine.interpretability import (
23
+ # get_similarity_maps_from_embeddings,
24
+ # plot_all_similarity_maps,
25
+ # plot_similarity_map,
26
+ # )
27
  import torch
28
  # from colpali_engine.models import ColPali, ColPaliProcessor
29
  # from colpali_engine.utils.torch_utils import get_torch_device
 
286
  print(f"Number of image patches: {n_patches}")
287
 
288
  # # Generate the similarity maps
289
+ # batched_similarity_maps = get_similarity_maps_from_embeddings(
290
+ # image_embeddings=image_embeddings,
291
+ # query_embeddings=query_embeddings,
292
+ # n_patches=n_patches,
293
+ # image_mask = image_mask
294
+ # )
295
 
296
+ # # # Get the similarity map for our (only) input image
297
+ # similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
298
 
299
+ # query_tokens_from_model = query_tokens[0]['tokens']
300
 
301
+ # plots = plot_all_similarity_maps(
302
+ # image=image,
303
+ # query_tokens=query_tokens_from_model,
304
+ # similarity_maps=similarity_maps,
305
+ # figsize=(8, 8),
306
+ # show_colorbar=False,
307
+ # add_title=True,
308
+ # )
309
+ # map_images = []
310
+ # for idx, (fig, ax) in enumerate(plots):
311
+ # if(idx<3):
312
+ # continue
313
+ # savepath = "/home/user/app/similarity_maps/similarity_map_"+(img.split("/"))[-1]+"_token_"+str(idx)+"_"+query_tokens_from_model[idx]+".png"
314
+ # fig.savefig(savepath, bbox_inches="tight")
315
+ # map_images.append({'file':savepath})
316
+ # print(f"Similarity map for token `{query_tokens_from_model[idx]}` saved at `{savepath}`")
317
 
318
+ # plt.close("all")
319
+ return ""#map_images
320
 
321
 
322