prasadnu commited on
Commit
991493a
·
1 Parent(s): ad41a02

RAG changes

Browse files
Files changed (3) hide show
  1. =0.3.0, +0 -0
  2. RAG/colpali.py +341 -0
  3. requirements.txt +2 -1
=0.3.0, ADDED
File without changes
RAG/colpali.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import random
4
+ import pprint
5
+ import os
6
+ from io import BytesIO
7
+ from pathlib import Path
8
+ from typing import Optional, cast
9
+ import numpy as np
10
+ #from datasets import load_dataset
11
+ import json
12
+ import boto3
13
+ from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
14
+ from requests.auth import HTTPBasicAuth
15
+ from requests_aws4auth import AWS4Auth
16
+ import matplotlib.pyplot as plt
17
+ import requests
18
+ import boto3
19
+ import streamlit as st
20
+ from IPython.display import display, Markdown
21
+ import base64
22
+ from colpali_engine.interpretability import (
23
+ get_similarity_maps_from_embeddings,
24
+ plot_all_similarity_maps,
25
+ plot_similarity_map,
26
+ )
27
+ import torch
28
+ # from colpali_engine.models import ColPali, ColPaliProcessor
29
+ # from colpali_engine.utils.torch_utils import get_torch_device
30
+ from PIL import Image
31
+ import utilities.invoke_models as invoke_models
32
+
33
+ model_name = (
34
+ "vidore/colpali-v1.3"
35
+ )
36
+ # colpali_model = ColPali.from_pretrained(
37
+ # model_name,
38
+ # torch_dtype=torch.bfloat16,
39
+ # device_map="cuda:0", # Use "cuda:0" for GPU, "cpu" for CPU, or "mps" for Apple Silicon
40
+ # ).eval()
41
+
42
+ # colpali_processor = ColPaliProcessor.from_pretrained(
43
+ # model_name
44
+ # )
45
+
46
+ credentials = boto3.Session().get_credentials()
47
+ auth = AWSV4SignerAuth(credentials, 'us-west-2', 'es')
48
+ aos_client = OpenSearch(
49
+ hosts = [{'host': 'search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com', 'port': 443}],
50
+ http_auth = auth,
51
+ use_ssl = True,
52
+ verify_certs = True,
53
+ connection_class = RequestsHttpConnection,
54
+ pool_maxsize = 20
55
+ )
56
+ region_endpoint = "us-east-1"
57
+
58
+
59
+ # Your SageMaker endpoint name
60
+ endpoint_name = "colpali-endpoint"
61
+
62
+
63
+ # Create a SageMaker runtime client
64
+ runtime = boto3.client("sagemaker-runtime", region_name=region_endpoint)
65
+
66
+ # Prepare your payload (e.g., text-only input)
67
+
68
+
69
+
70
+ def call_nova(
71
+ model,
72
+ messages,
73
+ system_message="",
74
+ streaming=False,
75
+ max_tokens=512,
76
+ temp=0.0001,
77
+ top_p=0.99,
78
+ top_k=20,
79
+ tools=None,
80
+ verbose=False,
81
+ ):
82
+ client = boto3.client("bedrock-runtime",region_name=region_endpoint)
83
+ system_list = [{"text": system_message}]
84
+ inf_params = {
85
+ "max_new_tokens": max_tokens,
86
+ "top_p": top_p,
87
+ "top_k": top_k,
88
+ "temperature": temp,
89
+ }
90
+ request_body = {
91
+ "messages": messages,
92
+ "system": system_list,
93
+ "inferenceConfig": inf_params,
94
+ }
95
+ if tools is not None:
96
+ tool_config = []
97
+ for tool in tools:
98
+ tool_config.append({"toolSpec": tool})
99
+ request_body["toolConfig"] = {"tools": tool_config}
100
+ if verbose:
101
+ print("Request Body", request_body)
102
+ if not streaming:
103
+ response = client.invoke_model(modelId=model, body=json.dumps(request_body))
104
+ model_response = json.loads(response["body"].read())
105
+ return model_response, model_response["output"]["message"]["content"][0]["text"]
106
+ else:
107
+ response = client.invoke_model_with_response_stream(
108
+ modelId=model, body=json.dumps(request_body)
109
+ )
110
+ return response["body"]
111
+ def get_base64_encoded_value(media_path):
112
+ with open(media_path, "rb") as media_file:
113
+ binary_data = media_file.read()
114
+ base_64_encoded_data = base64.b64encode(binary_data)
115
+ base64_string = base_64_encoded_data.decode("utf-8")
116
+ return base64_string
117
+
118
+ def generate_ans(top_result,query):
119
+ print(query)
120
+ system_message = "given an image of a PDF page, answer the question. Be accurate to the question. If you don't find the answer in the page, please say, I don't know"
121
+ messages = [
122
+ {
123
+ "role": "user",
124
+ "content": [
125
+ {
126
+ "image": {
127
+ "format": "jpeg",
128
+ "source": {
129
+ "bytes": get_base64_encoded_value(
130
+ top_result
131
+ )
132
+ },
133
+ }
134
+ },
135
+ {
136
+ "text": query#"what is the proportion of female new hires 2021-2023?"
137
+ },
138
+ ],
139
+ }
140
+ ]
141
+ model_response, content_text = call_nova(
142
+ "amazon.nova-pro-v1:0", messages, system_message=system_message, max_tokens=300
143
+ )
144
+ print(content_text)
145
+ return content_text
146
+
147
+
148
+ def colpali_search_rerank(query):
149
+ # Convert to JSON string
150
+ payload = {
151
+ "queries": [query]
152
+ }
153
+ body = json.dumps(payload)
154
+
155
+ # Call the endpoint
156
+ response = runtime.invoke_endpoint(
157
+ EndpointName=endpoint_name,
158
+ ContentType="application/json",
159
+ Body=body
160
+ )
161
+
162
+ # Read and print the response
163
+ result = json.loads(response["Body"].read().decode())
164
+ #print(len(result['query_embeddings'][0]))
165
+
166
+ final_docs_sorted_20 = []
167
+ for i in result['query_embeddings']:
168
+ batch_embeddings = i
169
+ a = np.array(batch_embeddings)
170
+ vec = a.mean(axis=0)
171
+ #print(vec)
172
+ hits = []
173
+ #for v in batch_embeddings:
174
+ query_ = {
175
+ "size": 200,
176
+ "query": {
177
+ "nested": {
178
+ "path": "page_sub_vectors",
179
+ "query": {
180
+ "knn": {
181
+ "page_sub_vectors.page_sub_vector": {
182
+ "vector": vec.tolist(),
183
+ "k": 200
184
+ }
185
+ }
186
+ }
187
+ }
188
+ }
189
+ }
190
+ response = aos_client.search(
191
+ body = query_,
192
+ index = 'colpali-vs'
193
+ )
194
+ #print(response)
195
+ query_token_vectors = batch_embeddings
196
+ final_docs = []
197
+ hits += response['hits']['hits']
198
+ #print(len(hits))
199
+ for ind,j in enumerate(hits):
200
+ max_score_dict_list = []
201
+ doc={"id":j["_id"],"score":j["_score"],"image":j["_source"]["image"]}
202
+ with_s = j['_source']['page_sub_vectors']
203
+ add_score = 0
204
+
205
+ for index,i in enumerate(query_token_vectors):
206
+ #token = vocab_dict[str(token_ids[index])]
207
+ #if(token!='[SEP]' and token!='[CLS]'):
208
+ query_token_vector = np.array(i)
209
+ #print("query token: "+token)
210
+ #print("-----------------")
211
+ scores = []
212
+ for m in with_s:
213
+ #m_arr = m.split("-")
214
+ #if(m_arr[-1]!='[SEP]' and m_arr[-1]!='[CLS]'):
215
+ #print("document token: "+m_arr[3])
216
+ doc_token_vector = np.array(m['page_sub_vector'])
217
+ score = np.dot(query_token_vector,doc_token_vector)
218
+ scores.append(score)
219
+ #print({"doc_token":m_arr[3],"score":score})
220
+
221
+ scores.sort(reverse=True)
222
+ max_score = scores[0]
223
+ add_score+=max_score
224
+ #max_score_dict_list.append(newlist[0])
225
+ #print(newlist[0])
226
+ #max_score_dict_list_sorted = sorted(max_score_dict_list, key=lambda d: d['score'], reverse=True)
227
+ #print(max_score_dict_list_sorted)
228
+ # print(add_score)
229
+ doc["total_score"] = add_score
230
+ #doc['max_score_dict_list_sorted'] = max_score_dict_list_sorted
231
+ final_docs.append(doc)
232
+ final_docs_sorted = sorted(final_docs, key=lambda d: d['total_score'], reverse=True)
233
+ final_docs_sorted_20.append(final_docs_sorted[:20])
234
+ img = "/home/ubuntu/AI-search-with-amazon-opensearch-service/vs/"+final_docs_sorted_20[0][0]['image']
235
+ ans = generate_ans(img,query)
236
+ images_highlighted = [{'file':img}]
237
+ # if(st.session_state.show_columns == True):
238
+ # images_highlighted = img_highlight(img,query_token_vectors,result['query_tokens'])
239
+ st.session_state.top_img = img
240
+ st.session_state.query_token_vectors = query_token_vectors
241
+ st.session_state.query_tokens = result['query_tokens']
242
+ return {'text':ans,'source':img,'image':images_highlighted,'table':[]}#[{'file':img}]
243
+
244
+
245
+
246
+ def img_highlight(img,batch_queries,query_tokens):
247
+ # Reference from : https://github.com/tonywu71/colpali-cookbooks/blob/main/examples/gen_colpali_similarity_maps.ipynb
248
+ with open(img, "rb") as f:
249
+ img_b64 = base64.b64encode(f.read()).decode("utf-8")
250
+
251
+ # Construct payload with only the image
252
+ payload = {
253
+ "images": [img_b64]
254
+ }
255
+
256
+ # Send to endpoint
257
+ response = runtime.invoke_endpoint(
258
+ EndpointName=endpoint_name, # your endpoint name
259
+ ContentType="application/json",
260
+ Body=json.dumps(payload)
261
+ )
262
+
263
+ # Read response
264
+ img_colpali_res = (json.loads(response["Body"].read().decode()))
265
+
266
+ # Convert outputs to tensors
267
+ image_embeddings = torch.tensor(img_colpali_res["image_embeddings"][0]) # shape: [B, T, D] or [T, D]
268
+ query_embeddings = torch.tensor(batch_queries) # shape: [B, D]
269
+ # Ensure you're accessing the full 1D mask vector, not a single value
270
+ image_mask_list = img_colpali_res["image_mask"]
271
+
272
+ if isinstance(image_mask_list[0], list):
273
+ # Correct: list of lists
274
+ image_mask = torch.tensor(image_mask_list[0]).bool()
275
+ else:
276
+ # Edge case: already flattened
277
+ image_mask = torch.tensor(image_mask_list).bool()
278
+
279
+ print("Valid patch count:", image_mask.sum().item()) # shape: [B, T] or [T]
280
+
281
+ # Ensure 2D query_embeddings
282
+ if query_embeddings.dim() == 2:
283
+ query_embeddings = query_embeddings.unsqueeze(0)
284
+
285
+ # Ensure image_embeddings and image_mask are batched
286
+ if image_embeddings.dim() == 2:
287
+ image_embeddings = image_embeddings.unsqueeze(0) # [1, T, D]
288
+
289
+ if image_mask.dim() == 1:
290
+ image_mask = image_mask.unsqueeze(0)
291
+
292
+ print("query_embeddings shape:", query_embeddings.shape)
293
+ print("image_embeddings shape:", image_embeddings.shape)
294
+ print("image_mask shape:", image_mask.shape)
295
+
296
+ # Get the number of image patches
297
+ image = Image.open(img)
298
+ n_patches = (img_colpali_res["patch_shape"]['height'],img_colpali_res["patch_shape"]['width'])
299
+ print(f"Number of image patches: {n_patches}")
300
+
301
+ # # Generate the similarity maps
302
+ batched_similarity_maps = get_similarity_maps_from_embeddings(
303
+ image_embeddings=image_embeddings,
304
+ query_embeddings=query_embeddings,
305
+ n_patches=n_patches,
306
+ image_mask = image_mask
307
+ )
308
+
309
+ # # Get the similarity map for our (only) input image
310
+ similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
311
+
312
+ print(f"Similarity map shape: (query_length, n_patches_x, n_patches_y) = {tuple(similarity_maps.shape)}")
313
+ print(query_tokens)
314
+ query_tokens_from_model = query_tokens[0]['tokens']
315
+ print(query_tokens_from_model)
316
+ print(type(query_tokens_from_model))
317
+
318
+ plots = plot_all_similarity_maps(
319
+ image=image,
320
+ query_tokens=query_tokens_from_model,
321
+ similarity_maps=similarity_maps,
322
+ figsize=(8, 8),
323
+ show_colorbar=False,
324
+ add_title=True,
325
+ )
326
+ map_images = []
327
+ for idx, (fig, ax) in enumerate(plots):
328
+ if(idx<3):
329
+ continue
330
+ savepath = "/home/ubuntu/AI-search-with-amazon-opensearch-service/similarity_maps/similarity_map_"+(img.split("/"))[-1]+"_token_"+str(idx)+"_"+query_tokens_from_model[idx]+".png"
331
+ fig.savefig(savepath, bbox_inches="tight")
332
+ map_images.append({'file':savepath})
333
+ print(f"Similarity map for token `{query_tokens_from_model[idx]}` saved at `{savepath}`")
334
+
335
+ plt.close("all")
336
+ return map_images
337
+
338
+
339
+
340
+
341
+
requirements.txt CHANGED
@@ -15,4 +15,5 @@ langchain==0.2.16
15
  langchain-core==0.2.39
16
  langchain-community==0.2.16
17
  langchain-experimental==0.0.65
18
- lark==1.2.2
 
 
15
  langchain-core==0.2.39
16
  langchain-community==0.2.16
17
  langchain-experimental==0.0.65
18
+ lark==1.2.2
19
+ colpali-engine>=0.3.0,<0.4.0