prasadnu commited on
Commit
1f43c77
·
1 Parent(s): 322689b

rerank model

Browse files
RAG/colpali.py CHANGED
@@ -66,7 +66,7 @@ runtime = boto3.client("sagemaker-runtime",aws_access_key_id=st.secrets['user_ac
66
  # Prepare your payload (e.g., text-only input)
67
 
68
 
69
-
70
  def call_nova(
71
  model,
72
  messages,
@@ -110,13 +110,14 @@ def call_nova(
110
  modelId=model, body=json.dumps(request_body)
111
  )
112
  return response["body"]
 
113
  def get_base64_encoded_value(media_path):
114
  with open(media_path, "rb") as media_file:
115
  binary_data = media_file.read()
116
  base_64_encoded_data = base64.b64encode(binary_data)
117
  base64_string = base_64_encoded_data.decode("utf-8")
118
  return base64_string
119
-
120
  def generate_ans(top_result,query):
121
  print(query)
122
  system_message = "given an image of a PDF page, answer the question. Be accurate to the question. If you don't find the answer in the page, please say, I don't know"
@@ -146,7 +147,7 @@ def generate_ans(top_result,query):
146
  print(content_text)
147
  return content_text
148
 
149
-
150
  def colpali_search_rerank(query):
151
  # Convert to JSON string
152
  payload = {
@@ -228,7 +229,7 @@ def colpali_search_rerank(query):
228
  return {'text':ans,'source':img,'image':images_highlighted,'table':[]}#[{'file':img}]
229
 
230
 
231
-
232
  def img_highlight(img,batch_queries,query_tokens):
233
  # Reference from : https://github.com/tonywu71/colpali-cookbooks/blob/main/examples/gen_colpali_similarity_maps.ipynb
234
  with open(img, "rb") as f:
 
66
  # Prepare your payload (e.g., text-only input)
67
 
68
 
69
+ @st.cache_resource
70
  def call_nova(
71
  model,
72
  messages,
 
110
  modelId=model, body=json.dumps(request_body)
111
  )
112
  return response["body"]
113
+ @st.cache_resource
114
  def get_base64_encoded_value(media_path):
115
  with open(media_path, "rb") as media_file:
116
  binary_data = media_file.read()
117
  base_64_encoded_data = base64.b64encode(binary_data)
118
  base64_string = base_64_encoded_data.decode("utf-8")
119
  return base64_string
120
+ @st.cache_resource
121
  def generate_ans(top_result,query):
122
  print(query)
123
  system_message = "given an image of a PDF page, answer the question. Be accurate to the question. If you don't find the answer in the page, please say, I don't know"
 
147
  print(content_text)
148
  return content_text
149
 
150
+ @st.cache_resource
151
  def colpali_search_rerank(query):
152
  # Convert to JSON string
153
  payload = {
 
229
  return {'text':ans,'source':img,'image':images_highlighted,'table':[]}#[{'file':img}]
230
 
231
 
232
+ @st.cache_resource
233
  def img_highlight(img,batch_queries,query_tokens):
234
  # Reference from : https://github.com/tonywu71/colpali-cookbooks/blob/main/examples/gen_colpali_similarity_maps.ipynb
235
  with open(img, "rb") as f:
RAG/rag_DocumentSearcher.py CHANGED
@@ -12,7 +12,7 @@ headers = {"Content-Type": "application/json"}
12
  host = "https://search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com/"
13
 
14
  parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
15
-
16
  def query_(awsauth,inputs, session_id,search_types):
17
 
18
  print("using index: "+st.session_state.input_index)
@@ -219,49 +219,49 @@ def query_(awsauth,inputs, session_id,search_types):
219
  hits = response_['hits']['hits']
220
 
221
  ##### GET reference tables separately like *_mm index search for images ######
222
- def lazy_get_table():
223
- table_ref = []
224
- any_table_exists = False
225
- for fname in os.listdir(parent_dirname+"/split_pdf_csv"):
226
- if fname.startswith(st.session_state.input_index):
227
- any_table_exists = True
228
- break
229
- if(any_table_exists):
230
- #################### Basic Match query #################
231
- # payload_tables = {
232
- # "query": {
233
- # "bool":{
234
 
235
- # "must":{"match": {
236
- # "processed_element": question
237
 
238
- # }},
239
 
240
- # "filter":{"term":{"raw_element_type": "table"}}
241
 
242
 
243
- # }}}
244
 
245
- #################### Neural Sparse query #################
246
- payload_tables = {"query":{"neural_sparse": {
247
- "processed_element_embedding_sparse": {
248
- "query_text": question,
249
- "model_id": "fkol-ZMBTp0efWqBcO2P"
250
- }
251
- } } }
252
 
253
 
254
- r_ = requests.get(url, auth=awsauth, json=payload_tables, headers=headers)
255
- r_tables = json.loads(r_.text)
256
 
257
- for res_ in r_tables['hits']['hits']:
258
- if(res_["_source"]['raw_element_type'] == 'table'):
259
- table_ref.append({'name':res_["_source"]['table'],'text':res_["_source"]['processed_element']})
260
- if(len(table_ref) == 2):
261
- break
262
 
263
 
264
- return table_ref
265
 
266
 
267
  ########################### LLM Generation ########################
 
12
  host = "https://search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com/"
13
 
14
  parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
15
+ @st.cache_resource
16
  def query_(awsauth,inputs, session_id,search_types):
17
 
18
  print("using index: "+st.session_state.input_index)
 
219
  hits = response_['hits']['hits']
220
 
221
  ##### GET reference tables separately like *_mm index search for images ######
222
+ # def lazy_get_table():
223
+ # table_ref = []
224
+ # any_table_exists = False
225
+ # for fname in os.listdir(parent_dirname+"/split_pdf_csv"):
226
+ # if fname.startswith(st.session_state.input_index):
227
+ # any_table_exists = True
228
+ # break
229
+ # if(any_table_exists):
230
+ # #################### Basic Match query #################
231
+ # # payload_tables = {
232
+ # # "query": {
233
+ # # "bool":{
234
 
235
+ # # "must":{"match": {
236
+ # # "processed_element": question
237
 
238
+ # # }},
239
 
240
+ # # "filter":{"term":{"raw_element_type": "table"}}
241
 
242
 
243
+ # # }}}
244
 
245
+ # #################### Neural Sparse query #################
246
+ # payload_tables = {"query":{"neural_sparse": {
247
+ # "processed_element_embedding_sparse": {
248
+ # "query_text": question,
249
+ # "model_id": "fkol-ZMBTp0efWqBcO2P"
250
+ # }
251
+ # } } }
252
 
253
 
254
+ # r_ = requests.get(url, auth=awsauth, json=payload_tables, headers=headers)
255
+ # r_tables = json.loads(r_.text)
256
 
257
+ # for res_ in r_tables['hits']['hits']:
258
+ # if(res_["_source"]['raw_element_type'] == 'table'):
259
+ # table_ref.append({'name':res_["_source"]['table'],'text':res_["_source"]['processed_element']})
260
+ # if(len(table_ref) == 2):
261
+ # break
262
 
263
 
264
+ # return table_ref
265
 
266
 
267
  ########################### LLM Generation ########################
utilities/invoke_models.py CHANGED
@@ -11,7 +11,7 @@ import streamlit as st
11
  #import torch
12
 
13
  region = 'us-east-1'
14
-
15
  bedrock_runtime_client = boto3.client(
16
  'bedrock-runtime',
17
  aws_access_key_id=st.secrets['user_access_key'],
@@ -30,7 +30,7 @@ bedrock_runtime_client = boto3.client(
30
  # max_length = 16
31
  # num_beams = 4
32
  # gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
33
-
34
  def invoke_model(input):
35
  response = bedrock_runtime_client.invoke_model(
36
  body=json.dumps({
@@ -43,7 +43,7 @@ def invoke_model(input):
43
 
44
  response_body = json.loads(response.get("body").read())
45
  return response_body.get("embedding")
46
-
47
  def invoke_model_mm(text,img):
48
  body_ = {
49
  "inputText": text,
@@ -64,7 +64,7 @@ def invoke_model_mm(text,img):
64
  response_body = json.loads(response.get("body").read())
65
  #print(response_body)
66
  return response_body.get("embedding")
67
-
68
  def invoke_llm_model(input,is_stream):
69
  if(is_stream == False):
70
  response = bedrock_runtime_client.invoke_model(
@@ -145,7 +145,7 @@ def invoke_llm_model(input,is_stream):
145
  # stream = response.get('body')
146
 
147
  # return stream
148
-
149
  def read_from_table(file,question):
150
  print("started table analysis:")
151
  print("-----------------------")
@@ -181,7 +181,7 @@ def read_from_table(file,question):
181
  )
182
  agent_res = agent.invoke(question)['output']
183
  return agent_res
184
-
185
  def generate_image_captions_llm(base64_string,question):
186
 
187
  # ant_client = Anthropic()
 
11
  #import torch
12
 
13
  region = 'us-east-1'
14
+ @st.cache_resource
15
  bedrock_runtime_client = boto3.client(
16
  'bedrock-runtime',
17
  aws_access_key_id=st.secrets['user_access_key'],
 
30
  # max_length = 16
31
  # num_beams = 4
32
  # gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
33
+ @st.cache_resource
34
  def invoke_model(input):
35
  response = bedrock_runtime_client.invoke_model(
36
  body=json.dumps({
 
43
 
44
  response_body = json.loads(response.get("body").read())
45
  return response_body.get("embedding")
46
+ @st.cache_resource
47
  def invoke_model_mm(text,img):
48
  body_ = {
49
  "inputText": text,
 
64
  response_body = json.loads(response.get("body").read())
65
  #print(response_body)
66
  return response_body.get("embedding")
67
+ @st.cache_resource
68
  def invoke_llm_model(input,is_stream):
69
  if(is_stream == False):
70
  response = bedrock_runtime_client.invoke_model(
 
145
  # stream = response.get('body')
146
 
147
  # return stream
148
+ @st.cache_resource
149
  def read_from_table(file,question):
150
  print("started table analysis:")
151
  print("-----------------------")
 
181
  )
182
  agent_res = agent.invoke(question)['output']
183
  return agent_res
184
+ @st.cache_resource
185
  def generate_image_captions_llm(base64_string,question):
186
 
187
  # ant_client = Anthropic()
utilities/re_ranker.py CHANGED
@@ -46,7 +46,7 @@ from sentence_transformers import CrossEncoder
46
  # print("Program ends.")
47
  #########################
48
 
49
-
50
  def re_rank(self_, rerank_type, search_type, question, answers):
51
 
52
  ans = []
 
46
  # print("Program ends.")
47
  #########################
48
 
49
+ @st.cache_resource
50
  def re_rank(self_, rerank_type, search_type, question, answers):
51
 
52
  ans = []