pvanand commited on
Commit
c73bfb4
·
verified ·
1 Parent(s): e9aadec

Update rag_routerv2.py

Browse files
Files changed (1) hide show
  1. rag_routerv2.py +216 -184
rag_routerv2.py CHANGED
@@ -1,185 +1,217 @@
1
- from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
2
- import pandas as pd
3
- import lancedb
4
- from functools import cached_property, lru_cache
5
- from pydantic import Field, BaseModel
6
- from typing import Optional, Dict, List, Annotated, Any
7
- from fastapi import APIRouter
8
- import uuid
9
- import io
10
- from io import BytesIO
11
- import csv
12
-
13
- # LlamaIndex imports
14
- from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
15
- from llama_index.vector_stores.lancedb import LanceDBVectorStore
16
- from llama_index.embeddings.fastembed import FastEmbedEmbedding
17
- from llama_index.core import StorageContext, load_index_from_storage
18
- import json
19
- import os
20
- import shutil
21
-
22
- router = APIRouter(
23
- prefix="/rag",
24
- tags=["rag"]
25
- )
26
-
27
- # Configure global LlamaIndex settings
28
- Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5")
29
- tables_file_path = './data/tables.json'
30
-
31
- # Database connection dependency
32
- @lru_cache()
33
- def get_db_connection(db_path: str = "./lancedb/dev"):
34
- return lancedb.connect(db_path)
35
-
36
- # Pydantic models
37
- class CreateTableResponse(BaseModel):
38
- table_id: str
39
- message: str
40
- status: str
41
-
42
- class QueryTableResponse(BaseModel):
43
- results: Dict[str, Any]
44
- total_results: int
45
-
46
-
47
- @router.post("/create_table", response_model=CreateTableResponse)
48
- async def create_embedding_table(
49
- user_id: str,
50
- files: List[UploadFile] = File(...),
51
- table_id: Optional[str] = None
52
- ) -> CreateTableResponse:
53
- """Create a table and load embeddings from uploaded files using LlamaIndex."""
54
- allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"}
55
- for file in files:
56
- if file.filename is None:
57
- raise HTTPException(status_code=400, detail="File must have a valid name.")
58
- file_extension = os.path.splitext(file.filename)[1].lower()
59
- if file_extension not in allowed_extensions:
60
- raise HTTPException(
61
- status_code=400,
62
- detail=f"File type {file_extension} is not allowed. Supported file types are: {', '.join(allowed_extensions)}."
63
- )
64
-
65
- if table_id is None:
66
- table_id = str(uuid.uuid4())
67
- table_name = table_id #f"{user_id}__table__{table_id}"
68
-
69
- # Create a directory for the uploaded files
70
- directory_path = f"./data/{table_name}"
71
- os.makedirs(directory_path, exist_ok=True)
72
-
73
- # Save each uploaded file to the data directory
74
- for file in files:
75
- file_path = os.path.join(directory_path, file.filename)
76
- with open(file_path, "wb") as buffer:
77
- shutil.copyfileobj(file.file, buffer)
78
-
79
- # Store user_id and table_name in a JSON file
80
- try:
81
- tables_file_path = './data/tables.json'
82
- os.makedirs(os.path.dirname(tables_file_path), exist_ok=True)
83
- # Load existing tables or create a new file if it doesn't exist
84
- try:
85
- with open(tables_file_path, 'r') as f:
86
- tables = json.load(f)
87
- except (FileNotFoundError, json.JSONDecodeError):
88
- tables = {}
89
-
90
- # Update the tables dictionary
91
- if user_id not in tables:
92
- tables[user_id] = []
93
- if table_name not in tables[user_id]:
94
- tables[user_id].append(table_name)
95
-
96
- # Write the updated tables back to the JSON file
97
- with open(tables_file_path, 'w') as f:
98
- json.dump(tables, f)
99
-
100
- except Exception as e:
101
- raise HTTPException(status_code=500, detail=f"Failed to update tables file: {str(e)}")
102
- try:
103
- # Setup LanceDB vector store
104
- vector_store = LanceDBVectorStore(
105
- uri="./lancedb/dev",
106
- table_name=table_name,
107
- # mode="overwrite",
108
- # query_type="vector"
109
- )
110
-
111
- # Load documents using SimpleDirectoryReader
112
- documents = SimpleDirectoryReader(directory_path).load_data()
113
-
114
- # Create the index
115
- index = VectorStoreIndex.from_documents(
116
- documents,
117
- vector_store=vector_store
118
- )
119
- index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")
120
-
121
- return CreateTableResponse(
122
- table_id=table_id,
123
- message=f"Table created and documents indexed successfully",
124
- status="success"
125
- )
126
-
127
- except Exception as e:
128
- raise HTTPException(status_code=500, detail=f"Table creation failed: {str(e)}")
129
-
130
- @router.post("/query_table/{table_id}", response_model=QueryTableResponse)
131
- async def query_table(
132
- table_id: str,
133
- query: str,
134
- user_id: str,
135
- #db: Annotated[Any, Depends(get_db_connection)],
136
- limit: Optional[int] = 10
137
- ) -> QueryTableResponse:
138
- """Query the database table using LlamaIndex."""
139
- try:
140
- table_name = table_id #f"{user_id}__table__{table_id}"
141
-
142
- # load index and retriever
143
- storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}")
144
- index = load_index_from_storage(storage_context)
145
- retriever = index.as_retriever(similarity_top_k=limit)
146
-
147
- # Get response
148
- response = retriever.retrieve(query)
149
-
150
- # Format results
151
- results = [{
152
- 'text': node.text,
153
- 'score': node.score
154
- } for node in response]
155
-
156
- return QueryTableResponse(
157
- results={'data': results},
158
- total_results=len(results)
159
- )
160
-
161
- except Exception as e:
162
- raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")
163
-
164
- @router.get("/get_tables/{user_id}")
165
- async def get_tables(user_id: str):
166
- """Get all tables for a user."""
167
-
168
- tables_file_path = './data/tables.json'
169
- try:
170
- # Load existing tables from the JSON file
171
- with open(tables_file_path, 'r') as f:
172
- tables = json.load(f)
173
-
174
- # Retrieve tables for the specified user
175
- user_tables = tables.get(user_id, [])
176
- return user_tables
177
-
178
- except (FileNotFoundError, json.JSONDecodeError):
179
- return [] # Return an empty list if the file doesn't exist or is invalid
180
- except Exception as e:
181
- raise HTTPException(status_code=500, detail=f"Failed to retrieve tables: {str(e)}")
182
-
183
- @router.get("/health")
184
- async def health_check():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  return {"status": "healthy"}
 
1
+ from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
2
+ import pandas as pd
3
+ import lancedb
4
+ from functools import cached_property, lru_cache
5
+ from pydantic import Field, BaseModel
6
+ from typing import Optional, Dict, List, Annotated, Any
7
+ from fastapi import APIRouter
8
+ import uuid
9
+ import io
10
+ from io import BytesIO
11
+ import csv
12
+
13
+ # LlamaIndex imports
14
+ from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
15
+ from llama_index.vector_stores.lancedb import LanceDBVectorStore
16
+ from llama_index.embeddings.fastembed import FastEmbedEmbedding
17
+ from llama_index.core import StorageContext, load_index_from_storage
18
+ import json
19
+ import os
20
+ import shutil
21
+
22
+ router = APIRouter(
23
+ prefix="/rag",
24
+ tags=["rag"]
25
+ )
26
+
27
+ # Configure global LlamaIndex settings
28
+ Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5")
29
+ tables_file_path = './data/tables.json'
30
+
31
+ # Database connection dependency
32
+ @lru_cache()
33
+ def get_db_connection(db_path: str = "./lancedb/dev"):
34
+ return lancedb.connect(db_path)
35
+
36
+ # Pydantic models
37
+ class CreateTableResponse(BaseModel):
38
+ table_id: str
39
+ message: str
40
+ status: str
41
+
42
+ class QueryTableResponse(BaseModel):
43
+ results: Dict[str, Any]
44
+ total_results: int
45
+
46
+
47
+ @router.post("/create_table", response_model=CreateTableResponse)
48
+ async def create_embedding_table(
49
+ user_id: str,
50
+ files: List[UploadFile] = File(...),
51
+ table_id: Optional[str] = None
52
+ ) -> CreateTableResponse:
53
+ """Create a table and load embeddings from uploaded files using LlamaIndex."""
54
+ allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"}
55
+ for file in files:
56
+ if file.filename is None:
57
+ raise HTTPException(status_code=400, detail="File must have a valid name.")
58
+ file_extension = os.path.splitext(file.filename)[1].lower()
59
+ if file_extension not in allowed_extensions:
60
+ raise HTTPException(
61
+ status_code=400,
62
+ detail=f"File type {file_extension} is not allowed. Supported file types are: {', '.join(allowed_extensions)}."
63
+ )
64
+
65
+ if table_id is None:
66
+ table_id = str(uuid.uuid4())
67
+ table_name = table_id #f"{user_id}__table__{table_id}"
68
+
69
+ # Create a directory for the uploaded files
70
+ directory_path = f"./data/{table_name}"
71
+ os.makedirs(directory_path, exist_ok=True)
72
+
73
+ # Save each uploaded file to the data directory
74
+ for file in files:
75
+ file_path = os.path.join(directory_path, file.filename)
76
+ with open(file_path, "wb") as buffer:
77
+ shutil.copyfileobj(file.file, buffer)
78
+
79
+ # Store user_id and table_name in a JSON file
80
+ try:
81
+ tables_file_path = './data/tables.json'
82
+ os.makedirs(os.path.dirname(tables_file_path), exist_ok=True)
83
+ # Load existing tables or create a new file if it doesn't exist
84
+ try:
85
+ with open(tables_file_path, 'r') as f:
86
+ tables = json.load(f)
87
+ except (FileNotFoundError, json.JSONDecodeError):
88
+ tables = {}
89
+
90
+ # Update the tables dictionary
91
+ if user_id not in tables:
92
+ tables[user_id] = []
93
+ if table_name not in tables[user_id]:
94
+ tables[user_id].append(table_name)
95
+
96
+ # Write the updated tables back to the JSON file
97
+ with open(tables_file_path, 'w') as f:
98
+ json.dump(tables, f)
99
+
100
+ except Exception as e:
101
+ raise HTTPException(status_code=500, detail=f"Failed to update tables file: {str(e)}")
102
+ try:
103
+ # Setup LanceDB vector store
104
+ vector_store = LanceDBVectorStore(
105
+ uri="./lancedb/dev",
106
+ table_name=table_name,
107
+ # mode="overwrite",
108
+ # query_type="vector"
109
+ )
110
+
111
+ # Load documents using SimpleDirectoryReader
112
+ documents = SimpleDirectoryReader(directory_path).load_data()
113
+
114
+ # Create the index
115
+ index = VectorStoreIndex.from_documents(
116
+ documents,
117
+ vector_store=vector_store
118
+ )
119
+ index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")
120
+
121
+ return CreateTableResponse(
122
+ table_id=table_id,
123
+ message=f"Table created and documents indexed successfully",
124
+ status="success"
125
+ )
126
+
127
+ except Exception as e:
128
+ raise HTTPException(status_code=500, detail=f"Table creation failed: {str(e)}")
129
+
130
+ @router.post("/query_table/{table_id}", response_model=QueryTableResponse)
131
+ async def query_table(
132
+ table_id: str,
133
+ query: str,
134
+ user_id: str,
135
+ #db: Annotated[Any, Depends(get_db_connection)],
136
+ limit: Optional[int] = 10
137
+ ) -> QueryTableResponse:
138
+ """Query the database table using LlamaIndex."""
139
+ try:
140
+ table_name = table_id #f"{user_id}__table__{table_id}"
141
+
142
+ # load index and retriever
143
+ storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}")
144
+ index = load_index_from_storage(storage_context)
145
+ retriever = index.as_retriever(similarity_top_k=limit)
146
+
147
+ # Get response
148
+ response = retriever.retrieve(query)
149
+
150
+ # Format results
151
+ results = [{
152
+ 'text': node.text,
153
+ 'score': node.score
154
+ } for node in response]
155
+
156
+ return QueryTableResponse(
157
+ results={'data': results},
158
+ total_results=len(results)
159
+ )
160
+
161
+ except Exception as e:
162
+ raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")
163
+
164
+ @router.get("/get_tables/{user_id}")
165
+ async def get_tables(user_id: str):
166
+ """Get all tables for a user."""
167
+
168
+ tables_file_path = './data/tables.json'
169
+ try:
170
+ # Load existing tables from the JSON file
171
+ with open(tables_file_path, 'r') as f:
172
+ tables = json.load(f)
173
+
174
+ # Retrieve tables for the specified user
175
+ user_tables = tables.get(user_id, [])
176
+ return user_tables
177
+
178
+ except (FileNotFoundError, json.JSONDecodeError):
179
+ return [] # Return an empty list if the file doesn't exist or is invalid
180
+ except Exception as e:
181
+ raise HTTPException(status_code=500, detail=f"Failed to retrieve tables: {str(e)}")
182
+
183
+ @router.on_event("startup")
184
+ async def startup():
185
+ print("RAG Router started")
186
+ from llama_index.core.schema import TextNode
187
+ table_name = "digiyatra"
188
+ vector_store = LanceDBVectorStore(
189
+ uri="./lancedb/dev",
190
+ table_name=table_name,
191
+ # mode="overwrite",
192
+ # query_type="vector"
193
+ )
194
+ # load digiyatra csv and create node for each row using csv.reader
195
+ with open("./data/digiyatra.csv", "r") as file:
196
+ reader = csv.reader(file)
197
+ nodes = []
198
+ for row in reader:
199
+ node = TextNode(text=row, id_=str(uuid.uuid4()))
200
+ nodes.append(node)
201
+
202
+ index = VectorStoreIndex(nodes, vector_store=vector_store)
203
+ index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")
204
+
205
+
206
+ # Create tables dictionary
207
+ tables = {}
208
+ user_id = "digiyatra"
209
+
210
+ tables[user_id] = [table_name]
211
+ with open(tables_file_path, 'w') as f:
212
+ json.dump(tables, f)
213
+
214
+
215
+ @router.get("/health")
216
+ async def health_check():
217
  return {"status": "healthy"}