Spaces:
Sleeping
Sleeping
add update to system message
Browse files
main.py
CHANGED
@@ -4,6 +4,7 @@ from fastapi.responses import StreamingResponse
|
|
4 |
from langchain_core.messages import (
|
5 |
BaseMessage,
|
6 |
HumanMessage,
|
|
|
7 |
trim_messages,
|
8 |
)
|
9 |
from langchain_core.tools import tool
|
@@ -22,6 +23,8 @@ import requests
|
|
22 |
from sse_starlette.sse import EventSourceResponse
|
23 |
from fastapi.middleware.cors import CORSMiddleware
|
24 |
import re
|
|
|
|
|
25 |
|
26 |
app = FastAPI()
|
27 |
app.include_router(document_rag_router)
|
@@ -34,6 +37,14 @@ app.add_middleware(
|
|
34 |
allow_headers=["*"],
|
35 |
)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
@tool
|
38 |
def get_user_age(name: str) -> str:
|
39 |
"""Use this tool to find the user's age."""
|
@@ -45,7 +56,6 @@ def get_user_age(name: str) -> str:
|
|
45 |
async def query_documents(
|
46 |
query: str,
|
47 |
config: RunnableConfig,
|
48 |
-
#state: Annotated[dict, InjectedState]
|
49 |
) -> str:
|
50 |
"""Use this tool to retrieve relevant data from the collection.
|
51 |
|
@@ -89,11 +99,9 @@ async def query_documents(
|
|
89 |
print(e)
|
90 |
return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"
|
91 |
|
92 |
-
|
93 |
async def query_documents_raw(
|
94 |
query: str,
|
95 |
config: RunnableConfig,
|
96 |
-
#state: Annotated[dict, InjectedState]
|
97 |
) -> SearchResult:
|
98 |
"""Use this tool to retrieve relevant data from the collection.
|
99 |
|
@@ -126,22 +134,60 @@ async def query_documents_raw(
|
|
126 |
memory = MemorySaver()
|
127 |
model = ChatOpenAI(model="gpt-4o-mini", streaming=True)
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
agent = create_react_agent(
|
141 |
model,
|
142 |
tools=[query_documents],
|
143 |
checkpointer=memory,
|
144 |
-
state_modifier=
|
145 |
)
|
146 |
|
147 |
class ChatInput(BaseModel):
|
@@ -190,43 +236,6 @@ async def chat(input_data: ChatInput):
|
|
190 |
media_type="text/event-stream"
|
191 |
)
|
192 |
|
193 |
-
async def clean_tool_input(tool_input: str):
|
194 |
-
# Use regex to parse the first key and value
|
195 |
-
pattern = r"{\s*'([^']+)':\s*'([^']+)'"
|
196 |
-
match = re.search(pattern, tool_input)
|
197 |
-
if match:
|
198 |
-
key, value = match.groups()
|
199 |
-
return {key: value}
|
200 |
-
return [tool_input]
|
201 |
-
|
202 |
-
async def clean_tool_response(tool_output: str):
|
203 |
-
"""Clean and extract relevant information from tool response if it contains query_documents."""
|
204 |
-
if "query_documents" in tool_output:
|
205 |
-
try:
|
206 |
-
# First safely evaluate the string as a Python literal
|
207 |
-
import ast
|
208 |
-
print(tool_output)
|
209 |
-
# Extract the list string from the content
|
210 |
-
start = tool_output.find("[{")
|
211 |
-
end = tool_output.rfind("}]") + 2
|
212 |
-
if start >= 0 and end > 0:
|
213 |
-
list_str = tool_output[start:end]
|
214 |
-
|
215 |
-
# Convert string to Python object using ast.literal_eval
|
216 |
-
results = ast.literal_eval(list_str)
|
217 |
-
|
218 |
-
# Return only relevant fields
|
219 |
-
return [{"text": r["text"], "document_id": r["metadata"]["document_id"]}
|
220 |
-
for r in results]
|
221 |
-
|
222 |
-
except SyntaxError as e:
|
223 |
-
print(f"Syntax error in parsing: {e}")
|
224 |
-
return f"Error parsing document results: {str(e)}"
|
225 |
-
except Exception as e:
|
226 |
-
print(f"General error: {e}")
|
227 |
-
return f"Error processing results: {str(e)}"
|
228 |
-
return tool_output
|
229 |
-
|
230 |
@app.post("/chat2")
|
231 |
async def chat2(input_data: ChatInput):
|
232 |
thread_id = input_data.thread_id or str(uuid.uuid4())
|
@@ -290,4 +299,8 @@ async def chat2(input_data: ChatInput):
|
|
290 |
|
291 |
@app.get("/health")
|
292 |
async def health_check():
|
293 |
-
return {"status": "healthy"}
|
|
|
|
|
|
|
|
|
|
4 |
from langchain_core.messages import (
|
5 |
BaseMessage,
|
6 |
HumanMessage,
|
7 |
+
SystemMessage,
|
8 |
trim_messages,
|
9 |
)
|
10 |
from langchain_core.tools import tool
|
|
|
23 |
from sse_starlette.sse import EventSourceResponse
|
24 |
from fastapi.middleware.cors import CORSMiddleware
|
25 |
import re
|
26 |
+
import os
|
27 |
+
from langchain_core.prompts import ChatPromptTemplate
|
28 |
|
29 |
app = FastAPI()
|
30 |
app.include_router(document_rag_router)
|
|
|
37 |
allow_headers=["*"],
|
38 |
)
|
39 |
|
40 |
+
def get_current_files():
|
41 |
+
"""Get list of files in current directory"""
|
42 |
+
try:
|
43 |
+
files = os.listdir('.')
|
44 |
+
return ", ".join(files)
|
45 |
+
except Exception as e:
|
46 |
+
return f"Error getting files: {str(e)}"
|
47 |
+
|
48 |
@tool
|
49 |
def get_user_age(name: str) -> str:
|
50 |
"""Use this tool to find the user's age."""
|
|
|
56 |
async def query_documents(
|
57 |
query: str,
|
58 |
config: RunnableConfig,
|
|
|
59 |
) -> str:
|
60 |
"""Use this tool to retrieve relevant data from the collection.
|
61 |
|
|
|
99 |
print(e)
|
100 |
return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"
|
101 |
|
|
|
102 |
async def query_documents_raw(
|
103 |
query: str,
|
104 |
config: RunnableConfig,
|
|
|
105 |
) -> SearchResult:
|
106 |
"""Use this tool to retrieve relevant data from the collection.
|
107 |
|
|
|
134 |
memory = MemorySaver()
|
135 |
model = ChatOpenAI(model="gpt-4o-mini", streaming=True)
|
136 |
|
137 |
+
# Create a prompt template for formatting
|
138 |
+
prompt = ChatPromptTemplate.from_messages([
|
139 |
+
("system", "You are a helpful AI assistant. Current directory contains: {current_files}"),
|
140 |
+
("placeholder", "{messages}"),
|
141 |
+
])
|
142 |
+
|
143 |
+
def format_for_model(state):
|
144 |
+
return prompt.invoke({
|
145 |
+
"current_files": get_current_files(),
|
146 |
+
"messages": state["messages"]
|
147 |
+
})
|
148 |
+
|
149 |
+
async def clean_tool_input(tool_input: str):
|
150 |
+
# Use regex to parse the first key and value
|
151 |
+
pattern = r"{\s*'([^']+)':\s*'([^']+)'"
|
152 |
+
match = re.search(pattern, tool_input)
|
153 |
+
if match:
|
154 |
+
key, value = match.groups()
|
155 |
+
return {key: value}
|
156 |
+
return [tool_input]
|
157 |
+
|
158 |
+
async def clean_tool_response(tool_output: str):
|
159 |
+
"""Clean and extract relevant information from tool response if it contains query_documents."""
|
160 |
+
if "query_documents" in tool_output:
|
161 |
+
try:
|
162 |
+
# First safely evaluate the string as a Python literal
|
163 |
+
import ast
|
164 |
+
print(tool_output)
|
165 |
+
# Extract the list string from the content
|
166 |
+
start = tool_output.find("[{")
|
167 |
+
end = tool_output.rfind("}]") + 2
|
168 |
+
if start >= 0 and end > 0:
|
169 |
+
list_str = tool_output[start:end]
|
170 |
+
|
171 |
+
# Convert string to Python object using ast.literal_eval
|
172 |
+
results = ast.literal_eval(list_str)
|
173 |
+
|
174 |
+
# Return only relevant fields
|
175 |
+
return [{"text": r["text"], "document_id": r["metadata"]["document_id"]}
|
176 |
+
for r in results]
|
177 |
+
|
178 |
+
except SyntaxError as e:
|
179 |
+
print(f"Syntax error in parsing: {e}")
|
180 |
+
return f"Error parsing document results: {str(e)}"
|
181 |
+
except Exception as e:
|
182 |
+
print(f"General error: {e}")
|
183 |
+
return f"Error processing results: {str(e)}"
|
184 |
+
return tool_output
|
185 |
|
186 |
agent = create_react_agent(
|
187 |
model,
|
188 |
tools=[query_documents],
|
189 |
checkpointer=memory,
|
190 |
+
state_modifier=format_for_model,
|
191 |
)
|
192 |
|
193 |
class ChatInput(BaseModel):
|
|
|
236 |
media_type="text/event-stream"
|
237 |
)
|
238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
@app.post("/chat2")
|
240 |
async def chat2(input_data: ChatInput):
|
241 |
thread_id = input_data.thread_id or str(uuid.uuid4())
|
|
|
299 |
|
300 |
@app.get("/health")
|
301 |
async def health_check():
|
302 |
+
return {"status": "healthy"}
|
303 |
+
|
304 |
+
if __name__ == "__main__":
|
305 |
+
import uvicorn
|
306 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|