File size: 7,796 Bytes
c73bfb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37d6123
c73bfb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37d6123
c73bfb4
 
 
 
 
 
 
 
 
 
37d6123
 
c73bfb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8e46cf
 
37d6123
c73bfb4
 
37d6123
c73bfb4
 
 
 
 
 
 
 
 
 
 
 
37d6123
d8ec593
 
c73bfb4
 
 
 
 
 
 
 
 
 
37d6123
c73bfb4
aceb069
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2afa572
 
 
c73bfb4
 
a5ed4c9
37d6123
b8e46cf
c73bfb4
 
a5ed4c9
c73bfb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37d6123
 
 
 
c73bfb4
 
 
 
 
e4b3e3a
c73bfb4
d8ec593
c73bfb4
 
d8ec593
 
c73bfb4
 
e4b3e3a
 
843a697
 
 
c73bfb4
 
 
 
 
 
 
 
37d6123
 
 
 
 
 
 
c73bfb4
 
 
37d6123
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
import pandas as pd
import lancedb
from functools import cached_property, lru_cache
from pydantic import Field, BaseModel
from typing import Optional, Dict, List, Annotated, Any
from fastapi import APIRouter
import uuid
import io
from io import BytesIO
import csv

# LlamaIndex imports
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.core.schema import TextNode
from llama_index.core import StorageContext, load_index_from_storage
import json
import os
import shutil

router = APIRouter(
    prefix="/rag",
    tags=["rag"]
)

# Configure global LlamaIndex settings
Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5")
tables_file_path = './data/tables.json'

# Database connection dependency
@lru_cache()
def get_db_connection(db_path: str = "./lancedb/dev"):
    return lancedb.connect(db_path)

# Pydantic models
class CreateTableResponse(BaseModel):
    table_id: str
    message: str
    status: str
    table_name: str

class QueryTableResponse(BaseModel):
    results: Dict[str, Any]
    total_results: int


@router.post("/create_table", response_model=CreateTableResponse)
async def create_embedding_table(
    user_id: str,
    files: List[UploadFile] = File(...),
    table_id: Optional[str] = None,
    table_name: Optional[str] = None
) -> CreateTableResponse:
    """Create a table and load embeddings from uploaded files using LlamaIndex."""
    allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"}
    for file in files:
        if file.filename is None:
            raise HTTPException(status_code=400, detail="File must have a valid name.")
        file_extension = os.path.splitext(file.filename)[1].lower()
        if file_extension not in allowed_extensions:
            raise HTTPException(
                status_code=400, 
                detail=f"File type {file_extension} is not allowed. Supported file types are: {', '.join(allowed_extensions)}."
            )
    
    if table_id is None:
        table_id = str(uuid.uuid4())
    table_name = f"knowledge-base-{str(uuid.uuid4())[:4]}" if not table_name else table_name
    
    #table_name = table_id #f"{user_id}__table__{table_id}"
    
    # Create a directory for the uploaded files
    directory_path = f"./data/{table_id}"
    os.makedirs(directory_path, exist_ok=True)

    # Save each uploaded file to the data directory
    for file in files:
        file_path = os.path.join(directory_path, file.filename)
        with open(file_path, "wb") as buffer:
            shutil.copyfileobj(file.file, buffer)

    try:
        # Setup LanceDB vector store
        vector_store = LanceDBVectorStore(
            uri="./lancedb/dev",
            table_name=table_id,
            mode="overwrite",
            query_type="hybrid"
        )

        # Load documents using SimpleDirectoryReader
        documents = SimpleDirectoryReader(directory_path).load_data()
        
        # Create the index
        index = VectorStoreIndex.from_documents(
            documents,
            vector_store=vector_store
        )
        index.storage_context.persist(persist_dir=f"./lancedb/index/{table_id}")


        # Store user_id and table_name in a JSON file
        try:
            tables_file_path = './data/tables.json'
            os.makedirs(os.path.dirname(tables_file_path), exist_ok=True)
            # Load existing tables or create a new file if it doesn't exist
            try:
                with open(tables_file_path, 'r') as f:
                    tables = json.load(f)
            except (FileNotFoundError, json.JSONDecodeError):
                tables = {}
    
            # Update the tables dictionary
            if user_id not in tables:
                tables[user_id] = []
            if table_id not in [table['table_id'] for table in tables[user_id]]:
                tables[user_id].append({"table_id": table_id, "table_name": table_name})
    
            # Write the updated tables back to the JSON file
            with open(tables_file_path, 'w') as f:
                json.dump(tables, f)
        
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to update tables file: {str(e)}")
        
        return CreateTableResponse(
            table_id=table_id,
            message="Table created and documents indexed successfully",
            status="success",
            table_name=table_name
        )


    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Table creation failed: {str(e)}")

@router.post("/query_table/{table_id}", response_model=QueryTableResponse)
async def query_table(
    table_id: str,
    query: str,
    user_id: str,
    #db: Annotated[Any, Depends(get_db_connection)],
    limit: Optional[int] = 10
) -> QueryTableResponse:
    """Query the database table using LlamaIndex."""
    try:
        table_name = table_id  #f"{user_id}__table__{table_id}"
        
        # load index and retriever
        storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}")
        index = load_index_from_storage(storage_context)
        retriever = index.as_retriever(similarity_top_k=limit)
        
        # Get response
        response = retriever.retrieve(query)
        
        # Format results
        results = [{
            'text': node.text,
            'score': node.score
        } for node in response]
        
        return QueryTableResponse(
            results={'data': results},
            total_results=len(results)
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")

@router.get("/get_tables/{user_id}")
async def get_tables(user_id: str):
    """Get all tables for a user."""
    
    tables_file_path = './data/tables.json'
    try:
        # Load existing tables from the JSON file
        with open(tables_file_path, 'r') as f:
            tables = json.load(f)

        # Retrieve tables for the specified user
        user_tables = tables.get(user_id, [])
        return user_tables

    except (FileNotFoundError, json.JSONDecodeError):
        return []  # Return an empty list if the file doesn't exist or is invalid
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to retrieve tables: {str(e)}")

@router.get("/health")
async def health_check():
    return {"status": "healthy"}

@router.on_event("startup")
async def startup():
    print("RAG Router started")
    from llama_index.core.schema import TextNode
    table_name = "digiyatra"
    nodes = []
    vector_store = LanceDBVectorStore(
        
        uri="./lancedb/dev",
        table_name=table_name,
        mode="overwrite",
        query_type="hybrid"
    )
    # load digiyatra csv and create node for each row using csv.reader
    with open('combined_digi_yatra.csv', newline='') as f:
        reader = csv.reader(f)
        data = list(reader)
        for row in data[1:]:
            node = TextNode(text=str(row), id_=str(uuid.uuid4()))
            nodes.append(node)

    index = VectorStoreIndex(nodes, vector_store=vector_store)
    index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")

    # Create tables dictionary
    tables = {}
    user_id = "digiyatra"

    tables[user_id] = [
        {
            "table_id": table_name,
            "table_name": table_name
        }
    ]
    with open(tables_file_path, 'w') as f:
        json.dump(tables, f)

@router.on_event("shutdown")
async def shutdown():
    print("RAG Router shutdown")