pvanand commited on
Commit
c6774a0
·
verified ·
1 Parent(s): 1f6995c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +31 -4
main.py CHANGED
@@ -17,7 +17,7 @@ from typing import Optional, Annotated
17
  from langchain_core.runnables import RunnableConfig
18
  from langgraph.prebuilt import InjectedState
19
  from document_rag_router import router as document_rag_router
20
- from document_rag_router import QueryInput, query_collection, SearchResult
21
  from fastapi import HTTPException
22
  import requests
23
  from sse_starlette.sse import EventSourceResponse
@@ -136,13 +136,40 @@ model = ChatOpenAI(model="gpt-4o-mini", streaming=True)
136
 
137
  # Create a prompt template for formatting
138
  prompt = ChatPromptTemplate.from_messages([
139
- ("system", "You are a helpful AI assistant. Current directory contains: {current_files}"),
140
  ("placeholder", "{messages}"),
141
  ])
142
 
143
- def format_for_model(state):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  return prompt.invoke({
145
- "current_files": get_current_files(),
146
  "messages": state["messages"]
147
  })
148
 
 
17
  from langchain_core.runnables import RunnableConfig
18
  from langgraph.prebuilt import InjectedState
19
  from document_rag_router import router as document_rag_router
20
+ from document_rag_router import QueryInput, query_collection, SearchResult,db
21
  from fastapi import HTTPException
22
  import requests
23
  from sse_starlette.sse import EventSourceResponse
 
136
 
137
  # Create a prompt template for formatting
138
  prompt = ChatPromptTemplate.from_messages([
139
+ ("system", "You are a helpful AI assistant. The current collection contains the following files: {collection_files}"),
140
  ("placeholder", "{messages}"),
141
  ])
142
 
143
+ async def get_collection_files(collection_id: str, user_id: str) -> str:
144
+ """Get list of files in the specified collection"""
145
+ try:
146
+ # Get the full collection name
147
+ collection_name = f"{user_id}_{collection_id}"
148
+
149
+ # Open the table and convert to pandas
150
+ table = db.open_table(collection_name)
151
+ df = table.to_pandas()
152
+
153
+ # Get unique file names
154
+ unique_files = df['file_name'].unique()
155
+
156
+ # Join the file names into a string
157
+ return ", ".join(unique_files)
158
+ except Exception as e:
159
+ logging.error(f"Error getting collection files: {str(e)}")
160
+ return f"Error getting files: {str(e)}"
161
+
162
+ async def format_for_model(state):
163
+ # Get collection_id and user_id from the state's configurable
164
+ config = state.get("configurable", {})
165
+ collection_id = config.get("collection_id")
166
+ user_id = config.get("user_id")
167
+
168
+ # Get files in the collection
169
+ collection_files = await get_collection_files(collection_id, user_id) if collection_id and user_id else "No files available"
170
+
171
  return prompt.invoke({
172
+ "collection_files": collection_files,
173
  "messages": state["messages"]
174
  })
175