Spaces:
Running
on
T4
Running
on
T4
rerank model
Browse files- RAG/colpali.py +5 -4
- RAG/rag_DocumentSearcher.py +33 -33
- utilities/invoke_models.py +6 -6
- utilities/re_ranker.py +1 -1
RAG/colpali.py
CHANGED
@@ -66,7 +66,7 @@ runtime = boto3.client("sagemaker-runtime",aws_access_key_id=st.secrets['user_ac
|
|
66 |
# Prepare your payload (e.g., text-only input)
|
67 |
|
68 |
|
69 |
-
|
70 |
def call_nova(
|
71 |
model,
|
72 |
messages,
|
@@ -110,13 +110,14 @@ def call_nova(
|
|
110 |
modelId=model, body=json.dumps(request_body)
|
111 |
)
|
112 |
return response["body"]
|
|
|
113 |
def get_base64_encoded_value(media_path):
|
114 |
with open(media_path, "rb") as media_file:
|
115 |
binary_data = media_file.read()
|
116 |
base_64_encoded_data = base64.b64encode(binary_data)
|
117 |
base64_string = base_64_encoded_data.decode("utf-8")
|
118 |
return base64_string
|
119 |
-
|
120 |
def generate_ans(top_result,query):
|
121 |
print(query)
|
122 |
system_message = "given an image of a PDF page, answer the question. Be accurate to the question. If you don't find the answer in the page, please say, I don't know"
|
@@ -146,7 +147,7 @@ def generate_ans(top_result,query):
|
|
146 |
print(content_text)
|
147 |
return content_text
|
148 |
|
149 |
-
|
150 |
def colpali_search_rerank(query):
|
151 |
# Convert to JSON string
|
152 |
payload = {
|
@@ -228,7 +229,7 @@ def colpali_search_rerank(query):
|
|
228 |
return {'text':ans,'source':img,'image':images_highlighted,'table':[]}#[{'file':img}]
|
229 |
|
230 |
|
231 |
-
|
232 |
def img_highlight(img,batch_queries,query_tokens):
|
233 |
# Reference from : https://github.com/tonywu71/colpali-cookbooks/blob/main/examples/gen_colpali_similarity_maps.ipynb
|
234 |
with open(img, "rb") as f:
|
|
|
66 |
# Prepare your payload (e.g., text-only input)
|
67 |
|
68 |
|
69 |
+
@st.cache_resource
|
70 |
def call_nova(
|
71 |
model,
|
72 |
messages,
|
|
|
110 |
modelId=model, body=json.dumps(request_body)
|
111 |
)
|
112 |
return response["body"]
|
113 |
+
@st.cache_resource
|
114 |
def get_base64_encoded_value(media_path):
|
115 |
with open(media_path, "rb") as media_file:
|
116 |
binary_data = media_file.read()
|
117 |
base_64_encoded_data = base64.b64encode(binary_data)
|
118 |
base64_string = base_64_encoded_data.decode("utf-8")
|
119 |
return base64_string
|
120 |
+
@st.cache_resource
|
121 |
def generate_ans(top_result,query):
|
122 |
print(query)
|
123 |
system_message = "given an image of a PDF page, answer the question. Be accurate to the question. If you don't find the answer in the page, please say, I don't know"
|
|
|
147 |
print(content_text)
|
148 |
return content_text
|
149 |
|
150 |
+
@st.cache_resource
|
151 |
def colpali_search_rerank(query):
|
152 |
# Convert to JSON string
|
153 |
payload = {
|
|
|
229 |
return {'text':ans,'source':img,'image':images_highlighted,'table':[]}#[{'file':img}]
|
230 |
|
231 |
|
232 |
+
@st.cache_resource
|
233 |
def img_highlight(img,batch_queries,query_tokens):
|
234 |
# Reference from : https://github.com/tonywu71/colpali-cookbooks/blob/main/examples/gen_colpali_similarity_maps.ipynb
|
235 |
with open(img, "rb") as f:
|
RAG/rag_DocumentSearcher.py
CHANGED
@@ -12,7 +12,7 @@ headers = {"Content-Type": "application/json"}
|
|
12 |
host = "https://search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com/"
|
13 |
|
14 |
parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
|
15 |
-
|
16 |
def query_(awsauth,inputs, session_id,search_types):
|
17 |
|
18 |
print("using index: "+st.session_state.input_index)
|
@@ -219,49 +219,49 @@ def query_(awsauth,inputs, session_id,search_types):
|
|
219 |
hits = response_['hits']['hits']
|
220 |
|
221 |
##### GET reference tables separately like *_mm index search for images ######
|
222 |
-
def lazy_get_table():
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
|
235 |
-
|
236 |
-
|
237 |
|
238 |
-
|
239 |
|
240 |
-
|
241 |
|
242 |
|
243 |
-
|
244 |
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
|
253 |
|
254 |
-
|
255 |
-
|
256 |
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
|
263 |
|
264 |
-
|
265 |
|
266 |
|
267 |
########################### LLM Generation ########################
|
|
|
12 |
host = "https://search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com/"
|
13 |
|
14 |
parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
|
15 |
+
@st.cache_resource
|
16 |
def query_(awsauth,inputs, session_id,search_types):
|
17 |
|
18 |
print("using index: "+st.session_state.input_index)
|
|
|
219 |
hits = response_['hits']['hits']
|
220 |
|
221 |
##### GET reference tables separately like *_mm index search for images ######
|
222 |
+
# def lazy_get_table():
|
223 |
+
# table_ref = []
|
224 |
+
# any_table_exists = False
|
225 |
+
# for fname in os.listdir(parent_dirname+"/split_pdf_csv"):
|
226 |
+
# if fname.startswith(st.session_state.input_index):
|
227 |
+
# any_table_exists = True
|
228 |
+
# break
|
229 |
+
# if(any_table_exists):
|
230 |
+
# #################### Basic Match query #################
|
231 |
+
# # payload_tables = {
|
232 |
+
# # "query": {
|
233 |
+
# # "bool":{
|
234 |
|
235 |
+
# # "must":{"match": {
|
236 |
+
# # "processed_element": question
|
237 |
|
238 |
+
# # }},
|
239 |
|
240 |
+
# # "filter":{"term":{"raw_element_type": "table"}}
|
241 |
|
242 |
|
243 |
+
# # }}}
|
244 |
|
245 |
+
# #################### Neural Sparse query #################
|
246 |
+
# payload_tables = {"query":{"neural_sparse": {
|
247 |
+
# "processed_element_embedding_sparse": {
|
248 |
+
# "query_text": question,
|
249 |
+
# "model_id": "fkol-ZMBTp0efWqBcO2P"
|
250 |
+
# }
|
251 |
+
# } } }
|
252 |
|
253 |
|
254 |
+
# r_ = requests.get(url, auth=awsauth, json=payload_tables, headers=headers)
|
255 |
+
# r_tables = json.loads(r_.text)
|
256 |
|
257 |
+
# for res_ in r_tables['hits']['hits']:
|
258 |
+
# if(res_["_source"]['raw_element_type'] == 'table'):
|
259 |
+
# table_ref.append({'name':res_["_source"]['table'],'text':res_["_source"]['processed_element']})
|
260 |
+
# if(len(table_ref) == 2):
|
261 |
+
# break
|
262 |
|
263 |
|
264 |
+
# return table_ref
|
265 |
|
266 |
|
267 |
########################### LLM Generation ########################
|
utilities/invoke_models.py
CHANGED
@@ -11,7 +11,7 @@ import streamlit as st
|
|
11 |
#import torch
|
12 |
|
13 |
region = 'us-east-1'
|
14 |
-
|
15 |
bedrock_runtime_client = boto3.client(
|
16 |
'bedrock-runtime',
|
17 |
aws_access_key_id=st.secrets['user_access_key'],
|
@@ -30,7 +30,7 @@ bedrock_runtime_client = boto3.client(
|
|
30 |
# max_length = 16
|
31 |
# num_beams = 4
|
32 |
# gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
33 |
-
|
34 |
def invoke_model(input):
|
35 |
response = bedrock_runtime_client.invoke_model(
|
36 |
body=json.dumps({
|
@@ -43,7 +43,7 @@ def invoke_model(input):
|
|
43 |
|
44 |
response_body = json.loads(response.get("body").read())
|
45 |
return response_body.get("embedding")
|
46 |
-
|
47 |
def invoke_model_mm(text,img):
|
48 |
body_ = {
|
49 |
"inputText": text,
|
@@ -64,7 +64,7 @@ def invoke_model_mm(text,img):
|
|
64 |
response_body = json.loads(response.get("body").read())
|
65 |
#print(response_body)
|
66 |
return response_body.get("embedding")
|
67 |
-
|
68 |
def invoke_llm_model(input,is_stream):
|
69 |
if(is_stream == False):
|
70 |
response = bedrock_runtime_client.invoke_model(
|
@@ -145,7 +145,7 @@ def invoke_llm_model(input,is_stream):
|
|
145 |
# stream = response.get('body')
|
146 |
|
147 |
# return stream
|
148 |
-
|
149 |
def read_from_table(file,question):
|
150 |
print("started table analysis:")
|
151 |
print("-----------------------")
|
@@ -181,7 +181,7 @@ def read_from_table(file,question):
|
|
181 |
)
|
182 |
agent_res = agent.invoke(question)['output']
|
183 |
return agent_res
|
184 |
-
|
185 |
def generate_image_captions_llm(base64_string,question):
|
186 |
|
187 |
# ant_client = Anthropic()
|
|
|
11 |
#import torch
|
12 |
|
13 |
region = 'us-east-1'
|
14 |
+
@st.cache_resource
|
15 |
bedrock_runtime_client = boto3.client(
|
16 |
'bedrock-runtime',
|
17 |
aws_access_key_id=st.secrets['user_access_key'],
|
|
|
30 |
# max_length = 16
|
31 |
# num_beams = 4
|
32 |
# gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
33 |
+
@st.cache_resource
|
34 |
def invoke_model(input):
|
35 |
response = bedrock_runtime_client.invoke_model(
|
36 |
body=json.dumps({
|
|
|
43 |
|
44 |
response_body = json.loads(response.get("body").read())
|
45 |
return response_body.get("embedding")
|
46 |
+
@st.cache_resource
|
47 |
def invoke_model_mm(text,img):
|
48 |
body_ = {
|
49 |
"inputText": text,
|
|
|
64 |
response_body = json.loads(response.get("body").read())
|
65 |
#print(response_body)
|
66 |
return response_body.get("embedding")
|
67 |
+
@st.cache_resource
|
68 |
def invoke_llm_model(input,is_stream):
|
69 |
if(is_stream == False):
|
70 |
response = bedrock_runtime_client.invoke_model(
|
|
|
145 |
# stream = response.get('body')
|
146 |
|
147 |
# return stream
|
148 |
+
@st.cache_resource
|
149 |
def read_from_table(file,question):
|
150 |
print("started table analysis:")
|
151 |
print("-----------------------")
|
|
|
181 |
)
|
182 |
agent_res = agent.invoke(question)['output']
|
183 |
return agent_res
|
184 |
+
@st.cache_resource
|
185 |
def generate_image_captions_llm(base64_string,question):
|
186 |
|
187 |
# ant_client = Anthropic()
|
utilities/re_ranker.py
CHANGED
@@ -46,7 +46,7 @@ from sentence_transformers import CrossEncoder
|
|
46 |
# print("Program ends.")
|
47 |
#########################
|
48 |
|
49 |
-
|
50 |
def re_rank(self_, rerank_type, search_type, question, answers):
|
51 |
|
52 |
ans = []
|
|
|
46 |
# print("Program ends.")
|
47 |
#########################
|
48 |
|
49 |
+
@st.cache_resource
|
50 |
def re_rank(self_, rerank_type, search_type, question, answers):
|
51 |
|
52 |
ans = []
|