Spaces:
Running
on
T4
Running
on
T4
RAG changes
Browse files- =0.3.0, +0 -0
- RAG/colpali.py +341 -0
- requirements.txt +2 -1
=0.3.0,
ADDED
File without changes
|
RAG/colpali.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
import pprint
|
5 |
+
import os
|
6 |
+
from io import BytesIO
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Optional, cast
|
9 |
+
import numpy as np
|
10 |
+
#from datasets import load_dataset
|
11 |
+
import json
|
12 |
+
import boto3
|
13 |
+
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
|
14 |
+
from requests.auth import HTTPBasicAuth
|
15 |
+
from requests_aws4auth import AWS4Auth
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
import requests
|
18 |
+
import boto3
|
19 |
+
import streamlit as st
|
20 |
+
from IPython.display import display, Markdown
|
21 |
+
import base64
|
22 |
+
from colpali_engine.interpretability import (
|
23 |
+
get_similarity_maps_from_embeddings,
|
24 |
+
plot_all_similarity_maps,
|
25 |
+
plot_similarity_map,
|
26 |
+
)
|
27 |
+
import torch
|
28 |
+
# from colpali_engine.models import ColPali, ColPaliProcessor
|
29 |
+
# from colpali_engine.utils.torch_utils import get_torch_device
|
30 |
+
from PIL import Image
|
31 |
+
import utilities.invoke_models as invoke_models
|
32 |
+
|
33 |
+
model_name = (
|
34 |
+
"vidore/colpali-v1.3"
|
35 |
+
)
|
36 |
+
# colpali_model = ColPali.from_pretrained(
|
37 |
+
# model_name,
|
38 |
+
# torch_dtype=torch.bfloat16,
|
39 |
+
# device_map="cuda:0", # Use "cuda:0" for GPU, "cpu" for CPU, or "mps" for Apple Silicon
|
40 |
+
# ).eval()
|
41 |
+
|
42 |
+
# colpali_processor = ColPaliProcessor.from_pretrained(
|
43 |
+
# model_name
|
44 |
+
# )
|
45 |
+
|
46 |
+
credentials = boto3.Session().get_credentials()
|
47 |
+
auth = AWSV4SignerAuth(credentials, 'us-west-2', 'es')
|
48 |
+
aos_client = OpenSearch(
|
49 |
+
hosts = [{'host': 'search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com', 'port': 443}],
|
50 |
+
http_auth = auth,
|
51 |
+
use_ssl = True,
|
52 |
+
verify_certs = True,
|
53 |
+
connection_class = RequestsHttpConnection,
|
54 |
+
pool_maxsize = 20
|
55 |
+
)
|
56 |
+
region_endpoint = "us-east-1"
|
57 |
+
|
58 |
+
|
59 |
+
# Your SageMaker endpoint name
|
60 |
+
endpoint_name = "colpali-endpoint"
|
61 |
+
|
62 |
+
|
63 |
+
# Create a SageMaker runtime client
|
64 |
+
runtime = boto3.client("sagemaker-runtime", region_name=region_endpoint)
|
65 |
+
|
66 |
+
# Prepare your payload (e.g., text-only input)
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
def call_nova(
|
71 |
+
model,
|
72 |
+
messages,
|
73 |
+
system_message="",
|
74 |
+
streaming=False,
|
75 |
+
max_tokens=512,
|
76 |
+
temp=0.0001,
|
77 |
+
top_p=0.99,
|
78 |
+
top_k=20,
|
79 |
+
tools=None,
|
80 |
+
verbose=False,
|
81 |
+
):
|
82 |
+
client = boto3.client("bedrock-runtime",region_name=region_endpoint)
|
83 |
+
system_list = [{"text": system_message}]
|
84 |
+
inf_params = {
|
85 |
+
"max_new_tokens": max_tokens,
|
86 |
+
"top_p": top_p,
|
87 |
+
"top_k": top_k,
|
88 |
+
"temperature": temp,
|
89 |
+
}
|
90 |
+
request_body = {
|
91 |
+
"messages": messages,
|
92 |
+
"system": system_list,
|
93 |
+
"inferenceConfig": inf_params,
|
94 |
+
}
|
95 |
+
if tools is not None:
|
96 |
+
tool_config = []
|
97 |
+
for tool in tools:
|
98 |
+
tool_config.append({"toolSpec": tool})
|
99 |
+
request_body["toolConfig"] = {"tools": tool_config}
|
100 |
+
if verbose:
|
101 |
+
print("Request Body", request_body)
|
102 |
+
if not streaming:
|
103 |
+
response = client.invoke_model(modelId=model, body=json.dumps(request_body))
|
104 |
+
model_response = json.loads(response["body"].read())
|
105 |
+
return model_response, model_response["output"]["message"]["content"][0]["text"]
|
106 |
+
else:
|
107 |
+
response = client.invoke_model_with_response_stream(
|
108 |
+
modelId=model, body=json.dumps(request_body)
|
109 |
+
)
|
110 |
+
return response["body"]
|
111 |
+
def get_base64_encoded_value(media_path):
|
112 |
+
with open(media_path, "rb") as media_file:
|
113 |
+
binary_data = media_file.read()
|
114 |
+
base_64_encoded_data = base64.b64encode(binary_data)
|
115 |
+
base64_string = base_64_encoded_data.decode("utf-8")
|
116 |
+
return base64_string
|
117 |
+
|
118 |
+
def generate_ans(top_result,query):
|
119 |
+
print(query)
|
120 |
+
system_message = "given an image of a PDF page, answer the question. Be accurate to the question. If you don't find the answer in the page, please say, I don't know"
|
121 |
+
messages = [
|
122 |
+
{
|
123 |
+
"role": "user",
|
124 |
+
"content": [
|
125 |
+
{
|
126 |
+
"image": {
|
127 |
+
"format": "jpeg",
|
128 |
+
"source": {
|
129 |
+
"bytes": get_base64_encoded_value(
|
130 |
+
top_result
|
131 |
+
)
|
132 |
+
},
|
133 |
+
}
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"text": query#"what is the proportion of female new hires 2021-2023?"
|
137 |
+
},
|
138 |
+
],
|
139 |
+
}
|
140 |
+
]
|
141 |
+
model_response, content_text = call_nova(
|
142 |
+
"amazon.nova-pro-v1:0", messages, system_message=system_message, max_tokens=300
|
143 |
+
)
|
144 |
+
print(content_text)
|
145 |
+
return content_text
|
146 |
+
|
147 |
+
|
148 |
+
def colpali_search_rerank(query):
|
149 |
+
# Convert to JSON string
|
150 |
+
payload = {
|
151 |
+
"queries": [query]
|
152 |
+
}
|
153 |
+
body = json.dumps(payload)
|
154 |
+
|
155 |
+
# Call the endpoint
|
156 |
+
response = runtime.invoke_endpoint(
|
157 |
+
EndpointName=endpoint_name,
|
158 |
+
ContentType="application/json",
|
159 |
+
Body=body
|
160 |
+
)
|
161 |
+
|
162 |
+
# Read and print the response
|
163 |
+
result = json.loads(response["Body"].read().decode())
|
164 |
+
#print(len(result['query_embeddings'][0]))
|
165 |
+
|
166 |
+
final_docs_sorted_20 = []
|
167 |
+
for i in result['query_embeddings']:
|
168 |
+
batch_embeddings = i
|
169 |
+
a = np.array(batch_embeddings)
|
170 |
+
vec = a.mean(axis=0)
|
171 |
+
#print(vec)
|
172 |
+
hits = []
|
173 |
+
#for v in batch_embeddings:
|
174 |
+
query_ = {
|
175 |
+
"size": 200,
|
176 |
+
"query": {
|
177 |
+
"nested": {
|
178 |
+
"path": "page_sub_vectors",
|
179 |
+
"query": {
|
180 |
+
"knn": {
|
181 |
+
"page_sub_vectors.page_sub_vector": {
|
182 |
+
"vector": vec.tolist(),
|
183 |
+
"k": 200
|
184 |
+
}
|
185 |
+
}
|
186 |
+
}
|
187 |
+
}
|
188 |
+
}
|
189 |
+
}
|
190 |
+
response = aos_client.search(
|
191 |
+
body = query_,
|
192 |
+
index = 'colpali-vs'
|
193 |
+
)
|
194 |
+
#print(response)
|
195 |
+
query_token_vectors = batch_embeddings
|
196 |
+
final_docs = []
|
197 |
+
hits += response['hits']['hits']
|
198 |
+
#print(len(hits))
|
199 |
+
for ind,j in enumerate(hits):
|
200 |
+
max_score_dict_list = []
|
201 |
+
doc={"id":j["_id"],"score":j["_score"],"image":j["_source"]["image"]}
|
202 |
+
with_s = j['_source']['page_sub_vectors']
|
203 |
+
add_score = 0
|
204 |
+
|
205 |
+
for index,i in enumerate(query_token_vectors):
|
206 |
+
#token = vocab_dict[str(token_ids[index])]
|
207 |
+
#if(token!='[SEP]' and token!='[CLS]'):
|
208 |
+
query_token_vector = np.array(i)
|
209 |
+
#print("query token: "+token)
|
210 |
+
#print("-----------------")
|
211 |
+
scores = []
|
212 |
+
for m in with_s:
|
213 |
+
#m_arr = m.split("-")
|
214 |
+
#if(m_arr[-1]!='[SEP]' and m_arr[-1]!='[CLS]'):
|
215 |
+
#print("document token: "+m_arr[3])
|
216 |
+
doc_token_vector = np.array(m['page_sub_vector'])
|
217 |
+
score = np.dot(query_token_vector,doc_token_vector)
|
218 |
+
scores.append(score)
|
219 |
+
#print({"doc_token":m_arr[3],"score":score})
|
220 |
+
|
221 |
+
scores.sort(reverse=True)
|
222 |
+
max_score = scores[0]
|
223 |
+
add_score+=max_score
|
224 |
+
#max_score_dict_list.append(newlist[0])
|
225 |
+
#print(newlist[0])
|
226 |
+
#max_score_dict_list_sorted = sorted(max_score_dict_list, key=lambda d: d['score'], reverse=True)
|
227 |
+
#print(max_score_dict_list_sorted)
|
228 |
+
# print(add_score)
|
229 |
+
doc["total_score"] = add_score
|
230 |
+
#doc['max_score_dict_list_sorted'] = max_score_dict_list_sorted
|
231 |
+
final_docs.append(doc)
|
232 |
+
final_docs_sorted = sorted(final_docs, key=lambda d: d['total_score'], reverse=True)
|
233 |
+
final_docs_sorted_20.append(final_docs_sorted[:20])
|
234 |
+
img = "/home/ubuntu/AI-search-with-amazon-opensearch-service/vs/"+final_docs_sorted_20[0][0]['image']
|
235 |
+
ans = generate_ans(img,query)
|
236 |
+
images_highlighted = [{'file':img}]
|
237 |
+
# if(st.session_state.show_columns == True):
|
238 |
+
# images_highlighted = img_highlight(img,query_token_vectors,result['query_tokens'])
|
239 |
+
st.session_state.top_img = img
|
240 |
+
st.session_state.query_token_vectors = query_token_vectors
|
241 |
+
st.session_state.query_tokens = result['query_tokens']
|
242 |
+
return {'text':ans,'source':img,'image':images_highlighted,'table':[]}#[{'file':img}]
|
243 |
+
|
244 |
+
|
245 |
+
|
246 |
+
def img_highlight(img,batch_queries,query_tokens):
|
247 |
+
# Reference from : https://github.com/tonywu71/colpali-cookbooks/blob/main/examples/gen_colpali_similarity_maps.ipynb
|
248 |
+
with open(img, "rb") as f:
|
249 |
+
img_b64 = base64.b64encode(f.read()).decode("utf-8")
|
250 |
+
|
251 |
+
# Construct payload with only the image
|
252 |
+
payload = {
|
253 |
+
"images": [img_b64]
|
254 |
+
}
|
255 |
+
|
256 |
+
# Send to endpoint
|
257 |
+
response = runtime.invoke_endpoint(
|
258 |
+
EndpointName=endpoint_name, # your endpoint name
|
259 |
+
ContentType="application/json",
|
260 |
+
Body=json.dumps(payload)
|
261 |
+
)
|
262 |
+
|
263 |
+
# Read response
|
264 |
+
img_colpali_res = (json.loads(response["Body"].read().decode()))
|
265 |
+
|
266 |
+
# Convert outputs to tensors
|
267 |
+
image_embeddings = torch.tensor(img_colpali_res["image_embeddings"][0]) # shape: [B, T, D] or [T, D]
|
268 |
+
query_embeddings = torch.tensor(batch_queries) # shape: [B, D]
|
269 |
+
# Ensure you're accessing the full 1D mask vector, not a single value
|
270 |
+
image_mask_list = img_colpali_res["image_mask"]
|
271 |
+
|
272 |
+
if isinstance(image_mask_list[0], list):
|
273 |
+
# Correct: list of lists
|
274 |
+
image_mask = torch.tensor(image_mask_list[0]).bool()
|
275 |
+
else:
|
276 |
+
# Edge case: already flattened
|
277 |
+
image_mask = torch.tensor(image_mask_list).bool()
|
278 |
+
|
279 |
+
print("Valid patch count:", image_mask.sum().item()) # shape: [B, T] or [T]
|
280 |
+
|
281 |
+
# Ensure 2D query_embeddings
|
282 |
+
if query_embeddings.dim() == 2:
|
283 |
+
query_embeddings = query_embeddings.unsqueeze(0)
|
284 |
+
|
285 |
+
# Ensure image_embeddings and image_mask are batched
|
286 |
+
if image_embeddings.dim() == 2:
|
287 |
+
image_embeddings = image_embeddings.unsqueeze(0) # [1, T, D]
|
288 |
+
|
289 |
+
if image_mask.dim() == 1:
|
290 |
+
image_mask = image_mask.unsqueeze(0)
|
291 |
+
|
292 |
+
print("query_embeddings shape:", query_embeddings.shape)
|
293 |
+
print("image_embeddings shape:", image_embeddings.shape)
|
294 |
+
print("image_mask shape:", image_mask.shape)
|
295 |
+
|
296 |
+
# Get the number of image patches
|
297 |
+
image = Image.open(img)
|
298 |
+
n_patches = (img_colpali_res["patch_shape"]['height'],img_colpali_res["patch_shape"]['width'])
|
299 |
+
print(f"Number of image patches: {n_patches}")
|
300 |
+
|
301 |
+
# # Generate the similarity maps
|
302 |
+
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
303 |
+
image_embeddings=image_embeddings,
|
304 |
+
query_embeddings=query_embeddings,
|
305 |
+
n_patches=n_patches,
|
306 |
+
image_mask = image_mask
|
307 |
+
)
|
308 |
+
|
309 |
+
# # Get the similarity map for our (only) input image
|
310 |
+
similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
|
311 |
+
|
312 |
+
print(f"Similarity map shape: (query_length, n_patches_x, n_patches_y) = {tuple(similarity_maps.shape)}")
|
313 |
+
print(query_tokens)
|
314 |
+
query_tokens_from_model = query_tokens[0]['tokens']
|
315 |
+
print(query_tokens_from_model)
|
316 |
+
print(type(query_tokens_from_model))
|
317 |
+
|
318 |
+
plots = plot_all_similarity_maps(
|
319 |
+
image=image,
|
320 |
+
query_tokens=query_tokens_from_model,
|
321 |
+
similarity_maps=similarity_maps,
|
322 |
+
figsize=(8, 8),
|
323 |
+
show_colorbar=False,
|
324 |
+
add_title=True,
|
325 |
+
)
|
326 |
+
map_images = []
|
327 |
+
for idx, (fig, ax) in enumerate(plots):
|
328 |
+
if(idx<3):
|
329 |
+
continue
|
330 |
+
savepath = "/home/ubuntu/AI-search-with-amazon-opensearch-service/similarity_maps/similarity_map_"+(img.split("/"))[-1]+"_token_"+str(idx)+"_"+query_tokens_from_model[idx]+".png"
|
331 |
+
fig.savefig(savepath, bbox_inches="tight")
|
332 |
+
map_images.append({'file':savepath})
|
333 |
+
print(f"Similarity map for token `{query_tokens_from_model[idx]}` saved at `{savepath}`")
|
334 |
+
|
335 |
+
plt.close("all")
|
336 |
+
return map_images
|
337 |
+
|
338 |
+
|
339 |
+
|
340 |
+
|
341 |
+
|
requirements.txt
CHANGED
@@ -15,4 +15,5 @@ langchain==0.2.16
|
|
15 |
langchain-core==0.2.39
|
16 |
langchain-community==0.2.16
|
17 |
langchain-experimental==0.0.65
|
18 |
-
lark==1.2.2
|
|
|
|
15 |
langchain-core==0.2.39
|
16 |
langchain-community==0.2.16
|
17 |
langchain-experimental==0.0.65
|
18 |
+
lark==1.2.2
|
19 |
+
colpali-engine>=0.3.0,<0.4.0
|