Spaces:
Running
on
T4
Running
on
T4
rerank model
Browse files- RAG/colpali.py +32 -33
- pages/Multimodal_Conversational_Search.py +1 -1
RAG/colpali.py
CHANGED
@@ -17,13 +17,12 @@ import matplotlib.pyplot as plt
|
|
17 |
import requests
|
18 |
import boto3
|
19 |
import streamlit as st
|
20 |
-
from IPython.display import display, Markdown
|
21 |
import base64
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
import torch
|
28 |
# from colpali_engine.models import ColPali, ColPaliProcessor
|
29 |
# from colpali_engine.utils.torch_utils import get_torch_device
|
@@ -286,37 +285,37 @@ def img_highlight(img,batch_queries,query_tokens):
|
|
286 |
print(f"Number of image patches: {n_patches}")
|
287 |
|
288 |
# # Generate the similarity maps
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
|
296 |
-
# #
|
297 |
-
|
298 |
|
299 |
-
|
300 |
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
|
318 |
-
|
319 |
-
return
|
320 |
|
321 |
|
322 |
|
|
|
17 |
import requests
|
18 |
import boto3
|
19 |
import streamlit as st
|
|
|
20 |
import base64
|
21 |
+
from colpali_engine.interpretability import (
|
22 |
+
get_similarity_maps_from_embeddings,
|
23 |
+
plot_all_similarity_maps,
|
24 |
+
plot_similarity_map,
|
25 |
+
)
|
26 |
import torch
|
27 |
# from colpali_engine.models import ColPali, ColPaliProcessor
|
28 |
# from colpali_engine.utils.torch_utils import get_torch_device
|
|
|
285 |
print(f"Number of image patches: {n_patches}")
|
286 |
|
287 |
# # Generate the similarity maps
|
288 |
+
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
289 |
+
image_embeddings=image_embeddings,
|
290 |
+
query_embeddings=query_embeddings,
|
291 |
+
n_patches=n_patches,
|
292 |
+
image_mask = image_mask
|
293 |
+
)
|
294 |
|
295 |
+
# # Get the similarity map for our (only) input image
|
296 |
+
similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
|
297 |
|
298 |
+
query_tokens_from_model = query_tokens[0]['tokens']
|
299 |
|
300 |
+
plots = plot_all_similarity_maps(
|
301 |
+
image=image,
|
302 |
+
query_tokens=query_tokens_from_model,
|
303 |
+
similarity_maps=similarity_maps,
|
304 |
+
figsize=(8, 8),
|
305 |
+
show_colorbar=False,
|
306 |
+
add_title=True,
|
307 |
+
)
|
308 |
+
map_images = []
|
309 |
+
for idx, (fig, ax) in enumerate(plots):
|
310 |
+
if(idx<3):
|
311 |
+
continue
|
312 |
+
savepath = "/home/user/app/similarity_maps/similarity_map_"+(img.split("/"))[-1]+"_token_"+str(idx)+"_"+query_tokens_from_model[idx]+".png"
|
313 |
+
fig.savefig(savepath, bbox_inches="tight")
|
314 |
+
map_images.append({'file':savepath})
|
315 |
+
print(f"Similarity map for token `{query_tokens_from_model[idx]}` saved at `{savepath}`")
|
316 |
|
317 |
+
plt.close("all")
|
318 |
+
return map_images
|
319 |
|
320 |
|
321 |
|
pages/Multimodal_Conversational_Search.py
CHANGED
@@ -13,7 +13,7 @@ import botocore.session
|
|
13 |
import json
|
14 |
import random
|
15 |
import string
|
16 |
-
import rag_DocumentLoader
|
17 |
import rag_DocumentSearcher
|
18 |
import pandas as pd
|
19 |
from PIL import Image
|
|
|
13 |
import json
|
14 |
import random
|
15 |
import string
|
16 |
+
#import rag_DocumentLoader
|
17 |
import rag_DocumentSearcher
|
18 |
import pandas as pd
|
19 |
from PIL import Image
|