documind-api-v2 / main.py
pvanand's picture
Update main.py
c6774a0 verified
raw
history blame
11.4 kB
import uuid
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from langchain_core.messages import (
BaseMessage,
HumanMessage,
SystemMessage,
trim_messages,
)
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from pydantic import BaseModel
import json
from typing import Optional, Annotated
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import InjectedState
from document_rag_router import router as document_rag_router
from document_rag_router import QueryInput, query_collection, SearchResult,db
from fastapi import HTTPException
import requests
from sse_starlette.sse import EventSourceResponse
from fastapi.middleware.cors import CORSMiddleware
import re
import os
from langchain_core.prompts import ChatPromptTemplate
app = FastAPI()
app.include_router(document_rag_router)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_current_files():
"""Get list of files in current directory"""
try:
files = os.listdir('.')
return ", ".join(files)
except Exception as e:
return f"Error getting files: {str(e)}"
@tool
def get_user_age(name: str) -> str:
"""Use this tool to find the user's age."""
if "bob" in name.lower():
return "42 years old"
return "41 years old"
@tool
async def query_documents(
query: str,
config: RunnableConfig,
) -> str:
"""Use this tool to retrieve relevant data from the collection.
Args:
query: The search query to find relevant document passages
"""
# Get collection_id and user_id from config
thread_config = config.get("configurable", {})
collection_id = thread_config.get("collection_id")
user_id = thread_config.get("user_id")
if not collection_id or not user_id:
return "Error: collection_id and user_id are required in the config"
try:
# Create query input
input_data = QueryInput(
collection_id=collection_id,
query=query,
user_id=user_id,
top_k=6
)
response = await query_collection(input_data)
results = []
# Access response directly since it's a Pydantic model
for r in response.results:
result_dict = {
"text": r.text,
"distance": r.distance,
"metadata": {
"document_id": r.metadata.get("document_id"),
"chunk_index": r.metadata.get("location", {}).get("chunk_index")
}
}
results.append(result_dict)
return str(results)
except Exception as e:
print(e)
return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"
async def query_documents_raw(
query: str,
config: RunnableConfig,
) -> SearchResult:
"""Use this tool to retrieve relevant data from the collection.
Args:
query: The search query to find relevant document passages
"""
# Get collection_id and user_id from config
thread_config = config.get("configurable", {})
collection_id = thread_config.get("collection_id")
user_id = thread_config.get("user_id")
if not collection_id or not user_id:
return "Error: collection_id and user_id are required in the config"
try:
# Create query input
input_data = QueryInput(
collection_id=collection_id,
query=query,
user_id=user_id,
top_k=6
)
response = await query_collection(input_data)
return response.results
except Exception as e:
print(e)
return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"
memory = MemorySaver()
model = ChatOpenAI(model="gpt-4o-mini", streaming=True)
# Create a prompt template for formatting
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI assistant. The current collection contains the following files: {collection_files}"),
("placeholder", "{messages}"),
])
async def get_collection_files(collection_id: str, user_id: str) -> str:
"""Get list of files in the specified collection"""
try:
# Get the full collection name
collection_name = f"{user_id}_{collection_id}"
# Open the table and convert to pandas
table = db.open_table(collection_name)
df = table.to_pandas()
# Get unique file names
unique_files = df['file_name'].unique()
# Join the file names into a string
return ", ".join(unique_files)
except Exception as e:
logging.error(f"Error getting collection files: {str(e)}")
return f"Error getting files: {str(e)}"
async def format_for_model(state):
# Get collection_id and user_id from the state's configurable
config = state.get("configurable", {})
collection_id = config.get("collection_id")
user_id = config.get("user_id")
# Get files in the collection
collection_files = await get_collection_files(collection_id, user_id) if collection_id and user_id else "No files available"
return prompt.invoke({
"collection_files": collection_files,
"messages": state["messages"]
})
async def clean_tool_input(tool_input: str):
# Use regex to parse the first key and value
pattern = r"{\s*'([^']+)':\s*'([^']+)'"
match = re.search(pattern, tool_input)
if match:
key, value = match.groups()
return {key: value}
return [tool_input]
async def clean_tool_response(tool_output: str):
"""Clean and extract relevant information from tool response if it contains query_documents."""
if "query_documents" in tool_output:
try:
# First safely evaluate the string as a Python literal
import ast
print(tool_output)
# Extract the list string from the content
start = tool_output.find("[{")
end = tool_output.rfind("}]") + 2
if start >= 0 and end > 0:
list_str = tool_output[start:end]
# Convert string to Python object using ast.literal_eval
results = ast.literal_eval(list_str)
# Return only relevant fields
return [{"text": r["text"], "document_id": r["metadata"]["document_id"]}
for r in results]
except SyntaxError as e:
print(f"Syntax error in parsing: {e}")
return f"Error parsing document results: {str(e)}"
except Exception as e:
print(f"General error: {e}")
return f"Error processing results: {str(e)}"
return tool_output
agent = create_react_agent(
model,
tools=[query_documents],
checkpointer=memory,
state_modifier=format_for_model,
)
class ChatInput(BaseModel):
message: str
thread_id: Optional[str] = None
collection_id: Optional[str] = None
user_id: Optional[str] = None
@app.post("/chat")
async def chat(input_data: ChatInput):
thread_id = input_data.thread_id or str(uuid.uuid4())
config = {
"configurable": {
"thread_id": thread_id,
"collection_id": input_data.collection_id,
"user_id": input_data.user_id
}
}
input_message = HumanMessage(content=input_data.message)
async def generate():
async for event in agent.astream_events(
{"messages": [input_message]},
config,
version="v2"
):
kind = event["event"]
if kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
if content:
yield f"{json.dumps({'type': 'token', 'content': content})}"
elif kind == "on_tool_start":
tool_input = str(event['data'].get('input', ''))
yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}"
elif kind == "on_tool_end":
tool_output = str(event['data'].get('output', ''))
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}"
return EventSourceResponse(
generate(),
media_type="text/event-stream"
)
@app.post("/chat2")
async def chat2(input_data: ChatInput):
thread_id = input_data.thread_id or str(uuid.uuid4())
config = {
"configurable": {
"thread_id": thread_id,
"collection_id": input_data.collection_id,
"user_id": input_data.user_id
}
}
input_message = HumanMessage(content=input_data.message)
async def generate():
async for event in agent.astream_events(
{"messages": [input_message]},
config,
version="v2"
):
kind = event["event"]
if kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
if content:
yield f"{json.dumps({'type': 'token', 'content': content})}"
elif kind == "on_tool_start":
tool_name = event['name']
tool_input = event['data'].get('input', '')
clean_input = await clean_tool_input(str(tool_input))
yield f"{json.dumps({'type': 'tool_start', 'tool': tool_name, 'inputs': clean_input})}"
elif kind == "on_tool_end":
if "query_documents" in event['name']:
print(event)
raw_output = await query_documents_raw(str(event['data'].get('input', '')), config)
try:
serializable_output = [
{
"text": result.text,
"distance": result.distance,
"metadata": result.metadata
}
for result in raw_output
]
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': json.dumps(serializable_output)})}"
except Exception as e:
print(e)
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': str(raw_output)})}"
else:
tool_name = event['name']
raw_output = str(event['data'].get('output', ''))
clean_output = await clean_tool_response(raw_output)
yield f"{json.dumps({'type': 'tool_end', 'tool': tool_name, 'output': clean_output})}"
return EventSourceResponse(
generate(),
media_type="text/event-stream"
)
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)