Daniel Foley commited on
Commit
fa714bc
·
1 Parent(s): 16f0715

transferring initializations to st.session_state

Browse files
Files changed (2) hide show
  1. RAG.py +1 -2
  2. streamlit_app.py +29 -31
RAG.py CHANGED
@@ -118,7 +118,7 @@ def parse_xml_and_check(xml_string: str) -> str:
118
 
119
  return parsed_response.get('RESPONSE', "No response found in the output")
120
 
121
- def RAG(llm: Any, query: str, index_name: str, embeddings: Any,vectorstore:PineconeVectorStore, top: int = 10, k: int = 100) -> Tuple[str, List[Document]]:
122
  """Main RAG function with improved error handling and validation."""
123
  start = time.time()
124
  try:
@@ -147,7 +147,6 @@ def RAG(llm: Any, query: str, index_name: str, embeddings: Any,vectorstore:Pinec
147
 
148
  <QUERY>{query}</QUERY>
149
  """
150
-
151
  )
152
  query_prompt = query_template.invoke({"query":query})
153
  query_response = llm.invoke(query_prompt)
 
118
 
119
  return parsed_response.get('RESPONSE', "No response found in the output")
120
 
121
+ def RAG(llm: Any, query: str,vectorstore:PineconeVectorStore, top: int = 10, k: int = 100) -> Tuple[str, List[Document]]:
122
  """Main RAG function with improved error handling and validation."""
123
  start = time.time()
124
  try:
 
147
 
148
  <QUERY>{query}</QUERY>
149
  """
 
150
  )
151
  query_prompt = query_template.invoke({"query":query})
152
  query_response = llm.invoke(query_prompt)
streamlit_app.py CHANGED
@@ -26,20 +26,29 @@ def initialize_models() -> Tuple[Optional[ChatOpenAI], HuggingFaceEmbeddings]:
26
  try:
27
  load_dotenv()
28
 
29
- # Initialize OpenAI model
30
- llm = ChatOpenAI(
31
- model="gpt-4", # Changed from gpt-4o-mini which appears to be a typo
32
- temperature=0,
33
- timeout=60, # Added reasonable timeout
34
- max_retries=2
35
- )
36
-
37
- # Initialize embeddings
38
- embeddings = HuggingFaceEmbeddings(
39
- model_name="sentence-transformers/all-MiniLM-L6-v2"
40
- )
41
 
42
- return llm, embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  except Exception as e:
45
  logger.error(f"Error initializing models: {str(e)}")
@@ -90,35 +99,24 @@ def display_sources(sources: List) -> None:
90
  st.error(f"Error displaying source {i}")
91
 
92
  def main():
93
- st.title("RAG Chatbot")
94
 
95
  INDEX_NAME = 'bpl-rag'
96
 
97
- pinecone_api_key = os.getenv("PINECONE_API_KEY")
98
-
99
  # Initialize session state
100
  if "messages" not in st.session_state:
101
  st.session_state.messages = []
102
 
103
  # Initialize models
104
- llm, embeddings = initialize_models()
105
- if not llm or not embeddings:
106
- st.error("Failed to initialize the application. Please check the logs.")
107
- return
108
-
109
- #initialize vectorstore
110
- pc = Pinecone(api_key=pinecone_api_key)
111
-
112
- index = pc.Index(INDEX_NAME)
113
- vector_store = PineconeVectorStore(index=index, embedding=embeddings)
114
-
115
  # Display chat history
116
  for message in st.session_state.messages:
117
  with st.chat_message(message["role"]):
118
  st.markdown(message["content"])
119
 
120
  # Chat input
121
- user_input = st.chat_input("Type your message here...")
122
  if user_input:
123
  # Display user message
124
  with st.chat_message("user"):
@@ -130,10 +128,10 @@ def main():
130
  with st.spinner("Let Me Think..."):
131
  response, sources = process_message(
132
  query=user_input,
133
- llm=llm,
134
  index_name=INDEX_NAME,
135
- embeddings=embeddings,
136
- vectorstore=vector_store
137
  )
138
 
139
  if isinstance(response, str):
 
26
  try:
27
  load_dotenv()
28
 
29
+ if "llm" not in st.session_state:
30
+ # Initialize OpenAI model
31
+ st.session_state.llm = ChatOpenAI(
32
+ model="gpt-4", # Changed from gpt-4o-mini which appears to be a typo
33
+ temperature=0,
34
+ timeout=60, # Added reasonable timeout
35
+ max_retries=2
36
+ )
 
 
 
 
37
 
38
+ if "embeddings" not in st.session_state:
39
+ # Initialize embeddings
40
+ st.session_state.embeddings = HuggingFaceEmbeddings(
41
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
42
+ )
43
+
44
+ if "pinecone" not in st.session_state:
45
+ pinecone_api_key = os.getenv("PINECONE_API_KEY")
46
+ INDEX_NAME = 'bpl-rag'
47
+ #initialize vectorstore
48
+ pc = Pinecone(api_key=pinecone_api_key)
49
+
50
+ index = pc.Index(INDEX_NAME)
51
+ st.session_state.pinecone = PineconeVectorStore(index=index, embedding=st.session_state.embeddings)
52
 
53
  except Exception as e:
54
  logger.error(f"Error initializing models: {str(e)}")
 
99
  st.error(f"Error displaying source {i}")
100
 
101
  def main():
102
+ st.title("Digital Commonwealth RAG")
103
 
104
  INDEX_NAME = 'bpl-rag'
105
 
 
 
106
  # Initialize session state
107
  if "messages" not in st.session_state:
108
  st.session_state.messages = []
109
 
110
  # Initialize models
111
+ initialize_models()
112
+
 
 
 
 
 
 
 
 
 
113
  # Display chat history
114
  for message in st.session_state.messages:
115
  with st.chat_message(message["role"]):
116
  st.markdown(message["content"])
117
 
118
  # Chat input
119
+ user_input = st.chat_input("Type your query here...")
120
  if user_input:
121
  # Display user message
122
  with st.chat_message("user"):
 
128
  with st.spinner("Let Me Think..."):
129
  response, sources = process_message(
130
  query=user_input,
131
+ llm=st.session_state.llm,
132
  index_name=INDEX_NAME,
133
+ embeddings=st.session_state.embeddings,
134
+ vectorstore=st.session_state.pinecone
135
  )
136
 
137
  if isinstance(response, str):