File size: 3,534 Bytes
c33d1d0
018fb30
e8faa1f
 
 
 
 
 
 
 
 
 
 
b08204e
e8faa1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b08204e
 
e8faa1f
70bd277
e8faa1f
 
 
 
ea07eae
e8faa1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import gradio as gr
from rag_tool import RAGTool

# Initialize the RAG Tool with default settings
rag_tool = RAGTool(
    documents_path="./documents",
    embedding_model="sentence-transformers/all-MiniLM-L6-v2",
    vector_store_type="faiss",
    chunk_size=1000,
    chunk_overlap=200,
    persist_directory="./vector_store"
)

# Function to handle document uploads
def upload_documents(files, chunk_size, chunk_overlap, embedding_model, vector_store_type):
    # Create a temporary directory for uploaded files
    os.makedirs("./uploaded_docs", exist_ok=True)
    
    # Save uploaded files
    for file in files:
        file_path = os.path.join("./uploaded_docs", os.path.basename(file.name))
        with open(file_path, "wb") as f:
            f.write(file.read())
    
    # Initialize a new RAG Tool with the uploaded documents
    global rag_tool
    rag_tool = RAGTool(
        documents_path="./uploaded_docs",
        embedding_model=embedding_model,
        vector_store_type=vector_store_type,
        chunk_size=int(chunk_size),
        chunk_overlap=int(chunk_overlap),
        persist_directory="./uploaded_vector_store"
    )
    
    return f"Documents uploaded and processed. Vector store created with {embedding_model} model."

# Function to handle queries
def query_documents(query, top_k):
    global rag_tool
    return rag_tool(query, top_k=int(top_k))

# Gradio interface
with gr.Blocks(title="Advanced RAG Tool") as demo:
    gr.Markdown("# Advanced RAG Tool")
    gr.Markdown("Upload documents and query them using semantic search")
    
    with gr.Tab("Upload & Configure"):
        with gr.Row():
            with gr.Column():
                files = gr.File(file_count="multiple", label="Upload Documents")
                chunk_size = gr.Slider(200, 2000, value=1000, step=100, label="Chunk Size")
                chunk_overlap = gr.Slider(0, 500, value=200, step=50, label="Chunk Overlap")
            
            with gr.Column():
                embedding_models = [
                    "sentence-transformers/all-MiniLM-L6-v2",
                    "BAAI/bge-small-en-v1.5",
                    "BAAI/bge-base-en-v1.5",
                    "thenlper/gte-small",
                    "thenlper/gte-base"
                ]
                embedding_model = gr.Dropdown(
                    choices=embedding_models, 
                    value="sentence-transformers/all-MiniLM-L6-v2",
                    label="Embedding Model"
                )
                vector_store_type = gr.Radio(
                    choices=["faiss", "chroma"], 
                    value="faiss",
                    label="Vector Store Type"
                )
        
        upload_button = gr.Button("Upload and Process Documents")
        upload_result = gr.Textbox(label="Upload Result")
        
        upload_button.click(
            upload_documents,
            inputs=[files, chunk_size, chunk_overlap, embedding_model, vector_store_type],
            outputs=upload_result
        )
    
    with gr.Tab("Query Documents"):
        query = gr.Textbox(label="Your Question", placeholder="What information are you looking for?")
        top_k = gr.Slider(1, 10, value=3, step=1, label="Number of Results")
        query_button = gr.Button("Search")
        answer = gr.Textbox(label="Results")
        
        query_button.click(
            query_documents,
            inputs=[query, top_k],
            outputs=answer
        )

# Launch the app
if __name__ == "__main__":
    demo.launch()