rag_chat_with_analytics / rag_routerv2.py
pvanand's picture
use sqlite to replace tables.json
131998f verified
raw
history blame
7.23 kB
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
import pandas as pd
import lancedb
from functools import cached_property, lru_cache
from pydantic import Field, BaseModel
from typing import Optional, Dict, List, Annotated, Any
from fastapi import APIRouter
import uuid
import io
from io import BytesIO
import csv
import sqlite3
# LlamaIndex imports
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.core.schema import TextNode
from llama_index.core import StorageContext, load_index_from_storage
import json
import os
import shutil
router = APIRouter(
prefix="/rag",
tags=["rag"]
)
# Configure global LlamaIndex settings
Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5")
# Database connection dependency
@lru_cache()
def get_db_connection(db_path: str = "./lancedb/dev"):
return lancedb.connect(db_path)
def get_db():
conn = sqlite3.connect('./data/tables.db')
conn.row_factory = sqlite3.Row
return conn
def init_db():
db = get_db()
db.execute('''
CREATE TABLE IF NOT EXISTS tables (
id INTEGER PRIMARY KEY,
user_id TEXT NOT NULL,
table_id TEXT NOT NULL,
table_name TEXT NOT NULL
)
''')
db.execute('''
CREATE TABLE IF NOT EXISTS table_files (
id INTEGER PRIMARY KEY,
table_id TEXT NOT NULL,
filename TEXT NOT NULL,
file_path TEXT NOT NULL,
FOREIGN KEY (table_id) REFERENCES tables (table_id)
)
''')
db.commit()
# Pydantic models
class CreateTableResponse(BaseModel):
table_id: str
message: str
status: str
table_name: str
class QueryTableResponse(BaseModel):
results: Dict[str, Any]
total_results: int
@router.post("/create_table", response_model=CreateTableResponse)
async def create_embedding_table(
user_id: str,
files: List[UploadFile] = File(...),
table_id: Optional[str] = None,
table_name: Optional[str] = None
) -> CreateTableResponse:
allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"}
for file in files:
if not file.filename:
raise HTTPException(status_code=400, detail="Invalid filename")
if os.path.splitext(file.filename)[1].lower() not in allowed_extensions:
raise HTTPException(status_code=400, detail="Unsupported file type")
table_id = table_id or str(uuid.uuid4())
table_name = table_name or f"knowledge-base-{str(uuid.uuid4())[:4]}"
directory_path = f"./data/{table_id}"
os.makedirs(directory_path, exist_ok=True)
for file in files:
file_path = os.path.join(directory_path, file.filename)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
try:
vector_store = LanceDBVectorStore(
uri="./lancedb/dev",
table_name=table_id,
mode="overwrite",
query_type="hybrid"
)
documents = SimpleDirectoryReader(directory_path).load_data()
index = VectorStoreIndex.from_documents(documents, vector_store=vector_store)
index.storage_context.persist(persist_dir=f"./lancedb/index/{table_id}")
db = get_db()
db.execute(
'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)',
(user_id, table_id, table_name)
)
for file in files:
db.execute(
'INSERT INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)',
(table_id, file.filename, f"./data/{table_id}/{file.filename}")
)
db.commit()
return CreateTableResponse(
table_id=table_id,
message="Success",
status="success",
table_name=table_name
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/query_table/{table_id}", response_model=QueryTableResponse)
async def query_table(
table_id: str,
query: str,
user_id: str,
#db: Annotated[Any, Depends(get_db_connection)],
limit: Optional[int] = 10
) -> QueryTableResponse:
"""Query the database table using LlamaIndex."""
try:
table_name = table_id #f"{user_id}__table__{table_id}"
# load index and retriever
storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}")
index = load_index_from_storage(storage_context)
retriever = index.as_retriever(similarity_top_k=limit)
# Get response
response = retriever.retrieve(query)
# Format results
results = [{
'text': node.text,
'score': node.score
} for node in response]
return QueryTableResponse(
results={'data': results},
total_results=len(results)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")
@router.get("/get_tables/{user_id}")
async def get_tables(user_id: str):
db = get_db()
tables = db.execute('''
SELECT t.*, GROUP_CONCAT(tf.filename) as filenames, GROUP_CONCAT(tf.file_path) as file_paths
FROM tables t
LEFT JOIN table_files tf ON t.table_id = tf.table_id
WHERE t.user_id = ?
GROUP BY t.table_id
''', (user_id,)).fetchall()
result = []
for table in tables:
table_dict = dict(table)
table_dict['files'] = [
{'filename': f, 'file_path': p}
for f, p in zip(
table_dict.pop('filenames').split(',') if table_dict['filenames'] else [],
table_dict.pop('file_paths').split(',') if table_dict['file_paths'] else []
)
]
result.append(table_dict)
return result
@router.get("/health")
async def health_check():
return {"status": "healthy"}
@router.on_event("startup")
async def startup():
init_db()
print("RAG Router started")
table_name = "digiyatra"
user_id = "digiyatra"
# Create vector store and index
vector_store = LanceDBVectorStore(
uri="./lancedb/dev",
table_name=table_name,
mode="overwrite",
query_type="hybrid"
)
# Load CSV and create nodes
with open('combined_digi_yatra.csv', newline='') as f:
nodes = [
TextNode(text=str(row), id_=str(uuid.uuid4()))
for row in list(csv.reader(f))[1:]
]
# Create and persist index
index = VectorStoreIndex(nodes, vector_store=vector_store)
index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")
# Store in SQLite
db = get_db()
db.execute(
'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)',
(user_id, table_name, table_name)
)
db.execute(
'INSERT INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)',
(table_name, 'combined_digi_yatra.csv', 'combined_digi_yatra.csv')
)
db.commit()
@router.on_event("shutdown")
async def shutdown():
print("RAG Router shutdown")