Neda1 commited on
Commit
c2bdf0d
Β·
verified Β·
1 Parent(s): 17cd610

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +27 -6
agent.py CHANGED
@@ -185,13 +185,34 @@ def build_graph(provider: str = "groq"):
185
  """Assistant node"""
186
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
187
 
 
 
 
 
 
 
 
 
188
  def retriever(state: MessagesState):
189
- """Retriever node"""
190
- similar_question = vector_store.similarity_search(state["messages"][0].content)
191
- example_msg = HumanMessage(
192
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
193
- )
194
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  builder = StateGraph(MessagesState)
197
  builder.add_node("retriever", retriever)
 
185
  """Assistant node"""
186
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
187
 
188
+ # def retriever(state: MessagesState):
189
+ # """Retriever node"""
190
+ # similar_question = vector_store.similarity_search(state["messages"][0].content)
191
+ # example_msg = HumanMessage(
192
+ # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
193
+ # )
194
+ # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
195
+
196
  def retriever(state: MessagesState):
197
+ """Retriever node"""
198
+ messages = state.get("messages", [])
199
+ if not messages:
200
+ print("⚠️ No messages received in retriever node.")
201
+ return {"messages": []}
202
+
203
+ query = messages[0].content
204
+ print(f"πŸ” Running similarity search for: {query}")
205
+ similar_question = vector_store.similarity_search(query)
206
+
207
+ if not similar_question:
208
+ print("⚠️ No similar questions found.")
209
+ return {"messages": messages} # Return unchanged messages
210
+
211
+ example_msg = HumanMessage(
212
+ content=f"Here I provide a similar question and answer for reference:\n\n{similar_question[0].page_content}",
213
+ )
214
+
215
+ return {"messages": [sys_msg] + messages + [example_msg]}
216
 
217
  builder = StateGraph(MessagesState)
218
  builder.add_node("retriever", retriever)