mgbam commited on
Commit
06ee039
Β·
verified Β·
1 Parent(s): 79eee55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -72
app.py CHANGED
@@ -1,12 +1,13 @@
 
 
 
 
 
1
  from langchain_openai import OpenAIEmbeddings
2
  from langchain_community.vectorstores import Chroma
3
- from langchain_core.messages import HumanMessage, AIMessage
4
- from langchain_text_splitters import RecursiveCharacterTextSplitter
5
- from langgraph.graph import StateGraph # Only import StateGraph from here
6
- # Define START and END manually
7
- START = "__start__"
8
- END = "__end__"
9
-
10
  from langgraph.prebuilt import ToolNode
11
  from langgraph.graph.message import add_messages
12
  from typing_extensions import TypedDict, Annotated
@@ -17,9 +18,9 @@ import streamlit as st
17
  import requests
18
  from langchain.tools.retriever import create_retriever_tool
19
 
20
- # --------------------------
21
- # Create Dummy Data
22
- # --------------------------
23
  research_texts = [
24
  "Research Report: Results of a New AI Model Improving Image Recognition Accuracy to 98%",
25
  "Academic Paper Summary: Why Transformers Became the Mainstream Architecture in Natural Language Processing",
@@ -32,41 +33,41 @@ development_texts = [
32
  "Product Y: In the Performance Optimization Stage Before Release"
33
  ]
34
 
35
- # --------------------------
36
- # Process the Data
37
- # --------------------------
38
  splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10)
39
-
40
  research_docs = splitter.create_documents(research_texts)
41
  development_docs = splitter.create_documents(development_texts)
42
 
 
 
 
43
  embeddings = OpenAIEmbeddings(
44
  model="text-embedding-3-large",
45
- dimensions=1024
 
46
  )
47
-
48
  research_vectorstore = Chroma.from_documents(
49
  documents=research_docs,
50
  embedding=embeddings,
51
  collection_name="research_collection"
52
  )
53
-
54
  development_vectorstore = Chroma.from_documents(
55
  documents=development_docs,
56
  embedding=embeddings,
57
  collection_name="development_collection"
58
  )
59
-
60
  research_retriever = research_vectorstore.as_retriever()
61
  development_retriever = development_vectorstore.as_retriever()
62
 
63
- # --------------------------
64
- # Create Retriever Tools
65
- # --------------------------
66
  research_tool = create_retriever_tool(
67
- research_retriever,
68
- "research_db_tool",
69
- "Search information from the research database."
70
  )
71
 
72
  development_tool = create_retriever_tool(
@@ -77,17 +78,22 @@ development_tool = create_retriever_tool(
77
 
78
  tools = [research_tool, development_tool]
79
 
80
- # --------------------------
81
- # Define the Agent Function
82
- # --------------------------
83
  class AgentState(TypedDict):
84
- messages: Annotated[Sequence[AIMessage | HumanMessage], add_messages]
85
 
86
  def agent(state: AgentState):
87
  print("---CALL AGENT---")
88
  messages = state["messages"]
89
- user_message = messages[0][1] if isinstance(messages[0], tuple) else messages[0].content
90
 
 
 
 
 
 
 
91
  prompt = f"""Given this user question: "{user_message}"
92
  If it's about research or academic topics, respond EXACTLY in this format:
93
  SEARCH_RESEARCH: <search terms>
@@ -100,7 +106,7 @@ Otherwise, just answer directly.
100
 
101
  headers = {
102
  "Accept": "application/json",
103
- "Authorization": "Bearer sk-1cddf19f9dc4466fa3ecea6fe10abec0",
104
  "Content-Type": "application/json"
105
  }
106
 
@@ -114,13 +120,15 @@ Otherwise, just answer directly.
114
  response = requests.post(
115
  "https://api.deepseek.com/v1/chat/completions",
116
  headers=headers,
117
- json=data
 
118
  )
119
 
120
  if response.status_code == 200:
121
  response_text = response.json()['choices'][0]['message']['content']
122
  print("Raw response:", response_text)
123
 
 
124
  if "SEARCH_RESEARCH:" in response_text:
125
  query = response_text.split("SEARCH_RESEARCH:")[1].strip()
126
  results = research_retriever.invoke(query)
@@ -136,9 +144,6 @@ Otherwise, just answer directly.
136
  else:
137
  raise Exception(f"API call failed: {response.text}")
138
 
139
- # --------------------------
140
- # Grading Function
141
- # --------------------------
142
  def simple_grade_documents(state: AgentState):
143
  messages = state["messages"]
144
  last_message = messages[-1]
@@ -151,15 +156,13 @@ def simple_grade_documents(state: AgentState):
151
  print("---NO DOCS FOUND, TRY REWRITE---")
152
  return "rewrite"
153
 
154
- # --------------------------
155
- # Generation Function
156
- # --------------------------
157
  def generate(state: AgentState):
158
  print("---GENERATE FINAL ANSWER---")
159
  messages = state["messages"]
160
- question = messages[0].content if not isinstance(messages[0], tuple) else messages[0][1]
161
  last_message = messages[-1]
162
 
 
163
  docs = ""
164
  if "Results: [" in last_message.content:
165
  results_start = last_message.content.find("Results: [")
@@ -168,7 +171,7 @@ def generate(state: AgentState):
168
 
169
  headers = {
170
  "Accept": "application/json",
171
- "Authorization": "Bearer sk-1cddf19f9dc4466fa3ecea6fe10abec0",
172
  "Content-Type": "application/json"
173
  }
174
 
@@ -180,7 +183,10 @@ Focus on extracting and synthesizing the key findings from the research papers.
180
 
181
  data = {
182
  "model": "deepseek-chat",
183
- "messages": [{"role": "user", "content": prompt}],
 
 
 
184
  "temperature": 0.7,
185
  "max_tokens": 1024
186
  }
@@ -189,7 +195,8 @@ Focus on extracting and synthesizing the key findings from the research papers.
189
  response = requests.post(
190
  "https://api.deepseek.com/v1/chat/completions",
191
  headers=headers,
192
- json=data
 
193
  )
194
 
195
  if response.status_code == 200:
@@ -199,9 +206,6 @@ Focus on extracting and synthesizing the key findings from the research papers.
199
  else:
200
  raise Exception(f"API call failed: {response.text}")
201
 
202
- # --------------------------
203
- # Rewrite Function
204
- # --------------------------
205
  def rewrite(state: AgentState):
206
  print("---REWRITE QUESTION---")
207
  messages = state["messages"]
@@ -209,23 +213,31 @@ def rewrite(state: AgentState):
209
 
210
  headers = {
211
  "Accept": "application/json",
212
- "Authorization": "Bearer sk-1cddf19f9dc4466fa3ecea6fe10abec0",
213
  "Content-Type": "application/json"
214
  }
215
 
216
  data = {
217
  "model": "deepseek-chat",
218
- "messages": [{"role": "user", "content": f"Rewrite this question to be more specific and clearer: {original_question}"}],
 
 
 
219
  "temperature": 0.7,
220
  "max_tokens": 1024
221
  }
222
 
 
223
  response = requests.post(
224
  "https://api.deepseek.com/v1/chat/completions",
225
  headers=headers,
226
- json=data
 
227
  )
228
 
 
 
 
229
  if response.status_code == 200:
230
  response_text = response.json()['choices'][0]['message']['content']
231
  print("Rewritten question:", response_text)
@@ -233,9 +245,6 @@ def rewrite(state: AgentState):
233
  else:
234
  raise Exception(f"API call failed: {response.text}")
235
 
236
- # --------------------------
237
- # Tools Decision Function
238
- # --------------------------
239
  tools_pattern = re.compile(r"Action: .*")
240
 
241
  def custom_tools_condition(state: AgentState):
@@ -244,11 +253,15 @@ def custom_tools_condition(state: AgentState):
244
  content = last_message.content
245
 
246
  print("Checking tools condition:", content)
247
- return "tools" if tools_pattern.match(content) else END
248
-
249
- # --------------------------
250
- # LangGraph Workflow Setup
251
- # --------------------------
 
 
 
 
252
  workflow = StateGraph(AgentState)
253
 
254
  workflow.add_node("agent", agent)
@@ -262,7 +275,10 @@ workflow.add_edge(START, "agent")
262
  workflow.add_conditional_edges(
263
  "agent",
264
  custom_tools_condition,
265
- {"tools": "retrieve", END: END}
 
 
 
266
  )
267
 
268
  workflow.add_conditional_edges("retrieve", simple_grade_documents)
@@ -271,18 +287,16 @@ workflow.add_edge("rewrite", "agent")
271
 
272
  app = workflow.compile()
273
 
274
- # --------------------------
275
- # Process Question Function
276
- # --------------------------
277
- def process_question(user_question):
278
  events = []
279
- for event in app.stream({"messages": [("user", user_question)]}):
280
  events.append(event)
281
  return events
282
 
283
- # --------------------------
284
- # Streamlit Application
285
- # --------------------------
286
  def main():
287
  st.set_page_config(
288
  page_title="AI Research & Development Assistant",
@@ -290,21 +304,40 @@ def main():
290
  initial_sidebar_state="expanded"
291
  )
292
 
 
293
  st.markdown("""
294
  <style>
295
- .stApp { background-color: #f8f9fa; }
296
- .stButton > button { width: 100%; margin-top: 20px; }
297
- .data-box { padding: 20px; border-radius: 10px; margin: 10px 0; }
298
- .research-box { background-color: #e3f2fd; border-left: 5px solid #1976d2; }
299
- .dev-box { background-color: #e8f5e9; border-left: 5px solid #43a047; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  </style>
301
  """, unsafe_allow_html=True)
302
 
 
303
  with st.sidebar:
304
  st.header("πŸ“š Available Data")
 
305
  st.subheader("Research Database")
306
  for text in research_texts:
307
  st.markdown(f'<div class="data-box research-box">{text}</div>', unsafe_allow_html=True)
 
308
  st.subheader("Development Database")
309
  for text in development_texts:
310
  st.markdown(f'<div class="data-box dev-box">{text}</div>', unsafe_allow_html=True)
@@ -312,15 +345,16 @@ def main():
312
  st.title("πŸ€– AI Research & Development Assistant")
313
  st.markdown("---")
314
 
315
- query = st.text_area("Enter your question:", height=100,
316
- placeholder="e.g., What is the latest advancement in AI research?")
317
 
318
  col1, col2 = st.columns([1, 2])
319
  with col1:
320
  if st.button("πŸ” Get Answer", use_container_width=True):
321
  if query:
322
  with st.spinner('Processing your question...'):
323
- events = process_question(query)
 
324
  for event in events:
325
  if 'agent' in event:
326
  with st.expander("πŸ”„ Processing Step", expanded=True):
 
1
+ # pip install -r requirements.txt
2
+
3
+ # ------------------------------
4
+ # Imports & Dependencies
5
+ # ------------------------------
6
  from langchain_openai import OpenAIEmbeddings
7
  from langchain_community.vectorstores import Chroma
8
+ from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langgraph.graph import END, StateGraph, START
 
 
 
 
11
  from langgraph.prebuilt import ToolNode
12
  from langgraph.graph.message import add_messages
13
  from typing_extensions import TypedDict, Annotated
 
18
  import requests
19
  from langchain.tools.retriever import create_retriever_tool
20
 
21
+ # ------------------------------
22
+ # Dummy Data: Research & Development Texts
23
+ # ------------------------------
24
  research_texts = [
25
  "Research Report: Results of a New AI Model Improving Image Recognition Accuracy to 98%",
26
  "Academic Paper Summary: Why Transformers Became the Mainstream Architecture in Natural Language Processing",
 
33
  "Product Y: In the Performance Optimization Stage Before Release"
34
  ]
35
 
36
+ # ------------------------------
37
+ # Text Splitting & Document Creation
38
+ # ------------------------------
39
  splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10)
 
40
  research_docs = splitter.create_documents(research_texts)
41
  development_docs = splitter.create_documents(development_texts)
42
 
43
+ # ------------------------------
44
+ # Creating Vector Stores with Embeddings
45
+ # ------------------------------
46
  embeddings = OpenAIEmbeddings(
47
  model="text-embedding-3-large",
48
+ # You can uncomment and set dimensions if needed:
49
+ # dimensions=1024
50
  )
 
51
  research_vectorstore = Chroma.from_documents(
52
  documents=research_docs,
53
  embedding=embeddings,
54
  collection_name="research_collection"
55
  )
 
56
  development_vectorstore = Chroma.from_documents(
57
  documents=development_docs,
58
  embedding=embeddings,
59
  collection_name="development_collection"
60
  )
 
61
  research_retriever = research_vectorstore.as_retriever()
62
  development_retriever = development_vectorstore.as_retriever()
63
 
64
+ # ------------------------------
65
+ # Creating Retriever Tools
66
+ # ------------------------------
67
  research_tool = create_retriever_tool(
68
+ research_retriever, # Retriever object
69
+ "research_db_tool", # Name of the tool
70
+ "Search information from the research database." # Tool description
71
  )
72
 
73
  development_tool = create_retriever_tool(
 
78
 
79
  tools = [research_tool, development_tool]
80
 
81
+ # ------------------------------
82
+ # Agent Function & Workflow Functions
83
+ # ------------------------------
84
  class AgentState(TypedDict):
85
+ messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages]
86
 
87
  def agent(state: AgentState):
88
  print("---CALL AGENT---")
89
  messages = state["messages"]
 
90
 
91
+ if isinstance(messages[0], tuple):
92
+ user_message = messages[0][1]
93
+ else:
94
+ user_message = messages[0].content
95
+
96
+ # Structure prompt for consistent text output
97
  prompt = f"""Given this user question: "{user_message}"
98
  If it's about research or academic topics, respond EXACTLY in this format:
99
  SEARCH_RESEARCH: <search terms>
 
106
 
107
  headers = {
108
  "Accept": "application/json",
109
+ "Authorization": f"Bearer sk-1cddf19f9dc4466fa3ecea6fe10abec0",
110
  "Content-Type": "application/json"
111
  }
112
 
 
120
  response = requests.post(
121
  "https://api.deepseek.com/v1/chat/completions",
122
  headers=headers,
123
+ json=data,
124
+ verify=False
125
  )
126
 
127
  if response.status_code == 200:
128
  response_text = response.json()['choices'][0]['message']['content']
129
  print("Raw response:", response_text)
130
 
131
+ # Format the response into expected tool format
132
  if "SEARCH_RESEARCH:" in response_text:
133
  query = response_text.split("SEARCH_RESEARCH:")[1].strip()
134
  results = research_retriever.invoke(query)
 
144
  else:
145
  raise Exception(f"API call failed: {response.text}")
146
 
 
 
 
147
  def simple_grade_documents(state: AgentState):
148
  messages = state["messages"]
149
  last_message = messages[-1]
 
156
  print("---NO DOCS FOUND, TRY REWRITE---")
157
  return "rewrite"
158
 
 
 
 
159
  def generate(state: AgentState):
160
  print("---GENERATE FINAL ANSWER---")
161
  messages = state["messages"]
162
+ question = messages[0].content if isinstance(messages[0], tuple) else messages[0].content
163
  last_message = messages[-1]
164
 
165
+ # Extract the document content from the results
166
  docs = ""
167
  if "Results: [" in last_message.content:
168
  results_start = last_message.content.find("Results: [")
 
171
 
172
  headers = {
173
  "Accept": "application/json",
174
+ "Authorization": f"Bearer sk-1cddf19f9dc4466fa3ecea6fe10abec0",
175
  "Content-Type": "application/json"
176
  }
177
 
 
183
 
184
  data = {
185
  "model": "deepseek-chat",
186
+ "messages": [{
187
+ "role": "user",
188
+ "content": prompt
189
+ }],
190
  "temperature": 0.7,
191
  "max_tokens": 1024
192
  }
 
195
  response = requests.post(
196
  "https://api.deepseek.com/v1/chat/completions",
197
  headers=headers,
198
+ json=data,
199
+ verify=False
200
  )
201
 
202
  if response.status_code == 200:
 
206
  else:
207
  raise Exception(f"API call failed: {response.text}")
208
 
 
 
 
209
  def rewrite(state: AgentState):
210
  print("---REWRITE QUESTION---")
211
  messages = state["messages"]
 
213
 
214
  headers = {
215
  "Accept": "application/json",
216
+ "Authorization": f"Bearer sk-1cddf19f9dc4466fa3ecea6fe10abec0",
217
  "Content-Type": "application/json"
218
  }
219
 
220
  data = {
221
  "model": "deepseek-chat",
222
+ "messages": [{
223
+ "role": "user",
224
+ "content": f"Rewrite this question to be more specific and clearer: {original_question}"
225
+ }],
226
  "temperature": 0.7,
227
  "max_tokens": 1024
228
  }
229
 
230
+ print("Sending rewrite request...")
231
  response = requests.post(
232
  "https://api.deepseek.com/v1/chat/completions",
233
  headers=headers,
234
+ json=data,
235
+ verify=False
236
  )
237
 
238
+ print("Status Code:", response.status_code)
239
+ print("Response:", response.text)
240
+
241
  if response.status_code == 200:
242
  response_text = response.json()['choices'][0]['message']['content']
243
  print("Rewritten question:", response_text)
 
245
  else:
246
  raise Exception(f"API call failed: {response.text}")
247
 
 
 
 
248
  tools_pattern = re.compile(r"Action: .*")
249
 
250
  def custom_tools_condition(state: AgentState):
 
253
  content = last_message.content
254
 
255
  print("Checking tools condition:", content)
256
+ if tools_pattern.match(content):
257
+ print("Moving to retrieve...")
258
+ return "tools"
259
+ print("Moving to END...")
260
+ return END
261
+
262
+ # ------------------------------
263
+ # Workflow Configuration using LangGraph
264
+ # ------------------------------
265
  workflow = StateGraph(AgentState)
266
 
267
  workflow.add_node("agent", agent)
 
275
  workflow.add_conditional_edges(
276
  "agent",
277
  custom_tools_condition,
278
+ {
279
+ "tools": "retrieve",
280
+ END: END
281
+ }
282
  )
283
 
284
  workflow.add_conditional_edges("retrieve", simple_grade_documents)
 
287
 
288
  app = workflow.compile()
289
 
290
+ def process_question(user_question, app, config):
291
+ """Process user question through the workflow"""
 
 
292
  events = []
293
+ for event in app.stream({"messages": [("user", user_question)]}, config):
294
  events.append(event)
295
  return events
296
 
297
+ # ------------------------------
298
+ # Streamlit App UI
299
+ # ------------------------------
300
  def main():
301
  st.set_page_config(
302
  page_title="AI Research & Development Assistant",
 
304
  initial_sidebar_state="expanded"
305
  )
306
 
307
+ # Custom CSS for styling
308
  st.markdown("""
309
  <style>
310
+ .stApp {
311
+ background-color: #f8f9fa;
312
+ }
313
+ .stButton > button {
314
+ width: 100%;
315
+ margin-top: 20px;
316
+ }
317
+ .data-box {
318
+ padding: 20px;
319
+ border-radius: 10px;
320
+ margin: 10px 0;
321
+ }
322
+ .research-box {
323
+ background-color: #e3f2fd;
324
+ border-left: 5px solid #1976d2;
325
+ }
326
+ .dev-box {
327
+ background-color: #e8f5e9;
328
+ border-left: 5px solid #43a047;
329
+ }
330
  </style>
331
  """, unsafe_allow_html=True)
332
 
333
+ # Sidebar with available data
334
  with st.sidebar:
335
  st.header("πŸ“š Available Data")
336
+
337
  st.subheader("Research Database")
338
  for text in research_texts:
339
  st.markdown(f'<div class="data-box research-box">{text}</div>', unsafe_allow_html=True)
340
+
341
  st.subheader("Development Database")
342
  for text in development_texts:
343
  st.markdown(f'<div class="data-box dev-box">{text}</div>', unsafe_allow_html=True)
 
345
  st.title("πŸ€– AI Research & Development Assistant")
346
  st.markdown("---")
347
 
348
+ # Query input box
349
+ query = st.text_area("Enter your question:", height=100, placeholder="e.g., What is the latest advancement in AI research?")
350
 
351
  col1, col2 = st.columns([1, 2])
352
  with col1:
353
  if st.button("πŸ” Get Answer", use_container_width=True):
354
  if query:
355
  with st.spinner('Processing your question...'):
356
+ events = process_question(query, app, {"configurable": {"thread_id": "1"}})
357
+
358
  for event in events:
359
  if 'agent' in event:
360
  with st.expander("πŸ”„ Processing Step", expanded=True):