Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -17,7 +17,7 @@ from typing import Optional, Annotated
|
|
17 |
from langchain_core.runnables import RunnableConfig
|
18 |
from langgraph.prebuilt import InjectedState
|
19 |
from document_rag_router import router as document_rag_router
|
20 |
-
from document_rag_router import QueryInput, query_collection, SearchResult
|
21 |
from fastapi import HTTPException
|
22 |
import requests
|
23 |
from sse_starlette.sse import EventSourceResponse
|
@@ -136,13 +136,40 @@ model = ChatOpenAI(model="gpt-4o-mini", streaming=True)
|
|
136 |
|
137 |
# Create a prompt template for formatting
|
138 |
prompt = ChatPromptTemplate.from_messages([
|
139 |
-
("system", "You are a helpful AI assistant.
|
140 |
("placeholder", "{messages}"),
|
141 |
])
|
142 |
|
143 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
return prompt.invoke({
|
145 |
-
"
|
146 |
"messages": state["messages"]
|
147 |
})
|
148 |
|
|
|
17 |
from langchain_core.runnables import RunnableConfig
|
18 |
from langgraph.prebuilt import InjectedState
|
19 |
from document_rag_router import router as document_rag_router
|
20 |
+
from document_rag_router import QueryInput, query_collection, SearchResult,db
|
21 |
from fastapi import HTTPException
|
22 |
import requests
|
23 |
from sse_starlette.sse import EventSourceResponse
|
|
|
136 |
|
137 |
# Create a prompt template for formatting
|
138 |
prompt = ChatPromptTemplate.from_messages([
|
139 |
+
("system", "You are a helpful AI assistant. The current collection contains the following files: {collection_files}"),
|
140 |
("placeholder", "{messages}"),
|
141 |
])
|
142 |
|
143 |
+
async def get_collection_files(collection_id: str, user_id: str) -> str:
|
144 |
+
"""Get list of files in the specified collection"""
|
145 |
+
try:
|
146 |
+
# Get the full collection name
|
147 |
+
collection_name = f"{user_id}_{collection_id}"
|
148 |
+
|
149 |
+
# Open the table and convert to pandas
|
150 |
+
table = db.open_table(collection_name)
|
151 |
+
df = table.to_pandas()
|
152 |
+
|
153 |
+
# Get unique file names
|
154 |
+
unique_files = df['file_name'].unique()
|
155 |
+
|
156 |
+
# Join the file names into a string
|
157 |
+
return ", ".join(unique_files)
|
158 |
+
except Exception as e:
|
159 |
+
logging.error(f"Error getting collection files: {str(e)}")
|
160 |
+
return f"Error getting files: {str(e)}"
|
161 |
+
|
162 |
+
async def format_for_model(state):
|
163 |
+
# Get collection_id and user_id from the state's configurable
|
164 |
+
config = state.get("configurable", {})
|
165 |
+
collection_id = config.get("collection_id")
|
166 |
+
user_id = config.get("user_id")
|
167 |
+
|
168 |
+
# Get files in the collection
|
169 |
+
collection_files = await get_collection_files(collection_id, user_id) if collection_id and user_id else "No files available"
|
170 |
+
|
171 |
return prompt.invoke({
|
172 |
+
"collection_files": collection_files,
|
173 |
"messages": state["messages"]
|
174 |
})
|
175 |
|