pvanand commited on
Commit
1f6995c
·
verified ·
1 Parent(s): df76ee8

add update to system message

Browse files
Files changed (1) hide show
  1. main.py +65 -52
main.py CHANGED
@@ -4,6 +4,7 @@ from fastapi.responses import StreamingResponse
4
  from langchain_core.messages import (
5
  BaseMessage,
6
  HumanMessage,
 
7
  trim_messages,
8
  )
9
  from langchain_core.tools import tool
@@ -22,6 +23,8 @@ import requests
22
  from sse_starlette.sse import EventSourceResponse
23
  from fastapi.middleware.cors import CORSMiddleware
24
  import re
 
 
25
 
26
  app = FastAPI()
27
  app.include_router(document_rag_router)
@@ -34,6 +37,14 @@ app.add_middleware(
34
  allow_headers=["*"],
35
  )
36
 
 
 
 
 
 
 
 
 
37
  @tool
38
  def get_user_age(name: str) -> str:
39
  """Use this tool to find the user's age."""
@@ -45,7 +56,6 @@ def get_user_age(name: str) -> str:
45
  async def query_documents(
46
  query: str,
47
  config: RunnableConfig,
48
- #state: Annotated[dict, InjectedState]
49
  ) -> str:
50
  """Use this tool to retrieve relevant data from the collection.
51
 
@@ -89,11 +99,9 @@ async def query_documents(
89
  print(e)
90
  return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"
91
 
92
-
93
  async def query_documents_raw(
94
  query: str,
95
  config: RunnableConfig,
96
- #state: Annotated[dict, InjectedState]
97
  ) -> SearchResult:
98
  """Use this tool to retrieve relevant data from the collection.
99
 
@@ -126,22 +134,60 @@ async def query_documents_raw(
126
  memory = MemorySaver()
127
  model = ChatOpenAI(model="gpt-4o-mini", streaming=True)
128
 
129
- def state_modifier(state) -> list[BaseMessage]:
130
- return trim_messages(
131
- state["messages"],
132
- token_counter=len,
133
- max_tokens=16000,
134
- strategy="last",
135
- start_on="human",
136
- include_system=True,
137
- allow_partial=False,
138
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  agent = create_react_agent(
141
  model,
142
  tools=[query_documents],
143
  checkpointer=memory,
144
- state_modifier=state_modifier,
145
  )
146
 
147
  class ChatInput(BaseModel):
@@ -190,43 +236,6 @@ async def chat(input_data: ChatInput):
190
  media_type="text/event-stream"
191
  )
192
 
193
- async def clean_tool_input(tool_input: str):
194
- # Use regex to parse the first key and value
195
- pattern = r"{\s*'([^']+)':\s*'([^']+)'"
196
- match = re.search(pattern, tool_input)
197
- if match:
198
- key, value = match.groups()
199
- return {key: value}
200
- return [tool_input]
201
-
202
- async def clean_tool_response(tool_output: str):
203
- """Clean and extract relevant information from tool response if it contains query_documents."""
204
- if "query_documents" in tool_output:
205
- try:
206
- # First safely evaluate the string as a Python literal
207
- import ast
208
- print(tool_output)
209
- # Extract the list string from the content
210
- start = tool_output.find("[{")
211
- end = tool_output.rfind("}]") + 2
212
- if start >= 0 and end > 0:
213
- list_str = tool_output[start:end]
214
-
215
- # Convert string to Python object using ast.literal_eval
216
- results = ast.literal_eval(list_str)
217
-
218
- # Return only relevant fields
219
- return [{"text": r["text"], "document_id": r["metadata"]["document_id"]}
220
- for r in results]
221
-
222
- except SyntaxError as e:
223
- print(f"Syntax error in parsing: {e}")
224
- return f"Error parsing document results: {str(e)}"
225
- except Exception as e:
226
- print(f"General error: {e}")
227
- return f"Error processing results: {str(e)}"
228
- return tool_output
229
-
230
  @app.post("/chat2")
231
  async def chat2(input_data: ChatInput):
232
  thread_id = input_data.thread_id or str(uuid.uuid4())
@@ -290,4 +299,8 @@ async def chat2(input_data: ChatInput):
290
 
291
  @app.get("/health")
292
  async def health_check():
293
- return {"status": "healthy"}
 
 
 
 
 
4
  from langchain_core.messages import (
5
  BaseMessage,
6
  HumanMessage,
7
+ SystemMessage,
8
  trim_messages,
9
  )
10
  from langchain_core.tools import tool
 
23
  from sse_starlette.sse import EventSourceResponse
24
  from fastapi.middleware.cors import CORSMiddleware
25
  import re
26
+ import os
27
+ from langchain_core.prompts import ChatPromptTemplate
28
 
29
  app = FastAPI()
30
  app.include_router(document_rag_router)
 
37
  allow_headers=["*"],
38
  )
39
 
40
+ def get_current_files():
41
+ """Get list of files in current directory"""
42
+ try:
43
+ files = os.listdir('.')
44
+ return ", ".join(files)
45
+ except Exception as e:
46
+ return f"Error getting files: {str(e)}"
47
+
48
  @tool
49
  def get_user_age(name: str) -> str:
50
  """Use this tool to find the user's age."""
 
56
  async def query_documents(
57
  query: str,
58
  config: RunnableConfig,
 
59
  ) -> str:
60
  """Use this tool to retrieve relevant data from the collection.
61
 
 
99
  print(e)
100
  return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"
101
 
 
102
  async def query_documents_raw(
103
  query: str,
104
  config: RunnableConfig,
 
105
  ) -> SearchResult:
106
  """Use this tool to retrieve relevant data from the collection.
107
 
 
134
  memory = MemorySaver()
135
  model = ChatOpenAI(model="gpt-4o-mini", streaming=True)
136
 
137
+ # Create a prompt template for formatting
138
+ prompt = ChatPromptTemplate.from_messages([
139
+ ("system", "You are a helpful AI assistant. Current directory contains: {current_files}"),
140
+ ("placeholder", "{messages}"),
141
+ ])
142
+
143
+ def format_for_model(state):
144
+ return prompt.invoke({
145
+ "current_files": get_current_files(),
146
+ "messages": state["messages"]
147
+ })
148
+
149
+ async def clean_tool_input(tool_input: str):
150
+ # Use regex to parse the first key and value
151
+ pattern = r"{\s*'([^']+)':\s*'([^']+)'"
152
+ match = re.search(pattern, tool_input)
153
+ if match:
154
+ key, value = match.groups()
155
+ return {key: value}
156
+ return [tool_input]
157
+
158
+ async def clean_tool_response(tool_output: str):
159
+ """Clean and extract relevant information from tool response if it contains query_documents."""
160
+ if "query_documents" in tool_output:
161
+ try:
162
+ # First safely evaluate the string as a Python literal
163
+ import ast
164
+ print(tool_output)
165
+ # Extract the list string from the content
166
+ start = tool_output.find("[{")
167
+ end = tool_output.rfind("}]") + 2
168
+ if start >= 0 and end > 0:
169
+ list_str = tool_output[start:end]
170
+
171
+ # Convert string to Python object using ast.literal_eval
172
+ results = ast.literal_eval(list_str)
173
+
174
+ # Return only relevant fields
175
+ return [{"text": r["text"], "document_id": r["metadata"]["document_id"]}
176
+ for r in results]
177
+
178
+ except SyntaxError as e:
179
+ print(f"Syntax error in parsing: {e}")
180
+ return f"Error parsing document results: {str(e)}"
181
+ except Exception as e:
182
+ print(f"General error: {e}")
183
+ return f"Error processing results: {str(e)}"
184
+ return tool_output
185
 
186
  agent = create_react_agent(
187
  model,
188
  tools=[query_documents],
189
  checkpointer=memory,
190
+ state_modifier=format_for_model,
191
  )
192
 
193
  class ChatInput(BaseModel):
 
236
  media_type="text/event-stream"
237
  )
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  @app.post("/chat2")
240
  async def chat2(input_data: ChatInput):
241
  thread_id = input_data.thread_id or str(uuid.uuid4())
 
299
 
300
  @app.get("/health")
301
  async def health_check():
302
+ return {"status": "healthy"}
303
+
304
+ if __name__ == "__main__":
305
+ import uvicorn
306
+ uvicorn.run(app, host="0.0.0.0", port=8000)