from fastapi import FastAPI, Depends, HTTPException, UploadFile, File import pandas as pd import lancedb from functools import cached_property, lru_cache from pydantic import Field, BaseModel from typing import Optional, Dict, List, Annotated, Any from fastapi import APIRouter import uuid import io from io import BytesIO import csv import sqlite3 # LlamaIndex imports from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex from llama_index.vector_stores.lancedb import LanceDBVectorStore from llama_index.embeddings.fastembed import FastEmbedEmbedding from llama_index.core.schema import TextNode from llama_index.core import StorageContext, load_index_from_storage import json import os import shutil router = APIRouter( prefix="/rag", tags=["rag"] ) # Configure global LlamaIndex settings Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5") # Database connection dependency @lru_cache() def get_db_connection(db_path: str = "./lancedb/dev"): return lancedb.connect(db_path) def get_db(): conn = sqlite3.connect('./data/tablesv2.db') conn.row_factory = sqlite3.Row return conn def init_db(): db = get_db() db.execute(''' CREATE TABLE IF NOT EXISTS tables ( id INTEGER PRIMARY KEY, user_id TEXT NOT NULL, table_id TEXT NOT NULL, table_name TEXT NOT NULL, created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) ''') db.execute(''' CREATE TABLE IF NOT EXISTS table_files ( id INTEGER PRIMARY KEY, table_id TEXT NOT NULL, filename TEXT NOT NULL, file_path TEXT NOT NULL, FOREIGN KEY (table_id) REFERENCES tables (table_id), UNIQUE(table_id, filename) ) ''') db.commit() # Pydantic models class CreateTableResponse(BaseModel): table_id: str message: str status: str table_name: str class QueryTableResponse(BaseModel): results: Dict[str, Any] total_results: int @router.post("/create_table", response_model=CreateTableResponse) async def create_embedding_table( user_id: str, files: List[UploadFile] = File(...), table_id: Optional[str] = None, table_name: Optional[str] = None ) -> CreateTableResponse: try: db = get_db() table_id = table_id or str(uuid.uuid4()) table_name = table_name or f"knowledge-base-{str(uuid.uuid4())[:4]}" # Check if table exists existing = db.execute( 'SELECT id FROM tables WHERE user_id = ? AND table_id = ?', (user_id, table_id) ).fetchone() directory_path = f"./data/{table_id}" os.makedirs(directory_path, exist_ok=True) for file in files: if not file.filename: raise HTTPException(status_code=400, detail="Invalid filename") if os.path.splitext(file.filename)[1].lower() not in {".pdf", ".docx", ".csv", ".txt", ".md"}: raise HTTPException(status_code=400, detail="Unsupported file type") file_path = os.path.join(directory_path, file.filename) with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) vector_store = LanceDBVectorStore( uri="./lancedb/dev", table_name=table_id, mode="overwrite", query_type="hybrid" ) documents = SimpleDirectoryReader(directory_path).load_data() index = VectorStoreIndex.from_documents(documents, vector_store=vector_store) index.storage_context.persist(persist_dir=f"./lancedb/index/{table_id}") if not existing: db.execute( 'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)', (user_id, table_id, table_name) ) for file in files: db.execute( 'INSERT OR REPLACE INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)', (table_id, file.filename, f"./data/{table_id}/{file.filename}") ) db.commit() return CreateTableResponse( table_id=table_id, message="Success", status="success", table_name=table_name ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/query_table/{table_id}", response_model=QueryTableResponse) async def query_table( table_id: str, query: str, user_id: str, #db: Annotated[Any, Depends(get_db_connection)], limit: Optional[int] = 10 ) -> QueryTableResponse: """Query the database table using LlamaIndex.""" try: table_name = table_id #f"{user_id}__table__{table_id}" # load index and retriever storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}") index = load_index_from_storage(storage_context) retriever = index.as_retriever(similarity_top_k=limit) # Get response response = retriever.retrieve(query) # Format results results = [{ 'text': node.text, 'score': node.score } for node in response] return QueryTableResponse( results={'data': results}, total_results=len(results) ) except Exception as e: raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}") @router.get("/get_tables/{user_id}") async def get_tables(user_id: str): db = get_db() tables = db.execute(''' SELECT t.table_id, t.table_name, t.created_time as created_at, GROUP_CONCAT(tf.filename) as filenames FROM tables t LEFT JOIN table_files tf ON t.table_id = tf.table_id WHERE t.user_id = ? GROUP BY t.table_id ''', (user_id,)).fetchall() result = [] for table in tables: table_dict = dict(table) result.append({ 'table_id': table_dict['table_id'], 'table_name': table_dict['table_name'], 'created_at': table_dict['created_at'], 'documents': [filename for filename in table_dict['filenames'].split(',') if filename] if table_dict['filenames'] else [] }) return result @router.delete("/delete_table/{table_id}") async def delete_table(table_id: str, user_id: str): try: db = get_db() # Verify user owns the table table = db.execute( 'SELECT * FROM tables WHERE table_id = ? AND user_id = ?', (table_id, user_id) ).fetchone() if not table: raise HTTPException(status_code=404, detail="Table not found or unauthorized") # Delete files from filesystem table_path = f"./data/{table_id}" index_path = f"./lancedb/index/{table_id}" if os.path.exists(table_path): shutil.rmtree(table_path) if os.path.exists(index_path): shutil.rmtree(index_path) # Delete from database db.execute('DELETE FROM table_files WHERE table_id = ?', (table_id,)) db.execute('DELETE FROM tables WHERE table_id = ?', (table_id,)) db.commit() return {"message": "Table deleted successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/health") async def health_check(): return {"status": "healthy"} @router.on_event("startup") async def startup(): init_db() print("RAG Router started") table_name = "digiyatra" user_id = "digiyatra" db = get_db() # Check if table already exists existing = db.execute('SELECT id FROM tables WHERE table_id = ?', (table_name,)).fetchone() if not existing: vector_store = LanceDBVectorStore( uri="./lancedb/dev", table_name=table_name, mode="overwrite", query_type="hybrid" ) with open('combined_digi_yatra.csv', newline='') as f: nodes = [TextNode(text=str(row), id_=str(uuid.uuid4())) for row in list(csv.reader(f))[1:]] index = VectorStoreIndex(nodes, vector_store=vector_store) index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}") db.execute( 'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)', (user_id, table_name, table_name) ) db.execute( 'INSERT INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)', (table_name, 'combined_digi_yatra.csv', 'combined_digi_yatra.csv') ) db.commit() @router.on_event("shutdown") async def shutdown(): print("RAG Router shutdown")