File size: 8,747 Bytes
c73bfb4
 
 
 
 
 
 
 
 
 
 
131998f
c73bfb4
 
 
 
 
37d6123
c73bfb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131998f
73a2ea8
131998f
 
 
 
 
 
 
 
 
 
1e97c1e
 
131998f
 
ed6643b
0ee4f3c
 
 
 
 
 
 
 
ed6643b
 
131998f
c73bfb4
 
 
 
 
37d6123
c73bfb4
 
 
 
 
 
 
 
1b031c5
 
 
 
c73bfb4
1b031c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c73bfb4
a5ed4c9
c73bfb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6c9315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c73bfb4
2c33c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37d6123
 
 
 
c73bfb4
 
410c5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131998f
410c5d9
 
 
131998f
410c5d9
 
 
 
 
 
 
 
 
 
 
 
c73bfb4
37d6123
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
import pandas as pd
import lancedb
from functools import cached_property, lru_cache
from pydantic import Field, BaseModel
from typing import Optional, Dict, List, Annotated, Any
from fastapi import APIRouter
import uuid
import io
from io import BytesIO
import csv
import sqlite3

# LlamaIndex imports
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.core.schema import TextNode
from llama_index.core import StorageContext, load_index_from_storage
import json
import os
import shutil

router = APIRouter(
    prefix="/rag",
    tags=["rag"]
)

# Configure global LlamaIndex settings
Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5")

# Database connection dependency
@lru_cache()
def get_db_connection(db_path: str = "./lancedb/dev"):
    return lancedb.connect(db_path)

def get_db():
   conn = sqlite3.connect('./data/tablesv2.db')
   conn.row_factory = sqlite3.Row
   return conn

def init_db():
   db = get_db()
   db.execute('''
       CREATE TABLE IF NOT EXISTS tables (
           id INTEGER PRIMARY KEY,
           user_id TEXT NOT NULL,
           table_id TEXT NOT NULL,
           table_name TEXT NOT NULL,
           created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
       )
   ''')
   db.execute('''
        CREATE TABLE IF NOT EXISTS table_files (
            id INTEGER PRIMARY KEY,
            table_id TEXT NOT NULL,
            filename TEXT NOT NULL,
            file_path TEXT NOT NULL,
            FOREIGN KEY (table_id) REFERENCES tables (table_id),
            UNIQUE(table_id, filename)
        )
   ''')
   db.commit()

# Pydantic models
class CreateTableResponse(BaseModel):
    table_id: str
    message: str
    status: str
    table_name: str

class QueryTableResponse(BaseModel):
    results: Dict[str, Any]
    total_results: int


@router.post("/create_table", response_model=CreateTableResponse)
async def create_embedding_table(
  user_id: str,
  files: List[UploadFile] = File(...),
  table_id: Optional[str] = None,
  table_name: Optional[str] = None
) -> CreateTableResponse:
  try:
      db = get_db()
      table_id = table_id or str(uuid.uuid4())
      table_name = table_name or f"knowledge-base-{str(uuid.uuid4())[:4]}"
      
      # Check if table exists
      existing = db.execute(
          'SELECT id FROM tables WHERE user_id = ? AND table_id = ?', 
          (user_id, table_id)
      ).fetchone()

      directory_path = f"./data/{table_id}"
      os.makedirs(directory_path, exist_ok=True)

      for file in files:
          if not file.filename:
              raise HTTPException(status_code=400, detail="Invalid filename")
          if os.path.splitext(file.filename)[1].lower() not in {".pdf", ".docx", ".csv", ".txt", ".md"}:
              raise HTTPException(status_code=400, detail="Unsupported file type")
              
          file_path = os.path.join(directory_path, file.filename)
          with open(file_path, "wb") as buffer:
              shutil.copyfileobj(file.file, buffer)

      vector_store = LanceDBVectorStore(
          uri="./lancedb/dev",
          table_name=table_id,
          mode="overwrite",
          query_type="hybrid"
      )

      documents = SimpleDirectoryReader(directory_path).load_data()
      index = VectorStoreIndex.from_documents(documents, vector_store=vector_store)
      index.storage_context.persist(persist_dir=f"./lancedb/index/{table_id}")

      if not existing:
          db.execute(
              'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)',
              (user_id, table_id, table_name)
          )
      
      for file in files:
          db.execute(
              'INSERT OR REPLACE INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)',
              (table_id, file.filename, f"./data/{table_id}/{file.filename}")
          )
      db.commit()

      return CreateTableResponse(
          table_id=table_id,
          message="Success",
          status="success",
          table_name=table_name
      )

  except Exception as e:
      raise HTTPException(status_code=500, detail=str(e))


@router.post("/query_table/{table_id}", response_model=QueryTableResponse)
async def query_table(
    table_id: str,
    query: str,
    user_id: str,
    #db: Annotated[Any, Depends(get_db_connection)],
    limit: Optional[int] = 10
) -> QueryTableResponse:
    """Query the database table using LlamaIndex."""
    try:
        table_name = table_id  #f"{user_id}__table__{table_id}"
        
        # load index and retriever
        storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}")
        index = load_index_from_storage(storage_context)
        retriever = index.as_retriever(similarity_top_k=limit)
        
        # Get response
        response = retriever.retrieve(query)
        
        # Format results
        results = [{
            'text': node.text,
            'score': node.score
        } for node in response]
        
        return QueryTableResponse(
            results={'data': results},
            total_results=len(results)
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")

@router.get("/get_tables/{user_id}")
async def get_tables(user_id: str):
   db = get_db()
   tables = db.execute('''
       SELECT 
           t.table_id,
           t.table_name,
           t.created_time as created_at,
           GROUP_CONCAT(tf.filename) as filenames
       FROM tables t
       LEFT JOIN table_files tf ON t.table_id = tf.table_id
       WHERE t.user_id = ?
       GROUP BY t.table_id
   ''', (user_id,)).fetchall()
   
   result = []
   for table in tables:
       table_dict = dict(table)
       result.append({
           'table_id': table_dict['table_id'],
           'table_name': table_dict['table_name'],
           'created_at': table_dict['created_at'],
           'documents': [filename for filename in table_dict['filenames'].split(',') if filename] if table_dict['filenames'] else []
       })
   
   return result


@router.delete("/delete_table/{table_id}")
async def delete_table(table_id: str, user_id: str):
   try:
       db = get_db()
       
       # Verify user owns the table
       table = db.execute(
           'SELECT * FROM tables WHERE table_id = ? AND user_id = ?',
           (table_id, user_id)
       ).fetchone()
       
       if not table:
           raise HTTPException(status_code=404, detail="Table not found or unauthorized")

       # Delete files from filesystem
       table_path = f"./data/{table_id}"
       index_path = f"./lancedb/index/{table_id}"
       if os.path.exists(table_path):
           shutil.rmtree(table_path)
       if os.path.exists(index_path):
           shutil.rmtree(index_path)

       # Delete from database
       db.execute('DELETE FROM table_files WHERE table_id = ?', (table_id,))
       db.execute('DELETE FROM tables WHERE table_id = ?', (table_id,))
       db.commit()

       return {"message": "Table deleted successfully"}

   except Exception as e:
       raise HTTPException(status_code=500, detail=str(e))


@router.get("/health")
async def health_check():
    return {"status": "healthy"}

@router.on_event("startup")
async def startup():
    init_db()
    print("RAG Router started")
    
    table_name = "digiyatra"
    user_id = "digiyatra"
    
    db = get_db()
    # Check if table already exists
    existing = db.execute('SELECT id FROM tables WHERE table_id = ?', (table_name,)).fetchone()
    if not existing:
        vector_store = LanceDBVectorStore(
            uri="./lancedb/dev",
            table_name=table_name,
            mode="overwrite", 
            query_type="hybrid"
        )

        with open('combined_digi_yatra.csv', newline='') as f:
            nodes = [TextNode(text=str(row), id_=str(uuid.uuid4())) 
                    for row in list(csv.reader(f))[1:]]

        index = VectorStoreIndex(nodes, vector_store=vector_store)
        index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")

        db.execute(
            'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)',
            (user_id, table_name, table_name)
        )
        db.execute(
            'INSERT INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)',
            (table_name, 'combined_digi_yatra.csv', 'combined_digi_yatra.csv')
        )
        db.commit()

@router.on_event("shutdown")
async def shutdown():
    print("RAG Router shutdown")