pvanand commited on
Commit
37d6123
·
verified ·
1 Parent(s): 58c3ea2

Update rag_routerv2.py

Browse files
Files changed (1) hide show
  1. rag_routerv2.py +26 -16
rag_routerv2.py CHANGED
@@ -12,9 +12,9 @@ import csv
12
 
13
  # LlamaIndex imports
14
  from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
15
- from llama_index.core.schema import TextNode
16
  from llama_index.vector_stores.lancedb import LanceDBVectorStore
17
  from llama_index.embeddings.fastembed import FastEmbedEmbedding
 
18
  from llama_index.core import StorageContext, load_index_from_storage
19
  import json
20
  import os
@@ -39,6 +39,7 @@ class CreateTableResponse(BaseModel):
39
  table_id: str
40
  message: str
41
  status: str
 
42
 
43
  class QueryTableResponse(BaseModel):
44
  results: Dict[str, Any]
@@ -49,7 +50,8 @@ class QueryTableResponse(BaseModel):
49
  async def create_embedding_table(
50
  user_id: str,
51
  files: List[UploadFile] = File(...),
52
- table_id: Optional[str] = None
 
53
  ) -> CreateTableResponse:
54
  """Create a table and load embeddings from uploaded files using LlamaIndex."""
55
  allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"}
@@ -65,10 +67,10 @@ async def create_embedding_table(
65
 
66
  if table_id is None:
67
  table_id = str(uuid.uuid4())
68
- table_name = table_id #f"{user_id}__table__{table_id}"
69
 
70
  # Create a directory for the uploaded files
71
- directory_path = f"./data/{table_name}"
72
  os.makedirs(directory_path, exist_ok=True)
73
 
74
  # Save each uploaded file to the data directory
@@ -91,8 +93,8 @@ async def create_embedding_table(
91
  # Update the tables dictionary
92
  if user_id not in tables:
93
  tables[user_id] = []
94
- if table_name not in tables[user_id]:
95
- tables[user_id].append(table_name)
96
 
97
  # Write the updated tables back to the JSON file
98
  with open(tables_file_path, 'w') as f:
@@ -104,7 +106,7 @@ async def create_embedding_table(
104
  # Setup LanceDB vector store
105
  vector_store = LanceDBVectorStore(
106
  uri="./lancedb/dev",
107
- table_name=table_name,
108
  mode="overwrite",
109
  query_type="hybrid"
110
  )
@@ -117,12 +119,13 @@ async def create_embedding_table(
117
  documents,
118
  vector_store=vector_store
119
  )
120
- index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")
121
 
122
  return CreateTableResponse(
123
  table_id=table_id,
124
  message=f"Table created and documents indexed successfully",
125
- status="success"
 
126
  )
127
 
128
  except Exception as e:
@@ -181,6 +184,10 @@ async def get_tables(user_id: str):
181
  except Exception as e:
182
  raise HTTPException(status_code=500, detail=f"Failed to retrieve tables: {str(e)}")
183
 
 
 
 
 
184
  @router.on_event("startup")
185
  async def startup():
186
  print("RAG Router started")
@@ -205,16 +212,19 @@ async def startup():
205
  index = VectorStoreIndex(nodes, vector_store=vector_store)
206
  index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")
207
 
208
-
209
  # Create tables dictionary
210
  tables = {}
211
  user_id = "digiyatra"
212
-
213
- tables[user_id] = [table_name]
 
 
 
 
 
214
  with open(tables_file_path, 'w') as f:
215
  json.dump(tables, f)
216
 
217
-
218
- @router.get("/health")
219
- async def health_check():
220
- return {"status": "healthy"}
 
12
 
13
  # LlamaIndex imports
14
  from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
 
15
  from llama_index.vector_stores.lancedb import LanceDBVectorStore
16
  from llama_index.embeddings.fastembed import FastEmbedEmbedding
17
+ from llama_index.core.schema import TextNode
18
  from llama_index.core import StorageContext, load_index_from_storage
19
  import json
20
  import os
 
39
  table_id: str
40
  message: str
41
  status: str
42
+ table_name: str
43
 
44
  class QueryTableResponse(BaseModel):
45
  results: Dict[str, Any]
 
50
  async def create_embedding_table(
51
  user_id: str,
52
  files: List[UploadFile] = File(...),
53
+ table_id: Optional[str] = None,
54
+ table_name: Optional[str] = None
55
  ) -> CreateTableResponse:
56
  """Create a table and load embeddings from uploaded files using LlamaIndex."""
57
  allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"}
 
67
 
68
  if table_id is None:
69
  table_id = str(uuid.uuid4())
70
+ #table_name = table_id #f"{user_id}__table__{table_id}"
71
 
72
  # Create a directory for the uploaded files
73
+ directory_path = f"./data/{table_id}"
74
  os.makedirs(directory_path, exist_ok=True)
75
 
76
  # Save each uploaded file to the data directory
 
93
  # Update the tables dictionary
94
  if user_id not in tables:
95
  tables[user_id] = []
96
+ if table_id not in [table['table_id'] for table in tables[user_id]]:
97
+ tables[user_id].append({"table_id": table_id, "table_name": table_name})
98
 
99
  # Write the updated tables back to the JSON file
100
  with open(tables_file_path, 'w') as f:
 
106
  # Setup LanceDB vector store
107
  vector_store = LanceDBVectorStore(
108
  uri="./lancedb/dev",
109
+ table_name=table_id,
110
  mode="overwrite",
111
  query_type="hybrid"
112
  )
 
119
  documents,
120
  vector_store=vector_store
121
  )
122
+ index.storage_context.persist(persist_dir=f"./lancedb/index/{table_id}")
123
 
124
  return CreateTableResponse(
125
  table_id=table_id,
126
  message=f"Table created and documents indexed successfully",
127
+ status="success",
128
+ table_name="" if table_name is None else table_name
129
  )
130
 
131
  except Exception as e:
 
184
  except Exception as e:
185
  raise HTTPException(status_code=500, detail=f"Failed to retrieve tables: {str(e)}")
186
 
187
+ @router.get("/health")
188
+ async def health_check():
189
+ return {"status": "healthy"}
190
+
191
  @router.on_event("startup")
192
  async def startup():
193
  print("RAG Router started")
 
212
  index = VectorStoreIndex(nodes, vector_store=vector_store)
213
  index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")
214
 
 
215
  # Create tables dictionary
216
  tables = {}
217
  user_id = "digiyatra"
218
+
219
+ tables[user_id] = [
220
+ {
221
+ "table_id": table_name,
222
+ "table_name": table_name
223
+ }
224
+ ]
225
  with open(tables_file_path, 'w') as f:
226
  json.dump(tables, f)
227
 
228
+ @router.on_event("shutdown")
229
+ async def shutdown():
230
+ print("RAG Router shutdown")