import logging import json import pandas as pd import streamlit as st from pinecone import Pinecone from llama_index.vector_stores.pinecone import PineconeVectorStore from llama_index.core import ( StorageContext, VectorStoreIndex, SimpleDirectoryReader, get_response_synthesizer, Settings ) from llama_index.core.node_parser import SentenceSplitter from llama_index.core.retrievers import ( VectorIndexRetriever, RouterRetriever ) from llama_index.retrievers.bm25 import BM25Retriever from llama_index.core.tools import RetrieverTool from llama_index.core.query_engine import ( RetrieverQueryEngine, FLAREInstructQueryEngine, MultiStepQueryEngine ) from llama_index.core.indices.query.query_transform import ( StepDecomposeQueryTransform ) from llama_index.llms.groq import Groq from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.llms.azure_openai import AzureOpenAI from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.readers.file import PyMuPDFReader import traceback from oauth2client.service_account import ServiceAccountCredentials import gspread import uuid from dotenv import load_dotenv import os from datetime import datetime # Load environment variables load_dotenv() # Configure logging logging.basicConfig(level=logging.INFO) # Google Sheets setup scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"] creds_dict = { "type": os.getenv("type"), "project_id": os.getenv("project_id"), "private_key_id": os.getenv("private_key_id"), "private_key": os.getenv("private_key"), "client_email": os.getenv("client_email"), "client_id": os.getenv("client_id"), "auth_uri": os.getenv("auth_uri"), "token_uri": os.getenv("token_uri"), "auth_provider_x509_cert_url": os.getenv("auth_provider_x509_cert_url"), "client_x509_cert_url": os.getenv("client_x509_cert_url") } creds_dict['private_key'] = creds_dict['private_key'].replace('\\n', '\n') creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) client = gspread.authorize(creds) sheet = client.open("RAG").sheet1 # Fixed variables AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_DEPLOYMENT_NAME") AZURE_API_VERSION = os.getenv("AZURE_API_VERSION") AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") # Global variables for lazy loading llm = None pinecone_index = None def log_and_exit(message): logging.error(message) raise SystemExit(message) def initialize_apis(api, model, pinecone_api_key, groq_api_key, azure_api_key): global llm, pinecone_index try: if llm is None: llm = initialize_llm(api, model, groq_api_key, azure_api_key) if pinecone_index is None: pinecone_client = Pinecone(pinecone_api_key) pinecone_index = pinecone_client.Index("ll144") logging.info("Initialized LLM and Pinecone.") except Exception as e: log_and_exit(f"Error initializing APIs: {e}") def initialize_llm(api, model, groq_api_key, azure_api_key): if api == 'groq': model_mappings = { 'mixtral-8x7b': "mixtral-8x7b-32768", 'llama3-8b': "llama3-8b-8192", 'llama3-70b': "llama3-70b-8192", 'gemma-7b': "gemma-7b-it" } return Groq(model=model_mappings[model], api_key=groq_api_key) elif api == 'azure': if model == 'gpt35': return AzureOpenAI( deployment_name=AZURE_DEPLOYMENT_NAME, temperature=0, api_key=azure_api_key, azure_endpoint=AZURE_OPENAI_ENDPOINT, api_version=AZURE_API_VERSION ) def load_pdf_data(chunk_size): reader = PyMuPDFReader() file_extractor = {".pdf": reader} documents = SimpleDirectoryReader(input_files=['LL144.pdf', 'LL144_Definitions.pdf'], file_extractor=file_extractor).load_data() return documents def create_index(documents, embedding_model_type="HF", embedding_model="BAAI/bge-large-en-v1.5", retriever_method="BM25", chunk_size=512): global llm, pinecone_index try: embed_model = select_embedding_model(embedding_model_type, embedding_model) Settings.llm = llm Settings.embed_model = embed_model Settings.chunk_size = chunk_size if retriever_method in ["BM25", "BM25+Vector"]: nodes = create_bm25_nodes(documents, chunk_size) logging.info("Created BM25 nodes from documents.") if retriever_method == "BM25+Vector": vector_store = PineconeVectorStore(pinecone_index=pinecone_index) storage_context = StorageContext.from_defaults(vector_store=vector_store) index = VectorStoreIndex.from_documents(documents, storage_context=storage_context) logging.info("Created index for BM25+Vector from documents.") return index, nodes return None, nodes else: vector_store = PineconeVectorStore(pinecone_index=pinecone_index) storage_context = StorageContext.from_defaults(vector_store=vector_store) index = VectorStoreIndex.from_documents(documents, storage_context=storage_context) logging.info("Created index from documents.") return index, None except Exception as e: log_and_exit(f"Error creating index: {e}") def select_embedding_model(embedding_model_type, embedding_model): if embedding_model_type == "HF": return HuggingFaceEmbedding(model_name=embedding_model) elif embedding_model_type == "OAI": return OpenAIEmbedding() # Implement OAI Embedding if needed def create_bm25_nodes(documents, chunk_size): splitter = SentenceSplitter(chunk_size=chunk_size) nodes = splitter.get_nodes_from_documents(documents) return nodes def select_retriever(index, nodes, retriever_method, top_k): logging.info(f"Selecting retriever with method: {retriever_method}") if nodes is not None: logging.info(f"Available document IDs: {list(range(len(nodes)))}") else: logging.warning("Nodes are None") if retriever_method == 'BM25': return BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k) elif retriever_method == "BM25+Vector": if index is None: log_and_exit("Index must be initialized when using BM25+Vector retriever method.") bm25_retriever = RetrieverTool.from_defaults( retriever=BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k), description="BM25 Retriever" ) vector_retriever = RetrieverTool.from_defaults( retriever=VectorIndexRetriever(index=index), description="Vector Retriever" ) router_retriever = RouterRetriever.from_defaults( retriever_tools=[bm25_retriever, vector_retriever], llm=llm, select_multi=True ) return router_retriever elif retriever_method == "Vector Search": if index is None: log_and_exit("Index must be initialized when using Vector Search retriever method.") return VectorIndexRetriever(index=index, similarity_top_k=top_k) else: log_and_exit(f"Unsupported retriever method: {retriever_method}") def setup_query_engine(index, response_mode, nodes=None, query_engine_method=None, retriever_method=None, top_k=2): global llm try: logging.info(f"Setting up query engine with retriever_method: {retriever_method} and query_engine_method: {query_engine_method}") retriever = select_retriever(index, nodes, retriever_method, top_k) if retriever is None: log_and_exit("Failed to create retriever. Index or nodes might be None.") response_synthesizer = get_response_synthesizer(response_mode=response_mode) index_query_engine = index.as_query_engine(similarity_top_k=top_k) if index else None if query_engine_method == "FLARE": query_engine = FLAREInstructQueryEngine( query_engine=index_query_engine, max_iterations=4, verbose=False ) elif query_engine_method == "MS": query_engine = MultiStepQueryEngine( query_engine=index_query_engine, query_transform=StepDecomposeQueryTransform(llm=llm, verbose=False), index_summary="Used to answer questions about the regulation" ) else: query_engine = RetrieverQueryEngine(retriever=retriever, response_synthesizer=response_synthesizer) if query_engine is None: log_and_exit("Failed to create query engine.") return query_engine except Exception as e: logging.error(f"Error setting up query engine: {e}") traceback.print_exc() log_and_exit(f"Error setting up query engine: {e}") def log_to_google_sheets(data): try: sheet.append_row(data) logging.info("Logged data to Google Sheets.") except Exception as e: logging.error(f"Error logging data to Google Sheets: {e}") def update_google_sheets(question_id, feedback=None, detailed_feedback=None, annotated_answer=None): try: existing_data = sheet.get_all_values() headers = existing_data[0] for i, row in enumerate(existing_data): if row[0] == question_id: if feedback is not None: sheet.update_cell(i+1, headers.index("Feedback") + 1, feedback) if detailed_feedback is not None: sheet.update_cell(i+1, headers.index("Detailed Feedback") + 1, detailed_feedback) if annotated_answer is not None: sheet.update_cell(i+1, headers.index("annotated_answer") + 1, annotated_answer) logging.info("Updated data in Google Sheets.") return except Exception as e: logging.error(f"Error updating data in Google Sheets: {e}") def run_streamlit_app(): if 'query_engine' not in st.session_state: st.session_state.query_engine = None st.title("RAG Chat Application") col1, col2 = st.columns(2) with col1: pinecone_api_key = st.text_input("Pinecone API Key") azure_api_key = st.text_input("Azure API Key") groq_api_key = st.text_input("Groq API Key") def update_api_based_on_model(): selected_model = st.session_state['selected_model'] if selected_model == 'gpt35': st.session_state['selected_api'] = 'azure' else: st.session_state['selected_api'] = 'groq' with col2: selected_model = st.selectbox("Select Model", ["llama3-8b", "llama3-70b", "mixtral-8x7b", "gemma-7b", "gpt35"], index=4, key='selected_model', on_change=update_api_based_on_model) selected_api = st.selectbox("Select API", ["azure", "groq"], index=0, key='selected_api', disabled=True) embedding_model_type = "HF" embedding_model = st.selectbox("Select Embedding Model", ["BAAI/bge-large-en-v1.5", "other_model"]) retriever_method = st.selectbox("Select Retriever Method", ["Vector Search", "BM25", "BM25+Vector"]) col3, col4 = st.columns(2) with col3: chunk_size = st.selectbox("Select Chunk Size", [128, 256, 512, 1024], index=2) with col4: top_k = st.selectbox("Select Top K", [1, 2, 3, 5, 6], index=1) if st.button("Initialize"): initialize_apis(st.session_state['selected_api'], selected_model, pinecone_api_key, groq_api_key, azure_api_key) documents = load_pdf_data(chunk_size) index, nodes = create_index(documents, embedding_model_type=embedding_model_type, embedding_model=embedding_model, retriever_method=retriever_method, chunk_size=chunk_size) st.session_state.query_engine = setup_query_engine(index, response_mode="compact", nodes=nodes, query_engine_method=None, retriever_method=retriever_method, top_k=top_k) st.success("Initialization complete.") if 'chat_history' not in st.session_state: st.session_state.chat_history = [] for chat_index, chat in enumerate(st.session_state.chat_history): with st.chat_message("user"): st.markdown(chat['user']) with st.chat_message("bot"): st.markdown("### Retrieved Contexts") for node in chat.get('contexts', []): st.markdown( f"