prasadnu commited on
Commit
b481a29
·
1 Parent(s): bc2fefa

rerank model

Browse files
RAG/rag_DocumentSearcher.py CHANGED
@@ -12,7 +12,7 @@ headers = {"Content-Type": "application/json"}
12
  host = "https://search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com/"
13
 
14
  parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
15
-
16
  def query_(awsauth,inputs, session_id,search_types):
17
 
18
  print("using index: "+st.session_state.input_index)
 
12
  host = "https://search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com/"
13
 
14
  parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
15
+ @st.cache_data
16
  def query_(awsauth,inputs, session_id,search_types):
17
 
18
  print("using index: "+st.session_state.input_index)
pages/Multimodal_Conversational_Search.py CHANGED
@@ -39,6 +39,7 @@ AI_ICON = "images/opensearch-twitter-card.png"
39
  REGENERATE_ICON = "images/regenerate.png"
40
  s3_bucket_ = "pdf-repo-uploads"
41
  #"pdf-repo-uploads"
 
42
  polly_client = boto3.client('polly',aws_access_key_id=st.secrets['user_access_key'],
43
  aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1')
44
 
@@ -99,6 +100,7 @@ if "input_rag_searchType" not in st.session_state:
99
 
100
 
101
  region = 'us-east-1'
 
102
  bedrock_runtime_client = boto3.client('bedrock-runtime',region_name=region)
103
  output = []
104
  service = 'es'
@@ -345,8 +347,6 @@ def render_answer(question,answer,index,res_img):
345
 
346
  rdn_key = ''.join([random.choice(string.ascii_letters)
347
  for _ in range(10)])
348
- # rdn_key_1 = ''.join([random.choice(string.ascii_letters)
349
- # for _ in range(10)])
350
  currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
351
  oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
352
  def on_button_click():
@@ -358,12 +358,6 @@ def render_answer(question,answer,index,res_img):
358
  handle_input()
359
  with placeholder.container():
360
  render_all()
361
- # def show_maxsim():
362
- # st.session_state.show_columns = True
363
- # st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
364
- # handle_input()
365
- # with placeholder.container():
366
- # render_all()
367
  if("currentValue" in st.session_state):
368
  del st.session_state["currentValue"]
369
 
 
39
  REGENERATE_ICON = "images/regenerate.png"
40
  s3_bucket_ = "pdf-repo-uploads"
41
  #"pdf-repo-uploads"
42
+ @st.cache_data
43
  polly_client = boto3.client('polly',aws_access_key_id=st.secrets['user_access_key'],
44
  aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1')
45
 
 
100
 
101
 
102
  region = 'us-east-1'
103
+ @st.cache_data
104
  bedrock_runtime_client = boto3.client('bedrock-runtime',region_name=region)
105
  output = []
106
  service = 'es'
 
347
 
348
  rdn_key = ''.join([random.choice(string.ascii_letters)
349
  for _ in range(10)])
 
 
350
  currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
351
  oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
352
  def on_button_click():
 
358
  handle_input()
359
  with placeholder.container():
360
  render_all()
 
 
 
 
 
 
361
  if("currentValue" in st.session_state):
362
  del st.session_state["currentValue"]
363
 
utilities/invoke_models.py CHANGED
@@ -11,6 +11,7 @@ import streamlit as st
11
  #import torch
12
 
13
  region = 'us-east-1'
 
14
  bedrock_runtime_client = boto3.client(
15
  'bedrock-runtime',
16
  aws_access_key_id=st.secrets['user_access_key'],
@@ -29,7 +30,7 @@ bedrock_runtime_client = boto3.client(
29
  # max_length = 16
30
  # num_beams = 4
31
  # gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
32
-
33
  def invoke_model(input):
34
  response = bedrock_runtime_client.invoke_model(
35
  body=json.dumps({
@@ -42,7 +43,7 @@ def invoke_model(input):
42
 
43
  response_body = json.loads(response.get("body").read())
44
  return response_body.get("embedding")
45
-
46
  def invoke_model_mm(text,img):
47
  body_ = {
48
  "inputText": text,
@@ -63,7 +64,7 @@ def invoke_model_mm(text,img):
63
  response_body = json.loads(response.get("body").read())
64
  #print(response_body)
65
  return response_body.get("embedding")
66
-
67
  def invoke_llm_model(input,is_stream):
68
  if(is_stream == False):
69
  response = bedrock_runtime_client.invoke_model(
@@ -144,7 +145,7 @@ def invoke_llm_model(input,is_stream):
144
  # stream = response.get('body')
145
 
146
  # return stream
147
-
148
  def read_from_table(file,question):
149
  print("started table analysis:")
150
  print("-----------------------")
@@ -159,7 +160,7 @@ def read_from_table(file,question):
159
  "top_p":0.7,
160
  "stop_sequences":["\\n\\nHuman:"]
161
  }
162
-
163
  model = BedrockChat(
164
  client=bedrock_runtime_client,
165
  model_id='anthropic.claude-3-sonnet-20240229-v1:0',
@@ -180,7 +181,7 @@ def read_from_table(file,question):
180
  )
181
  agent_res = agent.invoke(question)['output']
182
  return agent_res
183
-
184
  def generate_image_captions_llm(base64_string,question):
185
 
186
  # ant_client = Anthropic()
 
11
  #import torch
12
 
13
  region = 'us-east-1'
14
+ @st.cache_data
15
  bedrock_runtime_client = boto3.client(
16
  'bedrock-runtime',
17
  aws_access_key_id=st.secrets['user_access_key'],
 
30
  # max_length = 16
31
  # num_beams = 4
32
  # gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
33
+ @st.cache_data
34
  def invoke_model(input):
35
  response = bedrock_runtime_client.invoke_model(
36
  body=json.dumps({
 
43
 
44
  response_body = json.loads(response.get("body").read())
45
  return response_body.get("embedding")
46
+ @st.cache_data
47
  def invoke_model_mm(text,img):
48
  body_ = {
49
  "inputText": text,
 
64
  response_body = json.loads(response.get("body").read())
65
  #print(response_body)
66
  return response_body.get("embedding")
67
+ @st.cache_data
68
  def invoke_llm_model(input,is_stream):
69
  if(is_stream == False):
70
  response = bedrock_runtime_client.invoke_model(
 
145
  # stream = response.get('body')
146
 
147
  # return stream
148
+ @st.cache_data
149
  def read_from_table(file,question):
150
  print("started table analysis:")
151
  print("-----------------------")
 
160
  "top_p":0.7,
161
  "stop_sequences":["\\n\\nHuman:"]
162
  }
163
+ @st.cache_data
164
  model = BedrockChat(
165
  client=bedrock_runtime_client,
166
  model_id='anthropic.claude-3-sonnet-20240229-v1:0',
 
181
  )
182
  agent_res = agent.invoke(question)['output']
183
  return agent_res
184
+ @st.cache_data
185
  def generate_image_captions_llm(base64_string,question):
186
 
187
  # ant_client = Anthropic()