prasadnu commited on
Commit
bc2fefa
·
1 Parent(s): dc05755

rerank model

Browse files
RAG/colpali.py CHANGED
@@ -17,13 +17,12 @@ import matplotlib.pyplot as plt
17
  import requests
18
  import boto3
19
  import streamlit as st
20
- from IPython.display import display, Markdown
21
  import base64
22
- # from colpali_engine.interpretability import (
23
- # get_similarity_maps_from_embeddings,
24
- # plot_all_similarity_maps,
25
- # plot_similarity_map,
26
- # )
27
  import torch
28
  # from colpali_engine.models import ColPali, ColPaliProcessor
29
  # from colpali_engine.utils.torch_utils import get_torch_device
@@ -286,37 +285,37 @@ def img_highlight(img,batch_queries,query_tokens):
286
  print(f"Number of image patches: {n_patches}")
287
 
288
  # # Generate the similarity maps
289
- # batched_similarity_maps = get_similarity_maps_from_embeddings(
290
- # image_embeddings=image_embeddings,
291
- # query_embeddings=query_embeddings,
292
- # n_patches=n_patches,
293
- # image_mask = image_mask
294
- # )
295
 
296
- # # # Get the similarity map for our (only) input image
297
- # similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
298
 
299
- # query_tokens_from_model = query_tokens[0]['tokens']
300
 
301
- # plots = plot_all_similarity_maps(
302
- # image=image,
303
- # query_tokens=query_tokens_from_model,
304
- # similarity_maps=similarity_maps,
305
- # figsize=(8, 8),
306
- # show_colorbar=False,
307
- # add_title=True,
308
- # )
309
- # map_images = []
310
- # for idx, (fig, ax) in enumerate(plots):
311
- # if(idx<3):
312
- # continue
313
- # savepath = "/home/user/app/similarity_maps/similarity_map_"+(img.split("/"))[-1]+"_token_"+str(idx)+"_"+query_tokens_from_model[idx]+".png"
314
- # fig.savefig(savepath, bbox_inches="tight")
315
- # map_images.append({'file':savepath})
316
- # print(f"Similarity map for token `{query_tokens_from_model[idx]}` saved at `{savepath}`")
317
 
318
- # plt.close("all")
319
- return ""#map_images
320
 
321
 
322
 
 
17
  import requests
18
  import boto3
19
  import streamlit as st
 
20
  import base64
21
+ from colpali_engine.interpretability import (
22
+ get_similarity_maps_from_embeddings,
23
+ plot_all_similarity_maps,
24
+ plot_similarity_map,
25
+ )
26
  import torch
27
  # from colpali_engine.models import ColPali, ColPaliProcessor
28
  # from colpali_engine.utils.torch_utils import get_torch_device
 
285
  print(f"Number of image patches: {n_patches}")
286
 
287
  # # Generate the similarity maps
288
+ batched_similarity_maps = get_similarity_maps_from_embeddings(
289
+ image_embeddings=image_embeddings,
290
+ query_embeddings=query_embeddings,
291
+ n_patches=n_patches,
292
+ image_mask = image_mask
293
+ )
294
 
295
+ # # Get the similarity map for our (only) input image
296
+ similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
297
 
298
+ query_tokens_from_model = query_tokens[0]['tokens']
299
 
300
+ plots = plot_all_similarity_maps(
301
+ image=image,
302
+ query_tokens=query_tokens_from_model,
303
+ similarity_maps=similarity_maps,
304
+ figsize=(8, 8),
305
+ show_colorbar=False,
306
+ add_title=True,
307
+ )
308
+ map_images = []
309
+ for idx, (fig, ax) in enumerate(plots):
310
+ if(idx<3):
311
+ continue
312
+ savepath = "/home/user/app/similarity_maps/similarity_map_"+(img.split("/"))[-1]+"_token_"+str(idx)+"_"+query_tokens_from_model[idx]+".png"
313
+ fig.savefig(savepath, bbox_inches="tight")
314
+ map_images.append({'file':savepath})
315
+ print(f"Similarity map for token `{query_tokens_from_model[idx]}` saved at `{savepath}`")
316
 
317
+ plt.close("all")
318
+ return map_images
319
 
320
 
321
 
pages/Multimodal_Conversational_Search.py CHANGED
@@ -13,7 +13,7 @@ import botocore.session
13
  import json
14
  import random
15
  import string
16
- import rag_DocumentLoader
17
  import rag_DocumentSearcher
18
  import pandas as pd
19
  from PIL import Image
 
13
  import json
14
  import random
15
  import string
16
+ #import rag_DocumentLoader
17
  import rag_DocumentSearcher
18
  import pandas as pd
19
  from PIL import Image