File size: 5,979 Bytes
353e791
491af27
96cc439
 
 
 
 
 
 
 
353e791
 
96cc439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353e791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7211b51
 
 
 
 
 
 
 
 
 
 
353e791
7211b51
353e791
7211b51
353e791
 
7211b51
 
 
 
 
 
353e791
 
 
7211b51
353e791
 
 
 
 
 
 
8ae9422
353e791
 
 
 
 
 
 
 
 
 
 
8ae9422
353e791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7211b51
 
 
 
 
353e791
7211b51
353e791
7211b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import streamlit as st
from openai import OpenAI
from PyPDF2 import PdfReader
from pinecone import Pinecone
import uuid
from dotenv import load_dotenv
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

load_dotenv()

# Set up OpenAI client
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

# Set up Pinecone
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))

index_name = "main"  # Your index name
index = pc.Index(index_name)

def get_embedding(text):
    response = client.embeddings.create(input=text, model="text-embedding-3-large")
    return response.data[0].embedding

def process_pdf(file):
    reader = PdfReader(file)
    text = ""
    for page in reader.pages:
        text += page.extract_text() + "\n"
    return text

def process_upload(upload_type, file_or_link, file_name=None):
    print(f"Starting process_upload for {upload_type}")
    doc_id = str(uuid.uuid4())
    print(f"Generated doc_id: {doc_id}")

    if upload_type == "PDF":
        content = process_pdf(file_or_link)
        doc_name = file_name or "Uploaded PDF"
    else:
        print("Invalid upload type")
        return "Invalid upload type"

    content_length = len(content)
    print(f"Content extracted, length: {content_length}")

    # Dynamically adjust chunk size based on content length
    if content_length < 10000:
        chunk_size = 1000
    elif content_length < 100000:
        chunk_size = 2000
    else:
        chunk_size = 4000
    print(f"Using chunk size: {chunk_size}")

    chunks = [content[i:i+chunk_size] for i in range(0, content_length, chunk_size)]
    
    vectors = []
    total_chunks = len(chunks)

    # Use st.session_state to manage progress bar across function calls if needed on the page
    if 'upload_progress' in st.session_state and hasattr(st.session_state.upload_progress, 'progress'):
        progress_bar = st.session_state.upload_progress
    else:
        # If called outside the context of the upload page button press, handle appropriately
        # For now, let's assume it's called from the Upload page context where progress is set
        pass 


    with ThreadPoolExecutor() as executor:
        futures = {executor.submit(process_chunk, chunk, doc_id, i, upload_type, doc_name): i for i, chunk in enumerate(chunks)}
        
        processed_count = 0
        for future in as_completed(futures):
            vectors.append(future.result())
            processed_count += 1
            # Update progress if progress_bar exists
            if 'progress_bar' in locals() and progress_bar:
                 current_progress = processed_count / total_chunks
                 progress_bar.progress(current_progress)

    
    print(f"Generated {len(vectors)} vectors")
    
    # Consider batching upserts for very large documents
    index.upsert(vectors=vectors)
    print("Vectors upserted to Pinecone")
    
    return f"Processing complete for {upload_type}. Document Name: {doc_name}"

def process_chunk(chunk, doc_id, i, upload_type, doc_name):
    embedding = get_embedding(chunk)
    return (f"{doc_id}_{i}", embedding, {
        "text": chunk, 
        "type": upload_type,
        "doc_id": doc_id,
        "doc_name": doc_name,
        "chunk_index": i
    })

def get_relevant_context(query, top_k=5):
    print(f"Getting relevant context for query: {query}")
    query_embedding = get_embedding(query)
    
    search_results = index.query(vector=query_embedding, top_k=top_k, include_metadata=True)
    print(f"Found {len(search_results['matches'])} relevant results")
    
    # Sort results by doc_id and chunk_index to maintain document structure
    sorted_results = sorted(search_results['matches'], key=lambda x: (x['metadata']['doc_id'], x['metadata']['chunk_index']))
    
    context = "\n".join([result['metadata']['text'] for result in sorted_results])
    return context, sorted_results

def chat_with_ai(message):
    print(f"Chatting with AI, message: {message}")
    context, results = get_relevant_context(message)
    print(f"Retrieved context, length: {len(context)}")
    
    messages = [
        {"role": "system", "content": "You are a helpful assistant. Use the following information to answer the user's question, but don't mention the context directly in your response. If the information isn't in the context, say you don't know."},
        {"role": "system", "content": f"Context: {context}"},
        {"role": "user", "content": message}
    ]
    
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages
    )
    print("Received response from OpenAI")
    
    ai_response = response.choices[0].message.content
    
    # Prepare source information
    sources = [
        {
            "doc_id": result['metadata']['doc_id'],
            "doc_name": result['metadata']['doc_name'],
            "chunk_index": result['metadata']['chunk_index'],
            "text": result['metadata']['text'],
        }
        for result in results
    ]
    
    return ai_response, sources

def clear_database():
    print("Clearing database...")
    index.delete(delete_all=True)
    print("Database cleared")
    return "Database cleared successfully."

# Streamlit Main Page
st.set_page_config(
    page_title="RAG Chat Home",
    page_icon="👋",
)

st.title("Welcome to RAG Chat! 👋")

st.sidebar.success("Select a page above.")

st.markdown(
    """
    This application allows you to upload PDF documents and chat with an AI
    about their content.

    **👈 Select a page from the sidebar** to get started:
    - **Upload:** Add your PDF documents to the knowledge base.
    - **Chat:** Ask questions about the documents you've uploaded.

    The AI uses Retrieval-Augmented Generation (RAG) to find relevant sections
    from your documents and provide informed answers.
"""
)

# No UI elements here, just the core logic and initialization above.
# The pages in the 'pages' directory will handle the UI.