HyPA-RAG / app.py
rk68's picture
Update app.py
445d770 verified
raw
history blame
16 kB
import logging
import json
import pandas as pd
import streamlit as st
from pinecone import Pinecone
from llama_index.vector_stores.pinecone import PineconeVectorStore
from llama_index.core import (
StorageContext, VectorStoreIndex, SimpleDirectoryReader,
get_response_synthesizer, Settings
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.retrievers import (
VectorIndexRetriever, RouterRetriever
)
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.tools import RetrieverTool
from llama_index.core.query_engine import (
RetrieverQueryEngine, FLAREInstructQueryEngine, MultiStepQueryEngine
)
from llama_index.core.indices.query.query_transform import (
StepDecomposeQueryTransform
)
from llama_index.llms.groq import Groq
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.readers.file import PyMuPDFReader
import traceback
from oauth2client.service_account import ServiceAccountCredentials
import gspread
import uuid
from dotenv import load_dotenv
import os
from datetime import datetime
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
# Google Sheets setup
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
creds_dict = {
"type": os.getenv("type"),
"project_id": os.getenv("project_id"),
"private_key_id": os.getenv("private_key_id"),
"private_key": os.getenv("private_key"),
"client_email": os.getenv("client_email"),
"client_id": os.getenv("client_id"),
"auth_uri": os.getenv("auth_uri"),
"token_uri": os.getenv("token_uri"),
"auth_provider_x509_cert_url": os.getenv("auth_provider_x509_cert_url"),
"client_x509_cert_url": os.getenv("client_x509_cert_url")
}
creds_dict['private_key'] = creds_dict['private_key'].replace('\\n', '\n')
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
client = gspread.authorize(creds)
sheet = client.open("RAG").sheet1
# Fixed variables
AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_DEPLOYMENT_NAME")
AZURE_API_VERSION = os.getenv("AZURE_API_VERSION")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
# Global variables for lazy loading
llm = None
pinecone_index = None
def log_and_exit(message):
logging.error(message)
raise SystemExit(message)
def initialize_apis(api, model, pinecone_api_key, groq_api_key, azure_api_key):
global llm, pinecone_index
try:
if llm is None:
llm = initialize_llm(api, model, groq_api_key, azure_api_key)
if pinecone_index is None:
pinecone_client = Pinecone(pinecone_api_key)
pinecone_index = pinecone_client.Index("ll144")
logging.info("Initialized LLM and Pinecone.")
except Exception as e:
log_and_exit(f"Error initializing APIs: {e}")
def initialize_llm(api, model, groq_api_key, azure_api_key):
if api == 'groq':
model_mappings = {
'mixtral-8x7b': "mixtral-8x7b-32768",
'llama3-8b': "llama3-8b-8192",
'llama3-70b': "llama3-70b-8192",
'gemma-7b': "gemma-7b-it"
}
return Groq(model=model_mappings[model], api_key=groq_api_key)
elif api == 'azure':
if model == 'gpt35':
return AzureOpenAI(
deployment_name=AZURE_DEPLOYMENT_NAME,
temperature=0,
api_key=azure_api_key,
azure_endpoint=AZURE_OPENAI_ENDPOINT,
api_version=AZURE_API_VERSION
)
def load_pdf_data(chunk_size):
reader = PyMuPDFReader()
file_extractor = {".pdf": reader}
documents = SimpleDirectoryReader(input_files=['LL144.pdf', 'LL144_Definitions.pdf'], file_extractor=file_extractor).load_data()
return documents
def create_index(documents, embedding_model_type="HF", embedding_model="BAAI/bge-large-en-v1.5", retriever_method="BM25", chunk_size=512):
global llm, pinecone_index
try:
embed_model = select_embedding_model(embedding_model_type, embedding_model)
Settings.llm = llm
Settings.embed_model = embed_model
Settings.chunk_size = chunk_size
if retriever_method in ["BM25", "BM25+Vector"]:
nodes = create_bm25_nodes(documents, chunk_size)
logging.info("Created BM25 nodes from documents.")
if retriever_method == "BM25+Vector":
vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
logging.info("Created index for BM25+Vector from documents.")
return index, nodes
return None, nodes
else:
vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
logging.info("Created index from documents.")
return index, None
except Exception as e:
log_and_exit(f"Error creating index: {e}")
def select_embedding_model(embedding_model_type, embedding_model):
if embedding_model_type == "HF":
return HuggingFaceEmbedding(model_name=embedding_model)
elif embedding_model_type == "OAI":
return OpenAIEmbedding() # Implement OAI Embedding if needed
def create_bm25_nodes(documents, chunk_size):
splitter = SentenceSplitter(chunk_size=chunk_size)
nodes = splitter.get_nodes_from_documents(documents)
return nodes
def select_retriever(index, nodes, retriever_method, top_k):
logging.info(f"Selecting retriever with method: {retriever_method}")
if nodes is not None:
logging.info(f"Available document IDs: {list(range(len(nodes)))}")
else:
logging.warning("Nodes are None")
if retriever_method == 'BM25':
return BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k)
elif retriever_method == "BM25+Vector":
if index is None:
log_and_exit("Index must be initialized when using BM25+Vector retriever method.")
bm25_retriever = RetrieverTool.from_defaults(
retriever=BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k),
description="BM25 Retriever"
)
vector_retriever = RetrieverTool.from_defaults(
retriever=VectorIndexRetriever(index=index),
description="Vector Retriever"
)
router_retriever = RouterRetriever.from_defaults(
retriever_tools=[bm25_retriever, vector_retriever],
llm=llm,
select_multi=True
)
return router_retriever
elif retriever_method == "Vector Search":
if index is None:
log_and_exit("Index must be initialized when using Vector Search retriever method.")
return VectorIndexRetriever(index=index, similarity_top_k=top_k)
else:
log_and_exit(f"Unsupported retriever method: {retriever_method}")
def setup_query_engine(index, response_mode, nodes=None, query_engine_method=None, retriever_method=None, top_k=2):
global llm
try:
logging.info(f"Setting up query engine with retriever_method: {retriever_method} and query_engine_method: {query_engine_method}")
retriever = select_retriever(index, nodes, retriever_method, top_k)
if retriever is None:
log_and_exit("Failed to create retriever. Index or nodes might be None.")
response_synthesizer = get_response_synthesizer(response_mode=response_mode)
index_query_engine = index.as_query_engine(similarity_top_k=top_k) if index else None
if query_engine_method == "FLARE":
query_engine = FLAREInstructQueryEngine(
query_engine=index_query_engine,
max_iterations=4,
verbose=False
)
elif query_engine_method == "MS":
query_engine = MultiStepQueryEngine(
query_engine=index_query_engine,
query_transform=StepDecomposeQueryTransform(llm=llm, verbose=False),
index_summary="Used to answer questions about the regulation"
)
else:
query_engine = RetrieverQueryEngine(retriever=retriever, response_synthesizer=response_synthesizer)
if query_engine is None:
log_and_exit("Failed to create query engine.")
return query_engine
except Exception as e:
logging.error(f"Error setting up query engine: {e}")
traceback.print_exc()
log_and_exit(f"Error setting up query engine: {e}")
def log_to_google_sheets(data):
try:
sheet.append_row(data)
logging.info("Logged data to Google Sheets.")
except Exception as e:
logging.error(f"Error logging data to Google Sheets: {e}")
def update_google_sheets(question_id, feedback=None, detailed_feedback=None, annotated_answer=None):
try:
existing_data = sheet.get_all_values()
headers = existing_data[0]
for i, row in enumerate(existing_data):
if row[0] == question_id:
if feedback is not None:
sheet.update_cell(i+1, headers.index("Feedback") + 1, feedback)
if detailed_feedback is not None:
sheet.update_cell(i+1, headers.index("Detailed Feedback") + 1, detailed_feedback)
if annotated_answer is not None:
sheet.update_cell(i+1, headers.index("annotated_answer") + 1, annotated_answer)
logging.info("Updated data in Google Sheets.")
return
except Exception as e:
logging.error(f"Error updating data in Google Sheets: {e}")
def run_streamlit_app():
if 'query_engine' not in st.session_state:
st.session_state.query_engine = None
st.title("RAG Chat Application")
col1, col2 = st.columns(2)
with col1:
pinecone_api_key = st.text_input("Pinecone API Key")
azure_api_key = st.text_input("Azure API Key")
groq_api_key = st.text_input("Groq API Key")
def update_api_based_on_model():
selected_model = st.session_state['selected_model']
if selected_model == 'gpt35':
st.session_state['selected_api'] = 'azure'
else:
st.session_state['selected_api'] = 'groq'
with col2:
selected_model = st.selectbox("Select Model", ["llama3-8b", "llama3-70b", "mixtral-8x7b", "gemma-7b", "gpt35"], index=4, key='selected_model', on_change=update_api_based_on_model)
selected_api = st.selectbox("Select API", ["azure", "groq"], index=0, key='selected_api', disabled=True)
embedding_model_type = "HF"
embedding_model = st.selectbox("Select Embedding Model", ["BAAI/bge-large-en-v1.5", "other_model"])
retriever_method = st.selectbox("Select Retriever Method", ["Vector Search", "BM25", "BM25+Vector"])
col3, col4 = st.columns(2)
with col3:
chunk_size = st.selectbox("Select Chunk Size", [128, 256, 512, 1024], index=2)
with col4:
top_k = st.selectbox("Select Top K", [1, 2, 3, 5, 6], index=1)
if st.button("Initialize"):
initialize_apis(st.session_state['selected_api'], selected_model, pinecone_api_key, groq_api_key, azure_api_key)
documents = load_pdf_data(chunk_size)
index, nodes = create_index(documents, embedding_model_type=embedding_model_type, embedding_model=embedding_model, retriever_method=retriever_method, chunk_size=chunk_size)
st.session_state.query_engine = setup_query_engine(index, response_mode="compact", nodes=nodes, query_engine_method=None, retriever_method=retriever_method, top_k=top_k)
st.success("Initialization complete.")
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
for chat_index, chat in enumerate(st.session_state.chat_history):
with st.chat_message("user"):
st.markdown(chat['user'])
with st.chat_message("bot"):
st.markdown("### Retrieved Contexts")
for node in chat.get('contexts', []):
st.markdown(
f"<div style='border:1px solid #ccc; padding:10px; margin:10px 0; font-size:small;'>{node.text}</div>",
unsafe_allow_html=True
)
st.markdown("### Answer")
st.markdown(chat['response'])
col1, col2 = st.columns([1, 1])
with col1:
if st.button("Annotate πŸ‘Ž", key=f"annotate_{chat_index}"):
chat['annotate'] = True
chat['feedback'] = -1
st.session_state.chat_history[chat_index] = chat
update_google_sheets(chat['id'], feedback=-1)
st.rerun()
with col2:
if st.button("Approve πŸ‘", key=f"approve_{chat_index}"):
chat['approved'] = True
chat['feedback'] = 1
st.session_state.chat_history[chat_index] = chat
update_google_sheets(chat['id'], feedback=1, annotated_answer=chat['response'])
if chat.get('annotate', False):
annotated_answer = st.text_area("Annotate Answer", value=chat['response'], key=f"annotate_text_{chat_index}")
if st.button("Submit Annotated Answer", key=f"submit_annotate_{chat_index}"):
chat['annotated_answer'] = annotated_answer
chat['annotate'] = False
st.session_state.chat_history[chat_index] = chat
update_google_sheets(chat['id'], annotated_answer=annotated_answer)
feedback = st.text_area("How was the response? Does it match the context? Does it answer the question fully?", key=f"textarea_{chat_index}")
if st.button("Submit Feedback", key=f"submit_{chat_index}"):
chat['detailed_feedback'] = feedback
st.session_state.chat_history[chat_index] = chat
update_google_sheets(chat['id'], detailed_feedback=feedback)
if question := st.chat_input("Enter your question"):
if st.session_state.query_engine:
with st.spinner('Generating response...'):
# Compile chat history for context
history = "\n".join([f"Q: {chat['user']}\nA: {chat['response']}" for chat in st.session_state.chat_history])
full_query = f"{history}\nQ: {question}"
response = st.session_state.query_engine.query(full_query)
logging.info(f"Generated response: {response.response}")
logging.info(f"Retrieved contexts: {[node.text for node in response.source_nodes]}")
question_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
st.session_state.chat_history.append({'id': question_id, 'user': question, 'response': response.response, 'contexts': response.source_nodes, 'feedback': 0, 'detailed_feedback': '', 'annotated_answer': '', 'timestamp': timestamp})
# Log initial query and response to Google Sheets without feedback
log_to_google_sheets([question_id, question, response.response, st.session_state['selected_api'], selected_model, embedding_model, retriever_method, chunk_size, top_k, 0, "", "", timestamp])
st.rerun()
else:
st.error("Query engine is not initialized. Please initialize it first.")
if __name__ == "__main__":
run_streamlit_app()