File size: 15,974 Bytes
1fff800
ab7968f
1fff800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f856eff
1fff800
 
 
 
 
 
 
 
 
 
 
 
 
dc92394
1fff800
 
 
 
 
 
 
cbb024a
1fff800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445d770
1fff800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445d770
1fff800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445d770
1fff800
 
 
 
 
 
 
 
 
445d770
 
1fff800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa16c7b
 
 
 
 
 
 
1fff800
445d770
 
1fff800
 
 
 
 
 
 
 
 
 
 
fa16c7b
1fff800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445d770
 
1fff800
445d770
 
 
 
 
 
1fff800
445d770
 
 
1fff800
445d770
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fff800
 
 
 
f5f3a5b
 
 
 
1fff800
 
 
f856eff
445d770
1fff800
 
445d770
1fff800
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import logging
import json
import pandas as pd
import streamlit as st
from pinecone import Pinecone
from llama_index.vector_stores.pinecone import PineconeVectorStore
from llama_index.core import (
    StorageContext, VectorStoreIndex, SimpleDirectoryReader, 
    get_response_synthesizer, Settings
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.retrievers import (
    VectorIndexRetriever, RouterRetriever
)
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.tools import RetrieverTool
from llama_index.core.query_engine import (
    RetrieverQueryEngine, FLAREInstructQueryEngine, MultiStepQueryEngine
)
from llama_index.core.indices.query.query_transform import (
    StepDecomposeQueryTransform
)
from llama_index.llms.groq import Groq
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.readers.file import PyMuPDFReader
import traceback
from oauth2client.service_account import ServiceAccountCredentials
import gspread
import uuid
from dotenv import load_dotenv
import os
from datetime import datetime

# Load environment variables
load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO)

# Google Sheets setup
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
creds_dict = {
    "type": os.getenv("type"),
    "project_id": os.getenv("project_id"),
    "private_key_id": os.getenv("private_key_id"),
    "private_key": os.getenv("private_key"),
    "client_email": os.getenv("client_email"),
    "client_id": os.getenv("client_id"),
    "auth_uri": os.getenv("auth_uri"),
    "token_uri": os.getenv("token_uri"),
    "auth_provider_x509_cert_url": os.getenv("auth_provider_x509_cert_url"),
    "client_x509_cert_url": os.getenv("client_x509_cert_url")
}
creds_dict['private_key'] = creds_dict['private_key'].replace('\\n', '\n')
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
client = gspread.authorize(creds)
sheet = client.open("RAG").sheet1

# Fixed variables
AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_DEPLOYMENT_NAME")
AZURE_API_VERSION = os.getenv("AZURE_API_VERSION")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")

# Global variables for lazy loading
llm = None
pinecone_index = None

def log_and_exit(message):
    logging.error(message)
    raise SystemExit(message)

def initialize_apis(api, model, pinecone_api_key, groq_api_key, azure_api_key):
    global llm, pinecone_index
    try:
        if llm is None:
            llm = initialize_llm(api, model, groq_api_key, azure_api_key)
        if pinecone_index is None:
            pinecone_client = Pinecone(pinecone_api_key)
            pinecone_index = pinecone_client.Index("ll144")
        logging.info("Initialized LLM and Pinecone.")
    except Exception as e:
        log_and_exit(f"Error initializing APIs: {e}")

def initialize_llm(api, model, groq_api_key, azure_api_key):
    if api == 'groq':
        model_mappings = {
            'mixtral-8x7b': "mixtral-8x7b-32768",
            'llama3-8b': "llama3-8b-8192",
            'llama3-70b': "llama3-70b-8192",
            'gemma-7b': "gemma-7b-it"
        }
        return Groq(model=model_mappings[model], api_key=groq_api_key)
    elif api == 'azure':
        if model == 'gpt35':
            return AzureOpenAI(
                deployment_name=AZURE_DEPLOYMENT_NAME,
                temperature=0,
                api_key=azure_api_key,
                azure_endpoint=AZURE_OPENAI_ENDPOINT,
                api_version=AZURE_API_VERSION
            )

def load_pdf_data(chunk_size):
    reader = PyMuPDFReader()
    file_extractor = {".pdf": reader}
    documents = SimpleDirectoryReader(input_files=['LL144.pdf', 'LL144_Definitions.pdf'], file_extractor=file_extractor).load_data()
    return documents

def create_index(documents, embedding_model_type="HF", embedding_model="BAAI/bge-large-en-v1.5", retriever_method="BM25", chunk_size=512):
    global llm, pinecone_index
    try:
        embed_model = select_embedding_model(embedding_model_type, embedding_model)

        Settings.llm = llm
        Settings.embed_model = embed_model
        Settings.chunk_size = chunk_size

        if retriever_method in ["BM25", "BM25+Vector"]:
            nodes = create_bm25_nodes(documents, chunk_size)
            logging.info("Created BM25 nodes from documents.")
            if retriever_method == "BM25+Vector":
                vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
                storage_context = StorageContext.from_defaults(vector_store=vector_store)
                index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
                logging.info("Created index for BM25+Vector from documents.")
                return index, nodes
            return None, nodes
        else:
            vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
            storage_context = StorageContext.from_defaults(vector_store=vector_store)
            index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
            logging.info("Created index from documents.")
            return index, None
    except Exception as e:
        log_and_exit(f"Error creating index: {e}")

def select_embedding_model(embedding_model_type, embedding_model):
    if embedding_model_type == "HF":
        return HuggingFaceEmbedding(model_name=embedding_model)
    elif embedding_model_type == "OAI":
        return OpenAIEmbedding()  # Implement OAI Embedding if needed

def create_bm25_nodes(documents, chunk_size):
    splitter = SentenceSplitter(chunk_size=chunk_size)
    nodes = splitter.get_nodes_from_documents(documents)
    return nodes

def select_retriever(index, nodes, retriever_method, top_k):
    logging.info(f"Selecting retriever with method: {retriever_method}")
    if nodes is not None:
        logging.info(f"Available document IDs: {list(range(len(nodes)))}")
    else:
        logging.warning("Nodes are None")

    if retriever_method == 'BM25':
        return BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k)
    elif retriever_method == "BM25+Vector":
        if index is None:
            log_and_exit("Index must be initialized when using BM25+Vector retriever method.")
        
        bm25_retriever = RetrieverTool.from_defaults(
            retriever=BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k),
            description="BM25 Retriever"
        )
        
        vector_retriever = RetrieverTool.from_defaults(
            retriever=VectorIndexRetriever(index=index),
            description="Vector Retriever"
        )
        
        router_retriever = RouterRetriever.from_defaults(
            retriever_tools=[bm25_retriever, vector_retriever],
            llm=llm,
            select_multi=True
        )
        return router_retriever
    elif retriever_method == "Vector Search":
        if index is None:
            log_and_exit("Index must be initialized when using Vector Search retriever method.")
        return VectorIndexRetriever(index=index, similarity_top_k=top_k)
    else:
        log_and_exit(f"Unsupported retriever method: {retriever_method}")

def setup_query_engine(index, response_mode, nodes=None, query_engine_method=None, retriever_method=None, top_k=2):
    global llm
    try:
        logging.info(f"Setting up query engine with retriever_method: {retriever_method} and query_engine_method: {query_engine_method}")
        retriever = select_retriever(index, nodes, retriever_method, top_k)
        
        if retriever is None:
            log_and_exit("Failed to create retriever. Index or nodes might be None.")

        response_synthesizer = get_response_synthesizer(response_mode=response_mode)
        index_query_engine = index.as_query_engine(similarity_top_k=top_k) if index else None

        if query_engine_method == "FLARE":
            query_engine = FLAREInstructQueryEngine(
                query_engine=index_query_engine,
                max_iterations=4,
                verbose=False
            )
        elif query_engine_method == "MS":
            query_engine = MultiStepQueryEngine(
                query_engine=index_query_engine,
                query_transform=StepDecomposeQueryTransform(llm=llm, verbose=False),
                index_summary="Used to answer questions about the regulation"
            )
        else:
            query_engine = RetrieverQueryEngine(retriever=retriever, response_synthesizer=response_synthesizer)
        
        if query_engine is None:
            log_and_exit("Failed to create query engine.")
        
        return query_engine
    except Exception as e:
        logging.error(f"Error setting up query engine: {e}")
        traceback.print_exc()
        log_and_exit(f"Error setting up query engine: {e}")

def log_to_google_sheets(data):
    try:
        sheet.append_row(data)
        logging.info("Logged data to Google Sheets.")
    except Exception as e:
        logging.error(f"Error logging data to Google Sheets: {e}")

def update_google_sheets(question_id, feedback=None, detailed_feedback=None, annotated_answer=None):
    try:
        existing_data = sheet.get_all_values()
        headers = existing_data[0]
        for i, row in enumerate(existing_data):
            if row[0] == question_id:
                if feedback is not None:
                    sheet.update_cell(i+1, headers.index("Feedback") + 1, feedback)
                if detailed_feedback is not None:
                    sheet.update_cell(i+1, headers.index("Detailed Feedback") + 1, detailed_feedback)
                if annotated_answer is not None:
                    sheet.update_cell(i+1, headers.index("annotated_answer") + 1, annotated_answer)
                logging.info("Updated data in Google Sheets.")
                return
    except Exception as e:
        logging.error(f"Error updating data in Google Sheets: {e}")

def run_streamlit_app():
    if 'query_engine' not in st.session_state:
        st.session_state.query_engine = None

    st.title("RAG Chat Application")

    col1, col2 = st.columns(2)

    with col1:
        pinecone_api_key = st.text_input("Pinecone API Key")
        azure_api_key = st.text_input("Azure API Key")
        groq_api_key = st.text_input("Groq API Key")

    def update_api_based_on_model():
        selected_model = st.session_state['selected_model']
        if selected_model == 'gpt35':
            st.session_state['selected_api'] = 'azure'
        else:
            st.session_state['selected_api'] = 'groq'

    with col2:
        selected_model = st.selectbox("Select Model", ["llama3-8b", "llama3-70b", "mixtral-8x7b", "gemma-7b", "gpt35"], index=4, key='selected_model', on_change=update_api_based_on_model)
        selected_api = st.selectbox("Select API", ["azure", "groq"], index=0, key='selected_api', disabled=True)
        embedding_model_type = "HF"
        embedding_model = st.selectbox("Select Embedding Model", ["BAAI/bge-large-en-v1.5", "other_model"])
        retriever_method = st.selectbox("Select Retriever Method", ["Vector Search", "BM25", "BM25+Vector"])

    col3, col4 = st.columns(2)
    with col3:
        chunk_size = st.selectbox("Select Chunk Size", [128, 256, 512, 1024], index=2)
    with col4:
        top_k = st.selectbox("Select Top K", [1, 2, 3, 5, 6], index=1)

    if st.button("Initialize"):
        initialize_apis(st.session_state['selected_api'], selected_model, pinecone_api_key, groq_api_key, azure_api_key)
        documents = load_pdf_data(chunk_size)
        index, nodes = create_index(documents, embedding_model_type=embedding_model_type, embedding_model=embedding_model, retriever_method=retriever_method, chunk_size=chunk_size)
        st.session_state.query_engine = setup_query_engine(index, response_mode="compact", nodes=nodes, query_engine_method=None, retriever_method=retriever_method, top_k=top_k)
        st.success("Initialization complete.")

    if 'chat_history' not in st.session_state:
        st.session_state.chat_history = []

    for chat_index, chat in enumerate(st.session_state.chat_history):
        with st.chat_message("user"):
            st.markdown(chat['user'])
        with st.chat_message("bot"):
            st.markdown("### Retrieved Contexts")
            for node in chat.get('contexts', []):
                st.markdown(
                    f"<div style='border:1px solid #ccc; padding:10px; margin:10px 0; font-size:small;'>{node.text}</div>",
                    unsafe_allow_html=True
                )
            st.markdown("### Answer")
            st.markdown(chat['response'])

            col1, col2 = st.columns([1, 1])
            with col1:
                if st.button("Annotate πŸ‘Ž", key=f"annotate_{chat_index}"):
                    chat['annotate'] = True
                    chat['feedback'] = -1
                    st.session_state.chat_history[chat_index] = chat
                    update_google_sheets(chat['id'], feedback=-1)
                    st.rerun()
            with col2:
                if st.button("Approve πŸ‘", key=f"approve_{chat_index}"):
                    chat['approved'] = True
                    chat['feedback'] = 1
                    st.session_state.chat_history[chat_index] = chat
                    update_google_sheets(chat['id'], feedback=1, annotated_answer=chat['response'])

            if chat.get('annotate', False):
                annotated_answer = st.text_area("Annotate Answer", value=chat['response'], key=f"annotate_text_{chat_index}")
                if st.button("Submit Annotated Answer", key=f"submit_annotate_{chat_index}"):
                    chat['annotated_answer'] = annotated_answer
                    chat['annotate'] = False
                    st.session_state.chat_history[chat_index] = chat
                    update_google_sheets(chat['id'], annotated_answer=annotated_answer)
            
            feedback = st.text_area("How was the response? Does it match the context? Does it answer the question fully?", key=f"textarea_{chat_index}")
            if st.button("Submit Feedback", key=f"submit_{chat_index}"):
                chat['detailed_feedback'] = feedback
                st.session_state.chat_history[chat_index] = chat
                update_google_sheets(chat['id'], detailed_feedback=feedback)

    if question := st.chat_input("Enter your question"):
        if st.session_state.query_engine:
            with st.spinner('Generating response...'):
                # Compile chat history for context
                history = "\n".join([f"Q: {chat['user']}\nA: {chat['response']}" for chat in st.session_state.chat_history])
                full_query = f"{history}\nQ: {question}"
                response = st.session_state.query_engine.query(full_query)
                logging.info(f"Generated response: {response.response}")
                logging.info(f"Retrieved contexts: {[node.text for node in response.source_nodes]}")
            question_id = str(uuid.uuid4())
            timestamp = datetime.now().isoformat()
            st.session_state.chat_history.append({'id': question_id, 'user': question, 'response': response.response, 'contexts': response.source_nodes, 'feedback': 0, 'detailed_feedback': '', 'annotated_answer': '', 'timestamp': timestamp})
            
            # Log initial query and response to Google Sheets without feedback
            log_to_google_sheets([question_id, question, response.response, st.session_state['selected_api'], selected_model, embedding_model, retriever_method, chunk_size, top_k, 0, "", "", timestamp])
            
            st.rerun()
        else:
            st.error("Query engine is not initialized. Please initialize it first.")

if __name__ == "__main__":
    run_streamlit_app()