mgbam commited on
Commit
29bc714
Β·
verified Β·
1 Parent(s): bb7769c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -89
app.py CHANGED
@@ -1,6 +1,6 @@
1
- from langchain_openai import OpenAIEmbeddings
2
- from langchain_community.vectorstores import Chroma
3
- from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langgraph.graph import END, StateGraph, START
6
  from langgraph.prebuilt import ToolNode
@@ -31,21 +31,16 @@ development_texts = [
31
  # --------------------------
32
  # Process the Data
33
  # --------------------------
34
- # Text splitting settings
35
  splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10)
36
 
37
- # Generate Document objects from text
38
  research_docs = splitter.create_documents(research_texts)
39
  development_docs = splitter.create_documents(development_texts)
40
 
41
- # Create vector embeddings
42
  embeddings = OpenAIEmbeddings(
43
- model="text-embedding-3-large",
44
- # For text-embedding-3 class models, you can specify dimensions if needed.
45
- # dimensions=1024
46
  )
47
 
48
- # Create vector stores
49
  research_vectorstore = Chroma.from_documents(
50
  documents=research_docs,
51
  embedding=embeddings,
@@ -58,7 +53,6 @@ development_vectorstore = Chroma.from_documents(
58
  collection_name="development_collection"
59
  )
60
 
61
- # Create retrievers from the vector stores
62
  research_retriever = research_vectorstore.as_retriever()
63
  development_retriever = development_vectorstore.as_retriever()
64
 
@@ -66,9 +60,9 @@ development_retriever = development_vectorstore.as_retriever()
66
  # Create Retriever Tools
67
  # --------------------------
68
  research_tool = create_retriever_tool(
69
- research_retriever, # Retriever object
70
- "research_db_tool", # Name of the tool to create
71
- "Search information from the research database." # Description of the tool
72
  )
73
 
74
  development_tool = create_retriever_tool(
@@ -77,25 +71,19 @@ development_tool = create_retriever_tool(
77
  "Search information from the development database."
78
  )
79
 
80
- # Combine the tools into a list
81
  tools = [research_tool, development_tool]
82
 
83
  # --------------------------
84
  # Define the Agent Function
85
  # --------------------------
86
  class AgentState(TypedDict):
87
- messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages]
88
 
89
  def agent(state: AgentState):
90
  print("---CALL AGENT---")
91
  messages = state["messages"]
 
92
 
93
- if isinstance(messages[0], tuple):
94
- user_message = messages[0][1]
95
- else:
96
- user_message = messages[0].content
97
-
98
- # Structure prompt for consistent text output
99
  prompt = f"""Given this user question: "{user_message}"
100
  If it's about research or academic topics, respond EXACTLY in this format:
101
  SEARCH_RESEARCH: <search terms>
@@ -108,7 +96,7 @@ Otherwise, just answer directly.
108
 
109
  headers = {
110
  "Accept": "application/json",
111
- "Authorization": f"Bearer sk-1cddf19f9dc4466fa3ecea6fe10abec0",
112
  "Content-Type": "application/json"
113
  }
114
 
@@ -130,16 +118,13 @@ Otherwise, just answer directly.
130
  response_text = response.json()['choices'][0]['message']['content']
131
  print("Raw response:", response_text)
132
 
133
- # Format the response into expected tool format
134
  if "SEARCH_RESEARCH:" in response_text:
135
  query = response_text.split("SEARCH_RESEARCH:")[1].strip()
136
- # Use direct call to research retriever
137
  results = research_retriever.invoke(query)
138
  return {"messages": [AIMessage(content=f'Action: research_db_tool\n{{"query": "{query}"}}\n\nResults: {str(results)}')]}
139
 
140
  elif "SEARCH_DEV:" in response_text:
141
  query = response_text.split("SEARCH_DEV:")[1].strip()
142
- # Use direct call to development retriever
143
  results = development_retriever.invoke(query)
144
  return {"messages": [AIMessage(content=f'Action: development_db_tool\n{{"query": "{query}"}}\n\nResults: {str(results)}')]}
145
 
@@ -156,7 +141,6 @@ def simple_grade_documents(state: AgentState):
156
  last_message = messages[-1]
157
  print("Evaluating message:", last_message.content)
158
 
159
- # Check if the content contains retrieved documents
160
  if "Results: [Document" in last_message.content:
161
  print("---DOCS FOUND, GO TO GENERATE---")
162
  return "generate"
@@ -170,10 +154,9 @@ def simple_grade_documents(state: AgentState):
170
  def generate(state: AgentState):
171
  print("---GENERATE FINAL ANSWER---")
172
  messages = state["messages"]
173
- question = messages[0].content if isinstance(messages[0], tuple) else messages[0].content
174
  last_message = messages[-1]
175
 
176
- # Extract the document content from the results
177
  docs = ""
178
  if "Results: [" in last_message.content:
179
  results_start = last_message.content.find("Results: [")
@@ -182,7 +165,7 @@ def generate(state: AgentState):
182
 
183
  headers = {
184
  "Accept": "application/json",
185
- "Authorization": f"Bearer sk-1cddf19f9dc4466fa3ecea6fe10abec0",
186
  "Content-Type": "application/json"
187
  }
188
 
@@ -194,10 +177,7 @@ Focus on extracting and synthesizing the key findings from the research papers.
194
 
195
  data = {
196
  "model": "deepseek-chat",
197
- "messages": [{
198
- "role": "user",
199
- "content": prompt
200
- }],
201
  "temperature": 0.7,
202
  "max_tokens": 1024
203
  }
@@ -227,21 +207,17 @@ def rewrite(state: AgentState):
227
 
228
  headers = {
229
  "Accept": "application/json",
230
- "Authorization": f"Bearer sk-1cddf19f9dc4466fa3ecea6fe10abec0",
231
  "Content-Type": "application/json"
232
  }
233
 
234
  data = {
235
  "model": "deepseek-chat",
236
- "messages": [{
237
- "role": "user",
238
- "content": f"Rewrite this question to be more specific and clearer: {original_question}"
239
- }],
240
  "temperature": 0.7,
241
  "max_tokens": 1024
242
  }
243
 
244
- print("Sending rewrite request...")
245
  response = requests.post(
246
  "https://api.deepseek.com/v1/chat/completions",
247
  headers=headers,
@@ -249,9 +225,6 @@ def rewrite(state: AgentState):
249
  verify=False
250
  )
251
 
252
- print("Status Code:", response.status_code)
253
- print("Response:", response.text)
254
-
255
  if response.status_code == 200:
256
  response_text = response.json()['choices'][0]['message']['content']
257
  print("Rewritten question:", response_text)
@@ -270,50 +243,37 @@ def custom_tools_condition(state: AgentState):
270
  content = last_message.content
271
 
272
  print("Checking tools condition:", content)
273
- if tools_pattern.match(content):
274
- print("Moving to retrieve...")
275
- return "tools"
276
- print("Moving to END...")
277
- return END
278
 
279
  # --------------------------
280
  # LangGraph Workflow Setup
281
  # --------------------------
282
  workflow = StateGraph(AgentState)
283
 
284
- # Define the workflow nodes
285
  workflow.add_node("agent", agent)
286
  retrieve_node = ToolNode(tools)
287
  workflow.add_node("retrieve", retrieve_node)
288
  workflow.add_node("rewrite", rewrite)
289
  workflow.add_node("generate", generate)
290
 
291
- # Set up the initial edge
292
  workflow.add_edge(START, "agent")
293
 
294
- # Conditional edge from agent to either retrieve (if tool is called) or END
295
  workflow.add_conditional_edges(
296
  "agent",
297
  custom_tools_condition,
298
- {
299
- "tools": "retrieve",
300
- END: END
301
- }
302
  )
303
 
304
- # After retrieval, decide to generate or rewrite based on document grading
305
  workflow.add_conditional_edges("retrieve", simple_grade_documents)
306
  workflow.add_edge("generate", END)
307
  workflow.add_edge("rewrite", "agent")
308
 
309
- # Compile the workflow to make it executable
310
  app = workflow.compile()
311
 
312
  # --------------------------
313
  # Process Question Function
314
  # --------------------------
315
  def process_question(user_question, app, config):
316
- """Process user question through the workflow"""
317
  events = []
318
  for event in app.stream({"messages": [("user", user_question)]}, config):
319
  events.append(event)
@@ -329,60 +289,37 @@ def main():
329
  initial_sidebar_state="expanded"
330
  )
331
 
332
- # Custom CSS
333
  st.markdown("""
334
  <style>
335
- .stApp {
336
- background-color: #f8f9fa;
337
- }
338
- .stButton > button {
339
- width: 100%;
340
- margin-top: 20px;
341
- }
342
- .data-box {
343
- padding: 20px;
344
- border-radius: 10px;
345
- margin: 10px 0;
346
- }
347
- .research-box {
348
- background-color: #e3f2fd;
349
- border-left: 5px solid #1976d2;
350
- }
351
- .dev-box {
352
- background-color: #e8f5e9;
353
- border-left: 5px solid #43a047;
354
- }
355
  </style>
356
  """, unsafe_allow_html=True)
357
 
358
- # Sidebar with Data Display
359
  with st.sidebar:
360
  st.header("πŸ“š Available Data")
361
-
362
  st.subheader("Research Database")
363
  for text in research_texts:
364
  st.markdown(f'<div class="data-box research-box">{text}</div>', unsafe_allow_html=True)
365
-
366
  st.subheader("Development Database")
367
  for text in development_texts:
368
  st.markdown(f'<div class="data-box dev-box">{text}</div>', unsafe_allow_html=True)
369
 
370
- # Main Content
371
  st.title("πŸ€– AI Research & Development Assistant")
372
  st.markdown("---")
373
 
374
- # Query Input
375
- query = st.text_area("Enter your question:", height=100, placeholder="e.g., What is the latest advancement in AI research?")
376
 
377
  col1, col2 = st.columns([1, 2])
378
  with col1:
379
  if st.button("πŸ” Get Answer", use_container_width=True):
380
  if query:
381
  with st.spinner('Processing your question...'):
382
- # Process query through workflow
383
  events = process_question(query, app, {"configurable": {"thread_id": "1"}})
384
-
385
- # Display results
386
  for event in events:
387
  if 'agent' in event:
388
  with st.expander("πŸ”„ Processing Step", expanded=True):
@@ -412,4 +349,4 @@ def main():
412
  """)
413
 
414
  if __name__ == "__main__":
415
- main()
 
1
+ from langchain_openai import OpenAIEmbeddings # Updated import path
2
+ from langchain.vectorstores import Chroma
3
+ from langchain.schema import HumanMessage, AIMessage
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langgraph.graph import END, StateGraph, START
6
  from langgraph.prebuilt import ToolNode
 
31
  # --------------------------
32
  # Process the Data
33
  # --------------------------
 
34
  splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10)
35
 
 
36
  research_docs = splitter.create_documents(research_texts)
37
  development_docs = splitter.create_documents(development_texts)
38
 
 
39
  embeddings = OpenAIEmbeddings(
40
+ model="text-embedding-3-large"
41
+ # dimensions=1024 # Uncomment if needed
 
42
  )
43
 
 
44
  research_vectorstore = Chroma.from_documents(
45
  documents=research_docs,
46
  embedding=embeddings,
 
53
  collection_name="development_collection"
54
  )
55
 
 
56
  research_retriever = research_vectorstore.as_retriever()
57
  development_retriever = development_vectorstore.as_retriever()
58
 
 
60
  # Create Retriever Tools
61
  # --------------------------
62
  research_tool = create_retriever_tool(
63
+ research_retriever,
64
+ "research_db_tool",
65
+ "Search information from the research database."
66
  )
67
 
68
  development_tool = create_retriever_tool(
 
71
  "Search information from the development database."
72
  )
73
 
 
74
  tools = [research_tool, development_tool]
75
 
76
  # --------------------------
77
  # Define the Agent Function
78
  # --------------------------
79
  class AgentState(TypedDict):
80
+ messages: Annotated[Sequence[AIMessage | HumanMessage], add_messages]
81
 
82
  def agent(state: AgentState):
83
  print("---CALL AGENT---")
84
  messages = state["messages"]
85
+ user_message = messages[0][1] if isinstance(messages[0], tuple) else messages[0].content
86
 
 
 
 
 
 
 
87
  prompt = f"""Given this user question: "{user_message}"
88
  If it's about research or academic topics, respond EXACTLY in this format:
89
  SEARCH_RESEARCH: <search terms>
 
96
 
97
  headers = {
98
  "Accept": "application/json",
99
+ "Authorization": "Bearer sk-1cddf19f9dc4466fa3ecea6fe10abec0",
100
  "Content-Type": "application/json"
101
  }
102
 
 
118
  response_text = response.json()['choices'][0]['message']['content']
119
  print("Raw response:", response_text)
120
 
 
121
  if "SEARCH_RESEARCH:" in response_text:
122
  query = response_text.split("SEARCH_RESEARCH:")[1].strip()
 
123
  results = research_retriever.invoke(query)
124
  return {"messages": [AIMessage(content=f'Action: research_db_tool\n{{"query": "{query}"}}\n\nResults: {str(results)}')]}
125
 
126
  elif "SEARCH_DEV:" in response_text:
127
  query = response_text.split("SEARCH_DEV:")[1].strip()
 
128
  results = development_retriever.invoke(query)
129
  return {"messages": [AIMessage(content=f'Action: development_db_tool\n{{"query": "{query}"}}\n\nResults: {str(results)}')]}
130
 
 
141
  last_message = messages[-1]
142
  print("Evaluating message:", last_message.content)
143
 
 
144
  if "Results: [Document" in last_message.content:
145
  print("---DOCS FOUND, GO TO GENERATE---")
146
  return "generate"
 
154
  def generate(state: AgentState):
155
  print("---GENERATE FINAL ANSWER---")
156
  messages = state["messages"]
157
+ question = messages[0].content if not isinstance(messages[0], tuple) else messages[0][1]
158
  last_message = messages[-1]
159
 
 
160
  docs = ""
161
  if "Results: [" in last_message.content:
162
  results_start = last_message.content.find("Results: [")
 
165
 
166
  headers = {
167
  "Accept": "application/json",
168
+ "Authorization": "Bearer sk-1cddf19f9dc4466fa3ecea6fe10abec0",
169
  "Content-Type": "application/json"
170
  }
171
 
 
177
 
178
  data = {
179
  "model": "deepseek-chat",
180
+ "messages": [{"role": "user", "content": prompt}],
 
 
 
181
  "temperature": 0.7,
182
  "max_tokens": 1024
183
  }
 
207
 
208
  headers = {
209
  "Accept": "application/json",
210
+ "Authorization": "Bearer sk-1cddf19f9dc4466fa3ecea6fe10abec0",
211
  "Content-Type": "application/json"
212
  }
213
 
214
  data = {
215
  "model": "deepseek-chat",
216
+ "messages": [{"role": "user", "content": f"Rewrite this question to be more specific and clearer: {original_question}"}],
 
 
 
217
  "temperature": 0.7,
218
  "max_tokens": 1024
219
  }
220
 
 
221
  response = requests.post(
222
  "https://api.deepseek.com/v1/chat/completions",
223
  headers=headers,
 
225
  verify=False
226
  )
227
 
 
 
 
228
  if response.status_code == 200:
229
  response_text = response.json()['choices'][0]['message']['content']
230
  print("Rewritten question:", response_text)
 
243
  content = last_message.content
244
 
245
  print("Checking tools condition:", content)
246
+ return "tools" if tools_pattern.match(content) else END
 
 
 
 
247
 
248
  # --------------------------
249
  # LangGraph Workflow Setup
250
  # --------------------------
251
  workflow = StateGraph(AgentState)
252
 
 
253
  workflow.add_node("agent", agent)
254
  retrieve_node = ToolNode(tools)
255
  workflow.add_node("retrieve", retrieve_node)
256
  workflow.add_node("rewrite", rewrite)
257
  workflow.add_node("generate", generate)
258
 
 
259
  workflow.add_edge(START, "agent")
260
 
 
261
  workflow.add_conditional_edges(
262
  "agent",
263
  custom_tools_condition,
264
+ {"tools": "retrieve", END: END}
 
 
 
265
  )
266
 
 
267
  workflow.add_conditional_edges("retrieve", simple_grade_documents)
268
  workflow.add_edge("generate", END)
269
  workflow.add_edge("rewrite", "agent")
270
 
 
271
  app = workflow.compile()
272
 
273
  # --------------------------
274
  # Process Question Function
275
  # --------------------------
276
  def process_question(user_question, app, config):
 
277
  events = []
278
  for event in app.stream({"messages": [("user", user_question)]}, config):
279
  events.append(event)
 
289
  initial_sidebar_state="expanded"
290
  )
291
 
 
292
  st.markdown("""
293
  <style>
294
+ .stApp { background-color: #f8f9fa; }
295
+ .stButton > button { width: 100%; margin-top: 20px; }
296
+ .data-box { padding: 20px; border-radius: 10px; margin: 10px 0; }
297
+ .research-box { background-color: #e3f2fd; border-left: 5px solid #1976d2; }
298
+ .dev-box { background-color: #e8f5e9; border-left: 5px solid #43a047; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  </style>
300
  """, unsafe_allow_html=True)
301
 
 
302
  with st.sidebar:
303
  st.header("πŸ“š Available Data")
 
304
  st.subheader("Research Database")
305
  for text in research_texts:
306
  st.markdown(f'<div class="data-box research-box">{text}</div>', unsafe_allow_html=True)
 
307
  st.subheader("Development Database")
308
  for text in development_texts:
309
  st.markdown(f'<div class="data-box dev-box">{text}</div>', unsafe_allow_html=True)
310
 
 
311
  st.title("πŸ€– AI Research & Development Assistant")
312
  st.markdown("---")
313
 
314
+ query = st.text_area("Enter your question:", height=100,
315
+ placeholder="e.g., What is the latest advancement in AI research?")
316
 
317
  col1, col2 = st.columns([1, 2])
318
  with col1:
319
  if st.button("πŸ” Get Answer", use_container_width=True):
320
  if query:
321
  with st.spinner('Processing your question...'):
 
322
  events = process_question(query, app, {"configurable": {"thread_id": "1"}})
 
 
323
  for event in events:
324
  if 'agent' in event:
325
  with st.expander("πŸ”„ Processing Step", expanded=True):
 
349
  """)
350
 
351
  if __name__ == "__main__":
352
+ main()