Spaces:
Running
on
T4
Running
on
T4
rerank model
Browse files- RAG/colpali.py +32 -32
RAG/colpali.py
CHANGED
@@ -19,11 +19,11 @@ import boto3
|
|
19 |
import streamlit as st
|
20 |
from IPython.display import display, Markdown
|
21 |
import base64
|
22 |
-
from colpali_engine.interpretability import (
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
)
|
27 |
import torch
|
28 |
# from colpali_engine.models import ColPali, ColPaliProcessor
|
29 |
# from colpali_engine.utils.torch_utils import get_torch_device
|
@@ -286,37 +286,37 @@ def img_highlight(img,batch_queries,query_tokens):
|
|
286 |
print(f"Number of image patches: {n_patches}")
|
287 |
|
288 |
# # Generate the similarity maps
|
289 |
-
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
)
|
295 |
|
296 |
-
# # Get the similarity map for our (only) input image
|
297 |
-
similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
|
298 |
|
299 |
-
query_tokens_from_model = query_tokens[0]['tokens']
|
300 |
|
301 |
-
plots = plot_all_similarity_maps(
|
302 |
-
image=image,
|
303 |
-
query_tokens=query_tokens_from_model,
|
304 |
-
similarity_maps=similarity_maps,
|
305 |
-
figsize=(8, 8),
|
306 |
-
show_colorbar=False,
|
307 |
-
add_title=True,
|
308 |
-
)
|
309 |
-
map_images = []
|
310 |
-
for idx, (fig, ax) in enumerate(plots):
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
|
318 |
-
plt.close("all")
|
319 |
-
return map_images
|
320 |
|
321 |
|
322 |
|
|
|
19 |
import streamlit as st
|
20 |
from IPython.display import display, Markdown
|
21 |
import base64
|
22 |
+
# from colpali_engine.interpretability import (
|
23 |
+
# get_similarity_maps_from_embeddings,
|
24 |
+
# plot_all_similarity_maps,
|
25 |
+
# plot_similarity_map,
|
26 |
+
# )
|
27 |
import torch
|
28 |
# from colpali_engine.models import ColPali, ColPaliProcessor
|
29 |
# from colpali_engine.utils.torch_utils import get_torch_device
|
|
|
286 |
print(f"Number of image patches: {n_patches}")
|
287 |
|
288 |
# # Generate the similarity maps
|
289 |
+
# batched_similarity_maps = get_similarity_maps_from_embeddings(
|
290 |
+
# image_embeddings=image_embeddings,
|
291 |
+
# query_embeddings=query_embeddings,
|
292 |
+
# n_patches=n_patches,
|
293 |
+
# image_mask = image_mask
|
294 |
+
# )
|
295 |
|
296 |
+
# # # Get the similarity map for our (only) input image
|
297 |
+
# similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
|
298 |
|
299 |
+
# query_tokens_from_model = query_tokens[0]['tokens']
|
300 |
|
301 |
+
# plots = plot_all_similarity_maps(
|
302 |
+
# image=image,
|
303 |
+
# query_tokens=query_tokens_from_model,
|
304 |
+
# similarity_maps=similarity_maps,
|
305 |
+
# figsize=(8, 8),
|
306 |
+
# show_colorbar=False,
|
307 |
+
# add_title=True,
|
308 |
+
# )
|
309 |
+
# map_images = []
|
310 |
+
# for idx, (fig, ax) in enumerate(plots):
|
311 |
+
# if(idx<3):
|
312 |
+
# continue
|
313 |
+
# savepath = "/home/user/app/similarity_maps/similarity_map_"+(img.split("/"))[-1]+"_token_"+str(idx)+"_"+query_tokens_from_model[idx]+".png"
|
314 |
+
# fig.savefig(savepath, bbox_inches="tight")
|
315 |
+
# map_images.append({'file':savepath})
|
316 |
+
# print(f"Similarity map for token `{query_tokens_from_model[idx]}` saved at `{savepath}`")
|
317 |
|
318 |
+
# plt.close("all")
|
319 |
+
return ""#map_images
|
320 |
|
321 |
|
322 |
|